commit d55ce288b320b06cea9d8df9e15fa3718d8a67e9
parent cdd22f910575e5cda0dae06407aadb4ad3d22067
Author: Andrew <andrewlaack1@gmail.com>
Date: Tue, 17 Dec 2024 20:34:08 -0600
Did some work
Diffstat:
4 files changed, 198 insertions(+), 0 deletions(-)
diff --git a/classifier/Podtc.py b/classifier/Podtc.py
@@ -0,0 +1,77 @@
+import numpy as np
+from SplittingNode import SplittingNode
+
+class PseudoOptimalDecisionTreeClassifier():
+
+
+ # First is first split, last is ... well yeah.
+ splitList = []
+ threshold = 0
+ maxDepth = 0
+
+ def __init__(self, pruneThreshold = .2, maxDepth = 5):
+ self.threshold = pruneThreshold
+ self.maxDepth = maxDepth
+ return
+
+ def fit(self, X, y):
+
+ self.__validateInput(X,y)
+ y_re = y.reshape(-1,1)
+
+ # together [:,-1] == y
+ together = np.append(X,y_re, axis=1)
+
+ self.splitList.append(self._best_split(together))
+
+ ltArr = np.array([])
+ gtArr = np.array([])
+ ltArr, gtArr = self.splitList[-1].split(X)
+
+ print(f"LESS THAN \n{ltArr}")
+ print(f"GREATER THAN \n{gtArr}")
+ return
+
+ # pass in current root
+ # find best options from then on
+
+ # Find best split
+ def _best_split(self, together):
+ bestGini = float("inf")
+ bestNode = None
+
+ # each column (excluding y)
+ for x in range(0, len(together[0]) - 1):
+
+ together = together[together[:,x].argsort()]
+
+ # each row (sample)
+ # also, we are interpolating between samples
+
+ for y in range(0, len(together) - 1):
+ splitOn = ((together[y+1][x] - together[y][x]) / 2) + together[y][x]
+ split = SplittingNode(x, splitOn)
+ current = split.gini(together)
+ if current < bestGini:
+ bestNode = split
+ bestGini = current
+
+ return bestNode
+
+ def predict(self):
+ print("TODO")
+ return
+
+ def __str__(self):
+ return "TODO"
+
+ def __validateInput(self, X,y):
+ if X.shape[0] != y.shape[0]:
+ raise Exception(f"Incongruent array sizes. X has shape {X.shape} and y has shape {y.shape}.")
+
+ if len(X.shape) != 2:
+ raise Exception(f"X shape {X.shape} not supported. Ensure input array is 2d.")
+
+ if len(y.shape) != 1:
+ raise Exception(f"y shape {y.shape} not supported. Ensure input array is 1d.")
+ return
diff --git a/classifier/SplittingNode.py b/classifier/SplittingNode.py
@@ -0,0 +1,105 @@
+import numpy as np
+
+class SplittingNode:
+
+ index = 0
+ val = 0
+ child = None
+
+
+ def setChild(self, child):
+ self.child = child
+
+ def __init__(self, feature_index, value):
+ self.index = feature_index
+ self.val = value
+
+ # maybe add input validation???
+ # do in place weighted gini calculation
+
+ def gini(self, combined):
+
+ ltc = {}
+ geqc = {}
+
+ ltCount = 0
+ geqCount = 0
+
+ for i in range(0, len(combined)):
+
+ lessThan = self._lessThan(combined[i])
+ classification = int(combined[i][-1])
+
+ if(lessThan):
+ ltCount += 1
+ value = ltc.get(classification)
+ if(value != None):
+ ltc[classification] = value + 1
+ else:
+ ltc[classification] = 1
+ else:
+ geqCount += 1
+ value = geqc.get(classification)
+ if(value != None):
+ geqc[classification] = value + 1
+ else:
+ geqc[classification] = 1
+
+ lt_gini = 1
+ for key in ltc.keys():
+ lt_gini -= (ltc[key] / ltCount)**2
+
+ gt_gini = 1
+ for key in geqc.keys():
+ gt_gini -= (geqc[key] / geqCount)**2
+
+ lt_percent = ltCount / len(combined)
+ gt_percent = geqCount / len(combined)
+ weighted_gini = (lt_gini * lt_percent) + (gt_gini * gt_percent)
+
+ return weighted_gini
+
+
+ def _lessThan(self, sample):
+ value = sample[self.index]
+ if(value < self.val):
+ return True
+ return False
+
+ # split the data by current node
+ def split(self, arr):
+
+
+ ltCount = 0
+ gtCount = 0
+
+ for i in range(0, len(arr)):
+ lessThan = self._lessThan(arr[i])
+ if lessThan:
+ ltCount += 1
+ else:
+ gtCount += 1
+
+
+
+ ltArr = np.zeros(shape=(ltCount, arr.shape[1]))
+ gtArr = np.zeros(shape=(gtCount, arr.shape[1]))
+
+ ltItr = 0
+ gtItr = 0
+
+ for i in range(0, len(arr)):
+ lessThan = self._lessThan(arr[i])
+ if lessThan:
+ ltArr[ltItr] = arr[i]
+ ltItr += 1
+ else:
+ gtArr[gtItr] = arr[i]
+ gtItr += 1
+
+ assert ltItr + gtItr == len(arr)
+
+ return ltArr, gtArr
+
+ def __str__(self):
+ return f"Splitting index: {self.index}\nSplitting value: {self.val}"
diff --git a/classifier/Testing.py b/classifier/Testing.py
@@ -0,0 +1,16 @@
+from Podtc import PseudoOptimalDecisionTreeClassifier
+import numpy as np
+import plotly.express as px
+
+X = np.random.random((10, 2))
+y = (X[:,0] + X[:,1]) > 1
+
+classifier = PseudoOptimalDecisionTreeClassifier();
+
+classifier.fit(X,y)
+classifier.predict()
+print(classifier)
+
+
+scatter = px.scatter(x=X[:,0], y=X[:,1], color=y)
+scatter.show()
diff --git a/classifier/__init__.py b/classifier/__init__.py