dataeval.protocols.TrainingStrategy

class dataeval.protocols.TrainingStrategy

Protocol defining the interface for training a model on a dataset subset.

Implementations must provide a train method with this signature. Uses structural typing - no explicit inheritance required.

The @runtime_checkable decorator allows isinstance() checks if needed, though structural typing works without it at type-check time.

The model parameter accepts any type to support different ML backends (PyTorch, TensorFlow, JAX, etc.).

Examples

Creating a custom training strategy for PyTorch:

>>> class MyTraining:
...     def __init__(self, learning_rate: float, epochs: int):
...         self.learning_rate = learning_rate
...         self.epochs = epochs
...
...     def train(self, model: torch.nn.Module, dataset: Dataset, indices: Sequence[int]) -> None:
...         # Custom training implementation
...         optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
...         for epoch in range(self.epochs):
...             # Training loop using specified indices
...             ...

Creating a training strategy for JAX:

>>> class JAXTraining:
...     def __init__(self, learning_rate: float, epochs: int):
...         self.learning_rate = learning_rate
...         self.epochs = epochs
...
...     def train(self, params, dataset: Dataset, indices: Sequence[int]) -> None:
...         import jax
...         import jax.numpy as jnp
...
...         # JAX training with functional updates
...         for epoch in range(self.epochs):
...             for idx in indices:
...                 x, y, _ = dataset[idx]
...                 grads = jax.grad(loss_fn)(params, x, y)
...                 params = jax.tree.map(lambda p, g: p - self.learning_rate * g, params, grads)
train(model, dataset, indices)

Train the model using the specified indices from the dataset.

Parameters:
model : Any

The model to train. Can be any model type (PyTorch Module, TensorFlow model, etc.). Training should modify the model in-place when the backend supports it.

dataset : Dataset[T]

The full dataset. Only samples at the specified indices should be used for training.

indices : Sequence[int]

Indices indicating which samples from the dataset to use for training this step. These allow the same model to be trained incrementally on growing subsets.

Returns:

Training modifies the model in-place.

Return type:

None

Notes

Implementations should: - Only use samples at the specified indices - Modify the model parameters in-place - Handle their own loss computation and optimization - Be deterministic or set seeds internally for reproducibility