Takes in a model name, train/validation data sets, a parameter name and a vector of parameter values to try and then plots train/validation errors vs. parameter values.

plot_train_valid_error(
  model_name,
  X_train,
  y_train,
  X_valid,
  y_valid,
  param_name,
  param_vec
)

Arguments

model_name

a string of the machine learning model name. Only 'knn', 'decision tree', 'svm', and 'random forests' are allowed.

X_train

a numeric data frame of the training dataset without labels.

y_train

a numeric vector or factor of the training labels.

X_valid

a numeric data frame of the validation dataset without labels.

y_valid

a numeric vector or factor of the validation labels.

param_name

a string of the parameter name. Please choose this parameter based on the following information: 'knn': 'k', 'decision tree': 'maxdepth', 'svm': 'cost' or 'gamma', 'random forests': 'ntree'.

param_vec

a numeric vector of the parameter values.

Value

A plot

Examples

plot_train_valid_error("knn", tibble::tibble(a = c(1, 2, 3)), c(1, 2, 3), tibble::tibble(a = c(1, 2, 3)), c(1, 2, 3), "k", seq(3))