dataeval.protocols.LossFn¶
- class dataeval.protocols.LossFn¶
Protocol for generic loss functions that can be used with PyTorch models.
This is the base protocol for all loss functions. It supports both class-based (torch.nn.Module-like) and functional loss implementations.
Examples
Using built-in PyTorch loss:
>>> import torch.nn as nn >>> loss_fn = nn.MSELoss() >>> isinstance(loss_fn, LossFn) TrueCreating a custom functional loss:
>>> def custom_loss(y_true, y_pred): ... return torch.mean((y_true - y_pred) ** 2) >>> isinstance(custom_loss, LossFn) TrueCreating a custom class-based loss:
>>> class CustomLoss: ... def __call__(self, y_true, y_pred): ... return torch.mean((y_true - y_pred) ** 2) >>> loss_fn = CustomLoss() >>> isinstance(loss_fn, LossFn) True