commit 4a51fdcf8f570deb5dba6b14aa4649c359eabb47
parent 20f0de1bb557d065931e6932b1fa570f539598f8
Author: Andrew <andrewlaack1@gmail.com>
Date: Mon, 1 Jul 2024 15:59:18 -0500
Created autoencoder for mnist.
Diffstat:
2 files changed, 245 insertions(+), 0 deletions(-)
diff --git a/fashionMNIST/CNNFashionMNIST.ipynb b/fashionMNIST/CNNFashionMNIST.ipynb
@@ -78,6 +78,13 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "I should have use SPARSE categorical cross entropy. This stops the need to encode the expected outputs as [0,0,1] to represent 3. "
+ ]
+ },
+ {
"cell_type": "code",
"execution_count": 103,
"metadata": {},
diff --git a/mnist/MNISTAutoencoder.ipynb b/mnist/MNISTAutoencoder.ipynb
@@ -0,0 +1,238 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is an autoencoder trained to compress mnist hand written digits. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.datasets import fetch_openml\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from sklearn.preprocessing import MinMaxScaler\n",
+ "\n",
+ "# Fetch the MNIST dataset\n",
+ "mnist = fetch_openml(\"mnist_784\", as_frame=False)\n",
+ "\n",
+ "minMax = MinMaxScaler()\n",
+ "\n",
+ "\n",
+ "# Extract data and labels\n",
+ "X, y = mnist.data, mnist.target\n",
+ "\n",
+ "X = minMax.fit_transform(X)\n",
+ "\n",
+ "# Reshape images to 28x28\n",
+ "X = X.reshape(-1, 28, 28)\n",
+ "\n",
+ "# Split the data into training and test sets\n",
+ "X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]\n",
+ "\n",
+ "# Function to plot an image\n",
+ "def plot(image_data):\n",
+ " plt.imshow(image_data, cmap='binary')\n",
+ " plt.axis('off')\n",
+ " plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 86,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import keras\n",
+ "import tensorflow as tf \n",
+ "\n",
+ "\n",
+ "input_layer = keras.layers.Input(shape=(28,28,1))\n",
+ "flatten_layer = keras.layers.Flatten()\n",
+ "first_layer = keras.layers.Dense(784, activation='relu')\n",
+ "second_layer = keras.layers.Dense(256, activation='relu')\n",
+ "third_layer = keras.layers.Dense(128, activation='relu')\n",
+ "fourth_layer = keras.layers.Dense(256, activation='relu')\n",
+ "fifth_layer = keras.layers.Dense(784, activation='sigmoid')\n",
+ "unflatten_layer = keras.layers.Reshape(target_shape=(28,28,1))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 87,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "autoencoder = keras.Sequential(layers=[\n",
+ " input_layer,\n",
+ " flatten_layer,\n",
+ " first_layer,\n",
+ " second_layer,\n",
+ " third_layer,\n",
+ " fourth_layer,\n",
+ " fifth_layer,\n",
+ " unflatten_layer\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "autoencoder.compile(loss=keras.losses.MeanSquaredError, optimizer='adam')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 7ms/step - loss: 0.0315\n",
+ "Epoch 2/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0074\n",
+ "Epoch 3/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 7ms/step - loss: 0.0057\n",
+ "Epoch 4/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 6ms/step - loss: 0.0049\n",
+ "Epoch 5/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 7ms/step - loss: 0.0044\n",
+ "Epoch 6/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 7ms/step - loss: 0.0041\n",
+ "Epoch 7/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0038\n",
+ "Epoch 8/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0036\n",
+ "Epoch 9/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0034\n",
+ "Epoch 10/10\n",
+ "\u001b[1m1875/1875\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 6ms/step - loss: 0.0033\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "<keras.src.callbacks.history.History at 0x7f225c1903d0>"
+ ]
+ },
+ "execution_count": 89,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "autoencoder.fit(X_train, X_train, epochs=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "encoder = keras.Sequential(layers=[\n",
+ " input_layer,\n",
+ " flatten_layer,\n",
+ " first_layer,\n",
+ " second_layer,\n",
+ " third_layer,\n",
+ "])\n",
+ "\n",
+ "decoder = keras.Sequential(layers=[\n",
+ " third_layer,\n",
+ " fourth_layer,\n",
+ " fifth_layer,\n",
+ " unflatten_layer\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step\n",
+ "Before:\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJpUlEQVR4nO3cvYuUZxvG4Xt2dlVQcROERKsg2AhCUiQQhCUBsZB0NoKV/4CNAUs7Cxst1Sp1CNqkTZ9WwSaQD4RE4ge46Ca7js6T7qzk5b3uZGeH2ePoT+aJM+G3T3ONhmEYGgC01pZ2+gEAmB+iAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgCxvNMPwLsNw9C129raKm/27t1b3oxGo/IGmH/eFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDCQbyiyWRS3ty7d6+8uXr1annTWmurq6vlzblz58qb06dPlzcffvhhedNaawcPHixvlpdn89Mej8flTe+xw54jhD2fNZ1Oy5uef4eeTWvzfYxxlt/tdvGmAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECMht6zfrvU5uZmeXPq1Kny5pdffilvWut7vr1795Y3H3zwQXmztbVV3rTW2tJS/W+XlZWVmWzW19fLm7/++qu8aa21PXv2lDc91zcPHTpU3nzxxRflzddff13etNbasWPHypt5ukI677wpABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAMTyTj/AbnD+/Pny5tdff+36rOPHj5c3T58+LW+eP39e3jx48KC8aa21P/74o7zpOTq3sbFR3rx+/bq86T3O9vbt2/LmwIED5c1vv/1W3nz33XflzaefflretNbaRx99VN6Mx+Ouz9qNvCkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAxGgYhmGnH2LRzfs/cc+Btp7jbOvr6+VNa629ePGivHn16lV58+eff85k89lnn5U3rbW2urpa3ty9e7e8uXLlSnnT8xv//vvvy5vWWltbW+va8f/xpgBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQyzv9ALtBz8G5eTcej8ubnoNurbX23nvvde2qTp48Wd7M8rudTqflzebmZnmztbVV3vR8R5988kl5w/bzpgBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAuJJKl2EYypvei6KLeGW2x4sXL8qb69evlzeTyaS8uXjxYnlz8ODB8obt500BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEZDz2UzFsp0Op3JZmmp72+Q3t28evPmTdfu0qVL5c2tW7fKm8OHD5c3P//8c3njIN58Wqz/2wD4V0QBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiOWdfgD+Wz33DUejUXkzHo/Lm16LdrNxY2Oja3f37t3yZs+ePeXN7du3yxvH7RaHNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAcBBvwfQct5t3PQfxZnUYcDqdljfffvttedNaa+vr6+XNiRMnypuzZ8+WNywObwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAMRp6LodBh96f2qyO2/V48uRJebO2ttb1Wc+ePStvfvjhh/Lm448/Lm9YHN4UAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIjlnX4Ado/eK6mzung6mUzKm6+++qq8+f3338ub1lo7ffp0eXPy5Mmuz2L38qYAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEA7iMTO9h+16dj3H965du1bePHjwoLx5//33y5vWWrt9+3Z5Mx6Puz6L3cubAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAEA4iDeneg669e56Ds7NatPr0aNH5c2NGzfKmzdv3pQ3X375ZXnTWmuHDx/u2kGFNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAGA29l9fYVr1fy3Q6LW+Wlup/G8zyuN3Lly/LmzNnzpQ3P/74Y3mzurpa3jx8+LC8aa21o0ePdu1mYTKZlDcrKyvb8CT8W94UAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAGJ5px+Ad+s5bNfabA/VVb19+7Zrd/PmzfLm/v375c2BAwfKm3v37pU3R44cKW/mneN2i8ObAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhSuqcGo/HO/0I/9MwDOXN48ePuz7rm2++KW96rnZeuHChvFlbWytv5vmSLXhTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAgH8eiytbVV3ty5c6frs54/f17e7Nu3r7y5fPlyebO05O8qFotfNAAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAEA4iEeXR48elTfXr1/v+qye43uHDh0qb1ZWVsobWDTeFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDCQbwZGIahvBmNRtvwJP+dn376qbyZTCbb8CTv9vnnn5c3q6ur5c0ifrfsbt4UAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhXUmeg5yrm5uZm12ft27evvPn777/Lm42NjfJmebnv59ZzifTWrVvlzf79+8ubHj3/Pa25rspseFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiNHQe50LgIXjTQGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQDiH2QBVWrfZ8kbAAAAAElFTkSuQmCC",
+ "text/plain": [
+ "<Figure size 640x480 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "After:\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAH8klEQVR4nO3csatXdQPH8XMfFANDRBpEUJAuXBGEhgZ1yCEJFBsimvwPHBrbnXXMIepP0CVEXaLEOwQK0uLgVC5CUA0NgSjn2d7LEw9+T15/N+/rtX84XzjDm+/yXZvneZ4AYJqm/6z6AABsH6IAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCANm16gPsBNevXx/efP3114u+dejQoeHNW2+9Nby5ePHi8ObgwYPDm2mapvX19UU7YJybAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkLV5nudVH+JNd/To0eHNzz///OoPsmL79u1btDt+/PgrPgmv2uHDh4c3X3zxxaJvvf/++4t2vBw3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkF2rPsBO8M033wxvfvrpp0XfWvJ43KNHj4Y3Dx8+HN788MMPw5tpmqYff/xxeHPkyJHhzZMnT4Y3r9Pu3buHN++8887w5unTp8ObJf9oySN60+RBvK3mpgBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFALI2z/O86kOwM/zxxx+Ldkse31vyaNr9+/eHN6/Tnj17hjcbGxvDm2PHjg1vfv/99+HNtWvXhjfTNE2XLl1atOPluCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYB4EA/eYDdu3BjefPbZZ8ObEydODG++//774c00TdOBAwcW7Xg5bgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEC8kgr/Er/++uvwZsnrpUu+c/369eHNp59+Orxh67kpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGA7Fr1AYCXc+3ateHNksft9u/fP7zZ2NgY3rA9uSkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCszfM8r/oQsJNsbm4u2n344YfDm2fPng1v7t69O7z54IMPhjdsT24KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgu1Z9ANhpbt26tWi35HG7s2fPDm9OnTo1vOHN4aYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQDiQTz4B/7666/hzZ07dxZ9a8+ePcOby5cvD2927949vOHN4aYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDEK6nwD1y5cmV48/Dhw0XfOnfu3PDm9OnTi77FzuWmAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAsjbP87zqQ8B2cPPmzeHNJ598MrzZu3fv8Gaapun27dvDm1OnTi36FjuXmwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAMiuVR8AtsJvv/02vPn888+HN8+fPx/enD9/fngzTR634/VwUwAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCAFmb53le9SHg/3nx4sXw5uTJk8ObBw8eDG/W19eHN3fu3BneTNM0vfvuu4t2MMJNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4bHuPHz8e3mxsbGzBSf7Xt99+O7z5+OOPt+Ak8Gq4KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCANm16gOwc/zyyy+Ldh999NErPsnfu3r16vDmwoULW3ASWB03BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEA/i8dp89dVXi3ZLH9IbdebMmeHN2traFpwEVsdNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4LHLv3r3hzZdffrkFJwFeJTcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQD+KxyObm5vDmzz//3IKT/L319fXhzdtvv70FJ4F/FzcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgXkll23vvvfeGN999993w5sCBA8MbeNO4KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgKzN8zyv+hAAbA9uCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYD8F0WiwuwxnHeiAAAAAElFTkSuQmCC",
+ "text/plain": [
+ "<Figure size 640x480 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "prediction = autoencoder.predict(np.array([X_test[0]]))\n",
+ "print('Before:')\n",
+ "plot(prediction[0])\n",
+ "print('After:')\n",
+ "plot(X_test[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "These look virtualy identical and the second one was compressed into 128 dimensions instead of the original 784."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}