decision-tree-classifier

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit abd8f9bdb27dff0e4ed253f17d2dc0eb1bb5e2cb
parent 037c020a5b0397a611b4024036ce68c48f50fbc8
Author: Andrew <andrewlaack1@gmail.com>
Date:   Sun, 22 Dec 2024 10:11:48 -0600

Added prediction

Diffstat:
Mrewrite/DecisionTreeClassifier.cpp | 75++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
Mrewrite/DecisionTreeClassifier.h | 10++++++----
Mrewrite/Tests.cpp | 47+++++++++++++++++++++++++++++++++--------------
Mrewrite/TreeNode.cpp | 21+++++++++++++++++++--
Mrewrite/TreeNode.h | 5++++-
5 files changed, 132 insertions(+), 26 deletions(-)

diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp @@ -1,6 +1,7 @@ #include "DecisionTreeClassifier.h" #include <limits> #include <iostream> +#include <unordered_map> using namespace std; @@ -8,8 +9,21 @@ DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){ depth = maxDepth; } -void DecisionTreeClassifier::fit(float* X, int rows, int* y, int columns){ - splittingTree = recurse(X, rows, y, columns, depth); +void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){ + if (splittingTree != nullptr){ + throw logic_error("Decision trees don't support incremental learning, fit can only be called once."); + } + + if(features <= 0){ + throw invalid_argument("Invalid argument, there must be 1 or more features to train on."); + } + + if(samples <= 0){ + throw invalid_argument("Invalid argument, there must be 1 or more samples to train on."); + } + + splittingTree = recurse(X, samples, y, features, depth); + featureCount = features; } @@ -22,19 +36,40 @@ std::string DecisionTreeClassifier::getDot(){ return dot; } +int DecisionTreeClassifier::primaryClass(int* y, int labelCount){ + + unordered_map map = unordered_map<int,int>(); + + for(int i = 0; i < labelCount; ++i){ + map[y[i]] += 1; + } + + int mostElements = 0; + int label = 0; + + for (auto& item : map){ + if(item.second > mostElements){ + mostElements = item.second; + label = item.first; + } + } + + return label; +} + // add depth TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){ if(depthRem == 0){ - TreeNode* ret = new TreeNode(); + TreeNode* ret = new TreeNode(primaryClass(y, rows)); return ret; } // found minimum node if(rows == 1){ - TreeNode* ret = new TreeNode(); + TreeNode* ret = new TreeNode(primaryClass(y, rows)); return ret; } @@ -44,7 +79,7 @@ TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int column // no valid splits, but we still did create some new arrays. if(split.rightSize == rows || split.leftSize == rows){ - TreeNode* ret = new TreeNode(); + TreeNode* ret = new TreeNode(primaryClass(y, rows)); delete split.XLeft; delete split.XRight; delete split.yLeft; @@ -107,4 +142,34 @@ TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int colu } +int* DecisionTreeClassifier::predict(float* X, int samples, int features){ + if(featureCount == -1){ + throw logic_error("Unable to predict prior to calling fit()."); + } + + if(features != this->featureCount){ + throw invalid_argument("Incorrect number of features for prediction."); + } + cout << "PREDICTING" << endl; + + int* predictions = new int[samples]; + + for(int i = 0; i < samples; ++i){ + TreeNode* current = splittingTree; + while(!current->isLeaf()){ + float* currentElement = X; + currentElement += features * i; + bool lessThan = current->lessThan(currentElement, features); + if(lessThan){ + current = current->getLeftChild(); + } + else{ + current = current->getRightChild(); + } + } + predictions[i] = current->getClassification(); + } + + return predictions; +} diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/DecisionTreeClassifier.h @@ -3,12 +3,14 @@ class DecisionTreeClassifier{ public: DecisionTreeClassifier(int depth); - void fit(float* X, int rows, int* y, int columns); - int* predict(); + void fit(float* X, int samples, int* y, int features); + int* predict(float* X, int samples, int features); std::string getDot(); private: int depth; + int featureCount = -1; TreeNode* splittingTree = nullptr; - TreeNode* bestSplit(float* X, int rows, int* y, int columns); - TreeNode* recurse(float* X, int rows, int* y, int columns, int depth); + TreeNode* bestSplit(float* X, int samples, int* y, int features); + TreeNode* recurse(float* X, int samples, int* y, int features, int depth); + int primaryClass(int* y, int labelCount); }; diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp @@ -34,31 +34,50 @@ void testTreeNode(){ int main(){ - testTreeNode(); + cout << "STARTING" << endl; + + int SAMPLES = 20000; DecisionTreeClassifier clf(1); - // Large array of labels for 1000 samples - int labels[30000]; - for (int i = 0; i < 30000; ++i) { + int labels[SAMPLES]; + for (int i = 0; i < SAMPLES; ++i){ labels[i] = i % 10; // Example: create labels that cycle through 0 to 9 } - // Large array of samples, 1000 samples with 4 features each - float samples[30000][4]; - for (int i = 0; i < 30000; ++i) { + float samples[SAMPLES][4]; + for (int i = 0; i < SAMPLES; ++i){ + // Fill the samples with some arbitrary data (example: sequential) + samples[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10) + samples[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5) + samples[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7) + samples[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3) + } + + clf.fit(*samples, SAMPLES, labels, 4); + int PREDS = 10; + + float preds[PREDS][4]; + for (int i = 0; i < PREDS; ++i){ // Fill the samples with some arbitrary data (example: sequential) - samples[i][0] = (i % 10) + 1; // Feature 1 (e.g., 1 to 10) - samples[i][1] = (i % 5) + 1; // Feature 2 (e.g., 1 to 5) - samples[i][2] = (i % 7) + 1; // Feature 3 (e.g., 1 to 7) - samples[i][3] = (i % 3) + 1; // Feature 4 (e.g., 1 to 3) + preds[i][0] = (i % 157) + 1; // Feature 1 (e.g., 1 to 10) + preds[i][1] = (i % 250) + 1; // Feature 2 (e.g., 1 to 5) + preds[i][2] = (i % 492) + 1; // Feature 3 (e.g., 1 to 7) + preds[i][3] = (i % 481) + 1; // Feature 4 (e.g., 1 to 3) } - cout << "FITTING" << endl; - // Fit the classifier to the data - clf.fit(*samples, 30000, labels, 4); + int* predsOut = clf.predict(*preds, PREDS, 4); + + cout << "DONE" << endl; + + for(int i = 0 ; i < PREDS; ++i){ + cout << preds[i][0] << " " << preds[i][1] << " " <<preds[i][2] << " " << preds[i][3] << " " << predsOut[i] << endl; + } + + delete[] predsOut; return 0; + ofstream outFile("decision_tree.dot"); // Check if the file is open diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp @@ -6,8 +6,9 @@ #include <sstream> //for std::stringstream #include <string> //for std::string -TreeNode::TreeNode(){ +TreeNode::TreeNode(int classification){ leaf = true; + this->classification = classification; } TreeNode::TreeNode(float splittingVal, int featureIndex){ @@ -220,8 +221,24 @@ std::string TreeNode::getDotLabel(){ ss << address; std::string name = ss.str(); if (isLeaf()){ - return "\"" + name + " - LEAF" + "\""; + return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\""; } return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\""; } + +int TreeNode::getClassification(){ + if(isLeaf()){ + return classification; + } + throw std::logic_error("Unable to call getClassification() on internal vertices."); +} + +bool TreeNode::lessThan(float* sample, int features){ + + if(features < this->index){ + throw std::invalid_argument("Attempting to evaluate split with input that contains less features."); + } + + return(sample[index] < splitValue); +} diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h @@ -11,7 +11,7 @@ struct SplitResults{ class TreeNode{ public: - TreeNode(); + TreeNode(int classification); TreeNode(float splittingVal, int featureIndex); bool isLeaf(); void setSplit(float splittingValue, int featureIndex); @@ -24,6 +24,8 @@ class TreeNode{ int getIndexSplit(); SplitResults splitOnNode(float* X, int* y, int samples, int features); std::string getDotEdges(); + int getClassification(); + bool lessThan(float* sample, int features); private: bool leaf; @@ -33,6 +35,7 @@ class TreeNode{ TreeNode* rightChild; float giniImpurity(float* X, int* y, int samples, int features); std::string getDotLabel(); + int classification; };