Podtc.py (10042B)
1 from warnings import warn 2 import ctypes 3 import numpy as np 4 from tqdm import tqdm 5 from SplittingNode import SplittingNode 6 from SplittingNode import gini 7 from LeafNode import LeafNode 8 import math 9 import graphviz 10 11 class PseudoOptimalDecisionTreeClassifier(): 12 13 14 # First is first split, last is ... well yeah. 15 16 def __init__(self, pruneThreshold = .2, maxDepth = 5, proportionToTrainOn=.5, proportionToValidateSplits=.5, proportionOfDimsToTrainOn=.5): 17 self.threshold = pruneThreshold 18 self.maxDepth = maxDepth 19 20 # I guess allow > 1, just d 21 22 if(proportionToTrainOn > 1 or proportionToTrainOn <= 0): 23 raise Exception(f"Proportion to train on {proportionToTrainOn}, is not valid. Select a proportion in the range of (0,1]") 24 25 self.proportionToTrainOn = proportionToTrainOn 26 27 if(proportionToValidateSplits > 1 or proportionToValidateSplits <= 0): 28 raise Exception(f"Proportion to validate splits with {proportionToValidateSplits}, is not valid. Select a proportion in the range of (0,1]") 29 30 self.propValSplits = proportionToValidateSplits 31 32 if(proportionOfDimsToTrainOn > 1 or proportionOfDimsToTrainOn <= 0): 33 raise Exception(f"Proportion of dimensions to train on {proportionToValidateSplits}, is not valid. Select a proportion in the range of (0,1]") 34 35 self.propDimsTrain = proportionOfDimsToTrainOn 36 return 37 38 def fit(self, X, y): 39 40 X,y = self.__validateInput(X,y) 41 y_re = y.reshape(-1,1) 42 43 self.sampleSize = X.shape[1] 44 self.categories = np.unique(y) 45 46 # together [:,-1] == y 47 together = np.append(X,y_re, axis=1) 48 49 dims = self.findDims(together) 50 51 self.bestSplit = self.recurse(together, self.maxDepth, dims) 52 53 return 54 55 def findDims(self, together): 56 57 dimCount = len(together[0]) - 1 58 dims = np.arange(dimCount) 59 dimsToSample = math.ceil(dimCount * self.propDimsTrain) 60 if(dimsToSample != dimCount): 61 dims = dimsWithMostVar(dimsToSample, together) 62 63 return dims 64 65 66 67 def __classification(self, together): 68 lastCol = together[:, -1].astype('int') 69 counts = np.bincount(lastCol, minlength=len(self.categories)) 70 majority_label = np.argmax(counts) 71 if(len(counts) == 0): 72 assert False 73 return majority_label, counts 74 75 76 77 def recurse(self, together, depth, dims): 78 79 if(depth == 0): 80 classification, elements = self.__classification(together) 81 return LeafNode(classification, len(together), elements) 82 83 bestSplit, ltGini, gtGini = self._best_split(together, self.proportionToTrainOn, dims) 84 85 if bestSplit is None: 86 raise ValueError("bestSplit cannot be None") 87 88 ltArr, gtArr = bestSplit.split(arr=together) 89 90 if len(ltArr) == len(together): 91 classification, elements = self.__classification(ltArr) 92 return LeafNode(classification, len(ltArr), elements) 93 94 if len(gtArr) == len(together): 95 classification, elements = self.__classification(gtArr) 96 return LeafNode(classification, len(gtArr), elements) 97 98 # might make sense to simply stop 99 # if the length of either array is 0 100 # because that means splits aren't doing anything.. 101 # just a thought 102 103 if len(ltArr) > 1 and ltGini > 0: 104 blt = self.recurse(ltArr, depth - 1, dims) 105 bestSplit.leftChild = blt 106 else: 107 classification, elements = self.__classification(ltArr) 108 bestSplit.leftChild = LeafNode(classification, len(ltArr), elements) 109 110 if len(gtArr) > 1 and gtGini > 0: 111 bgt = self.recurse(gtArr, depth - 1, dims) 112 bestSplit.rightChild= bgt 113 else: 114 classification, elements = self.__classification(gtArr) 115 bestSplit.rightChild = LeafNode(classification, len(gtArr), elements) 116 117 return bestSplit 118 119 # pass in current root 120 # find best options from then on 121 122 # Find best split 123 def _best_split(self, together, proportionUsed, dims): 124 125 bestGini = float("inf") 126 bestNode = None 127 blg = float("inf") 128 bgg = float("inf") 129 130 # indices for evals. This decides which indices to check upon splitting 131 values = np.round(np.linspace(0, len(together) - 1, math.ceil(self.propValSplits * len(together)))).astype(np.int32) 132 vals = values.ctypes.data_as(ctypes.POINTER(ctypes.c_int)) 133 134 indices = np.round(np.linspace(0, len(together) - 2, math.ceil(proportionUsed * len(together)))).astype(int) 135 sample_count = len(together[:, 0]) 136 137 # columns (excluding y) 138 for x in tqdm(dims): 139 140 # random sampling would be a lot faster 141 142 together = together[together[:,x].argsort()] 143 144 145 # each row (sample) 146 # also, we are interpolating between samples 147 148 # indices for splits (this decides how many splits to test) 149 150 151 eles = together[:, x].astype(np.float32) 152 eles = eles.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) 153 154 classes = together[:, -1].astype(np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int)) 155 156 157 for currentSample in indices: 158 splitOn = ((together[currentSample+1][x] - together[currentSample][x]) / 2) + together[currentSample][x] 159 split = SplittingNode(x, splitOn) 160 161 current = gini(eles , values, split.val, classes, sample_count, vals) 162 163 164 if current[0] < bestGini: # type: ignore 165 bestNode = split 166 bestGini = current[0] # type: ignore 167 blg = current[1] # type: ignore 168 bgg = current[2] # type: ignore 169 170 171 # Return the best node, the left gini impurity, and right gini impurity. 172 # These impurities allow for us to stop if we have a pure node. 173 return (bestNode, blg, bgg) 174 175 176 # OPTIMIZE THIS... THIS SHOULD BE DONE WITH BATCHES NOT SINGLE INSTANCES 177 def predict(self, X): 178 179 self.__validatePrediction(X) 180 181 y = np.zeros(shape=(len(X),)) 182 183 leaf = LeafNode(0,0,[]) 184 185 for i in range(0, len(y)): 186 187 done = False 188 current = self.bestSplit 189 190 while not done: 191 if isinstance(current,type(leaf)): 192 y[i] = current.classification 193 done = True 194 continue 195 196 if self._lessThan(current, X[i]): #type: ignore 197 current = current.leftChild #type: ignore 198 else: 199 current = current.rightChild #type: ignore 200 return y 201 202 def _lessThan(self, split : SplittingNode, sample): 203 if(sample[split.index] < split.val): 204 return True 205 return False 206 207 208 def __str__(self): 209 return "TODO" 210 211 def __validatePrediction(self, X): 212 213 # check if bestSplit has been set 214 try: 215 self.bestSplit 216 except: 217 raise Exception("Tree must be fit prior to prediction") 218 219 X = np.asarray(X) 220 221 if len(X.shape) != 2: 222 raise Exception(f"X shape {X.shape} not supported. Ensure input array is 2d.") 223 224 if np.issubdtype(X.dtype, np.str_): 225 raise Exception(f"X contains strings which is not allowed.") 226 227 if X.shape[1] != self.sampleSize: 228 raise Exception(f"Prediction sample of size {X.shape[1]}, not compatible with fitted size of {self.sampleSize}") 229 230 231 return 0 232 233 def __validateInput(self, X,y): 234 235 y = np.asarray(y) 236 X = np.asarray(X) 237 238 if len(X.shape) != 2: 239 raise Exception(f"X shape {X.shape} not supported. Ensure input array is 2d.") 240 241 if np.issubdtype(y.dtype, np.str_): 242 raise Exception(f"y contains strings which is not allowed.") 243 244 if np.issubdtype(X.dtype, np.str_): 245 raise Exception(f"X contains strings which is not allowed.") 246 247 if np.issubdtype(y.dtype, np.floating): 248 if not np.all(np.equal(np.floor(y), y)): # Check if all values are whole numbers 249 raise Exception("y array contains continuous values, but classification requires discrete values") 250 251 if X.shape[0] != y.shape[0]: 252 raise Exception(f"Incongruent array sizes. X has shape {X.shape} and y has shape {y.shape}.") 253 254 255 if X.shape[0] <= 1: 256 raise Exception(f"X must contain more than one sample.") 257 return X,y 258 259 def graph(self): 260 if self.bestSplit == None: 261 raise Exception(f"Unable to create graph of classifier, call fit first.") 262 263 graph = graphviz.Digraph() 264 graph = createGraph(self.bestSplit, graph) 265 graph.render('whatever', format='png', view=True) 266 267 def createGraph(node, graph): 268 graph.node(str(node.__hash__()), str(node)) 269 traverseForGraph(node, graph) 270 return graph 271 272 def traverseForGraph(node, graph): 273 274 cid = str(node.__hash__()) 275 276 if node.leftChild == None: 277 raise Exception("left child should never be none") 278 279 if node.rightChild == None: 280 raise Exception("right child should never be none") 281 282 graph.node(str(node.leftChild.__hash__()), str(node.leftChild)) 283 graph.node(str(node.rightChild.__hash__()), str(node.rightChild)) 284 285 graph.edge(cid, str(node.leftChild.__hash__())) 286 graph.edge(cid, str(node.rightChild.__hash__())) 287 288 289 # will be false in case where child is leaf node 290 if type(node.leftChild) == type(node): 291 traverseForGraph(node.leftChild, graph) 292 293 if type(node.rightChild) == type(node): 294 traverseForGraph(node.rightChild, graph) 295 296 297 298 return graph 299 300 # use np.var 301 def dimsWithMostVar(dimCount, arr): 302 303 assert dimCount < len(arr[0]) - 1 304 305 vars = np.var(arr[:, :-1], axis=0) 306 retArr = np.argsort(vars)[::-1] 307 retArr = retArr[:dimCount] 308 309 assert dimCount == retArr.shape[0] 310 return retArr 311