commit 333e4207d7e436660537738c05c0284893d5717e
parent fd65d662d1afca9132e08c51d2958600a37200ec
Author: Andrew <andrewlaack1@gmail.com>
Date: Fri, 20 Dec 2024 04:02:32 -0600
Started over in c++
Diffstat:
6 files changed, 186 insertions(+), 0 deletions(-)
diff --git a/rewrite/DecisionTreeClassifier.cpp b/rewrite/DecisionTreeClassifier.cpp
@@ -0,0 +1,11 @@
+#include "DecisionTreeClassifier.h"
+#include "TreeNode.h"
+
+using namespace std;
+
+DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){
+ depth = maxDepth;
+}
+
+void DecisionTreeClassifier::fit(float* X, int rows, int* y, int columns){
+}
diff --git a/rewrite/DecisionTreeClassifier.h b/rewrite/DecisionTreeClassifier.h
@@ -0,0 +1,8 @@
+class DecisionTreeClassifier{
+ public:
+ DecisionTreeClassifier(int depth);
+ void fit(float* X, int rows, int* y, int columns);
+ int* predict();
+ private:
+ int depth;
+};
diff --git a/rewrite/Makefile b/rewrite/Makefile
@@ -0,0 +1,12 @@
+clean:
+ rm *.o
+ rm *.out
+
+node:
+ g++ -c TreeNode.cpp
+
+tests: DecisionTreeClassifier.o TreeNode.o
+ g++ Tests.cpp DecisionTreeClassifier.o TreeNode.o
+
+decisionTree:
+ g++ -c DecisionTreeClassifier.cpp
diff --git a/rewrite/Tests.cpp b/rewrite/Tests.cpp
@@ -0,0 +1,52 @@
+#include "DecisionTreeClassifier.h"
+#include "TreeNode.h"
+#include "iostream"
+#include "assert.h"
+
+using namespace std;
+
+
+
+void testTreeNode(){
+ int labels[] = {10, 10, 10, 1, 2, 3};
+ float samples[][4] = {
+ {1,1,5,3},
+ {1,2,5,3},
+ {1,7,5,3},
+ {1,3,5,3},
+ {1,7,5,3},
+ {1,1,5,3}
+ };
+
+ TreeNode tn = TreeNode(5.0f ,1);
+ bool isLeaf = tn.isLeaf();
+
+ assert(!isLeaf);
+ cout << "Is Leaf Passed" << "\n";
+
+ float giniVal = tn.evalSplit(*samples, labels, 6, 4, "gini");
+ assert(abs(giniVal - .5833333) < .0001 );
+
+ tn.setSplit(0.0f, 0);
+ float giniVal2 = tn.evalSplit(*samples, labels, 6, 4, "gini");
+ assert(abs(giniVal2 - .6666666) < .0001);
+
+ cout << "Gini Calculation Passed" << "\n";
+}
+
+int main(){
+ testTreeNode();
+ DecisionTreeClassifier clf = DecisionTreeClassifier(10);
+ int labels[] = {10, 10, 10, 1, 2, 3};
+
+ float samples[][4] = {
+ {1,1,5,3},
+ {1,2,5,3},
+ {1,7,5,3},
+ {1,3,5,3},
+ {1,7,5,3},
+ {1,1,5,3}
+ };
+
+ clf.fit(*samples, 6, labels, 4);
+}
diff --git a/rewrite/TreeNode.cpp b/rewrite/TreeNode.cpp
@@ -0,0 +1,87 @@
+#include "TreeNode.h"
+#include "stdexcept"
+#include "unordered_map"
+#include "math.h"
+#include "iostream"
+
+
+TreeNode::TreeNode(){
+ leaf = true;
+}
+
+TreeNode::TreeNode(float splittingVal, int featureIndex){
+ splitValue = splittingVal;
+ index = featureIndex;
+ leaf = false;
+}
+
+void TreeNode::setSplit(float splittingVal, int featureIndex){
+ splitValue = splittingVal;
+ index = featureIndex;
+ leaf = false;
+}
+
+bool TreeNode::isLeaf(){
+ return leaf;
+}
+
+float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){
+
+ if(isLeaf()){
+ throw std::logic_error("Cannot evaluate split on leaf node.");
+ }
+
+ if(criterion != "gini"){
+ throw std::invalid_argument("Gini impurity is the only supported criterion.");
+ }
+
+ return giniImpurity(X, y, samples, features);
+}
+
+float TreeNode::giniImpurity(float* X, int* y, int samples, int features){
+
+
+
+
+ std::unordered_map<int, int> ltMap;
+ std::unordered_map<int, int> gtMap;
+
+ int ltCount = 0;
+ int gteqCount = 0;
+
+ for(int i = 0; i < samples; ++i){
+ if(X[index + (i * features)] < splitValue){
+ ltMap[y[i]]++;
+ ltCount++;
+ }
+ else{
+ gtMap[y[i]]++;
+ gteqCount++;
+ }
+ }
+
+
+ float ltGini= 1.0f;
+
+ for (const auto& pair : ltMap) {
+ ltGini -= pow(float(pair.second) / ltCount, 2);
+ }
+
+ float gteqGini = 1.0f;
+
+ for (const auto& pair : gtMap) {
+ gteqGini -= pow(float(pair.second) / gteqCount, 2);
+ }
+
+ if(gteqCount == 0){
+ gteqGini = 0.0f;
+ }
+ if(ltCount == 0){
+ ltGini = 0.0f;
+ }
+
+ float gini = gteqGini * float(gteqCount) / samples;
+ gini += ltGini * float(ltCount) / samples;
+
+ return gini;
+}
diff --git a/rewrite/TreeNode.h b/rewrite/TreeNode.h
@@ -0,0 +1,16 @@
+#include "string"
+
+class TreeNode{
+ public:
+ TreeNode();
+ TreeNode(float splittingVal, int featureIndex);
+ bool isLeaf();
+ void setSplit(float splittingValue, int featureIndex);
+ float evalSplit(float* X, int* y, int samples, int features, std::string criterion);
+
+ private:
+ bool leaf;
+ float splitValue;
+ int index;
+ float giniImpurity(float* X, int* y, int samples, int features);
+};