dataeval.workflows.Sufficiency ============================== .. py:class:: dataeval.workflows.Sufficiency(model, train_ds, test_ds, train_fn, eval_fn, runs = 1, substeps = 5, train_kwargs = None, eval_kwargs = None) Project dataset :term:`sufficiency` using given a model and evaluation criteria :param model: Model that will be trained for each subset of data :type model: nn.Module :param train_ds: Full training data that will be split for each run :type train_ds: torch.Dataset :param test_ds: Data that will be used for every run's evaluation :type test_ds: torch.Dataset :param train_fn: Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset), indices to train on and executes model training against the data. :type train_fn: Callable[[nn.Module, Dataset, Sequence[int]], None] :param eval_fn: Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset) and returns a dictionary of metric values (Mapping[str, float]) which is used to assess model performance given the model and data. :type eval_fn: Callable[[nn.Module, Dataset], Mapping[str, float | ArrayLike]] :param runs: Number of models to run over all subsets :type runs: int, default 1 :param substeps: Total number of dataset partitions that each model will train on :type substeps: int, default 5 :param train_kwargs: Additional arguments required for custom training function :type train_kwargs: Mapping | None, default None :param eval_kwargs: Additional arguments required for custom evaluation function :type eval_kwargs: Mapping | None, default None .. py:method:: evaluate(eval_at = None, niter = 1000) Creates data indices, trains models, and returns plotting data :param eval_at: Specify this to collect accuracies over a specific set of dataset lengths, rather than letting :term:`sufficiency` internally create the lengths to evaluate at. :type eval_at: int | Iterable[int] | None, default None :param niter: Iterations to perform when using the basin-hopping method to curve-fit measure(s). :type niter: int, default 1000 :returns: Dataclass containing the average of each measure per substep :rtype: SufficiencyOutput :raises ValueError: If `eval_at` is not numerical .. rubric:: Examples >>> suff = Sufficiency( ... model=model, ... train_ds=train_ds, ... test_ds=test_ds, ... train_fn=train_fn, ... eval_fn=eval_fn, ... runs=3, ... substeps=5, ... ) >>> suff.evaluate() SufficiencyOutput(steps=array([ 1, 3, 10, 31, 100], dtype=uint32), params={'test': array([ 0., 42., 0.])}, measures={'test': array([1., 1., 1., 1., 1.])})