commit 24565f8473080c94b247a7360bc30b0c9a4a9fea
parent 5cc446afedd3319a2c71cff8c797aba4bc3368ba
Author: Andrew <andrewlaack1@gmail.com>
Date: Thu, 27 Jun 2024 10:42:03 -0500
Predict cheating for students based on scores.
Diffstat:
1 file changed, 1253 insertions(+), 0 deletions(-)
diff --git a/middleSchoolExamCheating/MiddleSchoolExamCheating.ipynb b/middleSchoolExamCheating/MiddleSchoolExamCheating.ipynb
@@ -0,0 +1,1253 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "NN is generally better at this, but there is so little data (127 not cheated 17 cheated) that any metrics are hard to find."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 680,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "df = pd.read_csv('../datasets/middleSchoolExam/MiddleSchoolExam.csv')\n",
+ "df.head(1)\n",
+ "\n",
+ "df = df.dropna()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 681,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def cheated(val):\n",
+ " return val == 1\n",
+ "\n",
+ "X = df.drop(columns=['cheated'], axis=1)\n",
+ "y = df['cheated']\n",
+ "y = y.apply(cheated)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 682,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "cheated\n",
+ "False 127\n",
+ "True 17\n",
+ "Name: count, dtype: int64"
+ ]
+ },
+ "execution_count": 682,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y.value_counts()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Thoughts:\n",
+ "\n",
+ "Convert DOB to age\n",
+ "\n",
+ "Convert Sex To Binary\n",
+ "\n",
+ "Ratio Between Exam Scores\n",
+ "\n",
+ "Ratio Between Participation and Exam Scores\n",
+ "\n",
+ "Ratio Between Test and Exam Scores\n",
+ "\n",
+ "Standardize Columns\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 683,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# DOB To Age\n",
+ "\n",
+ "from datetime import date\n",
+ "\n",
+ "def calculate_age(born):\n",
+ " born = date.fromisoformat(born)\n",
+ " today = date.today()\n",
+ " return today.year - born.year - ((today.month, today.day) < (born.month, born.day))\n",
+ "\n",
+ "X['age'] = X['date_of_birth'].apply(calculate_age)\n",
+ "X = X.drop(columns=['date_of_birth'], axis=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 684,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "sex\n",
+ "False 72\n",
+ "True 72\n",
+ "Name: count, dtype: int64"
+ ]
+ },
+ "execution_count": 684,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Sex to binary\n",
+ "X['sex'] = X['sex'] == 'M'\n",
+ "X['sex'].value_counts()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 685,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "cheated\n",
+ "False 127\n",
+ "True 17\n",
+ "Name: count, dtype: int64"
+ ]
+ },
+ "execution_count": 685,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y.value_counts()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 686,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>participation_1</th>\n",
+ " <th>test_1</th>\n",
+ " <th>final_exam_1</th>\n",
+ " <th>participation_2</th>\n",
+ " <th>test_2</th>\n",
+ " <th>final_exam_2</th>\n",
+ " <th>class</th>\n",
+ " <th>year</th>\n",
+ " <th>age</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>count</th>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>mean</th>\n",
+ " <td>12.590278</td>\n",
+ " <td>11.097222</td>\n",
+ " <td>10.763889</td>\n",
+ " <td>10.569444</td>\n",
+ " <td>10.597222</td>\n",
+ " <td>9.958333</td>\n",
+ " <td>2.534722</td>\n",
+ " <td>2.027778</td>\n",
+ " <td>12.819444</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>std</th>\n",
+ " <td>3.472999</td>\n",
+ " <td>4.280905</td>\n",
+ " <td>5.139979</td>\n",
+ " <td>4.386331</td>\n",
+ " <td>4.398537</td>\n",
+ " <td>4.962270</td>\n",
+ " <td>1.115142</td>\n",
+ " <td>1.003103</td>\n",
+ " <td>1.417438</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>min</th>\n",
+ " <td>5.000000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>2.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>10.000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>25%</th>\n",
+ " <td>10.000000</td>\n",
+ " <td>8.000000</td>\n",
+ " <td>6.000000</td>\n",
+ " <td>6.000000</td>\n",
+ " <td>6.000000</td>\n",
+ " <td>5.375000</td>\n",
+ " <td>2.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>12.000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>50%</th>\n",
+ " <td>11.500000</td>\n",
+ " <td>10.000000</td>\n",
+ " <td>11.000000</td>\n",
+ " <td>11.000000</td>\n",
+ " <td>10.500000</td>\n",
+ " <td>10.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>13.000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>75%</th>\n",
+ " <td>15.000000</td>\n",
+ " <td>15.000000</td>\n",
+ " <td>15.000000</td>\n",
+ " <td>14.000000</td>\n",
+ " <td>14.000000</td>\n",
+ " <td>14.500000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>14.000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>max</th>\n",
+ " <td>20.000000</td>\n",
+ " <td>20.000000</td>\n",
+ " <td>19.500000</td>\n",
+ " <td>20.000000</td>\n",
+ " <td>20.000000</td>\n",
+ " <td>19.000000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>18.000000</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " participation_1 test_1 final_exam_1 participation_2 test_2 \\\n",
+ "count 144.000000 144.000000 144.000000 144.000000 144.000000 \n",
+ "mean 12.590278 11.097222 10.763889 10.569444 10.597222 \n",
+ "std 3.472999 4.280905 5.139979 4.386331 4.398537 \n",
+ "min 5.000000 4.000000 1.000000 3.000000 4.000000 \n",
+ "25% 10.000000 8.000000 6.000000 6.000000 6.000000 \n",
+ "50% 11.500000 10.000000 11.000000 11.000000 10.500000 \n",
+ "75% 15.000000 15.000000 15.000000 14.000000 14.000000 \n",
+ "max 20.000000 20.000000 19.500000 20.000000 20.000000 \n",
+ "\n",
+ " final_exam_2 class year age \n",
+ "count 144.000000 144.000000 144.000000 144.000000 \n",
+ "mean 9.958333 2.534722 2.027778 12.819444 \n",
+ "std 4.962270 1.115142 1.003103 1.417438 \n",
+ "min 2.000000 1.000000 1.000000 10.000000 \n",
+ "25% 5.375000 2.000000 1.000000 12.000000 \n",
+ "50% 10.000000 3.000000 3.000000 13.000000 \n",
+ "75% 14.500000 4.000000 3.000000 14.000000 \n",
+ "max 19.000000 4.000000 3.000000 18.000000 "
+ ]
+ },
+ "execution_count": 686,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X.describe()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 687,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X['Exam/Exam Ratio'] = X['final_exam_2'] / X['final_exam_1'] \n",
+ "X['Participation/Exam Ratio 1'] = X['participation_1'] / X['final_exam_1'] \n",
+ "X['Participation/Exam Ratio 2'] = X['participation_2'] / X['final_exam_2']\n",
+ "X['Test Exam Ratio 1'] = X['test_1'] / X['final_exam_1'] \n",
+ "X['Test Exam Ratio 2'] = X['test_2'] / X['final_exam_2'] \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 688,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>participation_1</th>\n",
+ " <th>test_1</th>\n",
+ " <th>final_exam_1</th>\n",
+ " <th>participation_2</th>\n",
+ " <th>test_2</th>\n",
+ " <th>final_exam_2</th>\n",
+ " <th>class</th>\n",
+ " <th>year</th>\n",
+ " <th>age</th>\n",
+ " <th>Exam/Exam Ratio</th>\n",
+ " <th>Participation/Exam Ratio 1</th>\n",
+ " <th>Participation/Exam Ratio 2</th>\n",
+ " <th>Test Exam Ratio 1</th>\n",
+ " <th>Test Exam Ratio 2</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>count</th>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " <td>144.000000</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>mean</th>\n",
+ " <td>12.590278</td>\n",
+ " <td>11.097222</td>\n",
+ " <td>10.763889</td>\n",
+ " <td>10.569444</td>\n",
+ " <td>10.597222</td>\n",
+ " <td>9.958333</td>\n",
+ " <td>2.534722</td>\n",
+ " <td>2.027778</td>\n",
+ " <td>12.819444</td>\n",
+ " <td>1.153825</td>\n",
+ " <td>1.492997</td>\n",
+ " <td>1.243145</td>\n",
+ " <td>1.204897</td>\n",
+ " <td>1.251639</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>std</th>\n",
+ " <td>3.472999</td>\n",
+ " <td>4.280905</td>\n",
+ " <td>5.139979</td>\n",
+ " <td>4.386331</td>\n",
+ " <td>4.398537</td>\n",
+ " <td>4.962270</td>\n",
+ " <td>1.115142</td>\n",
+ " <td>1.003103</td>\n",
+ " <td>1.417438</td>\n",
+ " <td>0.955162</td>\n",
+ " <td>1.011789</td>\n",
+ " <td>0.639878</td>\n",
+ " <td>0.609706</td>\n",
+ " <td>0.646539</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>min</th>\n",
+ " <td>5.000000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>2.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>10.000000</td>\n",
+ " <td>0.166667</td>\n",
+ " <td>0.611111</td>\n",
+ " <td>0.444444</td>\n",
+ " <td>0.540541</td>\n",
+ " <td>0.444444</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>25%</th>\n",
+ " <td>10.000000</td>\n",
+ " <td>8.000000</td>\n",
+ " <td>6.000000</td>\n",
+ " <td>6.000000</td>\n",
+ " <td>6.000000</td>\n",
+ " <td>5.375000</td>\n",
+ " <td>2.000000</td>\n",
+ " <td>1.000000</td>\n",
+ " <td>12.000000</td>\n",
+ " <td>0.631250</td>\n",
+ " <td>0.965368</td>\n",
+ " <td>0.851190</td>\n",
+ " <td>0.909091</td>\n",
+ " <td>0.851190</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>50%</th>\n",
+ " <td>11.500000</td>\n",
+ " <td>10.000000</td>\n",
+ " <td>11.000000</td>\n",
+ " <td>11.000000</td>\n",
+ " <td>10.500000</td>\n",
+ " <td>10.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>13.000000</td>\n",
+ " <td>0.928571</td>\n",
+ " <td>1.240385</td>\n",
+ " <td>1.090909</td>\n",
+ " <td>1.057566</td>\n",
+ " <td>1.090909</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>75%</th>\n",
+ " <td>15.000000</td>\n",
+ " <td>15.000000</td>\n",
+ " <td>15.000000</td>\n",
+ " <td>14.000000</td>\n",
+ " <td>14.000000</td>\n",
+ " <td>14.500000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>14.000000</td>\n",
+ " <td>1.135417</td>\n",
+ " <td>1.675000</td>\n",
+ " <td>1.387424</td>\n",
+ " <td>1.333333</td>\n",
+ " <td>1.387424</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>max</th>\n",
+ " <td>20.000000</td>\n",
+ " <td>20.000000</td>\n",
+ " <td>19.500000</td>\n",
+ " <td>20.000000</td>\n",
+ " <td>20.000000</td>\n",
+ " <td>19.000000</td>\n",
+ " <td>4.000000</td>\n",
+ " <td>3.000000</td>\n",
+ " <td>18.000000</td>\n",
+ " <td>7.000000</td>\n",
+ " <td>10.000000</td>\n",
+ " <td>5.000000</td>\n",
+ " <td>5.000000</td>\n",
+ " <td>5.000000</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " participation_1 test_1 final_exam_1 participation_2 test_2 \\\n",
+ "count 144.000000 144.000000 144.000000 144.000000 144.000000 \n",
+ "mean 12.590278 11.097222 10.763889 10.569444 10.597222 \n",
+ "std 3.472999 4.280905 5.139979 4.386331 4.398537 \n",
+ "min 5.000000 4.000000 1.000000 3.000000 4.000000 \n",
+ "25% 10.000000 8.000000 6.000000 6.000000 6.000000 \n",
+ "50% 11.500000 10.000000 11.000000 11.000000 10.500000 \n",
+ "75% 15.000000 15.000000 15.000000 14.000000 14.000000 \n",
+ "max 20.000000 20.000000 19.500000 20.000000 20.000000 \n",
+ "\n",
+ " final_exam_2 class year age Exam/Exam Ratio \\\n",
+ "count 144.000000 144.000000 144.000000 144.000000 144.000000 \n",
+ "mean 9.958333 2.534722 2.027778 12.819444 1.153825 \n",
+ "std 4.962270 1.115142 1.003103 1.417438 0.955162 \n",
+ "min 2.000000 1.000000 1.000000 10.000000 0.166667 \n",
+ "25% 5.375000 2.000000 1.000000 12.000000 0.631250 \n",
+ "50% 10.000000 3.000000 3.000000 13.000000 0.928571 \n",
+ "75% 14.500000 4.000000 3.000000 14.000000 1.135417 \n",
+ "max 19.000000 4.000000 3.000000 18.000000 7.000000 \n",
+ "\n",
+ " Participation/Exam Ratio 1 Participation/Exam Ratio 2 \\\n",
+ "count 144.000000 144.000000 \n",
+ "mean 1.492997 1.243145 \n",
+ "std 1.011789 0.639878 \n",
+ "min 0.611111 0.444444 \n",
+ "25% 0.965368 0.851190 \n",
+ "50% 1.240385 1.090909 \n",
+ "75% 1.675000 1.387424 \n",
+ "max 10.000000 5.000000 \n",
+ "\n",
+ " Test Exam Ratio 1 Test Exam Ratio 2 \n",
+ "count 144.000000 144.000000 \n",
+ "mean 1.204897 1.251639 \n",
+ "std 0.609706 0.646539 \n",
+ "min 0.540541 0.444444 \n",
+ "25% 0.909091 0.851190 \n",
+ "50% 1.057566 1.090909 \n",
+ "75% 1.333333 1.387424 \n",
+ "max 5.000000 5.000000 "
+ ]
+ },
+ "execution_count": 688,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X.describe()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 689,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.preprocessing import StandardScaler\n",
+ "\n",
+ "scl = StandardScaler()\n",
+ "\n",
+ "X = scl.fit_transform(X)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Should be done with preprocessing now"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 690,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=10)\n",
+ "X_val, X_test, y_val, y_test = train_test_split(X_test,y_test,random_state=10, train_size=.5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now to train models.\n",
+ "\n",
+ "Random Forest:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 691,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<style>#sk-container-id-49 {\n",
+ " /* Definition of color scheme common for light and dark mode */\n",
+ " --sklearn-color-text: black;\n",
+ " --sklearn-color-line: gray;\n",
+ " /* Definition of color scheme for unfitted estimators */\n",
+ " --sklearn-color-unfitted-level-0: #fff5e6;\n",
+ " --sklearn-color-unfitted-level-1: #f6e4d2;\n",
+ " --sklearn-color-unfitted-level-2: #ffe0b3;\n",
+ " --sklearn-color-unfitted-level-3: chocolate;\n",
+ " /* Definition of color scheme for fitted estimators */\n",
+ " --sklearn-color-fitted-level-0: #f0f8ff;\n",
+ " --sklearn-color-fitted-level-1: #d4ebff;\n",
+ " --sklearn-color-fitted-level-2: #b3dbfd;\n",
+ " --sklearn-color-fitted-level-3: cornflowerblue;\n",
+ "\n",
+ " /* Specific color for light theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-icon: #696969;\n",
+ "\n",
+ " @media (prefers-color-scheme: dark) {\n",
+ " /* Redefinition of color scheme for dark theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-icon: #878787;\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 pre {\n",
+ " padding: 0;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 input.sk-hidden--visually {\n",
+ " border: 0;\n",
+ " clip: rect(1px 1px 1px 1px);\n",
+ " clip: rect(1px, 1px, 1px, 1px);\n",
+ " height: 1px;\n",
+ " margin: -1px;\n",
+ " overflow: hidden;\n",
+ " padding: 0;\n",
+ " position: absolute;\n",
+ " width: 1px;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-dashed-wrapped {\n",
+ " border: 1px dashed var(--sklearn-color-line);\n",
+ " margin: 0 0.4em 0.5em 0.4em;\n",
+ " box-sizing: border-box;\n",
+ " padding-bottom: 0.4em;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-container {\n",
+ " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
+ " but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
+ " so we also need the `!important` here to be able to override the\n",
+ " default hidden behavior on the sphinx rendered scikit-learn.org.\n",
+ " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
+ " display: inline-block !important;\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-text-repr-fallback {\n",
+ " display: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-parallel-item,\n",
+ "div.sk-serial,\n",
+ "div.sk-item {\n",
+ " /* draw centered vertical line to link estimators */\n",
+ " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
+ " background-size: 2px 100%;\n",
+ " background-repeat: no-repeat;\n",
+ " background-position: center center;\n",
+ "}\n",
+ "\n",
+ "/* Parallel-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-49 div.sk-parallel-item::after {\n",
+ " content: \"\";\n",
+ " width: 100%;\n",
+ " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
+ " flex-grow: 1;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-parallel {\n",
+ " display: flex;\n",
+ " align-items: stretch;\n",
+ " justify-content: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-parallel-item {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-parallel-item:first-child::after {\n",
+ " align-self: flex-end;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-parallel-item:last-child::after {\n",
+ " align-self: flex-start;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-parallel-item:only-child::after {\n",
+ " width: 0;\n",
+ "}\n",
+ "\n",
+ "/* Serial-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-49 div.sk-serial {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ " align-items: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " padding-right: 1em;\n",
+ " padding-left: 1em;\n",
+ "}\n",
+ "\n",
+ "\n",
+ "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
+ "clickable and can be expanded/collapsed.\n",
+ "- Pipeline and ColumnTransformer use this feature and define the default style\n",
+ "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
+ "*/\n",
+ "\n",
+ "/* Pipeline and ColumnTransformer style (default) */\n",
+ "\n",
+ "#sk-container-id-49 div.sk-toggleable {\n",
+ " /* Default theme specific background. It is overwritten whether we have a\n",
+ " specific estimator or a Pipeline/ColumnTransformer */\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable label */\n",
+ "#sk-container-id-49 label.sk-toggleable__label {\n",
+ " cursor: pointer;\n",
+ " display: block;\n",
+ " width: 100%;\n",
+ " margin-bottom: 0;\n",
+ " padding: 0.5em;\n",
+ " box-sizing: border-box;\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 label.sk-toggleable__label-arrow:before {\n",
+ " /* Arrow on the left of the label */\n",
+ " content: \"▸\";\n",
+ " float: left;\n",
+ " margin-right: 0.25em;\n",
+ " color: var(--sklearn-color-icon);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 label.sk-toggleable__label-arrow:hover:before {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable content - dropdown */\n",
+ "\n",
+ "#sk-container-id-49 div.sk-toggleable__content {\n",
+ " max-height: 0;\n",
+ " max-width: 0;\n",
+ " overflow: hidden;\n",
+ " text-align: left;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-toggleable__content.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-toggleable__content pre {\n",
+ " margin: 0.2em;\n",
+ " border-radius: 0.25em;\n",
+ " color: var(--sklearn-color-text);\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-toggleable__content.fitted pre {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
+ " /* Expand drop-down */\n",
+ " max-height: 200px;\n",
+ " max-width: 100%;\n",
+ " overflow: auto;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
+ " content: \"▾\";\n",
+ "}\n",
+ "\n",
+ "/* Pipeline/ColumnTransformer-specific style */\n",
+ "\n",
+ "#sk-container-id-49 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific style */\n",
+ "\n",
+ "/* Colorize estimator box */\n",
+ "#sk-container-id-49 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-label label.sk-toggleable__label,\n",
+ "#sk-container-id-49 div.sk-label label {\n",
+ " /* The background is the default theme color */\n",
+ " color: var(--sklearn-color-text-on-default-background);\n",
+ "}\n",
+ "\n",
+ "/* On hover, darken the color of the background */\n",
+ "#sk-container-id-49 div.sk-label:hover label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Label box, darken color on hover, fitted */\n",
+ "#sk-container-id-49 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator label */\n",
+ "\n",
+ "#sk-container-id-49 div.sk-label label {\n",
+ " font-family: monospace;\n",
+ " font-weight: bold;\n",
+ " display: inline-block;\n",
+ " line-height: 1.2em;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-label-container {\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific */\n",
+ "#sk-container-id-49 div.sk-estimator {\n",
+ " font-family: monospace;\n",
+ " border: 1px dotted var(--sklearn-color-border-box);\n",
+ " border-radius: 0.25em;\n",
+ " box-sizing: border-box;\n",
+ " margin-bottom: 0.5em;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-estimator.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "/* on hover */\n",
+ "#sk-container-id-49 div.sk-estimator:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 div.sk-estimator.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
+ "\n",
+ "/* Common style for \"i\" and \"?\" */\n",
+ "\n",
+ ".sk-estimator-doc-link,\n",
+ "a:link.sk-estimator-doc-link,\n",
+ "a:visited.sk-estimator-doc-link {\n",
+ " float: right;\n",
+ " font-size: smaller;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1em;\n",
+ " height: 1em;\n",
+ " width: 1em;\n",
+ " text-decoration: none !important;\n",
+ " margin-left: 1ex;\n",
+ " /* unfitted */\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted,\n",
+ "a:link.sk-estimator-doc-link.fitted,\n",
+ "a:visited.sk-estimator-doc-link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "/* Span, style for the box shown on hovering the info icon */\n",
+ ".sk-estimator-doc-link span {\n",
+ " display: none;\n",
+ " z-index: 9999;\n",
+ " position: relative;\n",
+ " font-weight: normal;\n",
+ " right: .2ex;\n",
+ " padding: .5ex;\n",
+ " margin: .5ex;\n",
+ " width: min-content;\n",
+ " min-width: 20ex;\n",
+ " max-width: 50ex;\n",
+ " color: var(--sklearn-color-text);\n",
+ " box-shadow: 2pt 2pt 4pt #999;\n",
+ " /* unfitted */\n",
+ " background: var(--sklearn-color-unfitted-level-0);\n",
+ " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted span {\n",
+ " /* fitted */\n",
+ " background: var(--sklearn-color-fitted-level-0);\n",
+ " border: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link:hover span {\n",
+ " display: block;\n",
+ "}\n",
+ "\n",
+ "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
+ "\n",
+ "#sk-container-id-49 a.estimator_doc_link {\n",
+ " float: right;\n",
+ " font-size: 1rem;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1rem;\n",
+ " height: 1rem;\n",
+ " width: 1rem;\n",
+ " text-decoration: none;\n",
+ " /* unfitted */\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 a.estimator_doc_link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "#sk-container-id-49 a.estimator_doc_link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-49 a.estimator_doc_link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "</style><div id=\"sk-container-id-49\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(max_depth=2, n_estimators=10)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-49\" type=\"checkbox\" checked><label for=\"sk-estimator-id-49\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier(max_depth=2, n_estimators=10)</pre></div> </div></div></div></div>"
+ ],
+ "text/plain": [
+ "RandomForestClassifier(max_depth=2, n_estimators=10)"
+ ]
+ },
+ "execution_count": 691,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "rnd_clf = RandomForestClassifier(max_depth=2, n_estimators=10)\n",
+ "rnd_clf.fit(X_train,y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 692,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1.0"
+ ]
+ },
+ "execution_count": 692,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.metrics import accuracy_score\n",
+ "\n",
+ "y_pred = rnd_clf.predict(X=X_val)\n",
+ "accuracy_score(y_pred=y_pred, y_true=y_val)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "I have found a problem, the dataset is far too small.\n",
+ "Regardless, let's continue."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 693,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import keras\n",
+ "import tensorflow as tf \n",
+ "\n",
+ "model = keras.Sequential(layers=[\n",
+ " keras.layers.Input(shape=[15,]),\n",
+ " keras.layers.Dense(128, activation='relu'),\n",
+ " keras.layers.Dense(128, activation='relu'),\n",
+ " keras.layers.Dense(1, activation='sigmoid')\n",
+ "])\n",
+ "\n",
+ "model.compile(loss=keras.losses.binary_crossentropy, optimizer='adam', metrics=['accuracy'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 694,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/10\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 39ms/step - accuracy: 0.7644 - loss: 0.5846 - val_accuracy: 1.0000 - val_loss: 0.3691\n",
+ "Epoch 2/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8537 - loss: 0.4351 - val_accuracy: 1.0000 - val_loss: 0.2381\n",
+ "Epoch 3/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8464 - loss: 0.4046 - val_accuracy: 1.0000 - val_loss: 0.1746\n",
+ "Epoch 4/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8422 - loss: 0.3757 - val_accuracy: 1.0000 - val_loss: 0.1501\n",
+ "Epoch 5/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8527 - loss: 0.3416 - val_accuracy: 1.0000 - val_loss: 0.1443\n",
+ "Epoch 6/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8589 - loss: 0.3011 - val_accuracy: 1.0000 - val_loss: 0.1481\n",
+ "Epoch 7/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8433 - loss: 0.2969 - val_accuracy: 1.0000 - val_loss: 0.1561\n",
+ "Epoch 8/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.9003 - loss: 0.2562 - val_accuracy: 1.0000 - val_loss: 0.1634\n",
+ "Epoch 9/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8962 - loss: 0.2561 - val_accuracy: 1.0000 - val_loss: 0.1694\n",
+ "Epoch 10/10\n",
+ "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8853 - loss: 0.2665 - val_accuracy: 0.9444 - val_loss: 0.1725\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "<keras.src.callbacks.history.History at 0x7fdddb016fd0>"
+ ]
+ },
+ "execution_count": 694,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.fit(X_train,y_train, validation_data=[X_val,y_val], epochs=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 695,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "141 False\n",
+ "122 False\n",
+ "138 False\n",
+ "84 False\n",
+ "48 False\n",
+ "80 False\n",
+ "136 False\n",
+ "110 False\n",
+ "137 False\n",
+ "71 False\n",
+ "67 False\n",
+ "20 False\n",
+ "56 False\n",
+ "117 False\n",
+ "106 False\n",
+ "100 False\n",
+ "103 False\n",
+ "59 False\n",
+ "Name: cheated, dtype: bool"
+ ]
+ },
+ "execution_count": 695,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_val"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Validate NN vs RND Forest"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 696,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 34ms/step\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "0.9444444444444444"
+ ]
+ },
+ "execution_count": 696,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "y_pred = model.predict(X_test)\n",
+ "binary_predictions = (y_pred >= 0.5).astype(np.bool_)\n",
+ "\n",
+ "accuracy_score(y_pred=binary_predictions, y_true=y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 697,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.8888888888888888"
+ ]
+ },
+ "execution_count": 697,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_pred = rnd_clf.predict(X=X_test)\n",
+ "accuracy_score(y_pred=y_pred, y_true=y_test)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}