SplittingNode.py (3504B)
1 import numpy as np 2 import math 3 import ctypes 4 5 class SplittingNode: 6 7 def __init__(self, feature_index, value, rightChild = None, leftChild = None): 8 self.index = feature_index 9 self.val = value 10 11 self.rightChild = rightChild 12 self.leftChild = leftChild 13 14 # maybe add input validation??? 15 # do in place weighted gini calculation 16 17 # split the data by current node 18 19 20 def split(self, arr): 21 22 pySplit = True 23 24 if pySplit: 25 return self.__split_py(arr) 26 else: 27 assert False 28 29 30 31 def __split_py(self, arr): 32 ltCount = 0 33 gtCount = 0 34 35 for i in range(0, len(arr)): 36 lessThan = _lessThan(arr[i], self.index, self.val) 37 if lessThan: 38 ltCount += 1 39 else: 40 gtCount += 1 41 42 43 44 ltArr = np.zeros(shape=(ltCount, arr.shape[1])) 45 gtArr = np.zeros(shape=(gtCount, arr.shape[1])) 46 47 ltItr = 0 48 gtItr = 0 49 50 for i in range(0, len(arr)): 51 lt = _lessThan(arr[i], self.index, self.val) 52 if lt: 53 ltArr[ltItr] = arr[i] 54 ltItr += 1 55 else: 56 gtArr[gtItr] = arr[i] 57 gtItr += 1 58 59 assert ltItr + gtItr == len(arr) 60 61 return ltArr, gtArr 62 63 64 def __str__(self): 65 return f"Splitting index: {self.index}\nSplitting value: {round(self.val,2)}" 66 67 class GiniResult(ctypes.Structure): 68 _fields_ = [("weighted", ctypes.c_float), 69 ("ltGini", ctypes.c_float), 70 ("gtGini", ctypes.c_float)] 71 72 def gini(eles, values, val, classes, sample_count, vals): 73 74 gini_lib = ctypes.CDLL('./cpp/libgini.so') 75 gini_lib.gini.restype = GiniResult 76 gini_lib.gini.argtypes = [ 77 ctypes.POINTER(ctypes.c_float), 78 ctypes.POINTER(ctypes.c_int), 79 ctypes.c_int, 80 ctypes.c_float, 81 ctypes.POINTER(ctypes.c_int), 82 ctypes.c_int 83 ] 84 85 split_val = ctypes.c_float(val) 86 result = gini_lib.gini(eles, classes, sample_count, split_val, vals, len(values)) 87 weightedGini = result.weighted 88 ltGini = result.ltGini 89 gtGini = result.gtGini 90 return (weightedGini, ltGini, gtGini) 91 92 93 def giniPy(combined , values, index, val): 94 ltc = {} 95 geqc = {} 96 ltCount = 0 97 geqCount = 0 98 for i in values: 99 100 lt = _lessThan(combined[i], index, val) 101 classification = int(combined[i][-1]) 102 103 if(lt): 104 ltCount += 1 105 value = ltc.get(classification) 106 if(value != None): 107 ltc[classification] = value + 1 108 else: 109 ltc[classification] = 1 110 else: 111 geqCount += 1 112 value = geqc.get(classification) 113 if(value != None): 114 geqc[classification] = value + 1 115 else: 116 geqc[classification] = 1 117 118 lt_gini = 1 119 for key in ltc.keys(): 120 lt_gini -= (ltc[key] / ltCount)**2 121 122 gt_gini = 1 123 for key in geqc.keys(): 124 gt_gini -= (geqc[key] / geqCount)**2 125 126 if(geqCount == 0): 127 gt_gini = 0 128 if(ltCount == 0): 129 lt_gini = 0 130 131 lt_percent = ltCount / len(values) 132 gt_percent = geqCount / len(values) 133 134 weighted_gini = (lt_gini * lt_percent) + (gt_gini * gt_percent) 135 136 return (weighted_gini, lt_gini, gt_gini) 137 138 139 def _lessThan(sample, index, val): 140 value = sample[index] 141 if(value < val): 142 return True 143 return False 144