Sufficiency
Tutorials
Check out this tutorial to begin using the Sufficiency class
How To Guides
DAML API
- class daml.workflows.Sufficiency(model: Module, train_ds: Dataset, test_ds: Dataset, train_fn: Callable[[Module, Dataset, Sequence[int]], None], eval_fn: Callable[[Module, Dataset], Dict[str, float] | Dict[str, ndarray]], runs: int = 1, substeps: int = 5, train_kwargs: Dict[str, Any] | None = None, eval_kwargs: Dict[str, Any] | None = None)
Project dataset sufficiency using given a model and evaluation criteria
- Parameters:
model (nn.Module) – Model that will be trained for each subset of data
train_ds (Dataset) – Full training data that will be split for each run
test_ds (Dataset) – Data that will be used for every run’s evaluation
train_fn (Callable[[nn.Module, Dataset, Sequence[int]], None]) – Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset), indices to train on and executes model training against the data.
eval_fn (Callable[[nn.Module, Dataset], Dict[str, float]]) – Function which takes a model (torch.nn.Module), a dataset (torch.utils.data.Dataset) and returns a dictionary of metric values (Dict[str, float]) which is used to assess model performance given the model and data.
runs (int, default 1) – Number of models to run over all subsets
substeps (int, default 5) – Total number of dataset partitions that each model will train on
train_kwargs (Dict[str, Any] | None, default None) – Additional arguments required for custom training function
eval_kwargs (Dict[str, Any] | None, default None) – Additional arguments required for custom evaluation function
- evaluate(eval_at: ndarray | None = None, niter: int = 1000) Dict[str, ndarray | Dict[str, ndarray]]
Creates data indices, trains models, and returns plotting data
- Parameters:
eval_at (Optional[np.ndarray]) – Specify this to collect accuracies over a specific set of dataset lengths, rather than letting Sufficiency internally create the lengths to evaluate at.
niter (int, default 1000) – Iterations to perform when using the basin-hopping method to curve-fit measure(s).
- Returns:
Dictionary containing the average of each measure per substep
- Return type:
Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]
- classmethod inv_project(targets: Dict[str, ndarray], data: Dict[str, ndarray | Dict[str, ndarray]]) Dict[str, ndarray]
Calculate training samples needed to achieve target model metric values.
- Parameters:
targets (Dict[str, np.ndarray]) – Dictionary of target metric scores (from 0.0 to 1.0) that we want to achieve, where the key is the name of the metric.
data (Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]) – Dataclass containing the average of each measure per substep
- Returns:
List of the number of training samples needed to achieve each corresponding entry in targets
- Return type:
Dict[str, np.ndarray]
- classmethod plot(data: Dict[str, ndarray | Dict[str, ndarray]], class_names: Sequence[str] | None = None) List[Figure]
Plotting function for data sufficiency tasks
- Parameters:
data (Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]) – Dataclass containing the average of each measure per substep
- Returns:
List of Figures for each measure
- Return type:
List[plt.Figure]
- Raises:
KeyError – If STEPS_KEY or measure is not a valid key
ValueError – If the length of data points in the measures do not match
- classmethod project(data: Dict[str, ndarray | Dict[str, ndarray]], projection: int | Sequence[int] | ndarray) Dict[str, ndarray]
Projects the measures for each value of X
- Parameters:
data (Dict[str, Union[np.ndarray, Dict[str, np.ndarray]]]) – Dataclass containing the average of each measure per substep
steps (Union[int, np.ndarray]) – Step or steps to project
niter (int, default 200) – Number of iterations to perform in the basin-hopping numerical process to curve-fit data
- Raises:
KeyError – If STEPS_KEY or measure is not a valid key
ValueError – If the length of data points in the measures do not match If the steps are not int, Sequence[int] or an ndarray