decision-tree-classifier

Decision tree classifier implementation in C++
git clone git://git.laack.co/decision-tree-classifier.git
Log | Files | Refs | README | LICENSE

TreeNode.cpp (4040B)


      1 #include "TreeNode.h"
      2 #include "stdexcept"
      3 #include "Criterion.h"
      4 #include "math.h"
      5 #include "iostream"
      6 #include <string>
      7 #include <sstream>
      8 
      9 TreeNode::TreeNode(int classification){
     10 	leaf = true;
     11 	this->classification = classification;
     12 }
     13 
     14 TreeNode::TreeNode(float splittingVal, int featureIndex){
     15 	splitValue = splittingVal;
     16 	index = featureIndex;
     17 	leaf = false;
     18 }
     19 
     20 void TreeNode::setSplit(float splittingVal, int featureIndex){
     21 	splitValue = splittingVal;
     22 	index = featureIndex;
     23 	leaf = false;
     24 }
     25 
     26 bool TreeNode::isLeaf(){
     27 	return leaf;
     28 }
     29 
     30 float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){
     31 
     32 	if(isLeaf()){
     33 		throw std::logic_error("Cannot evaluate split on leaf node.");
     34 	}
     35 
     36 	if(criterion != "gini"){
     37 		throw std::invalid_argument("Gini impurity is the only supported criterion.");
     38 	}
     39 
     40 	Criterion evalCriterion= Criterion();
     41 
     42 	return evalCriterion.giniImpurity(X, y, samples, features, this->index, this->splitValue);
     43 }
     44 
     45 
     46 void TreeNode::setLeftChild(TreeNode* child){
     47 	leftChild = child;
     48 }
     49 
     50 void TreeNode::setRightChild(TreeNode* child){
     51 	rightChild = child;
     52 }
     53 
     54 TreeNode* TreeNode::getLeftChild(){
     55 	return leftChild;
     56 }
     57 
     58 TreeNode* TreeNode::getRightChild(){
     59 	return rightChild;
     60 }
     61 
     62 float TreeNode::getSplitVal(){
     63 	return splitValue;
     64 }
     65 
     66 int TreeNode::getIndexSplit(){
     67 	return index;
     68 }
     69 
     70 SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){
     71 
     72 	SplitResults result = SplitResults();
     73 
     74 	int ltCount = 0;
     75 	int gteqCount = 0;
     76 
     77 	for(int i = 0 ; i < samples; ++i){
     78 		if(X[(i*features) + index] < splitValue){
     79 			ltCount += 1;
     80 		}
     81 		else{
     82 			gteqCount += 1;
     83 		}
     84 	}
     85 
     86 	// Create X arrays to return
     87 
     88 	float* ltArr = new float[ltCount * features];
     89 	float* gteqArr = new float[gteqCount * features];
     90 
     91 	// Create array ptr next open
     92 
     93 	float* nextLtX = ltArr;
     94 	float* nextGteqX = gteqArr;
     95 
     96 	// Create y arrays to return
     97 
     98 	int* ltYArr = new int[ltCount];
     99 	int* gteqYArr = new int[gteqCount];
    100 
    101 	// Create array ptr next open
    102 
    103 	int* nextLtY = ltYArr;
    104 	int* nextGteqY = gteqYArr;
    105 
    106 	// Set pointers for return to the new arrays
    107 
    108 	result.XLeft = ltArr;
    109 	result.yLeft = ltYArr;
    110 
    111 	result.XRight = gteqArr;
    112 	result.yRight = gteqYArr;
    113 
    114 	result.leftSize = ltCount;
    115 	result.rightSize = gteqCount;
    116 
    117 	// Set arrays with correct values
    118 
    119 	for(int i = 0 ; i < samples; ++i){
    120 		if(X[(i*features) + index] < splitValue){
    121 			for(int x = 0; x < features; ++x){
    122 				nextLtX[x] = X[(i*features) + x];
    123 			}
    124 
    125 			nextLtX += features;
    126 
    127 			nextLtY[0] = y[i];
    128 			nextLtY += 1;
    129 		}
    130 		else{
    131 			for(int x = 0; x < features; ++x){
    132 				nextGteqX[x] = X[(i*features) + x];
    133 			}
    134 
    135 			nextGteqX += features;
    136 
    137 			nextGteqY[0] = y[i];
    138 			nextGteqY += 1;
    139 		}
    140 	}
    141 
    142 	//for(int x = 0 ; x < ltCount; ++x){
    143 	//	for(int i = 0 ; i < features; ++i){
    144 	//		std::cout << ltArr[x*features + i];
    145 	//	}
    146 	//	std::cout << std::endl;
    147 	//}
    148 
    149 	//for(int x = 0 ; x < ltCount; ++x){
    150 	//	std::cout << ltYArr[x] << std::endl;
    151 	//}
    152 
    153 	return result;
    154 }
    155 
    156 
    157 
    158 
    159 
    160 
    161 
    162 std::string TreeNode::getDotEdges(){
    163 
    164 	if(isLeaf()){
    165 		return "";
    166 	}
    167 
    168 	std::string current = getDotLabel() + "->" + leftChild->getDotLabel() + ";\n";
    169 	current += getDotLabel() + "->" + rightChild->getDotLabel() + ";\n";
    170 
    171 	current += rightChild->getDotEdges();
    172 	current += leftChild->getDotEdges();
    173 
    174 	return current;
    175 }
    176 
    177 std::string TreeNode::getDotLabel(){
    178 	const void * address = static_cast<const void*>(this);
    179 	std::stringstream ss;
    180 	ss << address;  
    181 	std::string name = ss.str(); 
    182 	if (isLeaf()){
    183 		return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\"";
    184 	}
    185 
    186 	return "\"" + name + "\nINDEX: " +  std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\"";
    187 }
    188 
    189 int TreeNode::getClassification(){
    190 	if(isLeaf()){
    191 		return classification;
    192 	}
    193 	throw std::logic_error("Unable to call getClassification() on internal vertices.");
    194 }
    195 
    196 bool TreeNode::lessThan(float* sample, int features){
    197 
    198 	if(features < this->index){
    199 		throw std::invalid_argument("Attempting to evaluate split with input that contains less features.");
    200 	}
    201 
    202 	return(sample[index] < splitValue);
    203 }