cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

testing.cpp (16349B)


      1 #include <random>
      2 #include "iostream"
      3 #include "../include/ELCClassifier.h"
      4 
      5 
      6 void testHyperplaneCreationAndEvaluation(){
      7 
      8 	// test square matricies 100 times each
      9 	// for size 1-20.
     10 
     11 	for(int i = 0 ; i < 2100; ++i){
     12 		int features = int(i / 100) + 1;
     13 		int points = features;
     14 		float samples[features * features];
     15 
     16 		std::uniform_int_distribution<> dist(0, 99);  // Uniform distribution between 0 and 99
     17 		std::random_device rd;  // Get a random seed from hardware
     18 		for(int i = 0 ; i < features * features ; ++i){
     19 			samples[i] = dist(rd);
     20 		}
     21 
     22 		int indicesOrder[features - 1];
     23 		for(int i = 0 ; i < features; ++i){
     24 			indicesOrder[i] = i;
     25 		}
     26 
     27 		TreeNode node = TreeNode(samples, features, points, indicesOrder, features);
     28 		
     29 		float* equ = node.getEquation();
     30 
     31 		// evaluate to verify that all samples are on the plane
     32 		for(int x = 0; x < points; ++x){
     33 			float result = 0;
     34 
     35 			for(int i = 0 ; i < features; ++i){
     36 				result += equ[i] * samples[i + x*features];
     37 			}
     38 
     39 			// some margin of error; these are floats after all.
     40 			if(fabs(result - equ[features]) > .001){
     41 				throw std::logic_error("Problem with hyperplane equation computation.");
     42 			}
     43 		}
     44 	}
     45 	std::cout << "Hyperplane equations computed as expected" << std::endl;
     46 
     47 
     48 	// validate that it can take in indicesCounts and orders that aren't the same
     49 	// as the features. 
     50 	//
     51 	// Instead of enforcing that we have square matricies, we specify which features we care about right now.
     52 	// This means the indices count should still be equal to the points, but still.
     53 	//
     54 	// This is useful for evaluating splits with respect to the hyperplane because we can easily understand 
     55 	// which features we are doing the linear combination of.
     56 	
     57 	float samples[] = {
     58 		22, 2, 2, 8, 4,
     59 		40, 4, 5, 4, 1
     60 	};
     61 
     62 	int points = 2;
     63 	int indicesOrder[2] = {1,2};
     64 	int features = 5;
     65 	int indicesCount = 2;
     66 	TreeNode node = TreeNode(samples, features, points, indicesOrder, indicesCount);
     67 
     68 	// we already know this works because of the above test
     69 	float samples1[] = {
     70 		2, 2, 
     71 		4, 5,
     72 	};
     73 	int points1 = 2;
     74 	int indicesOrder1[2] = {0,1};
     75 	int features1 = 2;
     76 	int indicesCount1= 2;
     77 	TreeNode node1 = TreeNode(samples1, features1, points1, indicesOrder1, indicesCount1);
     78 
     79 	// by verifying these are the same, we have shown that the hyperplane respects the indices
     80 	// we specified.
     81 
     82 	for(int i = 0 ; i < indicesCount1 + 1; ++i){
     83 		if(node1.getEquation()[i] != node.getEquation()[i]){
     84 			throw std::logic_error("Indices order or count not working as expected.");
     85 		}
     86 	}
     87 	std::cout << "Verified indices order and count work as expected" << std::endl;
     88 
     89 	// equation for plane is:
     90 	//-0.83205 = a
     91 	//0.5547 = b
     92 	//-0.5547 = d
     93 	
     94 	int numSamples = 6;
     95 
     96 	// these can be evaluated manually with the equation shown above and the formula:
     97 	// ax + by = d
     98 	// WHERE:
     99 	// x = first feature
    100 	// y = second feature
    101 
    102 	float evalSamples[] = {
    103 		2, 2, // on plane
    104 		4, 5, // on plane
    105 		2.1, 2,// below
    106 		2, 2.1,// above
    107 		10, 13.99,// below
    108 		10, 14, // on plane 
    109 	};
    110 
    111 	float* eval = evalSamples;
    112 
    113 	bool answers[] = {
    114 		true,
    115 		true,
    116 		false,
    117 		true,
    118 		false,
    119 		true,
    120 	};
    121 
    122 	for(int i = 0 ; i < numSamples; ++i){
    123 
    124 		if(node1.aboveOrOnPlane(eval, 2) != answers[i]){
    125 			throw std::logic_error("Above or on plane not working as expected.");
    126 		}
    127 		// get next sample
    128 		eval += 2;
    129 	}
    130 
    131 	std::cout << "Verified evaluation for point on or above works as expected." << std::endl;
    132 }
    133 
    134 
    135 
    136 
    137 void testInformationGain(){
    138 
    139 	// this should result in x = 1.
    140 	float train[] = {
    141 		1
    142 	};
    143 
    144 	float validation[] = {
    145 		// these three are left
    146 		0, 5, 2,
    147 		0, 6, 3,
    148 		0, 7, 4,
    149 
    150 		// these three are right
    151 		4, 8, 5,
    152 		5, 9, 6,
    153 		6, 10, 7
    154 	};
    155 
    156 	int validationClasses[] = {
    157 		// these three are left
    158 		0,
    159 		0,
    160 		1,
    161 
    162 		// these three are right
    163 		1,
    164 		0,
    165 		1
    166 	};
    167 
    168 	// before = [0, 0, 1, 1, 0, 1]
    169 	// left = [0,0,1]
    170 	// right = [1,0,1]
    171 	//
    172 	// before = -(.5 log_2 .5 + .5 log_2 .5)
    173 	// = 1
    174 	//
    175 	// left H(X) = 0.9183
    176 	// right H(X) = 0.9183
    177 	//
    178 	// weighted = .9183
    179 	//
    180 	// IG = before - after
    181 	// = .0817 (bigger is better)
    182 	// 
    183 	// INV(IG) = .0817 * -1
    184 	// = -.0817
    185 
    186 
    187 	int points1 = 1;
    188 	int indicesOrder1[1] = {0};
    189 	int features1 = 1;
    190 	int indicesCount1= 1;
    191 
    192 	TreeNode node1 = TreeNode(train, features1, points1, indicesOrder1, indicesCount1);
    193 
    194 	float infGain = node1.evalSplit(validation, validationClasses, 6, 3, "information gain");
    195 	
    196 	// we return negative as the rest of our logic is set up to minimize
    197 	// the objetive function
    198 	
    199 	std::cout << "FINAL EVAL: " << infGain << std::endl;
    200 
    201 	float expected = -.0817f;
    202 	float diff = expected - infGain;
    203 
    204 	diff = fabs(diff);
    205 
    206 	std::cout << "DIFF: " << diff << std::endl;
    207 
    208 	if(diff > 0.0001f){
    209 		throw std::logic_error("information gain not working as expected");
    210 	}
    211 
    212 	std::cout << "Verified information gain computed correctly" << std::endl;
    213 }
    214 
    215 void testTwoing(){
    216 
    217 	// this should result in x = 1.
    218 	float train[] = {
    219 		1
    220 	};
    221 
    222 	float validation[] = {
    223 		// these three are left
    224 		0, 5, 2,
    225 		0, 6, 3,
    226 		0, 7, 4,
    227 
    228 		// these three are right
    229 		4, 8, 5,
    230 		5, 9, 6,
    231 		6, 10, 7
    232 	};
    233 
    234 	int validationClasses[] = {
    235 		// these three are left
    236 		0,
    237 		0,
    238 		1,
    239 
    240 		// these three are right
    241 		1,
    242 		0,
    243 		1
    244 	};
    245 
    246 	// left = [0,0,1]
    247 	// right = [1,0,1]
    248 	//
    249 	// p_L = .5
    250 	// p_R = .5
    251 	// p_L * p_R = .25
    252 	// (p_L * p_R) / 4 = .0625 = outer
    253 	//
    254 	// p(0 | t_L) = 2/3
    255 	// p(1 | t_L) = 1/3
    256 	//
    257 	// p(0 | t_R) = 1/3
    258 	// p(1 | t_R) = 2/3
    259 	//
    260 	// 2/3 - 1/3 = 1/3
    261 	// 1/3 - 2/3 = -1/3
    262 	//
    263 	// 1/3 + 1/3 = 2/3
    264 	// 
    265 	// 2/3 ^ 2 = 4/9 = inner
    266 	//
    267 	// 4/9 * .0625 = .02777777777
    268 
    269 	int points1 = 1;
    270 	int indicesOrder1[1] = {0};
    271 	int features1 = 1;
    272 	int indicesCount1= 1;
    273 
    274 	TreeNode node1 = TreeNode(train, features1, points1, indicesOrder1, indicesCount1);
    275 
    276 	float twoingEval = node1.evalSplit(validation, validationClasses, 6, 3, "twoing");
    277 	
    278 	// we return negative as the rest of our logic is set up to minimize
    279 	// the objetive function
    280 	
    281 	twoingEval *= -1;
    282 	std::cout << "FINAL EVAL: " << twoingEval << std::endl;
    283 
    284 	float expected = 0.0277777f;
    285 	// remember this returns the negative twoing
    286 	float diff = twoingEval - expected;
    287 
    288 	diff = fabs(diff);
    289 
    290 	std::cout << "DIFF: " << diff << std::endl;
    291 
    292 	if(diff > 0.0001f){
    293 		throw std::logic_error("twoing rule not working as expected");
    294 	}
    295 
    296 	std::cout << "Verified twoing rule computed correctly" << std::endl;
    297 }
    298 
    299 
    300 void testGini(){
    301 
    302 
    303 	// this should result in x = 1.
    304 	float train[] = {
    305 		1
    306 	};
    307 
    308 
    309 	// we notice that all of the ones classified with 0 as the y value
    310 	// have x > 2 and the one with classification 1 has x < 2.
    311 
    312 	// as such, the gini impurity should be 0 (best).
    313 	
    314 	float validation[] = {
    315 		0, 2, 5,
    316 		4, 5, 5,
    317 		4, 5, 5,
    318 		4, 5, 5,
    319 		4, 5, 5,
    320 	};
    321 
    322 	int validationClasses[] = {
    323 		1,
    324 		0,
    325 		0,
    326 		0,
    327 		0
    328 	};
    329 
    330 	int points1 = 1;
    331 	int indicesOrder1[1] = {0};
    332 	int features1 = 1;
    333 	int indicesCount1= 1;
    334 
    335 	TreeNode node1 = TreeNode(train, features1, points1, indicesOrder1, indicesCount1);
    336 
    337 	float giniImpurity = node1.evalSplit(validation, validationClasses, 5, 3, "gini");
    338 
    339 	if(giniImpurity != 0.0f){
    340 		throw std::logic_error("Gini impurity not working as expected");
    341 	}
    342 
    343 	// so long as len >= 1 we are good to evaluate
    344 	float val2[] = {
    345 		10, 2,
    346 		10, 2,
    347 		0, 2,
    348 		0, 10
    349 	};
    350 
    351 	// this means we should have one of each below and one of each above.
    352 	// as such:
    353 	// 1 - (.5)^2 = .75
    354 	// .75 - (.5)^2 = .5
    355 	//
    356 	// This is what we should find for both sides. Then since they are equally weighted, the final
    357 	// impurity should be .5
    358 
    359 	int val2Classes[] = {
    360 		0,
    361 		1,
    362 		0,
    363 		1
    364 	};
    365 
    366 	float gini = node1.evalSplit(val2, val2Classes, 4, 2, "gini");
    367 	if(gini != .5f){
    368 		throw std::logic_error("Gini impurity not working as expected");
    369 	}
    370 
    371 	// now let's verify weighting works as expected.
    372 
    373 	// lt split:
    374 	// 3, 2, 5, 7, 7
    375 	//
    376 	// 1/5^2 + 1/5^2 + 1/5^2 + 2/5^2 = ~.28
    377 	// gini for lt split = 1 - .28 = .72
    378 	//
    379 	// gt split:
    380 	// 1, 0
    381 	// GT Split gini = .5
    382 	//
    383 	// Weighted gini:
    384 	//
    385 	// (2 * .5) + (5*.72) = 4.6
    386 	// 4.6 / 7 = ~.657
    387 	
    388 	float val3[] = {
    389 		0,
    390 		0,
    391 		0,
    392 		0,
    393 		0,
    394 		2,
    395 		1
    396 	};
    397 
    398 	int val3Classes[] = {
    399 		// less than below
    400 		3,
    401 		2,
    402 		5,
    403 		7,
    404 		7,
    405 		// greater than below
    406 		1,
    407 		0
    408 	};
    409 
    410 	float finalGini = node1.evalSplit(val3, val3Classes, 7, 1, "gini");
    411 	float expectedGini = .657142857;
    412 
    413 
    414 	float diff = finalGini - expectedGini;
    415 	if(diff > .001 || diff < -.001){
    416 		throw std::logic_error("Gini impurity not working as expected");
    417 	}
    418 	
    419 	std::cout << "Verified gini impurity computed correctly" << std::endl;
    420 }
    421 
    422 void testSplitting(){
    423 
    424 	int points1 = 1;
    425 	int indicesOrder1[1] = {0};
    426 	int features1 = 1;
    427 	int indicesCount1= 1;
    428 	float train[] = {
    429 		1.0f
    430 	};
    431 
    432 	TreeNode node1 = TreeNode(train, features1, points1, indicesOrder1, indicesCount1);
    433 
    434 	// x >= 1
    435 	float val[] = {
    436 		0.0f, 2.0f,
    437 		0.0f, 2.0f,
    438 		0.0f, 5.0f,
    439 		0.0f, 7.0f,
    440 		0.0f, 8.0f,
    441 		1.0f, 9.0f,
    442 		7.0f, 10.0f
    443 	};
    444 
    445 	int valClasses[] = {
    446 		3,
    447 		2,
    448 		5,
    449 		7,
    450 		7,
    451 		// greater than below
    452 		1,
    453 		0
    454 	};
    455 
    456 	int featureCount = 2;
    457 	int sampleCount = 7;
    458 
    459 	SplitResults res = node1.splitOnNode(val, valClasses, sampleCount, featureCount);
    460 
    461 
    462 
    463 	int expLeftClasses[] = {
    464 		3,
    465 		2,
    466 		5,
    467 		7,
    468 		7
    469 	};
    470 
    471 	int expRightClasses[] = {
    472 		1,
    473 		0
    474 	};
    475 
    476 	float expLeftSamples[] = {
    477 		0, 2,
    478 		0, 2,
    479 		0, 5,
    480 		0, 7,
    481 		0, 8,
    482 	};
    483 
    484 	float expRightSamples[] = {
    485 		1, 9,
    486 		7, 10
    487 	};
    488 
    489 	// verify proper splitting of labels.
    490 
    491 	for(int i = 0 ; i < 5; ++i){
    492 		if(res.yLeft[i] != expLeftClasses[i]){
    493 			throw new std::logic_error("Splitting not working as expected.");
    494 		}
    495 	}
    496 
    497 	for(int i = 0 ; i < 2; ++i){
    498 		if(res.yRight[i] != expRightClasses[i]){
    499 			throw new std::logic_error("Splitting not working as expected.");
    500 		}
    501 	}
    502 
    503 	for(int i = 0 ; i < 2; ++i){
    504 		for(int x = 0 ; x < 2; ++x){
    505 			if(res.XRight[i*2 + x] != expRightSamples[i*2 + x]){
    506 				throw new std::logic_error("Splitting not working as expected.");
    507 			}
    508 		}
    509 	}
    510 
    511 	for(int i = 0 ; i < 5; ++i){
    512 		for(int x = 0 ; x < 2; ++x){
    513 			if(res.XLeft[i*2 + x] != expLeftSamples[i*2 + x]){
    514 				throw new std::logic_error("Splitting not working as expected.");
    515 			}
    516 		}
    517 	}
    518 
    519 	delete[] res.XRight;
    520 	delete[] res.XLeft;
    521 	delete[] res.yRight;
    522 	delete[] res.yLeft;
    523 
    524 	std::cout << "Verified splitting on tree nodes works as expected." << std::endl;
    525 }
    526 
    527 void testBestSplit(){
    528 
    529 	
    530 	float val[] = {
    531 
    532 		4.1f,  4.5f,  4.5f,  4.2f,  4.4f,  7.1f,  5.5f,  2.2f,  9.4f,  1.2f,
    533 		5.5f,  2.2f,  9.4f,  4.0f,  1.25f, 3.5f,  4.1f,  4.5f,  4.5f,  4.2f,
    534 		4.4f,  7.1f,  5.5f,  2.2f,  9.4f,  1.2f,  5.5f,  2.2f,  9.4f,  4.0f,
    535 		1.25f, 3.5f,  4.1f,  4.5f,  4.5f,  4.2f,  4.4f,  7.1f,  5.5f,  2.2f,
    536 		9.4f,  1.2f,  5.5f,  2.2f,  9.4f,  4.0f,  1.25f, 3.5f,  4.1f,  4.5f
    537 
    538 	};
    539 
    540 
    541 	// 0 1 2
    542 	// 0 1 2
    543 	
    544 	int valClasses[] = {
    545 		0,
    546 		1,
    547 		0,
    548 		0,
    549 		2
    550 	};
    551 
    552 	int MAX_DEPTH = 1;
    553 	int LINEAR_COMBINATIONS = 3;
    554 	int featureCount = 10;
    555 	int sampleCount = 5;
    556 
    557 
    558 	ELCClassifier clf(MAX_DEPTH, LINEAR_COMBINATIONS);
    559 	std::vector<int> input = std::vector<int>();
    560 	TreeNode* split = clf.bestSplit(val, sampleCount, valClasses, featureCount);
    561 
    562 	float value = split->evalSplit(val, valClasses, sampleCount, featureCount, "gini");
    563 
    564 	delete split;
    565 	
    566 
    567 
    568 	// still need to add proper validation.
    569 
    570 	std::cout << "Verified best by points computes without errors." << std::endl;
    571 
    572 }
    573 
    574 void testFit(){
    575 	float val[] = {
    576 		6.83f, 5.94f, 7.06f, 3.46f, 7.79f, 9.11f, 4.73f, 7.32f, 5.98f, 0.09f, 
    577 		9.93f, 1.45f, 6.1f, 6.72f, 1.36f, 7.3f, 0.52f, 9.62f, 0.98f, 1.89f, 
    578 		2.82f, 1.78f, 7.21f, 4.6f, 8.54f, 3.48f, 9.18f, 3.64f, 0.8f, 5.4f, 
    579 		6.31f, 9.6f, 9.37f, 6.15f, 9.36f, 9.53f, 7.54f, 0.63f, 0.77f, 2.67f, 
    580 		5.71f, 1.74f, 0.62f, 0.34f, 7.39f, 1.05f, 3.97f, 3.66f, 2.99f, 8.95f, 
    581 		6.44f, 3.39f, 7.11f, 5.85f, 7.68f, 9.16f, 6.73f, 1.69f, 6.91f, 1.54f, 
    582 		9.93f, 6.32f, 4.66f, 3.22f, 2.17f, 6.22f, 3.04f, 6.83f, 5.7f, 1.22f, 
    583 		5.31f, 4.43f, 0.93f, 3.89f, 6.92f, 5.31f, 2.05f, 5.73f, 5.83f, 1.62f, 
    584 		2.14f, 0.43f, 3.81f, 5.4f, 0.07f, 7.11f, 4.94f, 3.3f, 4.04f, 4.7f, 
    585 		2.0f, 1.63f, 7.04f, 3.45f, 1.72f, 2.1f, 1.83f, 8.61f, 0.21f, 5.11f, 
    586 		7.63f, 4.94f, 4.69f, 6.11f, 3.44f, 2.91f, 5.12f, 9.61f, 9.43f, 9.64f, 
    587 		6.17f, 8.1f, 4.18f, 3.57f, 3.02f, 4.94f, 6.52f, 9.97f, 0.68f, 4.64f, 
    588 		2.29f, 7.01f, 1.31f, 3.47f, 5.54f, 8.22f, 7.63f, 2.42f, 8.67f, 1.8f, 
    589 		5.23f, 2.3f, 9.51f, 8.93f, 1.75f, 1.5f, 1.04f, 0.24f, 3.26f, 8.5f, 
    590 	};
    591 
    592 
    593 	// 0 1 2
    594 	// 0 1 2
    595 	
    596 	int valClasses[] = {
    597 		0,
    598 		1,
    599 		10,
    600 		7,
    601 		7,
    602 		0,
    603 		6,
    604 		4,
    605 		10,
    606 		5,
    607 		8,
    608 		6,
    609 		9,
    610 		9,
    611 	};
    612 
    613 	int MAX_DEPTH = 50;
    614 	int LINEAR_COMBINATIONS = 2;
    615 	int featureCount = 10;
    616 	int sampleCount = 14;
    617 
    618 	ELCClassifier clf = ELCClassifier(MAX_DEPTH, LINEAR_COMBINATIONS);
    619 
    620 	clf.fit(val, sampleCount, valClasses, featureCount);
    621 	std::cout << "Fitting runs successfully" << std::endl;
    622 
    623 }
    624 
    625 
    626 void getDot(){
    627 
    628 	float val[] = {
    629 		1.2f, 2.0f,
    630 		1.0f, 2.0f,
    631 		1.0f, 2.1f,
    632 		1.2f, 9.0f,
    633 		5.2f, 8.0f,
    634 		3.2f, 2.0f,
    635 		4.2f, 7.0f,
    636 		0.2f, 7.0f,
    637 		1.0f, 2.0f,
    638 		0.0f, 7.1f,
    639         2.3f, 6.4f,
    640         8.1f, 3.5f,
    641         7.2f, 4.3f,
    642         3.9f, 1.5f,
    643         5.0f, 6.7f,
    644 	};
    645 
    646 
    647 	// 0 1 2
    648 	// 0 1 2
    649 	
    650 	int valClasses[] = {
    651 		0,
    652 		1,
    653 		10,
    654 		1,
    655 		1,
    656 		5,
    657 		7,
    658 		8,
    659 		6,
    660 		9,
    661 		5,
    662 		7,
    663 		8,
    664 		6,
    665 		9,
    666 	};
    667 
    668 	int MAX_DEPTH = 50;
    669 	int LINEAR_COMBINATIONS = 1;
    670 	int featureCount = 2;
    671 	int sampleCount = 15;
    672 
    673 
    674 	ELCClassifier clf = ELCClassifier(MAX_DEPTH, LINEAR_COMBINATIONS);
    675 	clf.fit(val, sampleCount, valClasses, featureCount);
    676 	std::cout << clf.getDot();
    677 	std::cout << std::endl;
    678 }
    679 
    680 void testPrediction(){
    681 
    682 	float val[] = {
    683 		1.2f, 9.0f,
    684 		5.2f, 8.0f,
    685 		3.2f, 2.0f,
    686 		4.2f, 7.0f,
    687 		0.2f, 7.0f,
    688 		1.0f, 2.0f,
    689 		0.0f, 7.1f,
    690         2.3f, 6.4f,
    691         8.1f, 3.5f,
    692         7.2f, 4.3f,
    693         3.9f, 1.5f,
    694         5.0f, 6.7f,
    695         4.8f, 9.1f,
    696         2.1f, 8.0f,
    697         1.4f, 3.3f,
    698         6.2f, 5.5f,
    699         7.9f, 2.4f,
    700         0.9f, 8.3f,
    701         4.4f, 1.8f,
    702         3.0f, 6.1f,
    703         2.8f, 7.9f,
    704         5.3f, 4.2f
    705 	};
    706 
    707 	int valClasses[] = {
    708 		1,
    709 		1,
    710 		5,
    711 		7,
    712 		8,
    713 		6,
    714 		9,
    715 		5,
    716 		7,
    717 		8,
    718 		6,
    719 		9,
    720 		0,
    721 		1,
    722 		5,
    723 		7,
    724 		8,
    725 		6,
    726 		9,
    727 		10,
    728 		2,
    729 		1
    730 	};
    731 
    732 	int MAX_DEPTH = 50;
    733 	int LINEAR_COMBINATIONS = 1;
    734 	int featureCount = 2;
    735 	int sampleCount = 22;
    736 
    737 	ELCClassifier clf = ELCClassifier(MAX_DEPTH, LINEAR_COMBINATIONS);
    738 	clf.fit(val, sampleCount, valClasses, featureCount);
    739 
    740 	int* preds = clf.predict(val, sampleCount, featureCount);
    741 	for(int i = 0 ; i < sampleCount; ++i){
    742 		if(preds[i] != valClasses[i]){
    743 			throw std::logic_error("Computing axis splits (LC=1) not working properly.");
    744 		}
    745 
    746 	}
    747 	delete[] preds;
    748 
    749 	std::cout << "Verified axis splits are working as expected" << std::endl;
    750 }
    751 
    752 void getRNDTree(int featureCount, int sampleCount, int MAX_DEPTH, int LINEAR_COMBINATIONS) {
    753 
    754     std::vector<float> val(featureCount * sampleCount);
    755     std::random_device rd;
    756     std::mt19937 gen(rd());
    757     std::uniform_real_distribution<float> dist(1.0f, 5.0f); // Random values between 1.0 and 5.0
    758 
    759     for (auto& v : val) {
    760         v = dist(gen);
    761     }
    762 
    763     // Generate random class labels
    764     std::vector<int> valClasses(sampleCount);
    765     std::uniform_int_distribution<int> classDist(0, 10); // Random classes between 0 and 10
    766 
    767     for (auto& c : valClasses) {
    768         c = classDist(gen);
    769     }
    770 
    771     // Create and fit the classifier
    772     ELCClassifier clf = ELCClassifier(MAX_DEPTH, LINEAR_COMBINATIONS, 50, "twoing");
    773 
    774 	for(int i = 0 ; i < 2; ++i){
    775     	clf.fit(val.data(), sampleCount, valClasses.data(), featureCount);
    776 		std::cout << std::endl;
    777 		std::cout << std::endl;
    778 	}
    779     // std::cout << std::endl;
    780     // std::cout << std::endl;
    781     // std::cout << std::endl;
    782     // std::cout << std::endl;
    783     // 
    784 	// std::cout << clf.getDot() << std::endl;
    785 
    786 	// std::cout << std::endl;
    787 	// 
    788 
    789 	// std::cout << "SPLITS: "<< clf.getSplits() << std::endl;
    790 
    791 }
    792 
    793 // TODO:
    794 // x Implement split on node
    795 // x Build logic for fitting :)
    796 // x Build logic for prediction
    797 // x Build dot logic for graphing
    798 // x Multicore support
    799 // - More tests
    800 // - Benchmarking
    801 
    802 //==3641677== 
    803 //==3641677== HEAP SUMMARY:
    804 //==3641677==     in use at exit: 0 bytes in 0 blocks
    805 //==3641677==   total heap usage: 652,186 allocs, 652,186 frees, 28,881,639 bytes allocated
    806 //==3641677== 
    807 //==3641677== All heap blocks were freed -- no leaks are possible
    808 //==3641677== 
    809 //==3641677== For lists of detected and suppressed errors, rerun with: -s
    810 //==3641677== ERROR SUMMARY: 0 errors from 0 contexts (suppressed: 0 from 0)
    811 // haha
    812 
    813 
    814 
    815 int main(){
    816 
    817 	//testHyperplaneCreationAndEvaluation();
    818 	testGini();
    819 	testTwoing();
    820 	testInformationGain();
    821 	//testSplitting();
    822 	//testBestSplit();
    823 	//testFit();
    824 	//getDot();
    825 	//testPrediction();
    826 
    827 	//getRNDTree(3, 10, 10, 2);
    828 	return 0;
    829 }