dataeval.utils.losses.ELBOLoss

class dataeval.utils.losses.ELBOLoss(beta=1.0, reduction='mean')

Class-based ELBO loss function for flexible configuration.

This class provides a callable loss function that can be initialized with specific parameters (like beta for beta-VAE) and then used like PyTorch’s built-in loss functions (e.g., nn.MSELoss()).

Parameters:
beta : float, default 1.0

Weight for the KL divergence term. Higher values encourage learning a more regular latent space (beta-VAE).

reduction : str, default "mean"

Reduction method for reconstruction loss: “mean”, “sum”, or “none”. Note: KL divergence always uses mean reduction.

Examples

Basic usage with default beta:

>>> from dataeval.utils.losses import ELBOLoss
>>> import torch
>>>
>>> loss_fn = ELBOLoss(beta=1.0)
>>> x = torch.rand(32, 1, 28, 28)
>>> x_recon = torch.rand(32, 1, 28, 28)
>>> mu = torch.rand(32, 128)
>>> logvar = torch.rand(32, 128)
>>> loss = loss_fn(x, x_recon, mu, logvar)

Using beta-VAE with higher beta for disentanglement:

>>> loss_fn = ELBOLoss(beta=4.0)
>>> loss = loss_fn(x, x_recon, mu, logvar)

Using with OODReconstruction:

>>> from dataeval.shift import OODReconstruction
>>> from dataeval.utils.models import VAE
>>>
>>> vae_model = VAE(input_shape=(1, 28, 28))
>>> ood = OODReconstruction(vae_model, model_type="vae")
>>> custom_loss = ELBOLoss(beta=2.0)
>>> ood.fit(x, threshold_perc=95, loss_fn=custom_loss, epochs=20)