commit 0fdeb834773677c2c28705ff33ae8fcc7c5340da
parent 799b34cbf762b55879ee1c9066b109524efdcbd8
Author: Andrew <andrewlaack1@gmail.com>
Date: Tue, 25 Jun 2024 16:39:25 -0500
Naive Bayes Custom Spam
Diffstat:
1 file changed, 744 insertions(+), 0 deletions(-)
diff --git a/spamFilter/NaiveBayesSpamFilterScratch.ipynb b/spamFilter/NaiveBayesSpamFilterScratch.ipynb
@@ -0,0 +1,744 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This dataset is from kaggle:\n",
+ "\n",
+ "https://www.kaggle.com/datasets/abdallahwagih/spam-emails\n",
+ "Download this, move it to the correct location and then start working. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The implementation from sklearn uses a more sophisticated tokenization strategy which makes it better than mine. As such, while I got 80% accuracy with Naive Bayes, sklearn got 97%"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 474,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "emails = pd.read_csv('../datasets/spamEmails/emails.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 475,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X = emails['Message']\n",
+ "y = emails['Category']\n",
+ "\n",
+ "y = y == 'spam'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 476,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "\n",
+ "\n",
+ "def removePunctuation(text):\n",
+ " return re.sub(r'[^a-zA-Z\\s]', '', text)\n",
+ "\n",
+ "X = X.apply(removePunctuation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 477,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 478,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def uniqueWords(text : str):\n",
+ " text = text.lower()\n",
+ " textLs = text.split(' ')\n",
+ " # Init set with ''\n",
+ " opts = {''}\n",
+ " for i in textLs:\n",
+ " opts.add(i)\n",
+ " # Remove ''\n",
+ " opts.remove('')\n",
+ " return opts\n",
+ "\n",
+ "\n",
+ "wcHam = {}\n",
+ "wcSpam = {}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 479,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "i = 0\n",
+ "while i < len(X_train):\n",
+ " if(y_train.iloc[i]) == 1:\n",
+ " words = uniqueWords(X_train.iloc[i])\n",
+ " for w in words:\n",
+ " count = 1\n",
+ " if w in wcSpam:\n",
+ " count = wcSpam[w] + 1\n",
+ " wcSpam[w] = count\n",
+ "\n",
+ " else:\n",
+ " words = uniqueWords(X_train.iloc[i])\n",
+ " for w in words:\n",
+ " count = 1\n",
+ " if w in wcHam:\n",
+ " count = wcHam[w] + 1\n",
+ " wcHam[w] = count\n",
+ " i += 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 480,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6238\n",
+ "1966\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(len(wcHam))\n",
+ "print(len(wcSpam))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 481,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7322\n",
+ "7322\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Add one to each value to ensure there are no 0 probabilities (pseudo count)\n",
+ "# With a zero probability it would mess up naive bayes calculations\n",
+ "\n",
+ "keys_combined = set(wcHam.keys()).union(wcSpam.keys())\n",
+ "\n",
+ "for i in keys_combined:\n",
+ " if i in wcHam:\n",
+ " wcHam[i] = wcHam[i] + 1\n",
+ " else:\n",
+ " wcHam[i] = 1\n",
+ " \n",
+ " if i in wcSpam:\n",
+ " wcSpam[i] = wcSpam[i] + 1\n",
+ " else:\n",
+ " wcSpam[i] = 1\n",
+ "\n",
+ "print(len(wcSpam))\n",
+ "print(len(wcHam))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 482,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_5255/3310267757.py:2: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
+ " hamCount = y_train.value_counts()[0]\n",
+ "/tmp/ipykernel_5255/3310267757.py:3: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
+ " spamCount = y_train.value_counts()[1]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Assign probabilities of each word occuring in any given spam/ham message\n",
+ "hamCount = y_train.value_counts()[0]\n",
+ "spamCount = y_train.value_counts()[1]\n",
+ "\n",
+ "for key in wcHam:\n",
+ " wcHam[key] = wcHam[key] / hamCount\n",
+ "\n",
+ "for key in wcSpam:\n",
+ " wcSpam[key] = wcSpam[key] / spamCount"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 483,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Dumbass forgot to multiply by probability of each class at the start\n",
+ "\n",
+ "def predict(message):\n",
+ " words = message.split(' ')\n",
+ " spamPercent = spamCount / (spamCount + hamCount)\n",
+ " hamPercent = hamCount / (spamCount + hamCount)\n",
+ "\n",
+ " for word in words:\n",
+ " if word in wcSpam:\n",
+ " spamPercent = spamPercent * wcSpam[word]\n",
+ " hamPercent = hamPercent * wcHam[word]\n",
+ " \n",
+ " return spamPercent > hamPercent"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 484,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "predictions = []\n",
+ "\n",
+ "for i in X_test:\n",
+ " predictions.append(predict(i))\n",
+ "\n",
+ "y_test = y_test.to_list()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 485,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total Percent Correct: 0.7982770997846375\n"
+ ]
+ }
+ ],
+ "source": [
+ "correct = 0\n",
+ "while count < len(predictions):\n",
+ " if y_test[count] == predictions[count]:\n",
+ " correct += 1\n",
+ " \n",
+ " count += 1\n",
+ "\n",
+ "print('Total Percent Correct: ', correct/count)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 486,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<style>#sk-container-id-15 {\n",
+ " /* Definition of color scheme common for light and dark mode */\n",
+ " --sklearn-color-text: black;\n",
+ " --sklearn-color-line: gray;\n",
+ " /* Definition of color scheme for unfitted estimators */\n",
+ " --sklearn-color-unfitted-level-0: #fff5e6;\n",
+ " --sklearn-color-unfitted-level-1: #f6e4d2;\n",
+ " --sklearn-color-unfitted-level-2: #ffe0b3;\n",
+ " --sklearn-color-unfitted-level-3: chocolate;\n",
+ " /* Definition of color scheme for fitted estimators */\n",
+ " --sklearn-color-fitted-level-0: #f0f8ff;\n",
+ " --sklearn-color-fitted-level-1: #d4ebff;\n",
+ " --sklearn-color-fitted-level-2: #b3dbfd;\n",
+ " --sklearn-color-fitted-level-3: cornflowerblue;\n",
+ "\n",
+ " /* Specific color for light theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-icon: #696969;\n",
+ "\n",
+ " @media (prefers-color-scheme: dark) {\n",
+ " /* Redefinition of color scheme for dark theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-icon: #878787;\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 pre {\n",
+ " padding: 0;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 input.sk-hidden--visually {\n",
+ " border: 0;\n",
+ " clip: rect(1px 1px 1px 1px);\n",
+ " clip: rect(1px, 1px, 1px, 1px);\n",
+ " height: 1px;\n",
+ " margin: -1px;\n",
+ " overflow: hidden;\n",
+ " padding: 0;\n",
+ " position: absolute;\n",
+ " width: 1px;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-dashed-wrapped {\n",
+ " border: 1px dashed var(--sklearn-color-line);\n",
+ " margin: 0 0.4em 0.5em 0.4em;\n",
+ " box-sizing: border-box;\n",
+ " padding-bottom: 0.4em;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-container {\n",
+ " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
+ " but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
+ " so we also need the `!important` here to be able to override the\n",
+ " default hidden behavior on the sphinx rendered scikit-learn.org.\n",
+ " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
+ " display: inline-block !important;\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-text-repr-fallback {\n",
+ " display: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-parallel-item,\n",
+ "div.sk-serial,\n",
+ "div.sk-item {\n",
+ " /* draw centered vertical line to link estimators */\n",
+ " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
+ " background-size: 2px 100%;\n",
+ " background-repeat: no-repeat;\n",
+ " background-position: center center;\n",
+ "}\n",
+ "\n",
+ "/* Parallel-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-15 div.sk-parallel-item::after {\n",
+ " content: \"\";\n",
+ " width: 100%;\n",
+ " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
+ " flex-grow: 1;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-parallel {\n",
+ " display: flex;\n",
+ " align-items: stretch;\n",
+ " justify-content: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-parallel-item {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-parallel-item:first-child::after {\n",
+ " align-self: flex-end;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-parallel-item:last-child::after {\n",
+ " align-self: flex-start;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-parallel-item:only-child::after {\n",
+ " width: 0;\n",
+ "}\n",
+ "\n",
+ "/* Serial-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-15 div.sk-serial {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ " align-items: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " padding-right: 1em;\n",
+ " padding-left: 1em;\n",
+ "}\n",
+ "\n",
+ "\n",
+ "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
+ "clickable and can be expanded/collapsed.\n",
+ "- Pipeline and ColumnTransformer use this feature and define the default style\n",
+ "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
+ "*/\n",
+ "\n",
+ "/* Pipeline and ColumnTransformer style (default) */\n",
+ "\n",
+ "#sk-container-id-15 div.sk-toggleable {\n",
+ " /* Default theme specific background. It is overwritten whether we have a\n",
+ " specific estimator or a Pipeline/ColumnTransformer */\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable label */\n",
+ "#sk-container-id-15 label.sk-toggleable__label {\n",
+ " cursor: pointer;\n",
+ " display: block;\n",
+ " width: 100%;\n",
+ " margin-bottom: 0;\n",
+ " padding: 0.5em;\n",
+ " box-sizing: border-box;\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 label.sk-toggleable__label-arrow:before {\n",
+ " /* Arrow on the left of the label */\n",
+ " content: \"▸\";\n",
+ " float: left;\n",
+ " margin-right: 0.25em;\n",
+ " color: var(--sklearn-color-icon);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 label.sk-toggleable__label-arrow:hover:before {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable content - dropdown */\n",
+ "\n",
+ "#sk-container-id-15 div.sk-toggleable__content {\n",
+ " max-height: 0;\n",
+ " max-width: 0;\n",
+ " overflow: hidden;\n",
+ " text-align: left;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-toggleable__content.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-toggleable__content pre {\n",
+ " margin: 0.2em;\n",
+ " border-radius: 0.25em;\n",
+ " color: var(--sklearn-color-text);\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-toggleable__content.fitted pre {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
+ " /* Expand drop-down */\n",
+ " max-height: 200px;\n",
+ " max-width: 100%;\n",
+ " overflow: auto;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
+ " content: \"▾\";\n",
+ "}\n",
+ "\n",
+ "/* Pipeline/ColumnTransformer-specific style */\n",
+ "\n",
+ "#sk-container-id-15 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific style */\n",
+ "\n",
+ "/* Colorize estimator box */\n",
+ "#sk-container-id-15 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-label label.sk-toggleable__label,\n",
+ "#sk-container-id-15 div.sk-label label {\n",
+ " /* The background is the default theme color */\n",
+ " color: var(--sklearn-color-text-on-default-background);\n",
+ "}\n",
+ "\n",
+ "/* On hover, darken the color of the background */\n",
+ "#sk-container-id-15 div.sk-label:hover label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Label box, darken color on hover, fitted */\n",
+ "#sk-container-id-15 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator label */\n",
+ "\n",
+ "#sk-container-id-15 div.sk-label label {\n",
+ " font-family: monospace;\n",
+ " font-weight: bold;\n",
+ " display: inline-block;\n",
+ " line-height: 1.2em;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-label-container {\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific */\n",
+ "#sk-container-id-15 div.sk-estimator {\n",
+ " font-family: monospace;\n",
+ " border: 1px dotted var(--sklearn-color-border-box);\n",
+ " border-radius: 0.25em;\n",
+ " box-sizing: border-box;\n",
+ " margin-bottom: 0.5em;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-estimator.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "/* on hover */\n",
+ "#sk-container-id-15 div.sk-estimator:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 div.sk-estimator.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
+ "\n",
+ "/* Common style for \"i\" and \"?\" */\n",
+ "\n",
+ ".sk-estimator-doc-link,\n",
+ "a:link.sk-estimator-doc-link,\n",
+ "a:visited.sk-estimator-doc-link {\n",
+ " float: right;\n",
+ " font-size: smaller;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1em;\n",
+ " height: 1em;\n",
+ " width: 1em;\n",
+ " text-decoration: none !important;\n",
+ " margin-left: 1ex;\n",
+ " /* unfitted */\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted,\n",
+ "a:link.sk-estimator-doc-link.fitted,\n",
+ "a:visited.sk-estimator-doc-link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "/* Span, style for the box shown on hovering the info icon */\n",
+ ".sk-estimator-doc-link span {\n",
+ " display: none;\n",
+ " z-index: 9999;\n",
+ " position: relative;\n",
+ " font-weight: normal;\n",
+ " right: .2ex;\n",
+ " padding: .5ex;\n",
+ " margin: .5ex;\n",
+ " width: min-content;\n",
+ " min-width: 20ex;\n",
+ " max-width: 50ex;\n",
+ " color: var(--sklearn-color-text);\n",
+ " box-shadow: 2pt 2pt 4pt #999;\n",
+ " /* unfitted */\n",
+ " background: var(--sklearn-color-unfitted-level-0);\n",
+ " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted span {\n",
+ " /* fitted */\n",
+ " background: var(--sklearn-color-fitted-level-0);\n",
+ " border: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link:hover span {\n",
+ " display: block;\n",
+ "}\n",
+ "\n",
+ "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
+ "\n",
+ "#sk-container-id-15 a.estimator_doc_link {\n",
+ " float: right;\n",
+ " font-size: 1rem;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1rem;\n",
+ " height: 1rem;\n",
+ " width: 1rem;\n",
+ " text-decoration: none;\n",
+ " /* unfitted */\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 a.estimator_doc_link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "#sk-container-id-15 a.estimator_doc_link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-15 a.estimator_doc_link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "</style><div id=\"sk-container-id-15\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-43\" type=\"checkbox\" ><label for=\"sk-estimator-id-43\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> Pipeline<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html\">?<span>Documentation for Pipeline</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-44\" type=\"checkbox\" ><label for=\"sk-estimator-id-44\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> CountVectorizer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html\">?<span>Documentation for CountVectorizer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>CountVectorizer()</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-45\" type=\"checkbox\" ><label for=\"sk-estimator-id-45\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> MultinomialNB<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.naive_bayes.MultinomialNB.html\">?<span>Documentation for MultinomialNB</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>MultinomialNB()</pre></div> </div></div></div></div></div></div>"
+ ],
+ "text/plain": [
+ "Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])"
+ ]
+ },
+ "execution_count": 486,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.naive_bayes import MultinomialNB\n",
+ "from sklearn.pipeline import Pipeline\n",
+ "from sklearn.feature_extraction.text import CountVectorizer\n",
+ "\n",
+ "pipeline = Pipeline([\n",
+ " ('vect', CountVectorizer()), # Use CountVectorizer to convert text into token counts\n",
+ " ('clf', MultinomialNB()), # Naive Bayes classifier\n",
+ "])\n",
+ "\n",
+ "# Fit the model on the training data\n",
+ "pipeline.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 487,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total Percent Correct: 0.9777458722182341\n"
+ ]
+ }
+ ],
+ "source": [
+ "predictions = pipeline.predict(X_test)\n",
+ "\n",
+ "correct = 0\n",
+ "count = 0\n",
+ "while count < len(predictions):\n",
+ " if y_test[count] == predictions[count]:\n",
+ " correct += 1\n",
+ "\n",
+ " count += 1\n",
+ "\n",
+ "print('Total Percent Correct: ', correct/count)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}