cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TreeNode.h (1814B)


      1 #include "string"
      2 #include "vector"
      3 
      4 struct SplitResults{
      5 	float* XLeft;
      6 	float* XRight;
      7 	int* yLeft;
      8 	int* yRight;
      9 	int leftSize;
     10 	int rightSize;
     11 };
     12 
     13 class TreeNode{
     14 	public:
     15 		TreeNode(int classification);
     16 
     17 		// featureIndices is an array with size featureCount that
     18 		// specifies which components should be used to generate the split
     19 		//
     20 		// EXAMPLE:
     21 		// [0, 0, 1, 1, 0]
     22 		// 
     23 		// we also need a system of linear equations as the input which
     24 		// will be used to determine the hyperplane that contains them
     25 		// where all coefficients are 0 unless they are included in the
     26 		// features indices list.
     27 		//
     28 		// EXAMPLE:
     29 		// [
     30 		//		[10, 3, 4, 5, 0],
     31 		//		[1, 5, 6, 8, 0],
     32 		// ]
     33 		//
     34 
     35 		TreeNode(float* samples, int features, int points, int* indicesOrder, int indicesCount);
     36 		bool isLeaf();
     37 		void setSplit(float splittingValue, int featureIndex);
     38 		float evalSplit(float* X, int* y, int samples, int features, std::string criterion);
     39 		TreeNode* getLeftChild();
     40 		TreeNode* getRightChild();
     41 		void setLeftChild(TreeNode* child);
     42 		void setRightChild(TreeNode* child);
     43 		float getSplitVal();
     44 		int getIndexSplit();
     45 		SplitResults splitOnNode(float* X, int* y, int samples, int features);
     46 		std::string getDotEdges();
     47 		int getClassification();
     48 		float* getEquation();
     49 		bool aboveOrOnPlane(float* sample, int features);
     50 		~TreeNode();
     51 
     52 	private:
     53 		bool leaf;
     54 		TreeNode* leftChild = nullptr;
     55 		TreeNode* rightChild = nullptr;
     56 		std::string getDotLabel();
     57 		int classification;
     58 		float* equation = nullptr;
     59 		int* indicesOrder = nullptr;
     60 		int indicesCount;
     61 		float giniImpurity(float* X, int* y, int samples, int features);
     62 		float twoingRule(float* X, int* y, int samples, int features);
     63 		float informationGain(float* X, int* y, int samples, int features);
     64 		float entropy(int* y, int samples);
     65 };
     66 
     67