You will be surprised by how much accuracy you can achieve in just a few kylobytes of resources: Decision Tree, Random Forest and XGBoost (Extreme Gradient Boosting) are now available on your microcontrollers: highly RAM-optmized implementations for super-fast classification on embedded devices.
Decision Tree
Decision Tree is without doubt one of the most well-known classification algorithms out there. It is so simple to understand that it was probably the first classifier you encountered in any Machine Learning course.
I won't go into the details of how a Decision Tree classifier trains and selects the splits for the input features: here I will explain how a RAM-efficient porting of such a classifier is implemented.
To an introduction visit Wikipedia; for a more in-depth guide visit KDNuggets.
Since we're willing to sacrifice program space (a.k.a flash) in favor of memory (a.k.a RAM), because RAM is the most scarce resource in the vast majority of microcontrollers, the smart way to port a Decision Tree classifier from Python to C is "hard-coding" the splits in code, without keeping any reference to them into variables.
Here's what it looks like for a Decision tree that classifies the Iris dataset.
As you can see, we're using 0 bytes of RAM to get the classification result, since no variable is being allocated. On the other side, the program space will grow almost linearly with the number of splits.
Since program space is often much greater than RAM on microcontrollers, this implementation exploits its abundance to be able to deploy larger models. How much large? It will depend on the flash size available: many new generations board (Arduino Nano 33 BLE Sense, ESP32, ST Nucleus...) have 1 Mb of flash, which will hold tens of thousands of splits.
Random Forest
Random Forest is just many Decision Trees joined together in a voting scheme. The core idea is that of "the wisdom of the corwd", such that if many trees vote for a given class (having being trained on different subsets of the training set), that class is probably the true class.
Towards Data Science has a more detailed guide on Random Forest and how it balances the trees with thebagging tecnique.
As easy as Decision Trees, Random Forest gets the exact same implementation with 0 bytes of RAM required (it actually needs as many bytes as the number of classes to store the votes, but that's really negligible): it just hard-codes all its composing trees.
XGBoost (Extreme Gradient Boosting)
Extreme Gradient Boosting is "Gradient Boosting on steroids" and has gained much attention from the Machine learning community due to its top results in many data competitions.
- "gradient boosting" refers to the process of chaining a number of trees so that each tree tries to learn from the errors of the previous
- "extreme" refers to many software and hardware optimizations that greatly reduce the time it takes to train the model
You can read the original paper about XGBoost here. For a discursive description head to KDNuggets, if you want some more math refer to this blog post on Medium.
Porting to plain C
If you followed my earlier posts on Gaussian Naive Bayes, SEFR, Relevant Vector Machine and Support Vector Machines, you already know how to port these new classifiers.
If you're new, you will need a couple things:
- install the micromlgen package with
pip install micromlgen
- (optionally, if you want to use Extreme Gradient Boosting) install the xgboost package with
pip install xgboost
- use the
micromlgen.port
function to generate your plain C code
from micromlgen import port
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
clf = DecisionTreeClassifier()
X, y = load_iris(return_X_y=True)
clf.fit(X, y)
print(port(clf))
You can then copy-past the C code and import it in your sketch.
Using in the Arduino sketch
Once you have the classifier code, create a new project named TreeClassifierExample
and copy the classifier code into a file named DecisionTree.h
(or RandomForest.h
or XGBoost.h
depending on the model you chose).
The copy the following to the main ino file.
#include "DecisionTree.h"
Eloquent::ML::Port::DecisionTree clf;
void setup() {
Serial.begin(115200);
Serial.println("Begin");
}
void loop() {
float irisSample[4] = {6.2, 2.8, 4.8, 1.8};
Serial.print("Predicted label (you should see '2': ");
Serial.println(clf.predict(irisSample));
delay(1000);
}
Bechmarks
How do the 3 classifiers compare against each other?
We will evaluate a few keypoints:
- training time
- accuracy
- needed RAM
- needed Flash
for each classifier on a variety of datasets. I will report the results for RAM and Flash on the Arduino Nano old generation, so you should consider more the relative figures than the absolute ones.
Dataset | Classifier | Training time (s) | Accuracy | RAM (bytes) | Flash (bytes) |
---|---|---|---|---|---|
Gas Sensor Array Drift Dataset | Decision Tree | 1,6 | 0.781 ± 0.12 | 290 | 5722 |
13910 samples x 128 features | Random Forest | 3 | 0.865 ± 0.083 | 290 | 6438 |
6 classes | XGBoost | 18,8 | 0.878 ± 0.074 | 290 | 6506 |
Gesture Phase Segmentation Dataset | Decision Tree | 0,1 | 0.943 ± 0.005 | 290 | 5638 |
10000 samples x 19 features | Random Forest | 0,7 | 0.970 ± 0.004 | 306 | 6466 |
5 classes | XGBoost | 18,9 | 0.969 ± 0.003 | 306 | 6536 |
Drive Diagnosis Dataset | Decision Tree | 0,6 | 0.946 ± 0.005 | 306 | 5850 |
10000 samples x 48 features | Random Forest | 2,6 | 0.983 ± 0.003 | 306 | 6526 |
11 classes | XGBoost | 68,9 | 0.977 ± 0.005 | 306 | 6698 |
* all datasets are taken from the UCI Machine Learning datasets archive
I'm collecting more data for a complete benchmark, but in the meantime you can see that both Random Forest and XGBoost are on par: if not that XGBoost takes 5 to 25 times longer to train.
I've never used XGBoost, so I may be missing some tuning parameters, but for now Random Forest remains my favourite classifier.
Troubleshooting
It can happen that when running micromlgen.port(clf)
you get a TemplateNotFound
error. To solve the problem, first of all uninstall micromlgen
.
pip uninstall micromlgen
Then head to Github, download the package as zip and extract the micromlgen
folder into your project.
Code listings
// example IRIS dataset classification with Decision Tree
int predict(float *x) {
if (x[3] <= 0.800000011920929) {
return 0;
}
else {
if (x[3] <= 1.75) {
if (x[2] <= 4.950000047683716) {
if (x[0] <= 5.049999952316284) {
return 1;
}
else {
return 1;
}
}
else {
return 2;
}
}
else {
if (x[2] <= 4.950000047683716) {
return 2;
}
else {
return 2;
}
}
}
}
// example IRIS dataset classification with Random Forest of 3 trees
int predict(float *x) {
uint16_t votes[3] = { 0 };
// tree #1
if (x[0] <= 5.450000047683716) {
if (x[1] <= 2.950000047683716) {
votes[1] += 1;
}
else {
votes[0] += 1;
}
}
else {
if (x[0] <= 6.049999952316284) {
if (x[3] <= 1.699999988079071) {
if (x[2] <= 3.549999952316284) {
votes[0] += 1;
}
else {
votes[1] += 1;
}
}
else {
votes[2] += 1;
}
}
else {
if (x[3] <= 1.699999988079071) {
if (x[3] <= 1.449999988079071) {
if (x[0] <= 6.1499998569488525) {
votes[1] += 1;
}
else {
votes[1] += 1;
}
}
else {
votes[1] += 1;
}
}
else {
votes[2] += 1;
}
}
}
// tree #2
if (x[0] <= 5.549999952316284) {
if (x[2] <= 2.449999988079071) {
votes[0] += 1;
}
else {
if (x[2] <= 3.950000047683716) {
votes[1] += 1;
}
else {
votes[1] += 1;
}
}
}
else {
if (x[3] <= 1.699999988079071) {
if (x[1] <= 2.649999976158142) {
if (x[3] <= 1.25) {
votes[1] += 1;
}
else {
votes[1] += 1;
}
}
else {
if (x[2] <= 4.1499998569488525) {
votes[1] += 1;
}
else {
if (x[0] <= 6.75) {
votes[1] += 1;
}
else {
votes[1] += 1;
}
}
}
}
else {
if (x[0] <= 6.0) {
votes[2] += 1;
}
else {
votes[2] += 1;
}
}
}
// tree #3
if (x[3] <= 1.75) {
if (x[2] <= 2.449999988079071) {
votes[0] += 1;
}
else {
if (x[2] <= 4.8500001430511475) {
if (x[0] <= 5.299999952316284) {
votes[1] += 1;
}
else {
votes[1] += 1;
}
}
else {
votes[1] += 1;
}
}
}
else {
if (x[0] <= 5.950000047683716) {
votes[2] += 1;
}
else {
votes[2] += 1;
}
}
// return argmax of votes
uint8_t classIdx = 0;
float maxVotes = votes[0];
for (uint8_t i = 1; i < 3; i++) {
if (votes[i] > maxVotes) {
classIdx = i;
maxVotes = votes[i];
}
}
return classIdx;
}