Visualization package for ML models.

This package contains four functions to allow users to conveniently plot various visualizations as well as compare performance of different classifier models. The purpose of this package is to reduce time spent on developing visualizations and comparing models, to speed up the model creation process for data scientists. The four functions will perform the following tasks:

  1. Compare the performance of various models

  2. Plot the confusion matrix based on the input data

  3. Plot train/validation errors vs. parameter values

  4. Plot the ROC curve and calculate the AUC

Contributors GitHub Handle
Anas Muhammad anasm-17
Tao Huang taohuang-ubc
Fanli Zhou flizhou
Mike Chen miketianchen

Installation

You can install RMLViz from GitHub with:

Functions

Function Name Input Output Description
model_comparison_table List of model, X_train, y_train, X_test, y_test, scoring option Dataframe of model score Takes in a list of models and the train test data then outputs a table comparing the scores for different models.
confusion_matrix Model, X_train, y_train, X_test, y_test, predicted_y Confusion Matrix Plot, Dataframe of various scores (Recall, F1 and etc) Takes in a trained model with X and y values to produce a confusion matrix visual. If predicted_y array is passed in, other evaluation scoring metrics such as Recall, and precision will also be produced.
plot_train_valid_error model_name, X_train, y_train, X_test, y_test, param_name, param_vec Train/validation errors vs. parameter values plot Takes in a model name, train/validation data sets, a parameter name and a vector of parameter values to try and then plots train/validation errors vs. parameter values.
plot_roc model, X_valid, y_valid ROC plot Takes in a fitted model, the validation set(X_valid) and the validation set labels(y_valid) and plots the ROC curve. The ROC curve also produces AUC score.

Alignment with R Ecosystems

For all of our functions, there are not existing packages that implement the exact same functionality in R. Most of these functions helps to show insights about machine learning models conveniently.

Dependencies

R version >= 3.6.1 and R packages:

  • vctrs,
  • lifecycle,
  • pillar,
  • dplyr
  • tidyr
  • magrittr
  • ggplot2
  • broom
  • pls
  • covr
  • gbm
  • tibble
  • testthat
  • purrr
  • pROC
  • plotROC
  • datasets
  • class
  • rpart
  • randomForest
  • e1071
  • mlbench
  • caTools
  • caret

Usage and example

  1. model_comparison_table
  1. confusion_matrix
library(RMLViz)
data(iris)

set.seed(123)
split <- caTools::sample.split(iris$Species, SplitRatio = 0.75)

training_set <- subset(iris, split == TRUE)
valid_set <- subset(iris, split == FALSE)

X_train <- training_set[, -5]
y_train <- training_set[, 5]
X_valid <- valid_set[, -5]
y_valid <- valid_set[, 5]

predict <- class::knn(X_train, X_train, y_train, k = 5)

confusion_matrix(y_train, predict)

  1. plot_train_valid_error

  1. plot_roc
library(RMLViz)
set.seed(420)
num.samples <- 100
weight <- sort(rnorm(n=num.samples, mean=172, sd=29))
obese <- ifelse(test=(runif(n=num.samples) < (rank(weight)/num.samples)),
                yes=1, no=0)

glm.fit=glm(obese ~ weight, family=binomial)
obese_proba <- glm.fit$fitted.values

plot_roc(obese, obese_proba)