In the previous post we learnt it is possible to train a Machine learning classifier directly on a microcontroller. In this post we'll look into how to do it to classify colors.

This will be an hands-on guide, so let's walk throughout each step you need to complete to run the example.

I setup this very example as a basis for your future projects, so you can easily swap the color classification task for any other one you could think of.

Definitions

#ifdef ESP32
#define min(a, b) (a) < (b) ? (a) : (b)
#define max(a, b) (a) > (b) ? (a) : (b)
#define abs(x) ((x) > 0 ? (x) : -(x))
#endif

#include <EloquentSVMSMO.h>
#include "RGB.h"

#define MAX_TRAINING_SAMPLES 20
#define FEATURES_DIM 3

using namespace Eloquent::ML;

int numSamples;
RGB rgb(2, 3, 4);
float X_train[MAX_TRAINING_SAMPLES][FEATURES_DIM];
int y_train[MAX_TRAINING_SAMPLES];
SVMSMO<FEATURES_DIM> classifier(linearKernel);

When training a classifier on your microcontroller there are some things that are mandatory:

  1. #include <EloquentSVMSMO.h>: this is the library that implements the SVM learning algorithm
  2. X_train: this is a matrix where each row represents a training sample. You will need to keep this data always with you, since it's required also during the inference
  3. y_train: this array contains, for each training sample, the class it belongs to: 1 or -1
  4. linearKernel: this is the kernel function for the SVM classifier (you can read more here). You can pass your own kernel other than linear (for example poly or rbf)

In this specific example, we're using the RGB class to handle the TCS3200 sensor reading, but this will change based on the dataset you want to train on. Also, since our features are going to be the R, G and B components of a color, FEATURES_DIM is set to 3.

Setup

void setup() {
    Serial.begin(115200);
    rgb.begin();

    classifier.setC(5);
    classifier.setTol(1e-5);
    classifier.setMaxIter(10000);
}

The setup does not contain any logic really. You can use this part to configure the parameters of the classifier:

  • C: "The C parameter tells the SVM optimization how much you want to avoid misclassifying each training example. For large values of C, the optimization will choose a smaller-margin hyperplane if that hyperplane does a better job of getting all the training points classified correctly. Conversely, a very small value of C will cause the optimizer to look for a larger-margin separating hyperplane, even if that hyperplane misclassifies more points" (quoted from stackexchange)
  • tol: "The tol parameter is a setting for the SVM's tolerance in optimization. Recall that yi(xi.w+b)-1 >= 0. For an SVM to be valid, all values must be greater than or equal to 0, and at least one value on each side needs to be "equal" to 0, which will be your support vectors. Since it is highly unlikely that you will actually get values equal perfectly to 0, you set tolerance to allow a bit of wiggle room." (quoted from pythonprogramming)
  • maxIter: set an upper bound to the number of iterations the algorithm can take to converge
  • passes: max # of times to iterate over α’s without changing
  • alphaTol: alfpha coefficients determine which samples from the training set are to be considered support vectors and so be included during the inference procedure. This value discards support vectors with an alpha too small to be noticeable.

Fit

else if (command == "fit") {
        Serial.print("How many samples will you record? ");
        numSamples = readSerialNumber();

        for (int i = 0; i < numSamples; i++) {
            Serial.print(i + 1);
            Serial.print("/");
            Serial.print(numSamples);
            Serial.println(" Which class does the sample belongs to, 1 or -1?");
            y_train[i] = readSerialNumber() > 0 ? 1 : -1;
            getFeatures(X_train[i]);
        }

        Serial.print("Start training... ");
        classifier.fit(X_train, y_train, numSamples);
        Serial.println("Done");
    }

This is the core of the project. Here we are loading the samples to train our classifier "live" on the board.

Since this is an interactive demo, the program prompts us to define how many samples we'll load and, one by one, which class they belong to.

Now there are a few important things to keep in mind:

  • numSamples: sadly, C has no easy way to know the size of an array, so we have to be explicit about it. To train the classifier, it is mandatory that you do know I many samples you're passing to it
  • getFeatures() is the function that reads the training sample. It is actually a "proxy" to your own custom logic: in this example it reads the TCS3200, in your project it could read an accelerometer or the like.
  • fit(): this is where the magic happens. With this single line of code you're training the SVM on the training data; when the functions ends, the classifier will have updated its internal state with the coefficients it needs to classify new samples

Predict

ColorClassificationTrainingExample.ino

else if (command == "predict") {
        int label;
        float x[FEATURES_DIM];

        getFeatures(x);
        Serial.print("Predicted label is ");
        Serial.println(classifier.predict(X_train, x));
    }

Now that our classifier has been trained, we can finally make use of it to classify new samples.

As easy as it can be, you just call its predict method.

As you can see, the predict method requires the X_train matrix other than the new sample vector

And that's it: you can now complete your Machine learning task on your microcontroller from start to end, without the need of a PC.

Want to learn more?


Check the full project code on Github


Full example

ColorClassificationTrainingExample.ino

#include <EloquentSVMSMO.h>
#include "RGB.h"

#define MAX_TRAINING_SAMPLES 20
#define FEATURES_DIM 3

using namespace Eloquent::ML;

int numSamples;
RGB rgb(2, 3, 4);
float X_train[MAX_TRAINING_SAMPLES][FEATURES_DIM];
int y_train[MAX_TRAINING_SAMPLES];
SVMSMO<FEATURES_DIM> classifier(linearKernel);

void setup() {
    Serial.begin(115200);
    rgb.begin();

    classifier.setC(5);
    classifier.setTol(1e-5);
    classifier.setMaxIter(10000);
}

void loop() {
    if (!Serial.available()) {
        delay(100);
        return;
    }

    String command = Serial.readStringUntil('\n');

    if (command == "help") {
        Serial.println("Available commands:");
        Serial.println("\tfit: train the classifier on a new set of samples");
        Serial.println("\tpredict: classify a new sample");
        Serial.println("\tinspect: print X_train and y_train");
    }
    else if (command == "fit") {
        Serial.print("How many samples will you record? ");
        numSamples = readSerialNumber();

        for (int i = 0; i < numSamples; i++) {
            Serial.print(i + 1);
            Serial.print("/");
            Serial.print(numSamples);
            Serial.println(" Which class does the sample belongs to, 1 or -1?");
            y_train[i] = readSerialNumber() > 0 ? 1 : -1;
            getFeatures(X_train[i]);
        }

        Serial.print("Start training... ");
        classifier.fit(X_train, y_train, numSamples);
        Serial.println("Done");
    }
    else if (command == "predict") {
        int label;
        float x[FEATURES_DIM];

        getFeatures(x);
        Serial.print("Predicted label is ");
        Serial.println(classifier.predict(X_train, x));
    }
    else if (command == "inspect") {
        for (int i = 0; i < numSamples; i++) {
            Serial.print("[");
            Serial.print(y_train[i]);
            Serial.print("] ");

            for (int j = 0; j < FEATURES_DIM; j++) {
                Serial.print(X_train[i][j]);
                Serial.print(", ");
            }

            Serial.println();
        }
    }
}

/**
 *
 * @return
 */
int readSerialNumber() {
    while (!Serial.available()) delay(1);

    return Serial.readStringUntil('\n').toInt();
}

/**
 * Get features for new sample
 * @param x
 */
void getFeatures(float x[FEATURES_DIM]) {
    rgb.read(x);

    for (int i = 0; i < FEATURES_DIM; i++) {
        Serial.print(x[i]);
        Serial.print(", ");
    }

    Serial.println();
}

RGB.h

#pragma once

/**
 * Wrapper for RGB color sensor
 */
class RGB {
    public:
        RGB(uint8_t s2, uint8_t s3, uint8_t out) :
            _s2(s2),
            _s3(s3),
            _out(out) {

        }

        /**
         *
         */
        void begin() {
            pinMode(_s2, OUTPUT);
            pinMode(_s3, OUTPUT);
            pinMode(_out, INPUT);
        }

        /**
         *
         * @param x
         */
        void read(float x[3]) {
            x[0] = readComponent(LOW, LOW);
            x[1] = readComponent(HIGH, HIGH);
            x[2] = readComponent(LOW, HIGH);
        }

    protected:
        uint8_t _s2;
        uint8_t _s3;
        uint8_t _out;

        /**
         *
         * @param s2
         * @param s3
         * @return
         */
        int readComponent(bool s2, bool s3) {
            delay(10);
            digitalWrite(_s2, s2);
            digitalWrite(_s3, s3);

            return pulseIn(_out, LOW);
        }
};
Help the blow grow