cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

ELCClassifier.cpp (12536B)


      1 #include "../include/ELCClassifier.h"
      2 #include <cmath>
      3 #include <unordered_map>
      4 #include "future"
      5 #include <queue>
      6 #include "map"
      7 #include <stack>
      8 
      9 using namespace std;
     10 
     11 ELCClassifier::ELCClassifier(int maxDepth, int linearCombinations, int maxThreadCount, std::string objFunction){
     12 	this->depth = maxDepth;
     13 	this->linearCombinations = linearCombinations;
     14 	this->maxThreads = maxThreadCount;
     15 	this->objectiveFunction = objFunction;
     16 }
     17 
     18 void ELCClassifier::fit(float* X, int samples, int* y, int features){
     19 
     20 	if (splittingTree != nullptr){
     21 		deleteTree(splittingTree);
     22 		splittingTree = nullptr;
     23 	}
     24 
     25 	this->featureCount = features;
     26 	this->splittingTree = recurse(X, samples, y, features, depth);
     27 }
     28 
     29 
     30 std::string ELCClassifier::getDot(){
     31 	if (splittingTree == nullptr){
     32 		throw logic_error("Decision tree must be created prior to generating dot output.");
     33 	}
     34 	std::string edges = splittingTree->getDotEdges();
     35 	std::string dot = "digraph decisionTree {\n" + edges + "}";
     36 	return dot;
     37 }
     38 
     39 int ELCClassifier::primaryClass(int* y, int labelCount){
     40 	unordered_map<int,int> map;
     41 
     42 	for(int i = 0; i < labelCount; ++i){
     43 		map[y[i]] += 1;
     44 	}
     45 
     46 	int mostElements = 0;
     47 	int label = 0;
     48 
     49 	for (auto& item : map){
     50 		if(item.second > mostElements){
     51 			mostElements = item.second;
     52 			label = item.first;
     53 		}
     54 	}
     55 
     56 	return label;
     57 }
     58 
     59 bool ELCClassifier::homogeneous(int* y, int samples){
     60 	
     61 	if(samples == 1){
     62 		return true;
     63 	}
     64 
     65 	for(int i = 1 ; i < samples; ++i){
     66 		if(y[i] != y[0]){
     67 			return false;
     68 		}
     69 	}
     70 	return true;
     71 }
     72 
     73 TreeNode* ELCClassifier::recurse(float* X, int rows, int* y, int columns, int depthRem){
     74 
     75 
     76 	for(int i = 0 ; i < rows * columns; ++i){
     77 
     78 		if(i % columns == 0){
     79 			std::cout << std::endl;
     80 		}
     81 		std::cout << X[i] << " ";
     82 
     83 	}
     84 
     85 	std::cout << endl;
     86 	for(int i = 0 ; i < rows; ++i){
     87 		std::cout << y[i] << " ";
     88 	}
     89 
     90 	std::cout << endl;
     91 
     92 	if(depthRem == 0 || homogeneous(y,rows)){
     93 		TreeNode* ret = new TreeNode(primaryClass(y, rows));
     94 		return ret;
     95 	}
     96 
     97 
     98 	// found minimum node
     99 	if(rows <= this->linearCombinations){
    100 		TreeNode* ret = new TreeNode(primaryClass(y, rows));
    101 		return ret; 
    102 	}
    103 
    104 
    105 	// get best split option 
    106 	TreeNode* chosen = bestSplit(X, rows, y, columns);
    107 	SplitResults split = chosen->splitOnNode(X, y, rows, columns);
    108 
    109 
    110 	//for(int i = 0 ; i < split.leftSize; ++i){
    111 	//	std::cout << split.XLeft[i] << " ";
    112 	//}
    113 
    114 	//std::cout << std::endl;
    115 
    116 	//for(int i = 0 ; i < linearCombinations; ++i){
    117 	//	std::cout << chosen->getEquation()[i] << " ";
    118 	//}
    119 
    120 	//std::cout << std::endl;
    121 	//std::cout << std::endl;
    122 
    123 	// no valid splits, but we still did create some new arrays.
    124 	if(split.rightSize == rows || split.leftSize == rows){
    125 		TreeNode* ret = new TreeNode(primaryClass(y, rows));
    126 		
    127 		// line of code is in prison.
    128 		// he cost me ~8 hours of time. Our battle was valiant,
    129 		// but alas you have been found, your leak patched,
    130 		// and you are now under arrest.
    131 
    132 		// __________//
    133 		delete chosen;
    134 		chosen = nullptr;
    135 		//^^^^^^^^^^^//
    136 
    137 		
    138 
    139 		delete[] split.XLeft;
    140 		delete[] split.XRight;
    141 		delete[] split.yLeft;
    142 		delete[] split.yRight;
    143 
    144 		split.XLeft = nullptr;
    145 		split.XRight = nullptr;
    146 		split.yLeft = nullptr;
    147 		split.yRight = nullptr;
    148 
    149 		return ret; 
    150 	}
    151 
    152 	// traverse lt tree
    153 	TreeNode* left = recurse(split.XLeft, split.leftSize, split.yLeft, columns, depthRem - 1);
    154 	// traverse gt tree
    155 	TreeNode* right = recurse(split.XRight, split.rightSize, split.yRight, columns, depthRem - 1);
    156 
    157 	chosen->setLeftChild(left);
    158 	chosen->setRightChild(right);
    159 
    160 	delete[] split.XLeft;
    161 	delete[] split.XRight;
    162 	delete[] split.yLeft;
    163 	delete[] split.yRight;
    164 
    165 	split.XLeft = nullptr;
    166 	split.XRight = nullptr;
    167 	split.yLeft = nullptr;
    168 	split.yRight = nullptr;
    169 
    170 	return chosen;
    171 }
    172 
    173 
    174 
    175 
    176 
    177 
    178 // steps:
    179 //
    180 // 1) find all combinations of points
    181 // 		combination count = nCr where n is rows and r is this->linearcombinations
    182 // 2) find all combinations of axis for each combinations
    183 // 		combination count = mCr where m is columns and r is this->linearcombinations
    184 // 3) Evaluate all combinations (impurity)
    185 // 4) Return tree node with the best split.
    186 
    187 // TreeNode(float* samples, int features, int points, int* indicesOrder, int indicesCount);
    188 
    189 TreeNode* ELCClassifier::bestSplit(float* X, int rows, int* y, int columns) {
    190 
    191 //	for(int i = 0 ; i < rows*columns; ++i){
    192 //		if(i % columns == 0){
    193 //			std::cout << std::endl;
    194 //		}
    195 //		std::cout << X[i] << " ";
    196 //
    197 //	}
    198 //	std::cout << std::endl;
    199 //	std::cout << std::endl;
    200 
    201 	float bestImpurity = 0.0f;
    202 	float* ptrImpurity = &bestImpurity;
    203 	std::queue<std::vector<int>> queue;
    204 	TreeNode* best = this->bestSplitHelper(X, y, rows, columns, std::vector<int>(), 0, ptrImpurity, true, queue, false);
    205 	return best;
    206 }
    207 
    208 
    209 // test all combinations of features for given points and return best selection.
    210 //
    211 //
    212 //
    213 //
    214 // this needs to be reworked.
    215 //
    216 // this currently passes in all samples and then finds the best indices to split on instead of what I want
    217 //
    218 // what I should make this do instead is find all indicies and then call a helper method that then computes
    219 // the best points to select with those indices.
    220 //
    221 // To do this I will build another method called bestSplitByPoints which accepts in the indices we are splitting on,
    222 // the other information associated with labels, and all other points for validation. This will then go through all
    223 // combinations of points for the specified indices, returning the best option.
    224 
    225 
    226 TreeNode* const ELCClassifier::bestNodeForSelectSamples(
    227 float* allSamples, int* y, 
    228 int sampleCount, int features, 
    229 vector<int> specifiedSamples, int currentFeature,
    230 float* bestImpurity,
    231 std::vector<int> selectedFeatures
    232 ){
    233 
    234 	if((int)selectedFeatures.size() == this->linearCombinations){
    235 
    236 		int* featuresToUse = selectedFeatures.data();
    237 		int size = features * specifiedSamples.size();
    238 		float samplesToTest[size];
    239 		int itr = 0;
    240 
    241 		for (int x = 0; x < (int)specifiedSamples.size(); ++x) {
    242 			for (int y = 0; y < features; ++y) {
    243 				int sampleIndex = specifiedSamples[x];
    244 				int calculatedIndex = (sampleIndex * features) + y;
    245 				samplesToTest[itr] = allSamples[calculatedIndex];
    246 				itr += 1;
    247 
    248 			}
    249 		}
    250 
    251 		TreeNode* node = new TreeNode(samplesToTest,  features, this->linearCombinations, featuresToUse, this->linearCombinations);
    252 		*bestImpurity = node->evalSplit(allSamples, y, sampleCount, features, this->objectiveFunction);
    253 
    254 		return node;
    255 	}
    256 
    257 	if(currentFeature >= features){
    258 		TreeNode* node = nullptr;
    259 		*bestImpurity = INFINITY;
    260 		return node;
    261 	}
    262 	
    263 	// without this one included
    264 	
    265 	float left = 0;
    266 	float right = 0;
    267 	float* leftPtr = &left;
    268 	float* rightPtr = &right;
    269 
    270 	TreeNode* bestWithout = bestNodeForSelectSamples(allSamples, y, sampleCount, features, specifiedSamples, currentFeature + 1, leftPtr, selectedFeatures);
    271 
    272 	// with this one included
    273 	
    274 	selectedFeatures.push_back(currentFeature);
    275 
    276 
    277 	TreeNode* bestWith = bestNodeForSelectSamples(allSamples, y, sampleCount, features, specifiedSamples, currentFeature + 1, rightPtr, selectedFeatures);
    278 
    279 	if(*leftPtr > *rightPtr){
    280 		if(bestWithout != nullptr){
    281 			delete bestWithout;
    282 			bestWithout = nullptr;
    283 		}
    284 		*bestImpurity = *rightPtr;
    285 		return bestWith;
    286 	}
    287 	else{
    288 		if(bestWith != nullptr){
    289 			delete bestWith;
    290 			bestWith = nullptr;
    291 		}
    292 		*bestImpurity = *leftPtr;
    293 		return bestWithout;
    294 	}
    295 
    296 
    297 }
    298 
    299 // init call is used to ensure we clean up the queue
    300 TreeNode* ELCClassifier::bestSplitHelper(float* allSamples, int* y, int sampleCount, int features, vector<int> current, int currentFeature, float* bestImpurity, bool initCall, std::queue<std::vector<int>>& queuedSelections, bool finalPass) {
    301 
    302 	if((int)current.size() == this->linearCombinations || finalPass){
    303 
    304 		// this will be -1 when calling for the final time if it is the init call. 
    305 		// This is messy, but I don't know how to make it better.
    306 
    307 		if(!finalPass){
    308 			queuedSelections.push(current);
    309 		}
    310 
    311 		// this is the only location where we evaluate potential samples to split on.
    312 		if((int)queuedSelections.size() > this->maxThreads or finalPass){
    313 
    314 			float currentBestImpurity = INFINITY;
    315 			TreeNode* bestNode = nullptr;
    316 
    317     		std::vector<std::future<TreeNode*>> futureList;
    318 			std::vector<float> floats = std::vector<float>();
    319 			for(int i = 0 ; i < (int)queuedSelections.size(); ++i){
    320 				floats.push_back(INFINITY);
    321 			}
    322 
    323 		
    324 			int itr = 0;
    325 			while ((int)queuedSelections.size() > 0) {
    326 
    327 				float* tempImpurity = &floats[itr];
    328 				auto currentVec = queuedSelections.front();
    329 				queuedSelections.pop();
    330 
    331 				std::future<TreeNode*> futureNode = std::async(std::launch::async, 
    332 					&ELCClassifier::bestNodeForSelectSamples, 
    333 					*this,  
    334 					allSamples, 
    335 					y, 
    336 					sampleCount, 
    337 					features, 
    338 					currentVec, 
    339 					0, 
    340 					tempImpurity, 
    341 					std::vector<int>());
    342 
    343 				futureList.push_back(std::move(futureNode));
    344 				itr++;
    345 			}
    346 
    347 		itr = 0;
    348 
    349 		for (auto& future : futureList) {
    350 			TreeNode* currentNode = future.get();  // This blocks until the future is ready
    351 			float tempImpurity = floats[itr];
    352 
    353 			if (tempImpurity < currentBestImpurity) {
    354 				currentBestImpurity = tempImpurity;
    355 				
    356 				// Delete previous best node
    357 				if (bestNode != nullptr) {
    358 					delete bestNode;
    359 					bestNode = nullptr;
    360 				}
    361 
    362 				bestNode = currentNode;
    363 			} else {
    364 				delete currentNode;
    365 				currentNode = nullptr;
    366 			}
    367 
    368 			itr++;
    369 		}
    370 
    371 		*bestImpurity = currentBestImpurity;
    372 		return bestNode;
    373 
    374 		}
    375 		else{
    376 			*bestImpurity = INFINITY;
    377 			return nullptr;
    378 		}
    379 	}
    380 
    381 	if(currentFeature >= sampleCount){
    382 		TreeNode* node = nullptr;
    383 		*bestImpurity = INFINITY;
    384 		return node;
    385 	}
    386 
    387 
    388 	// without this one included
    389 	
    390 	float left = 0;
    391 	float right = 0;
    392 	float* leftPtr = &left;
    393 	float* rightPtr = &right;
    394 
    395 
    396 	TreeNode* bestWithout = bestSplitHelper(allSamples, y, sampleCount, features, current, currentFeature + 1, leftPtr, false, queuedSelections, false);
    397 
    398 	// with this one included
    399 	current.push_back(currentFeature);
    400 
    401 
    402 	TreeNode* bestWith = bestSplitHelper(allSamples, y, sampleCount, features, current, currentFeature + 1, rightPtr, false, queuedSelections, false);
    403 	
    404 	// this is used to ensure that even if the total number of evaluated splits is less than the number of allowed threads
    405 	// we still clear out the queue.
    406 	
    407 	if(initCall){
    408 		float curImpurity = INFINITY;
    409 		TreeNode* final = bestSplitHelper(allSamples, y, sampleCount, features, current, currentFeature + 1, &curImpurity, false, queuedSelections, true);
    410 		if(curImpurity < left){
    411 
    412 			leftPtr = &curImpurity;
    413 
    414 			if(bestWithout != nullptr){
    415 				delete bestWithout;
    416 				bestWithout = nullptr;
    417 			}
    418 
    419 			bestWithout = final;
    420 			final = nullptr;
    421 		}
    422 	}
    423 
    424 	if(*leftPtr > *rightPtr){
    425 		if(bestWithout != nullptr){
    426 			delete bestWithout;
    427 			bestWithout = nullptr;
    428 		}
    429 		*bestImpurity = *rightPtr;
    430 		return bestWith;
    431 	}
    432 	else{
    433 		if(bestWith != nullptr){
    434 			delete bestWith;
    435 			bestWith = nullptr;
    436 		}
    437 		*bestImpurity = *leftPtr;
    438 		return bestWithout;
    439 	}
    440 }
    441 
    442 int* ELCClassifier::predict(float* X, int samples, int features) {
    443 
    444 	if(featureCount == -1){
    445 		throw logic_error("Unable to predict prior to calling fit().");
    446 	}
    447 
    448 	if(features != this->featureCount){
    449 		throw invalid_argument("Incorrect number of features for prediction.");
    450 	}
    451 
    452 	int* predictions = new int[samples];
    453 
    454 	for(int i = 0; i < samples; ++i){
    455 		TreeNode* current = splittingTree;
    456 		while(!current->isLeaf()){
    457 			float* currentElement = X;
    458 			currentElement += features * i;
    459 			bool above = current->aboveOrOnPlane(currentElement, features);
    460 			if(above){
    461 				current = current->getRightChild();
    462 			}
    463 			else{
    464 				current = current->getLeftChild();
    465 			}
    466 		}
    467 		predictions[i] = current->getClassification();
    468 	}
    469 
    470 	return predictions;
    471 }
    472 
    473 ELCClassifier::~ELCClassifier(){
    474 
    475 	if(this->splittingTree != nullptr){
    476 		deleteTree(this->splittingTree);
    477 		this->splittingTree = nullptr;
    478 	}
    479 
    480 }
    481 
    482 void ELCClassifier::deleteTree(TreeNode* node){
    483 
    484 	if(node == nullptr){
    485 		return;
    486 	}
    487 
    488 	if(node->getLeftChild() != nullptr){
    489 		deleteTree(node->getLeftChild());
    490 		node->setLeftChild(nullptr);
    491 	}
    492 
    493 	if(node->getRightChild() != nullptr){
    494 		deleteTree(node->getRightChild());
    495 		node->setRightChild(nullptr);
    496 	}
    497 
    498 	delete node;
    499 }
    500 
    501 int ELCClassifier::getSplits(){
    502 	
    503 	TreeNode* current = splittingTree;
    504 
    505 	if(current == nullptr){
    506 		return 0;
    507 	}
    508 
    509 
    510     int count = 0;
    511     std::stack<TreeNode*> stack;
    512     stack.push(splittingTree);
    513 
    514     while (!stack.empty()) {
    515         TreeNode* current = stack.top();
    516         stack.pop();
    517 
    518         if (!current->isLeaf()) {
    519             count++; 
    520             if (current->getLeftChild()) stack.push(current->getLeftChild());
    521             if (current->getRightChild()) stack.push(current->getRightChild());
    522         }
    523     }
    524 
    525     return count;
    526 }
    527