dataeval.protocols.ModelResetStrategy

class dataeval.protocols.ModelResetStrategy

Protocol for resetting model parameters between training runs.

This protocol enables backend-agnostic model reset functionality. Implementations can provide custom reset logic for any ML framework (PyTorch, TensorFlow, JAX, etc.).

For PyTorch models (nn.Module), a default implementation is provided that calls reset_parameters() on each layer. For other backends, users must provide their own reset strategy.

Examples

Custom reset for PyTorch with specific initialization:

>>> import torch.nn as nn
>>> class XavierReset:
...     def __call__(self, model: nn.Module) -> nn.Module:
...         for m in model.modules():
...             if hasattr(m, "weight") and m.weight is not None:
...                 nn.init.xavier_uniform_(m.weight)
...             if hasattr(m, "bias") and m.bias is not None:
...                 nn.init.zeros_(m.bias)
...         return model

Reset strategy for JAX models using parameter reinitialization:

>>> class JAXReset:
...     def __init__(self, init_fn, rng_key):
...         self.init_fn = init_fn
...         self.rng_key = rng_key
...
...     def __call__(self, params):
...         import jax.random as random
...
...         # Reinitialize parameters with new random key
...         self.rng_key, subkey = random.split(self.rng_key)
...         return self.init_fn(subkey)

Reset by reloading model weights from checkpoint:

>>> class CheckpointReset:
...     def __init__(self, checkpoint_path: str):
...         self.checkpoint_path = checkpoint_path
...
...     def __call__(self, model: nn.Module) -> nn.Module:
...         import torch
...
...         model.load_state_dict(torch.load(self.checkpoint_path))
...         return model

See also

Sufficiency

Uses this protocol for model reset between runs