decision-tree-classifier

Decision tree classifier implementation in C++
git clone git://git.laack.co/decision-tree-classifier.git
Log | Files | Refs | README | LICENSE

SplittingNode.py (3504B)


      1 import numpy as np
      2 import math
      3 import ctypes
      4 
      5 class SplittingNode:
      6 
      7     def __init__(self, feature_index, value, rightChild = None, leftChild = None):
      8         self.index = feature_index
      9         self.val = value
     10 
     11         self.rightChild = rightChild
     12         self.leftChild = leftChild
     13 
     14     # maybe add input validation???
     15     # do in place weighted gini calculation
     16 
     17     # split the data by current node
     18 
     19 
     20     def split(self, arr):
     21 
     22         pySplit = True
     23 
     24         if pySplit:
     25             return self.__split_py(arr) 
     26         else:
     27             assert False
     28 
     29 
     30 
     31     def __split_py(self, arr):
     32         ltCount = 0
     33         gtCount = 0
     34 
     35         for i in range(0, len(arr)):
     36             lessThan = _lessThan(arr[i], self.index, self.val)
     37             if lessThan:
     38                 ltCount += 1
     39             else:
     40                 gtCount += 1
     41 
     42 
     43 
     44         ltArr = np.zeros(shape=(ltCount, arr.shape[1]))
     45         gtArr = np.zeros(shape=(gtCount, arr.shape[1]))
     46 
     47         ltItr = 0
     48         gtItr = 0
     49 
     50         for i in range(0, len(arr)):
     51             lt = _lessThan(arr[i], self.index, self.val)
     52             if lt:
     53                 ltArr[ltItr] = arr[i]
     54                 ltItr += 1
     55             else:
     56                 gtArr[gtItr] = arr[i]
     57                 gtItr += 1
     58         
     59         assert ltItr + gtItr == len(arr)
     60 
     61         return ltArr, gtArr
     62 
     63 
     64     def __str__(self):
     65         return f"Splitting index: {self.index}\nSplitting value: {round(self.val,2)}"
     66 
     67 class GiniResult(ctypes.Structure):
     68     _fields_ = [("weighted", ctypes.c_float),
     69                 ("ltGini", ctypes.c_float),
     70                 ("gtGini", ctypes.c_float)]
     71 
     72 def gini(eles, values, val, classes, sample_count, vals):
     73 
     74     gini_lib = ctypes.CDLL('./cpp/libgini.so')
     75     gini_lib.gini.restype = GiniResult
     76     gini_lib.gini.argtypes = [
     77             ctypes.POINTER(ctypes.c_float), 
     78             ctypes.POINTER(ctypes.c_int), 
     79             ctypes.c_int, 
     80             ctypes.c_float, 
     81             ctypes.POINTER(ctypes.c_int), 
     82             ctypes.c_int
     83     ]
     84 
     85     split_val = ctypes.c_float(val)
     86     result = gini_lib.gini(eles, classes, sample_count, split_val, vals, len(values))
     87     weightedGini = result.weighted
     88     ltGini = result.ltGini
     89     gtGini = result.gtGini
     90     return (weightedGini, ltGini, gtGini)
     91 
     92 
     93 def giniPy(combined , values, index, val):
     94     ltc = {}
     95     geqc = {}
     96     ltCount = 0
     97     geqCount = 0
     98     for i in values:
     99 
    100         lt = _lessThan(combined[i], index, val)
    101         classification = int(combined[i][-1])
    102 
    103         if(lt):
    104             ltCount += 1
    105             value = ltc.get(classification)
    106             if(value != None):
    107                 ltc[classification] = value + 1
    108             else:
    109                 ltc[classification] = 1
    110         else:
    111             geqCount += 1
    112             value = geqc.get(classification)
    113             if(value != None):
    114                 geqc[classification] = value + 1
    115             else:
    116                 geqc[classification] = 1
    117 
    118     lt_gini = 1
    119     for key in ltc.keys():
    120         lt_gini -= (ltc[key] / ltCount)**2
    121 
    122     gt_gini = 1
    123     for key in geqc.keys():
    124         gt_gini -= (geqc[key] / geqCount)**2
    125 
    126     if(geqCount == 0):
    127         gt_gini = 0
    128     if(ltCount == 0):
    129         lt_gini = 0
    130 
    131     lt_percent = ltCount / len(values)
    132     gt_percent = geqCount / len(values)
    133 
    134     weighted_gini = (lt_gini * lt_percent) + (gt_gini * gt_percent)
    135 
    136     return (weighted_gini, lt_gini, gt_gini)
    137 
    138 
    139 def _lessThan(sample, index, val):
    140     value = sample[index]
    141     if(value < val):
    142         return True
    143     return False
    144