decision-tree-classifier

Decision tree classifier implementation in C++
git clone git://git.laack.co/decision-tree-classifier.git
Log | Files | Refs | README | LICENSE

Podtc.py (10042B)


      1 from warnings import warn
      2 import ctypes
      3 import numpy as np
      4 from tqdm import tqdm
      5 from SplittingNode import SplittingNode
      6 from SplittingNode import gini
      7 from LeafNode import LeafNode
      8 import math
      9 import graphviz
     10 
     11 class PseudoOptimalDecisionTreeClassifier():
     12 
     13 
     14     # First is first split, last is ... well yeah.
     15 
     16     def __init__(self, pruneThreshold = .2, maxDepth = 5, proportionToTrainOn=.5, proportionToValidateSplits=.5, proportionOfDimsToTrainOn=.5):
     17         self.threshold = pruneThreshold 
     18         self.maxDepth = maxDepth
     19 
     20         # I guess allow > 1, just d
     21 
     22         if(proportionToTrainOn > 1 or proportionToTrainOn <= 0):
     23             raise Exception(f"Proportion to train on {proportionToTrainOn}, is not valid. Select a proportion in the range of (0,1]")
     24 
     25         self.proportionToTrainOn = proportionToTrainOn
     26 
     27         if(proportionToValidateSplits > 1 or proportionToValidateSplits <= 0):
     28             raise Exception(f"Proportion to validate splits with {proportionToValidateSplits}, is not valid. Select a proportion in the range of (0,1]")
     29 
     30         self.propValSplits = proportionToValidateSplits
     31 
     32         if(proportionOfDimsToTrainOn > 1 or proportionOfDimsToTrainOn <= 0):
     33             raise Exception(f"Proportion of dimensions to train on {proportionToValidateSplits}, is not valid. Select a proportion in the range of (0,1]")
     34 
     35         self.propDimsTrain = proportionOfDimsToTrainOn 
     36         return
     37 
     38     def fit(self, X,  y):
     39 
     40         X,y = self.__validateInput(X,y)
     41         y_re = y.reshape(-1,1)
     42 
     43         self.sampleSize = X.shape[1]
     44         self.categories = np.unique(y)
     45         
     46         # together [:,-1] == y
     47         together = np.append(X,y_re, axis=1)
     48 
     49         dims = self.findDims(together)
     50 
     51         self.bestSplit = self.recurse(together, self.maxDepth, dims)
     52 
     53         return
     54 
     55     def findDims(self, together):
     56 
     57         dimCount = len(together[0]) - 1
     58         dims = np.arange(dimCount)
     59         dimsToSample = math.ceil(dimCount * self.propDimsTrain)
     60         if(dimsToSample != dimCount):
     61             dims = dimsWithMostVar(dimsToSample, together)
     62 
     63         return dims
     64 
     65 
     66 
     67     def __classification(self, together):
     68         lastCol = together[:, -1].astype('int')
     69         counts = np.bincount(lastCol, minlength=len(self.categories))
     70         majority_label = np.argmax(counts)
     71         if(len(counts) == 0):
     72             assert False
     73         return majority_label, counts
     74 
     75 
     76 
     77     def recurse(self, together, depth, dims):
     78         
     79         if(depth == 0):
     80             classification, elements = self.__classification(together)
     81             return LeafNode(classification, len(together), elements)
     82 
     83         bestSplit, ltGini, gtGini = self._best_split(together, self.proportionToTrainOn, dims)
     84 
     85         if bestSplit is None:
     86             raise ValueError("bestSplit cannot be None")
     87 
     88         ltArr, gtArr = bestSplit.split(arr=together)
     89 
     90         if len(ltArr) == len(together):
     91             classification, elements = self.__classification(ltArr)
     92             return LeafNode(classification, len(ltArr), elements)
     93 
     94         if len(gtArr) == len(together):
     95             classification, elements = self.__classification(gtArr)
     96             return LeafNode(classification, len(gtArr), elements)
     97 
     98         # might make sense to simply stop
     99         # if the length of either array is 0
    100         # because that means splits aren't doing anything..
    101         # just a thought
    102 
    103         if len(ltArr) > 1 and ltGini > 0:
    104             blt = self.recurse(ltArr, depth - 1, dims)
    105             bestSplit.leftChild = blt
    106         else:
    107             classification, elements = self.__classification(ltArr)
    108             bestSplit.leftChild = LeafNode(classification, len(ltArr), elements)
    109 
    110         if len(gtArr) > 1 and gtGini > 0:
    111             bgt = self.recurse(gtArr, depth - 1, dims)
    112             bestSplit.rightChild= bgt
    113         else:
    114             classification, elements = self.__classification(gtArr)
    115             bestSplit.rightChild = LeafNode(classification, len(gtArr), elements)
    116 
    117         return bestSplit
    118 
    119     # pass in current root
    120     # find best options from then on
    121 
    122     # Find best split
    123     def _best_split(self, together, proportionUsed, dims):
    124 
    125         bestGini = float("inf")
    126         bestNode  = None
    127         blg = float("inf")
    128         bgg = float("inf")
    129 
    130         # indices for evals. This decides which indices to check upon splitting
    131         values = np.round(np.linspace(0, len(together) - 1, math.ceil(self.propValSplits * len(together)))).astype(np.int32)
    132         vals = values.ctypes.data_as(ctypes.POINTER(ctypes.c_int))
    133 
    134         indices = np.round(np.linspace(0, len(together) - 2, math.ceil(proportionUsed * len(together)))).astype(int)
    135         sample_count = len(together[:, 0])
    136 
    137         # columns (excluding y)
    138         for x in tqdm(dims):
    139 
    140             # random sampling would be a lot faster
    141 
    142             together = together[together[:,x].argsort()]
    143 
    144 
    145             # each row (sample)
    146             # also, we are interpolating between samples
    147 
    148             # indices for splits (this decides how many splits to test)
    149 
    150 
    151             eles = together[:, x].astype(np.float32)
    152             eles = eles.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
    153 
    154             classes = together[:, -1].astype(np.int32).ctypes.data_as(ctypes.POINTER(ctypes.c_int))
    155 
    156             
    157             for currentSample in indices:
    158                 splitOn = ((together[currentSample+1][x] - together[currentSample][x]) / 2) + together[currentSample][x]
    159                 split = SplittingNode(x, splitOn)
    160 
    161                 current = gini(eles , values, split.val, classes, sample_count, vals)
    162 
    163 
    164                 if current[0] < bestGini: # type: ignore
    165                     bestNode = split
    166                     bestGini = current[0]  # type: ignore
    167                     blg = current[1]  # type: ignore
    168                     bgg = current[2]  # type: ignore
    169 
    170 
    171         # Return the best node, the left gini impurity, and right gini impurity.
    172         # These impurities allow for us to stop if we have a pure node.
    173         return (bestNode, blg, bgg)
    174 
    175 
    176     # OPTIMIZE THIS... THIS SHOULD BE DONE WITH BATCHES NOT SINGLE INSTANCES
    177     def predict(self, X):
    178 
    179         self.__validatePrediction(X)
    180 
    181         y = np.zeros(shape=(len(X),))
    182 
    183         leaf = LeafNode(0,0,[])
    184 
    185         for i in range(0, len(y)):
    186 
    187             done = False
    188             current = self.bestSplit
    189 
    190             while not done:
    191                 if isinstance(current,type(leaf)):
    192                     y[i] = current.classification
    193                     done = True
    194                     continue
    195 
    196                 if self._lessThan(current, X[i]): #type: ignore
    197                     current = current.leftChild #type: ignore
    198                 else:
    199                     current = current.rightChild #type: ignore
    200         return y
    201 
    202     def _lessThan(self, split : SplittingNode, sample):
    203         if(sample[split.index] < split.val):
    204             return True
    205         return False
    206 
    207 
    208     def __str__(self):
    209         return "TODO"
    210 
    211     def __validatePrediction(self, X):
    212 
    213         # check if bestSplit has been set
    214         try:
    215             self.bestSplit
    216         except:
    217             raise Exception("Tree must be fit prior to prediction")
    218 
    219         X = np.asarray(X)
    220 
    221         if len(X.shape) != 2:
    222             raise Exception(f"X shape {X.shape} not supported. Ensure input array is 2d.")
    223 
    224         if np.issubdtype(X.dtype, np.str_):
    225             raise Exception(f"X contains strings which is not allowed.")
    226         
    227         if X.shape[1] != self.sampleSize:
    228             raise Exception(f"Prediction sample of size {X.shape[1]}, not compatible with fitted size of {self.sampleSize}")
    229 
    230 
    231         return 0
    232 
    233     def __validateInput(self, X,y):
    234 
    235         y = np.asarray(y)
    236         X = np.asarray(X)
    237 
    238         if len(X.shape) != 2:
    239             raise Exception(f"X shape {X.shape} not supported. Ensure input array is 2d.")
    240 
    241         if np.issubdtype(y.dtype, np.str_):
    242             raise Exception(f"y contains strings which is not allowed.")
    243 
    244         if np.issubdtype(X.dtype, np.str_):
    245             raise Exception(f"X contains strings which is not allowed.")
    246 
    247         if np.issubdtype(y.dtype, np.floating):  
    248             if not np.all(np.equal(np.floor(y), y)):  # Check if all values are whole numbers
    249                 raise Exception("y array contains continuous values, but classification requires discrete values")
    250 
    251         if X.shape[0] != y.shape[0]:
    252             raise Exception(f"Incongruent array sizes. X has shape {X.shape} and y has shape {y.shape}.")
    253         
    254 
    255         if X.shape[0] <= 1:
    256             raise Exception(f"X must contain more than one sample.")
    257         return X,y
    258 
    259     def graph(self):    
    260         if self.bestSplit == None:
    261             raise Exception(f"Unable to create graph of classifier, call fit first.")
    262 
    263         graph = graphviz.Digraph()
    264         graph = createGraph(self.bestSplit, graph)
    265         graph.render('whatever', format='png', view=True)
    266 
    267 def createGraph(node, graph):
    268     graph.node(str(node.__hash__()), str(node))
    269     traverseForGraph(node, graph)
    270     return graph
    271 
    272 def traverseForGraph(node, graph):
    273 
    274     cid = str(node.__hash__())
    275 
    276     if node.leftChild == None:
    277         raise Exception("left child should never be none")
    278 
    279     if node.rightChild == None:
    280         raise Exception("right child should never be none")
    281 
    282     graph.node(str(node.leftChild.__hash__()), str(node.leftChild))
    283     graph.node(str(node.rightChild.__hash__()), str(node.rightChild))
    284 
    285     graph.edge(cid, str(node.leftChild.__hash__()))
    286     graph.edge(cid, str(node.rightChild.__hash__()))
    287 
    288 
    289     # will be false in case where child is leaf node
    290     if type(node.leftChild) == type(node):
    291         traverseForGraph(node.leftChild, graph)
    292 
    293     if type(node.rightChild) == type(node):
    294         traverseForGraph(node.rightChild, graph)
    295 
    296 
    297 
    298     return graph
    299 
    300 # use np.var
    301 def dimsWithMostVar(dimCount, arr):
    302     
    303     assert dimCount < len(arr[0]) - 1
    304     
    305     vars = np.var(arr[:, :-1], axis=0)
    306     retArr = np.argsort(vars)[::-1]
    307     retArr = retArr[:dimCount]
    308 
    309     assert dimCount == retArr.shape[0]
    310     return retArr
    311