MNISTRegressionClassificationNN.ipynb (22840B)
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": {}, 6 "source": [ 7 "This model is dump. I used regression for this... smh. \n", 8 "\n", 9 "I should have used binary classification and softmax for the output instead of a regression that outputted the correct value.\n", 10 "\n", 11 "Despite this, it is still fairly accurate." 12 ] 13 }, 14 { 15 "cell_type": "code", 16 "execution_count": 1, 17 "metadata": {}, 18 "outputs": [ 19 { 20 "name": "stderr", 21 "output_type": "stream", 22 "text": [ 23 "2024-06-14 11:57:09.924151: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", 24 "2024-06-14 11:57:09.927539: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", 25 "2024-06-14 11:57:09.970132: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", 26 "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 27 "2024-06-14 11:57:10.704842: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" 28 ] 29 } 30 ], 31 "source": [ 32 "import tensorflow as tf\n", 33 "from sklearn.datasets import fetch_openml\n", 34 "from tensorflow.keras import Sequential\n", 35 "\n", 36 "mnist = fetch_openml(\"mnist_784\", as_frame=False)" 37 ] 38 }, 39 { 40 "cell_type": "code", 41 "execution_count": 2, 42 "metadata": {}, 43 "outputs": [ 44 { 45 "data": { 46 "text/plain": [ 47 "(70000, 784)" 48 ] 49 }, 50 "execution_count": 2, 51 "metadata": {}, 52 "output_type": "execute_result" 53 } 54 ], 55 "source": [ 56 "y = mnist.target\n", 57 "X = mnist.data\n", 58 "X.shape" 59 ] 60 }, 61 { 62 "cell_type": "code", 63 "execution_count": 3, 64 "metadata": {}, 65 "outputs": [], 66 "source": [ 67 "from sklearn.decomposition import PCA\n", 68 "from sklearn.model_selection import train_test_split\n", 69 "\n", 70 "pca = PCA(n_components=.95)\n", 71 "X = pca.fit_transform(X)\n", 72 "\n", 73 "count = 0\n", 74 "for i in y:\n", 75 " y[count] = float(i)\n", 76 " count += 1\n", 77 "\n", 78 "X_train , X_test, y_train, y_test = train_test_split(X,y)" 79 ] 80 }, 81 { 82 "cell_type": "code", 83 "execution_count": 4, 84 "metadata": {}, 85 "outputs": [ 86 { 87 "name": "stderr", 88 "output_type": "stream", 89 "text": [ 90 "2024-06-14 11:57:17.182720: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", 91 "2024-06-14 11:57:17.184240: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", 92 "Skipping registering GPU devices...\n" 93 ] 94 } 95 ], 96 "source": [ 97 "model = Sequential()\n", 98 "input_1 = tf.keras.layers.Input([len(X[0])])\n", 99 "hidden_1 = tf.keras.layers.Dense(100 , 'relu')\n", 100 "hidden_2 = tf.keras.layers.Dense(100 , 'relu')\n", 101 "hidden_3 = tf.keras.layers.Dense(100 , 'relu')\n", 102 "output_1 = tf.keras.layers.Dense(1)\n", 103 "\n", 104 "model.add(input_1)\n", 105 "model.add(hidden_1)\n", 106 "model.add(hidden_2)\n", 107 "model.add(hidden_3)\n", 108 "model.add(output_1)\n" 109 ] 110 }, 111 { 112 "cell_type": "code", 113 "execution_count": 5, 114 "metadata": {}, 115 "outputs": [ 116 { 117 "data": { 118 "text/html": [ 119 "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n", 120 "</pre>\n" 121 ], 122 "text/plain": [ 123 "\u001b[1mModel: \"sequential\"\u001b[0m\n" 124 ] 125 }, 126 "metadata": {}, 127 "output_type": "display_data" 128 }, 129 { 130 "data": { 131 "text/html": [ 132 "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", 133 "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n", 134 "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", 135 "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">15,500</span> │\n", 136 "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", 137 "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">10,100</span> │\n", 138 "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", 139 "│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">10,100</span> │\n", 140 "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", 141 "│ dense_3 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">101</span> │\n", 142 "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", 143 "</pre>\n" 144 ], 145 "text/plain": [ 146 "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", 147 "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", 148 "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", 149 "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m15,500\u001b[0m │\n", 150 "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", 151 "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m10,100\u001b[0m │\n", 152 "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", 153 "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m10,100\u001b[0m │\n", 154 "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", 155 "│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", 156 "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" 157 ] 158 }, 159 "metadata": {}, 160 "output_type": "display_data" 161 }, 162 { 163 "data": { 164 "text/html": [ 165 "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">35,801</span> (139.85 KB)\n", 166 "</pre>\n" 167 ], 168 "text/plain": [ 169 "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m35,801\u001b[0m (139.85 KB)\n" 170 ] 171 }, 172 "metadata": {}, 173 "output_type": "display_data" 174 }, 175 { 176 "data": { 177 "text/html": [ 178 "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">35,801</span> (139.85 KB)\n", 179 "</pre>\n" 180 ], 181 "text/plain": [ 182 "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m35,801\u001b[0m (139.85 KB)\n" 183 ] 184 }, 185 "metadata": {}, 186 "output_type": "display_data" 187 }, 188 { 189 "data": { 190 "text/html": [ 191 "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n", 192 "</pre>\n" 193 ], 194 "text/plain": [ 195 "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" 196 ] 197 }, 198 "metadata": {}, 199 "output_type": "display_data" 200 } 201 ], 202 "source": [ 203 "model.summary()" 204 ] 205 }, 206 { 207 "cell_type": "code", 208 "execution_count": 6, 209 "metadata": {}, 210 "outputs": [], 211 "source": [ 212 "optimize = tf.keras.optimizers.Adam()\n", 213 "model.compile(optimizer=optimize,loss='mse', metrics=['mae'])" 214 ] 215 }, 216 { 217 "cell_type": "code", 218 "execution_count": 7, 219 "metadata": {}, 220 "outputs": [ 221 { 222 "name": "stdout", 223 "output_type": "stream", 224 "text": [ 225 "Epoch 1/30\n", 226 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 1ms/step - loss: 202.9985 - mae: 7.8408 - val_loss: 4.2281 - val_mae: 1.6003\n", 227 "Epoch 2/30\n", 228 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 4.2868 - mae: 1.6209 - val_loss: 3.1366 - val_mae: 1.3650\n", 229 "Epoch 3/30\n", 230 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 3.0048 - mae: 1.3345 - val_loss: 2.1976 - val_mae: 1.1122\n", 231 "Epoch 4/30\n", 232 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 2.1743 - mae: 1.1217 - val_loss: 1.8770 - val_mae: 1.0113\n", 233 "Epoch 5/30\n", 234 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 1.6681 - mae: 0.9647 - val_loss: 1.4420 - val_mae: 0.8738\n", 235 "Epoch 6/30\n", 236 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 1.3077 - mae: 0.8395 - val_loss: 1.5200 - val_mae: 0.9104\n", 237 "Epoch 7/30\n", 238 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.9894 - mae: 0.7081 - val_loss: 1.4996 - val_mae: 0.9626\n", 239 "Epoch 8/30\n", 240 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 1ms/step - loss: 0.8075 - mae: 0.6240 - val_loss: 1.0141 - val_mae: 0.6623\n", 241 "Epoch 9/30\n", 242 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.6778 - mae: 0.5564 - val_loss: 0.8471 - val_mae: 0.5888\n", 243 "Epoch 10/30\n", 244 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.5963 - mae: 0.5153 - val_loss: 0.7413 - val_mae: 0.5223\n", 245 "Epoch 11/30\n", 246 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.5044 - mae: 0.4551 - val_loss: 0.6787 - val_mae: 0.4632\n", 247 "Epoch 12/30\n", 248 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.4623 - mae: 0.4353 - val_loss: 0.6637 - val_mae: 0.4663\n", 249 "Epoch 13/30\n", 250 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.3952 - mae: 0.3990 - val_loss: 0.6840 - val_mae: 0.4617\n", 251 "Epoch 14/30\n", 252 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.3724 - mae: 0.3854 - val_loss: 0.6865 - val_mae: 0.4951\n", 253 "Epoch 15/30\n", 254 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 2ms/step - loss: 0.3377 - mae: 0.3649 - val_loss: 0.5628 - val_mae: 0.3864\n", 255 "Epoch 16/30\n", 256 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 2ms/step - loss: 0.2954 - mae: 0.3337 - val_loss: 0.6157 - val_mae: 0.4369\n", 257 "Epoch 17/30\n", 258 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 2ms/step - loss: 0.2833 - mae: 0.3199 - val_loss: 0.5496 - val_mae: 0.3746\n", 259 "Epoch 18/30\n", 260 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2465 - mae: 0.3095 - val_loss: 0.6086 - val_mae: 0.4012\n", 261 "Epoch 19/30\n", 262 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2436 - mae: 0.3003 - val_loss: 0.5564 - val_mae: 0.3431\n", 263 "Epoch 20/30\n", 264 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2413 - mae: 0.2940 - val_loss: 0.5232 - val_mae: 0.3490\n", 265 "Epoch 21/30\n", 266 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 1ms/step - loss: 0.2087 - mae: 0.2784 - val_loss: 0.5680 - val_mae: 0.3854\n", 267 "Epoch 22/30\n", 268 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2286 - mae: 0.2898 - val_loss: 0.5172 - val_mae: 0.3380\n", 269 "Epoch 23/30\n", 270 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.2041 - mae: 0.2718 - val_loss: 0.4947 - val_mae: 0.3167\n", 271 "Epoch 24/30\n", 272 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1863 - mae: 0.2601 - val_loss: 0.5392 - val_mae: 0.3530\n", 273 "Epoch 25/30\n", 274 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1917 - mae: 0.2661 - val_loss: 0.5393 - val_mae: 0.3282\n", 275 "Epoch 26/30\n", 276 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1691 - mae: 0.2516 - val_loss: 0.4788 - val_mae: 0.3174\n", 277 "Epoch 27/30\n", 278 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1597 - mae: 0.2428 - val_loss: 0.5443 - val_mae: 0.3862\n", 279 "Epoch 28/30\n", 280 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1663 - mae: 0.2417 - val_loss: 0.4760 - val_mae: 0.3027\n", 281 "Epoch 29/30\n", 282 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1673 - mae: 0.2456 - val_loss: 0.5956 - val_mae: 0.3623\n", 283 "Epoch 30/30\n", 284 "\u001b[1m1641/1641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 1ms/step - loss: 0.1619 - mae: 0.2360 - val_loss: 0.4810 - val_mae: 0.3082\n" 285 ] 286 }, 287 { 288 "data": { 289 "text/plain": [ 290 "<keras.src.callbacks.history.History at 0x7fb8eca64ad0>" 291 ] 292 }, 293 "execution_count": 7, 294 "metadata": {}, 295 "output_type": "execute_result" 296 } 297 ], 298 "source": [ 299 "import numpy as np\n", 300 "X_train = np.asarray(X_train).astype('float32')\n", 301 "y_train = np.asarray(y_train).astype('float32')\n", 302 "X_test = np.asarray(X_test).astype('float32')\n", 303 "y_test = np.asarray(y_test).astype('float32')\n", 304 "\n", 305 "\n", 306 "model.fit(epochs=30, x=X_train, y=y_train, validation_data=(X_test, y_test))" 307 ] 308 }, 309 { 310 "cell_type": "code", 311 "execution_count": 8, 312 "metadata": {}, 313 "outputs": [ 314 { 315 "name": "stdout", 316 "output_type": "stream", 317 "text": [ 318 "\u001b[1m547/547\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 627us/step\n" 319 ] 320 } 321 ], 322 "source": [ 323 "y_pred = model.predict(X_test)" 324 ] 325 }, 326 { 327 "cell_type": "code", 328 "execution_count": 9, 329 "metadata": {}, 330 "outputs": [], 331 "source": [ 332 "count = 0\n", 333 "for i in y_pred:\n", 334 " y_pred[count] = i.round()\n", 335 " count += 1" 336 ] 337 }, 338 { 339 "cell_type": "code", 340 "execution_count": 10, 341 "metadata": {}, 342 "outputs": [], 343 "source": [ 344 "correct = 0\n", 345 "wrong = 0\n", 346 "\n", 347 "count = 0\n", 348 "for i in y_pred:\n", 349 " if i != y_test[count]:\n", 350 " wrong += 1\n", 351 " else:\n", 352 " correct += 1\n", 353 " count += 1\n" 354 ] 355 }, 356 { 357 "cell_type": "code", 358 "execution_count": 11, 359 "metadata": {}, 360 "outputs": [ 361 { 362 "name": "stdout", 363 "output_type": "stream", 364 "text": [ 365 "15254\n", 366 "2246\n", 367 "0.8716571428571429\n" 368 ] 369 } 370 ], 371 "source": [ 372 "print(correct)\n", 373 "print(wrong)\n", 374 "print(correct / (correct + wrong))" 375 ] 376 }, 377 { 378 "cell_type": "code", 379 "execution_count": 13, 380 "metadata": {}, 381 "outputs": [], 382 "source": [ 383 "model.save('../models/MNISTClassificationModel.keras')" 384 ] 385 }, 386 { 387 "cell_type": "code", 388 "execution_count": 15, 389 "metadata": {}, 390 "outputs": [], 391 "source": [ 392 "loadedModel = tf.keras.models.load_model('../models/MNISTClassificationModel.keras')" 393 ] 394 } 395 ], 396 "metadata": { 397 "kernelspec": { 398 "display_name": "myvenv", 399 "language": "python", 400 "name": "python3" 401 }, 402 "language_info": { 403 "codemirror_mode": { 404 "name": "ipython", 405 "version": 3 406 }, 407 "file_extension": ".py", 408 "mimetype": "text/x-python", 409 "name": "python", 410 "nbconvert_exporter": "python", 411 "pygments_lexer": "ipython3", 412 "version": "3.11.2" 413 } 414 }, 415 "nbformat": 4, 416 "nbformat_minor": 2 417 }