Skip to content

Cross Validation#

Cross Validation (CV) is a technique for assessing the generalization performance of a model using data it has never seen before. The validation score gives us a sense for how well the model will perform in the real world. In addition, it allows the user to identify problems such as underfitting, overfitting, and selection bias which are discussed in the last section.

Creating a Testing Set#

For some projects we'll create a dedicated testing set, but in others we can separate some of the samples from our master dataset to be used for testing on the fly. To ensure that both the training and testing sets contain samples that accurately represent the master set we have a number of methods on the Dataset object we can employ.

Randomized Split#

The first method of creating a training and testing set that works for all datasets is to randomize and then split the dataset into two subsets of varying proportions. In the example below we'll create a training set with 80% of the samples and a testing set with the remaining 20% using the randomize() and split() methods on the Dataset object.

[$training, $testing] = $dataset->randomize()->split(0.8);

You can also use the take() method to extract a testing set while leaving the remaining samples in the training set.

$testing = $training->randomize()->take(1000);

Stratified Split#

If we have a Labeled dataset containing class labels, we can split the dataset in such a way that samples belonging to each class are represented fairly in both sets. This stratified method helps to reduce selection bias by ensuring that each subset remains balanced.

[$training, $testing] = $dataset->stratifiedSplit(0.8);

Metrics#

Cross validation Metrics are used to score the predictions made by an Estimator with respect to their known ground-truth labels. There are different metrics for different types of problems. To return a validation score from a Metric pass the predictions and labels to the score() method like in the example below.

use Rubix\ML\CrossValidation\Metrics\Accuracy;

$predictions = $estimator->predict($testing);

$metric = new Accuracy();

$score = $metric->score($predictions, $testing->labels());

echo $score;
0.85

Note

All metrics follow the schema that higher scores are better - thus, common loss functions such as Mean Squared Error and RMSE are given as their negative to conform to this schema.

Classification and Anomaly Detection#

Metrics for classification and anomaly detection (a special case of binary classification) compare class predictions to other categorical labels. Their scores are calculated from the true positive (TP), true negative (TN), false positive (FP), and false negative (FN) counts derived from the confusion matrix between the set of predictions and their ground-truth labels.

Name Range Formula Notes
Accuracy [0, 1] TPTP+FP Not suited for imbalanced datasets
F Beta [0, 1] (1+β2)precisionrecall(β2precision)+recall
Informedness [-1, 1] TPTP+FN+TPTN+FP1
MCC [-1, 1] TP×TNFP×FN(TP+FP)(TP+FN)(TN+FP)(TN+FN)

Regression#

Regression metrics output a score based on the error achieved by comparing continuous-valued predictions and their ground-truth labels.

Name Range Formula Notes
Mean Absolute Error [-∞, 0] 1ni=1n|YiYi^| Output in same units as predictions
Mean Squared Error [-∞, 0] 1ni=1n(YiYi^)2 Sensitive to outliers
Median Absolute Error [-∞, 0] median(|YiY~|) Robust to outliers
R Squared [-∞, 1] 1SSresSStot
RMSE [-∞, 0] 1ni=1n(YiYi^)2 Output in same units as predictions
SMAPE [-100, 0] 100%nt=1n|FtAt|(|At|+|Ft|)/2

Clustering#

Clustering metrics derive their scores from a contingency table which can be thought of as a confusion matrix where the class names of the predictions are unknown.

Name Range Formula Notes
Completeness [0, 1] 1H(K,C)H(K) Not suited for hyper-parameter tuning
Homogeneity [0, 1] 1H(C,K)H(C) Not suited for hyper-parameter tuning
Rand Index [-1, 1] ij(nij2)[i(ai2)j(bj2)]/(n2)12[i(ai2)+j(bj2)][i(ai2)j(bj2)]/(n2)
V Measure [0, 1] (1+β)hcβh+c

Reports#

Cross validation reports give you a deeper sense for how well a particular model performs with fine-grained information. The generate() method on the Report Generator interface takes a set of predictions and their corresponding ground-truth labels and returns a Report object filled with useful statistics that can be printed directly to the terminal or saved to a file.

Report Usage
Confusion Matrix Classification or Anomaly Detection
Contingency Table Clustering
Error Analysis Regression
Multiclass Breakdown Classification or Anomaly Detection

Generating a Report#

To generate the report, pass the predictions made by an estimator and their ground-truth labels to the generate() method on the report generator instance.

use Rubix\ML\CrossValidation\Reports\ErrorAnalysis;

$report = new ErrorAnalysis();

$results = $report->generate($predictions, $labels);

Printing a Report#

The results of the report are returned in a Report object. Report objects implement the Stringable interface which means they can be cast to strings to output the human-readable form of the report.

echo $results;
{
    "mean absolute error": 0.8,
    "median absolute error": 1,
    "mean squared error": 1,
    "mean absolute percentage error": 14.02077497665733,
    "rms error": 1,
    "mean squared log error": 0.019107097505647368,
    "r squared": 0.9958930551562692,
    "error mean": -0.2,
    "error standard deviation": 0.9898464007663,
    "error skewness": -0.22963966338592326,
    "error kurtosis": -1.0520833333333324,
    "error min": -2,
    "error 25%": -1.0,
    "error median": 0.0,
    "error 75%": 0.75,
    "error max": 1,
    "cardinality": 10
}

Accessing Report Attributes#

You can access individual report attributes by treating the report object as an associative array.

$mae = $results['mean absolute error'];

Saving a Report#

Report objects can be cast to JSON encodings which are persistable using a Persister object. To save a report, call the toJSON() method on the report to return an encoding object and then pass a persister to its saveTo() method like in the example below.

use Rubix\ML\Persisters\Filesystem;

$results->toJSON()->saveTo(new Filesystem('error.report'));

Validators#

Metrics can be used stand-alone or they can be used within a Validator object as the scoring function. Validators automate the cross validation process by training and testing a learner on different subsets of a master dataset. The way in which subsets are chosen depends on the algorithm employed under the hood. Most validators implement the Parallel interface which allows multiple tests to be run at the same time in parallel.

Validator Test Coverage Parallel
Hold Out Partial
K Fold Full
Leave P Out Full
Monte Carlo Asymptotically Full

For example, the K Fold validator automatically selects one of k subsets referred to as a fold as a validation set and then uses the rest of the folds to train the learner. It does this until the learner is trained and tested on every sample in the dataset at least once. The final score is then an average of the k validation scores returned by each test. To begin, pass an untrained Learner, a Labeled dataset, and your chosen validation metric to the validator's test() method.

use Rubix\ML\CrossValidation\KFold;
use Rubix\ML\CrossValidation\Metrics\FBeta;

$validator = new KFold(5);

$score = $validator->test($estimator, $dataset, new FBeta());

echo $score;
0.9175

Common Problems#

Poor generalization performance can be explained by one or more of these common problems.

Underfitting#

A poorly performing model can sometimes be explained as underfitting the training data - a condition in which the learner is unable to capture the underlying pattern or trend given the model constraints. The result of an underfit model is an estimator with high bias error. Underfitting usually occurs when a simple model is chosen to represent data that is complex and non-linear. Adding more features to the dataset can help, however if the problem is too severe, a more flexible learning algorithm can be chosen for the task instead.

Overfitting#

When a model performs well on training data but poorly during cross-validation it could be that the model has overfit the training data. Overfitting occurs when the model conforms too closely to the training data and therefore fails to generalize well to new data or make predictions reliably. Flexible models are more prone to overfitting due to their ability to memorize individual samples. Most learners employ strategies such as regularization, early stopping, and/or tree pruning to control overfitting, however if overfitting is still a problem, adding more unique samples to the dataset can also help.

Selection Bias#

When a model performs well on certain samples but poorly on others it could be that the learner was trained with a dataset that exhibits selection bias. Selection bias is the bias introduced when a population is disproportionally represented in a dataset. For example, if a learner is trained to classify pictures of cats and dogs but mostly (say 90%) cats are represented in the dataset, the model will likely have difficulty making real-world predictions where cats and dogs are more equally represented. To correct selection bias, either obtain more unique training samples or up-sample the class of the underrepresented type.


Last update: 2021-06-05