dataeval.utils.models.VAE

class dataeval.utils.models.VAE(input_shape, latent_dim=None, gmm_density_net=None)

Variational Autoencoder model with separate encoder and decoder.

Parameters:
input_shape : tuple[int, ...]

Shape of the input data in CHW format.

latent_dim : int or None, default None

The size of the latent space. If None, will be computed automatically.

gmm_density_net : GMMDensityNet or None, default None

Optional GMM density network to enable GMM-based OOD detection. If provided, the forward pass will return (reconstruction, z, gamma) instead of (reconstruction, mu, logvar). The GMMDensityNet’s latent_dim must match the VAE’s latent dimension.

Example

Sample data:

>>> x = torch.randn(32, 1, 28, 28)

Standard VAE:

>>> vae = VAE(input_shape=(1, 28, 28))
>>> recon, mu, logvar = vae(x)

VAE with GMM for OOD detection:

>>> from dataeval.utils.models import VAE, GMMDensityNet
>>> vae_gmm = VAE(input_shape=(1, 28, 28), gmm_density_net=GMMDensityNet(latent_dim=256, n_gmm=3))
>>> reconstruction, z, gamma = vae_gmm(x)
>>> # Use with OODReconstruction(vae_gmm) - auto-detects as VAE with GMM
forward(x)

Perform a forward pass through the VAE.

Parameters:
x : torch.Tensor

Input tensor

Returns:

If gmm_density_net is None: returns (reconstruction, mu, logvar). If gmm_density_net is provided: returns (reconstruction, z, gamma) where z is the latent representation (mu) and gamma is the mixture assignment probabilities.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

reparameterize(mu, logvar)

Reparameterization trick to sample from N(mu, var) using N(0,1).

Parameters:
mu : torch.Tensor

Mean of the latent Gaussian.

logvar : torch.Tensor

Log-variance of the latent Gaussian.

Returns:

Sampled latent vector.

Return type:

torch.Tensor