dataeval.protocols.LossFn

class dataeval.protocols.LossFn

Protocol for generic loss functions that can be used with PyTorch models.

This is the base protocol for all loss functions. It supports both class-based (torch.nn.Module-like) and functional loss implementations.

Examples

Using built-in PyTorch loss:

>>> import torch.nn as nn
>>> loss_fn = nn.MSELoss()
>>> isinstance(loss_fn, LossFn)
True

Creating a custom functional loss:

>>> def custom_loss(y_true, y_pred):
...     return torch.mean((y_true - y_pred) ** 2)
>>> isinstance(custom_loss, LossFn)
True

Creating a custom class-based loss:

>>> class CustomLoss:
...     def __call__(self, y_true, y_pred):
...         return torch.mean((y_true - y_pred) ** 2)
>>> loss_fn = CustomLoss()
>>> isinstance(loss_fn, LossFn)
True