TreeNode.h (936B)
1 #include "string" 2 3 struct SplitResults{ 4 float* XLeft; 5 float* XRight; 6 int* yLeft; 7 int* yRight; 8 int leftSize; 9 int rightSize; 10 }; 11 12 class TreeNode{ 13 public: 14 TreeNode(int classification); 15 TreeNode(float splittingVal, int featureIndex); 16 bool isLeaf(); 17 void setSplit(float splittingValue, int featureIndex); 18 float evalSplit(float* X, int* y, int samples, int features, std::string criterion); 19 TreeNode* getLeftChild(); 20 TreeNode* getRightChild(); 21 void setLeftChild(TreeNode* child); 22 void setRightChild(TreeNode* child); 23 float getSplitVal(); 24 int getIndexSplit(); 25 SplitResults splitOnNode(float* X, int* y, int samples, int features); 26 std::string getDotEdges(); 27 int getClassification(); 28 bool lessThan(float* sample, int features); 29 30 private: 31 bool leaf; 32 float splitValue; 33 int index; 34 TreeNode* leftChild = nullptr; 35 TreeNode* rightChild = nullptr; 36 std::string getDotLabel(); 37 int classification; 38 }; 39 40