dataeval.utils.dataset.split_dataset#

dataeval.utils.dataset.split_dataset(labels, num_folds=1, stratify=False, split_on=None, metadata=None, test_frac=0.0, val_frac=0.0)#

Top level splitting function. Returns a dataclass containing a list of train and validation indices. Indices for a test holdout may also be optionally included

Parameters:
  • labels (list or NDArray of ints) – Classification Labels used to generate splits. Determines the size of the dataset

  • num_folds (int, default 1) – Number of [train, val] folds. If equal to 1, val_frac must be greater than 0.0

  • stratify (bool, default False) – If true, dataset is split such that the class distribution of the entire dataset is preserved within each [train, val] partition, which is generally recommended.

  • split_on (list or None, default None) – Keys of the metadata dictionary upon which to group the dataset. A grouped partition is divided such that no group is present within both the training and validation set. Split_on groups should be selected to mitigate validation bias

  • metadata (dict or None, default None) – Dict containing data for potential dataset grouping. See split_on above

  • test_frac (float, default 0.0) – Fraction of data to be optionally held out for test set

  • val_frac (float, default 0.0) – Fraction of training data to be set aside for validation in the case where a single [train, val] split is desired

Returns:

split_defs – Output class containing a list of indices of training and validation data for each fold and optional test indices

Return type:

SplitDatasetOutput

Raises:

TypeError – Raised if split_on is passed, but metadata is None or empty

Note

When specifying groups and/or stratification, ratios for test and validation splits can vary as the stratification and grouping take higher priority than the percentages