commit d40d44c47be6e86d8bf6b91cca6c4eeaa554495f
parent db12c8d974d73a1f955784efa73c85e59e058fc5
Author: Andrew <andrewlaack1@gmail.com>
Date: Mon, 3 Jun 2024 22:17:18 -0500
Ending the day
Diffstat:
1 file changed, 41 insertions(+), 0 deletions(-)
diff --git a/mnist/MNISTClassification.ipynb b/mnist/MNISTClassification.ipynb
@@ -565,6 +565,47 @@
"# This shows that pure false is quite good...\n",
"cross_val_score(dummy_clf, X_train, y_train_5, cv=3, scoring='accuracy')"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import cross_val_predict\n",
+ "\n",
+ "# Returns predictions from each fold instead of percentages (like cross_val_score would)\n",
+ "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[53481, 1098],\n",
+ " [ 1546, 3875]])"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create confusion matrix based on actual and predicted values\n",
+ "\n",
+ "# As we can see, in retrospect, there are 1098 images that \n",
+ "# were thought to be 5 that are not and 1546 5's that were \n",
+ "# thought to be not 5.\n",
+ "\n",
+ "from sklearn.metrics import confusion_matrix\n",
+ "cm = confusion_matrix(y_train_5, y_train_pred)\n",
+ "cm"
+ ]
}
],
"metadata": {