Visualization package for ML models.
This package contains four functions to allow users to conveniently plot various visualizations as well as compare performance of different classifier models. The purpose of this package is to reduce time spent on developing visualizations and comparing models, to speed up the model creation process for data scientists. The four functions will perform the following tasks:
Compare the performance of various models
Plot the confusion matrix based on the input data
Plot train/validation errors vs. parameter values
Plot the ROC curve and calculate the AUC
Contributors | GitHub Handle |
---|---|
Anas Muhammad | anasm-17 |
Tao Huang | taohuang-ubc |
Fanli Zhou | flizhou |
Mike Chen | miketianchen |
You can install RMLViz from GitHub with:
Function Name | Input | Output | Description |
---|---|---|---|
model_comparison_table | List of model, X_train, y_train, X_test, y_test, scoring option | Dataframe of model score | Takes in a list of models and the train test data then outputs a table comparing the scores for different models. |
confusion_matrix | Model, X_train, y_train, X_test, y_test, predicted_y | Confusion Matrix Plot, Dataframe of various scores (Recall, F1 and etc) | Takes in a trained model with X and y values to produce a confusion matrix visual. If predicted_y array is passed in, other evaluation scoring metrics such as Recall, and precision will also be produced. |
plot_train_valid_error | model_name, X_train, y_train, X_test, y_test, param_name, param_vec | Train/validation errors vs. parameter values plot | Takes in a model name, train/validation data sets, a parameter name and a vector of parameter values to try and then plots train/validation errors vs. parameter values. |
plot_roc | model, X_valid, y_valid | ROC plot | Takes in a fitted model, the validation set(X_valid) and the validation set labels(y_valid) and plots the ROC curve. The ROC curve also produces AUC score. |
For all of our functions, there are not existing packages that implement the exact same functionality in R. Most of these functions helps to show insights about machine learning models conveniently.
R version >= 3.6.1 and R packages:
model_comparison_table
library(RMLViz)
library(mlbench)
data(Sonar)
toy_classification_data <- dplyr::select(dplyr::as_tibble(Sonar), V1, V2, V3, V4, V5, Class)
train_ind <- caret::createDataPartition(toy_classification_data$Class, p=0.9, list=F)
train_set_cf <- toy_classification_data[train_ind, ]
test_set_cf <- toy_classification_data[-train_ind, ]
## classification models setup
gbm <- caret::train(Class~., train_set_cf, method="gbm", verbose=F)
lm_cf <- caret::train(Class~., train_set_cf, method="LogitBoost", verbose=F)
model_comparison_table(train_set_cf, test_set_cf,
gbm_mod=gbm, log_mod = lm_cf)
#> # A tibble: 2 x 5
#> model train_Accuracy train_Kappa test_Accuracy test_Kappa
#> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 gbm_mod 0.761 0.516 0.35 -0.3
#> 2 log_mod 0.75 0.503 0.45 -0.0784
confusion_matrix
library(RMLViz)
data(iris)
set.seed(123)
split <- caTools::sample.split(iris$Species, SplitRatio = 0.75)
training_set <- subset(iris, split == TRUE)
valid_set <- subset(iris, split == FALSE)
X_train <- training_set[, -5]
y_train <- training_set[, 5]
X_valid <- valid_set[, -5]
y_valid <- valid_set[, 5]
predict <- class::knn(X_train, X_train, y_train, k = 5)
confusion_matrix(y_train, predict)
#> Sensitivity Specificity Pos Pred Value Neg Pred Value Precision
#> Class: setosa 1.0000000 1.0000000 1.00 1.000000 1.00
#> Class: versicolor 0.9473684 1.0000000 1.00 0.974359 1.00
#> Class: virginica 1.0000000 0.9736842 0.95 1.000000 0.95
#> Recall F1 Prevalence Detection Rate Detection Prevalence
#> Class: setosa 1.0000000 1.000000 0.3333333 0.3333333 0.3333333
#> Class: versicolor 0.9473684 0.972973 0.3333333 0.3157895 0.3157895
#> Class: virginica 1.0000000 0.974359 0.3333333 0.3333333 0.3508772
#> Balanced Accuracy
#> Class: setosa 1.0000000
#> Class: versicolor 0.9736842
#> Class: virginica 0.9868421
plot_train_valid_error
library(RMLViz)
data(iris)
set.seed(123)
split <- caTools::sample.split(iris$Species, SplitRatio = 0.75)
training_set <- subset(iris, split == TRUE)
valid_set <- subset(iris, split == FALSE)
X_train <- training_set[, -5]
y_train <- training_set[, 5]
X_valid <- valid_set[, -5]
y_valid <- valid_set[, 5]
plot_train_valid_error('knn',
X_train, y_train,
X_valid, y_valid,
'k', seq(50))
plot_roc
library(RMLViz)
set.seed(420)
num.samples <- 100
weight <- sort(rnorm(n=num.samples, mean=172, sd=29))
obese <- ifelse(test=(runif(n=num.samples) < (rank(weight)/num.samples)),
yes=1, no=0)
glm.fit=glm(obese ~ weight, family=binomial)
obese_proba <- glm.fit$fitted.values
plot_roc(obese, obese_proba)