decision-tree-classifier

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

commit fc97d111faf5f8f0a79d83ae3734f5a7d72a4de4
parent 98cfd66ea326623abf7766c2564ae4256d8c94aa
Author: Andrew <andrewlaack1@gmail.com>
Date:   Sun, 29 Dec 2024 20:24:32 -0600

Did stuff

Diffstat:
Arewrite/cpp/ObliqueNode.h | 40++++++++++++++++++++++++++++++++++++++++
Arewrite/cpp/OptimalObliqueTree.h | 17+++++++++++++++++
2 files changed, 57 insertions(+), 0 deletions(-)

diff --git a/rewrite/cpp/ObliqueNode.h b/rewrite/cpp/ObliqueNode.h @@ -0,0 +1,40 @@ +#include "string" + +struct SplitResults{ + float* XLeft; + float* XRight; + int* yLeft; + int* yRight; + int leftSize; + int rightSize; +}; + +class ObliqueNode{ + public: + ObliqueNode(int classification); + ObliqueNode(float splittingVal, int featureIndex); + bool isLeaf(); + void setSplit(float splittingValue, int featureIndex); + float evalSplit(float* X, int* y, int samples, int features, std::string criterion); + ObliqueNode* getLeftChild(); + ObliqueNode* getRightChild(); + void setLeftChild(ObliqueNode* child); + void setRightChild(ObliqueNode* child); + float getSplitVal(); + int getIndexSplit(); + SplitResults splitOnNode(float* X, int* y, int samples, int features); + std::string getDotEdges(); + int getClassification(); + bool lessThan(float* sample, int features); + + private: + bool leaf; + float splitValue; + int index; + ObliqueNode* leftChild = nullptr; + ObliqueNode* rightChild = nullptr; + std::string getDotLabel(); + int classification; +}; + + diff --git a/rewrite/cpp/OptimalObliqueTree.h b/rewrite/cpp/OptimalObliqueTree.h @@ -0,0 +1,17 @@ +#include "ObliqueNode.h" + +class OptimalObliqueTree{ + public: + OptimalObliqueTree(int depth); + void fit(float* X, int samples, int* y, int features); + int* predict(float* X, int samples, int features); + std::string getDot(); + private: + int depth; + int featureCount = -1; + ObliqueNode* splittingTree = nullptr; + ObliqueNode* bestSplit(float* X, int samples, int* y, int features); + ObliqueNode* recurse(float* X, int samples, int* y, int features, int depth); + int primaryClass(int* y, int labelCount); + void deleteTree(ObliqueNode* node); +};