dataeval.utils.models.AE¶
-
class dataeval.utils.models.AE(input_shape, gmm_density_net=
None)¶ An autoencoder model with a separate encoder and decoder.
- Parameters:¶
- input_shape : tuple[int, ...]¶
Shape of the input data in CHW format.
- gmm_density_net : GMMDensityNet or None, default None¶
Optional GMM density network to enable GMM-based OOD detection. If provided, the forward pass will return (reconstruction, z, gamma) instead of just reconstruction. The GMMDensityNet’s latent_dim must match the encoder’s encoding dimension.
Example
Sample data:
>>> x = torch.randn(32, 1, 28, 28)Standard autoencoder:
>>> ae = AE(input_shape=(1, 28, 28)) >>> reconstruction = ae(x)Autoencoder with GMM for OOD detection:
>>> from dataeval.utils.models import GMMDensityNet >>> gmm_density_net = GMMDensityNet(latent_dim=256, n_gmm=3) >>> ae_gmm = AE(input_shape=(1, 28, 28), gmm_density_net=gmm_density_net) >>> reconstruction, z, gamma = ae_gmm(x) >>> # Use with OODReconstruction(ae_gmm, model_type="ae", use_gmm=True)- forward(x)¶
Perform a forward pass through the encoder and decoder.
- Parameters:¶
- x : torch.Tensor¶
Input tensor
- Returns:¶
If gmm_density_net is None: returns reconstructed output tensor. If gmm_density_net is provided: returns (reconstruction, z, gamma) where z is the latent representation and gamma is the mixture assignment probabilities.
- Return type:¶
torch.Tensor or tuple[torch.Tensor, torch.Tensor, torch.Tensor]