dataeval.utils.losses.ELBOLoss¶
-
class dataeval.utils.losses.ELBOLoss(beta=
1.0, reduction='mean')¶ Class-based ELBO loss function for flexible configuration.
This class provides a callable loss function that can be initialized with specific parameters (like beta for beta-VAE) and then used like PyTorch’s built-in loss functions (e.g., nn.MSELoss()).
- Parameters:¶
Examples
Basic usage with default beta:
>>> from dataeval.utils.losses import ELBOLoss >>> import torch >>> >>> loss_fn = ELBOLoss(beta=1.0) >>> x = torch.rand(32, 1, 28, 28) >>> x_recon = torch.rand(32, 1, 28, 28) >>> mu = torch.rand(32, 128) >>> logvar = torch.rand(32, 128) >>> loss = loss_fn(x, x_recon, mu, logvar)Using beta-VAE with higher beta for disentanglement:
>>> loss_fn = ELBOLoss(beta=4.0) >>> loss = loss_fn(x, x_recon, mu, logvar)Using with OODReconstruction:
>>> from dataeval.shift import OODReconstruction >>> from dataeval.utils.models import VAE >>> >>> vae_model = VAE(input_shape=(1, 28, 28)) >>> config = OODReconstruction.Config(loss_fn=ELBOLoss(beta=2.0), epochs=20) >>> ood = OODReconstruction(vae_model, model_type="vae", threshold_perc=95, config=config) >>> ood.fit(x) OODReconstruction(loss_fn=ELBOLoss(beta=2.0, reduction='mean'), optimizer=None, epochs=20, batch_size=64, threshold_perc=95, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)