commit 0b82e16a80a312beb9bfb6f1309f51f363c4dd58
parent 333e4207d7e436660537738c05c0284893d5717e
Author: Andrew <andrewlaack1@gmail.com>
Date: Fri, 20 Dec 2024 10:16:55 -0600
Holy crap... this is insane
Diffstat:
5 files changed, 191 insertions(+), 8 deletions(-)
diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp
@@ -1,5 +1,6 @@
#include "DecisionTreeClassifier.h"
-#include "TreeNode.h"
+#include <limits>
+#include <iostream>
using namespace std;
@@ -8,4 +9,52 @@ DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){
}
void DecisionTreeClassifier::fit(float* X, int rows, int* y, int columns){
+ // IMPORTANT: MUST DEALLOCATE CHOSEN AFTER USE...
+ cout << "COMPUTING BEST" << endl;
+ TreeNode* chosen = bestSplit(X, rows, y, columns);
+ cout << "SPLIT VAL: " << chosen->getSplitVal() << endl;
+ cout << "INDEX: "<< chosen->getIndexSplit() << endl;
+
+ SplitResults res = chosen->splitOnNode(X,y,rows, columns);
+
+ // create recursive helper method.
+
+
+}
+
+// 1 1 0
+// 3 3 0
+// 2 1 1
+// 4 1 3
+
+// consider adding interpolation to this and sorting the list first.
+// Also, no reason to consider the 0th split if that is the case.
+
+TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns){
+
+ TreeNode* bestNode = nullptr;
+ float bestGini = std::numeric_limits<float>::max();
+
+ for(int col = 0 ; col < columns; ++col){
+ for(int row = 0; row < rows; ++row){
+
+ float val = X[row*columns + col];
+ TreeNode* current = new TreeNode(val, col);
+ float gini = current->evalSplit(X, y, rows, columns, "gini");
+ if (gini < bestGini){
+
+ TreeNode* prevBest = bestNode;
+ delete prevBest;
+
+ bestNode = current;
+ bestGini = gini;
+ }
+ else{
+ delete current;
+ }
+ }
+ }
+
+ return bestNode;
+
}
diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/DecisionTreeClassifier.h
@@ -1,3 +1,5 @@
+#include "TreeNode.h"
+
class DecisionTreeClassifier{
public:
DecisionTreeClassifier(int depth);
@@ -5,4 +7,6 @@ class DecisionTreeClassifier{
int* predict();
private:
int depth;
+ TreeNode* bestSplit(float* X, int rows, int* y, int columns);
+
};
diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp
@@ -1,12 +1,9 @@
#include "DecisionTreeClassifier.h"
-#include "TreeNode.h"
#include "iostream"
#include "assert.h"
using namespace std;
-
-
void testTreeNode(){
int labels[] = {10, 10, 10, 1, 2, 3};
float samples[][4] = {
@@ -22,7 +19,6 @@ void testTreeNode(){
bool isLeaf = tn.isLeaf();
assert(!isLeaf);
- cout << "Is Leaf Passed" << "\n";
float giniVal = tn.evalSplit(*samples, labels, 6, 4, "gini");
assert(abs(giniVal - .5833333) < .0001 );
@@ -31,11 +27,13 @@ void testTreeNode(){
float giniVal2 = tn.evalSplit(*samples, labels, 6, 4, "gini");
assert(abs(giniVal2 - .6666666) < .0001);
- cout << "Gini Calculation Passed" << "\n";
}
+
int main(){
+
testTreeNode();
+
DecisionTreeClassifier clf = DecisionTreeClassifier(10);
int labels[] = {10, 10, 10, 1, 2, 3};
@@ -45,8 +43,12 @@ int main(){
{1,7,5,3},
{1,3,5,3},
{1,7,5,3},
- {1,1,5,3}
+ {1,1.1,5,3}
};
+ //index 1, split val 3
+ //ltCount = 3
+ //gteqCount = 3
+
clf.fit(*samples, 6, labels, 4);
}
diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp
@@ -38,10 +38,118 @@ float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::stri
return giniImpurity(X, y, samples, features);
}
-float TreeNode::giniImpurity(float* X, int* y, int samples, int features){
+void TreeNode::setLeftChild(TreeNode* child){
+ leftChild = child;
+}
+
+void TreeNode::setRightChild(TreeNode* child){
+ rightChild = child;
+}
+
+TreeNode* TreeNode::getLeftChild(){
+ return leftChild;
+}
+
+TreeNode* TreeNode::getRightChild(){
+ return rightChild;
+}
+
+float TreeNode::getSplitVal(){
+ return splitValue;
+}
+
+int TreeNode::getIndexSplit(){
+ return index;
+}
+
+SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){
+
+ SplitResults result = SplitResults();
+
+ int ltCount = 0;
+ int gteqCount = 0;
+
+ for(int i = 0 ; i < samples; ++i){
+ if(X[(i*features) + index] < splitValue){
+ ltCount += 1;
+ }
+ else{
+ gteqCount += 1;
+ }
+ }
+
+ // Create X arrays to return
+
+ float* ltArr = new float[ltCount * features];
+ float* gteqArr = new float[gteqCount * features];
+ // Create array ptr next open
+ float* nextLtX = ltArr;
+ float* nextGteqX = gteqArr;
+
+ // Create y arrays to return
+
+ int* ltYArr = new int[ltCount];
+ int* gteqYArr = new int[gteqCount];
+
+ // Create array ptr next open
+
+ int* nextLtY = ltYArr;
+ int* nextGteqY = gteqYArr;
+
+ // Set pointers for return to the new arrays
+
+ result.XLeft = ltArr;
+ result.yLeft = ltYArr;
+
+ result.XRight = gteqArr;
+ result.yRight = gteqYArr;
+
+ result.leftSize = ltCount;
+ result.rightSize = gteqCount;
+
+ // Set arrays with correct values
+
+ for(int i = 0 ; i < samples; ++i){
+ if(X[(i*features) + index] < splitValue){
+ for(int x = 0; x < features; ++x){
+ nextLtX[0] = X[(i*features) + x];
+ nextLtX += 1;
+ }
+
+ nextLtY[0] = y[i];
+ nextLtY += 1;
+ }
+ else{
+ for(int x = 0; x < features; ++x){
+ nextGteqX[0] = X[(i*features) + x];
+ nextGteqX += 1;
+ }
+
+ nextGteqY[0] = y[i];
+ nextGteqY += 1;
+ }
+ }
+
+ //for(int x = 0 ; x < ltCount; ++x){
+ // for(int i = 0 ; i < features; ++i){
+ // std::cout << ltArr[x*features + i];
+ // }
+ // std::cout << std::endl;
+ //}
+
+ //for(int x = 0 ; x < ltCount; ++x){
+ // std::cout << ltYArr[x] << std::endl;
+ //}
+
+ return result;
+}
+
+
+
+float TreeNode::giniImpurity(float* X, int* y, int samples, int features){
std::unordered_map<int, int> ltMap;
std::unordered_map<int, int> gtMap;
diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h
@@ -1,5 +1,14 @@
#include "string"
+struct SplitResults{
+ float* XLeft;
+ float* XRight;
+ int* yLeft;
+ int* yRight;
+ int leftSize;
+ int rightSize;
+};
+
class TreeNode{
public:
TreeNode();
@@ -7,10 +16,21 @@ class TreeNode{
bool isLeaf();
void setSplit(float splittingValue, int featureIndex);
float evalSplit(float* X, int* y, int samples, int features, std::string criterion);
+ TreeNode* getLeftChild();
+ TreeNode* getRightChild();
+ void setLeftChild(TreeNode* child);
+ void setRightChild(TreeNode* child);
+ float getSplitVal();
+ int getIndexSplit();
+ SplitResults splitOnNode(float* X, int* y, int samples, int features);
private:
bool leaf;
float splitValue;
int index;
+ TreeNode* leftChild;
+ TreeNode* rightChild;
float giniImpurity(float* X, int* y, int samples, int features);
};
+
+