NaiveBayesSpamFilterScratch.ipynb (27512B)
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": {}, 6 "source": [ 7 "This dataset is from kaggle:\n", 8 "\n", 9 "https://www.kaggle.com/datasets/abdallahwagih/spam-emails\n", 10 "Download this, move it to the correct location and then start working. " 11 ] 12 }, 13 { 14 "cell_type": "markdown", 15 "metadata": {}, 16 "source": [ 17 "The implementation from sklearn uses a more sophisticated tokenization strategy which makes it better than mine. As such, while I got 80% accuracy with Naive Bayes, sklearn got 97%" 18 ] 19 }, 20 { 21 "cell_type": "code", 22 "execution_count": 474, 23 "metadata": {}, 24 "outputs": [], 25 "source": [ 26 "import pandas as pd\n", 27 "\n", 28 "emails = pd.read_csv('../datasets/spamEmails/emails.csv')" 29 ] 30 }, 31 { 32 "cell_type": "code", 33 "execution_count": 475, 34 "metadata": {}, 35 "outputs": [], 36 "source": [ 37 "X = emails['Message']\n", 38 "y = emails['Category']\n", 39 "\n", 40 "y = y == 'spam'" 41 ] 42 }, 43 { 44 "cell_type": "code", 45 "execution_count": 476, 46 "metadata": {}, 47 "outputs": [], 48 "source": [ 49 "import re\n", 50 "\n", 51 "\n", 52 "def removePunctuation(text):\n", 53 " return re.sub(r'[^a-zA-Z\\s]', '', text)\n", 54 "\n", 55 "X = X.apply(removePunctuation)" 56 ] 57 }, 58 { 59 "cell_type": "code", 60 "execution_count": 477, 61 "metadata": {}, 62 "outputs": [], 63 "source": [ 64 "from sklearn.model_selection import train_test_split\n", 65 "X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=10)" 66 ] 67 }, 68 { 69 "cell_type": "code", 70 "execution_count": 478, 71 "metadata": {}, 72 "outputs": [], 73 "source": [ 74 "def uniqueWords(text : str):\n", 75 " text = text.lower()\n", 76 " textLs = text.split(' ')\n", 77 " # Init set with ''\n", 78 " opts = {''}\n", 79 " for i in textLs:\n", 80 " opts.add(i)\n", 81 " # Remove ''\n", 82 " opts.remove('')\n", 83 " return opts\n", 84 "\n", 85 "\n", 86 "wcHam = {}\n", 87 "wcSpam = {}\n" 88 ] 89 }, 90 { 91 "cell_type": "code", 92 "execution_count": 479, 93 "metadata": {}, 94 "outputs": [], 95 "source": [ 96 "i = 0\n", 97 "while i < len(X_train):\n", 98 " if(y_train.iloc[i]) == 1:\n", 99 " words = uniqueWords(X_train.iloc[i])\n", 100 " for w in words:\n", 101 " count = 1\n", 102 " if w in wcSpam:\n", 103 " count = wcSpam[w] + 1\n", 104 " wcSpam[w] = count\n", 105 "\n", 106 " else:\n", 107 " words = uniqueWords(X_train.iloc[i])\n", 108 " for w in words:\n", 109 " count = 1\n", 110 " if w in wcHam:\n", 111 " count = wcHam[w] + 1\n", 112 " wcHam[w] = count\n", 113 " i += 1" 114 ] 115 }, 116 { 117 "cell_type": "code", 118 "execution_count": 480, 119 "metadata": {}, 120 "outputs": [ 121 { 122 "name": "stdout", 123 "output_type": "stream", 124 "text": [ 125 "6238\n", 126 "1966\n" 127 ] 128 } 129 ], 130 "source": [ 131 "print(len(wcHam))\n", 132 "print(len(wcSpam))" 133 ] 134 }, 135 { 136 "cell_type": "code", 137 "execution_count": 481, 138 "metadata": {}, 139 "outputs": [ 140 { 141 "name": "stdout", 142 "output_type": "stream", 143 "text": [ 144 "7322\n", 145 "7322\n" 146 ] 147 } 148 ], 149 "source": [ 150 "# Add one to each value to ensure there are no 0 probabilities (pseudo count)\n", 151 "# With a zero probability it would mess up naive bayes calculations\n", 152 "\n", 153 "keys_combined = set(wcHam.keys()).union(wcSpam.keys())\n", 154 "\n", 155 "for i in keys_combined:\n", 156 " if i in wcHam:\n", 157 " wcHam[i] = wcHam[i] + 1\n", 158 " else:\n", 159 " wcHam[i] = 1\n", 160 " \n", 161 " if i in wcSpam:\n", 162 " wcSpam[i] = wcSpam[i] + 1\n", 163 " else:\n", 164 " wcSpam[i] = 1\n", 165 "\n", 166 "print(len(wcSpam))\n", 167 "print(len(wcHam))" 168 ] 169 }, 170 { 171 "cell_type": "code", 172 "execution_count": 482, 173 "metadata": {}, 174 "outputs": [ 175 { 176 "name": "stderr", 177 "output_type": "stream", 178 "text": [ 179 "/tmp/ipykernel_5255/3310267757.py:2: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n", 180 " hamCount = y_train.value_counts()[0]\n", 181 "/tmp/ipykernel_5255/3310267757.py:3: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n", 182 " spamCount = y_train.value_counts()[1]\n" 183 ] 184 } 185 ], 186 "source": [ 187 "# Assign probabilities of each word occuring in any given spam/ham message\n", 188 "hamCount = y_train.value_counts()[0]\n", 189 "spamCount = y_train.value_counts()[1]\n", 190 "\n", 191 "for key in wcHam:\n", 192 " wcHam[key] = wcHam[key] / hamCount\n", 193 "\n", 194 "for key in wcSpam:\n", 195 " wcSpam[key] = wcSpam[key] / spamCount" 196 ] 197 }, 198 { 199 "cell_type": "code", 200 "execution_count": 483, 201 "metadata": {}, 202 "outputs": [], 203 "source": [ 204 "# Dumbass forgot to multiply by probability of each class at the start\n", 205 "\n", 206 "def predict(message):\n", 207 " words = message.split(' ')\n", 208 " spamPercent = spamCount / (spamCount + hamCount)\n", 209 " hamPercent = hamCount / (spamCount + hamCount)\n", 210 "\n", 211 " for word in words:\n", 212 " if word in wcSpam:\n", 213 " spamPercent = spamPercent * wcSpam[word]\n", 214 " hamPercent = hamPercent * wcHam[word]\n", 215 " \n", 216 " return spamPercent > hamPercent" 217 ] 218 }, 219 { 220 "cell_type": "code", 221 "execution_count": 484, 222 "metadata": {}, 223 "outputs": [], 224 "source": [ 225 "\n", 226 "predictions = []\n", 227 "\n", 228 "for i in X_test:\n", 229 " predictions.append(predict(i))\n", 230 "\n", 231 "y_test = y_test.to_list()" 232 ] 233 }, 234 { 235 "cell_type": "code", 236 "execution_count": 485, 237 "metadata": {}, 238 "outputs": [ 239 { 240 "name": "stdout", 241 "output_type": "stream", 242 "text": [ 243 "Total Percent Correct: 0.7982770997846375\n" 244 ] 245 } 246 ], 247 "source": [ 248 "correct = 0\n", 249 "while count < len(predictions):\n", 250 " if y_test[count] == predictions[count]:\n", 251 " correct += 1\n", 252 " \n", 253 " count += 1\n", 254 "\n", 255 "print('Total Percent Correct: ', correct/count)" 256 ] 257 }, 258 { 259 "cell_type": "code", 260 "execution_count": 486, 261 "metadata": {}, 262 "outputs": [ 263 { 264 "data": { 265 "text/html": [ 266 "<style>#sk-container-id-15 {\n", 267 " /* Definition of color scheme common for light and dark mode */\n", 268 " --sklearn-color-text: black;\n", 269 " --sklearn-color-line: gray;\n", 270 " /* Definition of color scheme for unfitted estimators */\n", 271 " --sklearn-color-unfitted-level-0: #fff5e6;\n", 272 " --sklearn-color-unfitted-level-1: #f6e4d2;\n", 273 " --sklearn-color-unfitted-level-2: #ffe0b3;\n", 274 " --sklearn-color-unfitted-level-3: chocolate;\n", 275 " /* Definition of color scheme for fitted estimators */\n", 276 " --sklearn-color-fitted-level-0: #f0f8ff;\n", 277 " --sklearn-color-fitted-level-1: #d4ebff;\n", 278 " --sklearn-color-fitted-level-2: #b3dbfd;\n", 279 " --sklearn-color-fitted-level-3: cornflowerblue;\n", 280 "\n", 281 " /* Specific color for light theme */\n", 282 " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", 283 " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", 284 " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", 285 " --sklearn-color-icon: #696969;\n", 286 "\n", 287 " @media (prefers-color-scheme: dark) {\n", 288 " /* Redefinition of color scheme for dark theme */\n", 289 " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", 290 " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", 291 " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", 292 " --sklearn-color-icon: #878787;\n", 293 " }\n", 294 "}\n", 295 "\n", 296 "#sk-container-id-15 {\n", 297 " color: var(--sklearn-color-text);\n", 298 "}\n", 299 "\n", 300 "#sk-container-id-15 pre {\n", 301 " padding: 0;\n", 302 "}\n", 303 "\n", 304 "#sk-container-id-15 input.sk-hidden--visually {\n", 305 " border: 0;\n", 306 " clip: rect(1px 1px 1px 1px);\n", 307 " clip: rect(1px, 1px, 1px, 1px);\n", 308 " height: 1px;\n", 309 " margin: -1px;\n", 310 " overflow: hidden;\n", 311 " padding: 0;\n", 312 " position: absolute;\n", 313 " width: 1px;\n", 314 "}\n", 315 "\n", 316 "#sk-container-id-15 div.sk-dashed-wrapped {\n", 317 " border: 1px dashed var(--sklearn-color-line);\n", 318 " margin: 0 0.4em 0.5em 0.4em;\n", 319 " box-sizing: border-box;\n", 320 " padding-bottom: 0.4em;\n", 321 " background-color: var(--sklearn-color-background);\n", 322 "}\n", 323 "\n", 324 "#sk-container-id-15 div.sk-container {\n", 325 " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", 326 " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", 327 " so we also need the `!important` here to be able to override the\n", 328 " default hidden behavior on the sphinx rendered scikit-learn.org.\n", 329 " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", 330 " display: inline-block !important;\n", 331 " position: relative;\n", 332 "}\n", 333 "\n", 334 "#sk-container-id-15 div.sk-text-repr-fallback {\n", 335 " display: none;\n", 336 "}\n", 337 "\n", 338 "div.sk-parallel-item,\n", 339 "div.sk-serial,\n", 340 "div.sk-item {\n", 341 " /* draw centered vertical line to link estimators */\n", 342 " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", 343 " background-size: 2px 100%;\n", 344 " background-repeat: no-repeat;\n", 345 " background-position: center center;\n", 346 "}\n", 347 "\n", 348 "/* Parallel-specific style estimator block */\n", 349 "\n", 350 "#sk-container-id-15 div.sk-parallel-item::after {\n", 351 " content: \"\";\n", 352 " width: 100%;\n", 353 " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", 354 " flex-grow: 1;\n", 355 "}\n", 356 "\n", 357 "#sk-container-id-15 div.sk-parallel {\n", 358 " display: flex;\n", 359 " align-items: stretch;\n", 360 " justify-content: center;\n", 361 " background-color: var(--sklearn-color-background);\n", 362 " position: relative;\n", 363 "}\n", 364 "\n", 365 "#sk-container-id-15 div.sk-parallel-item {\n", 366 " display: flex;\n", 367 " flex-direction: column;\n", 368 "}\n", 369 "\n", 370 "#sk-container-id-15 div.sk-parallel-item:first-child::after {\n", 371 " align-self: flex-end;\n", 372 " width: 50%;\n", 373 "}\n", 374 "\n", 375 "#sk-container-id-15 div.sk-parallel-item:last-child::after {\n", 376 " align-self: flex-start;\n", 377 " width: 50%;\n", 378 "}\n", 379 "\n", 380 "#sk-container-id-15 div.sk-parallel-item:only-child::after {\n", 381 " width: 0;\n", 382 "}\n", 383 "\n", 384 "/* Serial-specific style estimator block */\n", 385 "\n", 386 "#sk-container-id-15 div.sk-serial {\n", 387 " display: flex;\n", 388 " flex-direction: column;\n", 389 " align-items: center;\n", 390 " background-color: var(--sklearn-color-background);\n", 391 " padding-right: 1em;\n", 392 " padding-left: 1em;\n", 393 "}\n", 394 "\n", 395 "\n", 396 "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", 397 "clickable and can be expanded/collapsed.\n", 398 "- Pipeline and ColumnTransformer use this feature and define the default style\n", 399 "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", 400 "*/\n", 401 "\n", 402 "/* Pipeline and ColumnTransformer style (default) */\n", 403 "\n", 404 "#sk-container-id-15 div.sk-toggleable {\n", 405 " /* Default theme specific background. It is overwritten whether we have a\n", 406 " specific estimator or a Pipeline/ColumnTransformer */\n", 407 " background-color: var(--sklearn-color-background);\n", 408 "}\n", 409 "\n", 410 "/* Toggleable label */\n", 411 "#sk-container-id-15 label.sk-toggleable__label {\n", 412 " cursor: pointer;\n", 413 " display: block;\n", 414 " width: 100%;\n", 415 " margin-bottom: 0;\n", 416 " padding: 0.5em;\n", 417 " box-sizing: border-box;\n", 418 " text-align: center;\n", 419 "}\n", 420 "\n", 421 "#sk-container-id-15 label.sk-toggleable__label-arrow:before {\n", 422 " /* Arrow on the left of the label */\n", 423 " content: \"▸\";\n", 424 " float: left;\n", 425 " margin-right: 0.25em;\n", 426 " color: var(--sklearn-color-icon);\n", 427 "}\n", 428 "\n", 429 "#sk-container-id-15 label.sk-toggleable__label-arrow:hover:before {\n", 430 " color: var(--sklearn-color-text);\n", 431 "}\n", 432 "\n", 433 "/* Toggleable content - dropdown */\n", 434 "\n", 435 "#sk-container-id-15 div.sk-toggleable__content {\n", 436 " max-height: 0;\n", 437 " max-width: 0;\n", 438 " overflow: hidden;\n", 439 " text-align: left;\n", 440 " /* unfitted */\n", 441 " background-color: var(--sklearn-color-unfitted-level-0);\n", 442 "}\n", 443 "\n", 444 "#sk-container-id-15 div.sk-toggleable__content.fitted {\n", 445 " /* fitted */\n", 446 " background-color: var(--sklearn-color-fitted-level-0);\n", 447 "}\n", 448 "\n", 449 "#sk-container-id-15 div.sk-toggleable__content pre {\n", 450 " margin: 0.2em;\n", 451 " border-radius: 0.25em;\n", 452 " color: var(--sklearn-color-text);\n", 453 " /* unfitted */\n", 454 " background-color: var(--sklearn-color-unfitted-level-0);\n", 455 "}\n", 456 "\n", 457 "#sk-container-id-15 div.sk-toggleable__content.fitted pre {\n", 458 " /* unfitted */\n", 459 " background-color: var(--sklearn-color-fitted-level-0);\n", 460 "}\n", 461 "\n", 462 "#sk-container-id-15 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", 463 " /* Expand drop-down */\n", 464 " max-height: 200px;\n", 465 " max-width: 100%;\n", 466 " overflow: auto;\n", 467 "}\n", 468 "\n", 469 "#sk-container-id-15 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", 470 " content: \"▾\";\n", 471 "}\n", 472 "\n", 473 "/* Pipeline/ColumnTransformer-specific style */\n", 474 "\n", 475 "#sk-container-id-15 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 476 " color: var(--sklearn-color-text);\n", 477 " background-color: var(--sklearn-color-unfitted-level-2);\n", 478 "}\n", 479 "\n", 480 "#sk-container-id-15 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 481 " background-color: var(--sklearn-color-fitted-level-2);\n", 482 "}\n", 483 "\n", 484 "/* Estimator-specific style */\n", 485 "\n", 486 "/* Colorize estimator box */\n", 487 "#sk-container-id-15 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 488 " /* unfitted */\n", 489 " background-color: var(--sklearn-color-unfitted-level-2);\n", 490 "}\n", 491 "\n", 492 "#sk-container-id-15 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 493 " /* fitted */\n", 494 " background-color: var(--sklearn-color-fitted-level-2);\n", 495 "}\n", 496 "\n", 497 "#sk-container-id-15 div.sk-label label.sk-toggleable__label,\n", 498 "#sk-container-id-15 div.sk-label label {\n", 499 " /* The background is the default theme color */\n", 500 " color: var(--sklearn-color-text-on-default-background);\n", 501 "}\n", 502 "\n", 503 "/* On hover, darken the color of the background */\n", 504 "#sk-container-id-15 div.sk-label:hover label.sk-toggleable__label {\n", 505 " color: var(--sklearn-color-text);\n", 506 " background-color: var(--sklearn-color-unfitted-level-2);\n", 507 "}\n", 508 "\n", 509 "/* Label box, darken color on hover, fitted */\n", 510 "#sk-container-id-15 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", 511 " color: var(--sklearn-color-text);\n", 512 " background-color: var(--sklearn-color-fitted-level-2);\n", 513 "}\n", 514 "\n", 515 "/* Estimator label */\n", 516 "\n", 517 "#sk-container-id-15 div.sk-label label {\n", 518 " font-family: monospace;\n", 519 " font-weight: bold;\n", 520 " display: inline-block;\n", 521 " line-height: 1.2em;\n", 522 "}\n", 523 "\n", 524 "#sk-container-id-15 div.sk-label-container {\n", 525 " text-align: center;\n", 526 "}\n", 527 "\n", 528 "/* Estimator-specific */\n", 529 "#sk-container-id-15 div.sk-estimator {\n", 530 " font-family: monospace;\n", 531 " border: 1px dotted var(--sklearn-color-border-box);\n", 532 " border-radius: 0.25em;\n", 533 " box-sizing: border-box;\n", 534 " margin-bottom: 0.5em;\n", 535 " /* unfitted */\n", 536 " background-color: var(--sklearn-color-unfitted-level-0);\n", 537 "}\n", 538 "\n", 539 "#sk-container-id-15 div.sk-estimator.fitted {\n", 540 " /* fitted */\n", 541 " background-color: var(--sklearn-color-fitted-level-0);\n", 542 "}\n", 543 "\n", 544 "/* on hover */\n", 545 "#sk-container-id-15 div.sk-estimator:hover {\n", 546 " /* unfitted */\n", 547 " background-color: var(--sklearn-color-unfitted-level-2);\n", 548 "}\n", 549 "\n", 550 "#sk-container-id-15 div.sk-estimator.fitted:hover {\n", 551 " /* fitted */\n", 552 " background-color: var(--sklearn-color-fitted-level-2);\n", 553 "}\n", 554 "\n", 555 "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", 556 "\n", 557 "/* Common style for \"i\" and \"?\" */\n", 558 "\n", 559 ".sk-estimator-doc-link,\n", 560 "a:link.sk-estimator-doc-link,\n", 561 "a:visited.sk-estimator-doc-link {\n", 562 " float: right;\n", 563 " font-size: smaller;\n", 564 " line-height: 1em;\n", 565 " font-family: monospace;\n", 566 " background-color: var(--sklearn-color-background);\n", 567 " border-radius: 1em;\n", 568 " height: 1em;\n", 569 " width: 1em;\n", 570 " text-decoration: none !important;\n", 571 " margin-left: 1ex;\n", 572 " /* unfitted */\n", 573 " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", 574 " color: var(--sklearn-color-unfitted-level-1);\n", 575 "}\n", 576 "\n", 577 ".sk-estimator-doc-link.fitted,\n", 578 "a:link.sk-estimator-doc-link.fitted,\n", 579 "a:visited.sk-estimator-doc-link.fitted {\n", 580 " /* fitted */\n", 581 " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", 582 " color: var(--sklearn-color-fitted-level-1);\n", 583 "}\n", 584 "\n", 585 "/* On hover */\n", 586 "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", 587 ".sk-estimator-doc-link:hover,\n", 588 "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", 589 ".sk-estimator-doc-link:hover {\n", 590 " /* unfitted */\n", 591 " background-color: var(--sklearn-color-unfitted-level-3);\n", 592 " color: var(--sklearn-color-background);\n", 593 " text-decoration: none;\n", 594 "}\n", 595 "\n", 596 "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", 597 ".sk-estimator-doc-link.fitted:hover,\n", 598 "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", 599 ".sk-estimator-doc-link.fitted:hover {\n", 600 " /* fitted */\n", 601 " background-color: var(--sklearn-color-fitted-level-3);\n", 602 " color: var(--sklearn-color-background);\n", 603 " text-decoration: none;\n", 604 "}\n", 605 "\n", 606 "/* Span, style for the box shown on hovering the info icon */\n", 607 ".sk-estimator-doc-link span {\n", 608 " display: none;\n", 609 " z-index: 9999;\n", 610 " position: relative;\n", 611 " font-weight: normal;\n", 612 " right: .2ex;\n", 613 " padding: .5ex;\n", 614 " margin: .5ex;\n", 615 " width: min-content;\n", 616 " min-width: 20ex;\n", 617 " max-width: 50ex;\n", 618 " color: var(--sklearn-color-text);\n", 619 " box-shadow: 2pt 2pt 4pt #999;\n", 620 " /* unfitted */\n", 621 " background: var(--sklearn-color-unfitted-level-0);\n", 622 " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", 623 "}\n", 624 "\n", 625 ".sk-estimator-doc-link.fitted span {\n", 626 " /* fitted */\n", 627 " background: var(--sklearn-color-fitted-level-0);\n", 628 " border: var(--sklearn-color-fitted-level-3);\n", 629 "}\n", 630 "\n", 631 ".sk-estimator-doc-link:hover span {\n", 632 " display: block;\n", 633 "}\n", 634 "\n", 635 "/* \"?\"-specific style due to the `<a>` HTML tag */\n", 636 "\n", 637 "#sk-container-id-15 a.estimator_doc_link {\n", 638 " float: right;\n", 639 " font-size: 1rem;\n", 640 " line-height: 1em;\n", 641 " font-family: monospace;\n", 642 " background-color: var(--sklearn-color-background);\n", 643 " border-radius: 1rem;\n", 644 " height: 1rem;\n", 645 " width: 1rem;\n", 646 " text-decoration: none;\n", 647 " /* unfitted */\n", 648 " color: var(--sklearn-color-unfitted-level-1);\n", 649 " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", 650 "}\n", 651 "\n", 652 "#sk-container-id-15 a.estimator_doc_link.fitted {\n", 653 " /* fitted */\n", 654 " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", 655 " color: var(--sklearn-color-fitted-level-1);\n", 656 "}\n", 657 "\n", 658 "/* On hover */\n", 659 "#sk-container-id-15 a.estimator_doc_link:hover {\n", 660 " /* unfitted */\n", 661 " background-color: var(--sklearn-color-unfitted-level-3);\n", 662 " color: var(--sklearn-color-background);\n", 663 " text-decoration: none;\n", 664 "}\n", 665 "\n", 666 "#sk-container-id-15 a.estimator_doc_link.fitted:hover {\n", 667 " /* fitted */\n", 668 " background-color: var(--sklearn-color-fitted-level-3);\n", 669 "}\n", 670 "</style><div id=\"sk-container-id-15\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-43\" type=\"checkbox\" ><label for=\"sk-estimator-id-43\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> Pipeline<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html\">?<span>Documentation for Pipeline</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-44\" type=\"checkbox\" ><label for=\"sk-estimator-id-44\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> CountVectorizer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html\">?<span>Documentation for CountVectorizer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>CountVectorizer()</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-45\" type=\"checkbox\" ><label for=\"sk-estimator-id-45\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> MultinomialNB<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.naive_bayes.MultinomialNB.html\">?<span>Documentation for MultinomialNB</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>MultinomialNB()</pre></div> </div></div></div></div></div></div>" 671 ], 672 "text/plain": [ 673 "Pipeline(steps=[('vect', CountVectorizer()), ('clf', MultinomialNB())])" 674 ] 675 }, 676 "execution_count": 486, 677 "metadata": {}, 678 "output_type": "execute_result" 679 } 680 ], 681 "source": [ 682 "from sklearn.naive_bayes import MultinomialNB\n", 683 "from sklearn.pipeline import Pipeline\n", 684 "from sklearn.feature_extraction.text import CountVectorizer\n", 685 "\n", 686 "pipeline = Pipeline([\n", 687 " ('vect', CountVectorizer()), # Use CountVectorizer to convert text into token counts\n", 688 " ('clf', MultinomialNB()), # Naive Bayes classifier\n", 689 "])\n", 690 "\n", 691 "# Fit the model on the training data\n", 692 "pipeline.fit(X_train, y_train)" 693 ] 694 }, 695 { 696 "cell_type": "code", 697 "execution_count": 487, 698 "metadata": {}, 699 "outputs": [ 700 { 701 "name": "stdout", 702 "output_type": "stream", 703 "text": [ 704 "Total Percent Correct: 0.9777458722182341\n" 705 ] 706 } 707 ], 708 "source": [ 709 "predictions = pipeline.predict(X_test)\n", 710 "\n", 711 "correct = 0\n", 712 "count = 0\n", 713 "while count < len(predictions):\n", 714 " if y_test[count] == predictions[count]:\n", 715 " correct += 1\n", 716 "\n", 717 " count += 1\n", 718 "\n", 719 "print('Total Percent Correct: ', correct/count)" 720 ] 721 } 722 ], 723 "metadata": { 724 "kernelspec": { 725 "display_name": ".venv", 726 "language": "python", 727 "name": "python3" 728 }, 729 "language_info": { 730 "codemirror_mode": { 731 "name": "ipython", 732 "version": 3 733 }, 734 "file_extension": ".py", 735 "mimetype": "text/x-python", 736 "name": "python", 737 "nbconvert_exporter": "python", 738 "pygments_lexer": "ipython3", 739 "version": "3.11.2" 740 } 741 }, 742 "nbformat": 4, 743 "nbformat_minor": 2 744 }