dataeval.shift.OODReconstruction¶
-
class dataeval.shift.OODReconstruction(model, device=
None, model_type='auto', use_gmm=None, config=None)¶ Autoencoder (AE) or Variational Autoencoder (VAE) based out-of-distribution detector.
Supports standard autoencoders and variational autoencoders with optional Gaussian Mixture Model (GMM) in the latent space for enhanced detection. Model type can be auto-detected from the model structure or explicitly specified.
- Parameters:¶
- model : torch.nn.Module¶
An autoencoder or VAE model to use for encoding and reconstruction of images for detection of out-of-distribution samples. Model type will be auto-detected if not specified.
- device : DeviceLike or None, default None¶
The hardware device to use if specified, otherwise uses the DataEval default or torch default.
- model_type : {"ae", "vae", "auto"} or None, default "auto"¶
Type of model: “ae” for standard autoencoder, “vae” for variational autoencoder, or “auto” to auto-detect based on model structure. If None, defaults to “auto”.
- use_gmm : bool or None, default None¶
Whether to use Gaussian Mixture Model in the latent space for enhanced OOD detection. If None, will be auto-detected based on whether model has gmm_density_net attribute. When True, the model must output (reconstruction, z, gamma) where z is the latent representation and gamma is the mixture assignment probabilities.
- config : OODReconstruction.Config or None, default None¶
Optional configuration object with default training parameters. Parameters specified in fit() will override these defaults.
Example
Auto-detection (recommended):
>>> from dataeval.utils.models import AE, VAE, GMMDensityNet >>> from dataeval.shift import OODReconstruction >>> >>> train_data = torch.rand(10, 1, 28, 28) >>> >>> # Auto-detect AE >>> ae = AE(input_shape=(1, 28, 28)) >>> ood = OODReconstruction(ae) # Automatically detects as "ae", use_gmm=False >>> >>> # Auto-detect VAE >>> vae = VAE(input_shape=(1, 28, 28)) >>> ood = OODReconstruction(vae) # Automatically detects as "vae", use_gmm=False >>> >>> # Auto-detect GMM >>> ae_gmm = AE(input_shape=(1, 28, 28), gmm_density_net=GMMDensityNet(latent_dim=256, n_gmm=3)) >>> ood = OODReconstruction(ae_gmm) # Automatically detects as "ae", use_gmm=TrueUsing configuration:
>>> config = OODReconstruction.Config(epochs=10, batch_size=128, threshold_perc=99.0) >>> ood = OODReconstruction(vae, config=config) >>> ood.fit(train_data) # Uses config defaultsExplicit specification:
>>> ood = OODReconstruction(vae, model_type="vae", use_gmm=False) >>> ood.fit(train_data, threshold_perc=95, epochs=20)-
fit(x_ref, threshold_perc=
None, loss_fn=None, optimizer=None, epochs=None, batch_size=None)¶ Train the model and infer the threshold value.
Parameters specified here override defaults from the config object.
- Parameters:¶
- x_ref : ArrayLike¶
Training data.
- threshold_perc : float or None, default None¶
Percentage of reference data that is normal (0-100). If None, uses config.threshold_perc.
- loss_fn : ReconstructionLossFn | VAELossFn | Callable | None, default None¶
Loss function used for training. Can be: - A ReconstructionLossFn for AE (e.g., torch.nn.MSELoss()) - A VAELossFn for VAE (e.g., ELBOLoss(beta=1.0)) - Any callable with appropriate signature If None, uses config.loss_fn, or auto-selects based on model type: - For AE/GMM-AE: uses MSELoss() - For VAE: uses ELBOLoss()
- optimizer : torch.optim.Optimizer | None, default None¶
Optimizer used for training. If None, uses config.optimizer or Adam with lr=0.001.
- epochs : int or None, default None¶
Number of training epochs. If None, uses config.epochs (default 20).
- batch_size : int or None, default None¶
Batch size used for training. If None, uses config.batch_size (default 64).
Examples
Using config defaults (recommended):
>>> from dataeval.shift import OODReconstruction >>> from dataeval.utils.models import AE, VAE >>> input_shape = (1, 28, 28) >>> train_data = torch.rand(20, *input_shape) >>> config = OODReconstruction.Config(epochs=10, threshold_perc=95) >>> ood = OODReconstruction(AE(input_shape), config=config) >>> ood.fit(train_data) # Uses config defaultsOverriding specific parameters:
>>> ood = OODReconstruction(VAE(input_shape)) >>> ood.fit(train_data, epochs=20, batch_size=10) # Override defaultsUsing custom loss:
>>> config = OODReconstruction.Config(loss_fn=ELBOLoss(beta=2.0)) >>> ood = OODReconstruction(VAE(input_shape), config=config) >>> ood.fit(train_data)
-
predict(x, batch_size=
int(10000000000.0), ood_type='instance')¶ Predict whether instances are out of distribution or not.
-
score(x, batch_size=
int(10000000000.0))¶ Compute the out of distribution scores for a given dataset.
Classes¶
Configuration for OODReconstruction detector training and threshold computation. |