DecisionTreeClassifier.cpp (4996B)
1 #include "DecisionTreeClassifier.h" 2 #include <limits> 3 #include <iostream> 4 #include <unordered_map> 5 #include <thread> 6 #include <mutex> 7 8 using namespace std; 9 10 DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){ 11 this->depth = maxDepth; 12 } 13 14 void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){ 15 16 if (splittingTree != nullptr){ 17 deleteTree(splittingTree); 18 } 19 20 if(features <= 0){ 21 throw invalid_argument("Invalid argument, there must be 1 or more features to train on."); 22 } 23 24 if(samples <= 0){ 25 throw invalid_argument("Invalid argument, there must be 1 or more samples to train on."); 26 } 27 28 splittingTree = recurse(X, samples, y, features, depth); 29 featureCount = features; 30 31 } 32 33 34 std::string DecisionTreeClassifier::getDot(){ 35 if (splittingTree == nullptr){ 36 throw logic_error("Decision tree must be created prior to generating dot output."); 37 } 38 std::string edges = splittingTree->getDotEdges(); 39 std::string dot = "digraph decisionTree {\n" + edges + "}"; 40 return dot; 41 } 42 43 int DecisionTreeClassifier::primaryClass(int* y, int labelCount){ 44 45 unordered_map map = unordered_map<int,int>(); 46 47 for(int i = 0; i < labelCount; ++i){ 48 map[y[i]] += 1; 49 } 50 51 int mostElements = 0; 52 int label = 0; 53 54 for (auto& item : map){ 55 if(item.second > mostElements){ 56 mostElements = item.second; 57 label = item.first; 58 } 59 } 60 61 return label; 62 } 63 64 65 66 // add depth 67 TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){ 68 69 if(depthRem == 0){ 70 TreeNode* ret = new TreeNode(primaryClass(y, rows)); 71 return ret; 72 } 73 74 // found minimum node 75 if(rows == 1){ 76 TreeNode* ret = new TreeNode(primaryClass(y, rows)); 77 return ret; 78 } 79 80 // get best split option 81 TreeNode* chosen = bestSplit(X, rows, y, columns); 82 SplitResults split = chosen->splitOnNode(X, y, rows, columns); 83 84 // no valid splits, but we still did create some new arrays. 85 if(split.rightSize == rows || split.leftSize == rows){ 86 TreeNode* ret = new TreeNode(primaryClass(y, rows)); 87 delete split.XLeft; 88 delete split.XRight; 89 delete split.yLeft; 90 delete split.yRight; 91 return ret; 92 } 93 94 // traverse lt tree 95 TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1); 96 // traverse gt tree 97 TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1); 98 99 chosen->setLeftChild(left); 100 chosen->setRightChild(right); 101 102 delete split.XLeft; 103 delete split.XRight; 104 delete split.yLeft; 105 delete split.yRight; 106 107 return chosen; 108 } 109 110 111 112 113 114 // 1 1 0 115 // 3 3 0 116 // 2 1 1 117 // 4 1 3 118 119 120 121 122 123 // consider adding interpolation to this and sorting the list first. 124 // Also, no reason to consider the 0th split if that is the case. 125 126 TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns) { 127 TreeNode* bestNode = nullptr; 128 float bestGini = std::numeric_limits<float>::max(); 129 std::mutex mtx; 130 131 auto evalColumn = [&](int col){ 132 TreeNode* localBestNode = nullptr; 133 float localBestGini = std::numeric_limits<float>::max(); 134 135 for (int row = 0; row < rows; ++row){ 136 float val = X[row * columns + col]; 137 TreeNode* current = new TreeNode(val, col); 138 float gini = current->evalSplit(X, y, rows, columns, "gini"); 139 140 if (gini < localBestGini){ 141 delete localBestNode; 142 localBestNode = current; 143 localBestGini = gini; 144 } 145 else{ 146 delete current; 147 } 148 } 149 150 std::lock_guard<std::mutex> lock(mtx); 151 if (localBestGini < bestGini){ 152 delete bestNode; 153 bestNode = localBestNode; 154 bestGini = localBestGini; 155 } 156 else{ 157 delete localBestNode; 158 } 159 }; 160 161 std::vector<std::thread> threads; 162 for (int col = 0; col < columns; ++col) { 163 threads.emplace_back(evalColumn, col); 164 } 165 166 for (auto& thread : threads) { 167 thread.join(); 168 } 169 170 return bestNode; 171 } 172 173 int* DecisionTreeClassifier::predict(float* X, int samples, int features) { 174 175 if(featureCount == -1){ 176 throw logic_error("Unable to predict prior to calling fit()."); 177 } 178 179 if(features != this->featureCount){ 180 throw invalid_argument("Incorrect number of features for prediction."); 181 } 182 183 int* predictions = new int[samples]; 184 185 for(int i = 0; i < samples; ++i){ 186 TreeNode* current = splittingTree; 187 while(!current->isLeaf()){ 188 float* currentElement = X; 189 currentElement += features * i; 190 bool lessThan = current->lessThan(currentElement, features); 191 if(lessThan){ 192 current = current->getLeftChild(); 193 } 194 else{ 195 current = current->getRightChild(); 196 } 197 } 198 predictions[i] = current->getClassification(); 199 } 200 201 return predictions; 202 } 203 204 DecisionTreeClassifier::~DecisionTreeClassifier(){ 205 deleteTree(splittingTree); 206 207 } 208 209 void DecisionTreeClassifier::deleteTree(TreeNode* node){ 210 211 if(node == nullptr){ 212 return; 213 } 214 215 deleteTree(node->getLeftChild()); 216 deleteTree(node->getRightChild()); 217 delete node; 218 }