dataeval.protocols.EvidenceLowerBoundLossFn

class dataeval.protocols.EvidenceLowerBoundLossFn

Protocol for Evidence Lower Bound (ELBO) loss functions.

ELBO loss functions take the original input, reconstruction, mean (mu), and log-variance (logvar) to compute the ELBO loss.

Examples

Using the ELBO class:

>>> from dataeval.utils.losses import ELBOLoss
>>> loss_fn = ELBOLoss(beta=1.0)
>>> x = torch.randn(32, 1, 28, 28)
>>> x_recon = torch.randn(32, 1, 28, 28)
>>> mu = torch.randn(32, 128)
>>> logvar = torch.randn(32, 128)
>>> loss = loss_fn(x, x_recon, mu, logvar)

Creating a custom ELBO loss:

>>> class CustomELBOLoss:
...     def __init__(self, beta: float = 1.0):
...         self.beta = beta
...
...     def __call__(
...         self, x: torch.Tensor, x_recon: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor
...     ) -> torch.Tensor:
...         recon_loss = torch.mean((x - x_recon) ** 2)
...         kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
...         return recon_loss + self.beta * kld_loss