StochasticGradientDescentLinearRegression.ipynb (20529B)
1 { 2 "cells": [ 3 { 4 "cell_type": "code", 5 "execution_count": 9, 6 "metadata": {}, 7 "outputs": [], 8 "source": [ 9 "import numpy as np\n", 10 "from sklearn.preprocessing import add_dummy_feature\n", 11 "\n", 12 "# init\n", 13 "m = 100\n", 14 "X = 2*np.random.rand(m,1)\n", 15 "y = 4+3 * X + np.random.randn(m,1)\n", 16 "X_b = add_dummy_feature(X)\n", 17 "\n", 18 "# Stochastic gradient descent\n", 19 "epochs = 50\n", 20 "t0,t1 = 5,50 # Learning schedule\n", 21 "\n", 22 "def learning_schedule(t):\n", 23 " return t0 / (t+t1)\n", 24 "\n", 25 "np.random.seed(42)\n", 26 "theta = np.random.randn(2,1)\n", 27 "\n", 28 "for epoch in range(epochs):\n", 29 " for iteration in range(m):\n", 30 " random_index = np.random.randint(m)\n", 31 " xi = X_b[random_index : random_index + 1]\n", 32 " yi = y[random_index : random_index + 1]\n", 33 " gradients = 2*xi.T @ (xi @ theta - yi)\n", 34 " eta = learning_schedule(epoch * m + iteration)\n", 35 " theta = theta - eta * gradients\n" 36 ] 37 }, 38 { 39 "cell_type": "code", 40 "execution_count": 10, 41 "metadata": {}, 42 "outputs": [ 43 { 44 "name": "stdout", 45 "output_type": "stream", 46 "text": [ 47 "[[4.07817553]\n", 48 " [3.05643471]]\n" 49 ] 50 } 51 ], 52 "source": [ 53 "print(theta)" 54 ] 55 }, 56 { 57 "cell_type": "code", 58 "execution_count": 14, 59 "metadata": {}, 60 "outputs": [ 61 { 62 "data": { 63 "text/html": [ 64 "<style>#sk-container-id-3 {\n", 65 " /* Definition of color scheme common for light and dark mode */\n", 66 " --sklearn-color-text: black;\n", 67 " --sklearn-color-line: gray;\n", 68 " /* Definition of color scheme for unfitted estimators */\n", 69 " --sklearn-color-unfitted-level-0: #fff5e6;\n", 70 " --sklearn-color-unfitted-level-1: #f6e4d2;\n", 71 " --sklearn-color-unfitted-level-2: #ffe0b3;\n", 72 " --sklearn-color-unfitted-level-3: chocolate;\n", 73 " /* Definition of color scheme for fitted estimators */\n", 74 " --sklearn-color-fitted-level-0: #f0f8ff;\n", 75 " --sklearn-color-fitted-level-1: #d4ebff;\n", 76 " --sklearn-color-fitted-level-2: #b3dbfd;\n", 77 " --sklearn-color-fitted-level-3: cornflowerblue;\n", 78 "\n", 79 " /* Specific color for light theme */\n", 80 " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", 81 " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", 82 " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", 83 " --sklearn-color-icon: #696969;\n", 84 "\n", 85 " @media (prefers-color-scheme: dark) {\n", 86 " /* Redefinition of color scheme for dark theme */\n", 87 " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", 88 " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", 89 " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", 90 " --sklearn-color-icon: #878787;\n", 91 " }\n", 92 "}\n", 93 "\n", 94 "#sk-container-id-3 {\n", 95 " color: var(--sklearn-color-text);\n", 96 "}\n", 97 "\n", 98 "#sk-container-id-3 pre {\n", 99 " padding: 0;\n", 100 "}\n", 101 "\n", 102 "#sk-container-id-3 input.sk-hidden--visually {\n", 103 " border: 0;\n", 104 " clip: rect(1px 1px 1px 1px);\n", 105 " clip: rect(1px, 1px, 1px, 1px);\n", 106 " height: 1px;\n", 107 " margin: -1px;\n", 108 " overflow: hidden;\n", 109 " padding: 0;\n", 110 " position: absolute;\n", 111 " width: 1px;\n", 112 "}\n", 113 "\n", 114 "#sk-container-id-3 div.sk-dashed-wrapped {\n", 115 " border: 1px dashed var(--sklearn-color-line);\n", 116 " margin: 0 0.4em 0.5em 0.4em;\n", 117 " box-sizing: border-box;\n", 118 " padding-bottom: 0.4em;\n", 119 " background-color: var(--sklearn-color-background);\n", 120 "}\n", 121 "\n", 122 "#sk-container-id-3 div.sk-container {\n", 123 " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", 124 " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", 125 " so we also need the `!important` here to be able to override the\n", 126 " default hidden behavior on the sphinx rendered scikit-learn.org.\n", 127 " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", 128 " display: inline-block !important;\n", 129 " position: relative;\n", 130 "}\n", 131 "\n", 132 "#sk-container-id-3 div.sk-text-repr-fallback {\n", 133 " display: none;\n", 134 "}\n", 135 "\n", 136 "div.sk-parallel-item,\n", 137 "div.sk-serial,\n", 138 "div.sk-item {\n", 139 " /* draw centered vertical line to link estimators */\n", 140 " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", 141 " background-size: 2px 100%;\n", 142 " background-repeat: no-repeat;\n", 143 " background-position: center center;\n", 144 "}\n", 145 "\n", 146 "/* Parallel-specific style estimator block */\n", 147 "\n", 148 "#sk-container-id-3 div.sk-parallel-item::after {\n", 149 " content: \"\";\n", 150 " width: 100%;\n", 151 " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", 152 " flex-grow: 1;\n", 153 "}\n", 154 "\n", 155 "#sk-container-id-3 div.sk-parallel {\n", 156 " display: flex;\n", 157 " align-items: stretch;\n", 158 " justify-content: center;\n", 159 " background-color: var(--sklearn-color-background);\n", 160 " position: relative;\n", 161 "}\n", 162 "\n", 163 "#sk-container-id-3 div.sk-parallel-item {\n", 164 " display: flex;\n", 165 " flex-direction: column;\n", 166 "}\n", 167 "\n", 168 "#sk-container-id-3 div.sk-parallel-item:first-child::after {\n", 169 " align-self: flex-end;\n", 170 " width: 50%;\n", 171 "}\n", 172 "\n", 173 "#sk-container-id-3 div.sk-parallel-item:last-child::after {\n", 174 " align-self: flex-start;\n", 175 " width: 50%;\n", 176 "}\n", 177 "\n", 178 "#sk-container-id-3 div.sk-parallel-item:only-child::after {\n", 179 " width: 0;\n", 180 "}\n", 181 "\n", 182 "/* Serial-specific style estimator block */\n", 183 "\n", 184 "#sk-container-id-3 div.sk-serial {\n", 185 " display: flex;\n", 186 " flex-direction: column;\n", 187 " align-items: center;\n", 188 " background-color: var(--sklearn-color-background);\n", 189 " padding-right: 1em;\n", 190 " padding-left: 1em;\n", 191 "}\n", 192 "\n", 193 "\n", 194 "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", 195 "clickable and can be expanded/collapsed.\n", 196 "- Pipeline and ColumnTransformer use this feature and define the default style\n", 197 "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", 198 "*/\n", 199 "\n", 200 "/* Pipeline and ColumnTransformer style (default) */\n", 201 "\n", 202 "#sk-container-id-3 div.sk-toggleable {\n", 203 " /* Default theme specific background. It is overwritten whether we have a\n", 204 " specific estimator or a Pipeline/ColumnTransformer */\n", 205 " background-color: var(--sklearn-color-background);\n", 206 "}\n", 207 "\n", 208 "/* Toggleable label */\n", 209 "#sk-container-id-3 label.sk-toggleable__label {\n", 210 " cursor: pointer;\n", 211 " display: block;\n", 212 " width: 100%;\n", 213 " margin-bottom: 0;\n", 214 " padding: 0.5em;\n", 215 " box-sizing: border-box;\n", 216 " text-align: center;\n", 217 "}\n", 218 "\n", 219 "#sk-container-id-3 label.sk-toggleable__label-arrow:before {\n", 220 " /* Arrow on the left of the label */\n", 221 " content: \"▸\";\n", 222 " float: left;\n", 223 " margin-right: 0.25em;\n", 224 " color: var(--sklearn-color-icon);\n", 225 "}\n", 226 "\n", 227 "#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {\n", 228 " color: var(--sklearn-color-text);\n", 229 "}\n", 230 "\n", 231 "/* Toggleable content - dropdown */\n", 232 "\n", 233 "#sk-container-id-3 div.sk-toggleable__content {\n", 234 " max-height: 0;\n", 235 " max-width: 0;\n", 236 " overflow: hidden;\n", 237 " text-align: left;\n", 238 " /* unfitted */\n", 239 " background-color: var(--sklearn-color-unfitted-level-0);\n", 240 "}\n", 241 "\n", 242 "#sk-container-id-3 div.sk-toggleable__content.fitted {\n", 243 " /* fitted */\n", 244 " background-color: var(--sklearn-color-fitted-level-0);\n", 245 "}\n", 246 "\n", 247 "#sk-container-id-3 div.sk-toggleable__content pre {\n", 248 " margin: 0.2em;\n", 249 " border-radius: 0.25em;\n", 250 " color: var(--sklearn-color-text);\n", 251 " /* unfitted */\n", 252 " background-color: var(--sklearn-color-unfitted-level-0);\n", 253 "}\n", 254 "\n", 255 "#sk-container-id-3 div.sk-toggleable__content.fitted pre {\n", 256 " /* unfitted */\n", 257 " background-color: var(--sklearn-color-fitted-level-0);\n", 258 "}\n", 259 "\n", 260 "#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", 261 " /* Expand drop-down */\n", 262 " max-height: 200px;\n", 263 " max-width: 100%;\n", 264 " overflow: auto;\n", 265 "}\n", 266 "\n", 267 "#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", 268 " content: \"▾\";\n", 269 "}\n", 270 "\n", 271 "/* Pipeline/ColumnTransformer-specific style */\n", 272 "\n", 273 "#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 274 " color: var(--sklearn-color-text);\n", 275 " background-color: var(--sklearn-color-unfitted-level-2);\n", 276 "}\n", 277 "\n", 278 "#sk-container-id-3 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 279 " background-color: var(--sklearn-color-fitted-level-2);\n", 280 "}\n", 281 "\n", 282 "/* Estimator-specific style */\n", 283 "\n", 284 "/* Colorize estimator box */\n", 285 "#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 286 " /* unfitted */\n", 287 " background-color: var(--sklearn-color-unfitted-level-2);\n", 288 "}\n", 289 "\n", 290 "#sk-container-id-3 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", 291 " /* fitted */\n", 292 " background-color: var(--sklearn-color-fitted-level-2);\n", 293 "}\n", 294 "\n", 295 "#sk-container-id-3 div.sk-label label.sk-toggleable__label,\n", 296 "#sk-container-id-3 div.sk-label label {\n", 297 " /* The background is the default theme color */\n", 298 " color: var(--sklearn-color-text-on-default-background);\n", 299 "}\n", 300 "\n", 301 "/* On hover, darken the color of the background */\n", 302 "#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {\n", 303 " color: var(--sklearn-color-text);\n", 304 " background-color: var(--sklearn-color-unfitted-level-2);\n", 305 "}\n", 306 "\n", 307 "/* Label box, darken color on hover, fitted */\n", 308 "#sk-container-id-3 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", 309 " color: var(--sklearn-color-text);\n", 310 " background-color: var(--sklearn-color-fitted-level-2);\n", 311 "}\n", 312 "\n", 313 "/* Estimator label */\n", 314 "\n", 315 "#sk-container-id-3 div.sk-label label {\n", 316 " font-family: monospace;\n", 317 " font-weight: bold;\n", 318 " display: inline-block;\n", 319 " line-height: 1.2em;\n", 320 "}\n", 321 "\n", 322 "#sk-container-id-3 div.sk-label-container {\n", 323 " text-align: center;\n", 324 "}\n", 325 "\n", 326 "/* Estimator-specific */\n", 327 "#sk-container-id-3 div.sk-estimator {\n", 328 " font-family: monospace;\n", 329 " border: 1px dotted var(--sklearn-color-border-box);\n", 330 " border-radius: 0.25em;\n", 331 " box-sizing: border-box;\n", 332 " margin-bottom: 0.5em;\n", 333 " /* unfitted */\n", 334 " background-color: var(--sklearn-color-unfitted-level-0);\n", 335 "}\n", 336 "\n", 337 "#sk-container-id-3 div.sk-estimator.fitted {\n", 338 " /* fitted */\n", 339 " background-color: var(--sklearn-color-fitted-level-0);\n", 340 "}\n", 341 "\n", 342 "/* on hover */\n", 343 "#sk-container-id-3 div.sk-estimator:hover {\n", 344 " /* unfitted */\n", 345 " background-color: var(--sklearn-color-unfitted-level-2);\n", 346 "}\n", 347 "\n", 348 "#sk-container-id-3 div.sk-estimator.fitted:hover {\n", 349 " /* fitted */\n", 350 " background-color: var(--sklearn-color-fitted-level-2);\n", 351 "}\n", 352 "\n", 353 "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", 354 "\n", 355 "/* Common style for \"i\" and \"?\" */\n", 356 "\n", 357 ".sk-estimator-doc-link,\n", 358 "a:link.sk-estimator-doc-link,\n", 359 "a:visited.sk-estimator-doc-link {\n", 360 " float: right;\n", 361 " font-size: smaller;\n", 362 " line-height: 1em;\n", 363 " font-family: monospace;\n", 364 " background-color: var(--sklearn-color-background);\n", 365 " border-radius: 1em;\n", 366 " height: 1em;\n", 367 " width: 1em;\n", 368 " text-decoration: none !important;\n", 369 " margin-left: 1ex;\n", 370 " /* unfitted */\n", 371 " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", 372 " color: var(--sklearn-color-unfitted-level-1);\n", 373 "}\n", 374 "\n", 375 ".sk-estimator-doc-link.fitted,\n", 376 "a:link.sk-estimator-doc-link.fitted,\n", 377 "a:visited.sk-estimator-doc-link.fitted {\n", 378 " /* fitted */\n", 379 " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", 380 " color: var(--sklearn-color-fitted-level-1);\n", 381 "}\n", 382 "\n", 383 "/* On hover */\n", 384 "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", 385 ".sk-estimator-doc-link:hover,\n", 386 "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", 387 ".sk-estimator-doc-link:hover {\n", 388 " /* unfitted */\n", 389 " background-color: var(--sklearn-color-unfitted-level-3);\n", 390 " color: var(--sklearn-color-background);\n", 391 " text-decoration: none;\n", 392 "}\n", 393 "\n", 394 "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", 395 ".sk-estimator-doc-link.fitted:hover,\n", 396 "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", 397 ".sk-estimator-doc-link.fitted:hover {\n", 398 " /* fitted */\n", 399 " background-color: var(--sklearn-color-fitted-level-3);\n", 400 " color: var(--sklearn-color-background);\n", 401 " text-decoration: none;\n", 402 "}\n", 403 "\n", 404 "/* Span, style for the box shown on hovering the info icon */\n", 405 ".sk-estimator-doc-link span {\n", 406 " display: none;\n", 407 " z-index: 9999;\n", 408 " position: relative;\n", 409 " font-weight: normal;\n", 410 " right: .2ex;\n", 411 " padding: .5ex;\n", 412 " margin: .5ex;\n", 413 " width: min-content;\n", 414 " min-width: 20ex;\n", 415 " max-width: 50ex;\n", 416 " color: var(--sklearn-color-text);\n", 417 " box-shadow: 2pt 2pt 4pt #999;\n", 418 " /* unfitted */\n", 419 " background: var(--sklearn-color-unfitted-level-0);\n", 420 " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", 421 "}\n", 422 "\n", 423 ".sk-estimator-doc-link.fitted span {\n", 424 " /* fitted */\n", 425 " background: var(--sklearn-color-fitted-level-0);\n", 426 " border: var(--sklearn-color-fitted-level-3);\n", 427 "}\n", 428 "\n", 429 ".sk-estimator-doc-link:hover span {\n", 430 " display: block;\n", 431 "}\n", 432 "\n", 433 "/* \"?\"-specific style due to the `<a>` HTML tag */\n", 434 "\n", 435 "#sk-container-id-3 a.estimator_doc_link {\n", 436 " float: right;\n", 437 " font-size: 1rem;\n", 438 " line-height: 1em;\n", 439 " font-family: monospace;\n", 440 " background-color: var(--sklearn-color-background);\n", 441 " border-radius: 1rem;\n", 442 " height: 1rem;\n", 443 " width: 1rem;\n", 444 " text-decoration: none;\n", 445 " /* unfitted */\n", 446 " color: var(--sklearn-color-unfitted-level-1);\n", 447 " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", 448 "}\n", 449 "\n", 450 "#sk-container-id-3 a.estimator_doc_link.fitted {\n", 451 " /* fitted */\n", 452 " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", 453 " color: var(--sklearn-color-fitted-level-1);\n", 454 "}\n", 455 "\n", 456 "/* On hover */\n", 457 "#sk-container-id-3 a.estimator_doc_link:hover {\n", 458 " /* unfitted */\n", 459 " background-color: var(--sklearn-color-unfitted-level-3);\n", 460 " color: var(--sklearn-color-background);\n", 461 " text-decoration: none;\n", 462 "}\n", 463 "\n", 464 "#sk-container-id-3 a.estimator_doc_link.fitted:hover {\n", 465 " /* fitted */\n", 466 " background-color: var(--sklearn-color-fitted-level-3);\n", 467 "}\n", 468 "</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>SGDRegressor(n_iter_no_change=100, penalty=None, random_state=42, tol=1e-05)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> SGDRegressor<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.linear_model.SGDRegressor.html\">?<span>Documentation for SGDRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>SGDRegressor(n_iter_no_change=100, penalty=None, random_state=42, tol=1e-05)</pre></div> </div></div></div></div>" 469 ], 470 "text/plain": [ 471 "SGDRegressor(n_iter_no_change=100, penalty=None, random_state=42, tol=1e-05)" 472 ] 473 }, 474 "execution_count": 14, 475 "metadata": {}, 476 "output_type": "execute_result" 477 } 478 ], 479 "source": [ 480 "# Built in version of this using sklearn\n", 481 "\n", 482 "from sklearn.linear_model import SGDRegressor\n", 483 "\n", 484 "sgd_reg = SGDRegressor(max_iter=1000, tol=1e-5, penalty=None, eta0=0.01,\n", 485 "n_iter_no_change=100, random_state=42)\n", 486 "\n", 487 "sgd_reg.fit(X, y.ravel())" 488 ] 489 }, 490 { 491 "cell_type": "code", 492 "execution_count": 18, 493 "metadata": {}, 494 "outputs": [ 495 { 496 "data": { 497 "text/plain": [ 498 "(array([4.10476564]), array([3.01131049]))" 499 ] 500 }, 501 "execution_count": 18, 502 "metadata": {}, 503 "output_type": "execute_result" 504 } 505 ], 506 "source": [ 507 "sgd_reg.intercept_, sgd_reg.coef_" 508 ] 509 } 510 ], 511 "metadata": { 512 "kernelspec": { 513 "display_name": "notebook", 514 "language": "python", 515 "name": "notebook" 516 }, 517 "language_info": { 518 "codemirror_mode": { 519 "name": "ipython", 520 "version": 3 521 }, 522 "file_extension": ".py", 523 "mimetype": "text/x-python", 524 "name": "python", 525 "nbconvert_exporter": "python", 526 "pygments_lexer": "ipython3", 527 "version": "3.11.2" 528 } 529 }, 530 "nbformat": 4, 531 "nbformat_minor": 2 532 }