machinelearning

Machine learning code
git clone git://git.laack.co/machinelearning.git
Log | Files | Refs

MNISTRegressionClassificationNN.ipynb (22840B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "markdown",
      5    "metadata": {},
      6    "source": [
      7     "This model is dump. I used regression for this... smh. \n",
      8     "\n",
      9     "I should have used binary classification and softmax for the output instead of a regression that outputted the correct value.\n",
     10     "\n",
     11     "Despite this, it is still fairly accurate."
     12    ]
     13   },
     14   {
     15    "cell_type": "code",
     16    "execution_count": 1,
     17    "metadata": {},
     18    "outputs": [
     19     {
     20      "name": "stderr",
     21      "output_type": "stream",
     22      "text": [
     23       "2024-06-14 11:57:09.924151: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
     24       "2024-06-14 11:57:09.927539: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
     25       "2024-06-14 11:57:09.970132: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
     26       "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
     27       "2024-06-14 11:57:10.704842: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     28      ]
     29     }
     30    ],
     31    "source": [
     32     "import tensorflow as tf\n",
     33     "from sklearn.datasets import fetch_openml\n",
     34     "from tensorflow.keras import Sequential\n",
     35     "\n",
     36     "mnist = fetch_openml(\"mnist_784\", as_frame=False)"
     37    ]
     38   },
     39   {
     40    "cell_type": "code",
     41    "execution_count": 2,
     42    "metadata": {},
     43    "outputs": [
     44     {
     45      "data": {
     46       "text/plain": [
     47        "(70000, 784)"
     48       ]
     49      },
     50      "execution_count": 2,
     51      "metadata": {},
     52      "output_type": "execute_result"
     53     }
     54    ],
     55    "source": [
     56     "y = mnist.target\n",
     57     "X = mnist.data\n",
     58     "X.shape"
     59    ]
     60   },
     61   {
     62    "cell_type": "code",
     63    "execution_count": 3,
     64    "metadata": {},
     65    "outputs": [],
     66    "source": [
     67     "from sklearn.decomposition import PCA\n",
     68     "from sklearn.model_selection import train_test_split\n",
     69     "\n",
     70     "pca = PCA(n_components=.95)\n",
     71     "X = pca.fit_transform(X)\n",
     72     "\n",
     73     "count = 0\n",
     74     "for i in y:\n",
     75     "    y[count] = float(i)\n",
     76     "    count += 1\n",
     77     "\n",
     78     "X_train , X_test, y_train, y_test = train_test_split(X,y)"
     79    ]
     80   },
     81   {
     82    "cell_type": "code",
     83    "execution_count": 4,
     84    "metadata": {},
     85    "outputs": [
     86     {
     87      "name": "stderr",
     88      "output_type": "stream",
     89      "text": [
     90       "2024-06-14 11:57:17.182720: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
     91       "2024-06-14 11:57:17.184240: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
     92       "Skipping registering GPU devices...\n"
     93      ]
     94     }
     95    ],
     96    "source": [
     97     "model = Sequential()\n",
     98     "input_1 = tf.keras.layers.Input([len(X[0])])\n",
     99     "hidden_1 = tf.keras.layers.Dense(100 , 'relu')\n",
    100     "hidden_2 = tf.keras.layers.Dense(100 , 'relu')\n",
    101     "hidden_3 = tf.keras.layers.Dense(100 , 'relu')\n",
    102     "output_1 = tf.keras.layers.Dense(1)\n",
    103     "\n",
    104     "model.add(input_1)\n",
    105     "model.add(hidden_1)\n",
    106     "model.add(hidden_2)\n",
    107     "model.add(hidden_3)\n",
    108     "model.add(output_1)\n"
    109    ]
    110   },
    111   {
    112    "cell_type": "code",
    113    "execution_count": 5,
    114    "metadata": {},
    115    "outputs": [
    116     {
    117      "data": {
    118       "text/html": [
    119        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
    120        "</pre>\n"
    121       ],
    122       "text/plain": [
    123        "\u001b[1mModel: \"sequential\"\u001b[0m\n"
    124       ]
    125      },
    126      "metadata": {},
    127      "output_type": "display_data"
    128     },
    129     {
    130      "data": {
    131       "text/html": [
    132        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
    133        "┃<span style=\"font-weight: bold\"> Layer (type)                    </span>┃<span style=\"font-weight: bold\"> Output Shape           </span>┃<span style=\"font-weight: bold\">       Param # </span>┃\n",
    134        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
    135        "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                   │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>)            │        <span style=\"color: #00af00; text-decoration-color: #00af00\">15,500</span> │\n",
    136        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    137        "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>)            │        <span style=\"color: #00af00; text-decoration-color: #00af00\">10,100</span> │\n",
    138        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    139        "│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>)            │        <span style=\"color: #00af00; text-decoration-color: #00af00\">10,100</span> │\n",
    140        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    141        "│ dense_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>)              │           <span style=\"color: #00af00; text-decoration-color: #00af00\">101</span> │\n",
    142        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
    143        "</pre>\n"
    144       ],
    145       "text/plain": [
    146        "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
    147        "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                   \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape          \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
    148        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
    149        "│ dense (\u001b[38;5;33mDense\u001b[0m)                   │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m)            │        \u001b[38;5;34m15,500\u001b[0m │\n",
    150        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    151        "│ dense_1 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m)            │        \u001b[38;5;34m10,100\u001b[0m │\n",
    152        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    153        "│ dense_2 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m)            │        \u001b[38;5;34m10,100\u001b[0m │\n",
    154        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
    155        "│ dense_3 (\u001b[38;5;33mDense\u001b[0m)                 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m)              │           \u001b[38;5;34m101\u001b[0m │\n",
    156        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
    157       ]
    158      },
    159      "metadata": {},
    160      "output_type": "display_data"
    161     },
    162     {
    163      "data": {
    164       "text/html": [
    165        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">35,801</span> (139.85 KB)\n",
    166        "</pre>\n"
    167       ],
    168       "text/plain": [
    169        "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m35,801\u001b[0m (139.85 KB)\n"
    170       ]
    171      },
    172      "metadata": {},
    173      "output_type": "display_data"
    174     },
    175     {
    176      "data": {
    177       "text/html": [
    178        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">35,801</span> (139.85 KB)\n",
    179        "</pre>\n"
    180       ],
    181       "text/plain": [
    182        "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m35,801\u001b[0m (139.85 KB)\n"
    183       ]
    184      },
    185      "metadata": {},
    186      "output_type": "display_data"
    187     },
    188     {
    189      "data": {
    190       "text/html": [
    191        "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
    192        "</pre>\n"
    193       ],
    194       "text/plain": [
    195        "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
    196       ]
    197      },
    198      "metadata": {},
    199      "output_type": "display_data"
    200     }
    201    ],
    202    "source": [
    203     "model.summary()"
    204    ]
    205   },
    206   {
    207    "cell_type": "code",
    208    "execution_count": 6,
    209    "metadata": {},
    210    "outputs": [],
    211    "source": [
    212     "optimize = tf.keras.optimizers.Adam()\n",
    213     "model.compile(optimizer=optimize,loss='mse', metrics=['mae'])"
    214    ]
    215   },
    216   {
    217    "cell_type": "code",
    218    "execution_count": 7,
    219    "metadata": {},
    220    "outputs": [
    221     {
    222      "name": "stdout",
    223      "output_type": "stream",
    224      "text": [
    225       "Epoch 1/30\n",
    226       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 1ms/step - loss: 202.9985 - mae: 7.8408 - val_loss: 4.2281 - val_mae: 1.6003\n",
    227       "Epoch 2/30\n",
    228       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 4.2868 - mae: 1.6209 - val_loss: 3.1366 - val_mae: 1.3650\n",
    229       "Epoch 3/30\n",
    230       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 3.0048 - mae: 1.3345 - val_loss: 2.1976 - val_mae: 1.1122\n",
    231       "Epoch 4/30\n",
    232       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 2.1743 - mae: 1.1217 - val_loss: 1.8770 - val_mae: 1.0113\n",
    233       "Epoch 5/30\n",
    234       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 1.6681 - mae: 0.9647 - val_loss: 1.4420 - val_mae: 0.8738\n",
    235       "Epoch 6/30\n",
    236       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 1.3077 - mae: 0.8395 - val_loss: 1.5200 - val_mae: 0.9104\n",
    237       "Epoch 7/30\n",
    238       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.9894 - mae: 0.7081 - val_loss: 1.4996 - val_mae: 0.9626\n",
    239       "Epoch 8/30\n",
    240       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 1ms/step - loss: 0.8075 - mae: 0.6240 - val_loss: 1.0141 - val_mae: 0.6623\n",
    241       "Epoch 9/30\n",
    242       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.6778 - mae: 0.5564 - val_loss: 0.8471 - val_mae: 0.5888\n",
    243       "Epoch 10/30\n",
    244       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.5963 - mae: 0.5153 - val_loss: 0.7413 - val_mae: 0.5223\n",
    245       "Epoch 11/30\n",
    246       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.5044 - mae: 0.4551 - val_loss: 0.6787 - val_mae: 0.4632\n",
    247       "Epoch 12/30\n",
    248       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.4623 - mae: 0.4353 - val_loss: 0.6637 - val_mae: 0.4663\n",
    249       "Epoch 13/30\n",
    250       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.3952 - mae: 0.3990 - val_loss: 0.6840 - val_mae: 0.4617\n",
    251       "Epoch 14/30\n",
    252       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.3724 - mae: 0.3854 - val_loss: 0.6865 - val_mae: 0.4951\n",
    253       "Epoch 15/30\n",
    254       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 2ms/step - loss: 0.3377 - mae: 0.3649 - val_loss: 0.5628 - val_mae: 0.3864\n",
    255       "Epoch 16/30\n",
    256       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 2ms/step - loss: 0.2954 - mae: 0.3337 - val_loss: 0.6157 - val_mae: 0.4369\n",
    257       "Epoch 17/30\n",
    258       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 2ms/step - loss: 0.2833 - mae: 0.3199 - val_loss: 0.5496 - val_mae: 0.3746\n",
    259       "Epoch 18/30\n",
    260       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2465 - mae: 0.3095 - val_loss: 0.6086 - val_mae: 0.4012\n",
    261       "Epoch 19/30\n",
    262       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2436 - mae: 0.3003 - val_loss: 0.5564 - val_mae: 0.3431\n",
    263       "Epoch 20/30\n",
    264       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2413 - mae: 0.2940 - val_loss: 0.5232 - val_mae: 0.3490\n",
    265       "Epoch 21/30\n",
    266       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 1ms/step - loss: 0.2087 - mae: 0.2784 - val_loss: 0.5680 - val_mae: 0.3854\n",
    267       "Epoch 22/30\n",
    268       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2286 - mae: 0.2898 - val_loss: 0.5172 - val_mae: 0.3380\n",
    269       "Epoch 23/30\n",
    270       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2041 - mae: 0.2718 - val_loss: 0.4947 - val_mae: 0.3167\n",
    271       "Epoch 24/30\n",
    272       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1863 - mae: 0.2601 - val_loss: 0.5392 - val_mae: 0.3530\n",
    273       "Epoch 25/30\n",
    274       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1917 - mae: 0.2661 - val_loss: 0.5393 - val_mae: 0.3282\n",
    275       "Epoch 26/30\n",
    276       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1691 - mae: 0.2516 - val_loss: 0.4788 - val_mae: 0.3174\n",
    277       "Epoch 27/30\n",
    278       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1597 - mae: 0.2428 - val_loss: 0.5443 - val_mae: 0.3862\n",
    279       "Epoch 28/30\n",
    280       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1663 - mae: 0.2417 - val_loss: 0.4760 - val_mae: 0.3027\n",
    281       "Epoch 29/30\n",
    282       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1673 - mae: 0.2456 - val_loss: 0.5956 - val_mae: 0.3623\n",
    283       "Epoch 30/30\n",
    284       "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1619 - mae: 0.2360 - val_loss: 0.4810 - val_mae: 0.3082\n"
    285      ]
    286     },
    287     {
    288      "data": {
    289       "text/plain": [
    290        "<keras.src.callbacks.history.History at 0x7fb8eca64ad0>"
    291       ]
    292      },
    293      "execution_count": 7,
    294      "metadata": {},
    295      "output_type": "execute_result"
    296     }
    297    ],
    298    "source": [
    299     "import numpy as np\n",
    300     "X_train = np.asarray(X_train).astype('float32')\n",
    301     "y_train = np.asarray(y_train).astype('float32')\n",
    302     "X_test = np.asarray(X_test).astype('float32')\n",
    303     "y_test = np.asarray(y_test).astype('float32')\n",
    304     "\n",
    305     "\n",
    306     "model.fit(epochs=30, x=X_train, y=y_train, validation_data=(X_test, y_test))"
    307    ]
    308   },
    309   {
    310    "cell_type": "code",
    311    "execution_count": 8,
    312    "metadata": {},
    313    "outputs": [
    314     {
    315      "name": "stdout",
    316      "output_type": "stream",
    317      "text": [
    318       "\u001b[1m547/547\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 627us/step\n"
    319      ]
    320     }
    321    ],
    322    "source": [
    323     "y_pred = model.predict(X_test)"
    324    ]
    325   },
    326   {
    327    "cell_type": "code",
    328    "execution_count": 9,
    329    "metadata": {},
    330    "outputs": [],
    331    "source": [
    332     "count = 0\n",
    333     "for i in y_pred:\n",
    334     "    y_pred[count] = i.round()\n",
    335     "    count += 1"
    336    ]
    337   },
    338   {
    339    "cell_type": "code",
    340    "execution_count": 10,
    341    "metadata": {},
    342    "outputs": [],
    343    "source": [
    344     "correct = 0\n",
    345     "wrong = 0\n",
    346     "\n",
    347     "count = 0\n",
    348     "for i in y_pred:\n",
    349     "    if i != y_test[count]:\n",
    350     "        wrong += 1\n",
    351     "    else:\n",
    352     "        correct += 1\n",
    353     "    count += 1\n"
    354    ]
    355   },
    356   {
    357    "cell_type": "code",
    358    "execution_count": 11,
    359    "metadata": {},
    360    "outputs": [
    361     {
    362      "name": "stdout",
    363      "output_type": "stream",
    364      "text": [
    365       "15254\n",
    366       "2246\n",
    367       "0.8716571428571429\n"
    368      ]
    369     }
    370    ],
    371    "source": [
    372     "print(correct)\n",
    373     "print(wrong)\n",
    374     "print(correct / (correct + wrong))"
    375    ]
    376   },
    377   {
    378    "cell_type": "code",
    379    "execution_count": 13,
    380    "metadata": {},
    381    "outputs": [],
    382    "source": [
    383     "model.save('../models/MNISTClassificationModel.keras')"
    384    ]
    385   },
    386   {
    387    "cell_type": "code",
    388    "execution_count": 15,
    389    "metadata": {},
    390    "outputs": [],
    391    "source": [
    392     "loadedModel = tf.keras.models.load_model('../models/MNISTClassificationModel.keras')"
    393    ]
    394   }
    395  ],
    396  "metadata": {
    397   "kernelspec": {
    398    "display_name": "myvenv",
    399    "language": "python",
    400    "name": "python3"
    401   },
    402   "language_info": {
    403    "codemirror_mode": {
    404     "name": "ipython",
    405     "version": 3
    406    },
    407    "file_extension": ".py",
    408    "mimetype": "text/x-python",
    409    "name": "python",
    410    "nbconvert_exporter": "python",
    411    "pygments_lexer": "ipython3",
    412    "version": "3.11.2"
    413   }
    414  },
    415  "nbformat": 4,
    416  "nbformat_minor": 2
    417 }