dataeval.protocols.EvaluationStrategy

class dataeval.protocols.EvaluationStrategy

Protocol defining the interface for evaluating a trained model.

Implementations must provide an evaluate method with this signature. Uses structural typing - no explicit inheritance required.

The @runtime_checkable decorator allows isinstance() checks if needed, though structural typing works without it at type-check time.

The model parameter accepts any type to support different ML backends (PyTorch, TensorFlow, JAX, etc.).

Examples

Creating a custom evaluation strategy for PyTorch:

>>> class MyEvaluation:
...     def __init__(self, batch_size: int, metrics: list[str]):
...         self.batch_size = batch_size
...         self.metrics = metrics
...
...     def evaluate(self, model: torch.nn.Module, dataset: Dataset) -> Mapping[str, float | np.ndarray]:
...         # Custom evaluation implementation
...         model.eval()
...         with torch.no_grad():
...             # Compute metrics
...             ...
...         return {"accuracy": 0.95, "f1": 0.93}

Creating an evaluation strategy for JAX:

>>> class JAXEvaluation:
...     def __init__(self, apply_fn):
...         self.apply_fn = apply_fn  # JAX model's forward function
...
...     def evaluate(self, params, dataset: Dataset) -> Mapping[str, float]:
...         import jax.numpy as jnp
...
...         correct = 0
...         total = len(dataset)
...         for i in range(total):
...             x, y, _ = dataset[i]
...             pred = self.apply_fn(params, x)
...             if jnp.argmax(pred) == jnp.argmax(y):
...                 correct += 1
...         return {"accuracy": correct / total}
evaluate(model, dataset)

Evaluate the model on the dataset and return performance metrics.

Parameters:
model : Any

The trained model to evaluate. Can be any model type (PyTorch Module, TensorFlow model, JAX parameters, etc.).

dataset : Dataset[T]

The dataset to evaluate on (typically a test/validation set)

Returns:

Mapping of metric names to values. Each value is either: - A scalar (float) for single-class metrics - An array (np.ndarray) for per-class or per-sample metrics - Examples:

{“accuracy”: 0.95} # Single metric {“accuracy”: 0.95, “precision”: 0.93, “recall”: 0.94} # Multiple metrics {“accuracy”: np.array([0.9, 0.85, 0.92])} # Per-class metrics

Return type:

Mapping[str, float | ArrayLike]

Notes

Implementations should: - Set model to eval mode if needed - Return consistent metric names across calls - Handle both single-class and multi-class scenarios - Use the entire dataset (unlike training which uses subsets)