machinelearning

Machine learning code
git clone git://git.laack.co/machinelearning.git
Log | Files | Refs

MNISTClassificationCNN.ipynb (31675B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "code",
      5    "execution_count": null,
      6    "metadata": {},
      7    "outputs": [],
      8    "source": []
      9   },
     10   {
     11    "cell_type": "markdown",
     12    "metadata": {},
     13    "source": [
     14     "# MNIST Classification With CNN (99.2% on Test)\n",
     15     "\n",
     16     "## Specs\n",
     17     "\n",
     18     "Data is preprocessed with one hot encoding for labels and the inputs are normalized from 0-255 to 0-1 values.\n",
     19     "\n",
     20     "The model is a CNN with three convolutional layers the first two having 32 filters and the last having 64 filters. \n",
     21     "\n",
     22     "The final hidden layer is a dense layer with 128 neurons that feeds into a softmax output layer. \n",
     23     "\n",
     24     "To regularize the model I added batch normalization on the outputs of the first two convolutional layers. I then added dropout for the final convolutional layer and the dense layer. \n",
     25     "\n",
     26     "The convolutional layer has 10% dropout and the dense layer has 25% dropout. \n",
     27     "\n",
     28     "To avoid complexity I used pooling (2x2) after the first and third convolutional layers. This also improves model robustness. \n",
     29     "\n",
     30     "___\n",
     31     "\n",
     32     "## Stats\n",
     33     "\n",
     34     "Total Parameters:\n",
     35     "\n",
     36     "160,810\n",
     37     "\n",
     38     "Accuracy:\n",
     39     "\n",
     40     "0.9919999837875366\n",
     41     "\n",
     42     "Loss:\n",
     43     "\n",
     44     "0.05224525183439255\n",
     45     "\n",
     46     "____\n",
     47     "\n",
     48     "## Lessons Learned\n",
     49     "\n",
     50     "Make sure to normalize inputs. I did not do this at first and the model was both inaccurate and taking forever to converge. In this case it made sense to normalized instead of standardize because the range of values is not crazy. \n",
     51     "\n",
     52     "Use 3x3 or 5x5 filters. I have never used CNNs before so this is good info to have. \n",
     53     "\n",
     54     "In convolutional layers use 32 filters as a baseline (most of the time). Also, have increasing numbers of filters as the model goes further (generally).  \n",
     55     "\n",
     56     "Don't mix dropout and batch normalization. This will lead to unexpected results. Use one or the other per layer if needed to fix overfitting and vanishing gradients. Normalization helps to speed up convergence by help with unstable gradients, and dropout helps to decrease overfitting. \n",
     57     "\n",
     58     "Use pooling to minimize computational overhead and to enhance robustness."
     59    ]
     60   },
     61   {
     62    "cell_type": "code",
     63    "execution_count": 6,
     64    "metadata": {},
     65    "outputs": [],
     66    "source": [
     67     "from keras.datasets import mnist\n",
     68     "import numpy as np\n",
     69     "\n",
     70     "np.random.seed = 10\n",
     71     "\n",
     72     "(X_train,y_train), (X_test, y_test) = mnist.load_data()"
     73    ]
     74   },
     75   {
     76    "cell_type": "code",
     77    "execution_count": 7,
     78    "metadata": {},
     79    "outputs": [],
     80    "source": [
     81     "from sklearn.model_selection import train_test_split\n",
     82     "import numpy as np\n",
     83     "from sklearn.preprocessing import OneHotEncoder\n",
     84     "\n",
     85     "X_test, X_validation, y_test, y_validation = train_test_split(X_test,y_test, random_state=10, test_size=.5)\n",
     86     "hot = OneHotEncoder(sparse_output=False)\n",
     87     "\n",
     88     "\n",
     89     "def transformX(X):\n",
     90     "    X = X / 255\n",
     91     "    return X\n",
     92     "\n",
     93     "\n",
     94     "def transformY(y):\n",
     95     "    y = y.reshape(-1,1)\n",
     96     "    y = hot.fit_transform(y)\n",
     97     "    return y\n",
     98     "\n",
     99     "y_train = transformY(y_train)\n",
    100     "X_train = transformX(X_train)\n",
    101     "\n",
    102     "y_test = transformY(y_test)\n",
    103     "X_test = transformX(X_test)\n",
    104     "\n",
    105     "y_validation = transformY(y_validation)\n",
    106     "X_validation = transformX(X_validation)\n",
    107     "\n"
    108    ]
    109   },
    110   {
    111    "cell_type": "code",
    112    "execution_count": 8,
    113    "metadata": {},
    114    "outputs": [
    115     {
    116      "data": {
    117       "text/html": [
    118        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_1\"</span>\n",
    119        "</pre>\n"
    120       ],
    121       "text/plain": [
    122        "\u001b[1mModel: \"sequential_1\"\u001b[0m\n"
    123       ]
    124      },
    125      "metadata": {},
    126      "output_type": "display_data"
    127     },
    128     {
    129      "data": {
    130       "text/html": [
    131        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
    132        "┃<span style=\"font-weight: bold\"> Layer (type)                    </span>┃<span style=\"font-weight: bold\"> Output Shape           </span>┃<span style=\"font-weight: bold\">       Param # </span>┃\n",
    133        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
    134        "│ conv2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>)     │           <span style=\"color: #00af00; text-decoration-color: #00af00\">320</span> │\n",
    135        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    136        "│ batch_normalization_2           │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>)     │           <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span> │\n",
    137        "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">BatchNormalization</span>)            │                        │               │\n",
    138        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    139        "│ max_pooling2d_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>)  │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">13</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>)     │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
    140        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    141        "│ conv2d_4 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>)     │         <span style=\"color: #00af00; text-decoration-color: #00af00\">9,248</span> │\n",
    142        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    143        "│ batch_normalization_3           │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">11</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>)     │           <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span> │\n",
    144        "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">BatchNormalization</span>)            │                        │               │\n",
    145        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    146        "│ conv2d_5 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>)               │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)       │        <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │\n",
    147        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    148        "│ dropout_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>)             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)       │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
    149        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    150        "│ max_pooling2d_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>)  │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">4</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>)       │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
    151        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    152        "│ flatten_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>)             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1024</span>)           │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
    153        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    154        "│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)            │       <span style=\"color: #00af00; text-decoration-color: #00af00\">131,200</span> │\n",
    155        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    156        "│ dropout_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>)             │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>)            │             <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
    157        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    158        "│ dense_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)             │         <span style=\"color: #00af00; text-decoration-color: #00af00\">1,290</span> │\n",
    159        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
    160        "</pre>\n"
    161       ],
    162       "text/plain": [
    163        "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
    164        "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape          \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
    165        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
    166        "│ conv2d_3 (\u001b[38;5;33mConv2D\u001b[0m)               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m)     │           \u001b[38;5;34m320\u001b[0m │\n",
    167        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    168        "│ batch_normalization_2           │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m)     │           \u001b[38;5;34m128\u001b[0m │\n",
    169        "│ (\u001b[38;5;33mBatchNormalization\u001b[0m)            │                        │               │\n",
    170        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    171        "│ max_pooling2d_2 (\u001b[38;5;33mMaxPooling2D\u001b[0m)  │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m13\u001b[0m, \u001b[38;5;34m32\u001b[0m)     │             \u001b[38;5;34m0\u001b[0m │\n",
    172        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    173        "│ conv2d_4 (\u001b[38;5;33mConv2D\u001b[0m)               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m)     │         \u001b[38;5;34m9,248\u001b[0m │\n",
    174        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    175        "│ batch_normalization_3           │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m11\u001b[0m, \u001b[38;5;34m32\u001b[0m)     │           \u001b[38;5;34m128\u001b[0m │\n",
    176        "│ (\u001b[38;5;33mBatchNormalization\u001b[0m)            │                        │               │\n",
    177        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    178        "│ conv2d_5 (\u001b[38;5;33mConv2D\u001b[0m)               │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m64\u001b[0m)       │        \u001b[38;5;34m18,496\u001b[0m │\n",
    179        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    180        "│ dropout_2 (\u001b[38;5;33mDropout\u001b[0m)             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m9\u001b[0m, \u001b[38;5;34m64\u001b[0m)       │             \u001b[38;5;34m0\u001b[0m │\n",
    181        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    182        "│ max_pooling2d_3 (\u001b[38;5;33mMaxPooling2D\u001b[0m)  │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m4\u001b[0m, \u001b[38;5;34m64\u001b[0m)       │             \u001b[38;5;34m0\u001b[0m │\n",
    183        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    184        "│ flatten_1 (\u001b[38;5;33mFlatten\u001b[0m)             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m)           │             \u001b[38;5;34m0\u001b[0m │\n",
    185        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    186        "│ dense_2 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m)            │       \u001b[38;5;34m131,200\u001b[0m │\n",
    187        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    188        "│ dropout_3 (\u001b[38;5;33mDropout\u001b[0m)             │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m)            │             \u001b[38;5;34m0\u001b[0m │\n",
    189        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    190        "│ dense_3 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)             │         \u001b[38;5;34m1,290\u001b[0m │\n",
    191        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
    192       ]
    193      },
    194      "metadata": {},
    195      "output_type": "display_data"
    196     },
    197     {
    198      "data": {
    199       "text/html": [
    200        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">160,810</span> (628.16 KB)\n",
    201        "</pre>\n"
    202       ],
    203       "text/plain": [
    204        "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m160,810\u001b[0m (628.16 KB)\n"
    205       ]
    206      },
    207      "metadata": {},
    208      "output_type": "display_data"
    209     },
    210     {
    211      "data": {
    212       "text/html": [
    213        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">160,682</span> (627.66 KB)\n",
    214        "</pre>\n"
    215       ],
    216       "text/plain": [
    217        "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m160,682\u001b[0m (627.66 KB)\n"
    218       ]
    219      },
    220      "metadata": {},
    221      "output_type": "display_data"
    222     },
    223     {
    224      "data": {
    225       "text/html": [
    226        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">128</span> (512.00 B)\n",
    227        "</pre>\n"
    228       ],
    229       "text/plain": [
    230        "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m128\u001b[0m (512.00 B)\n"
    231       ]
    232      },
    233      "metadata": {},
    234      "output_type": "display_data"
    235     }
    236    ],
    237    "source": [
    238     "import tensorflow as tf\n",
    239     "import keras\n",
    240     "\n",
    241     "model = keras.Sequential(\n",
    242     "    layers=[\n",
    243     "        keras.layers.Input((28,28, 1)),\n",
    244     "        keras.layers.Conv2D(kernel_size=(3,3), filters=32,),\n",
    245     "        keras.layers.BatchNormalization(),\n",
    246     "        keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
    247     "\n",
    248     "        keras.layers.Conv2D(kernel_size=(3,3), filters=32,),\n",
    249     "        keras.layers.BatchNormalization(),\n",
    250     "\n",
    251     "        keras.layers.Conv2D(kernel_size=(3,3), filters=64,),\n",
    252     "        keras.layers.Dropout(0.1),\n",
    253     "        keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
    254     "        keras.layers.Flatten(),\n",
    255     "\n",
    256     "        keras.layers.Dense(128, activation='relu'),\n",
    257     "        keras.layers.Dropout(0.25),\n",
    258     "\n",
    259     "        keras.layers.Dense(10, activation='softmax')\n",
    260     "    ]\n",
    261     "\n",
    262     ")\n",
    263     "\n",
    264     "optimizer = keras.optimizers.Adam()\n",
    265     "\n",
    266     "model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n",
    267     "\n",
    268     "model.summary()"
    269    ]
    270   },
    271   {
    272    "cell_type": "code",
    273    "execution_count": 9,
    274    "metadata": {},
    275    "outputs": [
    276     {
    277      "name": "stdout",
    278      "output_type": "stream",
    279      "text": [
    280       "Epoch 1/25\n",
    281       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m25s\u001b[0m 50ms/step - accuracy: 0.8535 - loss: 0.4796 - val_accuracy: 0.9094 - val_loss: 0.2961\n",
    282       "Epoch 2/25\n",
    283       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 49ms/step - accuracy: 0.9747 - loss: 0.0817 - val_accuracy: 0.9822 - val_loss: 0.0582\n",
    284       "Epoch 3/25\n",
    285       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 46ms/step - accuracy: 0.9801 - loss: 0.0652 - val_accuracy: 0.9880 - val_loss: 0.0389\n",
    286       "Epoch 4/25\n",
    287       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m45s\u001b[0m 54ms/step - accuracy: 0.9840 - loss: 0.0517 - val_accuracy: 0.9770 - val_loss: 0.0934\n",
    288       "Epoch 5/25\n",
    289       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 48ms/step - accuracy: 0.9854 - loss: 0.0473 - val_accuracy: 0.9902 - val_loss: 0.0354\n",
    290       "Epoch 6/25\n",
    291       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 44ms/step - accuracy: 0.9881 - loss: 0.0386 - val_accuracy: 0.9890 - val_loss: 0.0389\n",
    292       "Epoch 7/25\n",
    293       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m44s\u001b[0m 50ms/step - accuracy: 0.9880 - loss: 0.0394 - val_accuracy: 0.9896 - val_loss: 0.0430\n",
    294       "Epoch 8/25\n",
    295       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m39s\u001b[0m 46ms/step - accuracy: 0.9898 - loss: 0.0321 - val_accuracy: 0.9916 - val_loss: 0.0399\n",
    296       "Epoch 9/25\n",
    297       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 47ms/step - accuracy: 0.9897 - loss: 0.0319 - val_accuracy: 0.9898 - val_loss: 0.0490\n",
    298       "Epoch 10/25\n",
    299       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 48ms/step - accuracy: 0.9900 - loss: 0.0311 - val_accuracy: 0.9894 - val_loss: 0.0495\n",
    300       "Epoch 11/25\n",
    301       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 47ms/step - accuracy: 0.9913 - loss: 0.0263 - val_accuracy: 0.9908 - val_loss: 0.0386\n",
    302       "Epoch 12/25\n",
    303       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 48ms/step - accuracy: 0.9920 - loss: 0.0246 - val_accuracy: 0.9900 - val_loss: 0.0447\n",
    304       "Epoch 13/25\n",
    305       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m23s\u001b[0m 49ms/step - accuracy: 0.9932 - loss: 0.0225 - val_accuracy: 0.9890 - val_loss: 0.0570\n",
    306       "Epoch 14/25\n",
    307       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 46ms/step - accuracy: 0.9916 - loss: 0.0249 - val_accuracy: 0.9922 - val_loss: 0.0346\n",
    308       "Epoch 15/25\n",
    309       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 47ms/step - accuracy: 0.9922 - loss: 0.0242 - val_accuracy: 0.9898 - val_loss: 0.0485\n",
    310       "Epoch 16/25\n",
    311       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 47ms/step - accuracy: 0.9922 - loss: 0.0236 - val_accuracy: 0.9896 - val_loss: 0.0476\n",
    312       "Epoch 17/25\n",
    313       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 57ms/step - accuracy: 0.9935 - loss: 0.0212 - val_accuracy: 0.9916 - val_loss: 0.0411\n",
    314       "Epoch 18/25\n",
    315       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m24s\u001b[0m 50ms/step - accuracy: 0.9940 - loss: 0.0186 - val_accuracy: 0.9904 - val_loss: 0.0487\n",
    316       "Epoch 19/25\n",
    317       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 45ms/step - accuracy: 0.9945 - loss: 0.0153 - val_accuracy: 0.9900 - val_loss: 0.0625\n",
    318       "Epoch 20/25\n",
    319       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 45ms/step - accuracy: 0.9946 - loss: 0.0166 - val_accuracy: 0.9888 - val_loss: 0.0753\n",
    320       "Epoch 21/25\n",
    321       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 46ms/step - accuracy: 0.9946 - loss: 0.0171 - val_accuracy: 0.9896 - val_loss: 0.0673\n",
    322       "Epoch 22/25\n",
    323       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 47ms/step - accuracy: 0.9948 - loss: 0.0180 - val_accuracy: 0.9900 - val_loss: 0.0565\n",
    324       "Epoch 23/25\n",
    325       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 45ms/step - accuracy: 0.9947 - loss: 0.0155 - val_accuracy: 0.9926 - val_loss: 0.0535\n",
    326       "Epoch 24/25\n",
    327       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 45ms/step - accuracy: 0.9952 - loss: 0.0159 - val_accuracy: 0.9900 - val_loss: 0.0612\n",
    328       "Epoch 25/25\n",
    329       "\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 47ms/step - accuracy: 0.9952 - loss: 0.0156 - val_accuracy: 0.9936 - val_loss: 0.0561\n"
    330      ]
    331     },
    332     {
    333      "data": {
    334       "text/plain": [
    335        "<keras.src.callbacks.history.History at 0x7f5120805990>"
    336       ]
    337      },
    338      "execution_count": 9,
    339      "metadata": {},
    340      "output_type": "execute_result"
    341     }
    342    ],
    343    "source": [
    344     "model.fit(X_train, y_train, epochs=25, validation_data=[X_validation, y_validation], batch_size=128)"
    345    ]
    346   },
    347   {
    348    "cell_type": "code",
    349    "execution_count": 11,
    350    "metadata": {},
    351    "outputs": [
    352     {
    353      "name": "stdout",
    354      "output_type": "stream",
    355      "text": [
    356       "\u001b[1m157/157\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - accuracy: 0.9922 - loss: 0.0599\n"
    357      ]
    358     },
    359     {
    360      "data": {
    361       "text/plain": [
    362        "[0.05224525183439255, 0.9919999837875366]"
    363       ]
    364      },
    365      "execution_count": 11,
    366      "metadata": {},
    367      "output_type": "execute_result"
    368     }
    369    ],
    370    "source": [
    371     "model.evaluate(X_test, y_test)"
    372    ]
    373   },
    374   {
    375    "cell_type": "code",
    376    "execution_count": 12,
    377    "metadata": {},
    378    "outputs": [],
    379    "source": [
    380     "# model.save('../models/MNISTClassificationCNN99_2.keras')"
    381    ]
    382   }
    383  ],
    384  "metadata": {
    385   "kernelspec": {
    386    "display_name": ".venv",
    387    "language": "python",
    388    "name": "python3"
    389   },
    390   "language_info": {
    391    "codemirror_mode": {
    392     "name": "ipython",
    393     "version": 3
    394    },
    395    "file_extension": ".py",
    396    "mimetype": "text/x-python",
    397    "name": "python",
    398    "nbconvert_exporter": "python",
    399    "pygments_lexer": "ipython3",
    400    "version": "3.11.2"
    401   }
    402  },
    403  "nbformat": 4,
    404  "nbformat_minor": 2
    405 }