MNISTAutoencoder.ipynb (13518B)
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": {}, 6 "source": [ 7 "This is an autoencoder trained to compress mnist hand written digits. " 8 ] 9 }, 10 { 11 "cell_type": "code", 12 "execution_count": 85, 13 "metadata": {}, 14 "outputs": [], 15 "source": [ 16 "from sklearn.datasets import fetch_openml\n", 17 "import matplotlib.pyplot as plt\n", 18 "import numpy as np\n", 19 "from sklearn.preprocessing import MinMaxScaler\n", 20 "\n", 21 "# Fetch the MNIST dataset\n", 22 "mnist = fetch_openml(\"mnist_784\", as_frame=False)\n", 23 "\n", 24 "minMax = MinMaxScaler()\n", 25 "\n", 26 "\n", 27 "# Extract data and labels\n", 28 "X, y = mnist.data, mnist.target\n", 29 "\n", 30 "X = minMax.fit_transform(X)\n", 31 "\n", 32 "# Reshape images to 28x28\n", 33 "X = X.reshape(-1, 28, 28)\n", 34 "\n", 35 "# Split the data into training and test sets\n", 36 "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]\n", 37 "\n", 38 "# Function to plot an image\n", 39 "def plot(image_data):\n", 40 " plt.imshow(image_data, cmap='binary')\n", 41 " plt.axis('off')\n", 42 " plt.show()\n" 43 ] 44 }, 45 { 46 "cell_type": "code", 47 "execution_count": 86, 48 "metadata": {}, 49 "outputs": [], 50 "source": [ 51 "import keras\n", 52 "import tensorflow as tf \n", 53 "\n", 54 "\n", 55 "input_layer = keras.layers.Input(shape=(28,28,1))\n", 56 "flatten_layer = keras.layers.Flatten()\n", 57 "first_layer = keras.layers.Dense(784, activation='relu')\n", 58 "second_layer = keras.layers.Dense(256, activation='relu')\n", 59 "third_layer = keras.layers.Dense(128, activation='relu')\n", 60 "fourth_layer = keras.layers.Dense(256, activation='relu')\n", 61 "fifth_layer = keras.layers.Dense(784, activation='sigmoid')\n", 62 "unflatten_layer = keras.layers.Reshape(target_shape=(28,28,1))\n" 63 ] 64 }, 65 { 66 "cell_type": "code", 67 "execution_count": 87, 68 "metadata": {}, 69 "outputs": [], 70 "source": [ 71 "autoencoder = keras.Sequential(layers=[\n", 72 " input_layer,\n", 73 " flatten_layer,\n", 74 " first_layer,\n", 75 " second_layer,\n", 76 " third_layer,\n", 77 " fourth_layer,\n", 78 " fifth_layer,\n", 79 " unflatten_layer\n", 80 "])" 81 ] 82 }, 83 { 84 "cell_type": "code", 85 "execution_count": 88, 86 "metadata": {}, 87 "outputs": [], 88 "source": [ 89 "autoencoder.compile(loss=keras.losses.MeanSquaredError, optimizer='adam')" 90 ] 91 }, 92 { 93 "cell_type": "code", 94 "execution_count": 89, 95 "metadata": {}, 96 "outputs": [ 97 { 98 "name": "stdout", 99 "output_type": "stream", 100 "text": [ 101 "Epoch 1/10\n", 102 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 7ms/step - loss: 0.0315\n", 103 "Epoch 2/10\n", 104 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0074\n", 105 "Epoch 3/10\n", 106 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 7ms/step - loss: 0.0057\n", 107 "Epoch 4/10\n", 108 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 6ms/step - loss: 0.0049\n", 109 "Epoch 5/10\n", 110 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 7ms/step - loss: 0.0044\n", 111 "Epoch 6/10\n", 112 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 7ms/step - loss: 0.0041\n", 113 "Epoch 7/10\n", 114 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0038\n", 115 "Epoch 8/10\n", 116 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0036\n", 117 "Epoch 9/10\n", 118 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0034\n", 119 "Epoch 10/10\n", 120 "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0033\n" 121 ] 122 }, 123 { 124 "data": { 125 "text/plain": [ 126 "<keras.src.callbacks.history.History at 0x7f225c1903d0>" 127 ] 128 }, 129 "execution_count": 89, 130 "metadata": {}, 131 "output_type": "execute_result" 132 } 133 ], 134 "source": [ 135 "autoencoder.fit(X_train, X_train, epochs=10)" 136 ] 137 }, 138 { 139 "cell_type": "code", 140 "execution_count": 90, 141 "metadata": {}, 142 "outputs": [], 143 "source": [ 144 "encoder = keras.Sequential(layers=[\n", 145 " input_layer,\n", 146 " flatten_layer,\n", 147 " first_layer,\n", 148 " second_layer,\n", 149 " third_layer,\n", 150 "])\n", 151 "\n", 152 "decoder = keras.Sequential(layers=[\n", 153 " third_layer,\n", 154 " fourth_layer,\n", 155 " fifth_layer,\n", 156 " unflatten_layer\n", 157 "])" 158 ] 159 }, 160 { 161 "cell_type": "code", 162 "execution_count": 92, 163 "metadata": {}, 164 "outputs": [ 165 { 166 "name": "stdout", 167 "output_type": "stream", 168 "text": [ 169 "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step\n", 170 "Before:\n" 171 ] 172 }, 173 { 174 "data": { 175 "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJpUlEQVR4nO3cvYuUZxvG4Xt2dlVQcROERKsg2AhCUiQQhCUBsZB0NoKV/4CNAUs7Cxst1Sp1CNqkTZ9WwSaQD4RE4ge46Ca7js6T7qzk5b3uZGeH2ePoT+aJM+G3T3ONhmEYGgC01pZ2+gEAmB+iAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgCxvNMPwLsNw9C129raKm/27t1b3oxGo/IGmH/eFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDCQbyiyWRS3ty7d6+8uXr1annTWmurq6vlzblz58qb06dPlzcffvhhedNaawcPHixvlpdn89Mej8flTe+xw54jhD2fNZ1Oy5uef4eeTWvzfYxxlt/tdvGmAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECMht6zfrvU5uZmeXPq1Kny5pdffilvWut7vr1795Y3H3zwQXmztbVV3rTW2tJS/W+XlZWVmWzW19fLm7/++qu8aa21PXv2lDc91zcPHTpU3nzxxRflzddff13etNbasWPHypt5ukI677wpABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAMTyTj/AbnD+/Pny5tdff+36rOPHj5c3T58+LW+eP39e3jx48KC8aa21P/74o7zpOTq3sbFR3rx+/bq86T3O9vbt2/LmwIED5c1vv/1W3nz33XflzaefflretNbaRx99VN6Mx+Ouz9qNvCkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAxGgYhmGnH2LRzfs/cc+Btp7jbOvr6+VNa629ePGivHn16lV58+eff85k89lnn5U3rbW2urpa3ty9e7e8uXLlSnnT8xv//vvvy5vWWltbW+va8f/xpgBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQyzv9ALtBz8G5eTcej8ubnoNurbX23nvvde2qTp48Wd7M8rudTqflzebmZnmztbVV3vR8R5988kl5w/bzpgBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAuJJKl2EYypvei6KLeGW2x4sXL8qb69evlzeTyaS8uXjxYnlz8ODB8obt500BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEZDz2UzFsp0Op3JZmmp72+Q3t28evPmTdfu0qVL5c2tW7fKm8OHD5c3P//8c3njIN58Wqz/2wD4V0QBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiOWdfgD+Wz33DUejUXkzHo/Lm16LdrNxY2Oja3f37t3yZs+ePeXN7du3yxvH7RaHNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAcBBvwfQct5t3PQfxZnUYcDqdljfffvttedNaa+vr6+XNiRMnypuzZ8+WNywObwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAMRp6LodBh96f2qyO2/V48uRJebO2ttb1Wc+ePStvfvjhh/Lm448/Lm9YHN4UAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIjlnX4Ado/eK6mzung6mUzKm6+++qq8+f3338ub1lo7ffp0eXPy5Mmuz2L38qYAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEA7iMTO9h+16dj3H965du1bePHjwoLx5//33y5vWWrt9+3Z5Mx6Puz6L3cubAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAEA4iDeneg669e56Ds7NatPr0aNH5c2NGzfKmzdv3pQ3X375ZXnTWmuHDx/u2kGFNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAGA29l9fYVr1fy3Q6LW+Wlup/G8zyuN3Lly/LmzNnzpQ3P/74Y3mzurpa3jx8+LC8aa21o0ePdu1mYTKZlDcrKyvb8CT8W94UAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAGJ5px+Ad+s5bNfabA/VVb19+7Zrd/PmzfLm/v375c2BAwfKm3v37pU3R44cKW/mneN2i8ObAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhSuqcGo/HO/0I/9MwDOXN48ePuz7rm2++KW96rnZeuHChvFlbWytv5vmSLXhTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAgH8eiytbVV3ty5c6frs54/f17e7Nu3r7y5fPlyebO05O8qFotfNAAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAEA4iEeXR48elTfXr1/v+qye43uHDh0qb1ZWVsobWDTeFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDCQbwZGIahvBmNRtvwJP+dn376qbyZTCbb8CTv9vnnn5c3q6ur5c0ifrfsbt4UAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhXUmeg5yrm5uZm12ft27evvPn777/Lm42NjfJmebnv59ZzifTWrVvlzf79+8ubHj3/Pa25rspseFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiNHQe50LgIXjTQGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDiH2QBVWrfZ8kbAAAAAElFTkSuQmCC", 176 "text/plain": [ 177 "<Figure size 640x480 with 1 Axes>" 178 ] 179 }, 180 "metadata": {}, 181 "output_type": "display_data" 182 }, 183 { 184 "name": "stdout", 185 "output_type": "stream", 186 "text": [ 187 "After:\n" 188 ] 189 }, 190 { 191 "data": { 192 "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAH8klEQVR4nO3csatXdQPH8XMfFANDRBpEUJAuXBGEhgZ1yCEJFBsimvwPHBrbnXXMIepP0CVEXaLEOwQK0uLgVC5CUA0NgSjn2d7LEw9+T15/N+/rtX84XzjDm+/yXZvneZ4AYJqm/6z6AABsH6IAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCANm16gPsBNevXx/efP3114u+dejQoeHNW2+9Nby5ePHi8ObgwYPDm2mapvX19UU7YJybAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkLV5nudVH+JNd/To0eHNzz///OoPsmL79u1btDt+/PgrPgmv2uHDh4c3X3zxxaJvvf/++4t2vBw3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkF2rPsBO8M033wxvfvrpp0XfWvJ43KNHj4Y3Dx8+HN788MMPw5tpmqYff/xxeHPkyJHhzZMnT4Y3r9Pu3buHN++8887w5unTp8ObJf9oySN60+RBvK3mpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFALI2z/O86kOwM/zxxx+Ldkse31vyaNr9+/eHN6/Tnj17hjcbGxvDm2PHjg1vfv/99+HNtWvXhjfTNE2XLl1atOPluCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYB4EA/eYDdu3BjefPbZZ8ObEydODG++//774c00TdOBAwcW7Xg5bgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEC8kgr/Er/++uvwZsnrpUu+c/369eHNp59+Orxh67kpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGA7Fr1AYCXc+3ateHNksft9u/fP7zZ2NgY3rA9uSkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCszfM8r/oQsJNsbm4u2n344YfDm2fPng1v7t69O7z54IMPhjdsT24KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgu1Z9ANhpbt26tWi35HG7s2fPDm9OnTo1vOHN4aYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQDiQTz4B/7666/hzZ07dxZ9a8+ePcOby5cvD2927949vOHN4aYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDEK6nwD1y5cmV48/Dhw0XfOnfu3PDm9OnTi77FzuWmAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAsjbP87zqQ8B2cPPmzeHNJ598MrzZu3fv8Gaapun27dvDm1OnTi36FjuXmwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAMiuVR8AtsJvv/02vPn888+HN8+fPx/enD9/fngzTR634/VwUwAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCAFmb53le9SHg/3nx4sXw5uTJk8ObBw8eDG/W19eHN3fu3BneTNM0vfvuu4t2MMJNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4bHuPHz8e3mxsbGzBSf7Xt99+O7z5+OOPt+Ak8Gq4KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCANm16gOwc/zyyy+Ldh999NErPsnfu3r16vDmwoULW3ASWB03BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEA/i8dp89dVXi3ZLH9IbdebMmeHN2traFpwEVsdNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4LHLv3r3hzZdffrkFJwFeJTcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQD+KxyObm5vDmzz//3IKT/L319fXhzdtvv70FJ4F/FzcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgXkll23vvvfeGN999993w5sCBA8MbeNO4KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgKzN8zyv+hAAbA9uCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYD8F0WiwuwxnHeiAAAAAElFTkSuQmCC", 193 "text/plain": [ 194 "<Figure size 640x480 with 1 Axes>" 195 ] 196 }, 197 "metadata": {}, 198 "output_type": "display_data" 199 } 200 ], 201 "source": [ 202 "prediction = autoencoder.predict(np.array([X_test[0]]))\n", 203 "print('Before:')\n", 204 "plot(prediction[0])\n", 205 "print('After:')\n", 206 "plot(X_test[0])" 207 ] 208 }, 209 { 210 "cell_type": "markdown", 211 "metadata": {}, 212 "source": [ 213 "These look virtualy identical and the second one was compressed into 128 dimensions instead of the original 784." 214 ] 215 } 216 ], 217 "metadata": { 218 "kernelspec": { 219 "display_name": ".venv", 220 "language": "python", 221 "name": "python3" 222 }, 223 "language_info": { 224 "codemirror_mode": { 225 "name": "ipython", 226 "version": 3 227 }, 228 "file_extension": ".py", 229 "mimetype": "text/x-python", 230 "name": "python", 231 "nbconvert_exporter": "python", 232 "pygments_lexer": "ipython3", 233 "version": "3.11.2" 234 } 235 }, 236 "nbformat": 4, 237 "nbformat_minor": 2 238 }