This function encodes categorical variables by fitting a posterior distribution per each category to the target variable y, using a known conjugate-prior. The resulting mean(s) of each posterior distribution per each category are used as the encodings.

conjugate_encoder(
  X_train,
  X_test = NULL,
  y,
  cat_columns,
  prior_params,
  objective = "regression"
)

Arguments

X_train

A tibble representing the training data set containing some categorical features/columns.

X_test

An optional tibble representing the test set, containing some set of categorical features/columns. Default = NULL.

y

A numeric vector or character vector representing the target variable. If the objective is "binary", then the vector should only contain two unique values.

cat_columns

A character vector containing the names of the categorical columns in the tibble that should be encoded.

prior_params

A list with named parameters that specify the prior assumed. For regression, this requires a dictionary with four keys and four values: mu, vega, alpha, beta. All must be real numbers, alpha should be greater than 0, beta and vega should be greater than 0. mu can be negative. For binary classification, this requires a list with two keys and two values: alpha, beta. All must be real numbers and be greater than 0.

objective

A string, either "regression" or "binary" specifying the problem. Default is regression. For regression, only the uniform quantization method is incorporated here for simplicity.

Value

A list containing with processed training and test sets, in which the named categorical columns are replaced with their encodings. For regression, the encoder will add one additional dimension to the original training set since the assumed prior distribution is two dimensional.

Examples

conjugate_encoder( X_train = mtcars, y = mtcars$mpg, cat_columns = c("cyl", "vs"), prior_params = list(mu = 3, vega = 5, alpha = 3, beta = 3), objective = "regression")
#> Joining, by = "cyl"
#> Joining, by = "vs"
#> [[1]] #> mpg disp hp drat wt qsec am gear carb cyl_encoded_mean #> 1 21.0 160.0 110 3.90 2.620 16.46 1 4 4 17.48031 #> 2 21.0 160.0 110 3.90 2.875 17.02 1 4 4 17.48031 #> 3 22.8 108.0 93 3.85 2.320 18.61 1 4 1 23.46585 #> 4 21.4 258.0 110 3.08 3.215 19.44 0 3 1 17.48031 #> 5 18.7 360.0 175 3.15 3.440 17.02 0 3 2 13.46486 #> 6 18.1 225.0 105 2.76 3.460 20.22 0 3 1 17.48031 #> 7 14.3 360.0 245 3.21 3.570 15.84 0 3 4 13.46486 #> 8 24.4 146.7 62 3.69 3.190 20.00 0 4 2 23.46585 #> 9 22.8 140.8 95 3.92 3.150 22.90 0 4 2 23.46585 #> 10 19.2 167.6 123 3.92 3.440 18.30 0 4 4 17.48031 #> 11 17.8 167.6 123 3.92 3.440 18.90 0 4 4 17.48031 #> 12 16.4 275.8 180 3.07 4.070 17.40 0 3 3 13.46486 #> 13 17.3 275.8 180 3.07 3.730 17.60 0 3 3 13.46486 #> 14 15.2 275.8 180 3.07 3.780 18.00 0 3 3 13.46486 #> 15 10.4 472.0 205 2.93 5.250 17.98 0 3 4 13.46486 #> 16 10.4 460.0 215 3.00 5.424 17.82 0 3 4 13.46486 #> 17 14.7 440.0 230 3.23 5.345 17.42 0 3 4 13.46486 #> 18 32.4 78.7 66 4.08 2.200 19.47 1 4 1 23.46585 #> 19 30.4 75.7 52 4.93 1.615 18.52 1 4 2 23.46585 #> 20 33.9 71.1 65 4.22 1.835 19.90 1 4 1 23.46585 #> 21 21.5 120.1 97 3.70 2.465 20.01 0 3 1 23.46585 #> 22 15.5 318.0 150 2.76 3.520 16.87 0 3 2 13.46486 #> 23 15.2 304.0 150 3.15 3.435 17.30 0 3 2 13.46486 #> 24 13.3 350.0 245 3.73 3.840 15.41 0 3 4 13.46486 #> 25 19.2 400.0 175 3.08 3.845 17.05 0 3 2 13.46486 #> 26 27.3 79.0 66 4.08 1.935 18.90 1 4 1 23.46585 #> 27 26.0 120.3 91 4.43 2.140 16.70 1 5 2 23.46585 #> 28 30.4 95.1 113 3.77 1.513 16.90 1 5 2 23.46585 #> 29 15.8 351.0 264 4.22 3.170 14.50 1 5 4 13.46486 #> 30 19.7 145.0 175 3.62 2.770 15.50 1 5 6 17.48031 #> 31 15.0 301.0 335 3.54 3.570 14.60 1 5 8 13.46486 #> 32 21.4 121.0 109 4.11 2.780 18.60 1 4 2 23.46585 #> cyl_encoded_var vs_encoded_mean vs_encoded_var #> 1 35.71723 14.77658 35.68746 #> 2 35.71723 14.77658 35.68746 #> 3 85.50876 21.64402 81.70632 #> 4 35.71723 21.64402 81.70632 #> 5 23.57909 14.77658 35.68746 #> 6 35.71723 21.64402 81.70632 #> 7 23.57909 14.77658 35.68746 #> 8 85.50876 21.64402 81.70632 #> 9 85.50876 21.64402 81.70632 #> 10 35.71723 21.64402 81.70632 #> 11 35.71723 21.64402 81.70632 #> 12 23.57909 14.77658 35.68746 #> 13 23.57909 14.77658 35.68746 #> 14 23.57909 14.77658 35.68746 #> 15 23.57909 14.77658 35.68746 #> 16 23.57909 14.77658 35.68746 #> 17 23.57909 14.77658 35.68746 #> 18 85.50876 21.64402 81.70632 #> 19 85.50876 21.64402 81.70632 #> 20 85.50876 21.64402 81.70632 #> 21 85.50876 21.64402 81.70632 #> 22 23.57909 14.77658 35.68746 #> 23 23.57909 14.77658 35.68746 #> 24 23.57909 14.77658 35.68746 #> 25 23.57909 14.77658 35.68746 #> 26 85.50876 21.64402 81.70632 #> 27 85.50876 14.77658 35.68746 #> 28 85.50876 21.64402 81.70632 #> 29 23.57909 14.77658 35.68746 #> 30 35.71723 14.77658 35.68746 #> 31 23.57909 14.77658 35.68746 #> 32 85.50876 21.64402 81.70632 #>