machinelearning

Machine learning code
git clone git://git.laack.co/machinelearning.git
Log | Files | Refs

MiddleSchoolExamCheating.ipynb (44609B)


      1 {
      2  "cells": [
      3   {
      4    "cell_type": "markdown",
      5    "metadata": {},
      6    "source": [
      7     "NN is generally better at this, but there is so little data (127 not cheated 17 cheated) that any metrics are hard to find."
      8    ]
      9   },
     10   {
     11    "cell_type": "code",
     12    "execution_count": 680,
     13    "metadata": {},
     14    "outputs": [],
     15    "source": [
     16     "import pandas as pd\n",
     17     "df = pd.read_csv('../datasets/middleSchoolExam/MiddleSchoolExam.csv')\n",
     18     "df.head(1)\n",
     19     "\n",
     20     "df = df.dropna()"
     21    ]
     22   },
     23   {
     24    "cell_type": "code",
     25    "execution_count": 681,
     26    "metadata": {},
     27    "outputs": [],
     28    "source": [
     29     "def cheated(val):\n",
     30     "    return val == 1\n",
     31     "\n",
     32     "X = df.drop(columns=['cheated'], axis=1)\n",
     33     "y = df['cheated']\n",
     34     "y = y.apply(cheated)"
     35    ]
     36   },
     37   {
     38    "cell_type": "code",
     39    "execution_count": 682,
     40    "metadata": {},
     41    "outputs": [
     42     {
     43      "data": {
     44       "text/plain": [
     45        "cheated\n",
     46        "False    127\n",
     47        "True      17\n",
     48        "Name: count, dtype: int64"
     49       ]
     50      },
     51      "execution_count": 682,
     52      "metadata": {},
     53      "output_type": "execute_result"
     54     }
     55    ],
     56    "source": [
     57     "y.value_counts()"
     58    ]
     59   },
     60   {
     61    "cell_type": "markdown",
     62    "metadata": {},
     63    "source": [
     64     "Thoughts:\n",
     65     "\n",
     66     "Convert DOB to age\n",
     67     "\n",
     68     "Convert Sex To Binary\n",
     69     "\n",
     70     "Ratio Between Exam Scores\n",
     71     "\n",
     72     "Ratio Between Participation and Exam Scores\n",
     73     "\n",
     74     "Ratio Between Test and Exam Scores\n",
     75     "\n",
     76     "Standardize Columns\n"
     77    ]
     78   },
     79   {
     80    "cell_type": "code",
     81    "execution_count": 683,
     82    "metadata": {},
     83    "outputs": [],
     84    "source": [
     85     "# DOB To Age\n",
     86     "\n",
     87     "from datetime import date\n",
     88     "\n",
     89     "def calculate_age(born):\n",
     90     "    born = date.fromisoformat(born)\n",
     91     "    today = date.today()\n",
     92     "    return today.year - born.year - ((today.month, today.day) < (born.month, born.day))\n",
     93     "\n",
     94     "X['age'] = X['date_of_birth'].apply(calculate_age)\n",
     95     "X = X.drop(columns=['date_of_birth'], axis=1)"
     96    ]
     97   },
     98   {
     99    "cell_type": "code",
    100    "execution_count": 684,
    101    "metadata": {},
    102    "outputs": [
    103     {
    104      "data": {
    105       "text/plain": [
    106        "sex\n",
    107        "False    72\n",
    108        "True     72\n",
    109        "Name: count, dtype: int64"
    110       ]
    111      },
    112      "execution_count": 684,
    113      "metadata": {},
    114      "output_type": "execute_result"
    115     }
    116    ],
    117    "source": [
    118     "# Sex to binary\n",
    119     "X['sex'] = X['sex'] == 'M'\n",
    120     "X['sex'].value_counts()"
    121    ]
    122   },
    123   {
    124    "cell_type": "code",
    125    "execution_count": 685,
    126    "metadata": {},
    127    "outputs": [
    128     {
    129      "data": {
    130       "text/plain": [
    131        "cheated\n",
    132        "False    127\n",
    133        "True      17\n",
    134        "Name: count, dtype: int64"
    135       ]
    136      },
    137      "execution_count": 685,
    138      "metadata": {},
    139      "output_type": "execute_result"
    140     }
    141    ],
    142    "source": [
    143     "y.value_counts()"
    144    ]
    145   },
    146   {
    147    "cell_type": "code",
    148    "execution_count": 686,
    149    "metadata": {},
    150    "outputs": [
    151     {
    152      "data": {
    153       "text/html": [
    154        "<div>\n",
    155        "<style scoped>\n",
    156        "    .dataframe tbody tr th:only-of-type {\n",
    157        "        vertical-align: middle;\n",
    158        "    }\n",
    159        "\n",
    160        "    .dataframe tbody tr th {\n",
    161        "        vertical-align: top;\n",
    162        "    }\n",
    163        "\n",
    164        "    .dataframe thead th {\n",
    165        "        text-align: right;\n",
    166        "    }\n",
    167        "</style>\n",
    168        "<table border=\"1\" class=\"dataframe\">\n",
    169        "  <thead>\n",
    170        "    <tr style=\"text-align: right;\">\n",
    171        "      <th></th>\n",
    172        "      <th>participation_1</th>\n",
    173        "      <th>test_1</th>\n",
    174        "      <th>final_exam_1</th>\n",
    175        "      <th>participation_2</th>\n",
    176        "      <th>test_2</th>\n",
    177        "      <th>final_exam_2</th>\n",
    178        "      <th>class</th>\n",
    179        "      <th>year</th>\n",
    180        "      <th>age</th>\n",
    181        "    </tr>\n",
    182        "  </thead>\n",
    183        "  <tbody>\n",
    184        "    <tr>\n",
    185        "      <th>count</th>\n",
    186        "      <td>144.000000</td>\n",
    187        "      <td>144.000000</td>\n",
    188        "      <td>144.000000</td>\n",
    189        "      <td>144.000000</td>\n",
    190        "      <td>144.000000</td>\n",
    191        "      <td>144.000000</td>\n",
    192        "      <td>144.000000</td>\n",
    193        "      <td>144.000000</td>\n",
    194        "      <td>144.000000</td>\n",
    195        "    </tr>\n",
    196        "    <tr>\n",
    197        "      <th>mean</th>\n",
    198        "      <td>12.590278</td>\n",
    199        "      <td>11.097222</td>\n",
    200        "      <td>10.763889</td>\n",
    201        "      <td>10.569444</td>\n",
    202        "      <td>10.597222</td>\n",
    203        "      <td>9.958333</td>\n",
    204        "      <td>2.534722</td>\n",
    205        "      <td>2.027778</td>\n",
    206        "      <td>12.819444</td>\n",
    207        "    </tr>\n",
    208        "    <tr>\n",
    209        "      <th>std</th>\n",
    210        "      <td>3.472999</td>\n",
    211        "      <td>4.280905</td>\n",
    212        "      <td>5.139979</td>\n",
    213        "      <td>4.386331</td>\n",
    214        "      <td>4.398537</td>\n",
    215        "      <td>4.962270</td>\n",
    216        "      <td>1.115142</td>\n",
    217        "      <td>1.003103</td>\n",
    218        "      <td>1.417438</td>\n",
    219        "    </tr>\n",
    220        "    <tr>\n",
    221        "      <th>min</th>\n",
    222        "      <td>5.000000</td>\n",
    223        "      <td>4.000000</td>\n",
    224        "      <td>1.000000</td>\n",
    225        "      <td>3.000000</td>\n",
    226        "      <td>4.000000</td>\n",
    227        "      <td>2.000000</td>\n",
    228        "      <td>1.000000</td>\n",
    229        "      <td>1.000000</td>\n",
    230        "      <td>10.000000</td>\n",
    231        "    </tr>\n",
    232        "    <tr>\n",
    233        "      <th>25%</th>\n",
    234        "      <td>10.000000</td>\n",
    235        "      <td>8.000000</td>\n",
    236        "      <td>6.000000</td>\n",
    237        "      <td>6.000000</td>\n",
    238        "      <td>6.000000</td>\n",
    239        "      <td>5.375000</td>\n",
    240        "      <td>2.000000</td>\n",
    241        "      <td>1.000000</td>\n",
    242        "      <td>12.000000</td>\n",
    243        "    </tr>\n",
    244        "    <tr>\n",
    245        "      <th>50%</th>\n",
    246        "      <td>11.500000</td>\n",
    247        "      <td>10.000000</td>\n",
    248        "      <td>11.000000</td>\n",
    249        "      <td>11.000000</td>\n",
    250        "      <td>10.500000</td>\n",
    251        "      <td>10.000000</td>\n",
    252        "      <td>3.000000</td>\n",
    253        "      <td>3.000000</td>\n",
    254        "      <td>13.000000</td>\n",
    255        "    </tr>\n",
    256        "    <tr>\n",
    257        "      <th>75%</th>\n",
    258        "      <td>15.000000</td>\n",
    259        "      <td>15.000000</td>\n",
    260        "      <td>15.000000</td>\n",
    261        "      <td>14.000000</td>\n",
    262        "      <td>14.000000</td>\n",
    263        "      <td>14.500000</td>\n",
    264        "      <td>4.000000</td>\n",
    265        "      <td>3.000000</td>\n",
    266        "      <td>14.000000</td>\n",
    267        "    </tr>\n",
    268        "    <tr>\n",
    269        "      <th>max</th>\n",
    270        "      <td>20.000000</td>\n",
    271        "      <td>20.000000</td>\n",
    272        "      <td>19.500000</td>\n",
    273        "      <td>20.000000</td>\n",
    274        "      <td>20.000000</td>\n",
    275        "      <td>19.000000</td>\n",
    276        "      <td>4.000000</td>\n",
    277        "      <td>3.000000</td>\n",
    278        "      <td>18.000000</td>\n",
    279        "    </tr>\n",
    280        "  </tbody>\n",
    281        "</table>\n",
    282        "</div>"
    283       ],
    284       "text/plain": [
    285        "       participation_1      test_1  final_exam_1  participation_2      test_2  \\\n",
    286        "count       144.000000  144.000000    144.000000       144.000000  144.000000   \n",
    287        "mean         12.590278   11.097222     10.763889        10.569444   10.597222   \n",
    288        "std           3.472999    4.280905      5.139979         4.386331    4.398537   \n",
    289        "min           5.000000    4.000000      1.000000         3.000000    4.000000   \n",
    290        "25%          10.000000    8.000000      6.000000         6.000000    6.000000   \n",
    291        "50%          11.500000   10.000000     11.000000        11.000000   10.500000   \n",
    292        "75%          15.000000   15.000000     15.000000        14.000000   14.000000   \n",
    293        "max          20.000000   20.000000     19.500000        20.000000   20.000000   \n",
    294        "\n",
    295        "       final_exam_2       class        year         age  \n",
    296        "count    144.000000  144.000000  144.000000  144.000000  \n",
    297        "mean       9.958333    2.534722    2.027778   12.819444  \n",
    298        "std        4.962270    1.115142    1.003103    1.417438  \n",
    299        "min        2.000000    1.000000    1.000000   10.000000  \n",
    300        "25%        5.375000    2.000000    1.000000   12.000000  \n",
    301        "50%       10.000000    3.000000    3.000000   13.000000  \n",
    302        "75%       14.500000    4.000000    3.000000   14.000000  \n",
    303        "max       19.000000    4.000000    3.000000   18.000000  "
    304       ]
    305      },
    306      "execution_count": 686,
    307      "metadata": {},
    308      "output_type": "execute_result"
    309     }
    310    ],
    311    "source": [
    312     "X.describe()"
    313    ]
    314   },
    315   {
    316    "cell_type": "code",
    317    "execution_count": 687,
    318    "metadata": {},
    319    "outputs": [],
    320    "source": [
    321     "X['Exam/Exam Ratio'] = X['final_exam_2'] / X['final_exam_1'] \n",
    322     "X['Participation/Exam Ratio 1'] = X['participation_1'] / X['final_exam_1'] \n",
    323     "X['Participation/Exam Ratio 2'] = X['participation_2'] / X['final_exam_2']\n",
    324     "X['Test Exam Ratio 1'] = X['test_1'] / X['final_exam_1'] \n",
    325     "X['Test Exam Ratio 2'] = X['test_2'] / X['final_exam_2'] \n"
    326    ]
    327   },
    328   {
    329    "cell_type": "code",
    330    "execution_count": 688,
    331    "metadata": {},
    332    "outputs": [
    333     {
    334      "data": {
    335       "text/html": [
    336        "<div>\n",
    337        "<style scoped>\n",
    338        "    .dataframe tbody tr th:only-of-type {\n",
    339        "        vertical-align: middle;\n",
    340        "    }\n",
    341        "\n",
    342        "    .dataframe tbody tr th {\n",
    343        "        vertical-align: top;\n",
    344        "    }\n",
    345        "\n",
    346        "    .dataframe thead th {\n",
    347        "        text-align: right;\n",
    348        "    }\n",
    349        "</style>\n",
    350        "<table border=\"1\" class=\"dataframe\">\n",
    351        "  <thead>\n",
    352        "    <tr style=\"text-align: right;\">\n",
    353        "      <th></th>\n",
    354        "      <th>participation_1</th>\n",
    355        "      <th>test_1</th>\n",
    356        "      <th>final_exam_1</th>\n",
    357        "      <th>participation_2</th>\n",
    358        "      <th>test_2</th>\n",
    359        "      <th>final_exam_2</th>\n",
    360        "      <th>class</th>\n",
    361        "      <th>year</th>\n",
    362        "      <th>age</th>\n",
    363        "      <th>Exam/Exam Ratio</th>\n",
    364        "      <th>Participation/Exam Ratio 1</th>\n",
    365        "      <th>Participation/Exam Ratio 2</th>\n",
    366        "      <th>Test Exam Ratio 1</th>\n",
    367        "      <th>Test Exam Ratio 2</th>\n",
    368        "    </tr>\n",
    369        "  </thead>\n",
    370        "  <tbody>\n",
    371        "    <tr>\n",
    372        "      <th>count</th>\n",
    373        "      <td>144.000000</td>\n",
    374        "      <td>144.000000</td>\n",
    375        "      <td>144.000000</td>\n",
    376        "      <td>144.000000</td>\n",
    377        "      <td>144.000000</td>\n",
    378        "      <td>144.000000</td>\n",
    379        "      <td>144.000000</td>\n",
    380        "      <td>144.000000</td>\n",
    381        "      <td>144.000000</td>\n",
    382        "      <td>144.000000</td>\n",
    383        "      <td>144.000000</td>\n",
    384        "      <td>144.000000</td>\n",
    385        "      <td>144.000000</td>\n",
    386        "      <td>144.000000</td>\n",
    387        "    </tr>\n",
    388        "    <tr>\n",
    389        "      <th>mean</th>\n",
    390        "      <td>12.590278</td>\n",
    391        "      <td>11.097222</td>\n",
    392        "      <td>10.763889</td>\n",
    393        "      <td>10.569444</td>\n",
    394        "      <td>10.597222</td>\n",
    395        "      <td>9.958333</td>\n",
    396        "      <td>2.534722</td>\n",
    397        "      <td>2.027778</td>\n",
    398        "      <td>12.819444</td>\n",
    399        "      <td>1.153825</td>\n",
    400        "      <td>1.492997</td>\n",
    401        "      <td>1.243145</td>\n",
    402        "      <td>1.204897</td>\n",
    403        "      <td>1.251639</td>\n",
    404        "    </tr>\n",
    405        "    <tr>\n",
    406        "      <th>std</th>\n",
    407        "      <td>3.472999</td>\n",
    408        "      <td>4.280905</td>\n",
    409        "      <td>5.139979</td>\n",
    410        "      <td>4.386331</td>\n",
    411        "      <td>4.398537</td>\n",
    412        "      <td>4.962270</td>\n",
    413        "      <td>1.115142</td>\n",
    414        "      <td>1.003103</td>\n",
    415        "      <td>1.417438</td>\n",
    416        "      <td>0.955162</td>\n",
    417        "      <td>1.011789</td>\n",
    418        "      <td>0.639878</td>\n",
    419        "      <td>0.609706</td>\n",
    420        "      <td>0.646539</td>\n",
    421        "    </tr>\n",
    422        "    <tr>\n",
    423        "      <th>min</th>\n",
    424        "      <td>5.000000</td>\n",
    425        "      <td>4.000000</td>\n",
    426        "      <td>1.000000</td>\n",
    427        "      <td>3.000000</td>\n",
    428        "      <td>4.000000</td>\n",
    429        "      <td>2.000000</td>\n",
    430        "      <td>1.000000</td>\n",
    431        "      <td>1.000000</td>\n",
    432        "      <td>10.000000</td>\n",
    433        "      <td>0.166667</td>\n",
    434        "      <td>0.611111</td>\n",
    435        "      <td>0.444444</td>\n",
    436        "      <td>0.540541</td>\n",
    437        "      <td>0.444444</td>\n",
    438        "    </tr>\n",
    439        "    <tr>\n",
    440        "      <th>25%</th>\n",
    441        "      <td>10.000000</td>\n",
    442        "      <td>8.000000</td>\n",
    443        "      <td>6.000000</td>\n",
    444        "      <td>6.000000</td>\n",
    445        "      <td>6.000000</td>\n",
    446        "      <td>5.375000</td>\n",
    447        "      <td>2.000000</td>\n",
    448        "      <td>1.000000</td>\n",
    449        "      <td>12.000000</td>\n",
    450        "      <td>0.631250</td>\n",
    451        "      <td>0.965368</td>\n",
    452        "      <td>0.851190</td>\n",
    453        "      <td>0.909091</td>\n",
    454        "      <td>0.851190</td>\n",
    455        "    </tr>\n",
    456        "    <tr>\n",
    457        "      <th>50%</th>\n",
    458        "      <td>11.500000</td>\n",
    459        "      <td>10.000000</td>\n",
    460        "      <td>11.000000</td>\n",
    461        "      <td>11.000000</td>\n",
    462        "      <td>10.500000</td>\n",
    463        "      <td>10.000000</td>\n",
    464        "      <td>3.000000</td>\n",
    465        "      <td>3.000000</td>\n",
    466        "      <td>13.000000</td>\n",
    467        "      <td>0.928571</td>\n",
    468        "      <td>1.240385</td>\n",
    469        "      <td>1.090909</td>\n",
    470        "      <td>1.057566</td>\n",
    471        "      <td>1.090909</td>\n",
    472        "    </tr>\n",
    473        "    <tr>\n",
    474        "      <th>75%</th>\n",
    475        "      <td>15.000000</td>\n",
    476        "      <td>15.000000</td>\n",
    477        "      <td>15.000000</td>\n",
    478        "      <td>14.000000</td>\n",
    479        "      <td>14.000000</td>\n",
    480        "      <td>14.500000</td>\n",
    481        "      <td>4.000000</td>\n",
    482        "      <td>3.000000</td>\n",
    483        "      <td>14.000000</td>\n",
    484        "      <td>1.135417</td>\n",
    485        "      <td>1.675000</td>\n",
    486        "      <td>1.387424</td>\n",
    487        "      <td>1.333333</td>\n",
    488        "      <td>1.387424</td>\n",
    489        "    </tr>\n",
    490        "    <tr>\n",
    491        "      <th>max</th>\n",
    492        "      <td>20.000000</td>\n",
    493        "      <td>20.000000</td>\n",
    494        "      <td>19.500000</td>\n",
    495        "      <td>20.000000</td>\n",
    496        "      <td>20.000000</td>\n",
    497        "      <td>19.000000</td>\n",
    498        "      <td>4.000000</td>\n",
    499        "      <td>3.000000</td>\n",
    500        "      <td>18.000000</td>\n",
    501        "      <td>7.000000</td>\n",
    502        "      <td>10.000000</td>\n",
    503        "      <td>5.000000</td>\n",
    504        "      <td>5.000000</td>\n",
    505        "      <td>5.000000</td>\n",
    506        "    </tr>\n",
    507        "  </tbody>\n",
    508        "</table>\n",
    509        "</div>"
    510       ],
    511       "text/plain": [
    512        "       participation_1      test_1  final_exam_1  participation_2      test_2  \\\n",
    513        "count       144.000000  144.000000    144.000000       144.000000  144.000000   \n",
    514        "mean         12.590278   11.097222     10.763889        10.569444   10.597222   \n",
    515        "std           3.472999    4.280905      5.139979         4.386331    4.398537   \n",
    516        "min           5.000000    4.000000      1.000000         3.000000    4.000000   \n",
    517        "25%          10.000000    8.000000      6.000000         6.000000    6.000000   \n",
    518        "50%          11.500000   10.000000     11.000000        11.000000   10.500000   \n",
    519        "75%          15.000000   15.000000     15.000000        14.000000   14.000000   \n",
    520        "max          20.000000   20.000000     19.500000        20.000000   20.000000   \n",
    521        "\n",
    522        "       final_exam_2       class        year         age  Exam/Exam Ratio  \\\n",
    523        "count    144.000000  144.000000  144.000000  144.000000       144.000000   \n",
    524        "mean       9.958333    2.534722    2.027778   12.819444         1.153825   \n",
    525        "std        4.962270    1.115142    1.003103    1.417438         0.955162   \n",
    526        "min        2.000000    1.000000    1.000000   10.000000         0.166667   \n",
    527        "25%        5.375000    2.000000    1.000000   12.000000         0.631250   \n",
    528        "50%       10.000000    3.000000    3.000000   13.000000         0.928571   \n",
    529        "75%       14.500000    4.000000    3.000000   14.000000         1.135417   \n",
    530        "max       19.000000    4.000000    3.000000   18.000000         7.000000   \n",
    531        "\n",
    532        "       Participation/Exam Ratio 1  Participation/Exam Ratio 2  \\\n",
    533        "count                  144.000000                  144.000000   \n",
    534        "mean                     1.492997                    1.243145   \n",
    535        "std                      1.011789                    0.639878   \n",
    536        "min                      0.611111                    0.444444   \n",
    537        "25%                      0.965368                    0.851190   \n",
    538        "50%                      1.240385                    1.090909   \n",
    539        "75%                      1.675000                    1.387424   \n",
    540        "max                     10.000000                    5.000000   \n",
    541        "\n",
    542        "       Test Exam Ratio 1  Test Exam Ratio 2  \n",
    543        "count         144.000000         144.000000  \n",
    544        "mean            1.204897           1.251639  \n",
    545        "std             0.609706           0.646539  \n",
    546        "min             0.540541           0.444444  \n",
    547        "25%             0.909091           0.851190  \n",
    548        "50%             1.057566           1.090909  \n",
    549        "75%             1.333333           1.387424  \n",
    550        "max             5.000000           5.000000  "
    551       ]
    552      },
    553      "execution_count": 688,
    554      "metadata": {},
    555      "output_type": "execute_result"
    556     }
    557    ],
    558    "source": [
    559     "X.describe()"
    560    ]
    561   },
    562   {
    563    "cell_type": "code",
    564    "execution_count": 689,
    565    "metadata": {},
    566    "outputs": [],
    567    "source": [
    568     "from sklearn.preprocessing import StandardScaler\n",
    569     "\n",
    570     "scl = StandardScaler()\n",
    571     "\n",
    572     "X = scl.fit_transform(X)"
    573    ]
    574   },
    575   {
    576    "cell_type": "markdown",
    577    "metadata": {},
    578    "source": [
    579     "Should be done with preprocessing now"
    580    ]
    581   },
    582   {
    583    "cell_type": "code",
    584    "execution_count": 690,
    585    "metadata": {},
    586    "outputs": [],
    587    "source": [
    588     "from sklearn.model_selection import train_test_split\n",
    589     "\n",
    590     "X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=10)\n",
    591     "X_val, X_test, y_val, y_test = train_test_split(X_test,y_test,random_state=10, train_size=.5)"
    592    ]
    593   },
    594   {
    595    "cell_type": "markdown",
    596    "metadata": {},
    597    "source": [
    598     "Now to train models.\n",
    599     "\n",
    600     "Random Forest:"
    601    ]
    602   },
    603   {
    604    "cell_type": "code",
    605    "execution_count": 691,
    606    "metadata": {},
    607    "outputs": [
    608     {
    609      "data": {
    610       "text/html": [
    611        "<style>#sk-container-id-49 {\n",
    612        "  /* Definition of color scheme common for light and dark mode */\n",
    613        "  --sklearn-color-text: black;\n",
    614        "  --sklearn-color-line: gray;\n",
    615        "  /* Definition of color scheme for unfitted estimators */\n",
    616        "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
    617        "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
    618        "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
    619        "  --sklearn-color-unfitted-level-3: chocolate;\n",
    620        "  /* Definition of color scheme for fitted estimators */\n",
    621        "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
    622        "  --sklearn-color-fitted-level-1: #d4ebff;\n",
    623        "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
    624        "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
    625        "\n",
    626        "  /* Specific color for light theme */\n",
    627        "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
    628        "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
    629        "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
    630        "  --sklearn-color-icon: #696969;\n",
    631        "\n",
    632        "  @media (prefers-color-scheme: dark) {\n",
    633        "    /* Redefinition of color scheme for dark theme */\n",
    634        "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
    635        "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
    636        "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
    637        "    --sklearn-color-icon: #878787;\n",
    638        "  }\n",
    639        "}\n",
    640        "\n",
    641        "#sk-container-id-49 {\n",
    642        "  color: var(--sklearn-color-text);\n",
    643        "}\n",
    644        "\n",
    645        "#sk-container-id-49 pre {\n",
    646        "  padding: 0;\n",
    647        "}\n",
    648        "\n",
    649        "#sk-container-id-49 input.sk-hidden--visually {\n",
    650        "  border: 0;\n",
    651        "  clip: rect(1px 1px 1px 1px);\n",
    652        "  clip: rect(1px, 1px, 1px, 1px);\n",
    653        "  height: 1px;\n",
    654        "  margin: -1px;\n",
    655        "  overflow: hidden;\n",
    656        "  padding: 0;\n",
    657        "  position: absolute;\n",
    658        "  width: 1px;\n",
    659        "}\n",
    660        "\n",
    661        "#sk-container-id-49 div.sk-dashed-wrapped {\n",
    662        "  border: 1px dashed var(--sklearn-color-line);\n",
    663        "  margin: 0 0.4em 0.5em 0.4em;\n",
    664        "  box-sizing: border-box;\n",
    665        "  padding-bottom: 0.4em;\n",
    666        "  background-color: var(--sklearn-color-background);\n",
    667        "}\n",
    668        "\n",
    669        "#sk-container-id-49 div.sk-container {\n",
    670        "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
    671        "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
    672        "     so we also need the `!important` here to be able to override the\n",
    673        "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
    674        "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
    675        "  display: inline-block !important;\n",
    676        "  position: relative;\n",
    677        "}\n",
    678        "\n",
    679        "#sk-container-id-49 div.sk-text-repr-fallback {\n",
    680        "  display: none;\n",
    681        "}\n",
    682        "\n",
    683        "div.sk-parallel-item,\n",
    684        "div.sk-serial,\n",
    685        "div.sk-item {\n",
    686        "  /* draw centered vertical line to link estimators */\n",
    687        "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
    688        "  background-size: 2px 100%;\n",
    689        "  background-repeat: no-repeat;\n",
    690        "  background-position: center center;\n",
    691        "}\n",
    692        "\n",
    693        "/* Parallel-specific style estimator block */\n",
    694        "\n",
    695        "#sk-container-id-49 div.sk-parallel-item::after {\n",
    696        "  content: \"\";\n",
    697        "  width: 100%;\n",
    698        "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
    699        "  flex-grow: 1;\n",
    700        "}\n",
    701        "\n",
    702        "#sk-container-id-49 div.sk-parallel {\n",
    703        "  display: flex;\n",
    704        "  align-items: stretch;\n",
    705        "  justify-content: center;\n",
    706        "  background-color: var(--sklearn-color-background);\n",
    707        "  position: relative;\n",
    708        "}\n",
    709        "\n",
    710        "#sk-container-id-49 div.sk-parallel-item {\n",
    711        "  display: flex;\n",
    712        "  flex-direction: column;\n",
    713        "}\n",
    714        "\n",
    715        "#sk-container-id-49 div.sk-parallel-item:first-child::after {\n",
    716        "  align-self: flex-end;\n",
    717        "  width: 50%;\n",
    718        "}\n",
    719        "\n",
    720        "#sk-container-id-49 div.sk-parallel-item:last-child::after {\n",
    721        "  align-self: flex-start;\n",
    722        "  width: 50%;\n",
    723        "}\n",
    724        "\n",
    725        "#sk-container-id-49 div.sk-parallel-item:only-child::after {\n",
    726        "  width: 0;\n",
    727        "}\n",
    728        "\n",
    729        "/* Serial-specific style estimator block */\n",
    730        "\n",
    731        "#sk-container-id-49 div.sk-serial {\n",
    732        "  display: flex;\n",
    733        "  flex-direction: column;\n",
    734        "  align-items: center;\n",
    735        "  background-color: var(--sklearn-color-background);\n",
    736        "  padding-right: 1em;\n",
    737        "  padding-left: 1em;\n",
    738        "}\n",
    739        "\n",
    740        "\n",
    741        "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
    742        "clickable and can be expanded/collapsed.\n",
    743        "- Pipeline and ColumnTransformer use this feature and define the default style\n",
    744        "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
    745        "*/\n",
    746        "\n",
    747        "/* Pipeline and ColumnTransformer style (default) */\n",
    748        "\n",
    749        "#sk-container-id-49 div.sk-toggleable {\n",
    750        "  /* Default theme specific background. It is overwritten whether we have a\n",
    751        "  specific estimator or a Pipeline/ColumnTransformer */\n",
    752        "  background-color: var(--sklearn-color-background);\n",
    753        "}\n",
    754        "\n",
    755        "/* Toggleable label */\n",
    756        "#sk-container-id-49 label.sk-toggleable__label {\n",
    757        "  cursor: pointer;\n",
    758        "  display: block;\n",
    759        "  width: 100%;\n",
    760        "  margin-bottom: 0;\n",
    761        "  padding: 0.5em;\n",
    762        "  box-sizing: border-box;\n",
    763        "  text-align: center;\n",
    764        "}\n",
    765        "\n",
    766        "#sk-container-id-49 label.sk-toggleable__label-arrow:before {\n",
    767        "  /* Arrow on the left of the label */\n",
    768        "  content: \"▸\";\n",
    769        "  float: left;\n",
    770        "  margin-right: 0.25em;\n",
    771        "  color: var(--sklearn-color-icon);\n",
    772        "}\n",
    773        "\n",
    774        "#sk-container-id-49 label.sk-toggleable__label-arrow:hover:before {\n",
    775        "  color: var(--sklearn-color-text);\n",
    776        "}\n",
    777        "\n",
    778        "/* Toggleable content - dropdown */\n",
    779        "\n",
    780        "#sk-container-id-49 div.sk-toggleable__content {\n",
    781        "  max-height: 0;\n",
    782        "  max-width: 0;\n",
    783        "  overflow: hidden;\n",
    784        "  text-align: left;\n",
    785        "  /* unfitted */\n",
    786        "  background-color: var(--sklearn-color-unfitted-level-0);\n",
    787        "}\n",
    788        "\n",
    789        "#sk-container-id-49 div.sk-toggleable__content.fitted {\n",
    790        "  /* fitted */\n",
    791        "  background-color: var(--sklearn-color-fitted-level-0);\n",
    792        "}\n",
    793        "\n",
    794        "#sk-container-id-49 div.sk-toggleable__content pre {\n",
    795        "  margin: 0.2em;\n",
    796        "  border-radius: 0.25em;\n",
    797        "  color: var(--sklearn-color-text);\n",
    798        "  /* unfitted */\n",
    799        "  background-color: var(--sklearn-color-unfitted-level-0);\n",
    800        "}\n",
    801        "\n",
    802        "#sk-container-id-49 div.sk-toggleable__content.fitted pre {\n",
    803        "  /* unfitted */\n",
    804        "  background-color: var(--sklearn-color-fitted-level-0);\n",
    805        "}\n",
    806        "\n",
    807        "#sk-container-id-49 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
    808        "  /* Expand drop-down */\n",
    809        "  max-height: 200px;\n",
    810        "  max-width: 100%;\n",
    811        "  overflow: auto;\n",
    812        "}\n",
    813        "\n",
    814        "#sk-container-id-49 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
    815        "  content: \"▾\";\n",
    816        "}\n",
    817        "\n",
    818        "/* Pipeline/ColumnTransformer-specific style */\n",
    819        "\n",
    820        "#sk-container-id-49 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
    821        "  color: var(--sklearn-color-text);\n",
    822        "  background-color: var(--sklearn-color-unfitted-level-2);\n",
    823        "}\n",
    824        "\n",
    825        "#sk-container-id-49 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
    826        "  background-color: var(--sklearn-color-fitted-level-2);\n",
    827        "}\n",
    828        "\n",
    829        "/* Estimator-specific style */\n",
    830        "\n",
    831        "/* Colorize estimator box */\n",
    832        "#sk-container-id-49 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
    833        "  /* unfitted */\n",
    834        "  background-color: var(--sklearn-color-unfitted-level-2);\n",
    835        "}\n",
    836        "\n",
    837        "#sk-container-id-49 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
    838        "  /* fitted */\n",
    839        "  background-color: var(--sklearn-color-fitted-level-2);\n",
    840        "}\n",
    841        "\n",
    842        "#sk-container-id-49 div.sk-label label.sk-toggleable__label,\n",
    843        "#sk-container-id-49 div.sk-label label {\n",
    844        "  /* The background is the default theme color */\n",
    845        "  color: var(--sklearn-color-text-on-default-background);\n",
    846        "}\n",
    847        "\n",
    848        "/* On hover, darken the color of the background */\n",
    849        "#sk-container-id-49 div.sk-label:hover label.sk-toggleable__label {\n",
    850        "  color: var(--sklearn-color-text);\n",
    851        "  background-color: var(--sklearn-color-unfitted-level-2);\n",
    852        "}\n",
    853        "\n",
    854        "/* Label box, darken color on hover, fitted */\n",
    855        "#sk-container-id-49 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
    856        "  color: var(--sklearn-color-text);\n",
    857        "  background-color: var(--sklearn-color-fitted-level-2);\n",
    858        "}\n",
    859        "\n",
    860        "/* Estimator label */\n",
    861        "\n",
    862        "#sk-container-id-49 div.sk-label label {\n",
    863        "  font-family: monospace;\n",
    864        "  font-weight: bold;\n",
    865        "  display: inline-block;\n",
    866        "  line-height: 1.2em;\n",
    867        "}\n",
    868        "\n",
    869        "#sk-container-id-49 div.sk-label-container {\n",
    870        "  text-align: center;\n",
    871        "}\n",
    872        "\n",
    873        "/* Estimator-specific */\n",
    874        "#sk-container-id-49 div.sk-estimator {\n",
    875        "  font-family: monospace;\n",
    876        "  border: 1px dotted var(--sklearn-color-border-box);\n",
    877        "  border-radius: 0.25em;\n",
    878        "  box-sizing: border-box;\n",
    879        "  margin-bottom: 0.5em;\n",
    880        "  /* unfitted */\n",
    881        "  background-color: var(--sklearn-color-unfitted-level-0);\n",
    882        "}\n",
    883        "\n",
    884        "#sk-container-id-49 div.sk-estimator.fitted {\n",
    885        "  /* fitted */\n",
    886        "  background-color: var(--sklearn-color-fitted-level-0);\n",
    887        "}\n",
    888        "\n",
    889        "/* on hover */\n",
    890        "#sk-container-id-49 div.sk-estimator:hover {\n",
    891        "  /* unfitted */\n",
    892        "  background-color: var(--sklearn-color-unfitted-level-2);\n",
    893        "}\n",
    894        "\n",
    895        "#sk-container-id-49 div.sk-estimator.fitted:hover {\n",
    896        "  /* fitted */\n",
    897        "  background-color: var(--sklearn-color-fitted-level-2);\n",
    898        "}\n",
    899        "\n",
    900        "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
    901        "\n",
    902        "/* Common style for \"i\" and \"?\" */\n",
    903        "\n",
    904        ".sk-estimator-doc-link,\n",
    905        "a:link.sk-estimator-doc-link,\n",
    906        "a:visited.sk-estimator-doc-link {\n",
    907        "  float: right;\n",
    908        "  font-size: smaller;\n",
    909        "  line-height: 1em;\n",
    910        "  font-family: monospace;\n",
    911        "  background-color: var(--sklearn-color-background);\n",
    912        "  border-radius: 1em;\n",
    913        "  height: 1em;\n",
    914        "  width: 1em;\n",
    915        "  text-decoration: none !important;\n",
    916        "  margin-left: 1ex;\n",
    917        "  /* unfitted */\n",
    918        "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
    919        "  color: var(--sklearn-color-unfitted-level-1);\n",
    920        "}\n",
    921        "\n",
    922        ".sk-estimator-doc-link.fitted,\n",
    923        "a:link.sk-estimator-doc-link.fitted,\n",
    924        "a:visited.sk-estimator-doc-link.fitted {\n",
    925        "  /* fitted */\n",
    926        "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
    927        "  color: var(--sklearn-color-fitted-level-1);\n",
    928        "}\n",
    929        "\n",
    930        "/* On hover */\n",
    931        "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
    932        ".sk-estimator-doc-link:hover,\n",
    933        "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
    934        ".sk-estimator-doc-link:hover {\n",
    935        "  /* unfitted */\n",
    936        "  background-color: var(--sklearn-color-unfitted-level-3);\n",
    937        "  color: var(--sklearn-color-background);\n",
    938        "  text-decoration: none;\n",
    939        "}\n",
    940        "\n",
    941        "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
    942        ".sk-estimator-doc-link.fitted:hover,\n",
    943        "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
    944        ".sk-estimator-doc-link.fitted:hover {\n",
    945        "  /* fitted */\n",
    946        "  background-color: var(--sklearn-color-fitted-level-3);\n",
    947        "  color: var(--sklearn-color-background);\n",
    948        "  text-decoration: none;\n",
    949        "}\n",
    950        "\n",
    951        "/* Span, style for the box shown on hovering the info icon */\n",
    952        ".sk-estimator-doc-link span {\n",
    953        "  display: none;\n",
    954        "  z-index: 9999;\n",
    955        "  position: relative;\n",
    956        "  font-weight: normal;\n",
    957        "  right: .2ex;\n",
    958        "  padding: .5ex;\n",
    959        "  margin: .5ex;\n",
    960        "  width: min-content;\n",
    961        "  min-width: 20ex;\n",
    962        "  max-width: 50ex;\n",
    963        "  color: var(--sklearn-color-text);\n",
    964        "  box-shadow: 2pt 2pt 4pt #999;\n",
    965        "  /* unfitted */\n",
    966        "  background: var(--sklearn-color-unfitted-level-0);\n",
    967        "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
    968        "}\n",
    969        "\n",
    970        ".sk-estimator-doc-link.fitted span {\n",
    971        "  /* fitted */\n",
    972        "  background: var(--sklearn-color-fitted-level-0);\n",
    973        "  border: var(--sklearn-color-fitted-level-3);\n",
    974        "}\n",
    975        "\n",
    976        ".sk-estimator-doc-link:hover span {\n",
    977        "  display: block;\n",
    978        "}\n",
    979        "\n",
    980        "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
    981        "\n",
    982        "#sk-container-id-49 a.estimator_doc_link {\n",
    983        "  float: right;\n",
    984        "  font-size: 1rem;\n",
    985        "  line-height: 1em;\n",
    986        "  font-family: monospace;\n",
    987        "  background-color: var(--sklearn-color-background);\n",
    988        "  border-radius: 1rem;\n",
    989        "  height: 1rem;\n",
    990        "  width: 1rem;\n",
    991        "  text-decoration: none;\n",
    992        "  /* unfitted */\n",
    993        "  color: var(--sklearn-color-unfitted-level-1);\n",
    994        "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
    995        "}\n",
    996        "\n",
    997        "#sk-container-id-49 a.estimator_doc_link.fitted {\n",
    998        "  /* fitted */\n",
    999        "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
   1000        "  color: var(--sklearn-color-fitted-level-1);\n",
   1001        "}\n",
   1002        "\n",
   1003        "/* On hover */\n",
   1004        "#sk-container-id-49 a.estimator_doc_link:hover {\n",
   1005        "  /* unfitted */\n",
   1006        "  background-color: var(--sklearn-color-unfitted-level-3);\n",
   1007        "  color: var(--sklearn-color-background);\n",
   1008        "  text-decoration: none;\n",
   1009        "}\n",
   1010        "\n",
   1011        "#sk-container-id-49 a.estimator_doc_link.fitted:hover {\n",
   1012        "  /* fitted */\n",
   1013        "  background-color: var(--sklearn-color-fitted-level-3);\n",
   1014        "}\n",
   1015        "</style><div id=\"sk-container-id-49\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier(max_depth=2, n_estimators=10)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-49\" type=\"checkbox\" checked><label for=\"sk-estimator-id-49\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier(max_depth=2, n_estimators=10)</pre></div> </div></div></div></div>"
   1016       ],
   1017       "text/plain": [
   1018        "RandomForestClassifier(max_depth=2, n_estimators=10)"
   1019       ]
   1020      },
   1021      "execution_count": 691,
   1022      "metadata": {},
   1023      "output_type": "execute_result"
   1024     }
   1025    ],
   1026    "source": [
   1027     "from sklearn.ensemble import RandomForestClassifier\n",
   1028     "rnd_clf = RandomForestClassifier(max_depth=2, n_estimators=10)\n",
   1029     "rnd_clf.fit(X_train,y_train)"
   1030    ]
   1031   },
   1032   {
   1033    "cell_type": "code",
   1034    "execution_count": 692,
   1035    "metadata": {},
   1036    "outputs": [
   1037     {
   1038      "data": {
   1039       "text/plain": [
   1040        "1.0"
   1041       ]
   1042      },
   1043      "execution_count": 692,
   1044      "metadata": {},
   1045      "output_type": "execute_result"
   1046     }
   1047    ],
   1048    "source": [
   1049     "from sklearn.metrics import accuracy_score\n",
   1050     "\n",
   1051     "y_pred = rnd_clf.predict(X=X_val)\n",
   1052     "accuracy_score(y_pred=y_pred, y_true=y_val)"
   1053    ]
   1054   },
   1055   {
   1056    "cell_type": "markdown",
   1057    "metadata": {},
   1058    "source": [
   1059     "I have found a problem, the dataset is far too small.\n",
   1060     "Regardless, let's continue."
   1061    ]
   1062   },
   1063   {
   1064    "cell_type": "code",
   1065    "execution_count": 693,
   1066    "metadata": {},
   1067    "outputs": [],
   1068    "source": [
   1069     "import keras\n",
   1070     "import tensorflow as tf \n",
   1071     "\n",
   1072     "model = keras.Sequential(layers=[\n",
   1073     "    keras.layers.Input(shape=[15,]),\n",
   1074     "    keras.layers.Dense(128, activation='relu'),\n",
   1075     "    keras.layers.Dense(128, activation='relu'),\n",
   1076     "    keras.layers.Dense(1, activation='sigmoid')\n",
   1077     "])\n",
   1078     "\n",
   1079     "model.compile(loss=keras.losses.binary_crossentropy, optimizer='adam', metrics=['accuracy'])"
   1080    ]
   1081   },
   1082   {
   1083    "cell_type": "code",
   1084    "execution_count": 694,
   1085    "metadata": {},
   1086    "outputs": [
   1087     {
   1088      "name": "stdout",
   1089      "output_type": "stream",
   1090      "text": [
   1091       "Epoch 1/10\n"
   1092      ]
   1093     },
   1094     {
   1095      "name": "stdout",
   1096      "output_type": "stream",
   1097      "text": [
   1098       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 39ms/step - accuracy: 0.7644 - loss: 0.5846 - val_accuracy: 1.0000 - val_loss: 0.3691\n",
   1099       "Epoch 2/10\n",
   1100       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8537 - loss: 0.4351 - val_accuracy: 1.0000 - val_loss: 0.2381\n",
   1101       "Epoch 3/10\n",
   1102       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8464 - loss: 0.4046 - val_accuracy: 1.0000 - val_loss: 0.1746\n",
   1103       "Epoch 4/10\n",
   1104       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8422 - loss: 0.3757 - val_accuracy: 1.0000 - val_loss: 0.1501\n",
   1105       "Epoch 5/10\n",
   1106       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8527 - loss: 0.3416 - val_accuracy: 1.0000 - val_loss: 0.1443\n",
   1107       "Epoch 6/10\n",
   1108       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8589 - loss: 0.3011 - val_accuracy: 1.0000 - val_loss: 0.1481\n",
   1109       "Epoch 7/10\n",
   1110       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8433 - loss: 0.2969 - val_accuracy: 1.0000 - val_loss: 0.1561\n",
   1111       "Epoch 8/10\n",
   1112       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.9003 - loss: 0.2562 - val_accuracy: 1.0000 - val_loss: 0.1634\n",
   1113       "Epoch 9/10\n",
   1114       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8962 - loss: 0.2561 - val_accuracy: 1.0000 - val_loss: 0.1694\n",
   1115       "Epoch 10/10\n",
   1116       "\u001b[1m4/4\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.8853 - loss: 0.2665 - val_accuracy: 0.9444 - val_loss: 0.1725\n"
   1117      ]
   1118     },
   1119     {
   1120      "data": {
   1121       "text/plain": [
   1122        "<keras.src.callbacks.history.History at 0x7fdddb016fd0>"
   1123       ]
   1124      },
   1125      "execution_count": 694,
   1126      "metadata": {},
   1127      "output_type": "execute_result"
   1128     }
   1129    ],
   1130    "source": [
   1131     "model.fit(X_train,y_train, validation_data=[X_val,y_val], epochs=10)"
   1132    ]
   1133   },
   1134   {
   1135    "cell_type": "code",
   1136    "execution_count": 695,
   1137    "metadata": {},
   1138    "outputs": [
   1139     {
   1140      "data": {
   1141       "text/plain": [
   1142        "141    False\n",
   1143        "122    False\n",
   1144        "138    False\n",
   1145        "84     False\n",
   1146        "48     False\n",
   1147        "80     False\n",
   1148        "136    False\n",
   1149        "110    False\n",
   1150        "137    False\n",
   1151        "71     False\n",
   1152        "67     False\n",
   1153        "20     False\n",
   1154        "56     False\n",
   1155        "117    False\n",
   1156        "106    False\n",
   1157        "100    False\n",
   1158        "103    False\n",
   1159        "59     False\n",
   1160        "Name: cheated, dtype: bool"
   1161       ]
   1162      },
   1163      "execution_count": 695,
   1164      "metadata": {},
   1165      "output_type": "execute_result"
   1166     }
   1167    ],
   1168    "source": [
   1169     "y_val"
   1170    ]
   1171   },
   1172   {
   1173    "cell_type": "markdown",
   1174    "metadata": {},
   1175    "source": [
   1176     "Validate NN vs RND Forest"
   1177    ]
   1178   },
   1179   {
   1180    "cell_type": "code",
   1181    "execution_count": 696,
   1182    "metadata": {},
   1183    "outputs": [
   1184     {
   1185      "name": "stdout",
   1186      "output_type": "stream",
   1187      "text": [
   1188       "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 34ms/step\n"
   1189      ]
   1190     },
   1191     {
   1192      "data": {
   1193       "text/plain": [
   1194        "0.9444444444444444"
   1195       ]
   1196      },
   1197      "execution_count": 696,
   1198      "metadata": {},
   1199      "output_type": "execute_result"
   1200     }
   1201    ],
   1202    "source": [
   1203     "import numpy as np\n",
   1204     "y_pred = model.predict(X_test)\n",
   1205     "binary_predictions = (y_pred >= 0.5).astype(np.bool_)\n",
   1206     "\n",
   1207     "accuracy_score(y_pred=binary_predictions, y_true=y_test)"
   1208    ]
   1209   },
   1210   {
   1211    "cell_type": "code",
   1212    "execution_count": 697,
   1213    "metadata": {},
   1214    "outputs": [
   1215     {
   1216      "data": {
   1217       "text/plain": [
   1218        "0.8888888888888888"
   1219       ]
   1220      },
   1221      "execution_count": 697,
   1222      "metadata": {},
   1223      "output_type": "execute_result"
   1224     }
   1225    ],
   1226    "source": [
   1227     "y_pred = rnd_clf.predict(X=X_test)\n",
   1228     "accuracy_score(y_pred=y_pred, y_true=y_test)"
   1229    ]
   1230   }
   1231  ],
   1232  "metadata": {
   1233   "kernelspec": {
   1234    "display_name": ".venv",
   1235    "language": "python",
   1236    "name": "python3"
   1237   },
   1238   "language_info": {
   1239    "codemirror_mode": {
   1240     "name": "ipython",
   1241     "version": 3
   1242    },
   1243    "file_extension": ".py",
   1244    "mimetype": "text/x-python",
   1245    "name": "python",
   1246    "nbconvert_exporter": "python",
   1247    "pygments_lexer": "ipython3",
   1248    "version": "3.11.2"
   1249   }
   1250  },
   1251  "nbformat": 4,
   1252  "nbformat_minor": 2
   1253 }