dataeval.utils.dataset.split_dataset#
- dataeval.utils.dataset.split_dataset(labels, num_folds=1, stratify=False, split_on=None, metadata=None, test_frac=0.0, val_frac=0.0)#
Top level splitting function. Returns a dataclass containing a list of train and validation indices. Indices for a test holdout may also be optionally included
- Parameters:
labels (list or NDArray of ints) – Classification Labels used to generate splits. Determines the size of the dataset
num_folds (int, default 1) – Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0
stratify (bool, default False) – If true, dataset is split such that the class distribution of the entire dataset is preserved within each [train, val] partition, which is generally recommended.
split_on (list or None, default None) – Keys of the metadata dictionary upon which to group the dataset. A grouped partition is divided such that no group is present within both the training and validation set. Split_on groups should be selected to mitigate validation bias
metadata (dict or None, default None) – Dict containing data for potential dataset grouping. See split_on above
test_frac (float, default 0.0) – Fraction of data to be optionally held out for test set
val_frac (float, default 0.0) – Fraction of training data to be set aside for validation in the case where a single [train, val] split is desired
- Returns:
split_defs – Output class containing a list of indices of training and validation data for each fold and optional test indices
- Return type:
- Raises:
TypeError – Raised if split_on is passed, but metadata is None or empty
Note
When specifying groups and/or stratification, ratios for test and validation splits can vary as the stratification and grouping take higher priority than the percentages