machinelearning

Machine learning code
git clone git://git.laack.co/machinelearning.git
Log | Files | Refs

KNN.py (3686B)


      1 # K-Nearest-Neighbors
      2 import random
      3 import heapq
      4 from matplotlib import pyplot as plt
      5 import numpy as np
      6 
      7 
      8 class Neighbor:
      9     def __init__(self, distance, index):
     10         self.distance = distance
     11         self.index = index
     12 
     13     # this is done to make min heap simulate a max heap
     14     def __lt__(self, other):
     15         return -self.distance < -other.distance
     16 
     17     def __str__(self):
     18         return "Distance: " + str(self.distance) + " Index: " + str(self.index)
     19 
     20 
     21 class KNN:
     22     def __init__(self, n):
     23         self.n = n
     24 
     25     def fit(self, X, y):
     26         self.X = X.copy()
     27         self.y = y.copy()
     28 
     29     def predict(self, X):
     30 
     31         preds = []
     32 
     33         for i in range(len(X)):
     34 
     35             nearest_neighbors = []
     36 
     37             for index, instance in enumerate(self.X):
     38                 euc_dist = euc_dist_squared(instance, X[i])
     39                 neighbor = Neighbor(euc_dist, index)
     40 
     41                 if len(nearest_neighbors) > 0:
     42                     if len(nearest_neighbors) < self.n:
     43                         heapq.heappush(nearest_neighbors, neighbor)
     44                     else:
     45                         if neighbor.distance < nearest_neighbors[0].distance:
     46                             heapq.heappop(nearest_neighbors)
     47                             heapq.heappush(nearest_neighbors, neighbor)
     48                 else:
     49                     heapq.heappush(nearest_neighbors, neighbor)
     50 
     51             if self.n > len(self.X):
     52                 assert len(nearest_neighbors) == self.n
     53 
     54             preds.append(self._most_common_vote(nearest_neighbors))
     55 
     56         return preds
     57 
     58     def _most_common_vote(self, neighbors):
     59 
     60         votes = {}
     61 
     62         for neighbor in neighbors:
     63             classification = self.y[neighbor.index]
     64             votes[classification] = votes.get(classification, 0) + 1
     65 
     66         max = 0
     67         most_popular = ""
     68 
     69         for key in votes:
     70             if votes[key] > max:
     71                 max = votes[key]
     72                 most_popular = key
     73 
     74         return most_popular
     75 
     76 
     77 def euc_dist_squared(input_1, input_2):
     78     assert len(input_1) == len(input_2)
     79     sum = 0
     80     for i in range(len(input_1)):
     81         sum += (input_1[i] - input_2[i]) ** 2
     82     return sum
     83 
     84 
     85 if __name__ == "__main__":
     86     knn = KNN(5)
     87 
     88     samples_1 = 500
     89     samples_2 = 500
     90 
     91     cluster_1 = [[random.gauss(0, 5), random.gauss(0, 5)] for _ in range(samples_1)]
     92     cluster_2 = [[random.gauss(10, 5), random.gauss(10, 5)] for _ in range(samples_2)]
     93     cluster_1.extend(cluster_2)
     94     X = cluster_1
     95 
     96     y = [0] * samples_1
     97     y.extend([1] * samples_2)
     98 
     99     knn.fit(X, y)
    100 
    101     test_0 = [[random.gauss(0, 5), random.gauss(0, 5)] for _ in range(samples_1)]
    102     test_1 = [[random.gauss(10, 5), random.gauss(10, 5)] for _ in range(samples_2)]
    103 
    104     test_0.extend(test_1)
    105     test_X = test_0
    106 
    107     preds = knn.predict(test_X)
    108 
    109     X = np.array(X)
    110     test_X = np.array(test_X)
    111     test_y = [0] * samples_1
    112     test_y.extend([1] * samples_2)
    113 
    114     correctness = []
    115 
    116     for i in range(len(test_y)):
    117         if test_y[i] != preds[i]:
    118             correctness.append(0)
    119         else:
    120             correctness.append(1)
    121 
    122     # ORIGINAL
    123     plt.scatter(x=X[:, 0], y=X[:, 1], c=y, cmap="jet", alpha=0.1)
    124 
    125     alphas = []
    126     for correct in correctness:
    127         if correct == 1:
    128             alphas.append(0.1)
    129         else:
    130             alphas.append(1)
    131 
    132     # PREDICTIONS
    133     plt.scatter(
    134         x=test_X[:, 0],
    135         y=test_X[:, 1],
    136         c=test_y,
    137         cmap="jet",
    138         edgecolors="black",
    139         alpha=alphas,
    140     )
    141 
    142     plt.show()
    143 
    144     correct_count = 0
    145     for correct in correctness:
    146         if correct == 1:
    147             correct_count += 1
    148 
    149     print("Accuracy: ", correct_count / len(correctness))