decision-tree-classifier

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit 0b82e16a80a312beb9bfb6f1309f51f363c4dd58
parent 333e4207d7e436660537738c05c0284893d5717e
Author: Andrew <andrewlaack1@gmail.com>
Date:   Fri, 20 Dec 2024 10:16:55 -0600

Holy crap... this is insane

Diffstat:
Mrewrite/DecisionTreeClassifier.cpp | 51++++++++++++++++++++++++++++++++++++++++++++++++++-
Mrewrite/DecisionTreeClassifier.h | 4++++
Mrewrite/Tests.cpp | 14++++++++------
Mrewrite/TreeNode.cpp | 110++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
Mrewrite/TreeNode.h | 20++++++++++++++++++++
5 files changed, 191 insertions(+), 8 deletions(-)

diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp @@ -1,5 +1,6 @@ #include "DecisionTreeClassifier.h" -#include "TreeNode.h" +#include <limits> +#include <iostream> using namespace std; @@ -8,4 +9,52 @@ DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){ } void DecisionTreeClassifier::fit(float* X, int rows, int* y, int columns){ + // IMPORTANT: MUST DEALLOCATE CHOSEN AFTER USE... + cout << "COMPUTING BEST" << endl; + TreeNode* chosen = bestSplit(X, rows, y, columns); + cout << "SPLIT VAL: " << chosen->getSplitVal() << endl; + cout << "INDEX: "<< chosen->getIndexSplit() << endl; + + SplitResults res = chosen->splitOnNode(X,y,rows, columns); + + // create recursive helper method. + + +} + +// 1 1 0 +// 3 3 0 +// 2 1 1 +// 4 1 3 + +// consider adding interpolation to this and sorting the list first. +// Also, no reason to consider the 0th split if that is the case. + +TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns){ + + TreeNode* bestNode = nullptr; + float bestGini = std::numeric_limits<float>::max(); + + for(int col = 0 ; col < columns; ++col){ + for(int row = 0; row < rows; ++row){ + + float val = X[row*columns + col]; + TreeNode* current = new TreeNode(val, col); + float gini = current->evalSplit(X, y, rows, columns, "gini"); + if (gini < bestGini){ + + TreeNode* prevBest = bestNode; + delete prevBest; + + bestNode = current; + bestGini = gini; + } + else{ + delete current; + } + } + } + + return bestNode; + } diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/DecisionTreeClassifier.h @@ -1,3 +1,5 @@ +#include "TreeNode.h" + class DecisionTreeClassifier{ public: DecisionTreeClassifier(int depth); @@ -5,4 +7,6 @@ class DecisionTreeClassifier{ int* predict(); private: int depth; + TreeNode* bestSplit(float* X, int rows, int* y, int columns); + }; diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp @@ -1,12 +1,9 @@ #include "DecisionTreeClassifier.h" -#include "TreeNode.h" #include "iostream" #include "assert.h" using namespace std; - - void testTreeNode(){ int labels[] = {10, 10, 10, 1, 2, 3}; float samples[][4] = { @@ -22,7 +19,6 @@ void testTreeNode(){ bool isLeaf = tn.isLeaf(); assert(!isLeaf); - cout << "Is Leaf Passed" << "\n"; float giniVal = tn.evalSplit(*samples, labels, 6, 4, "gini"); assert(abs(giniVal - .5833333) < .0001 ); @@ -31,11 +27,13 @@ void testTreeNode(){ float giniVal2 = tn.evalSplit(*samples, labels, 6, 4, "gini"); assert(abs(giniVal2 - .6666666) < .0001); - cout << "Gini Calculation Passed" << "\n"; } + int main(){ + testTreeNode(); + DecisionTreeClassifier clf = DecisionTreeClassifier(10); int labels[] = {10, 10, 10, 1, 2, 3}; @@ -45,8 +43,12 @@ int main(){ {1,7,5,3}, {1,3,5,3}, {1,7,5,3}, - {1,1,5,3} + {1,1.1,5,3} }; + //index 1, split val 3 + //ltCount = 3 + //gteqCount = 3 + clf.fit(*samples, 6, labels, 4); } diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp @@ -38,10 +38,118 @@ float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::stri return giniImpurity(X, y, samples, features); } -float TreeNode::giniImpurity(float* X, int* y, int samples, int features){ +void TreeNode::setLeftChild(TreeNode* child){ + leftChild = child; +} + +void TreeNode::setRightChild(TreeNode* child){ + rightChild = child; +} + +TreeNode* TreeNode::getLeftChild(){ + return leftChild; +} + +TreeNode* TreeNode::getRightChild(){ + return rightChild; +} + +float TreeNode::getSplitVal(){ + return splitValue; +} + +int TreeNode::getIndexSplit(){ + return index; +} + +SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){ + + SplitResults result = SplitResults(); + + int ltCount = 0; + int gteqCount = 0; + + for(int i = 0 ; i < samples; ++i){ + if(X[(i*features) + index] < splitValue){ + ltCount += 1; + } + else{ + gteqCount += 1; + } + } + + // Create X arrays to return + + float* ltArr = new float[ltCount * features]; + float* gteqArr = new float[gteqCount * features]; + // Create array ptr next open + float* nextLtX = ltArr; + float* nextGteqX = gteqArr; + + // Create y arrays to return + + int* ltYArr = new int[ltCount]; + int* gteqYArr = new int[gteqCount]; + + // Create array ptr next open + + int* nextLtY = ltYArr; + int* nextGteqY = gteqYArr; + + // Set pointers for return to the new arrays + + result.XLeft = ltArr; + result.yLeft = ltYArr; + + result.XRight = gteqArr; + result.yRight = gteqYArr; + + result.leftSize = ltCount; + result.rightSize = gteqCount; + + // Set arrays with correct values + + for(int i = 0 ; i < samples; ++i){ + if(X[(i*features) + index] < splitValue){ + for(int x = 0; x < features; ++x){ + nextLtX[0] = X[(i*features) + x]; + nextLtX += 1; + } + + nextLtY[0] = y[i]; + nextLtY += 1; + } + else{ + for(int x = 0; x < features; ++x){ + nextGteqX[0] = X[(i*features) + x]; + nextGteqX += 1; + } + + nextGteqY[0] = y[i]; + nextGteqY += 1; + } + } + + //for(int x = 0 ; x < ltCount; ++x){ + // for(int i = 0 ; i < features; ++i){ + // std::cout << ltArr[x*features + i]; + // } + // std::cout << std::endl; + //} + + //for(int x = 0 ; x < ltCount; ++x){ + // std::cout << ltYArr[x] << std::endl; + //} + + return result; +} + + + +float TreeNode::giniImpurity(float* X, int* y, int samples, int features){ std::unordered_map<int, int> ltMap; std::unordered_map<int, int> gtMap; diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h @@ -1,5 +1,14 @@ #include "string" +struct SplitResults{ + float* XLeft; + float* XRight; + int* yLeft; + int* yRight; + int leftSize; + int rightSize; +}; + class TreeNode{ public: TreeNode(); @@ -7,10 +16,21 @@ class TreeNode{ bool isLeaf(); void setSplit(float splittingValue, int featureIndex); float evalSplit(float* X, int* y, int samples, int features, std::string criterion); + TreeNode* getLeftChild(); + TreeNode* getRightChild(); + void setLeftChild(TreeNode* child); + void setRightChild(TreeNode* child); + float getSplitVal(); + int getIndexSplit(); + SplitResults splitOnNode(float* X, int* y, int samples, int features); private: bool leaf; float splitValue; int index; + TreeNode* leftChild; + TreeNode* rightChild; float giniImpurity(float* X, int* y, int samples, int features); }; + +