In this hands-on guide about on-board SVM training we're going to see a classifier in action, training it on the Iris dataset and evaluating its performance.

What we'll make

In this demo project we're going to take a know dataset (iris flowers) and interactively train an SVM classifier on it, adjusting the number of samples to see the effects on both training time, inference time and accuracy.

Definitions

#ifdef ESP32
#define min(a, b) (a) < (b) ? (a) : (b)
#define max(a, b) (a) > (b) ? (a) : (b)
#define abs(x) ((x) > 0 ? (x) : -(x))
#endif

#include <EloquentSVMSMO.h>
#include "iris.h"

#define TOTAL_SAMPLES (POSITIVE_SAMPLES + NEGATIVE_SAMPLES)

using namespace Eloquent::ML;

float X_train[TOTAL_SAMPLES][FEATURES_DIM];
float X_test[TOTAL_SAMPLES][FEATURES_DIM];
int y_train[TOTAL_SAMPLES];
int y_test[TOTAL_SAMPLES];
SVMSMO<FEATURES_DIM> classifier(linearKernel);

First of all we need to include a couple files, namely EloquentSVMSMO.h for the SVM classifier and iris.h for the dataset.

iris.h defines a couple constants:

  • FEATURES_DIM: the number of features each sample has (4 in this case)
  • POSITIVE_SAMPLES: the number of samples that belong to the positive class (50)
  • NEGATIVE_SAMPLES: the number of samples that belong to the negative class (50)

The we declare the array that hold the data: X_train and y_train for the training process, X_test and y_test for the inference process.

Setup

void setup() {
    Serial.begin(115200);
    delay(5000);

    // configure classifier
    classifier.setC(5);
    classifier.setTol(1e-5);
    classifier.setMaxIter(10000);
}

Here we just set a few parameters for the classifier. You could actually skip this step in this demo, since the defaults will work well. Those lines are there so you know you can tweak them, if needed.

Please refer to the demo for color classification for an explanation of each parameter.

Interactivity

void loop() {
    int positiveSamples = readSerialNumber("How many positive samples will you use for training? ", POSITIVE_SAMPLES);

    if (positiveSamples > POSITIVE_SAMPLES - 1) {
        Serial.println("Too many positive samples entered. All but one will be used instead");
        positiveSamples = POSITIVE_SAMPLES - 1;
    }

    int negativeSamples = readSerialNumber("How many negative samples will you use for training? ", NEGATIVE_SAMPLES);

    if (negativeSamples > NEGATIVE_SAMPLES - 1) {
        Serial.println("Too many negative samples entered. All but one will be used instead");
        negativeSamples = NEGATIVE_SAMPLES - 1;
    }

    loadDataset(positiveSamples, negativeSamples);

    // ...
}

/**
 * Ask the user to enter a numeric value
 */
int readSerialNumber(String prompt, int maxAllowed) {
    Serial.print(prompt);
    Serial.print(" (");
    Serial.print(maxAllowed);
    Serial.print(" max) ");

    while (!Serial.available()) delay(1);

    int n = Serial.readStringUntil('\n').toInt();

    Serial.println(n);

    return n;
}

/**
 * Divide training and test data
 */
void loadDataset(int positiveSamples, int negativeSamples) {
    int positiveTestSamples = POSITIVE_SAMPLES - positiveSamples;

    for (int i = 0; i < positiveSamples; i++) {
        memcpy(X_train[i], X_positive[i], FEATURES_DIM);
        y_train[i] = 1;
    }

    for (int i = 0; i < negativeSamples; i++) {
        memcpy(X_train[i + positiveSamples], X_negative[i], FEATURES_DIM);
        y_train[i + positiveSamples] = -1;
    }

    for (int i = 0; i < positiveTestSamples; i++) {
        memcpy(X_test[i], X_positive[i + positiveSamples], FEATURES_DIM);
        y_test[i] = 1;
    }

    for (int i = 0; i < NEGATIVE_SAMPLES - negativeSamples; i++) {
        memcpy(X_test[i + positiveTestSamples], X_negative[i + negativeSamples], FEATURES_DIM);
        y_test[i + positiveTestSamples] = -1;
    }
}

The code above is a preliminary step where you're asked to enter how many samples you will use for training of both positive and negative classes.

This way you can have multiple run of benchmarking without the need to re-compile and re-upload the sketch.

It also shows that the training process can be "dynamic", in the sense that you can tweak it at runtime as per your need.

Training

time_t start = millis();
classifier.fit(X_train, y_train, positiveSamples + negativeSamples);
Serial.print("It took ");
Serial.print(millis() - start);
Serial.print("ms to train on ");
Serial.print(positiveSamples + negativeSamples);
Serial.println(" samples");

Training is actually a one line operation. Here we'll also logging how much time it takes to train.

Predicting

void loop() {
    // ...

    int tp = 0;
    int tn = 0;
    int fp = 0;
    int fn = 0;

    start = millis();

    for (int i = 0; i < TOTAL_SAMPLES - positiveSamples - negativeSamples; i++) {
        int y_pred = classifier.predict(X_train, X_test[i]);
        int y_true = y_test[i];

        if (y_pred == y_true && y_pred ==  1) tp += 1;
        if (y_pred == y_true && y_pred == -1) tn += 1;
        if (y_pred != y_true && y_pred ==  1) fp += 1;
        if (y_pred != y_true && y_pred == -1) fn += 1;
    }

    Serial.print("It took ");
    Serial.print(millis() - start);
    Serial.print("ms to test on ");
    Serial.print(TOTAL_SAMPLES - positiveSamples - negativeSamples);
    Serial.println(" samples");

    printConfusionMatrix(tp, tn, fp, fn);
}

/**
 * Dump confusion matrix to Serial monitor
 */
void printConfusionMatrix(int tp, int tn, int fp, int fn) {
    Serial.print("Overall accuracy ");
    Serial.print(100.0 * (tp + tn) / (tp + tn + fp + fn));
    Serial.println("%");
    Serial.println("Confusion matrix");
    Serial.print("          | Predicted 1 | Predicted -1 |\n");
    Serial.print("----------------------------------------\n");
    Serial.print("Actual  1 |      ");
    Serial.print(tp);
    Serial.print("     |      ");
    Serial.print(fn);
    Serial.print("       |\n");
    Serial.print("----------------------------------------\n");
    Serial.print("Actual -1 |      ");
    Serial.print(fp);
    Serial.print("      |      ");
    Serial.print(tn);
    Serial.print("       |\n");
    Serial.print("----------------------------------------\n\n\n");
}

Finally we can run the classification on our test set and get the overall accuracy.

We also print the confusion matrix to double-check each class accuracy.

Want to learn more?


Check the full project code on Github where you'll also find another dataset to test, which is characterized by a number of features much higher (30 instead of 4).

Help the blow grow