{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Appendix B: Logistic Loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler\n",
"import plotly.express as px"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Logistic Regression Refresher\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Logistic Regression is a classification model where we calculate the probability of an observation belonging to a class as:\n",
"\n",
"$$z=w^Tx$$\n",
"\n",
"$$\\hat{y} = \\frac{1}{(1+\\exp(-z))}$$\n",
"\n",
"And then assign that observation to a class based on some threshold (usually 0.5):\n",
"\n",
"$$\\text{Class }\\hat{y}=\\left\\{\n",
"\\begin{array}{ll}\n",
" 0, & \\hat{y}\\le0.5 \\\\\n",
" 1, & \\hat{y}>0.5 \\\\\n",
"\\end{array} \n",
"\\right.$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Motivating the Loss Function\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- In [Lecture 2](lecture2_gradient-descent.ipynb) we focussed on the mean squared error as a loss function for optimizing linear regression:\n",
"\n",
"$$f(w)=\\frac{1}{n}\\sum^{n}_{i=1}(\\hat{y}-y_i))^2$$\n",
"\n",
"- That won't work for logistic regression classification problems because it ends up being \"non-convex\" (which basically means there are multiple minima)\n",
"- Instead we use the following loss function:\n",
"\n",
"$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right) + (1 - y_i)\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n",
"\n",
"- This function is called the \"log loss\" or \"binary cross entropy\"\n",
"- I want to visually show you the differences in these two functions, and then we'll discuss why that loss functions works\n",
"- Recall the Pokemon dataset from [Lecture 2](lecture2_gradient-descent.ipynb), I'm going to load that in again (and standardize the data while I'm at it):"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" name | \n",
" attack | \n",
" defense | \n",
" speed | \n",
" capture_rt | \n",
" legendary | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Bulbasaur | \n",
" 49 | \n",
" 49 | \n",
" 45 | \n",
" 45 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" Ivysaur | \n",
" 62 | \n",
" 63 | \n",
" 60 | \n",
" 45 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" Venusaur | \n",
" 100 | \n",
" 123 | \n",
" 80 | \n",
" 45 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" Charmander | \n",
" 52 | \n",
" 43 | \n",
" 65 | \n",
" 45 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" Charmeleon | \n",
" 64 | \n",
" 58 | \n",
" 80 | \n",
" 45 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" name attack defense speed capture_rt legendary\n",
"0 Bulbasaur 49 49 45 45 0\n",
"1 Ivysaur 62 63 60 45 0\n",
"2 Venusaur 100 123 80 45 0\n",
"3 Charmander 52 43 65 45 0\n",
"4 Charmeleon 64 58 80 45 0"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"data/pokemon.csv\", usecols=['name', 'defense', 'attack', 'speed', 'capture_rt', 'legendary'])\n",
"x = StandardScaler().fit_transform(df.drop(columns=[\"name\", \"legendary\"]))\n",
"X = np.hstack((np.ones((len(x), 1)), x))\n",
"y = df['legendary'].to_numpy()\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1. , -0.89790944, -0.78077335, -0.73258737, -0.70940526],\n",
" [ 1. , -0.49341318, -0.32548801, -0.21387459, -0.70940526],\n",
" [ 1. , 0.6889605 , 1.62573488, 0.47774246, -0.70940526],\n",
" ...,\n",
" [ 1. , 0.72007559, -0.65069183, -0.80174908, -1.10265558],\n",
" [ 1. , 0.90676617, 0.91028648, 0.44316161, -1.25995571],\n",
" [ 1. , 0.53338501, 1.36557183, -0.04097032, -1.25995571]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- The goal here is to use the features (but not \"name\", that's just there for illustration purposes) to predict the target \"legendary\" (which takes values of 0/No and 1/Yes).\n",
"- So we have 4 features meaning that our logistic regression model will have 5 parameters that need to be estimated (4 feature coefficients and 1 intercept)\n",
"- At this point let's define our loss functions:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(w, x):\n",
" \"\"\"Sigmoid function (i.e., logistic regression predictions).\"\"\"\n",
" return 1 / (1 + np.exp(-x @ w))\n",
"\n",
"\n",
"def mse(w, x, y):\n",
" \"\"\"Mean squared error.\"\"\"\n",
" return np.mean((sigmoid(w, x) - y) ** 2)\n",
"\n",
"\n",
"def logistic_loss(w, x, y):\n",
" \"\"\"Logistic loss.\"\"\"\n",
" return -np.mean(y * np.log(sigmoid(w, x)) + (1 - y) * np.log(1 - sigmoid(w, x)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- For a moment, let's assume a value for all the parameters execpt for $w_1$\n",
"- We will then calculate the mean squared error for different values of $w_1$ as in the code below"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" w1 | \n",
" mse | \n",
" log | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" -3.0 | \n",
" 0.451184 | \n",
" 1.604272 | \n",
"
\n",
" \n",
" 1 | \n",
" -2.9 | \n",
" 0.446996 | \n",
" 1.571701 | \n",
"
\n",
" \n",
" 2 | \n",
" -2.8 | \n",
" 0.442773 | \n",
" 1.539928 | \n",
"
\n",
" \n",
" 3 | \n",
" -2.7 | \n",
" 0.438537 | \n",
" 1.508997 | \n",
"
\n",
" \n",
" 4 | \n",
" -2.6 | \n",
" 0.434309 | \n",
" 1.478955 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" w1 mse log\n",
"0 -3.0 0.451184 1.604272\n",
"1 -2.9 0.446996 1.571701\n",
"2 -2.8 0.442773 1.539928\n",
"3 -2.7 0.438537 1.508997\n",
"4 -2.6 0.434309 1.478955"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w1_arr = np.arange(-3, 6.1, 0.1)\n",
"losses = pd.DataFrame({\"w1\": w1_arr,\n",
" \"mse\": [mse([0.5, w1, -0.5, 0.5, -2], X, y) for w1 in w1_arr],\n",
" \"log\": [logistic_loss([0.5, w1, -0.5, 0.5, -2], X, y) for w1 in w1_arr]})\n",
"losses.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"hovertemplate": "loss=mse
w1=%{x}
value=%{y}",
"legendgroup": "mse",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "mse",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
-3,
-2.9,
-2.8,
-2.6999999999999997,
-2.5999999999999996,
-2.4999999999999996,
-2.3999999999999995,
-2.2999999999999994,
-2.1999999999999993,
-2.099999999999999,
-1.9999999999999991,
-1.899999999999999,
-1.799999999999999,
-1.6999999999999988,
-1.5999999999999988,
-1.4999999999999987,
-1.3999999999999986,
-1.2999999999999985,
-1.1999999999999984,
-1.0999999999999983,
-0.9999999999999982,
-0.8999999999999981,
-0.799999999999998,
-0.699999999999998,
-0.5999999999999979,
-0.4999999999999978,
-0.3999999999999977,
-0.2999999999999976,
-0.1999999999999975,
-0.09999999999999742,
2.6645352591003757e-15,
0.10000000000000275,
0.20000000000000284,
0.30000000000000293,
0.400000000000003,
0.5000000000000031,
0.6000000000000032,
0.7000000000000033,
0.8000000000000034,
0.9000000000000035,
1.0000000000000036,
1.1000000000000032,
1.2000000000000037,
1.3000000000000043,
1.400000000000004,
1.5000000000000036,
1.600000000000004,
1.7000000000000046,
1.8000000000000043,
1.900000000000004,
2.0000000000000044,
2.100000000000005,
2.2000000000000046,
2.3000000000000043,
2.400000000000005,
2.5000000000000053,
2.600000000000005,
2.7000000000000046,
2.800000000000005,
2.9000000000000057,
3.0000000000000053,
3.100000000000005,
3.2000000000000055,
3.300000000000006,
3.4000000000000057,
3.5000000000000053,
3.600000000000006,
3.7000000000000064,
3.800000000000006,
3.9000000000000057,
4.000000000000006,
4.100000000000007,
4.200000000000006,
4.300000000000006,
4.400000000000007,
4.500000000000007,
4.600000000000007,
4.700000000000006,
4.800000000000007,
4.9000000000000075,
5.000000000000007,
5.100000000000007,
5.200000000000006,
5.300000000000008,
5.4000000000000075,
5.500000000000007,
5.6000000000000085,
5.700000000000008,
5.800000000000008,
5.9000000000000075,
6.000000000000007
],
"xaxis": "x",
"y": [
0.45118413729203294,
0.44699591969245017,
0.4427734039772584,
0.43853694835420387,
0.4343090559208596,
0.43011407789321454,
0.4259779788511469,
0.42192820325092617,
0.4179936388761584,
0.4142046335578609,
0.41059300010449856,
0.40719193891743133,
0.4040358027470889,
0.4011596113169522,
0.3985982066010584,
0.3963849616634432,
0.3945500569518339,
0.3931185069431876,
0.392108261155972,
0.3915286735553178,
0.3913793677717037,
0.39164916852709314,
0.3923146699455984,
0.3933384561849971,
0.39466782081561175,
0.39623537865505953,
0.3979626056035935,
0.39976616699742196,
0.4015656150713772,
0.40329036499792326,
0.40488414875139794,
0.40630624952839506,
0.4075300794681095,
0.4085403743862227,
0.4093302530347399,
0.40989893653614246,
0.4102504299622494,
0.4103931104139613,
0.4103399266922344,
0.41010872202536053,
0.4097220566154894,
0.40920598669629726,
0.4085877149945022,
0.40789273262719666,
0.40714250226917276,
0.40635344852000727,
0.4055372492476331,
0.4047018427935356,
0.4038525393819013,
0.4029929290439403,
0.4021255468365788,
0.4012523541841965,
0.4003750864409099,
0.3994954894445678,
0.39861545729062253,
0.3977370866749279,
0.3968626685433115,
0.3959946393759557,
0.39513551184243345,
0.394287799750402,
0.3934539471083785,
0.392636266832288,
0.3918368914995587,
0.39105773651886194,
0.3903004749015518,
0.3895665222403999,
0.3888570303121277,
0.3881728877645162,
0.38751472651414576,
0.38688293269043433,
0.38627766116658263,
0.38569885289062855,
0.38514625435868227,
0.3846194386579576,
0.3841178275574433,
0.3836407141508363,
0.3831872855723988,
0.38275664532288545,
0.3823478347679552,
0.38195985341032823,
0.38159167759071533,
0.38124227733964483,
0.3809106311790326,
0.3805957387537631,
0.3802966312544198,
0.3800123796677423,
0.37974210095740013,
0.37948496233146223,
0.3792401837929624,
0.3790070391959025,
0.37878485604158707
],
"yaxis": "y"
},
{
"hovertemplate": "loss=log
w1=%{x}
value=%{y}",
"legendgroup": "log",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "log",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
-3,
-2.9,
-2.8,
-2.6999999999999997,
-2.5999999999999996,
-2.4999999999999996,
-2.3999999999999995,
-2.2999999999999994,
-2.1999999999999993,
-2.099999999999999,
-1.9999999999999991,
-1.899999999999999,
-1.799999999999999,
-1.6999999999999988,
-1.5999999999999988,
-1.4999999999999987,
-1.3999999999999986,
-1.2999999999999985,
-1.1999999999999984,
-1.0999999999999983,
-0.9999999999999982,
-0.8999999999999981,
-0.799999999999998,
-0.699999999999998,
-0.5999999999999979,
-0.4999999999999978,
-0.3999999999999977,
-0.2999999999999976,
-0.1999999999999975,
-0.09999999999999742,
2.6645352591003757e-15,
0.10000000000000275,
0.20000000000000284,
0.30000000000000293,
0.400000000000003,
0.5000000000000031,
0.6000000000000032,
0.7000000000000033,
0.8000000000000034,
0.9000000000000035,
1.0000000000000036,
1.1000000000000032,
1.2000000000000037,
1.3000000000000043,
1.400000000000004,
1.5000000000000036,
1.600000000000004,
1.7000000000000046,
1.8000000000000043,
1.900000000000004,
2.0000000000000044,
2.100000000000005,
2.2000000000000046,
2.3000000000000043,
2.400000000000005,
2.5000000000000053,
2.600000000000005,
2.7000000000000046,
2.800000000000005,
2.9000000000000057,
3.0000000000000053,
3.100000000000005,
3.2000000000000055,
3.300000000000006,
3.4000000000000057,
3.5000000000000053,
3.600000000000006,
3.7000000000000064,
3.800000000000006,
3.9000000000000057,
4.000000000000006,
4.100000000000007,
4.200000000000006,
4.300000000000006,
4.400000000000007,
4.500000000000007,
4.600000000000007,
4.700000000000006,
4.800000000000007,
4.9000000000000075,
5.000000000000007,
5.100000000000007,
5.200000000000006,
5.300000000000008,
5.4000000000000075,
5.500000000000007,
5.6000000000000085,
5.700000000000008,
5.800000000000008,
5.9000000000000075,
6.000000000000007
],
"xaxis": "x2",
"y": [
1.6042724815909397,
1.5717010229541641,
1.5399277352025134,
1.5089970649918891,
1.4789551947660584,
1.4498499707752563,
1.4217308220966425,
1.3946486601871817,
1.368655740793874,
1.3438054641327422,
1.3201520867392897,
1.2977503199984275,
1.2766547960165007,
1.2569193908060774,
1.2385964070438729,
1.2217356321255692,
1.2063832976419966,
1.1925809671929817,
1.1803643655482796,
1.1697621374101363,
1.1607945076403041,
1.1534718377082804,
1.1477931578512641,
1.1437448859955746,
1.1413000577390067,
1.1404184021711117,
1.1410474655512488,
1.1431247502816644,
1.1465805952268535,
1.151341365133059,
1.157332491594455,
1.1644810135110115,
1.1727174458989076,
1.1819769809220917,
1.192200134492107,
1.203332983209341,
1.2153271120784637,
1.2281393447905733,
1.2417312811487062,
1.2560686390881899,
1.2711204057436218,
1.2868578484794826,
1.303253507257328,
1.3202803416291597,
1.3379111903550227,
1.3561186084988919,
1.374875024728316,
1.394153081511393,
1.4139260158415075,
1.434167985770661,
1.4548543050200822,
1.475961586544332,
1.497467812177854,
1.5193523476030375,
1.5415959185494725,
1.5641805601956482,
1.5870895487330718,
1.610307321965497,
1.633819394287408,
1.6576122701483706,
1.681673359048735,
1.705990894203332,
1.730553856258962,
1.755351902856488,
1.7803753043830541,
1.80561488594071,
1.8310619753400996,
1.8567083567957439,
1.882546229913556,
1.9085681735342432,
1.9347671139756724,
1.9611362972375972,
1.9876692647391974,
2.014359832182527,
2.0412020711855754,
2.068190293309688,
2.0953190361955283,
2.122583051460998,
2.1499772941805424,
2.1774969135555082,
2.2051372447873274,
2.2328938016540145,
2.2607622699292285,
2.288738501404517,
2.316818507885395,
2.3449984565106217,
2.3732746638867663,
2.4016435917642887,
2.43010184240463,
2.4586461537213657,
2.487273393728349
],
"yaxis": "y2"
}
],
"layout": {
"annotations": [
{
"font": {},
"showarrow": false,
"text": "loss=mse",
"x": 0.225,
"xanchor": "center",
"xref": "paper",
"y": 1,
"yanchor": "bottom",
"yref": "paper"
},
{
"font": {},
"showarrow": false,
"text": "loss=log",
"x": 0.775,
"xanchor": "center",
"xref": "paper",
"y": 1,
"yanchor": "bottom",
"yref": "paper"
}
],
"height": 400,
"legend": {
"title": {
"text": "loss"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"fillpattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 800,
"xaxis": {
"anchor": "y",
"domain": [
0,
0.45
],
"title": {
"text": "w1"
}
},
"xaxis2": {
"anchor": "y2",
"domain": [
0.55,
1
],
"showticklabels": true,
"title": {
"text": "w1"
}
},
"yaxis": {
"anchor": "x",
"domain": [
0,
1
],
"title": {
"text": "value"
}
},
"yaxis2": {
"anchor": "x2",
"domain": [
0,
1
],
"showticklabels": true
}
}
},
"text/html": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = px.line(losses.melt(id_vars=\"w1\", var_name=\"loss\"), x=\"w1\", y=\"value\", color=\"loss\", facet_col=\"loss\", facet_col_spacing=0.1)\n",
"fig.update_yaxes(matches=None, showticklabels=True, col=2)\n",
"fig.update_xaxes(matches=None, showticklabels=True, col=2)\n",
"fig.update_layout(width=800, height=400)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- This is a pretty simple dataset but you can already see the \"non-convexity\" of the MSE loss function.\n",
"- If you want a more mathematical description of the logistic loss function, check out [Chapter 3 of Neural Networks and Deep Learning by Michael Nielsen](http://neuralnetworksanddeeplearning.com/chap3.html) or [this Youtube video by Andrew Ng](https://www.youtube.com/watch?v=HIQlmHxI6-0)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Breaking Down the Log Loss Function\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- So we saw the log loss before:\n",
"\n",
"$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right) + (1 - y_i)\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n",
"\n",
"- It looks complicated but it's actually quite simple. Let's break it down.\n",
"- Recall that we have a binary classification task here so $y_i$ can only be 0 or 1."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### When `y = 1`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- When $y_i = 1$ we are left with:\n",
"\n",
"$$f(w)=-\\frac{1}{n}\\sum_{i=1}^n\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n",
"\n",
"- That looks fine!\n",
"- With $y_i = 1$, if $\\hat{y_i} = \\frac{1}{1 + \\exp(-w^Tx_i)}$ is also close to 1 we want the loss to be small, if it is close to 0 we want the loss to be large, that's where the `log()` comes in:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y = 1\n",
"y_hat_small = 0.05\n",
"y_hat_large = 0.95"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.995732273553991"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-np.log(y_hat_small)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.05129329438755058"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-np.log(y_hat_large)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### When `y = 0`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- When $y_i = 1$ we are left with:\n",
"\n",
"$$f(w)=-\\frac{1}{n}\\sum_{i=1}^n\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n",
"\n",
"- With $y_i = 0$, if $\\hat{y_i} = \\frac{1}{1 + \\exp(-w^Tx_i)}$ is also close to 0 we want the loss to be small, if it is close to 1 we want the loss to be large, that's where the `log()` comes in:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"y = 0\n",
"y_hat_small = 0.05\n",
"y_hat_large = 0.95"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.05129329438755058"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-np.log(1 - y_hat_small)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.99573227355399"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-np.log(1 - y_hat_large)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot Log Loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We know that our predictions from logistic regression $\\hat{y}$ are limited between 0 and 1 thanks to the sigmoid function\n",
"- So let's plot the losses because it's interesting to see how the worse our predictions are, the worse the loss is (i.e., if $y=1$ and our model predicts $\\hat{y}=0.05$, the penalty is exponentially bigger than if the prediction was $\\hat{y}=0.90$)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"hovertemplate": "y=y=0
y_hat=%{x}
loss=%{y}",
"legendgroup": "y=0",
"line": {
"color": "#636efa",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "y=0",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
0.01,
0.02,
0.03,
0.04,
0.05,
0.06,
0.07,
0.08,
0.09,
0.1,
0.11,
0.12,
0.13,
0.14,
0.15,
0.16,
0.17,
0.18,
0.19,
0.2,
0.21,
0.22,
0.23,
0.24,
0.25,
0.26,
0.27,
0.28,
0.29,
0.3,
0.31,
0.32,
0.33,
0.34,
0.35000000000000003,
0.36,
0.37,
0.38,
0.39,
0.4,
0.41000000000000003,
0.42,
0.43,
0.44,
0.45,
0.46,
0.47000000000000003,
0.48,
0.49,
0.5,
0.51,
0.52,
0.53,
0.54,
0.55,
0.56,
0.5700000000000001,
0.58,
0.59,
0.6,
0.61,
0.62,
0.63,
0.64,
0.65,
0.66,
0.67,
0.68,
0.6900000000000001,
0.7000000000000001,
0.71,
0.72,
0.73,
0.74,
0.75,
0.76,
0.77,
0.78,
0.79,
0.8,
0.81,
0.8200000000000001,
0.8300000000000001,
0.84,
0.85,
0.86,
0.87,
0.88,
0.89,
0.9,
0.91,
0.92,
0.93,
0.9400000000000001,
0.9500000000000001,
0.96,
0.97,
0.98,
0.99
],
"xaxis": "x",
"y": [
0.01005033585350145,
0.020202707317519466,
0.030459207484708574,
0.040821994520255166,
0.05129329438755058,
0.06187540371808753,
0.0725706928348355,
0.08338160893905101,
0.09431067947124129,
0.10536051565782628,
0.11653381625595151,
0.12783337150988489,
0.13926206733350766,
0.15082288973458366,
0.16251892949777494,
0.1743533871447778,
0.18632957819149348,
0.19845093872383818,
0.21072103131565253,
0.2231435513142097,
0.23572233352106983,
0.2484613592984996,
0.2613647641344075,
0.2744368457017603,
0.2876820724517809,
0.3011050927839216,
0.31471074483970024,
0.3285040669720361,
0.342490308946776,
0.35667494393873245,
0.37106368139083207,
0.3856624808119848,
0.4004775665971254,
0.41551544396166595,
0.4307829160924544,
0.4462871026284195,
0.4620354595965587,
0.4780358009429998,
0.4942963218147801,
0.5108256237659907,
0.527632742082372,
0.5447271754416719,
0.5621189181535411,
0.579818495252942,
0.5978370007556204,
0.616186139423817,
0.6348782724359695,
0.6539264674066639,
0.6733445532637656,
0.6931471805599453,
0.7133498878774648,
0.7339691750802004,
0.7550225842780328,
0.7765287894989964,
0.7985076962177717,
0.8209805520698303,
0.8439700702945291,
0.867500567704723,
0.8915981192837835,
0.916290731874155,
0.941608539858445,
0.9675840262617056,
0.9942522733438669,
1.0216512475319814,
1.0498221244986778,
1.07880966137193,
1.1086626245216114,
1.139434283188365,
1.1711829815029453,
1.2039728043259361,
1.2378743560016172,
1.2729656758128873,
1.3093333199837622,
1.3470736479666092,
1.3862943611198906,
1.4271163556401458,
1.4696759700589417,
1.5141277326297757,
1.5606477482646686,
1.6094379124341005,
1.660731206821651,
1.714798428091927,
1.7719568419318756,
1.83258146374831,
1.897119984885881,
1.9661128563728327,
2.0402208285265546,
2.120263536200091,
2.207274913189721,
2.302585092994046,
2.4079456086518722,
2.525728644308256,
2.659260036932779,
2.8134107167600373,
2.995732273553992,
3.2188758248681997,
3.506557897319981,
3.912023005428145,
4.605170185988091
],
"yaxis": "y"
},
{
"hovertemplate": "y=y=1
y_hat=%{x}
loss=%{y}",
"legendgroup": "y=1",
"line": {
"color": "#EF553B",
"dash": "solid"
},
"marker": {
"symbol": "circle"
},
"mode": "lines",
"name": "y=1",
"orientation": "v",
"showlegend": true,
"type": "scatter",
"x": [
0.01,
0.02,
0.03,
0.04,
0.05,
0.06,
0.07,
0.08,
0.09,
0.1,
0.11,
0.12,
0.13,
0.14,
0.15,
0.16,
0.17,
0.18,
0.19,
0.2,
0.21,
0.22,
0.23,
0.24,
0.25,
0.26,
0.27,
0.28,
0.29,
0.3,
0.31,
0.32,
0.33,
0.34,
0.35000000000000003,
0.36,
0.37,
0.38,
0.39,
0.4,
0.41000000000000003,
0.42,
0.43,
0.44,
0.45,
0.46,
0.47000000000000003,
0.48,
0.49,
0.5,
0.51,
0.52,
0.53,
0.54,
0.55,
0.56,
0.5700000000000001,
0.58,
0.59,
0.6,
0.61,
0.62,
0.63,
0.64,
0.65,
0.66,
0.67,
0.68,
0.6900000000000001,
0.7000000000000001,
0.71,
0.72,
0.73,
0.74,
0.75,
0.76,
0.77,
0.78,
0.79,
0.8,
0.81,
0.8200000000000001,
0.8300000000000001,
0.84,
0.85,
0.86,
0.87,
0.88,
0.89,
0.9,
0.91,
0.92,
0.93,
0.9400000000000001,
0.9500000000000001,
0.96,
0.97,
0.98,
0.99
],
"xaxis": "x",
"y": [
4.605170185988091,
3.912023005428146,
3.506557897319982,
3.2188758248682006,
2.995732273553991,
2.8134107167600364,
2.659260036932778,
2.5257286443082556,
2.4079456086518722,
2.3025850929940455,
2.2072749131897207,
2.120263536200091,
2.0402208285265546,
1.9661128563728327,
1.8971199848858813,
1.8325814637483102,
1.7719568419318752,
1.7147984280919266,
1.6607312068216509,
1.6094379124341003,
1.5606477482646683,
1.5141277326297755,
1.4696759700589417,
1.4271163556401458,
1.3862943611198906,
1.3470736479666092,
1.3093333199837622,
1.2729656758128873,
1.2378743560016174,
1.2039728043259361,
1.171182981502945,
1.1394342831883648,
1.1086626245216111,
1.0788096613719298,
1.0498221244986776,
1.0216512475319814,
0.9942522733438669,
0.9675840262617056,
0.941608539858445,
0.916290731874155,
0.8915981192837835,
0.8675005677047231,
0.843970070294529,
0.8209805520698302,
0.7985076962177716,
0.7765287894989963,
0.7550225842780327,
0.7339691750802004,
0.7133498878774648,
0.6931471805599453,
0.6733445532637656,
0.6539264674066639,
0.6348782724359695,
0.616186139423817,
0.5978370007556204,
0.579818495252942,
0.5621189181535411,
0.5447271754416722,
0.527632742082372,
0.5108256237659907,
0.4942963218147801,
0.4780358009429998,
0.4620354595965587,
0.4462871026284195,
0.4307829160924542,
0.4155154439616658,
0.40047756659712525,
0.3856624808119846,
0.3710636813908319,
0.3566749439387323,
0.342490308946776,
0.3285040669720361,
0.31471074483970024,
0.3011050927839216,
0.2876820724517809,
0.2744368457017603,
0.2613647641344075,
0.2484613592984996,
0.23572233352106983,
0.2231435513142097,
0.21072103131565253,
0.19845093872383818,
0.18632957819149337,
0.1743533871447778,
0.16251892949777494,
0.15082288973458366,
0.13926206733350766,
0.12783337150988489,
0.11653381625595151,
0.10536051565782628,
0.09431067947124129,
0.08338160893905101,
0.07257069283483537,
0.06187540371808741,
0.05129329438755046,
0.040821994520255166,
0.030459207484708574,
0.020202707317519466,
0.01005033585350145
],
"yaxis": "y"
}
],
"layout": {
"height": 400,
"legend": {
"title": {
"text": "y"
},
"tracegroupgap": 0
},
"margin": {
"t": 60
},
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"fillpattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"width": 500,
"xaxis": {
"anchor": "y",
"domain": [
0,
1
],
"title": {
"text": "y_hat"
}
},
"yaxis": {
"anchor": "x",
"domain": [
0,
1
],
"title": {
"text": "loss"
}
}
}
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"y_hat = np.arange(0.01, 1.00, 0.01)\n",
"log_loss = pd.DataFrame({\"y_hat\": y_hat,\n",
" \"y=0\": -np.log(1 - y_hat),\n",
" \"y=1\": -np.log(y_hat)}).melt(id_vars=\"y_hat\", var_name=\"y\", value_name=\"loss\")\n",
"fig = px.line(log_loss, x=\"y_hat\", y=\"loss\", color=\"y\")\n",
"fig.update_layout(width=500, height=400)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Log Loss Gradient\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- In [Lecture 2](lecture2_gradient-descent.ipynb) we used the gradient of the log loss to implement gradient descent\n",
"- Here's the log loss and it's gradient:\n",
"\n",
"$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\left(\\frac{1}{1 + \\exp(-w^Tx_i)}\\right) + (1 - y_i)\\log\\left(1 - \\frac{1}{1 + \\exp(-w^Tx_i)}\\right)$$\n",
"\n",
"$$\\frac{\\partial f(w)}{\\partial w}=\\frac{1}{n}\\sum_{i=1}^nx_i\\left(\\frac{1}{1 + \\exp(-w^Tx_i)} - y_i)\\right)$$\n",
"\n",
"- Let's derive that now.\n",
"- We'll denote:\n",
"\n",
"$$z = -w^Tx_i$$\n",
"\n",
"$$\\sigma(z) = \\frac{1}{1 + \\exp(z)}$$\n",
"\n",
"- Such that:\n",
"\n",
"$$f(w)=-\\frac{1}{n}\\sum_{i=1}^ny_i\\log\\sigma(z) + (1 - y_i)\\log(1 - \\sigma(z))$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Okay let's do it:\n",
"\n",
"$$\n",
"\\begin{equation}\n",
"\\begin{split}\n",
"\\frac{\\partial f(w)}{\\partial w} & =-\\frac{1}{n}\\sum_{i=1}^ny_i \\times \\frac{1}{\\sigma(z)} \\times \\frac{\\partial \\sigma(z)}{\\partial w} + (1 - y_i) \\times \\frac{1}{1 - \\sigma(z)} \\times -\\frac{\\partial \\sigma(z)}{\\partial w} \\\\\n",
"& =-\\frac{1}{n}\\sum_{i=1}^n\\left(\\frac{y_i}{\\sigma(z)} - \\frac{1 - y_i}{1 - \\sigma(z)}\\right)\\frac{\\partial \\sigma(z)}{\\partial w} \\\\\n",
"& =\\frac{1}{n}\\sum_{i=1}^n \\frac{\\sigma(z)-y_i}{\\sigma(z)(1 - \\sigma(z))}\\frac{\\partial \\sigma(z)}{\\partial w}\n",
"\\end{split}\n",
"\\end{equation}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Now we just need to work out $\\frac{\\partial \\sigma(z)}{\\partial w}$, I'll mostly skip this part but there's an intuitive derivation [here](https://medium.com/analytics-vidhya/derivative-of-log-loss-function-for-logistic-regression-9b832f025c2d), it's just about using the chain rule:\n",
"\n",
"$$\n",
"\\begin{equation}\n",
"\\begin{split}\n",
"\\frac{\\partial \\sigma(z)}{\\partial w} & = \\frac{\\partial \\sigma(z)}{\\partial z} \\times \\frac{\\partial z}{\\partial w}\\\\\n",
"& = \\sigma(z)(1-\\sigma(z))x_i \\\\\n",
"\\end{split}\n",
"\\end{equation}\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- So finally:\n",
"\n",
"$$\n",
"\\begin{equation}\n",
"\\begin{split}\n",
"\\frac{\\partial f(w)}{\\partial w} & =\\frac{1}{n}\\sum_{i=1}^n \\frac{\\sigma(z)-y_i}{\\sigma(z)(1 - \\sigma(z))} \\times \\sigma(z)(1-\\sigma(z))x_i \\\\\n",
"& = \\frac{1}{n}\\sum_{i=1}^nx_i(\\sigma(z)-y_i) \\\\\n",
"& = \\frac{1}{n}\\sum_{i=1}^nx_i\\left(\\frac{1}{1 + \\exp(-w^Tx_i)} - y_i)\\right)\n",
"\\end{split}\n",
"\\end{equation}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9525741268224334"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"1/(1+np.exp(-3))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.04742587317756678"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"1/(1+np.exp(3))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.907755278982137"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"-np.log(0.001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "cpsc330",
"language": "python",
"name": "cpsc330"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
},
"vscode": {
"interpreter": {
"hash": "f821000d0c0da66e5bcde88c37d59c8e0de03b40667fb62009a8148ca49465a0"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}