dataeval.protocols.EvidenceLowerBoundLossFn¶
- class dataeval.protocols.EvidenceLowerBoundLossFn¶
Protocol for Evidence Lower Bound (ELBO) loss functions.
ELBO loss functions take the original input, reconstruction, mean (mu), and log-variance (logvar) to compute the ELBO loss.
Examples
Using the ELBO class:
>>> from dataeval.utils.losses import ELBOLoss >>> loss_fn = ELBOLoss(beta=1.0) >>> x = torch.randn(32, 1, 28, 28) >>> x_recon = torch.randn(32, 1, 28, 28) >>> mu = torch.randn(32, 128) >>> logvar = torch.randn(32, 128) >>> loss = loss_fn(x, x_recon, mu, logvar)Creating a custom ELBO loss:
>>> class CustomELBOLoss: ... def __init__(self, beta: float = 1.0): ... self.beta = beta ... ... def __call__( ... self, x: torch.Tensor, x_recon: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor ... ) -> torch.Tensor: ... recon_loss = torch.mean((x - x_recon) ** 2) ... kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) ... return recon_loss + self.beta * kld_loss