dataeval.utils.models.GMMDensityNet¶
-
class dataeval.utils.models.GMMDensityNet(latent_dim, n_gmm=
2, hidden_dim=10)¶ Gaussian Mixture Model (GMM) density network for converting latent representations.
Converts latent representations to mixture assignment probabilities.
This network can be appended to AE or VAE models to enable GMM-based OOD detection by producing gamma (mixture assignment probabilities) from the latent representation.
- Parameters:¶
Example
Creating a VAE with GMM density estimation:
>>> from dataeval.shift import OODReconstruction >>> from dataeval.utils.models import VAE, GMMDensityNet >>> >>> # Use with OODReconstruction >>> gmm_density_net = GMMDensityNet(latent_dim=256, n_gmm=3) >>> vae_gmm_model = VAE(input_shape=(1, 28, 28), gmm_density_net=gmm_density_net) >>> ood = OODReconstruction(vae_gmm_model, model_type="vae", use_gmm=True)Notes
The network architecture is based on the GMM density network from Alibi-Detect, adapted from TensorFlow to PyTorch. It consists of: - A hidden layer with tanh activation - An output layer with softmax activation to produce valid probability distributions
- forward(z)¶
Convert latent representation to mixture assignment probabilities.
- Parameters:¶
- z : torch.Tensor¶
Latent representation with shape (batch_size, latent_dim).
- Returns:¶
Mixture assignment probabilities (gamma) with shape (batch_size, n_gmm). Each row sums to 1.0 and represents the probability distribution over mixture components for that sample.
- Return type:¶
torch.Tensor