dataeval.utils.losses.ELBOLoss¶
-
class dataeval.utils.losses.ELBOLoss(beta=
1.0, reduction='mean')¶ Class-based ELBO loss function for flexible configuration.
This class provides a callable loss function that can be initialized with specific parameters (like beta for beta-VAE) and then used like PyTorch’s built-in loss functions (e.g., nn.MSELoss()).
- Parameters:¶
Examples
Basic usage with default beta:
>>> from dataeval.utils.losses import ELBOLoss >>> import torch >>> >>> loss_fn = ELBOLoss(beta=1.0) >>> x = torch.rand(32, 1, 28, 28) >>> x_recon = torch.rand(32, 1, 28, 28) >>> mu = torch.rand(32, 128) >>> logvar = torch.rand(32, 128) >>> loss = loss_fn(x, x_recon, mu, logvar)Using beta-VAE with higher beta for disentanglement:
>>> loss_fn = ELBOLoss(beta=4.0) >>> loss = loss_fn(x, x_recon, mu, logvar)Using with OODReconstruction:
>>> from dataeval.shift import OODReconstruction >>> from dataeval.utils.models import VAE >>> >>> vae_model = VAE(input_shape=(1, 28, 28)) >>> ood = OODReconstruction(vae_model, model_type="vae") >>> custom_loss = ELBOLoss(beta=2.0) >>> ood.fit(x, threshold_perc=95, loss_fn=custom_loss, epochs=20)