dataeval.protocols.ModelResetStrategy

class dataeval.protocols.ModelResetStrategy

Protocol for resetting model parameters between training runs.

This protocol enables backend-agnostic model reset functionality. Implementations can provide custom reset logic for any ML framework (PyTorch, TensorFlow, JAX, etc.).

For PyTorch models (nn.Module), a default implementation is provided that calls reset_parameters() on each layer. For other backends, users must provide their own reset strategy.

See also

Sufficiency

Uses this protocol for model reset between runs

Examples

Custom reset for PyTorch with specific initialization:

>>> import torch.nn as nn
>>> class XavierReset:
...     def __call__(self, model: nn.Module) -> nn.Module:
...         for m in model.modules():
...             if hasattr(m, "weight") and m.weight is not None:
...                 nn.init.xavier_uniform_(m.weight)
...             if hasattr(m, "bias") and m.bias is not None:
...                 nn.init.zeros_(m.bias)
...         return model

Reset strategy for JAX models using parameter reinitialization:

>>> class JAXReset:
...     def __init__(self, init_fn, rng_key):
...         self.init_fn = init_fn
...         self.rng_key = rng_key
...
...     def __call__(self, params):
...         import jax.random as random
...
...         # Reinitialize parameters with new random key
...         self.rng_key, subkey = random.split(self.rng_key)
...         return self.init_fn(subkey)

Reset by reloading model weights from checkpoint:

>>> class CheckpointReset:
...     def __init__(self, checkpoint_path: str):
...         self.checkpoint_path = checkpoint_path
...
...     def __call__(self, model: nn.Module) -> nn.Module:
...         import torch
...
...         model.load_state_dict(torch.load(self.checkpoint_path))
...         return model