dataeval.utils.models.VAE¶
-
class dataeval.utils.models.VAE(input_shape, latent_dim=
None, gmm_density_net=None)¶ Variational Autoencoder model with separate encoder and decoder.
- Parameters:¶
- input_shape : tuple[int, ...]¶
Shape of the input data in CHW format.
- latent_dim : int or None, default None¶
The size of the latent space. If None, will be computed automatically.
- gmm_density_net : GMMDensityNet or None, default None¶
Optional GMM density network to enable GMM-based OOD detection. If provided, the forward pass will return (reconstruction, z, gamma) instead of (reconstruction, mu, logvar). The GMMDensityNet’s latent_dim must match the VAE’s latent dimension.
Example
Sample data:
>>> x = torch.randn(32, 1, 28, 28)Standard VAE:
>>> vae = VAE(input_shape=(1, 28, 28)) >>> recon, mu, logvar = vae(x)VAE with GMM for OOD detection:
>>> from dataeval.utils.models import VAE, GMMDensityNet >>> vae_gmm = VAE(input_shape=(1, 28, 28), gmm_density_net=GMMDensityNet(latent_dim=256, n_gmm=3)) >>> reconstruction, z, gamma = vae_gmm(x) >>> # Use with OODReconstruction(vae_gmm) - auto-detects as VAE with GMM- forward(x)¶
Perform a forward pass through the VAE.
- Parameters:¶
- x : torch.Tensor¶
Input tensor
- Returns:¶
If gmm_density_net is None: returns (reconstruction, mu, logvar). If gmm_density_net is provided: returns (reconstruction, z, gamma) where z is the latent representation (mu) and gamma is the mixture assignment probabilities.
- Return type:¶
tuple[torch.Tensor, torch.Tensor, torch.Tensor]