Sufficiency

Tutorials

Check out this tutorial to begin using the Sufficiency class

Sufficiency Tutorial

How To Guides

  1. Sufficiency How To

DAML API

class daml.metrics.Sufficiency(model: Module, train_ds: Dataset, test_ds: Dataset, train_fn: Callable[[Module, Dataset, Sequence[int]], None], eval_fn: Callable[[Module, Dataset], Dict[str, float] | Dict[str, ndarray]], runs: int = 1, substeps: int = 5, train_kwargs: Dict[str, Any] | None = None, eval_kwargs: Dict[str, Any] | None = None)

Project dataset sufficiency using given a model and evaluation criteria

Parameters:
  • model (nn.Module) – Model that will be trained for each subset of data

  • train_ds (Dataset) – Full training data that will be split for each run

  • test_ds (Dataset) – Data that will be used for every run’s evaluation

  • train_fn (Callable[[nn.Module, Dataset, Sequence[int]], None]) – Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset), indices to train on and executes model training against the data.

  • eval_fn (Callable[[nn.Module, Dataset], Dict[str, float]]) – Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset) and returns a dictionary of metric values (Dict[str, float]) which is used to assess model performance given the model and data.

  • runs (int) – Number of models to run over all subsets

  • substeps (int) – Total number of dataset partitions that each model will train on

  • train_kwargs (Dict[str, Any] | None, default None) – Additional arguments required for custom training function

  • eval_kwargs (Dict[str, Any] | None, default None) – Additional arguments required for custom evaluation function

evaluate() Dict[str, ndarray]

Creates data indices, trains models, and returns plotting data

Returns:

Dictionary containing the average of each measure per substep

Return type:

Dict[str, np.ndarray]

classmethod plot(data: Dict[str, ndarray], class_names: Sequence[str] | None = None) List[Figure]

Plotting function for data sufficiency tasks

Parameters:

data (Dict[str, np.ndarray]) – Dataclass containing the average of each measure per substep

Returns:

List of Figures for each measure

Return type:

List[plt.Figure]

Raises:
  • KeyError – If STEPS_KEY or measure is not a valid key

  • ValueError – If the length of data points in the measures do not match

classmethod project(data: Dict[str, ndarray], projection: int | Sequence[int] | ndarray) Dict[str, ndarray]

Projects the measures for each value of X

Parameters:
  • data (Dict[str, np.ndarray]) – Dataclass containing the average of each measure per substep

  • steps (Union[int, np.ndarray]) – Step or steps to project

Raises:
  • KeyError – If STEPS_KEY or measure is not a valid key

  • ValueError – If the length of data points in the measures do not match If the steps are not int, Sequence[int] or an ndarray