dataeval.workflows.Sufficiency¶
-
class dataeval.workflows.Sufficiency(model, train_ds, test_ds, train_fn, eval_fn, runs=
1, substeps=5, train_kwargs=None, eval_kwargs=None)¶ Project dataset sufficiency using given a model and evaluation criteria.
- Parameters:¶
- model : nn.Module¶
Model that will be trained for each subset of data
- train_ds : torch.Dataset¶
Full training data that will be split for each run
- test_ds : torch.Dataset¶
Data that will be used for every run’s evaluation
- train_fn : Callable[[nn.Module, Dataset, Sequence[int]], None]¶
Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset), indices to train on and executes model training against the data.
- eval_fn : Callable[[nn.Module, Dataset], Mapping[str, float | ArrayLike]]¶
Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset) and returns a dictionary of metric values (Mapping[str, float]) which is used to assess model performance given the model and data.
- runs : int, default 1¶
Number of models to run over all subsets
- substeps : int, default 5¶
Total number of dataset partitions that each model will train on
- train_kwargs : Mapping | None, default None¶
Additional arguments required for custom training function
- eval_kwargs : Mapping | None, default None¶
Additional arguments required for custom evaluation function
-
evaluate(eval_at=
None)¶ Creates data indices, trains models, and returns plotting data
- Parameters:¶
- eval_at : int | Iterable[int] | None, default None¶
Specify this to collect accuracies over a specific set of dataset lengths, rather than letting sufficiency internally create the lengths to evaluate at.
- Returns:¶
Dataclass containing the average of each measure per substep
- Return type:¶
- Raises:¶
ValueError – If eval_at is not numerical
Examples
>>> suff = Sufficiency( ... model=model, ... train_ds=train_ds, ... test_ds=test_ds, ... train_fn=train_fn, ... eval_fn=eval_fn, ... runs=3, ... substeps=5, ... ) >>> suff.evaluate() SufficiencyOutput(steps=array([ 1, 3, 10, 31, 100], dtype=uint32), measures={'test': array([1., 1., 1., 1., 1.])}, n_iter=1000)