dataeval.protocols.ReconstructionLossFn¶
- class dataeval.protocols.ReconstructionLossFn¶
Protocol for reconstruction-based loss functions (Autoencoder).
Used for standard autoencoders that only return reconstruction. The loss function takes the original input and reconstruction.
Examples
Using MSE for reconstruction:
>>> import torch >>> import torch.nn as nn >>> loss_fn = nn.MSELoss() >>> x = torch.randn(32, 1, 28, 28) >>> x_recon = torch.randn(32, 1, 28, 28) >>> loss = loss_fn(x, x_recon)Creating a custom reconstruction loss:
>>> class CustomReconstructionLoss: ... def __call__(self, x: torch.Tensor, x_recon: torch.Tensor) -> torch.Tensor: ... return torch.mean(torch.abs(x - x_recon)) >>> loss_fn = CustomReconstructionLoss()