machinelearning

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs

commit 0fdeb834773677c2c28705ff33ae8fcc7c5340da
parent 799b34cbf762b55879ee1c9066b109524efdcbd8
Author: Andrew <andrewlaack1@gmail.com>
Date:   Tue, 25 Jun 2024 16:39:25 -0500

Naive Bayes Custom Spam

Diffstat:
AspamFilter/NaiveBayesSpamFilterScratch.ipynb | 744+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 744 insertions(+), 0 deletions(-)

diff --git a/spamFilter/NaiveBayesSpamFilterScratch.ipynb b/spamFilter/NaiveBayesSpamFilterScratch.ipynb @@ -0,0 +1,744 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This dataset is from kaggle:\n", + "\n", + "https://www.kaggle.com/datasets/abdallahwagih/spam-emails\n", + "Download this, move it to the correct location and then start working. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The implementation from sklearn uses a more sophisticated tokenization strategy which makes it better than mine. As such, while I got 80% accuracy with Naive Bayes, sklearn got 97%" + ] + }, + { + "cell_type": "code", + "execution_count": 474, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "emails = pd.read_csv('../datasets/spamEmails/emails.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 475, + "metadata": {}, + "outputs": [], + "source": [ + "X = emails['Message']\n", + "y = emails['Category']\n", + "\n", + "y = y == 'spam'" + ] + }, + { + "cell_type": "code", + "execution_count": 476, + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "\n", + "\n", + "def removePunctuation(text):\n", + " return re.sub(r'[^a-zA-Z\\s]', '', text)\n", + "\n", + "X = X.apply(removePunctuation)" + ] + }, + { + "cell_type": "code", + "execution_count": 477, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 478, + "metadata": {}, + "outputs": [], + "source": [ + "def uniqueWords(text : str):\n", + " text = text.lower()\n", + " textLs = text.split(' ')\n", + " # Init set with ''\n", + " opts = {''}\n", + " for i in textLs:\n", + " opts.add(i)\n", + " # Remove ''\n", + " opts.remove('')\n", + " return opts\n", + "\n", + "\n", + "wcHam = {}\n", + "wcSpam = {}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 479, + "metadata": {}, + "outputs": [], + "source": [ + "i = 0\n", + "while i < len(X_train):\n", + " if(y_train.iloc[i]) == 1:\n", + " words = uniqueWords(X_train.iloc[i])\n", + " for w in words:\n", + " count = 1\n", + " if w in wcSpam:\n", + " count = wcSpam[w] + 1\n", + " wcSpam[w] = count\n", + "\n", + " else:\n", + " words = uniqueWords(X_train.iloc[i])\n", + " for w in words:\n", + " count = 1\n", + " if w in wcHam:\n", + " count = wcHam[w] + 1\n", + " wcHam[w] = count\n", + " i += 1" + ] + }, + { + "cell_type": "code", + "execution_count": 480, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6238\n", + "1966\n" + ] + } + ], + "source": [ + "print(len(wcHam))\n", + "print(len(wcSpam))" + ] + }, + { + "cell_type": "code", + "execution_count": 481, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7322\n", + "7322\n" + ] + } + ], + "source": [ + "# Add one to each value to ensure there are no 0 probabilities (pseudo count)\n", + "# With a zero probability it would mess up naive bayes calculations\n", + "\n", + "keys_combined = set(wcHam.keys()).union(wcSpam.keys())\n", + "\n", + "for i in keys_combined:\n", + " if i in wcHam:\n", + " wcHam[i] = wcHam[i] + 1\n", + " else:\n", + " wcHam[i] = 1\n", + " \n", + " if i in wcSpam:\n", + " wcSpam[i] = wcSpam[i] + 1\n", + " else:\n", + " wcSpam[i] = 1\n", + "\n", + "print(len(wcSpam))\n", + "print(len(wcHam))" + ] + }, + { + "cell_type": "code", + "execution_count": 482, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_5255/3310267757.py:2: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n", + " hamCount = y_train.value_counts()[0]\n", + "/tmp/ipykernel_5255/3310267757.py:3: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n", + " spamCount = y_train.value_counts()[1]\n" + ] + } + ], + "source": [ + "# Assign probabilities of each word occuring in any given spam/ham message\n", + "hamCount = y_train.value_counts()[0]\n", + "spamCount = y_train.value_counts()[1]\n", + "\n", + "for key in wcHam:\n", + " wcHam[key] = wcHam[key] / hamCount\n", + "\n", + "for key in wcSpam:\n", + " wcSpam[key] = wcSpam[key] / spamCount" + ] + }, + { + "cell_type": "code", + "execution_count": 483, + "metadata": {}, + "outputs": [], + "source": [ + "# Dumbass forgot to multiply by probability of each class at the start\n", + "\n", + "def predict(message):\n", + " words = message.split(' ')\n", + " spamPercent = spamCount / (spamCount + hamCount)\n", + " hamPercent = hamCount / (spamCount + hamCount)\n", + "\n", + " for word in words:\n", + " if word in wcSpam:\n", + " spamPercent = spamPercent * wcSpam[word]\n", + " hamPercent = hamPercent * wcHam[word]\n", + " \n", + " return spamPercent > hamPercent" + ] + }, + { + "cell_type": "code", + "execution_count": 484, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "predictions = []\n", + "\n", + "for i in X_test:\n", + " predictions.append(predict(i))\n", + "\n", + "y_test = y_test.to_list()" + ] + }, + { + "cell_type": "code", + "execution_count": 485, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total Percent Correct: 0.7982770997846375\n" + ] + } + ], + "source": [ + "correct = 0\n", + "while count < len(predictions):\n", + " if y_test[count] == predictions[count]:\n", + " correct += 1\n", + " \n", + " count += 1\n", + "\n", + "print('Total Percent Correct: ', correct/count)" + ] + }, + { + "cell_type": "code", + "execution_count": 486, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<style>#sk-container-id-15 {\n", + " /* Definition of color scheme common for light and dark mode */\n", + " --sklearn-color-text: black;\n", + " --sklearn-color-line: gray;\n", + " /* Definition of color scheme for unfitted estimators */\n", + " --sklearn-color-unfitted-level-0: #fff5e6;\n", + " --sklearn-color-unfitted-level-1: #f6e4d2;\n", + " --sklearn-color-unfitted-level-2: #ffe0b3;\n", + " --sklearn-color-unfitted-level-3: chocolate;\n", + " /* Definition of color scheme for fitted estimators */\n", + " --sklearn-color-fitted-level-0: #f0f8ff;\n", + " --sklearn-color-fitted-level-1: #d4ebff;\n", + " --sklearn-color-fitted-level-2: #b3dbfd;\n", + " --sklearn-color-fitted-level-3: cornflowerblue;\n", + "\n", + " /* Specific color for light theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", + " --sklearn-color-icon: #696969;\n", + "\n", + " @media (prefers-color-scheme: dark) {\n", + " /* Redefinition of color scheme for dark theme */\n", + " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", + " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", + " --sklearn-color-icon: #878787;\n", + " }\n", + "}\n", + "\n", + "#sk-container-id-15 {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "#sk-container-id-15 pre {\n", + " padding: 0;\n", + "}\n", + "\n", + "#sk-container-id-15 input.sk-hidden--visually {\n", + " border: 0;\n", + " clip: rect(1px 1px 1px 1px);\n", + " clip: rect(1px, 1px, 1px, 1px);\n", + " height: 1px;\n", + " margin: -1px;\n", + " overflow: hidden;\n", + " padding: 0;\n", + " position: absolute;\n", + " width: 1px;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-dashed-wrapped {\n", + " border: 1px dashed var(--sklearn-color-line);\n", + " margin: 0 0.4em 0.5em 0.4em;\n", + " box-sizing: border-box;\n", + " padding-bottom: 0.4em;\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-container {\n", + " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", + " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", + " so we also need the `!important` here to be able to override the\n", + " default hidden behavior on the sphinx rendered scikit-learn.org.\n", + " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", + " display: inline-block !important;\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-text-repr-fallback {\n", + " display: none;\n", + "}\n", + "\n", + "div.sk-parallel-item,\n", + "div.sk-serial,\n", + "div.sk-item {\n", + " /* draw centered vertical line to link estimators */\n", + " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", + " background-size: 2px 100%;\n", + " background-repeat: no-repeat;\n", + " background-position: center center;\n", + "}\n", + "\n", + "/* Parallel-specific style estimator block */\n", + "\n", + "#sk-container-id-15 div.sk-parallel-item::after {\n", + " content: \"\";\n", + " width: 100%;\n", + " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", + " flex-grow: 1;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-parallel {\n", + " display: flex;\n", + " align-items: stretch;\n", + " justify-content: center;\n", + " background-color: var(--sklearn-color-background);\n", + " position: relative;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-parallel-item {\n", + " display: flex;\n", + " flex-direction: column;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-parallel-item:first-child::after {\n", + " align-self: flex-end;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-parallel-item:last-child::after {\n", + " align-self: flex-start;\n", + " width: 50%;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-parallel-item:only-child::after {\n", + " width: 0;\n", + "}\n", + "\n", + "/* Serial-specific style estimator block */\n", + "\n", + "#sk-container-id-15 div.sk-serial {\n", + " display: flex;\n", + " flex-direction: column;\n", + " align-items: center;\n", + " background-color: var(--sklearn-color-background);\n", + " padding-right: 1em;\n", + " padding-left: 1em;\n", + "}\n", + "\n", + "\n", + "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", + "clickable and can be expanded/collapsed.\n", + "- Pipeline and ColumnTransformer use this feature and define the default style\n", + "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", + "*/\n", + "\n", + "/* Pipeline and ColumnTransformer style (default) */\n", + "\n", + "#sk-container-id-15 div.sk-toggleable {\n", + " /* Default theme specific background. It is overwritten whether we have a\n", + " specific estimator or a Pipeline/ColumnTransformer */\n", + " background-color: var(--sklearn-color-background);\n", + "}\n", + "\n", + "/* Toggleable label */\n", + "#sk-container-id-15 label.sk-toggleable__label {\n", + " cursor: pointer;\n", + " display: block;\n", + " width: 100%;\n", + " margin-bottom: 0;\n", + " padding: 0.5em;\n", + " box-sizing: border-box;\n", + " text-align: center;\n", + "}\n", + "\n", + "#sk-container-id-15 label.sk-toggleable__label-arrow:before {\n", + " /* Arrow on the left of the label */\n", + " content: \"▸\";\n", + " float: left;\n", + " margin-right: 0.25em;\n", + " color: var(--sklearn-color-icon);\n", + "}\n", + "\n", + "#sk-container-id-15 label.sk-toggleable__label-arrow:hover:before {\n", + " color: var(--sklearn-color-text);\n", + "}\n", + "\n", + "/* Toggleable content - dropdown */\n", + "\n", + "#sk-container-id-15 div.sk-toggleable__content {\n", + " max-height: 0;\n", + " max-width: 0;\n", + " overflow: hidden;\n", + " text-align: left;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-toggleable__content.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-toggleable__content pre {\n", + " margin: 0.2em;\n", + " border-radius: 0.25em;\n", + " color: var(--sklearn-color-text);\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-toggleable__content.fitted pre {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-15 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", + " /* Expand drop-down */\n", + " max-height: 200px;\n", + " max-width: 100%;\n", + " overflow: auto;\n", + "}\n", + "\n", + "#sk-container-id-15 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", + " content: \"▾\";\n", + "}\n", + "\n", + "/* Pipeline/ColumnTransformer-specific style */\n", + "\n", + "#sk-container-id-15 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator-specific style */\n", + "\n", + "/* Colorize estimator box */\n", + "#sk-container-id-15 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-label label.sk-toggleable__label,\n", + "#sk-container-id-15 div.sk-label label {\n", + " /* The background is the default theme color */\n", + " color: var(--sklearn-color-text-on-default-background);\n", + "}\n", + "\n", + "/* On hover, darken the color of the background */\n", + "#sk-container-id-15 div.sk-label:hover label.sk-toggleable__label {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "/* Label box, darken color on hover, fitted */\n", + "#sk-container-id-15 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", + " color: var(--sklearn-color-text);\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Estimator label */\n", + "\n", + "#sk-container-id-15 div.sk-label label {\n", + " font-family: monospace;\n", + " font-weight: bold;\n", + " display: inline-block;\n", + " line-height: 1.2em;\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-label-container {\n", + " text-align: center;\n", + "}\n", + "\n", + "/* Estimator-specific */\n", + "#sk-container-id-15 div.sk-estimator {\n", + " font-family: monospace;\n", + " border: 1px dotted var(--sklearn-color-border-box);\n", + " border-radius: 0.25em;\n", + " box-sizing: border-box;\n", + " margin-bottom: 0.5em;\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-0);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-estimator.fitted {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-0);\n", + "}\n", + "\n", + "/* on hover */\n", + "#sk-container-id-15 div.sk-estimator:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-2);\n", + "}\n", + "\n", + "#sk-container-id-15 div.sk-estimator.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-2);\n", + "}\n", + "\n", + "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", + "\n", + "/* Common style for \"i\" and \"?\" */\n", + "\n", + ".sk-estimator-doc-link,\n", + "a:link.sk-estimator-doc-link,\n", + "a:visited.sk-estimator-doc-link {\n", + " float: right;\n", + " font-size: smaller;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1em;\n", + " height: 1em;\n", + " width: 1em;\n", + " text-decoration: none !important;\n", + " margin-left: 1ex;\n", + " /* unfitted */\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted,\n", + "a:link.sk-estimator-doc-link.fitted,\n", + "a:visited.sk-estimator-doc-link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", + ".sk-estimator-doc-link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover,\n", + "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", + ".sk-estimator-doc-link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "/* Span, style for the box shown on hovering the info icon */\n", + ".sk-estimator-doc-link span {\n", + " display: none;\n", + " z-index: 9999;\n", + " position: relative;\n", + " font-weight: normal;\n", + " right: .2ex;\n", + " padding: .5ex;\n", + " margin: .5ex;\n", + " width: min-content;\n", + " min-width: 20ex;\n", + " max-width: 50ex;\n", + " color: var(--sklearn-color-text);\n", + " box-shadow: 2pt 2pt 4pt #999;\n", + " /* unfitted */\n", + " background: var(--sklearn-color-unfitted-level-0);\n", + " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link.fitted span {\n", + " /* fitted */\n", + " background: var(--sklearn-color-fitted-level-0);\n", + " border: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "\n", + ".sk-estimator-doc-link:hover span {\n", + " display: block;\n", + "}\n", + "\n", + "/* \"?\"-specific style due to the `<a>` HTML tag */\n", + "\n", + "#sk-container-id-15 a.estimator_doc_link {\n", + " float: right;\n", + " font-size: 1rem;\n", + " line-height: 1em;\n", + " font-family: monospace;\n", + " background-color: var(--sklearn-color-background);\n", + " border-radius: 1rem;\n", + " height: 1rem;\n", + " width: 1rem;\n", + " text-decoration: none;\n", + " /* unfitted */\n", + " color: var(--sklearn-color-unfitted-level-1);\n", + " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", + "}\n", + "\n", + "#sk-container-id-15 a.estimator_doc_link.fitted {\n", + " /* fitted */\n", + " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", + " color: var(--sklearn-color-fitted-level-1);\n", + "}\n", + "\n", + "/* On hover */\n", + "#sk-container-id-15 a.estimator_doc_link:hover {\n", + " /* unfitted */\n", + " background-color: var(--sklearn-color-unfitted-level-3);\n", + " color: var(--sklearn-color-background);\n", + " text-decoration: none;\n", + "}\n", + "\n", + "#sk-container-id-15 a.estimator_doc_link.fitted:hover {\n", + " /* fitted */\n", + " background-color: var(--sklearn-color-fitted-level-3);\n", + "}\n", + "</style><div id=\"sk-container-id-15\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[(&#x27;vect&#x27;, CountVectorizer()), (&#x27;clf&#x27;, MultinomialNB())])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-43\" type=\"checkbox\" ><label for=\"sk-estimator-id-43\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;Pipeline<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html\">?<span>Documentation for Pipeline</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[(&#x27;vect&#x27;, CountVectorizer()), (&#x27;clf&#x27;, MultinomialNB())])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-44\" type=\"checkbox\" ><label for=\"sk-estimator-id-44\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;CountVectorizer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html\">?<span>Documentation for CountVectorizer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>CountVectorizer()</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-45\" type=\"checkbox\" ><label for=\"sk-estimator-id-45\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;MultinomialNB<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.naive_bayes.MultinomialNB.html\">?<span>Documentation for MultinomialNB</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>MultinomialNB()</pre></div> </div></div></div></div></div></div>" + ], + "text/plain": [ + "Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])" + ] + }, + "execution_count": 486, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.naive_bayes import MultinomialNB\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.feature_extraction.text import CountVectorizer\n", + "\n", + "pipeline = Pipeline([\n", + " ('vect', CountVectorizer()), # Use CountVectorizer to convert text into token counts\n", + " ('clf', MultinomialNB()), # Naive Bayes classifier\n", + "])\n", + "\n", + "# Fit the model on the training data\n", + "pipeline.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 487, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total Percent Correct: 0.9777458722182341\n" + ] + } + ], + "source": [ + "predictions = pipeline.predict(X_test)\n", + "\n", + "correct = 0\n", + "count = 0\n", + "while count < len(predictions):\n", + " if y_test[count] == predictions[count]:\n", + " correct += 1\n", + "\n", + " count += 1\n", + "\n", + "print('Total Percent Correct: ', correct/count)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}