commit abd8f9bdb27dff0e4ed253f17d2dc0eb1bb5e2cb
parent 037c020a5b0397a611b4024036ce68c48f50fbc8
Author: Andrew <andrewlaack1@gmail.com>
Date: Sun, 22 Dec 2024 10:11:48 -0600
Added prediction
Diffstat:
5 files changed, 132 insertions(+), 26 deletions(-)
diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp
@@ -1,6 +1,7 @@
#include "DecisionTreeClassifier.h"
#include <limits>
#include <iostream>
+#include <unordered_map>
using namespace std;
@@ -8,8 +9,21 @@ DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){
depth = maxDepth;
}
-void DecisionTreeClassifier::fit(float* X, int rows, int* y, int columns){
- splittingTree = recurse(X, rows, y, columns, depth);
+void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){
+ if (splittingTree != nullptr){
+ throw logic_error("Decision trees don't support incremental learning, fit can only be called once.");
+ }
+
+ if(features <= 0){
+ throw invalid_argument("Invalid argument, there must be 1 or more features to train on.");
+ }
+
+ if(samples <= 0){
+ throw invalid_argument("Invalid argument, there must be 1 or more samples to train on.");
+ }
+
+ splittingTree = recurse(X, samples, y, features, depth);
+ featureCount = features;
}
@@ -22,19 +36,40 @@ std::string DecisionTreeClassifier::getDot(){
return dot;
}
+int DecisionTreeClassifier::primaryClass(int* y, int labelCount){
+
+ unordered_map map = unordered_map<int,int>();
+
+ for(int i = 0; i < labelCount; ++i){
+ map[y[i]] += 1;
+ }
+
+ int mostElements = 0;
+ int label = 0;
+
+ for (auto& item : map){
+ if(item.second > mostElements){
+ mostElements = item.second;
+ label = item.first;
+ }
+ }
+
+ return label;
+}
+
// add depth
TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){
if(depthRem == 0){
- TreeNode* ret = new TreeNode();
+ TreeNode* ret = new TreeNode(primaryClass(y, rows));
return ret;
}
// found minimum node
if(rows == 1){
- TreeNode* ret = new TreeNode();
+ TreeNode* ret = new TreeNode(primaryClass(y, rows));
return ret;
}
@@ -44,7 +79,7 @@ TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int column
// no valid splits, but we still did create some new arrays.
if(split.rightSize == rows || split.leftSize == rows){
- TreeNode* ret = new TreeNode();
+ TreeNode* ret = new TreeNode(primaryClass(y, rows));
delete split.XLeft;
delete split.XRight;
delete split.yLeft;
@@ -107,4 +142,34 @@ TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int colu
}
+int* DecisionTreeClassifier::predict(float* X, int samples, int features){
+ if(featureCount == -1){
+ throw logic_error("Unable to predict prior to calling fit().");
+ }
+
+ if(features != this->featureCount){
+ throw invalid_argument("Incorrect number of features for prediction.");
+ }
+ cout << "PREDICTING" << endl;
+
+ int* predictions = new int[samples];
+
+ for(int i = 0; i < samples; ++i){
+ TreeNode* current = splittingTree;
+ while(!current->isLeaf()){
+ float* currentElement = X;
+ currentElement += features * i;
+ bool lessThan = current->lessThan(currentElement, features);
+ if(lessThan){
+ current = current->getLeftChild();
+ }
+ else{
+ current = current->getRightChild();
+ }
+ }
+ predictions[i] = current->getClassification();
+ }
+
+ return predictions;
+}
diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/DecisionTreeClassifier.h
@@ -3,12 +3,14 @@
class DecisionTreeClassifier{
public:
DecisionTreeClassifier(int depth);
- void fit(float* X, int rows, int* y, int columns);
- int* predict();
+ void fit(float* X, int samples, int* y, int features);
+ int* predict(float* X, int samples, int features);
std::string getDot();
private:
int depth;
+ int featureCount = -1;
TreeNode* splittingTree = nullptr;
- TreeNode* bestSplit(float* X, int rows, int* y, int columns);
- TreeNode* recurse(float* X, int rows, int* y, int columns, int depth);
+ TreeNode* bestSplit(float* X, int samples, int* y, int features);
+ TreeNode* recurse(float* X, int samples, int* y, int features, int depth);
+ int primaryClass(int* y, int labelCount);
};
diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp
@@ -34,31 +34,50 @@ void testTreeNode(){
int main(){
- testTreeNode();
+ cout << "STARTING" << endl;
+
+ int SAMPLES = 20000;
DecisionTreeClassifier clf(1);
- // Large array of labels for 1000 samples
- int labels[30000];
- for (int i = 0; i < 30000; ++i) {
+ int labels[SAMPLES];
+ for (int i = 0; i < SAMPLES; ++i){
labels[i] = i % 10; // Example: create labels that cycle through 0 to 9
}
- // Large array of samples, 1000 samples with 4 features each
- float samples[30000][4];
- for (int i = 0; i < 30000; ++i) {
+ float samples[SAMPLES][4];
+ for (int i = 0; i < SAMPLES; ++i){
+ // Fill the samples with some arbitrary data (example: sequential)
+ samples[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10)
+ samples[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5)
+ samples[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7)
+ samples[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3)
+ }
+
+ clf.fit(*samples, SAMPLES, labels, 4);
+ int PREDS = 10;
+
+ float preds[PREDS][4];
+ for (int i = 0; i < PREDS; ++i){
// Fill the samples with some arbitrary data (example: sequential)
- samples[i][0] = (i % 10) + 1; // Feature 1 (e.g., 1 to 10)
- samples[i][1] = (i % 5) + 1; // Feature 2 (e.g., 1 to 5)
- samples[i][2] = (i % 7) + 1; // Feature 3 (e.g., 1 to 7)
- samples[i][3] = (i % 3) + 1; // Feature 4 (e.g., 1 to 3)
+ preds[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10)
+ preds[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5)
+ preds[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7)
+ preds[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3)
}
- cout << "FITTING" << endl;
- // Fit the classifier to the data
- clf.fit(*samples, 30000, labels, 4);
+ int* predsOut = clf.predict(*preds, PREDS, 4);
+
+ cout << "DONE" << endl;
+
+ for(int i = 0 ; i < PREDS; ++i){
+ cout << preds[i][0] << " " << preds[i][1] << " " <<preds[i][2] << " " << preds[i][3] << " " << predsOut[i] << endl;
+ }
+
+ delete[] predsOut;
return 0;
+
ofstream outFile("decision_tree.dot");
// Check if the file is open
diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp
@@ -6,8 +6,9 @@
#include <sstream> //for std::stringstream
#include <string> //for std::string
-TreeNode::TreeNode(){
+TreeNode::TreeNode(int classification){
leaf = true;
+ this->classification = classification;
}
TreeNode::TreeNode(float splittingVal, int featureIndex){
@@ -220,8 +221,24 @@ std::string TreeNode::getDotLabel(){
ss << address;
std::string name = ss.str();
if (isLeaf()){
- return "\"" + name + " - LEAF" + "\"";
+ return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\"";
}
return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\"";
}
+
+int TreeNode::getClassification(){
+ if(isLeaf()){
+ return classification;
+ }
+ throw std::logic_error("Unable to call getClassification() on internal vertices.");
+}
+
+bool TreeNode::lessThan(float* sample, int features){
+
+ if(features < this->index){
+ throw std::invalid_argument("Attempting to evaluate split with input that contains less features.");
+ }
+
+ return(sample[index] < splitValue);
+}
diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h
@@ -11,7 +11,7 @@ struct SplitResults{
class TreeNode{
public:
- TreeNode();
+ TreeNode(int classification);
TreeNode(float splittingVal, int featureIndex);
bool isLeaf();
void setSplit(float splittingValue, int featureIndex);
@@ -24,6 +24,8 @@ class TreeNode{
int getIndexSplit();
SplitResults splitOnNode(float* X, int* y, int samples, int features);
std::string getDotEdges();
+ int getClassification();
+ bool lessThan(float* sample, int features);
private:
bool leaf;
@@ -33,6 +35,7 @@ class TreeNode{
TreeNode* rightChild;
float giniImpurity(float* X, int* y, int samples, int features);
std::string getDotLabel();
+ int classification;
};