decision-tree-classifier

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit 333e4207d7e436660537738c05c0284893d5717e
parent fd65d662d1afca9132e08c51d2958600a37200ec
Author: Andrew <andrewlaack1@gmail.com>
Date:   Fri, 20 Dec 2024 04:02:32 -0600

Started over in c++

Diffstat:
Arewrite/DecisionTreeClassifier.cpp | 11+++++++++++
Arewrite/DecisionTreeClassifier.h | 8++++++++
Arewrite/Makefile | 12++++++++++++
Arewrite/Tests.cpp | 52++++++++++++++++++++++++++++++++++++++++++++++++++++
Arewrite/TreeNode.cpp | 87+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Arewrite/TreeNode.h | 16++++++++++++++++
6 files changed, 186 insertions(+), 0 deletions(-)

diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp @@ -0,0 +1,11 @@ +#include "DecisionTreeClassifier.h" +#include "TreeNode.h" + +using namespace std; + +DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){ + depth = maxDepth; +} + +void DecisionTreeClassifier::fit(float* X, int rows, int* y, int columns){ +} diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/DecisionTreeClassifier.h @@ -0,0 +1,8 @@ +class DecisionTreeClassifier{ + public: + DecisionTreeClassifier(int depth); + void fit(float* X, int rows, int* y, int columns); + int* predict(); + private: + int depth; +}; diff --git a/rewrite/Makefile b/rewrite/Makefile @@ -0,0 +1,12 @@ +clean: + rm *.o + rm *.out + +node: + g++ -c TreeNode.cpp + +tests: DecisionTreeClassifier.o TreeNode.o + g++ Tests.cpp DecisionTreeClassifier.o TreeNode.o + +decisionTree: + g++ -c DecisionTreeClassifier.cpp diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp @@ -0,0 +1,52 @@ +#include "DecisionTreeClassifier.h" +#include "TreeNode.h" +#include "iostream" +#include "assert.h" + +using namespace std; + + + +void testTreeNode(){ + int labels[] = {10, 10, 10, 1, 2, 3}; + float samples[][4] = { + {1,1,5,3}, + {1,2,5,3}, + {1,7,5,3}, + {1,3,5,3}, + {1,7,5,3}, + {1,1,5,3} + }; + + TreeNode tn = TreeNode(5.0f ,1); + bool isLeaf = tn.isLeaf(); + + assert(!isLeaf); + cout << "Is Leaf Passed" << "\n"; + + float giniVal = tn.evalSplit(*samples, labels, 6, 4, "gini"); + assert(abs(giniVal - .5833333) < .0001 ); + + tn.setSplit(0.0f, 0); + float giniVal2 = tn.evalSplit(*samples, labels, 6, 4, "gini"); + assert(abs(giniVal2 - .6666666) < .0001); + + cout << "Gini Calculation Passed" << "\n"; +} + +int main(){ + testTreeNode(); + DecisionTreeClassifier clf = DecisionTreeClassifier(10); + int labels[] = {10, 10, 10, 1, 2, 3}; + + float samples[][4] = { + {1,1,5,3}, + {1,2,5,3}, + {1,7,5,3}, + {1,3,5,3}, + {1,7,5,3}, + {1,1,5,3} + }; + + clf.fit(*samples, 6, labels, 4); +} diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp @@ -0,0 +1,87 @@ +#include "TreeNode.h" +#include "stdexcept" +#include "unordered_map" +#include "math.h" +#include "iostream" + + +TreeNode::TreeNode(){ + leaf = true; +} + +TreeNode::TreeNode(float splittingVal, int featureIndex){ + splitValue = splittingVal; + index = featureIndex; + leaf = false; +} + +void TreeNode::setSplit(float splittingVal, int featureIndex){ + splitValue = splittingVal; + index = featureIndex; + leaf = false; +} + +bool TreeNode::isLeaf(){ + return leaf; +} + +float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){ + + if(isLeaf()){ + throw std::logic_error("Cannot evaluate split on leaf node."); + } + + if(criterion != "gini"){ + throw std::invalid_argument("Gini impurity is the only supported criterion."); + } + + return giniImpurity(X, y, samples, features); +} + +float TreeNode::giniImpurity(float* X, int* y, int samples, int features){ + + + + + std::unordered_map<int, int> ltMap; + std::unordered_map<int, int> gtMap; + + int ltCount = 0; + int gteqCount = 0; + + for(int i = 0; i < samples; ++i){ + if(X[index + (i * features)] < splitValue){ + ltMap[y[i]]++; + ltCount++; + } + else{ + gtMap[y[i]]++; + gteqCount++; + } + } + + + float ltGini= 1.0f; + + for (const auto& pair : ltMap) { + ltGini -= pow(float(pair.second) / ltCount, 2); + } + + float gteqGini = 1.0f; + + for (const auto& pair : gtMap) { + gteqGini -= pow(float(pair.second) / gteqCount, 2); + } + + if(gteqCount == 0){ + gteqGini = 0.0f; + } + if(ltCount == 0){ + ltGini = 0.0f; + } + + float gini = gteqGini * float(gteqCount) / samples; + gini += ltGini * float(ltCount) / samples; + + return gini; +} diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h @@ -0,0 +1,16 @@ +#include "string" + +class TreeNode{ + public: + TreeNode(); + TreeNode(float splittingVal, int featureIndex); + bool isLeaf(); + void setSplit(float splittingValue, int featureIndex); + float evalSplit(float* X, int* y, int samples, int features, std::string criterion); + + private: + bool leaf; + float splitValue; + int index; + float giniImpurity(float* X, int* y, int samples, int features); +};