TreeNode.cpp (4040B)
1 #include "TreeNode.h" 2 #include "stdexcept" 3 #include "Criterion.h" 4 #include "math.h" 5 #include "iostream" 6 #include <string> 7 #include <sstream> 8 9 TreeNode::TreeNode(int classification){ 10 leaf = true; 11 this->classification = classification; 12 } 13 14 TreeNode::TreeNode(float splittingVal, int featureIndex){ 15 splitValue = splittingVal; 16 index = featureIndex; 17 leaf = false; 18 } 19 20 void TreeNode::setSplit(float splittingVal, int featureIndex){ 21 splitValue = splittingVal; 22 index = featureIndex; 23 leaf = false; 24 } 25 26 bool TreeNode::isLeaf(){ 27 return leaf; 28 } 29 30 float TreeNode::evalSplit(float* X, int* y, int samples, int features, std::string criterion){ 31 32 if(isLeaf()){ 33 throw std::logic_error("Cannot evaluate split on leaf node."); 34 } 35 36 if(criterion != "gini"){ 37 throw std::invalid_argument("Gini impurity is the only supported criterion."); 38 } 39 40 Criterion evalCriterion= Criterion(); 41 42 return evalCriterion.giniImpurity(X, y, samples, features, this->index, this->splitValue); 43 } 44 45 46 void TreeNode::setLeftChild(TreeNode* child){ 47 leftChild = child; 48 } 49 50 void TreeNode::setRightChild(TreeNode* child){ 51 rightChild = child; 52 } 53 54 TreeNode* TreeNode::getLeftChild(){ 55 return leftChild; 56 } 57 58 TreeNode* TreeNode::getRightChild(){ 59 return rightChild; 60 } 61 62 float TreeNode::getSplitVal(){ 63 return splitValue; 64 } 65 66 int TreeNode::getIndexSplit(){ 67 return index; 68 } 69 70 SplitResults TreeNode::splitOnNode(float* X, int* y, int samples, int features){ 71 72 SplitResults result = SplitResults(); 73 74 int ltCount = 0; 75 int gteqCount = 0; 76 77 for(int i = 0 ; i < samples; ++i){ 78 if(X[(i*features) + index] < splitValue){ 79 ltCount += 1; 80 } 81 else{ 82 gteqCount += 1; 83 } 84 } 85 86 // Create X arrays to return 87 88 float* ltArr = new float[ltCount * features]; 89 float* gteqArr = new float[gteqCount * features]; 90 91 // Create array ptr next open 92 93 float* nextLtX = ltArr; 94 float* nextGteqX = gteqArr; 95 96 // Create y arrays to return 97 98 int* ltYArr = new int[ltCount]; 99 int* gteqYArr = new int[gteqCount]; 100 101 // Create array ptr next open 102 103 int* nextLtY = ltYArr; 104 int* nextGteqY = gteqYArr; 105 106 // Set pointers for return to the new arrays 107 108 result.XLeft = ltArr; 109 result.yLeft = ltYArr; 110 111 result.XRight = gteqArr; 112 result.yRight = gteqYArr; 113 114 result.leftSize = ltCount; 115 result.rightSize = gteqCount; 116 117 // Set arrays with correct values 118 119 for(int i = 0 ; i < samples; ++i){ 120 if(X[(i*features) + index] < splitValue){ 121 for(int x = 0; x < features; ++x){ 122 nextLtX[x] = X[(i*features) + x]; 123 } 124 125 nextLtX += features; 126 127 nextLtY[0] = y[i]; 128 nextLtY += 1; 129 } 130 else{ 131 for(int x = 0; x < features; ++x){ 132 nextGteqX[x] = X[(i*features) + x]; 133 } 134 135 nextGteqX += features; 136 137 nextGteqY[0] = y[i]; 138 nextGteqY += 1; 139 } 140 } 141 142 //for(int x = 0 ; x < ltCount; ++x){ 143 // for(int i = 0 ; i < features; ++i){ 144 // std::cout << ltArr[x*features + i]; 145 // } 146 // std::cout << std::endl; 147 //} 148 149 //for(int x = 0 ; x < ltCount; ++x){ 150 // std::cout << ltYArr[x] << std::endl; 151 //} 152 153 return result; 154 } 155 156 157 158 159 160 161 162 std::string TreeNode::getDotEdges(){ 163 164 if(isLeaf()){ 165 return ""; 166 } 167 168 std::string current = getDotLabel() + "->" + leftChild->getDotLabel() + ";\n"; 169 current += getDotLabel() + "->" + rightChild->getDotLabel() + ";\n"; 170 171 current += rightChild->getDotEdges(); 172 current += leftChild->getDotEdges(); 173 174 return current; 175 } 176 177 std::string TreeNode::getDotLabel(){ 178 const void * address = static_cast<const void*>(this); 179 std::stringstream ss; 180 ss << address; 181 std::string name = ss.str(); 182 if (isLeaf()){ 183 return "\"" + name + "\nCLASSIFICATION: " + std::to_string(classification) + "\""; 184 } 185 186 return "\"" + name + "\nINDEX: " + std::to_string(index) + "\nVALUE:" + std::to_string(splitValue) + "\""; 187 } 188 189 int TreeNode::getClassification(){ 190 if(isLeaf()){ 191 return classification; 192 } 193 throw std::logic_error("Unable to call getClassification() on internal vertices."); 194 } 195 196 bool TreeNode::lessThan(float* sample, int features){ 197 198 if(features < this->index){ 199 throw std::invalid_argument("Attempting to evaluate split with input that contains less features."); 200 } 201 202 return(sample[index] < splitValue); 203 }