dataeval.protocols.ModelResetStrategy¶
- class dataeval.protocols.ModelResetStrategy¶
Protocol for resetting model parameters between training runs.
This protocol enables backend-agnostic model reset functionality. Implementations can provide custom reset logic for any ML framework (PyTorch, TensorFlow, JAX, etc.).
For PyTorch models (nn.Module), a default implementation is provided that calls reset_parameters() on each layer. For other backends, users must provide their own reset strategy.
Examples
Custom reset for PyTorch with specific initialization:
>>> import torch.nn as nn >>> class XavierReset: ... def __call__(self, model: nn.Module) -> nn.Module: ... for m in model.modules(): ... if hasattr(m, "weight") and m.weight is not None: ... nn.init.xavier_uniform_(m.weight) ... if hasattr(m, "bias") and m.bias is not None: ... nn.init.zeros_(m.bias) ... return modelReset strategy for JAX models using parameter reinitialization:
>>> class JAXReset: ... def __init__(self, init_fn, rng_key): ... self.init_fn = init_fn ... self.rng_key = rng_key ... ... def __call__(self, params): ... import jax.random as random ... ... # Reinitialize parameters with new random key ... self.rng_key, subkey = random.split(self.rng_key) ... return self.init_fn(subkey)Reset by reloading model weights from checkpoint:
>>> class CheckpointReset: ... def __init__(self, checkpoint_path: str): ... self.checkpoint_path = checkpoint_path ... ... def __call__(self, model: nn.Module) -> nn.Module: ... import torch ... ... model.load_state_dict(torch.load(self.checkpoint_path)) ... return modelSee also
SufficiencyUses this protocol for model reset between runs