DecisionTreeClassifier.h (592B)
1 #include "TreeNode.h" 2 #include <vector> 3 4 class DecisionTreeClassifier{ 5 public: 6 DecisionTreeClassifier(int depth); 7 void fit(float* X, int samples, int* y, int features); 8 int* predict(float* X, int samples, int features); 9 std::string getDot(); 10 ~DecisionTreeClassifier(); 11 private: 12 int depth; 13 int featureCount = -1; 14 TreeNode* splittingTree = nullptr; 15 TreeNode* bestSplit(float* X, int samples, int* y, int features); 16 TreeNode* recurse(float* X, int samples, int* y, int features, int depth); 17 int primaryClass(int* y, int labelCount); 18 void deleteTree(TreeNode* node); 19 };