machinelearning

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

commit 24565f8473080c94b247a7360bc30b0c9a4a9fea
parent 5cc446afedd3319a2c71cff8c797aba4bc3368ba
Author: Andrew <andrewlaack1@gmail.com>
Date:   Thu, 27 Jun 2024 10:42:03 -0500

Predict cheating for students based on scores.

Diffstat:
AmiddleSchoolExamCheating/MiddleSchoolExamCheating.ipynb | 1253+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 1253 insertions(+), 0 deletions(-)

diff --git a/middleSchoolExamCheating/MiddleSchoolExamCheating.ipynb b/middleSchoolExamCheating/MiddleSchoolExamCheating.ipynb @@ -0,0 +1,1253 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "NN is generally better at this, but there is so little data (127 not cheated 17 cheated) that any metrics are hard to find." + ] + }, + { + "cell_type": "code", + "execution_count": 680, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "df = pd.read_csv('../datasets/middleSchoolExam/MiddleSchoolExam.csv')\n", + "df.head(1)\n", + "\n", + "df = df.dropna()" + ] + }, + { + "cell_type": "code", + "execution_count": 681, + "metadata": {}, + "outputs": [], + "source": [ + "def cheated(val):\n", + " return val == 1\n", + "\n", + "X = df.drop(columns=['cheated'], axis=1)\n", + "y = df['cheated']\n", + "y = y.apply(cheated)" + ] + }, + { + "cell_type": "code", + "execution_count": 682, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cheated\n", + "False 127\n", + "True 17\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 682, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.value_counts()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thoughts:\n", + "\n", + "Convert DOB to age\n", + "\n", + "Convert Sex To Binary\n", + "\n", + "Ratio Between Exam Scores\n", + "\n", + "Ratio Between Participation and Exam Scores\n", + "\n", + "Ratio Between Test and Exam Scores\n", + "\n", + "Standardize Columns\n" + ] + }, + { + "cell_type": "code", + "execution_count": 683, + "metadata": {}, + "outputs": [], + "source": [ + "# DOB To Age\n", + "\n", + "from datetime import date\n", + "\n", + "def calculate_age(born):\n", + " born = date.fromisoformat(born)\n", + " today = date.today()\n", + " return today.year - born.year - ((today.month, today.day) < (born.month, born.day))\n", + "\n", + "X['age'] = X['date_of_birth'].apply(calculate_age)\n", + "X = X.drop(columns=['date_of_birth'], axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 684, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "sex\n", + "False 72\n", + "True 72\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 684, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Sex to binary\n", + "X['sex'] = X['sex'] == 'M'\n", + "X['sex'].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 685, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cheated\n", + "False 127\n", + "True 17\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 685, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 686, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>participation_1</th>\n", + " <th>test_1</th>\n", + " <th>final_exam_1</th>\n", + " <th>participation_2</th>\n", + " <th>test_2</th>\n", + " <th>final_exam_2</th>\n", + " <th>class</th>\n", + " <th>year</th>\n", + " <th>age</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>count</th>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mean</th>\n", + " <td>12.590278</td>\n", + " <td>11.097222</td>\n", + " <td>10.763889</td>\n", + " <td>10.569444</td>\n", + " <td>10.597222</td>\n", + " <td>9.958333</td>\n", + " <td>2.534722</td>\n", + " <td>2.027778</td>\n", + " <td>12.819444</td>\n", + " </tr>\n", + " <tr>\n", + " <th>std</th>\n", + " <td>3.472999</td>\n", + " <td>4.280905</td>\n", + " <td>5.139979</td>\n", + " <td>4.386331</td>\n", + " <td>4.398537</td>\n", + " <td>4.962270</td>\n", + " <td>1.115142</td>\n", + " <td>1.003103</td>\n", + " <td>1.417438</td>\n", + " </tr>\n", + " <tr>\n", + " <th>min</th>\n", + " <td>5.000000</td>\n", + " <td>4.000000</td>\n", + " <td>1.000000</td>\n", + " <td>3.000000</td>\n", + " <td>4.000000</td>\n", + " <td>2.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>10.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>25%</th>\n", + " <td>10.000000</td>\n", + " <td>8.000000</td>\n", + " <td>6.000000</td>\n", + " <td>6.000000</td>\n", + " <td>6.000000</td>\n", + " <td>5.375000</td>\n", + " <td>2.000000</td>\n", + " <td>1.000000</td>\n", + " <td>12.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>50%</th>\n", + " <td>11.500000</td>\n", + " <td>10.000000</td>\n", + " <td>11.000000</td>\n", + " <td>11.000000</td>\n", + " <td>10.500000</td>\n", + " <td>10.000000</td>\n", + " <td>3.000000</td>\n", + " <td>3.000000</td>\n", + " <td>13.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>75%</th>\n", + " <td>15.000000</td>\n", + " <td>15.000000</td>\n", + " <td>15.000000</td>\n", + " <td>14.000000</td>\n", + " <td>14.000000</td>\n", + " <td>14.500000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>14.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>max</th>\n", + " <td>20.000000</td>\n", + " <td>20.000000</td>\n", + " <td>19.500000</td>\n", + " <td>20.000000</td>\n", + " <td>20.000000</td>\n", + " <td>19.000000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>18.000000</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " participation_1 test_1 final_exam_1 participation_2 test_2 \\\n", + "count 144.000000 144.000000 144.000000 144.000000 144.000000 \n", + "mean 12.590278 11.097222 10.763889 10.569444 10.597222 \n", + "std 3.472999 4.280905 5.139979 4.386331 4.398537 \n", + "min 5.000000 4.000000 1.000000 3.000000 4.000000 \n", + "25% 10.000000 8.000000 6.000000 6.000000 6.000000 \n", + "50% 11.500000 10.000000 11.000000 11.000000 10.500000 \n", + "75% 15.000000 15.000000 15.000000 14.000000 14.000000 \n", + "max 20.000000 20.000000 19.500000 20.000000 20.000000 \n", + "\n", + " final_exam_2 class year age \n", + "count 144.000000 144.000000 144.000000 144.000000 \n", + "mean 9.958333 2.534722 2.027778 12.819444 \n", + "std 4.962270 1.115142 1.003103 1.417438 \n", + "min 2.000000 1.000000 1.000000 10.000000 \n", + "25% 5.375000 2.000000 1.000000 12.000000 \n", + "50% 10.000000 3.000000 3.000000 13.000000 \n", + "75% 14.500000 4.000000 3.000000 14.000000 \n", + "max 19.000000 4.000000 3.000000 18.000000 " + ] + }, + "execution_count": 686, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 687, + "metadata": {}, + "outputs": [], + "source": [ + "X['Exam/Exam Ratio'] = X['final_exam_2'] / X['final_exam_1'] \n", + "X['Participation/Exam Ratio 1'] = X['participation_1'] / X['final_exam_1'] \n", + "X['Participation/Exam Ratio 2'] = X['participation_2'] / X['final_exam_2']\n", + "X['Test Exam Ratio 1'] = X['test_1'] / X['final_exam_1'] \n", + "X['Test Exam Ratio 2'] = X['test_2'] / X['final_exam_2'] \n" + ] + }, + { + "cell_type": "code", + "execution_count": 688, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>participation_1</th>\n", + " <th>test_1</th>\n", + " <th>final_exam_1</th>\n", + " <th>participation_2</th>\n", + " <th>test_2</th>\n", + " <th>final_exam_2</th>\n", + " <th>class</th>\n", + " <th>year</th>\n", + " <th>age</th>\n", + " <th>Exam/Exam Ratio</th>\n", + " <th>Participation/Exam Ratio 1</th>\n", + " <th>Participation/Exam Ratio 2</th>\n", + " <th>Test Exam Ratio 1</th>\n", + " <th>Test Exam Ratio 2</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>count</th>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " <td>144.000000</td>\n", + " </tr>\n", + " <tr>\n", + " <th>mean</th>\n", + " <td>12.590278</td>\n", + " <td>11.097222</td>\n", + " <td>10.763889</td>\n", + " <td>10.569444</td>\n", + " <td>10.597222</td>\n", + " <td>9.958333</td>\n", + " <td>2.534722</td>\n", + " <td>2.027778</td>\n", + " <td>12.819444</td>\n", + " <td>1.153825</td>\n", + " <td>1.492997</td>\n", + " <td>1.243145</td>\n", + " <td>1.204897</td>\n", + " <td>1.251639</td>\n", + " </tr>\n", + " <tr>\n", + " <th>std</th>\n", + " <td>3.472999</td>\n", + " <td>4.280905</td>\n", + " <td>5.139979</td>\n", + " <td>4.386331</td>\n", + " <td>4.398537</td>\n", + " <td>4.962270</td>\n", + " <td>1.115142</td>\n", + " <td>1.003103</td>\n", + " <td>1.417438</td>\n", + " <td>0.955162</td>\n", + " <td>1.011789</td>\n", + " <td>0.639878</td>\n", + " <td>0.609706</td>\n", + " <td>0.646539</td>\n", + " </tr>\n", + " <tr>\n", + " <th>min</th>\n", + " <td>5.000000</td>\n", + " <td>4.000000</td>\n", + " <td>1.000000</td>\n", + " <td>3.000000</td>\n", + " <td>4.000000</td>\n", + " <td>2.000000</td>\n", + " <td>1.000000</td>\n", + " <td>1.000000</td>\n", + " <td>10.000000</td>\n", + " <td>0.166667</td>\n", + " <td>0.611111</td>\n", + " <td>0.444444</td>\n", + " <td>0.540541</td>\n", + " <td>0.444444</td>\n", + " </tr>\n", + " <tr>\n", + " <th>25%</th>\n", + " <td>10.000000</td>\n", + " <td>8.000000</td>\n", + " <td>6.000000</td>\n", + " <td>6.000000</td>\n", + " <td>6.000000</td>\n", + " <td>5.375000</td>\n", + " <td>2.000000</td>\n", + " <td>1.000000</td>\n", + " <td>12.000000</td>\n", + " <td>0.631250</td>\n", + " <td>0.965368</td>\n", + " <td>0.851190</td>\n", + " <td>0.909091</td>\n", + " <td>0.851190</td>\n", + " </tr>\n", + " <tr>\n", + " <th>50%</th>\n", + " <td>11.500000</td>\n", + " <td>10.000000</td>\n", + " <td>11.000000</td>\n", + " <td>11.000000</td>\n", + " <td>10.500000</td>\n", + " <td>10.000000</td>\n", + " <td>3.000000</td>\n", + " <td>3.000000</td>\n", + " <td>13.000000</td>\n", + " <td>0.928571</td>\n", + " <td>1.240385</td>\n", + " <td>1.090909</td>\n", + " <td>1.057566</td>\n", + " <td>1.090909</td>\n", + " </tr>\n", + " <tr>\n", + " <th>75%</th>\n", + " <td>15.000000</td>\n", + " <td>15.000000</td>\n", + " <td>15.000000</td>\n", + " <td>14.000000</td>\n", + " <td>14.000000</td>\n", + " <td>14.500000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>14.000000</td>\n", + " <td>1.135417</td>\n", + " <td>1.675000</td>\n", + " <td>1.387424</td>\n", + " <td>1.333333</td>\n", + " <td>1.387424</td>\n", + " </tr>\n", + " <tr>\n", + " <th>max</th>\n", + " <td>20.000000</td>\n", + " <td>20.000000</td>\n", + " <td>19.500000</td>\n", + " <td>20.000000</td>\n", + " <td>20.000000</td>\n", + " <td>19.000000</td>\n", + " <td>4.000000</td>\n", + " <td>3.000000</td>\n", + " <td>18.000000</td>\n", + " <td>7.000000</td>\n", + " <td>10.000000</td>\n", + " <td>5.000000</td>\n", + " <td>5.000000</td>\n", + " <td>5.000000</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " participation_1 test_1 final_exam_1 participation_2 test_2 \\\n", + "count 144.000000 144.000000 144.000000 144.000000 144.000000 \n", + "mean 12.590278 11.097222 10.763889 10.569444 10.597222 \n", + "std 3.472999 4.280905 5.139979 4.386331 4.398537 \n", + "min 5.000000 4.000000 1.000000 3.000000 4.000000 \n", + "25% 10.000000 8.000000 6.000000 6.000000 6.000000 \n", + "50% 11.500000 10.000000 11.000000 11.000000 10.500000 \n", + "75% 15.000000 15.000000 15.000000 14.000000 14.000000 \n", + "max 20.000000 20.000000 19.500000 20.000000 20.000000 \n", + "\n", + " final_exam_2 class year age Exam/Exam Ratio \\\n", + "count 144.000000 144.000000 144.000000 144.000000 144.000000 \n", + "mean 9.958333 2.534722 2.027778 12.819444 1.153825 \n", + "std 4.962270 1.115142 1.003103 1.417438 0.955162 \n", + "min 2.000000 1.000000 1.000000 10.000000 0.166667 \n", + "25% 5.375000 2.000000 1.000000 12.000000 0.631250 \n", + "50% 10.000000 3.000000 3.000000 13.000000 0.928571 \n", + "75% 14.500000 4.000000 3.000000 14.000000 1.135417 \n", + "max 19.000000 4.000000 3.000000 18.000000 7.000000 \n", + "\n", + " Participation/Exam Ratio 1 Participation/Exam Ratio 2 \\\n", + "count 144.000000 144.000000 \n", + "mean 1.492997 1.243145 \n", + "std 1.011789 0.639878 \n", + "min 0.611111 0.444444 \n", + "25% 0.965368 0.851190 \n", + "50% 1.240385 1.090909 \n", + "75% 1.675000 1.387424 \n", + "max 10.000000 5.000000 \n", + "\n", + " Test Exam Ratio 1 Test Exam Ratio 2 \n", + "count 144.000000 144.000000 \n", + "mean 1.204897 1.251639 \n", + "std 0.609706 0.646539 \n", + "min 0.540541 0.444444 \n", + "25% 0.909091 0.851190 \n", + "50% 1.057566 1.090909 \n", + "75% 1.333333 1.387424 \n", + "max 5.000000 5.000000 " + ] + }, + "execution_count": 688, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 689, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "scl = StandardScaler()\n", + "\n", + "X = scl.fit_transform(X)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Should be done with preprocessing now" + ] + }, + { + "cell_type": "code", + "execution_count": 690, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=10)\n", + "X_val, X_test, y_val, y_test = train_test_split(X_test,y_test,random_state=10, train_size=.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now to train models.\n", + "\n", + "Random Forest:" + ] + }, + { + "cell_type": "code", + "execution_count": 691, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<style>#sk-container-id-49 {\n", + " /* Definition of color scheme common for light and dark mode */\n", + " --sklearn-color-text: black;\n", + " --sklearn-color-line: gray;\n", + " /* Definition of color scheme for unfitted estimators */\n", + " --sklearn-color-unfitted-level-0: #fff5e6;\n", + " --sklearn-color-unfitted-level-1: #f6e4d2;\n", + " --sklearn-color-unfitted-level-2: #ffe0b3;\n", + " --sklearn-color-unfitted-level-3: chocolate;\n", + " /* Definition of color scheme for fitted estimators */\n", + " --sklearn-color-fitted-level-0: #f0f8ff;\n", + " --sklearn-color-fitted-level-1: #d4ebff;\n", + " --sklearn-color-fitted-level-2: #b3dbfd;\n", + " --sklearn-color-fitted-level-3: cornflowerblue;\n", + "\n", + " /* Specific color for light theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-icon: #696969;\n", + "\n", + " @media (prefers-color-scheme: dark) {\n", + " /* Redefinition of color scheme for dark theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-icon: #878787;\n", + " }\n", + "}\n", + "\n", + "#sk-container-id-49 {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "#sk-container-id-49 pre {\n", + " padding: 0;\n", + "}\n", + "\n", + "#sk-container-id-49 input.sk-hidden--visually {\n", + " border: 0;\n", + " clip: rect(1px 1px 1px 1px);\n", + " clip: rect(1px, 1px, 1px, 1px);\n", + " height: 1px;\n", + " margin: -1px;\n", + " overflow: hidden;\n", + " padding: 0;\n", + " position: absolute;\n", + " width: 1px;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-dashed-wrapped {\n", + " border: 1px dashed var(--sklearn-color-line);\n", + " margin: 0 0.4em 0.5em 0.4em;\n", + " box-sizing: border-box;\n", + " padding-bottom: 0.4em;\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-container {\n", + " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", + " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", + " so we also need the `!important` here to be able to override the\n", + " default hidden behavior on the sphinx rendered scikit-learn.org.\n", + " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", + " display: inline-block !important;\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-text-repr-fallback {\n", + " display: none;\n", + "}\n", + "\n", + "div.sk-parallel-item,\n", + "div.sk-serial,\n", + "div.sk-item {\n", + " /* draw centered vertical line to link estimators */\n", + " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", + " background-size: 2px 100%;\n", + " background-repeat: no-repeat;\n", + " background-position: center center;\n", + "}\n", + "\n", + "/* Parallel-specific style estimator block */\n", + "\n", + "#sk-container-id-49 div.sk-parallel-item::after {\n", + " content: \"\";\n", + " width: 100%;\n", + " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", + " flex-grow: 1;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-parallel {\n", + " display: flex;\n", + " align-items: stretch;\n", + " justify-content: center;\n", + " background-color: var(--sklearn-color-background);\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-parallel-item {\n", + " display: flex;\n", + " flex-direction: column;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-parallel-item:first-child::after {\n", + " align-self: flex-end;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-parallel-item:last-child::after {\n", + " align-self: flex-start;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-parallel-item:only-child::after {\n", + " width: 0;\n", + "}\n", + "\n", + "/* Serial-specific style estimator block */\n", + "\n", + "#sk-container-id-49 div.sk-serial {\n", + " display: flex;\n", + " flex-direction: column;\n", + " align-items: center;\n", + " background-color: var(--sklearn-color-background);\n", + " padding-right: 1em;\n", + " padding-left: 1em;\n", + "}\n", + "\n", + "\n", + "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", + "clickable and can be expanded/collapsed.\n", + "- Pipeline and ColumnTransformer use this feature and define the default style\n", + "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", + "*/\n", + "\n", + "/* Pipeline and ColumnTransformer style (default) */\n", + "\n", + "#sk-container-id-49 div.sk-toggleable {\n", + " /* Default theme specific background. It is overwritten whether we have a\n", + " specific estimator or a Pipeline/ColumnTransformer */\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "/* Toggleable label */\n", + "#sk-container-id-49 label.sk-toggleable__label {\n", + " cursor: pointer;\n", + " display: block;\n", + " width: 100%;\n", + " margin-bottom: 0;\n", + " padding: 0.5em;\n", + " box-sizing: border-box;\n", + " text-align: center;\n", + "}\n", + "\n", + "#sk-container-id-49 label.sk-toggleable__label-arrow:before {\n", + " /* Arrow on the left of the label */\n", + " content: \"▸\";\n", + " float: left;\n", + " margin-right: 0.25em;\n", + " color: var(--sklearn-color-icon);\n", + "}\n", + "\n", + "#sk-container-id-49 label.sk-toggleable__label-arrow:hover:before {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "/* Toggleable content - dropdown */\n", + "\n", + "#sk-container-id-49 div.sk-toggleable__content {\n", + " max-height: 0;\n", + " max-width: 0;\n", + " overflow: hidden;\n", + " text-align: left;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-toggleable__content.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-toggleable__content pre {\n", + " margin: 0.2em;\n", + " border-radius: 0.25em;\n", + " color: var(--sklearn-color-text);\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-toggleable__content.fitted pre {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-49 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", + " /* Expand drop-down */\n", + " max-height: 200px;\n", + " max-width: 100%;\n", + " overflow: auto;\n", + "}\n", + "\n", + "#sk-container-id-49 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", + " content: \"▾\";\n", + "}\n", + "\n", + "/* Pipeline/ColumnTransformer-specific style */\n", + "\n", + "#sk-container-id-49 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator-specific style */\n", + "\n", + "/* Colorize estimator box */\n", + "#sk-container-id-49 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-label label.sk-toggleable__label,\n", + "#sk-container-id-49 div.sk-label label {\n", + " /* The background is the default theme color */\n", + " color: var(--sklearn-color-text-on-default-background);\n", + "}\n", + "\n", + "/* On hover, darken the color of the background */\n", + "#sk-container-id-49 div.sk-label:hover label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "/* Label box, darken color on hover, fitted */\n", + "#sk-container-id-49 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator label */\n", + "\n", + "#sk-container-id-49 div.sk-label label {\n", + " font-family: monospace;\n", + " font-weight: bold;\n", + " display: inline-block;\n", + " line-height: 1.2em;\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-label-container {\n", + " text-align: center;\n", + "}\n", + "\n", + "/* Estimator-specific */\n", + "#sk-container-id-49 div.sk-estimator {\n", + " font-family: monospace;\n", + " border: 1px dotted var(--sklearn-color-border-box);\n", + " border-radius: 0.25em;\n", + " box-sizing: border-box;\n", + " margin-bottom: 0.5em;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-estimator.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "/* on hover */\n", + "#sk-container-id-49 div.sk-estimator:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-49 div.sk-estimator.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", + "\n", + "/* Common style for \"i\" and \"?\" */\n", + "\n", + ".sk-estimator-doc-link,\n", + "a:link.sk-estimator-doc-link,\n", + "a:visited.sk-estimator-doc-link {\n", + " float: right;\n", + " font-size: smaller;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1em;\n", + " height: 1em;\n", + " width: 1em;\n", + " text-decoration: none !important;\n", + " margin-left: 1ex;\n", + " /* unfitted */\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted,\n", + "a:link.sk-estimator-doc-link.fitted,\n", + "a:visited.sk-estimator-doc-link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "/* Span, style for the box shown on hovering the info icon */\n", + ".sk-estimator-doc-link span {\n", + " display: none;\n", + " z-index: 9999;\n", + " position: relative;\n", + " font-weight: normal;\n", + " right: .2ex;\n", + " padding: .5ex;\n", + " margin: .5ex;\n", + " width: min-content;\n", + " min-width: 20ex;\n", + " max-width: 50ex;\n", + " color: var(--sklearn-color-text);\n", + " box-shadow: 2pt 2pt 4pt #999;\n", + " /* unfitted */\n", + " background: var(--sklearn-color-unfitted-level-0);\n", + " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted span {\n", + " /* fitted */\n", + " background: var(--sklearn-color-fitted-level-0);\n", + " border: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link:hover span {\n", + " display: block;\n", + "}\n", + "\n", + "/* \"?\"-specific style due to the `<a>` HTML tag */\n", + "\n", + "#sk-container-id-49 a.estimator_doc_link {\n", + " float: right;\n", + " font-size: 1rem;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1rem;\n", + " height: 1rem;\n", + " width: 1rem;\n", + " text-decoration: none;\n", + " /* unfitted */\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + "}\n", + "\n", + "#sk-container-id-49 a.estimator_doc_link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "#sk-container-id-49 a.estimator_doc_link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "#sk-container-id-49 a.estimator_doc_link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "</style><div id=\"sk-container-id-49\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(max_depth=2, n_estimators=10)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-49\" type=\"checkbox\" checked><label for=\"sk-estimator-id-49\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier(max_depth=2, n_estimators=10)</pre></div> </div></div></div></div>" + ], + "text/plain": [ + "RandomForestClassifier(max_depth=2, n_estimators=10)" + ] + }, + "execution_count": 691, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.ensemble import RandomForestClassifier\n", + "rnd_clf = RandomForestClassifier(max_depth=2, n_estimators=10)\n", + "rnd_clf.fit(X_train,y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 692, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 692, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.metrics import accuracy_score\n", + "\n", + "y_pred = rnd_clf.predict(X=X_val)\n", + "accuracy_score(y_pred=y_pred, y_true=y_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "I have found a problem, the dataset is far too small.\n", + "Regardless, let's continue." + ] + }, + { + "cell_type": "code", + "execution_count": 693, + "metadata": {}, + "outputs": [], + "source": [ + "import keras\n", + "import tensorflow as tf \n", + "\n", + "model = keras.Sequential(layers=[\n", + " keras.layers.Input(shape=[15,]),\n", + " keras.layers.Dense(128, activation='relu'),\n", + " keras.layers.Dense(128, activation='relu'),\n", + " keras.layers.Dense(1, activation='sigmoid')\n", + "])\n", + "\n", + "model.compile(loss=keras.losses.binary_crossentropy, optimizer='adam', metrics=['accuracy'])" + ] + }, + { + "cell_type": "code", + "execution_count": 694, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 39ms/step - accuracy: 0.7644 - loss: 0.5846 - val_accuracy: 1.0000 - val_loss: 0.3691\n", + "Epoch 2/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8537 - loss: 0.4351 - val_accuracy: 1.0000 - val_loss: 0.2381\n", + "Epoch 3/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8464 - loss: 0.4046 - val_accuracy: 1.0000 - val_loss: 0.1746\n", + "Epoch 4/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8422 - loss: 0.3757 - val_accuracy: 1.0000 - val_loss: 0.1501\n", + "Epoch 5/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8527 - loss: 0.3416 - val_accuracy: 1.0000 - val_loss: 0.1443\n", + "Epoch 6/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8589 - loss: 0.3011 - val_accuracy: 1.0000 - val_loss: 0.1481\n", + "Epoch 7/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8433 - loss: 0.2969 - val_accuracy: 1.0000 - val_loss: 0.1561\n", + "Epoch 8/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.9003 - loss: 0.2562 - val_accuracy: 1.0000 - val_loss: 0.1634\n", + "Epoch 9/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8962 - loss: 0.2561 - val_accuracy: 1.0000 - val_loss: 0.1694\n", + "Epoch 10/10\n", + "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8853 - loss: 0.2665 - val_accuracy: 0.9444 - val_loss: 0.1725\n" + ] + }, + { + "data": { + "text/plain": [ + "<keras.src.callbacks.history.History at 0x7fdddb016fd0>" + ] + }, + "execution_count": 694, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(X_train,y_train, validation_data=[X_val,y_val], epochs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 695, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "141 False\n", + "122 False\n", + "138 False\n", + "84 False\n", + "48 False\n", + "80 False\n", + "136 False\n", + "110 False\n", + "137 False\n", + "71 False\n", + "67 False\n", + "20 False\n", + "56 False\n", + "117 False\n", + "106 False\n", + "100 False\n", + "103 False\n", + "59 False\n", + "Name: cheated, dtype: bool" + ] + }, + "execution_count": 695, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_val" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Validate NN vs RND Forest" + ] + }, + { + "cell_type": "code", + "execution_count": 696, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 34ms/step\n" + ] + }, + { + "data": { + "text/plain": [ + "0.9444444444444444" + ] + }, + "execution_count": 696, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "y_pred = model.predict(X_test)\n", + "binary_predictions = (y_pred >= 0.5).astype(np.bool_)\n", + "\n", + "accuracy_score(y_pred=binary_predictions, y_true=y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 697, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8888888888888888" + ] + }, + "execution_count": 697, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred = rnd_clf.predict(X=X_test)\n", + "accuracy_score(y_pred=y_pred, y_true=y_test)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}