commit bc574a4c5fab2abf6bc6eb671ab54f694505ea8c
parent 505b2e05ba4f43008cda0974f61c30d454821115
Author: Andrew <andrewlaack1@gmail.com>
Date: Thu, 19 Dec 2024 14:30:25 -0600
Wrote code for C++ gini
Diffstat:
4 files changed, 198 insertions(+), 66 deletions(-)
diff --git a/classifier/Podtc.py b/classifier/Podtc.py
@@ -2,6 +2,8 @@ from warnings import warn
import numpy as np
from tqdm import tqdm
from SplittingNode import SplittingNode
+from SplittingNode import gini
+
from LeafNode import LeafNode
import math
import graphviz
@@ -35,7 +37,7 @@ class PseudoOptimalDecisionTreeClassifier():
def fit(self, X, y):
- self.__validateInput(X,y)
+ X,y = self.__validateInput(X,y)
y_re = y.reshape(-1,1)
self.sampleSize = X.shape[1]
@@ -119,19 +121,27 @@ class PseudoOptimalDecisionTreeClassifier():
# each row (sample)
# also, we are interpolating between samples
+ # indices for splits (this decides how many splits to test)
indices = np.round(np.linspace(0, len(together) - 2, math.ceil(proportionUsed * len(together)))).astype(int)
+ # indices for evals. This decides which indices to check upon splitting
+ values = np.round(np.linspace(0, len(together) - 1, math.ceil(self.propValSplits * len(together)))).astype(np.int32)
+
for currentSample in indices:
splitOn = ((together[currentSample+1][x] - together[currentSample][x]) / 2) + together[currentSample][x]
split = SplittingNode(x, splitOn)
- current = split.gini(together, self.propValSplits)
+
+ current = gini(together, values, split.index, split.val)
+
+
if current[0] < bestGini:
bestNode = split
bestGini = current[0]
blg = current[1]
bgg = current[2]
+
# Return the best node, the left gini impurity, and right gini impurity.
# These impurities allow for us to stop if we have a pure node.
return (bestNode, blg, bgg)
@@ -218,7 +228,7 @@ class PseudoOptimalDecisionTreeClassifier():
if X.shape[0] <= 1:
raise Exception(f"X must contain more than one sample.")
- return
+ return X,y
def graph(self):
if self.bestSplit == None:
diff --git a/classifier/SplittingNode.py b/classifier/SplittingNode.py
@@ -1,6 +1,6 @@
import numpy as np
import math
-from numba import njit
+import ctypes
class SplittingNode:
@@ -14,57 +14,6 @@ class SplittingNode:
# maybe add input validation???
# do in place weighted gini calculation
- def gini(self, combined, propToValWith):
-
- ltc = {}
- geqc = {}
-
- ltCount = 0
- geqCount = 0
-
- values = np.round(np.linspace(0, len(combined) - 1, math.ceil(propToValWith * len(combined)))).astype(int)
-
- for i in values:
-
- lessThan = self._lessThan(combined[i])
- classification = int(combined[i][-1])
-
- if(lessThan):
- ltCount += 1
- value = ltc.get(classification)
- if(value != None):
- ltc[classification] = value + 1
- else:
- ltc[classification] = 1
- else:
- geqCount += 1
- value = geqc.get(classification)
- if(value != None):
- geqc[classification] = value + 1
- else:
- geqc[classification] = 1
-
- lt_gini = 1
- for key in ltc.keys():
- lt_gini -= (ltc[key] / ltCount)**2
-
- gt_gini = 1
- for key in geqc.keys():
- gt_gini -= (geqc[key] / geqCount)**2
-
- lt_percent = ltCount / len(combined)
- gt_percent = geqCount / len(combined)
- weighted_gini = (lt_gini * lt_percent) + (gt_gini * gt_percent)
-
- return (weighted_gini, lt_gini, gt_gini)
-
-
- def _lessThan(self, sample):
- value = sample[self.index]
- if(value < self.val):
- return True
- return False
-
# split the data by current node
def split(self, arr):
@@ -73,7 +22,7 @@ class SplittingNode:
gtCount = 0
for i in range(0, len(arr)):
- lessThan = self._lessThan(arr[i])
+ lessThan = _lessThan(arr[i], self.index, self.val)
if lessThan:
ltCount += 1
else:
@@ -88,8 +37,8 @@ class SplittingNode:
gtItr = 0
for i in range(0, len(arr)):
- lessThan = self._lessThan(arr[i])
- if lessThan:
+ lt = _lessThan(arr[i], self.index, self.val)
+ if lt:
ltArr[ltItr] = arr[i]
ltItr += 1
else:
@@ -102,3 +51,112 @@ class SplittingNode:
def __str__(self):
return f"Splitting index: {self.index}\nSplitting value: {round(self.val,2)}"
+
+class GiniResult(ctypes.Structure):
+ _fields_ = [("weighted", ctypes.c_float),
+ ("ltGini", ctypes.c_float),
+ ("gtGini", ctypes.c_float)]
+
+def gini(combined, values, index, val):
+
+
+ # implement prop to val with for c++
+
+
+ useCPP = False
+ usePy = True
+
+
+ # add indices and index count
+ if useCPP:
+ gini_lib = ctypes.CDLL('./cpp/libgini.so')
+ gini_lib.gini.restype = GiniResult
+ gini_lib.gini.argtypes = [
+ ctypes.POINTER(ctypes.c_float),
+ ctypes.POINTER(ctypes.c_int),
+ ctypes.c_int,
+ ctypes.c_float,
+ ctypes.POINTER(ctypes.c_int),
+ ctypes.c_int
+ ]
+
+ eles = combined[:, index].astype(np.float32)
+ eles = eles.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
+
+ classes = combined[:, -1].astype(np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int))
+ sample_count = len(combined[:, index])
+ split_val = ctypes.c_float(val)
+
+ vals = values.ctypes.data_as(ctypes.POINTER(ctypes.c_int))
+
+ result = gini_lib.gini(eles, classes, sample_count, split_val, vals, len(values))
+
+ weightedGini = result.weighted
+ ltGini = result.ltGini
+ gtGini = result.gtGini
+ return (weightedGini, ltGini, gtGini)
+ if usePy:
+ outPy = giniPy(combined, values, index, val)
+ return outPy
+ # return outPy
+ return
+
+
+
+def giniPy(combined , values, index, val):
+
+ ltc = {}
+ geqc = {}
+
+ ltCount = 0
+ geqCount = 0
+
+
+
+ for i in values:
+
+ lt = _lessThan(combined[i], index, val)
+ classification = int(combined[i][-1])
+
+ if(lt):
+ ltCount += 1
+ value = ltc.get(classification)
+ if(value != None):
+ ltc[classification] = value + 1
+ else:
+ ltc[classification] = 1
+ else:
+ geqCount += 1
+ value = geqc.get(classification)
+ if(value != None):
+ geqc[classification] = value + 1
+ else:
+ geqc[classification] = 1
+
+ lt_gini = 1
+ for key in ltc.keys():
+ lt_gini -= (ltc[key] / ltCount)**2
+
+ gt_gini = 1
+ for key in geqc.keys():
+ gt_gini -= (geqc[key] / geqCount)**2
+
+ if(geqCount == 0):
+ gt_gini = 0
+ if(ltCount == 0):
+ lt_gini = 0
+
+ lt_percent = ltCount / len(values)
+ gt_percent = geqCount / len(values)
+
+ weighted_gini = (lt_gini * lt_percent) + (gt_gini * gt_percent)
+
+ return (weighted_gini, lt_gini, gt_gini)
+
+
+def _lessThan(sample, index, val):
+ value = sample[index]
+ if(value < val):
+ return True
+ return False
+
diff --git a/classifier/Testing.py b/classifier/Testing.py
@@ -6,26 +6,30 @@ from sklearn.datasets import load_digits
import pandas as pd
from keras.datasets import mnist
from sklearn.metrics import accuracy_score
-
+from setuptools import setup
(train_X, train_y), (test_X, test_y) = mnist.load_data()
train_X = train_X.reshape(-1, 784)
test_X = test_X.reshape(-1, 784)
-classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=.09, proportionToValidateSplits=.02, proportionOfDimsToTrainOn=.75, maxDepth=15);
+
+# train_X = [[2,5], [5,2], [3,4], [4,4], [5,5], [10, 10], [2,2], [12,12]]
+# train_y = [1, 1 , 2, 1, 5, 2,1 ,3]
+
+classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=.01, proportionToValidateSplits=.1, proportionOfDimsToTrainOn=.1, maxDepth=1);
classifier.fit(train_X, train_y)
y_pred = classifier.predict(test_X)
print(accuracy_score(y_true=test_y, y_pred=y_pred))
-classifier = DecisionTreeClassifier(max_depth=15)
-classifier.fit(train_X, train_y)
-y_pred = classifier.predict(test_X)
+# classifier = DecisionTreeClassifier(max_depth=15)
+# classifier.fit(train_X, train_y)
+# y_pred = classifier.predict(test_X)
-print("SECOND ACCURACY:")
-print(accuracy_score(y_true=test_y, y_pred=y_pred))
+# print("SECOND ACCURACY:")
+# print(accuracy_score(y_true=test_y, y_pred=y_pred))
assert False
@@ -41,7 +45,6 @@ y = y.round()
classifier = PseudoOptimalDecisionTreeClassifier(proportionToTrainOn=1, proportionToValidateSplits=1, proportionOfDimsToTrainOn=1, maxDepth=2);
-
classifier.fit(X,y)
diff --git a/classifier/cpp/gini.cpp b/classifier/cpp/gini.cpp
@@ -0,0 +1,61 @@
+#include <unordered_map>
+#include <cmath>
+#include <iostream>
+
+using namespace std;
+
+
+
+extern "C" {
+
+ struct GiniResult {
+ float weighted;
+ float ltGini;
+ float gtGini;
+ };
+
+ GiniResult gini(float* eles, int* classes, int sampleCount, float split, int* indices, int indexCt) {
+
+ unordered_map<int, int> ltMap;
+ unordered_map<int, int> gtMap;
+
+ int ltCount = 0;
+ int gtCount = 0;
+
+ // Split the data based on the threshold
+ for(int i = 0; i < indexCt; ++i) {
+ if(eles[indices[i]] < split) {
+ ltMap[classes[indices[i]]]++;
+ ltCount++;
+ } else {
+ gtMap[classes[indices[i]]]++;
+ gtCount++;
+ }
+ }
+
+
+ GiniResult result;
+
+ result.ltGini = 1.0f;
+ for (const auto& pair : ltMap) {
+ result.ltGini -= pow(float(pair.second) / ltCount, 2);
+ }
+
+ result.gtGini = 1.0f;
+ for (const auto& pair : gtMap) {
+ result.gtGini -= pow(float(pair.second) / gtCount, 2);
+ }
+
+ if(gtCount == 0){
+ result.gtGini = 0.0f;
+ }
+ if(ltCount == 0){
+ result.ltGini = 0.0f;
+ }
+
+ result.weighted = result.gtGini * float(gtCount) / sampleCount;
+ result.weighted += result.ltGini * float(ltCount) / sampleCount;
+
+ return result;
+ }
+}