dataeval.protocols.TrainingStrategy¶
- class dataeval.protocols.TrainingStrategy¶
Protocol defining the interface for training a model on a dataset subset.
Implementations must provide a train method with this signature. Uses structural typing - no explicit inheritance required.
The @runtime_checkable decorator allows isinstance() checks if needed, though structural typing works without it at type-check time.
Examples
Creating a custom training strategy:
>>> class MyTraining: ... def __init__(self, learning_rate: float, epochs: int): ... self.learning_rate = learning_rate ... self.epochs = epochs ... ... def train(self, model: torch.nn.Module, dataset: Dataset, indices: Sequence[int]) -> None: ... # Custom training implementation ... optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) ... for epoch in range(self.epochs): ... # Training loop using specified indices ... ...- train(model, dataset, indices)¶
Train the model using the specified indices from the dataset.
- Parameters:¶
- model : nn.Module¶
The model to train. Training should modify the model in-place.
- dataset : Dataset[T]¶
The full dataset. Only samples at the specified indices should be used for training.
- indices : Sequence[int]¶
Indices indicating which samples from the dataset to use for training this step. These allow the same model to be trained incrementally on growing subsets.
- Returns:¶
Training modifies the model in-place.
- Return type:¶
None
Notes
Implementations should: - Only use samples at the specified indices - Modify the model parameters in-place - Handle their own loss computation and optimization - Be deterministic or set seeds internally for reproducibility