TreeNode.h (1814B)
1 #include "string" 2 #include "vector" 3 4 struct SplitResults{ 5 float* XLeft; 6 float* XRight; 7 int* yLeft; 8 int* yRight; 9 int leftSize; 10 int rightSize; 11 }; 12 13 class TreeNode{ 14 public: 15 TreeNode(int classification); 16 17 // featureIndices is an array with size featureCount that 18 // specifies which components should be used to generate the split 19 // 20 // EXAMPLE: 21 // [0, 0, 1, 1, 0] 22 // 23 // we also need a system of linear equations as the input which 24 // will be used to determine the hyperplane that contains them 25 // where all coefficients are 0 unless they are included in the 26 // features indices list. 27 // 28 // EXAMPLE: 29 // [ 30 // [10, 3, 4, 5, 0], 31 // [1, 5, 6, 8, 0], 32 // ] 33 // 34 35 TreeNode(float* samples, int features, int points, int* indicesOrder, int indicesCount); 36 bool isLeaf(); 37 void setSplit(float splittingValue, int featureIndex); 38 float evalSplit(float* X, int* y, int samples, int features, std::string criterion); 39 TreeNode* getLeftChild(); 40 TreeNode* getRightChild(); 41 void setLeftChild(TreeNode* child); 42 void setRightChild(TreeNode* child); 43 float getSplitVal(); 44 int getIndexSplit(); 45 SplitResults splitOnNode(float* X, int* y, int samples, int features); 46 std::string getDotEdges(); 47 int getClassification(); 48 float* getEquation(); 49 bool aboveOrOnPlane(float* sample, int features); 50 ~TreeNode(); 51 52 private: 53 bool leaf; 54 TreeNode* leftChild = nullptr; 55 TreeNode* rightChild = nullptr; 56 std::string getDotLabel(); 57 int classification; 58 float* equation = nullptr; 59 int* indicesOrder = nullptr; 60 int indicesCount; 61 float giniImpurity(float* X, int* y, int samples, int features); 62 float twoingRule(float* X, int* y, int samples, int features); 63 float informationGain(float* X, int* y, int samples, int features); 64 float entropy(int* y, int samples); 65 }; 66 67