dataeval.protocols.ReconstructionLossFn

class dataeval.protocols.ReconstructionLossFn

Protocol for reconstruction-based loss functions (Autoencoder).

Used for standard autoencoders that only return reconstruction. The loss function takes the original input and reconstruction.

Examples

Using MSE for reconstruction:

>>> import torch
>>> import torch.nn as nn
>>> loss_fn = nn.MSELoss()
>>> x = torch.randn(32, 1, 28, 28)
>>> x_recon = torch.randn(32, 1, 28, 28)
>>> loss = loss_fn(x, x_recon)

Creating a custom reconstruction loss:

>>> class CustomReconstructionLoss:
...     def __call__(self, x: torch.Tensor, x_recon: torch.Tensor) -> torch.Tensor:
...         return torch.mean(torch.abs(x - x_recon))
>>> loss_fn = CustomReconstructionLoss()