{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Appendix B: Logistic Loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn.preprocessing import StandardScaler\n", "import plotly.express as px" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logistic Regression Refresher\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Logistic Regression is a classification model where we calculate the probability of an observation belonging to a class as:\n", "\n", "$$z=w^Tx$$\n", "\n", "$$\\hat{y} = \\frac{1}{(1+\\exp(-z))}$$\n", "\n", "And then assign that observation to a class based on some threshold (usually 0.5):\n", "\n", "$$\\text{Class }\\hat{y}=\\left\\{\n", "\\begin{array}{ll}\n", " 0, & \\hat{y}\\le0.5 \\\\\n", " 1, & \\hat{y}>0.5 \\\\\n", "\\end{array} \n", "\\right.$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Motivating the Loss Function\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- In [Lecture 2](lecture2_gradient-descent.ipynb) we focussed on the mean squared error as a loss function for optimizing linear regression:\n", "\n", "$$f(w)=\\frac{1}{n}\\sum^{n}_{i=1}(\\hat{y}-y_i))^2$$\n", "\n", "- That won't work for logistic regression classification problems because it ends up being \"non-convex\" (which basically means there are multiple minima)\n", "- Instead we use the following loss function:\n", "\n", "$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right) + (1 - y_i)\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n", "\n", "- This function is called the \"log loss\" or \"binary cross entropy\"\n", "- I want to visually show you the differences in these two functions, and then we'll discuss why that loss functions works\n", "- Recall the Pokemon dataset from [Lecture 2](lecture2_gradient-descent.ipynb), I'm going to load that in again (and standardize the data while I'm at it):" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nameattackdefensespeedcapture_rtlegendary
0Bulbasaur494945450
1Ivysaur626360450
2Venusaur10012380450
3Charmander524365450
4Charmeleon645880450
\n", "
" ], "text/plain": [ " name attack defense speed capture_rt legendary\n", "0 Bulbasaur 49 49 45 45 0\n", "1 Ivysaur 62 63 60 45 0\n", "2 Venusaur 100 123 80 45 0\n", "3 Charmander 52 43 65 45 0\n", "4 Charmeleon 64 58 80 45 0" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(\"data/pokemon.csv\", usecols=['name', 'defense', 'attack', 'speed', 'capture_rt', 'legendary'])\n", "x = StandardScaler().fit_transform(df.drop(columns=[\"name\", \"legendary\"]))\n", "X = np.hstack((np.ones((len(x), 1)), x))\n", "y = df['legendary'].to_numpy()\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1. , -0.89790944, -0.78077335, -0.73258737, -0.70940526],\n", " [ 1. , -0.49341318, -0.32548801, -0.21387459, -0.70940526],\n", " [ 1. , 0.6889605 , 1.62573488, 0.47774246, -0.70940526],\n", " ...,\n", " [ 1. , 0.72007559, -0.65069183, -0.80174908, -1.10265558],\n", " [ 1. , 0.90676617, 0.91028648, 0.44316161, -1.25995571],\n", " [ 1. , 0.53338501, 1.36557183, -0.04097032, -1.25995571]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The goal here is to use the features (but not \"name\", that's just there for illustration purposes) to predict the target \"legendary\" (which takes values of 0/No and 1/Yes).\n", "- So we have 4 features meaning that our logistic regression model will have 5 parameters that need to be estimated (4 feature coefficients and 1 intercept)\n", "- At this point let's define our loss functions:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def sigmoid(w, x):\n", " \"\"\"Sigmoid function (i.e., logistic regression predictions).\"\"\"\n", " return 1 / (1 + np.exp(-x @ w))\n", "\n", "\n", "def mse(w, x, y):\n", " \"\"\"Mean squared error.\"\"\"\n", " return np.mean((sigmoid(w, x) - y) ** 2)\n", "\n", "\n", "def logistic_loss(w, x, y):\n", " \"\"\"Logistic loss.\"\"\"\n", " return -np.mean(y * np.log(sigmoid(w, x)) + (1 - y) * np.log(1 - sigmoid(w, x)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- For a moment, let's assume a value for all the parameters execpt for $w_1$\n", "- We will then calculate the mean squared error for different values of $w_1$ as in the code below" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
w1mselog
0-3.00.4511841.604272
1-2.90.4469961.571701
2-2.80.4427731.539928
3-2.70.4385371.508997
4-2.60.4343091.478955
\n", "
" ], "text/plain": [ " w1 mse log\n", "0 -3.0 0.451184 1.604272\n", "1 -2.9 0.446996 1.571701\n", "2 -2.8 0.442773 1.539928\n", "3 -2.7 0.438537 1.508997\n", "4 -2.6 0.434309 1.478955" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w1_arr = np.arange(-3, 6.1, 0.1)\n", "losses = pd.DataFrame({\"w1\": w1_arr,\n", " \"mse\": [mse([0.5, w1, -0.5, 0.5, -2], X, y) for w1 in w1_arr],\n", " \"log\": [logistic_loss([0.5, w1, -0.5, 0.5, -2], X, y) for w1 in w1_arr]})\n", "losses.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "hovertemplate": "loss=mse
w1=%{x}
value=%{y}", "legendgroup": "mse", "line": { "color": "#636efa", "dash": "solid" }, "marker": { "symbol": "circle" }, "mode": "lines", "name": "mse", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ -3, -2.9, -2.8, -2.6999999999999997, -2.5999999999999996, -2.4999999999999996, -2.3999999999999995, -2.2999999999999994, -2.1999999999999993, -2.099999999999999, -1.9999999999999991, -1.899999999999999, -1.799999999999999, -1.6999999999999988, -1.5999999999999988, -1.4999999999999987, -1.3999999999999986, -1.2999999999999985, -1.1999999999999984, -1.0999999999999983, -0.9999999999999982, -0.8999999999999981, -0.799999999999998, -0.699999999999998, -0.5999999999999979, -0.4999999999999978, -0.3999999999999977, -0.2999999999999976, -0.1999999999999975, -0.09999999999999742, 2.6645352591003757e-15, 0.10000000000000275, 0.20000000000000284, 0.30000000000000293, 0.400000000000003, 0.5000000000000031, 0.6000000000000032, 0.7000000000000033, 0.8000000000000034, 0.9000000000000035, 1.0000000000000036, 1.1000000000000032, 1.2000000000000037, 1.3000000000000043, 1.400000000000004, 1.5000000000000036, 1.600000000000004, 1.7000000000000046, 1.8000000000000043, 1.900000000000004, 2.0000000000000044, 2.100000000000005, 2.2000000000000046, 2.3000000000000043, 2.400000000000005, 2.5000000000000053, 2.600000000000005, 2.7000000000000046, 2.800000000000005, 2.9000000000000057, 3.0000000000000053, 3.100000000000005, 3.2000000000000055, 3.300000000000006, 3.4000000000000057, 3.5000000000000053, 3.600000000000006, 3.7000000000000064, 3.800000000000006, 3.9000000000000057, 4.000000000000006, 4.100000000000007, 4.200000000000006, 4.300000000000006, 4.400000000000007, 4.500000000000007, 4.600000000000007, 4.700000000000006, 4.800000000000007, 4.9000000000000075, 5.000000000000007, 5.100000000000007, 5.200000000000006, 5.300000000000008, 5.4000000000000075, 5.500000000000007, 5.6000000000000085, 5.700000000000008, 5.800000000000008, 5.9000000000000075, 6.000000000000007 ], "xaxis": "x", "y": [ 0.45118413729203294, 0.44699591969245017, 0.4427734039772584, 0.43853694835420387, 0.4343090559208596, 0.43011407789321454, 0.4259779788511469, 0.42192820325092617, 0.4179936388761584, 0.4142046335578609, 0.41059300010449856, 0.40719193891743133, 0.4040358027470889, 0.4011596113169522, 0.3985982066010584, 0.3963849616634432, 0.3945500569518339, 0.3931185069431876, 0.392108261155972, 0.3915286735553178, 0.3913793677717037, 0.39164916852709314, 0.3923146699455984, 0.3933384561849971, 0.39466782081561175, 0.39623537865505953, 0.3979626056035935, 0.39976616699742196, 0.4015656150713772, 0.40329036499792326, 0.40488414875139794, 0.40630624952839506, 0.4075300794681095, 0.4085403743862227, 0.4093302530347399, 0.40989893653614246, 0.4102504299622494, 0.4103931104139613, 0.4103399266922344, 0.41010872202536053, 0.4097220566154894, 0.40920598669629726, 0.4085877149945022, 0.40789273262719666, 0.40714250226917276, 0.40635344852000727, 0.4055372492476331, 0.4047018427935356, 0.4038525393819013, 0.4029929290439403, 0.4021255468365788, 0.4012523541841965, 0.4003750864409099, 0.3994954894445678, 0.39861545729062253, 0.3977370866749279, 0.3968626685433115, 0.3959946393759557, 0.39513551184243345, 0.394287799750402, 0.3934539471083785, 0.392636266832288, 0.3918368914995587, 0.39105773651886194, 0.3903004749015518, 0.3895665222403999, 0.3888570303121277, 0.3881728877645162, 0.38751472651414576, 0.38688293269043433, 0.38627766116658263, 0.38569885289062855, 0.38514625435868227, 0.3846194386579576, 0.3841178275574433, 0.3836407141508363, 0.3831872855723988, 0.38275664532288545, 0.3823478347679552, 0.38195985341032823, 0.38159167759071533, 0.38124227733964483, 0.3809106311790326, 0.3805957387537631, 0.3802966312544198, 0.3800123796677423, 0.37974210095740013, 0.37948496233146223, 0.3792401837929624, 0.3790070391959025, 0.37878485604158707 ], "yaxis": "y" }, { "hovertemplate": "loss=log
w1=%{x}
value=%{y}", "legendgroup": "log", "line": { "color": "#EF553B", "dash": "solid" }, "marker": { "symbol": "circle" }, "mode": "lines", "name": "log", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ -3, -2.9, -2.8, -2.6999999999999997, -2.5999999999999996, -2.4999999999999996, -2.3999999999999995, -2.2999999999999994, -2.1999999999999993, -2.099999999999999, -1.9999999999999991, -1.899999999999999, -1.799999999999999, -1.6999999999999988, -1.5999999999999988, -1.4999999999999987, -1.3999999999999986, -1.2999999999999985, -1.1999999999999984, -1.0999999999999983, -0.9999999999999982, -0.8999999999999981, -0.799999999999998, -0.699999999999998, -0.5999999999999979, -0.4999999999999978, -0.3999999999999977, -0.2999999999999976, -0.1999999999999975, -0.09999999999999742, 2.6645352591003757e-15, 0.10000000000000275, 0.20000000000000284, 0.30000000000000293, 0.400000000000003, 0.5000000000000031, 0.6000000000000032, 0.7000000000000033, 0.8000000000000034, 0.9000000000000035, 1.0000000000000036, 1.1000000000000032, 1.2000000000000037, 1.3000000000000043, 1.400000000000004, 1.5000000000000036, 1.600000000000004, 1.7000000000000046, 1.8000000000000043, 1.900000000000004, 2.0000000000000044, 2.100000000000005, 2.2000000000000046, 2.3000000000000043, 2.400000000000005, 2.5000000000000053, 2.600000000000005, 2.7000000000000046, 2.800000000000005, 2.9000000000000057, 3.0000000000000053, 3.100000000000005, 3.2000000000000055, 3.300000000000006, 3.4000000000000057, 3.5000000000000053, 3.600000000000006, 3.7000000000000064, 3.800000000000006, 3.9000000000000057, 4.000000000000006, 4.100000000000007, 4.200000000000006, 4.300000000000006, 4.400000000000007, 4.500000000000007, 4.600000000000007, 4.700000000000006, 4.800000000000007, 4.9000000000000075, 5.000000000000007, 5.100000000000007, 5.200000000000006, 5.300000000000008, 5.4000000000000075, 5.500000000000007, 5.6000000000000085, 5.700000000000008, 5.800000000000008, 5.9000000000000075, 6.000000000000007 ], "xaxis": "x2", "y": [ 1.6042724815909397, 1.5717010229541641, 1.5399277352025134, 1.5089970649918891, 1.4789551947660584, 1.4498499707752563, 1.4217308220966425, 1.3946486601871817, 1.368655740793874, 1.3438054641327422, 1.3201520867392897, 1.2977503199984275, 1.2766547960165007, 1.2569193908060774, 1.2385964070438729, 1.2217356321255692, 1.2063832976419966, 1.1925809671929817, 1.1803643655482796, 1.1697621374101363, 1.1607945076403041, 1.1534718377082804, 1.1477931578512641, 1.1437448859955746, 1.1413000577390067, 1.1404184021711117, 1.1410474655512488, 1.1431247502816644, 1.1465805952268535, 1.151341365133059, 1.157332491594455, 1.1644810135110115, 1.1727174458989076, 1.1819769809220917, 1.192200134492107, 1.203332983209341, 1.2153271120784637, 1.2281393447905733, 1.2417312811487062, 1.2560686390881899, 1.2711204057436218, 1.2868578484794826, 1.303253507257328, 1.3202803416291597, 1.3379111903550227, 1.3561186084988919, 1.374875024728316, 1.394153081511393, 1.4139260158415075, 1.434167985770661, 1.4548543050200822, 1.475961586544332, 1.497467812177854, 1.5193523476030375, 1.5415959185494725, 1.5641805601956482, 1.5870895487330718, 1.610307321965497, 1.633819394287408, 1.6576122701483706, 1.681673359048735, 1.705990894203332, 1.730553856258962, 1.755351902856488, 1.7803753043830541, 1.80561488594071, 1.8310619753400996, 1.8567083567957439, 1.882546229913556, 1.9085681735342432, 1.9347671139756724, 1.9611362972375972, 1.9876692647391974, 2.014359832182527, 2.0412020711855754, 2.068190293309688, 2.0953190361955283, 2.122583051460998, 2.1499772941805424, 2.1774969135555082, 2.2051372447873274, 2.2328938016540145, 2.2607622699292285, 2.288738501404517, 2.316818507885395, 2.3449984565106217, 2.3732746638867663, 2.4016435917642887, 2.43010184240463, 2.4586461537213657, 2.487273393728349 ], "yaxis": "y2" } ], "layout": { "annotations": [ { "font": {}, "showarrow": false, "text": "loss=mse", "x": 0.225, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper" }, { "font": {}, "showarrow": false, "text": "loss=log", "x": 0.775, "xanchor": "center", "xref": "paper", "y": 1, "yanchor": "bottom", "yref": "paper" } ], "height": 400, "legend": { "title": { "text": "loss" }, "tracegroupgap": 0 }, "margin": { "t": 60 }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "width": 800, "xaxis": { "anchor": "y", "domain": [ 0, 0.45 ], "title": { "text": "w1" } }, "xaxis2": { "anchor": "y2", "domain": [ 0.55, 1 ], "showticklabels": true, "title": { "text": "w1" } }, "yaxis": { "anchor": "x", "domain": [ 0, 1 ], "title": { "text": "value" } }, "yaxis2": { "anchor": "x2", "domain": [ 0, 1 ], "showticklabels": true } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = px.line(losses.melt(id_vars=\"w1\", var_name=\"loss\"), x=\"w1\", y=\"value\", color=\"loss\", facet_col=\"loss\", facet_col_spacing=0.1)\n", "fig.update_yaxes(matches=None, showticklabels=True, col=2)\n", "fig.update_xaxes(matches=None, showticklabels=True, col=2)\n", "fig.update_layout(width=800, height=400)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- This is a pretty simple dataset but you can already see the \"non-convexity\" of the MSE loss function.\n", "- If you want a more mathematical description of the logistic loss function, check out [Chapter 3 of Neural Networks and Deep Learning by Michael Nielsen](http://neuralnetworksanddeeplearning.com/chap3.html) or [this Youtube video by Andrew Ng](https://www.youtube.com/watch?v=HIQlmHxI6-0)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Breaking Down the Log Loss Function\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- So we saw the log loss before:\n", "\n", "$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right) + (1 - y_i)\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n", "\n", "- It looks complicated but it's actually quite simple. Let's break it down.\n", "- Recall that we have a binary classification task here so $y_i$ can only be 0 or 1." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### When `y = 1`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- When $y_i = 1$ we are left with:\n", "\n", "$$f(w)=-\\frac{1}{n}\\sum_{i=1}^n\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n", "\n", "- That looks fine!\n", "- With $y_i = 1$, if $\\hat{y_i} = \\frac{1}{1 + \\exp(-w^Tx_i)}$ is also close to 1 we want the loss to be small, if it is close to 0 we want the loss to be large, that's where the `log()` comes in:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y = 1\n", "y_hat_small = 0.05\n", "y_hat_large = 0.95" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.995732273553991" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-np.log(y_hat_small)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.05129329438755058" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-np.log(y_hat_large)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### When `y = 0`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- When $y_i = 1$ we are left with:\n", "\n", "$$f(w)=-\\frac{1}{n}\\sum_{i=1}^n\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n", "\n", "- With $y_i = 0$, if $\\hat{y_i} = \\frac{1}{1 + \\exp(-w^Tx_i)}$ is also close to 0 we want the loss to be small, if it is close to 1 we want the loss to be large, that's where the `log()` comes in:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "y = 0\n", "y_hat_small = 0.05\n", "y_hat_large = 0.95" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.05129329438755058" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-np.log(1 - y_hat_small)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.99573227355399" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-np.log(1 - y_hat_large)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot Log Loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We know that our predictions from logistic regression $\\hat{y}$ are limited between 0 and 1 thanks to the sigmoid function\n", "- So let's plot the losses because it's interesting to see how the worse our predictions are, the worse the loss is (i.e., if $y=1$ and our model predicts $\\hat{y}=0.05$, the penalty is exponentially bigger than if the prediction was $\\hat{y}=0.90$)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "hovertemplate": "y=y=0
y_hat=%{x}
loss=%{y}", "legendgroup": "y=0", "line": { "color": "#636efa", "dash": "solid" }, "marker": { "symbol": "circle" }, "mode": "lines", "name": "y=0", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99 ], "xaxis": "x", "y": [ 0.01005033585350145, 0.020202707317519466, 0.030459207484708574, 0.040821994520255166, 0.05129329438755058, 0.06187540371808753, 0.0725706928348355, 0.08338160893905101, 0.09431067947124129, 0.10536051565782628, 0.11653381625595151, 0.12783337150988489, 0.13926206733350766, 0.15082288973458366, 0.16251892949777494, 0.1743533871447778, 0.18632957819149348, 0.19845093872383818, 0.21072103131565253, 0.2231435513142097, 0.23572233352106983, 0.2484613592984996, 0.2613647641344075, 0.2744368457017603, 0.2876820724517809, 0.3011050927839216, 0.31471074483970024, 0.3285040669720361, 0.342490308946776, 0.35667494393873245, 0.37106368139083207, 0.3856624808119848, 0.4004775665971254, 0.41551544396166595, 0.4307829160924544, 0.4462871026284195, 0.4620354595965587, 0.4780358009429998, 0.4942963218147801, 0.5108256237659907, 0.527632742082372, 0.5447271754416719, 0.5621189181535411, 0.579818495252942, 0.5978370007556204, 0.616186139423817, 0.6348782724359695, 0.6539264674066639, 0.6733445532637656, 0.6931471805599453, 0.7133498878774648, 0.7339691750802004, 0.7550225842780328, 0.7765287894989964, 0.7985076962177717, 0.8209805520698303, 0.8439700702945291, 0.867500567704723, 0.8915981192837835, 0.916290731874155, 0.941608539858445, 0.9675840262617056, 0.9942522733438669, 1.0216512475319814, 1.0498221244986778, 1.07880966137193, 1.1086626245216114, 1.139434283188365, 1.1711829815029453, 1.2039728043259361, 1.2378743560016172, 1.2729656758128873, 1.3093333199837622, 1.3470736479666092, 1.3862943611198906, 1.4271163556401458, 1.4696759700589417, 1.5141277326297757, 1.5606477482646686, 1.6094379124341005, 1.660731206821651, 1.714798428091927, 1.7719568419318756, 1.83258146374831, 1.897119984885881, 1.9661128563728327, 2.0402208285265546, 2.120263536200091, 2.207274913189721, 2.302585092994046, 2.4079456086518722, 2.525728644308256, 2.659260036932779, 2.8134107167600373, 2.995732273553992, 3.2188758248681997, 3.506557897319981, 3.912023005428145, 4.605170185988091 ], "yaxis": "y" }, { "hovertemplate": "y=y=1
y_hat=%{x}
loss=%{y}", "legendgroup": "y=1", "line": { "color": "#EF553B", "dash": "solid" }, "marker": { "symbol": "circle" }, "mode": "lines", "name": "y=1", "orientation": "v", "showlegend": true, "type": "scatter", "x": [ 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99 ], "xaxis": "x", "y": [ 4.605170185988091, 3.912023005428146, 3.506557897319982, 3.2188758248682006, 2.995732273553991, 2.8134107167600364, 2.659260036932778, 2.5257286443082556, 2.4079456086518722, 2.3025850929940455, 2.2072749131897207, 2.120263536200091, 2.0402208285265546, 1.9661128563728327, 1.8971199848858813, 1.8325814637483102, 1.7719568419318752, 1.7147984280919266, 1.6607312068216509, 1.6094379124341003, 1.5606477482646683, 1.5141277326297755, 1.4696759700589417, 1.4271163556401458, 1.3862943611198906, 1.3470736479666092, 1.3093333199837622, 1.2729656758128873, 1.2378743560016174, 1.2039728043259361, 1.171182981502945, 1.1394342831883648, 1.1086626245216111, 1.0788096613719298, 1.0498221244986776, 1.0216512475319814, 0.9942522733438669, 0.9675840262617056, 0.941608539858445, 0.916290731874155, 0.8915981192837835, 0.8675005677047231, 0.843970070294529, 0.8209805520698302, 0.7985076962177716, 0.7765287894989963, 0.7550225842780327, 0.7339691750802004, 0.7133498878774648, 0.6931471805599453, 0.6733445532637656, 0.6539264674066639, 0.6348782724359695, 0.616186139423817, 0.5978370007556204, 0.579818495252942, 0.5621189181535411, 0.5447271754416722, 0.527632742082372, 0.5108256237659907, 0.4942963218147801, 0.4780358009429998, 0.4620354595965587, 0.4462871026284195, 0.4307829160924542, 0.4155154439616658, 0.40047756659712525, 0.3856624808119846, 0.3710636813908319, 0.3566749439387323, 0.342490308946776, 0.3285040669720361, 0.31471074483970024, 0.3011050927839216, 0.2876820724517809, 0.2744368457017603, 0.2613647641344075, 0.2484613592984996, 0.23572233352106983, 0.2231435513142097, 0.21072103131565253, 0.19845093872383818, 0.18632957819149337, 0.1743533871447778, 0.16251892949777494, 0.15082288973458366, 0.13926206733350766, 0.12783337150988489, 0.11653381625595151, 0.10536051565782628, 0.09431067947124129, 0.08338160893905101, 0.07257069283483537, 0.06187540371808741, 0.05129329438755046, 0.040821994520255166, 0.030459207484708574, 0.020202707317519466, 0.01005033585350145 ], "yaxis": "y" } ], "layout": { "height": 400, "legend": { "title": { "text": "y" }, "tracegroupgap": 0 }, "margin": { "t": 60 }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "width": 500, "xaxis": { "anchor": "y", "domain": [ 0, 1 ], "title": { "text": "y_hat" } }, "yaxis": { "anchor": "x", "domain": [ 0, 1 ], "title": { "text": "loss" } } } } }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_hat = np.arange(0.01, 1.00, 0.01)\n", "log_loss = pd.DataFrame({\"y_hat\": y_hat,\n", " \"y=0\": -np.log(1 - y_hat),\n", " \"y=1\": -np.log(y_hat)}).melt(id_vars=\"y_hat\", var_name=\"y\", value_name=\"loss\")\n", "fig = px.line(log_loss, x=\"y_hat\", y=\"loss\", color=\"y\")\n", "fig.update_layout(width=500, height=400)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Log Loss Gradient\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- In [Lecture 2](lecture2_gradient-descent.ipynb) we used the gradient of the log loss to implement gradient descent\n", "- Here's the log loss and it's gradient:\n", "\n", "$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right) + (1 - y_i)\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n", "\n", "$$\\frac{\\partial f(w)}{\\partial w}=\\frac{1}{n}\\sum_{i=1}^nx_i\\left(\\frac{1}{1 + \\exp(-w^Tx_i)} - y_i)\\right)$$\n", "\n", "- Let's derive that now.\n", "- We'll denote:\n", "\n", "$$z = -w^Tx_i$$\n", "\n", "$$\\sigma(z) = \\frac{1}{1 + \\exp(z)}$$\n", "\n", "- Such that:\n", "\n", "$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\sigma(z) + (1 - y_i)\\log(1 - \\sigma(z))$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Okay let's do it:\n", "\n", "$$\n", "\\begin{equation}\n", "\\begin{split}\n", "\\frac{\\partial f(w)}{\\partial w} & =-\\frac{1}{n}\\sum_{i=1}^ny_i \\times \\frac{1}{\\sigma(z)} \\times \\frac{\\partial \\sigma(z)}{\\partial w} + (1 - y_i) \\times \\frac{1}{1 - \\sigma(z)} \\times -\\frac{\\partial \\sigma(z)}{\\partial w} \\\\\n", "& =-\\frac{1}{n}\\sum_{i=1}^n\\left(\\frac{y_i}{\\sigma(z)} - \\frac{1 - y_i}{1 - \\sigma(z)}\\right)\\frac{\\partial \\sigma(z)}{\\partial w} \\\\\n", "& =\\frac{1}{n}\\sum_{i=1}^n \\frac{\\sigma(z)-y_i}{\\sigma(z)(1 - \\sigma(z))}\\frac{\\partial \\sigma(z)}{\\partial w}\n", "\\end{split}\n", "\\end{equation}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Now we just need to work out $\\frac{\\partial \\sigma(z)}{\\partial w}$, I'll mostly skip this part but there's an intuitive derivation [here](https://medium.com/analytics-vidhya/derivative-of-log-loss-function-for-logistic-regression-9b832f025c2d), it's just about using the chain rule:\n", "\n", "$$\n", "\\begin{equation}\n", "\\begin{split}\n", "\\frac{\\partial \\sigma(z)}{\\partial w} & = \\frac{\\partial \\sigma(z)}{\\partial z} \\times \\frac{\\partial z}{\\partial w}\\\\\n", "& = \\sigma(z)(1-\\sigma(z))x_i \\\\\n", "\\end{split}\n", "\\end{equation}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- So finally:\n", "\n", "$$\n", "\\begin{equation}\n", "\\begin{split}\n", "\\frac{\\partial f(w)}{\\partial w} & =\\frac{1}{n}\\sum_{i=1}^n \\frac{\\sigma(z)-y_i}{\\sigma(z)(1 - \\sigma(z))} \\times \\sigma(z)(1-\\sigma(z))x_i \\\\\n", "& = \\frac{1}{n}\\sum_{i=1}^nx_i(\\sigma(z)-y_i) \\\\\n", "& = \\frac{1}{n}\\sum_{i=1}^nx_i\\left(\\frac{1}{1 + \\exp(-w^Tx_i)} - y_i)\\right)\n", "\\end{split}\n", "\\end{equation}\n", "$$" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9525741268224334" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "1/(1+np.exp(-3))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.04742587317756678" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "1/(1+np.exp(3))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6.907755278982137" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-np.log(0.001)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "cpsc330", "language": "python", "name": "cpsc330" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" }, "vscode": { "interpreter": { "hash": "f821000d0c0da66e5bcde88c37d59c8e0de03b40667fb62009a8148ca49465a0" } } }, "nbformat": 4, "nbformat_minor": 4 }