KNN.py (3686B)
1 # K-Nearest-Neighbors 2 import random 3 import heapq 4 from matplotlib import pyplot as plt 5 import numpy as np 6 7 8 class Neighbor: 9 def __init__(self, distance, index): 10 self.distance = distance 11 self.index = index 12 13 # this is done to make min heap simulate a max heap 14 def __lt__(self, other): 15 return -self.distance < -other.distance 16 17 def __str__(self): 18 return "Distance: " + str(self.distance) + " Index: " + str(self.index) 19 20 21 class KNN: 22 def __init__(self, n): 23 self.n = n 24 25 def fit(self, X, y): 26 self.X = X.copy() 27 self.y = y.copy() 28 29 def predict(self, X): 30 31 preds = [] 32 33 for i in range(len(X)): 34 35 nearest_neighbors = [] 36 37 for index, instance in enumerate(self.X): 38 euc_dist = euc_dist_squared(instance, X[i]) 39 neighbor = Neighbor(euc_dist, index) 40 41 if len(nearest_neighbors) > 0: 42 if len(nearest_neighbors) < self.n: 43 heapq.heappush(nearest_neighbors, neighbor) 44 else: 45 if neighbor.distance < nearest_neighbors[0].distance: 46 heapq.heappop(nearest_neighbors) 47 heapq.heappush(nearest_neighbors, neighbor) 48 else: 49 heapq.heappush(nearest_neighbors, neighbor) 50 51 if self.n > len(self.X): 52 assert len(nearest_neighbors) == self.n 53 54 preds.append(self._most_common_vote(nearest_neighbors)) 55 56 return preds 57 58 def _most_common_vote(self, neighbors): 59 60 votes = {} 61 62 for neighbor in neighbors: 63 classification = self.y[neighbor.index] 64 votes[classification] = votes.get(classification, 0) + 1 65 66 max = 0 67 most_popular = "" 68 69 for key in votes: 70 if votes[key] > max: 71 max = votes[key] 72 most_popular = key 73 74 return most_popular 75 76 77 def euc_dist_squared(input_1, input_2): 78 assert len(input_1) == len(input_2) 79 sum = 0 80 for i in range(len(input_1)): 81 sum += (input_1[i] - input_2[i]) ** 2 82 return sum 83 84 85 if __name__ == "__main__": 86 knn = KNN(5) 87 88 samples_1 = 500 89 samples_2 = 500 90 91 cluster_1 = [[random.gauss(0, 5), random.gauss(0, 5)] for _ in range(samples_1)] 92 cluster_2 = [[random.gauss(10, 5), random.gauss(10, 5)] for _ in range(samples_2)] 93 cluster_1.extend(cluster_2) 94 X = cluster_1 95 96 y = [0] * samples_1 97 y.extend([1] * samples_2) 98 99 knn.fit(X, y) 100 101 test_0 = [[random.gauss(0, 5), random.gauss(0, 5)] for _ in range(samples_1)] 102 test_1 = [[random.gauss(10, 5), random.gauss(10, 5)] for _ in range(samples_2)] 103 104 test_0.extend(test_1) 105 test_X = test_0 106 107 preds = knn.predict(test_X) 108 109 X = np.array(X) 110 test_X = np.array(test_X) 111 test_y = [0] * samples_1 112 test_y.extend([1] * samples_2) 113 114 correctness = [] 115 116 for i in range(len(test_y)): 117 if test_y[i] != preds[i]: 118 correctness.append(0) 119 else: 120 correctness.append(1) 121 122 # ORIGINAL 123 plt.scatter(x=X[:, 0], y=X[:, 1], c=y, cmap="jet", alpha=0.1) 124 125 alphas = [] 126 for correct in correctness: 127 if correct == 1: 128 alphas.append(0.1) 129 else: 130 alphas.append(1) 131 132 # PREDICTIONS 133 plt.scatter( 134 x=test_X[:, 0], 135 y=test_X[:, 1], 136 c=test_y, 137 cmap="jet", 138 edgecolors="black", 139 alpha=alphas, 140 ) 141 142 plt.show() 143 144 correct_count = 0 145 for correct in correctness: 146 if correct == 1: 147 correct_count += 1 148 149 print("Accuracy: ", correct_count / len(correctness))