dataeval.utils.models.AE

class dataeval.utils.models.AE(input_shape, gmm_density_net=None)

An autoencoder model with a separate encoder and decoder.

Parameters:
input_shape : tuple[int, ...]

Shape of the input data in CHW format.

gmm_density_net : GMMDensityNet or None, default None

Optional GMM density network to enable GMM-based OOD detection. If provided, the forward pass will return (reconstruction, z, gamma) instead of just reconstruction. The GMMDensityNet’s latent_dim must match the encoder’s encoding dimension.

Example

Sample data:

>>> x = torch.randn(32, 1, 28, 28)

Standard autoencoder:

>>> ae = AE(input_shape=(1, 28, 28))
>>> reconstruction = ae(x)

Autoencoder with GMM for OOD detection:

>>> from dataeval.utils.models import GMMDensityNet
>>> gmm_density_net = GMMDensityNet(latent_dim=256, n_gmm=3)
>>> ae_gmm = AE(input_shape=(1, 28, 28), gmm_density_net=gmm_density_net)
>>> reconstruction, z, gamma = ae_gmm(x)
>>> # Use with OODReconstruction(ae_gmm, model_type="ae", use_gmm=True)
forward(x)

Perform a forward pass through the encoder and decoder.

Parameters:
x : torch.Tensor

Input tensor

Returns:

If gmm_density_net is None: returns reconstructed output tensor. If gmm_density_net is provided: returns (reconstruction, z, gamma) where z is the latent representation and gamma is the mixture assignment probabilities.

Return type:

torch.Tensor or tuple[torch.Tensor, torch.Tensor, torch.Tensor]