{ "cells": [ { "cell_type": "markdown", "id": "94517988-2d62-4a7b-aef0-b2d5cd760962", "metadata": {}, "source": [ "# Lecture 07: Collaborative filtering class demo\n", "\n", "![](../img/eva-fun-times.png)" ] }, { "cell_type": "code", "execution_count": 1, "id": "8d3f6f02-d323-46c9-8333-0f787ee9f906", "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "import sys\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "sys.path.append(os.path.join(os.path.abspath(\"..\"), \"code\"))" ] }, { "cell_type": "markdown", "id": "da7bcc2b-6007-475d-97e2-a0c57417f0b1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Example dataset: [Jester 1.7M jokes ratings dataset](https://www.kaggle.com/vikashrajluhaniwal/jester-17m-jokes-ratings-dataset?select=jester_ratings.csv)\n", "\n", "- We'll use a sample of [Jester 1.7M jokes ratings dataset](https://www.kaggle.com/vikashrajluhaniwal/jester-17m-jokes-ratings-dataset) to demonstrate different recommendation systems. \n", "\n", "The dataset comes with two CSVs\n", "- A CSV containing ratings (-10.0 to +10.0) of 150 jokes from 59,132 users. \n", "- A CSV containing joke IDs and the actual text of jokes. \n", "\n", "> Some jokes might be offensive. Please do not look too much into the actual text data if you are sensitive to such language." ] }, { "cell_type": "markdown", "id": "c9393c9e-5326-4608-a741-933e3ff0eeb0", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "- Recommendation systems are most effective when you have a large amount of data.\n", "- But we are only taking a sample here for speed." ] }, { "cell_type": "code", "execution_count": 3, "id": "08eea17f-eccb-4dae-8db4-7839e9e391c5", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "filename = \"../data/jester_ratings.csv\"\n", "ratings_full = pd.read_csv(filename)\n", "ratings = ratings_full[ratings_full[\"userId\"] <= 4000]" ] }, { "cell_type": "code", "execution_count": 4, "id": "4b6af512-9f7f-4b19-aa78-9742c5e99ded", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdjokeIdrating
0150.219
117-9.281
218-9.281
3113-6.781
41150.875
\n", "
" ], "text/plain": [ " userId jokeId rating\n", "0 1 5 0.219\n", "1 1 7 -9.281\n", "2 1 8 -9.281\n", "3 1 13 -6.781\n", "4 1 15 0.875" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ratings.head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "7111f9d8-b1bc-4a09-a1a3-eba95ba55372", "metadata": {}, "outputs": [], "source": [ "user_key = \"userId\"\n", "item_key = \"jokeId\"" ] }, { "cell_type": "markdown", "id": "0f0a8452-9d28-4238-8209-41e98d5cd3ec", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Dataset stats " ] }, { "cell_type": "code", "execution_count": 6, "id": "0b4134ed-52e4-42aa-8b88-e963a2ba7582", "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 141362 entries, 0 to 141361\n", "Data columns (total 3 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 userId 141362 non-null int64 \n", " 1 jokeId 141362 non-null int64 \n", " 2 rating 141362 non-null float64\n", "dtypes: float64(1), int64(2)\n", "memory usage: 4.3 MB\n" ] } ], "source": [ "ratings.info()" ] }, { "cell_type": "code", "execution_count": 7, "id": "a3746d08-cb2c-490a-8733-2b323476b870", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of ratings: 141362\n", "Average rating: 1.200\n", "Number of users (N): 3635\n", "Number of items (M): 140\n", "Fraction non-nan ratings: 0.278\n" ] } ], "source": [ "def get_stats(ratings, item_key=\"jokeId\", user_key=\"userId\"):\n", " print(\"Number of ratings:\", len(ratings))\n", " print(\"Average rating: %0.3f\" % (np.mean(ratings[\"rating\"])))\n", " N = len(np.unique(ratings[user_key]))\n", " M = len(np.unique(ratings[item_key]))\n", " print(\"Number of users (N): %d\" % N)\n", " print(\"Number of items (M): %d\" % M)\n", " print(\"Fraction non-nan ratings: %0.3f\" % (len(ratings) / (N * M)))\n", " return N, M\n", "\n", "\n", "N, M = get_stats(ratings)" ] }, { "cell_type": "markdown", "id": "51877561-5a47-4fc8-a1b7-7d499c402efb", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "- Let's construct utility matrix with `number of users` rows and `number of items` columns from the ratings data. \n", "> Note we are constructing a non-sparse matrix for demonstration purpose here. In real life it's recommended that you work with sparse matrices. " ] }, { "cell_type": "code", "execution_count": 8, "id": "13da40d2-ca4c-45e5-827c-b0e47e418131", "metadata": {}, "outputs": [], "source": [ "user_mapper = dict(zip(np.unique(ratings[user_key]), list(range(N))))\n", "item_mapper = dict(zip(np.unique(ratings[item_key]), list(range(M))))\n", "user_inverse_mapper = dict(zip(list(range(N)), np.unique(ratings[user_key])))\n", "item_inverse_mapper = dict(zip(list(range(M)), np.unique(ratings[item_key])))" ] }, { "cell_type": "code", "execution_count": 9, "id": "4dabc244-057b-4120-aa52-eb015098ad72", "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "def create_Y_from_ratings(\n", " data, N, M, user_mapper, item_mapper, user_key=\"userId\", item_key=\"jokeId\"\n", "): # Function to create a dense utility matrix\n", " Y = np.zeros((N, M))\n", " Y.fill(np.nan)\n", " for index, val in data.iterrows():\n", " n = user_mapper[val[user_key]]\n", " m = item_mapper[val[item_key]]\n", " Y[n, m] = val[\"rating\"]\n", "\n", " return Y" ] }, { "cell_type": "markdown", "id": "e1a80499-3d25-45d0-afc5-00de5fe009ec", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Utility matrix for Jester jokes ratings data\n", "- Rows represent users.\n", "- Columns represent items (jokes in our case).\n", "- Each cell gives the rating given by the user to the corresponding joke. \n", "- Users are features for jokes and jokes are features for users.\n", "- We want to predict the missing entries. " ] }, { "cell_type": "code", "execution_count": 10, "id": "7902ee59-7268-4a5b-89ef-c2f9324112f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3635, 140)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_mat = create_Y_from_ratings(ratings, N, M, user_mapper, item_mapper)\n", "Y_mat.shape" ] }, { "cell_type": "code", "execution_count": 11, "id": "630c2d22-4355-41c2-9f40-3477a256f506", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...130131132133134135136137138139
00.219-9.281-9.281-6.7810.875-9.656-9.031-7.469-8.719-9.156...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1-9.6889.9389.5319.9380.4063.7199.656-2.688-9.562-9.125...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2-9.844-9.844-7.219-2.031-9.938-9.969-9.875-9.812-9.781-6.844...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3-5.812-4.500-4.906NaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
46.9064.750-5.906-0.406-4.0313.8756.2195.6566.0945.406...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
..................................................................
3630NaN-9.812-0.062NaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3631NaN-9.8447.531-9.719-9.3443.8759.8128.9388.375NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3632NaN-1.9063.969-2.312-0.344-8.8444.188NaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3633NaN-8.875-9.156-9.156NaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3634NaN-6.3121.281-3.5312.125-5.8125.562-6.0620.125NaN...NaNNaN4.188NaNNaNNaNNaNNaNNaNNaN
\n", "

3635 rows × 140 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 \\\n", "0 0.219 -9.281 -9.281 -6.781 0.875 -9.656 -9.031 -7.469 -8.719 -9.156 \n", "1 -9.688 9.938 9.531 9.938 0.406 3.719 9.656 -2.688 -9.562 -9.125 \n", "2 -9.844 -9.844 -7.219 -2.031 -9.938 -9.969 -9.875 -9.812 -9.781 -6.844 \n", "3 -5.812 -4.500 -4.906 NaN NaN NaN NaN NaN NaN NaN \n", "4 6.906 4.750 -5.906 -0.406 -4.031 3.875 6.219 5.656 6.094 5.406 \n", "... ... ... ... ... ... ... ... ... ... ... \n", "3630 NaN -9.812 -0.062 NaN NaN NaN NaN NaN NaN NaN \n", "3631 NaN -9.844 7.531 -9.719 -9.344 3.875 9.812 8.938 8.375 NaN \n", "3632 NaN -1.906 3.969 -2.312 -0.344 -8.844 4.188 NaN NaN NaN \n", "3633 NaN -8.875 -9.156 -9.156 NaN NaN NaN NaN NaN NaN \n", "3634 NaN -6.312 1.281 -3.531 2.125 -5.812 5.562 -6.062 0.125 NaN \n", "\n", " ... 130 131 132 133 134 135 136 137 138 139 \n", "0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "1 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "2 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "4 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "... ... ... ... ... ... ... ... ... ... ... ... \n", "3630 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3631 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3632 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3633 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3634 ... NaN NaN 4.188 NaN NaN NaN NaN NaN NaN NaN \n", "\n", "[3635 rows x 140 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(Y_mat)" ] }, { "cell_type": "markdown", "id": "fccaa878-6d21-4e9e-bfb3-11ec50e71d2d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "

" ] }, { "cell_type": "markdown", "id": "7403caca-06f6-40af-8857-56d84bb26e41", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Data splitting and evaluation" ] }, { "cell_type": "markdown", "id": "3cd090b1-b6a8-4591-abb6-4c284aaba355", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "- Recall that our goal is to predict missing entries in the utility matrix. " ] }, { "cell_type": "code", "execution_count": 12, "id": "3b1470b0-0544-4dad-b30d-5643e3c0e3dc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...130131132133134135136137138139
00.219-9.281-9.281-6.7810.875-9.656-9.031-7.469-8.719-9.156...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1-9.6889.9389.5319.9380.4063.7199.656-2.688-9.562-9.125...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2-9.844-9.844-7.219-2.031-9.938-9.969-9.875-9.812-9.781-6.844...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3-5.812-4.500-4.906NaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
46.9064.750-5.906-0.406-4.0313.8756.2195.6566.0945.406...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
..................................................................
3630NaN-9.812-0.062NaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3631NaN-9.8447.531-9.719-9.3443.8759.8128.9388.375NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3632NaN-1.9063.969-2.312-0.344-8.8444.188NaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3633NaN-8.875-9.156-9.156NaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3634NaN-6.3121.281-3.5312.125-5.8125.562-6.0620.125NaN...NaNNaN4.188NaNNaNNaNNaNNaNNaNNaN
\n", "

3635 rows × 140 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 \\\n", "0 0.219 -9.281 -9.281 -6.781 0.875 -9.656 -9.031 -7.469 -8.719 -9.156 \n", "1 -9.688 9.938 9.531 9.938 0.406 3.719 9.656 -2.688 -9.562 -9.125 \n", "2 -9.844 -9.844 -7.219 -2.031 -9.938 -9.969 -9.875 -9.812 -9.781 -6.844 \n", "3 -5.812 -4.500 -4.906 NaN NaN NaN NaN NaN NaN NaN \n", "4 6.906 4.750 -5.906 -0.406 -4.031 3.875 6.219 5.656 6.094 5.406 \n", "... ... ... ... ... ... ... ... ... ... ... \n", "3630 NaN -9.812 -0.062 NaN NaN NaN NaN NaN NaN NaN \n", "3631 NaN -9.844 7.531 -9.719 -9.344 3.875 9.812 8.938 8.375 NaN \n", "3632 NaN -1.906 3.969 -2.312 -0.344 -8.844 4.188 NaN NaN NaN \n", "3633 NaN -8.875 -9.156 -9.156 NaN NaN NaN NaN NaN NaN \n", "3634 NaN -6.312 1.281 -3.531 2.125 -5.812 5.562 -6.062 0.125 NaN \n", "\n", " ... 130 131 132 133 134 135 136 137 138 139 \n", "0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "1 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "2 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "4 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "... ... ... ... ... ... ... ... ... ... ... ... \n", "3630 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3631 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3632 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3633 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN \n", "3634 ... NaN NaN 4.188 NaN NaN NaN NaN NaN NaN NaN \n", "\n", "[3635 rows x 140 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(Y_mat)" ] }, { "cell_type": "markdown", "id": "19fe3fac-fe6a-48a6-909c-c0bb912ee1e7", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Data splitting \n", "\n", "- We split the ratings into train and validation sets. \n", "- It's easier to split the ratings data instead of splitting the utility matrix.\n", "- Don't worry about `y`; we're not really going to use it. " ] }, { "cell_type": "code", "execution_count": 13, "id": "df0820a3-890b-4bb1-8ed0-43db8b97fb04", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((113089, 3), (28273, 3))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import train_test_split\n", "X = ratings.copy()\n", "y = ratings[user_key]\n", "X_train, X_valid, y_train, y_valid = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "X_train.shape, X_valid.shape" ] }, { "cell_type": "markdown", "id": "97339a02-44fe-48b7-b29e-b0164316649b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Now we will create utility matrices for train and validation splits. " ] }, { "cell_type": "code", "execution_count": 14, "id": "67d2c94a-efb8-4bfe-b3eb-a999482f8aa4", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "train_mat = create_Y_from_ratings(X_train, N, M, user_mapper, item_mapper)\n", "valid_mat = create_Y_from_ratings(X_valid, N, M, user_mapper, item_mapper)" ] }, { "cell_type": "code", "execution_count": 15, "id": "ec7d281d-1374-4009-b760-31c45b60ab79", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((3635, 140), (3635, 140))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_mat.shape, valid_mat.shape" ] }, { "cell_type": "code", "execution_count": 16, "id": "15d2a1f7-2519-42a3-ad1a-248261c6d2ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.22222244055806642" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(len(X_train) / (N * M)) # Fraction of non-nan entries in the train set" ] }, { "cell_type": "code", "execution_count": 17, "id": "da20cf39-8acf-40a6-b990-fd60a3f7a1ee", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.055557083906464924" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(len(X_valid) / (N * M)) # Fraction of non-nan entries in the valid set" ] }, { "cell_type": "markdown", "id": "33f31392-7204-4f0b-a085-32f30e68b1fa", "metadata": {}, "source": [ "- `train_mat` has only ratings from the train set and `valid_mat` has only ratings from the valid set.\n", "- During training we assume that we do not have access to some of the available ratings. We predict these ratings and evaluate them against ratings in the validation set. " ] }, { "cell_type": "code", "execution_count": 18, "id": "a8af6508-6245-40ea-9419-ead3804fe0f0", "metadata": {}, "outputs": [], "source": [ "def error(X1, X2):\n", " \"\"\"\n", " Returns the root mean squared error.\n", " \"\"\"\n", " return np.sqrt(np.nanmean((X1 - X2) ** 2))\n", "\n", "\n", "def evaluate(pred_X, train_X, valid_X, model_name=\"Global average\"):\n", " print(\"%s train RMSE: %0.2f\" % (model_name, error(pred_X, train_X)))\n", " print(\"%s valid RMSE: %0.2f\" % (model_name, error(pred_X, valid_X)))" ] }, { "cell_type": "markdown", "id": "bc11cab9-fca1-4e2b-a68b-1e2b55e55eba", "metadata": {}, "source": [ "

" ] }, { "cell_type": "markdown", "id": "153577e5-df09-49a6-a65d-fbf22365dd69", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Baseline approaches \n", "\n", "Let's first try some simple approaches to predict missing entries. \n", "\n", "1. Global average baseline\n", "2. Per-user average baseline\n", "3. Per-item average baseline\n", "4. Average of 2 and 3\n", " - Take an average of per-user and per-item averages. \n", "5. [$k$-Nearest Neighbours imputation](https://scikit-learn.org/stable/modules/generated/sklearn.impute.KNNImputer.html) \n", " \n", "I'll show you 1. and 5. You'll explore 2., 3., and 4. in the lab. " ] }, { "cell_type": "markdown", "id": "fb5662ed-49db-443d-9413-5711a97b942c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Global average baseline\n", "\n", "- Let's examine RMSE of the global average baseline. \n", "- In this baseline we predict everything as the global average rating." ] }, { "cell_type": "code", "execution_count": 19, "id": "65524cea-5241-4efb-ad83-b89e50196e1d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...130131132133134135136137138139
01.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741...1.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741
11.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741...1.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741
21.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741...1.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741
31.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741...1.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741
41.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741...1.207411.207411.207411.207411.207411.207411.207411.207411.207411.20741
\n", "

5 rows × 140 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 \\\n", "0 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "1 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "2 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "3 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "4 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "\n", " 8 9 ... 130 131 132 133 134 \\\n", "0 1.20741 1.20741 ... 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "1 1.20741 1.20741 ... 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "2 1.20741 1.20741 ... 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "3 1.20741 1.20741 ... 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "4 1.20741 1.20741 ... 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "\n", " 135 136 137 138 139 \n", "0 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "1 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "2 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "3 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "4 1.20741 1.20741 1.20741 1.20741 1.20741 \n", "\n", "[5 rows x 140 columns]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "avg = np.nanmean(train_mat)\n", "pred_g = np.zeros(train_mat.shape) + avg\n", "pd.DataFrame(pred_g).head()" ] }, { "cell_type": "code", "execution_count": 20, "id": "a9369cc1-b95b-4e6b-90d3-929efde5f339", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Global average train RMSE: 5.75\n", "Global average valid RMSE: 5.77\n" ] } ], "source": [ "evaluate(pred_g, train_mat, valid_mat, model_name=\"Global average\")" ] }, { "cell_type": "markdown", "id": "7330ab24-8adf-4a3a-b0dd-d169eafb4046", "metadata": {}, "source": [ "

" ] }, { "cell_type": "markdown", "id": "348f7e71-3cea-4f01-acae-00e0d6784d6d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### [$k$-nearest neighbours imputation](https://scikit-learn.org/stable/modules/generated/sklearn.impute.KNNImputer.html)" ] }, { "cell_type": "code", "execution_count": 21, "id": "1a795191-6720-4a86-9bd5-e0466d556efa", "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from sklearn.impute import KNNImputer\n", "\n", "imputer = KNNImputer(n_neighbors=10, keep_empty_features=True)\n", "train_mat_imp = imputer.fit_transform(train_mat)" ] }, { "cell_type": "code", "execution_count": 22, "id": "bab1c362-4e3b-49d3-9561-6b35b50ed2da", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...130131132133134135136137138139
0-5.9406-9.2810-9.2810-6.78100.8750-9.6560-9.0310-7.4690-8.7190-9.1560...-4.53111.89680.6905-3.12181.2843-2.6063-0.1812-1.39371.7625-0.4092
12.34059.93809.53109.93800.40603.71909.6560-2.68804.3438-9.1250...2.24373.17195.02515.18128.24075.93115.83756.38121.16876.2532
2-9.8440-3.5750-7.2190-2.0310-9.9380-9.9690-9.8750-9.8120-9.7810-6.8440...-4.4186-3.1156-1.5655-5.62500.3720-4.0439-6.0500-5.5563-5.4125-5.5874
3-5.8120-2.4624-4.9060-2.7781-0.0532-3.85941.7031-0.36871.84690.0593...-2.03442.14692.88751.68451.2437-0.01561.25953.82193.19715.0249
41.31574.75001.8658-0.40601.79373.87506.21901.92206.09405.4060...-0.28441.13134.01573.03444.04060.52184.35944.09683.92503.9657
..................................................................
3630-0.7750-9.8120-0.0620-2.8218-4.1470-4.82812.2718-2.8782-1.01250.0688...-6.68443.05312.86871.52814.5002-0.18782.00314.09082.35635.0406
36312.5188-5.0625-0.4001-9.7190-9.3440-1.6408-4.11878.93808.3750-0.9314...-4.03447.91553.42824.29686.79687.39991.85005.82195.18122.8437
36320.1749-1.90603.9690-1.3844-0.3440-8.84404.1880-1.55645.05930.3343...-4.01262.83442.44992.93122.3750-0.40621.43753.9750-1.22202.8375
3633-4.5937-6.4376-5.9563-9.1560-7.1437-5.58442.2531-0.9688-2.8530-0.6406...-4.69383.41865.1656-0.16262.5594-0.77504.67811.2658-1.1718-0.7157
3634-0.0812-6.31201.2810-3.53102.1250-5.81205.56200.22180.1250-1.1874...-3.81564.18124.18803.72803.07502.10332.81565.53123.82834.1219
\n", "

3635 rows × 140 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 \\\n", "0 -5.9406 -9.2810 -9.2810 -6.7810 0.8750 -9.6560 -9.0310 -7.4690 -8.7190 \n", "1 2.3405 9.9380 9.5310 9.9380 0.4060 3.7190 9.6560 -2.6880 4.3438 \n", "2 -9.8440 -3.5750 -7.2190 -2.0310 -9.9380 -9.9690 -9.8750 -9.8120 -9.7810 \n", "3 -5.8120 -2.4624 -4.9060 -2.7781 -0.0532 -3.8594 1.7031 -0.3687 1.8469 \n", "4 1.3157 4.7500 1.8658 -0.4060 1.7937 3.8750 6.2190 1.9220 6.0940 \n", "... ... ... ... ... ... ... ... ... ... \n", "3630 -0.7750 -9.8120 -0.0620 -2.8218 -4.1470 -4.8281 2.2718 -2.8782 -1.0125 \n", "3631 2.5188 -5.0625 -0.4001 -9.7190 -9.3440 -1.6408 -4.1187 8.9380 8.3750 \n", "3632 0.1749 -1.9060 3.9690 -1.3844 -0.3440 -8.8440 4.1880 -1.5564 5.0593 \n", "3633 -4.5937 -6.4376 -5.9563 -9.1560 -7.1437 -5.5844 2.2531 -0.9688 -2.8530 \n", "3634 -0.0812 -6.3120 1.2810 -3.5310 2.1250 -5.8120 5.5620 0.2218 0.1250 \n", "\n", " 9 ... 130 131 132 133 134 135 136 \\\n", "0 -9.1560 ... -4.5311 1.8968 0.6905 -3.1218 1.2843 -2.6063 -0.1812 \n", "1 -9.1250 ... 2.2437 3.1719 5.0251 5.1812 8.2407 5.9311 5.8375 \n", "2 -6.8440 ... -4.4186 -3.1156 -1.5655 -5.6250 0.3720 -4.0439 -6.0500 \n", "3 0.0593 ... -2.0344 2.1469 2.8875 1.6845 1.2437 -0.0156 1.2595 \n", "4 5.4060 ... -0.2844 1.1313 4.0157 3.0344 4.0406 0.5218 4.3594 \n", "... ... ... ... ... ... ... ... ... ... \n", "3630 0.0688 ... -6.6844 3.0531 2.8687 1.5281 4.5002 -0.1878 2.0031 \n", "3631 -0.9314 ... -4.0344 7.9155 3.4282 4.2968 6.7968 7.3999 1.8500 \n", "3632 0.3343 ... -4.0126 2.8344 2.4499 2.9312 2.3750 -0.4062 1.4375 \n", "3633 -0.6406 ... -4.6938 3.4186 5.1656 -0.1626 2.5594 -0.7750 4.6781 \n", "3634 -1.1874 ... -3.8156 4.1812 4.1880 3.7280 3.0750 2.1033 2.8156 \n", "\n", " 137 138 139 \n", "0 -1.3937 1.7625 -0.4092 \n", "1 6.3812 1.1687 6.2532 \n", "2 -5.5563 -5.4125 -5.5874 \n", "3 3.8219 3.1971 5.0249 \n", "4 4.0968 3.9250 3.9657 \n", "... ... ... ... \n", "3630 4.0908 2.3563 5.0406 \n", "3631 5.8219 5.1812 2.8437 \n", "3632 3.9750 -1.2220 2.8375 \n", "3633 1.2658 -1.1718 -0.7157 \n", "3634 5.5312 3.8283 4.1219 \n", "\n", "[3635 rows x 140 columns]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(train_mat_imp)" ] }, { "cell_type": "code", "execution_count": 23, "id": "0032b5c0-f80d-4b2a-8578-3206a6b1c434", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KNN imputer train RMSE: 0.00\n", "KNN imputer valid RMSE: 4.79\n" ] } ], "source": [ "evaluate(train_mat_imp, train_mat, valid_mat, model_name=\"KNN imputer\")" ] }, { "cell_type": "markdown", "id": "69e1ea0a-fcee-420c-8172-372fc5d0153e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### (Optional) Finding [nearest neighbors](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html)\n", "\n", "- We can look at nearest neighbours of a query item. \n", "- Here our columns are jokes, and users are features for jokes, and we'll have to find nearest neighbours of columns vectors. " ] }, { "cell_type": "code", "execution_count": 24, "id": "8ac75d41-354a-4e84-8933-3a44616b5e40", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...130131132133134135136137138139
0-5.9406-9.2810-9.2810-6.78100.8750-9.6560-9.0310-7.4690-8.7190-9.1560...-4.53111.89680.6905-3.12181.2843-2.6063-0.1812-1.39371.7625-0.4092
12.34059.93809.53109.93800.40603.71909.6560-2.68804.3438-9.1250...2.24373.17195.02515.18128.24075.93115.83756.38121.16876.2532
2-9.8440-3.5750-7.2190-2.0310-9.9380-9.9690-9.8750-9.8120-9.7810-6.8440...-4.4186-3.1156-1.5655-5.62500.3720-4.0439-6.0500-5.5563-5.4125-5.5874
3-5.8120-2.4624-4.9060-2.7781-0.0532-3.85941.7031-0.36871.84690.0593...-2.03442.14692.88751.68451.2437-0.01561.25953.82193.19715.0249
41.31574.75001.8658-0.40601.79373.87506.21901.92206.09405.4060...-0.28441.13134.01573.03444.04060.52184.35944.09683.92503.9657
\n", "

5 rows × 140 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 \\\n", "0 -5.9406 -9.2810 -9.2810 -6.7810 0.8750 -9.6560 -9.0310 -7.4690 -8.7190 \n", "1 2.3405 9.9380 9.5310 9.9380 0.4060 3.7190 9.6560 -2.6880 4.3438 \n", "2 -9.8440 -3.5750 -7.2190 -2.0310 -9.9380 -9.9690 -9.8750 -9.8120 -9.7810 \n", "3 -5.8120 -2.4624 -4.9060 -2.7781 -0.0532 -3.8594 1.7031 -0.3687 1.8469 \n", "4 1.3157 4.7500 1.8658 -0.4060 1.7937 3.8750 6.2190 1.9220 6.0940 \n", "\n", " 9 ... 130 131 132 133 134 135 136 \\\n", "0 -9.1560 ... -4.5311 1.8968 0.6905 -3.1218 1.2843 -2.6063 -0.1812 \n", "1 -9.1250 ... 2.2437 3.1719 5.0251 5.1812 8.2407 5.9311 5.8375 \n", "2 -6.8440 ... -4.4186 -3.1156 -1.5655 -5.6250 0.3720 -4.0439 -6.0500 \n", "3 0.0593 ... -2.0344 2.1469 2.8875 1.6845 1.2437 -0.0156 1.2595 \n", "4 5.4060 ... -0.2844 1.1313 4.0157 3.0344 4.0406 0.5218 4.3594 \n", "\n", " 137 138 139 \n", "0 -1.3937 1.7625 -0.4092 \n", "1 6.3812 1.1687 6.2532 \n", "2 -5.5563 -5.4125 -5.5874 \n", "3 3.8219 3.1971 5.0249 \n", "4 4.0968 3.9250 3.9657 \n", "\n", "[5 rows x 140 columns]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(train_mat_imp).head()" ] }, { "cell_type": "markdown", "id": "0a23120c-7fea-4e16-b289-d05dd8d466fc", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### (Optional) $k$-nearest neighbours on a query joke\n", "- Let's transpose the matrix. " ] }, { "cell_type": "code", "execution_count": null, "id": "bdfd3bb6-03f9-4b83-8a33-3e85bf14df36", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "item_user_mat = train_mat_imp.T" ] }, { "cell_type": "code", "execution_count": null, "id": "a409b45a-c541-4ff5-8f8a-7017daf3437b", "metadata": {}, "outputs": [], "source": [ "jokes_df = pd.read_csv(\"../data/jester_items.csv\")\n", "jokes_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "0f72d898-9e1e-47ef-aea5-5645d0655d45", "metadata": {}, "outputs": [], "source": [ "jester_items_df = pd.read_csv(\"../data/jester_items.csv\")\n", "jester_items_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "6cde5af6-d2c0-442e-a219-4c3c2320557e", "metadata": { "slideshow": { "slide_type": "-" } }, "outputs": [], "source": [ "id_joke_map = dict(zip(jokes_df.jokeId, jokes_df.jokeText))" ] }, { "cell_type": "code", "execution_count": null, "id": "7da09b44-cc32-4c69-9291-5581aa4fadf6", "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from sklearn.neighbors import NearestNeighbors\n", "\n", "\n", "def get_topk_recommendations(X, query_ind=0, metric=\"cosine\", k=5):\n", " query_idx = item_inverse_mapper[query_ind]\n", " model = NearestNeighbors(n_neighbors=k, metric=\"cosine\")\n", " model.fit(X)\n", " neigh_ind = model.kneighbors([X[query_ind]], k, return_distance=False).flatten()\n", " neigh_ind = np.delete(neigh_ind, np.where(query_ind == query_ind))\n", " recs = [id_joke_map[item_inverse_mapper[i]] for i in neigh_ind]\n", " print(\"Query joke: \", id_joke_map[query_idx])\n", "\n", " return pd.DataFrame(data=recs, columns=[\"top recommendations\"])\n", "\n", "\n", "get_topk_recommendations(item_user_mat, query_ind=8, metric=\"cosine\", k=5)" ] }, { "cell_type": "markdown", "id": "86633e8d-cbc9-4a1f-bfb3-e58bd17fe22f", "metadata": {}, "source": [ "## Collaborative filtering using the `surprise` package" ] }, { "cell_type": "markdown", "id": "0f708516-8fed-456e-9d6a-b050284d92b9", "metadata": {}, "source": [ "Although matrix factorization is a prominent approach to complete the utility matrix, `TruncatedSVD` is not appropriate in this context because of a large number of NaN values in this matrix. \n", "\n", "- We consider only observed ratings and add regularization to avoid overfitting. \n", "- Here is the loss function \n", "\n", "$$f(Z, W) = \\sum_{(i,j) \\in R} ((w_j^Tz_i) - y_{ij})^2 + \\frac{\\lambda_1}{2}\\lVert Z \\lVert_2^2 + \\frac{\\lambda_2}{2}\\lVert W \\lVert_2^2$$" ] }, { "cell_type": "markdown", "id": "0bde1f0e-72f6-4ace-a368-772ea1dc213b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Let's try it out on our Jester dataset utility matrix. " ] }, { "cell_type": "code", "execution_count": 25, "id": "b3db92b4-a18a-4abe-a55e-91cdf6768f47", "metadata": {}, "outputs": [], "source": [ "import surprise\n", "from surprise import SVD, Dataset, Reader, accuracy" ] }, { "cell_type": "code", "execution_count": 26, "id": "ac9a71c6-1eee-40cf-aed5-840ac40a88a0", "metadata": {}, "outputs": [], "source": [ "reader = Reader()\n", "data = Dataset.load_from_df(ratings, reader) # Load the data\n", "\n", "# I'm being sloppy here. Probably there is a way to create validset from our already split data.\n", "trainset, validset = surprise.model_selection.train_test_split(\n", " data, test_size=0.2, random_state=42\n", ") # Split the data" ] }, { "cell_type": "markdown", "id": "80dd2535-ad12-4520-bc06-f51541de0f99", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "**Regularized SVD**" ] }, { "cell_type": "code", "execution_count": 27, "id": "fc16e752-beee-4df2-961f-f67f2e564122", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE: 5.2893\n" ] }, { "data": { "text/plain": [ "5.28926338380112" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k = 10\n", "algo = SVD(n_factors=k, random_state=42)\n", "algo.fit(trainset)\n", "svd_preds = algo.test(validset)\n", "accuracy.rmse(svd_preds, verbose=True)" ] }, { "cell_type": "markdown", "id": "1491a0e0-52c7-4e19-8f5b-57ea8c7c24ce", "metadata": {}, "source": [ "- No big improvement over the global baseline (RMSE=5.77). \n", "- Probably because we are only considering a sample. " ] }, { "cell_type": "markdown", "id": "6c59c839-a031-4eb6-adde-d74ffdde80b9", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "**Cross-validation for recommender systems**\n", "\n", "- We can also carry out cross-validation and grid search with this package. \n", "- Let's look at an example of cross-validation. " ] }, { "cell_type": "code", "execution_count": 28, "id": "360a650f-a0fc-480f-bac1-3923981b0f69", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluating RMSE, MAE of algorithm SVD on 5 split(s).\n", "\n", " Fold 1 Fold 2 Fold 3 Fold 4 Fold 5 Mean Std \n", "RMSE (testset) 5.3022 5.2667 5.2661 5.2970 5.3140 5.2892 0.0194 \n", "MAE (testset) 4.2272 4.1867 4.1792 4.2200 4.2142 4.2055 0.0190 \n", "Fit time 0.21 0.22 0.22 0.22 0.22 0.22 0.00 \n", "Test time 0.07 0.07 0.12 0.07 0.12 0.09 0.03 \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
test_rmsetest_maefit_timetest_time
05.3022324.2272370.2122850.068680
15.2667154.1867040.2225060.066116
25.2660534.1791800.2183990.119088
35.2969704.2199690.2171750.068268
45.3139874.2142290.2171950.118838
\n", "
" ], "text/plain": [ " test_rmse test_mae fit_time test_time\n", "0 5.302232 4.227237 0.212285 0.068680\n", "1 5.266715 4.186704 0.222506 0.066116\n", "2 5.266053 4.179180 0.218399 0.119088\n", "3 5.296970 4.219969 0.217175 0.068268\n", "4 5.313987 4.214229 0.217195 0.118838" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from surprise.model_selection import cross_validate\n", "\n", "pd.DataFrame(cross_validate(algo, data, measures=[\"RMSE\", \"MAE\"], cv=5, verbose=True))" ] }, { "cell_type": "markdown", "id": "c54c7c2a-6afd-4d3d-b974-f918631723f2", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "- Jester dataset is available as one of the built-in datasets in this package and you can load it as follows and run cross-validation as follows. " ] }, { "cell_type": "code", "execution_count": null, "id": "b746cd0c-3832-4207-ada5-b29b9451f761", "metadata": {}, "outputs": [], "source": [ "data = Dataset.load_builtin(\"jester\")\n", "\n", "pd.DataFrame(cross_validate(algo, data, measures=[\"RMSE\", \"MAE\"], cv=5, verbose=True))" ] }, { "cell_type": "markdown", "id": "02922761-e7d7-419e-91e4-0517da452a43", "metadata": {}, "source": [ "

" ] }, { "cell_type": "markdown", "id": "8bd3f7fd-219c-4824-bf3f-7916f3f72dbe", "metadata": {}, "source": [ "### (Optional) PyTorch implementation \n", "\n", "We can also implement the loss function above using `PyTorch`. " ] }, { "cell_type": "code", "execution_count": 29, "id": "5a09504f-7324-4faf-8fbe-5b196741e292", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " userId jokeId rating\n", "0 1 5 0.219\n", "1 1 7 -9.281\n", "2 1 8 -9.281\n", "3 1 13 -6.781\n", "4 1 15 0.875\n" ] } ], "source": [ "import pandas as pd\n", "\n", "# Load the dataset\n", "ratings = pd.read_csv('../data/jester_ratings.csv')\n", "print(ratings.head())" ] }, { "cell_type": "code", "execution_count": 30, "id": "541c68d5-d643-4761-9068-bdd41f4687e2", "metadata": {}, "outputs": [], "source": [ "ratings_df = pd.read_csv('../data/jester_ratings.csv')\n", "ratings = ratings_df[ratings_df[\"userId\"] <= 4000].copy()\n", "# ratings = ratings_full[ratings_full[\"userId\"] <= 4000]" ] }, { "cell_type": "code", "execution_count": 31, "id": "7620328b-4308-46eb-8271-0797cfd120cd", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdjokeIdrating
0150.219
117-9.281
218-9.281
3113-6.781
41150.875
............
141357400018-6.062
1413584000190.125
1413594000765.719
1413604000535.344
14136140001434.188
\n", "

141362 rows × 3 columns

\n", "
" ], "text/plain": [ " userId jokeId rating\n", "0 1 5 0.219\n", "1 1 7 -9.281\n", "2 1 8 -9.281\n", "3 1 13 -6.781\n", "4 1 15 0.875\n", "... ... ... ...\n", "141357 4000 18 -6.062\n", "141358 4000 19 0.125\n", "141359 4000 76 5.719\n", "141360 4000 53 5.344\n", "141361 4000 143 4.188\n", "\n", "[141362 rows x 3 columns]" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ratings" ] }, { "cell_type": "code", "execution_count": 32, "id": "7150f4bc-7e4a-4220-b61e-4de2b0df2fa4", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "\n", "user_encoder = LabelEncoder()\n", "ratings['user'] = user_encoder.fit_transform(ratings.userId.values)\n", "\n", "item_encoder = LabelEncoder()\n", "ratings['item'] = item_encoder.fit_transform(ratings.jokeId.values)\n", "\n", "num_users = ratings['user'].nunique()\n", "num_items = ratings['item'].nunique()" ] }, { "cell_type": "code", "execution_count": 33, "id": "eeb0a2ac-3f78-4df0-a75c-8f1a9ccf3309", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdjokeIdratinguseritem
0150.21900
117-9.28101
218-9.28102
3113-6.78103
41150.87504
..................
141357400018-6.06236347
1413584000190.12536348
1413594000765.719363465
1413604000535.344363442
14136140001434.1883634132
\n", "

141362 rows × 5 columns

\n", "
" ], "text/plain": [ " userId jokeId rating user item\n", "0 1 5 0.219 0 0\n", "1 1 7 -9.281 0 1\n", "2 1 8 -9.281 0 2\n", "3 1 13 -6.781 0 3\n", "4 1 15 0.875 0 4\n", "... ... ... ... ... ...\n", "141357 4000 18 -6.062 3634 7\n", "141358 4000 19 0.125 3634 8\n", "141359 4000 76 5.719 3634 65\n", "141360 4000 53 5.344 3634 42\n", "141361 4000 143 4.188 3634 132\n", "\n", "[141362 rows x 5 columns]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ratings" ] }, { "cell_type": "code", "execution_count": 34, "id": "6683d754-74f0-4ac4-9b02-829c35cf48ef", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((113089, 5), (28273, 5))" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = ratings.copy()\n", "y = ratings[user_key]\n", "X_train, X_valid, y_train, y_valid = train_test_split(\n", " X, y, test_size=0.2, random_state=42\n", ")\n", "\n", "X_train.shape, X_valid.shape" ] }, { "cell_type": "markdown", "id": "5f6f7d87-a96c-4b23-94b4-067bc4f9a093", "metadata": {}, "source": [ "Let's create a custom `ItemsRatingsDataset` class for the ratings data, so that we can use of PyTorch's DataLoader for batch processing and data shuffling during training." ] }, { "cell_type": "code", "execution_count": 35, "id": "803851a5-8a54-4014-9239-62b740e21bf9", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "class ItemsRatingsDataset(Dataset):\n", " def __init__(self, users, items, ratings):\n", " self.users = users\n", " self.items = items\n", " self.ratings = ratings\n", "\n", " def __len__(self):\n", " return len(self.ratings)\n", "\n", " def __getitem__(self, idx):\n", " return {\n", " 'user': torch.tensor(self.users[idx], dtype=torch.long),\n", " 'item': torch.tensor(self.items[idx], dtype=torch.long),\n", " 'rating': torch.tensor(self.ratings[idx], dtype=torch.float)\n", " }\n", "\n", "train_dataset = ItemsRatingsDataset(X_train['user'].values, X_train['item'].values, X_train['rating'].values)\n", "train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)\n", "\n", "valid_dataset = ItemsRatingsDataset(X_valid['user'].values, X_valid['item'].values, X_valid['rating'].values)\n", "valid_dataloader = DataLoader(valid_dataset, batch_size=512, shuffle=True)" ] }, { "cell_type": "markdown", "id": "22c275d8-46cc-4e62-b9cc-6cfa4f5b83e7", "metadata": {}, "source": [ "The `CFModel` class below defines the architecture and the forward method of collaborative filtering. We are using the `embedding` layer of `torch.nn` which simply creates a lookup table that stores embeddings of a fixed dictionary and size." ] }, { "cell_type": "code", "execution_count": 36, "id": "99ead3ac-5228-459e-b970-2a0500b71b19", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class CFModel(nn.Module):\n", " def __init__(self, num_users, num_items, emb_size=100):\n", " super(CFModel, self).__init__()\n", " # Embeddings for users\n", " self.user_emb = nn.Embedding(num_users, emb_size)\n", "\n", " # Embeddings for items\n", " self.item_emb = nn.Embedding(num_items, emb_size)\n", "\n", " # Initialize the embeddings \n", " self.user_emb.weight.data.uniform_(0, 0.05)\n", " self.item_emb.weight.data.uniform_(0, 0.05)\n", " \n", " def forward(self, user, movie):\n", " user_embedding = self.user_emb(user)\n", " item_embedding = self.item_emb(movie)\n", " # calculate predicted ratings as dot products\n", " # of corresponding user and item embeddings\n", " return (user_embedding * item_embedding).sum(1)\n", "\n", "model = CFModel(num_users, num_items)" ] }, { "cell_type": "code", "execution_count": 37, "id": "9e7052e9-2186-4aea-85f9-f2d2b1741670", "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "\n", "# Loss function\n", "criterion = nn.MSELoss()\n", "\n", "# Regularization coefficient\n", "lambda_reg = 0.001\n", "\n", "# Optimizer (without weight decay) \n", "# We manually add regularization in the loss below\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", "\n", "\n", "import torch\n", "import numpy as np\n", "import torch.nn.functional as F\n", "\n", "def train(model, train_dataloader, valid_dataloader, optimizer, criterion, epochs=10):\n", " for epoch in range(epochs): \n", " model.train() # Set model to training mode\n", " total_train_loss = 0\n", " for batch in train_dataloader:\n", " optimizer.zero_grad()\n", " \n", " # Forward pass\n", " predictions = model(batch['user'], batch['item'])\n", " \n", " # Compute the base loss (MSE)\n", " loss = criterion(predictions, batch['rating'])\n", " \n", " # Compute regularization terms (L2 norm of user and movie embeddings)\n", " user_reg = lambda_reg * model.user_emb.weight.norm(2)\n", " item_reg = lambda_reg * model.item_emb.weight.norm(2)\n", " \n", " # Total loss is the sum of base loss and regularization terms\n", " train_loss = loss + user_reg + item_reg\n", " \n", " # Backpropagation\n", " train_loss.backward()\n", " optimizer.step()\n", " \n", " total_train_loss += train_loss.item()\n", " \n", " avg_train_loss = total_train_loss / len(train_dataloader)\n", " \n", " model.eval() # Set model to evaluation mode\n", " total_valid_loss = 0\n", " all_predictions = []\n", " all_ratings = []\n", " with torch.no_grad():\n", " for batch in valid_dataloader:\n", " predictions = model(batch['user'], batch['item'])\n", " valid_loss = criterion(predictions, batch['rating'])\n", " total_valid_loss += loss.item()\n", " all_predictions.extend(predictions.tolist())\n", " all_ratings.extend(batch['rating'].tolist())\n", " avg_valid_loss = total_valid_loss / len(valid_dataloader)\n", " rmse = np.sqrt(np.mean((np.array(all_predictions) - np.array(all_ratings)) ** 2))\n", " print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}, Valid RMSE: {rmse:.4f}')" ] }, { "cell_type": "code", "execution_count": 38, "id": "17956fec-e2a0-42bb-ac3d-c48b6ff3248e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Train Loss: 31.8835, Valid Loss: 26.9124, Valid RMSE: 5.2919\n", "Epoch 2, Train Loss: 25.1327, Valid Loss: 23.5237, Valid RMSE: 4.9331\n", "Epoch 3, Train Loss: 23.2564, Valid Loss: 23.8726, Valid RMSE: 4.8777\n", "Epoch 4, Train Loss: 22.7042, Valid Loss: 20.4802, Valid RMSE: 4.8478\n", "Epoch 5, Train Loss: 22.2066, Valid Loss: 25.0819, Valid RMSE: 4.8148\n", "Epoch 6, Train Loss: 21.5153, Valid Loss: 21.0060, Valid RMSE: 4.7681\n", "Epoch 7, Train Loss: 20.6530, Valid Loss: 21.4653, Valid RMSE: 4.7184\n", "Epoch 8, Train Loss: 19.8259, Valid Loss: 18.8280, Valid RMSE: 4.6790\n", "Epoch 9, Train Loss: 19.1407, Valid Loss: 17.8768, Valid RMSE: 4.6501\n", "Epoch 10, Train Loss: 18.5690, Valid Loss: 18.0142, Valid RMSE: 4.6269\n", "Epoch 11, Train Loss: 18.0456, Valid Loss: 17.7909, Valid RMSE: 4.6082\n", "Epoch 12, Train Loss: 17.5188, Valid Loss: 17.0157, Valid RMSE: 4.5913\n", "Epoch 13, Train Loss: 16.9827, Valid Loss: 15.0803, Valid RMSE: 4.5783\n", "Epoch 14, Train Loss: 16.4400, Valid Loss: 18.2713, Valid RMSE: 4.5671\n", "Epoch 15, Train Loss: 15.8830, Valid Loss: 19.4104, Valid RMSE: 4.5590\n", "Epoch 16, Train Loss: 15.3265, Valid Loss: 15.5913, Valid RMSE: 4.5510\n", "Epoch 17, Train Loss: 14.7745, Valid Loss: 14.3047, Valid RMSE: 4.5445\n", "Epoch 18, Train Loss: 14.2318, Valid Loss: 14.5296, Valid RMSE: 4.5399\n", "Epoch 19, Train Loss: 13.7009, Valid Loss: 13.7255, Valid RMSE: 4.5345\n", "Epoch 20, Train Loss: 13.1774, Valid Loss: 11.8074, Valid RMSE: 4.5313\n" ] } ], "source": [ "train(model, train_dataloader, valid_dataloader, optimizer, criterion, epochs = 20)" ] }, { "cell_type": "markdown", "id": "f03f70d5-219e-4a29-8a21-d5efe942ce41", "metadata": {}, "source": [ "This is great! With this, we have the flexibility to tailor the loss function in the training loop as needed. For instance, we can integrate both user and item biases into the model and include regularization terms for these biases (challenging lab exercise)." ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:563]", "language": "python", "name": "conda-env-563-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }