R/plot_train_valid_error.R
plot_train_valid_error.Rd
Takes in a model name, train/validation data sets, a parameter name and a vector of parameter values to try and then plots train/validation errors vs. parameter values.
plot_train_valid_error( model_name, X_train, y_train, X_valid, y_valid, param_name, param_vec )
model_name | a string of the machine learning model name. Only 'knn', 'decision tree', 'svm', and 'random forests' are allowed. |
---|---|
X_train | a numeric data frame of the training dataset without labels. |
y_train | a numeric vector or factor of the training labels. |
X_valid | a numeric data frame of the validation dataset without labels. |
y_valid | a numeric vector or factor of the validation labels. |
param_name | a string of the parameter name. Please choose this parameter based on the following information: 'knn': 'k', 'decision tree': 'maxdepth', 'svm': 'cost' or 'gamma', 'random forests': 'ntree'. |
param_vec | a numeric vector of the parameter values. |
A plot
plot_train_valid_error("knn", tibble::tibble(a = c(1, 2, 3)), c(1, 2, 3), tibble::tibble(a = c(1, 2, 3)), c(1, 2, 3), "k", seq(3))