dataeval.workflows.Sufficiency

class dataeval.workflows.Sufficiency(model, train_ds, test_ds, train_fn, eval_fn, runs=1, substeps=5, train_kwargs=None, eval_kwargs=None)

Project dataset sufficiency using given a model and evaluation criteria.

Parameters:
model : nn.Module

Model that will be trained for each subset of data

train_ds : torch.Dataset

Full training data that will be split for each run

test_ds : torch.Dataset

Data that will be used for every run’s evaluation

train_fn : Callable[[nn.Module, Dataset, Sequence[int]], None]

Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset), indices to train on and executes model training against the data.

eval_fn : Callable[[nn.Module, Dataset], Mapping[str, float | ArrayLike]]

Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset) and returns a dictionary of metric values (Mapping[str, float]) which is used to assess model performance given the model and data.

runs : int, default 1

Number of models to run over all subsets

substeps : int, default 5

Total number of dataset partitions that each model will train on

train_kwargs : Mapping | None, default None

Additional arguments required for custom training function

eval_kwargs : Mapping | None, default None

Additional arguments required for custom evaluation function

evaluate(eval_at=None)

Creates data indices, trains models, and returns plotting data

Parameters:
eval_at : int | Iterable[int] | None, default None

Specify this to collect accuracies over a specific set of dataset lengths, rather than letting sufficiency internally create the lengths to evaluate at.

Returns:

Dataclass containing the average of each measure per substep

Return type:

SufficiencyOutput

Raises:

ValueError – If eval_at is not numerical

Examples

>>> suff = Sufficiency(
...     model=model,
...     train_ds=train_ds,
...     test_ds=test_ds,
...     train_fn=train_fn,
...     eval_fn=eval_fn,
...     runs=3,
...     substeps=5,
... )
>>> suff.evaluate()
SufficiencyOutput(steps=array([  1,   3,  10,  31, 100], dtype=uint32), measures={'test': array([1., 1., 1., 1., 1.])}, n_iter=1000)