decision-tree-classifier

Decision tree classifier implementation in C++
git clone git://git.laack.co/decision-tree-classifier.git
Log | Files | Refs | README | LICENSE

DecisionTreeClassifier.cpp (4996B)


      1 #include "DecisionTreeClassifier.h"
      2 #include <limits>
      3 #include <iostream>
      4 #include <unordered_map>
      5 #include <thread>
      6 #include <mutex>
      7 
      8 using namespace std;
      9 
     10 DecisionTreeClassifier::DecisionTreeClassifier(int maxDepth){
     11 	this->depth = maxDepth;
     12 }
     13 
     14 void DecisionTreeClassifier::fit(float* X, int samples, int* y, int features){
     15 
     16 	if (splittingTree != nullptr){
     17 		deleteTree(splittingTree);
     18 	}
     19 
     20 	if(features <= 0){
     21 		throw invalid_argument("Invalid argument, there must be 1 or more features to train on.");
     22 	}
     23 
     24 	if(samples <= 0){
     25 		throw invalid_argument("Invalid argument, there must be 1 or more samples to train on.");
     26 	}
     27 
     28 	splittingTree = recurse(X, samples, y, features, depth);
     29 	featureCount = features;
     30 
     31 }
     32 
     33 
     34 std::string DecisionTreeClassifier::getDot(){
     35 	if (splittingTree == nullptr){
     36 		throw logic_error("Decision tree must be created prior to generating dot output.");
     37 	}
     38 	std::string edges = splittingTree->getDotEdges();
     39 	std::string dot = "digraph decisionTree {\n" + edges + "}";
     40 	return dot;
     41 }
     42 
     43 int DecisionTreeClassifier::primaryClass(int* y, int labelCount){
     44 
     45 	unordered_map map = unordered_map<int,int>();
     46 
     47 	for(int i = 0; i < labelCount; ++i){
     48 		map[y[i]] += 1;
     49 	}
     50 
     51 	int mostElements = 0;
     52 	int label = 0;
     53 
     54 	for (auto& item : map){
     55 		if(item.second > mostElements){
     56 			mostElements = item.second;
     57 			label = item.first;
     58 		}
     59 	}
     60 
     61 	return label;
     62 }
     63 
     64 
     65 
     66 // add depth
     67 TreeNode* DecisionTreeClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){
     68 
     69 	if(depthRem == 0){
     70 		TreeNode* ret = new TreeNode(primaryClass(y, rows));
     71 		return ret;
     72 	}
     73 
     74 	// found minimum node
     75 	if(rows == 1){
     76 		TreeNode* ret = new TreeNode(primaryClass(y, rows));
     77 		return ret; 
     78 	}
     79 
     80 	// get best split option 
     81 	TreeNode* chosen = bestSplit(X, rows, y, columns);
     82 	SplitResults split = chosen->splitOnNode(X, y, rows, columns);
     83 
     84 	// no valid splits, but we still did create some new arrays.
     85 	if(split.rightSize == rows || split.leftSize == rows){
     86 		TreeNode* ret = new TreeNode(primaryClass(y, rows));
     87 		delete split.XLeft;
     88 		delete split.XRight;
     89 		delete split.yLeft;
     90 		delete split.yRight;
     91 		return ret; 
     92 	}
     93 
     94 	// traverse lt tree
     95 	TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1);
     96 	// traverse gt tree
     97 	TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1);
     98 
     99 	chosen->setLeftChild(left);
    100 	chosen->setRightChild(right);
    101 
    102 	delete split.XLeft;
    103 	delete split.XRight;
    104 	delete split.yLeft;
    105 	delete split.yRight;
    106 
    107 	return chosen;
    108 }
    109 
    110 
    111 
    112 
    113 
    114 //	1	1	0
    115 //	3	3	0
    116 //	2	1	1
    117 //	4	1	3
    118 
    119 
    120 
    121 
    122 
    123 // consider adding interpolation to this and sorting the list first.
    124 // Also, no reason to consider the 0th split if that is the case.
    125 
    126 TreeNode* DecisionTreeClassifier::bestSplit(float* X, int rows, int* y, int columns) {
    127     TreeNode* bestNode = nullptr;
    128     float bestGini = std::numeric_limits<float>::max();
    129     std::mutex mtx; 
    130 
    131     auto evalColumn = [&](int col){
    132         TreeNode* localBestNode = nullptr;
    133         float localBestGini = std::numeric_limits<float>::max();
    134 
    135         for (int row = 0; row < rows; ++row){
    136             float val = X[row * columns + col];
    137             TreeNode* current = new TreeNode(val, col);
    138             float gini = current->evalSplit(X, y, rows, columns, "gini");
    139 
    140             if (gini < localBestGini){
    141                 delete localBestNode;
    142                 localBestNode = current;
    143                 localBestGini = gini;
    144             }
    145 			else{
    146                 delete current;
    147             }
    148         }
    149 
    150         std::lock_guard<std::mutex> lock(mtx);
    151         if (localBestGini < bestGini){
    152             delete bestNode;
    153             bestNode = localBestNode;
    154             bestGini = localBestGini;
    155         }
    156 		else{
    157             delete localBestNode;
    158         }
    159     };
    160 
    161     std::vector<std::thread> threads;
    162     for (int col = 0; col < columns; ++col) {
    163         threads.emplace_back(evalColumn, col);
    164     }
    165 
    166     for (auto& thread : threads) {
    167         thread.join();
    168     }
    169 
    170     return bestNode;
    171 }
    172 
    173 int* DecisionTreeClassifier::predict(float* X, int samples, int features) {
    174 
    175 	if(featureCount == -1){
    176 		throw logic_error("Unable to predict prior to calling fit().");
    177 	}
    178 
    179 	if(features != this->featureCount){
    180 		throw invalid_argument("Incorrect number of features for prediction.");
    181 	}
    182 
    183 	int* predictions = new int[samples];
    184 
    185 	for(int i = 0; i < samples; ++i){
    186 		TreeNode* current = splittingTree;
    187 		while(!current->isLeaf()){
    188 			float* currentElement = X;
    189 			currentElement += features * i;
    190 			bool lessThan = current->lessThan(currentElement, features);
    191 			if(lessThan){
    192 				current = current->getLeftChild();
    193 			}
    194 			else{
    195 				current = current->getRightChild();
    196 			}
    197 		}
    198 		predictions[i] = current->getClassification();
    199 	}
    200 
    201 	return predictions;
    202 }
    203 
    204 DecisionTreeClassifier::~DecisionTreeClassifier(){
    205 	deleteTree(splittingTree);
    206 
    207 }
    208 
    209 void DecisionTreeClassifier::deleteTree(TreeNode* node){
    210 
    211 	if(node == nullptr){
    212 		return;
    213 	}
    214 
    215 	deleteTree(node->getLeftChild());
    216 	deleteTree(node->getRightChild());
    217 	delete node;
    218 }