{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model selection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In determining the metric we use for our model, we need to first consider the context of Type I and Type II errors for our problem:\n", "\n", "- A type I (false positive) error would be predicting that a customer will make a purchase, when they in fact do not.\n", "- A type II (false negative) error would be predicting that a customer will not make a purchase, when in fact they do.\n", "\n", "Further, we need to consider the business objective of an e-commerce company who would potentially be using this model. We assume that the company will have the following objectives: \n", "\n", "1. Maximize revenue by increasing purchase conversion rate\n", "2. Minimize disruption to customer experience from targeted nudges\n", "\n", "Based on this, we note that relevant metrics include precision and recall. To be more specific, we want to maximize recall, while keeping a minimum threshold of 60% precision (threshold based on business requirement and tolerance)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Base model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our baseline model will be the `DummyClassifier` model from `sklearn` with the default parameters. From the `sklearn` [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyClassifier.html), the default `strategy` parameter is `prior` which always predicts the class that maximizes the class prior." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models tested" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In performing model selection, we will consider the following models:\n", "\n", "- Logistic Regression\n", "- Support Vector Machine w/ an RBF kernel\n", "- Random Forest Classifier\n", "- XGBoost Classifier\n", "\n", "In assessing these models we will:\n", "\n", "1. Perform five fold cross validation and look at the mean results\n", "2. Generate and analyze confuision matrices with cross validated predictions on the train set\n", "3. Generate and analyze precision recall curves with cross validated predictions (probabilities in this case) on the train set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cross validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We performed 5 fold cross validation for the above models on our training data and observed the following metrics:\n", "\n", "_Please note that the following metrics are the mean values over the 5 folds of cross validation_" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "remove_input" ] }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " | DummyClassifier | \n", "LogisticRegression | \n", "SVC | \n", "RandomForest | \n", "XGBoost | \n", "
---|---|---|---|---|---|
fit_time | \n", "0.00 (+/- 0.00) | \n", "0.21 (+/- 0.02) | \n", "1.27 (+/- 0.19) | \n", "0.80 (+/- 0.06) | \n", "0.66 (+/- 0.01) | \n", "
score_time | \n", "0.01 (+/- 0.00) | \n", "0.01 (+/- 0.00) | \n", "1.46 (+/- 0.16) | \n", "0.06 (+/- 0.00) | \n", "0.01 (+/- 0.00) | \n", "
test_accuracy | \n", "0.85 (+/- 0.00) | \n", "0.88 (+/- 0.02) | \n", "0.89 (+/- 0.03) | \n", "0.89 (+/- 0.03) | \n", "0.88 (+/- 0.04) | \n", "
train_accuracy | \n", "0.85 (+/- 0.00) | \n", "0.89 (+/- 0.01) | \n", "0.91 (+/- 0.01) | \n", "1.00 (+/- 0.00) | \n", "0.99 (+/- 0.00) | \n", "
test_precision | \n", "0.00 (+/- 0.00) | \n", "0.73 (+/- 0.15) | \n", "0.71 (+/- 0.13) | \n", "0.69 (+/- 0.13) | \n", "0.65 (+/- 0.16) | \n", "
train_precision | \n", "0.00 (+/- 0.00) | \n", "0.77 (+/- 0.02) | \n", "0.78 (+/- 0.01) | \n", "1.00 (+/- 0.00) | \n", "1.00 (+/- 0.00) | \n", "
test_recall | \n", "0.00 (+/- 0.00) | \n", "0.38 (+/- 0.08) | \n", "0.48 (+/- 0.11) | \n", "0.55 (+/- 0.11) | \n", "0.56 (+/- 0.12) | \n", "
train_recall | \n", "0.00 (+/- 0.00) | \n", "0.41 (+/- 0.04) | \n", "0.55 (+/- 0.07) | \n", "1.00 (+/- 0.00) | \n", "0.94 (+/- 0.02) | \n", "
test_f1 | \n", "0.00 (+/- 0.00) | \n", "0.50 (+/- 0.10) | \n", "0.57 (+/- 0.12) | \n", "0.61 (+/- 0.12) | \n", "0.60 (+/- 0.13) | \n", "
train_f1 | \n", "0.00 (+/- 0.00) | \n", "0.53 (+/- 0.04) | \n", "0.64 (+/- 0.05) | \n", "1.00 (+/- 0.00) | \n", "0.97 (+/- 0.01) | \n", "
test_average_precision | \n", "0.15 (+/- 0.00) | \n", "0.61 (+/- 0.14) | \n", "0.65 (+/- 0.17) | \n", "0.68 (+/- 0.18) | \n", "0.65 (+/- 0.18) | \n", "
train_average_precision | \n", "0.15 (+/- 0.00) | \n", "0.67 (+/- 0.04) | \n", "0.78 (+/- 0.03) | \n", "1.00 (+/- 0.00) | \n", "1.00 (+/- 0.00) | \n", "