gini.cpp (1368B)
1 #include <unordered_map> 2 #include <cmath> 3 #include <iostream> 4 5 using namespace std; 6 7 extern "C" { 8 9 struct GiniResult { 10 float weighted; 11 float ltGini; 12 float gtGini; 13 }; 14 15 GiniResult gini(float* eles, int* classes, int sampleCount, float split, int* indices, int indexCt) { 16 17 unordered_map<int, int> ltMap; 18 unordered_map<int, int> gtMap; 19 20 int ltCount = 0; 21 int gtCount = 0; 22 23 // Split the data based on the threshold 24 for(int i = 0; i < indexCt; ++i) { 25 if(eles[indices[i]] < split) { 26 ltMap[classes[indices[i]]]++; 27 ltCount++; 28 } else { 29 gtMap[classes[indices[i]]]++; 30 gtCount++; 31 } 32 } 33 34 35 GiniResult result; 36 37 result.ltGini = 1.0f; 38 for (const auto& pair : ltMap) { 39 result.ltGini -= pow(float(pair.second) / ltCount, 2); 40 } 41 42 result.gtGini = 1.0f; 43 for (const auto& pair : gtMap) { 44 result.gtGini -= pow(float(pair.second) / gtCount, 2); 45 } 46 47 if(gtCount == 0){ 48 result.gtGini = 0.0f; 49 } 50 if(ltCount == 0){ 51 result.ltGini = 0.0f; 52 } 53 54 result.weighted = result.gtGini * float(gtCount) / sampleCount; 55 result.weighted += result.ltGini * float(ltCount) / sampleCount; 56 57 return result; 58 } 59 }