K-fold cross validation is primarily used in applied machine learning to estimate the performance of a machine learning model on unseen data. It is a re-sampling procedure to evaluate machine learning models in limited data. In summary, the k-fold step splits the data set into k different subsets and iterates over them using one of them as the test set and the remaining k-1 elements as the training set. In the figure blow there is a example for k equals 10 shown.
Parameter
k
: number of in how many pieces the data set gets splitoutput_path
: file path in which the output of the k-fold validation gets storedoutput_field
: field in which the predicitons are going to be stored inclassifier and label fields are there for inheritance and are not actually used.
classifier_params
: parameter of a lib.nlp classifier to be used during the k-fold validation
Example
{ "step": "classifier", "type": "kfold", "k": 5, "output_path": "./output.json", "output_field": "prediction", "label_field": "class", "classifier": "none", "classifier_params": { "explanation_field": "explanation", "input_fields": [ "normalized_extract" ], "label_field": "label", "model_kwargs": { "probability": true }, "model_type": "SVC", "output_field": "prediction", "step": "classifier", "type": "sklearn" } }
Output
It outputs the success rate for each group of the k folds. In addition also lists the overall metrics of the output (in case of multiclass classification the precision, recall and f1-score are macro averaged).
****** KFOLD VALIDATION OUTPUT ***** Group 0:{'successful predicted': 26, 'total': 27, 'ratio': 0.9629629629629629} Group 1:{'successful predicted': 27, 'total': 27, 'ratio': 1.0} Group 2:{'successful predicted': 27, 'total': 27, 'ratio': 1.0} Group 3:{'successful predicted': 25, 'total': 27, 'ratio': 0.9259259259259259} Group 4:{'successful predicted': 27, 'total': 27, 'ratio': 1.0} Report saved into: ./output.json ********************************** ****** OVERALL METRICS ***** metrics: {'metrics': {'accuracy': 97.2, 'precision': 97.39999999999999, 'recall': 97.2, 'f1-score': 97.2}, 'confusion_matrix': {'labels': ['cat', 'not_cat'], 'values': [17, 0, 1, 18]}} Report stored at: ./kfold_validation.json **********************************