commit fc97d111faf5f8f0a79d83ae3734f5a7d72a4de4
parent 98cfd66ea326623abf7766c2564ae4256d8c94aa
Author: Andrew <andrewlaack1@gmail.com>
Date: Sun, 29 Dec 2024 20:24:32 -0600
Did stuff
Diffstat:
2 files changed, 57 insertions(+), 0 deletions(-)
diff --git a/rewrite/cpp/ObliqueNode.h b/rewrite/cpp/ObliqueNode.h
@@ -0,0 +1,40 @@
+#include "string"
+
+struct SplitResults{
+ float* XLeft;
+ float* XRight;
+ int* yLeft;
+ int* yRight;
+ int leftSize;
+ int rightSize;
+};
+
+class ObliqueNode{
+ public:
+ ObliqueNode(int classification);
+ ObliqueNode(float splittingVal, int featureIndex);
+ bool isLeaf();
+ void setSplit(float splittingValue, int featureIndex);
+ float evalSplit(float* X, int* y, int samples, int features, std::string criterion);
+ ObliqueNode* getLeftChild();
+ ObliqueNode* getRightChild();
+ void setLeftChild(ObliqueNode* child);
+ void setRightChild(ObliqueNode* child);
+ float getSplitVal();
+ int getIndexSplit();
+ SplitResults splitOnNode(float* X, int* y, int samples, int features);
+ std::string getDotEdges();
+ int getClassification();
+ bool lessThan(float* sample, int features);
+
+ private:
+ bool leaf;
+ float splitValue;
+ int index;
+ ObliqueNode* leftChild = nullptr;
+ ObliqueNode* rightChild = nullptr;
+ std::string getDotLabel();
+ int classification;
+};
+
+
diff --git a/rewrite/cpp/OptimalObliqueTree.h b/rewrite/cpp/OptimalObliqueTree.h
@@ -0,0 +1,17 @@
+#include "ObliqueNode.h"
+
+class OptimalObliqueTree{
+ public:
+ OptimalObliqueTree(int depth);
+ void fit(float* X, int samples, int* y, int features);
+ int* predict(float* X, int samples, int features);
+ std::string getDot();
+ private:
+ int depth;
+ int featureCount = -1;
+ ObliqueNode* splittingTree = nullptr;
+ ObliqueNode* bestSplit(float* X, int samples, int* y, int features);
+ ObliqueNode* recurse(float* X, int samples, int* y, int features, int depth);
+ int primaryClass(int* y, int labelCount);
+ void deleteTree(ObliqueNode* node);
+};