dataeval.shift.OODReconstruction¶
-
class dataeval.shift.OODReconstruction(model, device=
None, model_type='auto', use_gmm=None, threshold_perc=None, config=None)¶ Autoencoder (AE) or Variational Autoencoder (VAE) based out-of-distribution detector.
Supports standard autoencoders and variational autoencoders with optional Gaussian Mixture Model (GMM) in the latent space for enhanced detection. Model type can be auto-detected from the model structure or explicitly specified.
Input data must be on the unit interval [0, 1].
- Parameters:¶
- model : torch.nn.Module¶
An autoencoder or VAE model to use for encoding and reconstruction of images for detection of out-of-distribution samples. Model type will be auto-detected if not specified.
- device : DeviceLike or None, default None¶
The hardware device to use if specified, otherwise uses the DataEval default or torch default.
- model_type : {"ae", "vae", "auto"} or None, default "auto"¶
Type of model: “ae” for standard autoencoder, “vae” for variational autoencoder, or “auto” to auto-detect based on model structure. If None, defaults to “auto”.
- use_gmm : bool or None, default None¶
Whether to use Gaussian Mixture Model in the latent space for enhanced OOD detection. If None, will be auto-detected based on whether model has gmm_density_net attribute. When True, the model must output (reconstruction, z, gamma) where z is the latent representation and gamma is the mixture assignment probabilities.
- threshold_perc : float or None, default None¶
Percentage of reference data considered normal (0-100). If None, uses config.threshold_perc (default 95.0).
- config : OODReconstruction.Config or None, default None¶
Optional configuration object with default training parameters. Parameters specified in fit() will override these defaults.
Example
Auto-detection (recommended):
>>> from dataeval.utils.models import AE, VAE, GMMDensityNet >>> from dataeval.shift import OODReconstruction >>> >>> train_data = torch.rand(10, 1, 28, 28) >>> >>> # Auto-detect AE >>> ae = AE(input_shape=(1, 28, 28)) >>> ood = OODReconstruction(ae) # Automatically detects as "ae", use_gmm=False >>> >>> # Auto-detect VAE >>> vae = VAE(input_shape=(1, 28, 28)) >>> ood = OODReconstruction(vae) # Automatically detects as "vae", use_gmm=False >>> >>> # Auto-detect GMM >>> ae_gmm = AE(input_shape=(1, 28, 28), gmm_density_net=GMMDensityNet(latent_dim=256, n_gmm=3)) >>> ood = OODReconstruction(ae_gmm) # Automatically detects as "ae", use_gmm=TrueUsing configuration:
>>> config = OODReconstruction.Config(epochs=10, batch_size=128, threshold_perc=99.0) >>> ood = OODReconstruction(vae, config=config) >>> ood.fit(train_data) # Uses config defaults OODReconstruction(loss_fn=None, optimizer=None, epochs=10, batch_size=128, threshold_perc=99.0, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)Explicit specification:
>>> config = OODReconstruction.Config(epochs=20) >>> ood = OODReconstruction(vae, model_type="vae", use_gmm=False, threshold_perc=95, config=config) >>> ood.fit(train_data) OODReconstruction(loss_fn=None, optimizer=None, epochs=20, batch_size=64, threshold_perc=95, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)- fit(reference_data)¶
Train the model and infer the threshold value.
Training parameters (
loss_fn,optimizer,epochs,batch_size) are taken fromConfig.- Parameters:¶
- reference_data : ArrayLike¶
Training data.
- Returns:¶
The fitted detector (for method chaining).
- Return type:¶
Self
Examples
>>> from dataeval.shift import OODReconstruction >>> from dataeval.utils.models import AE, VAE>>> input_shape = (1, 28, 28) >>> train_data = torch.rand(20, *input_shape) >>> config = OODReconstruction.Config(epochs=10, threshold_perc=95) >>> ood = OODReconstruction(AE(input_shape), config=config) >>> ood.fit(train_data) OODReconstruction(loss_fn=None, optimizer=None, epochs=10, batch_size=64, threshold_perc=95, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)
-
predict(data, batch_size=
None, ood_type='instance')¶ Predict whether instances are out of distribution.
- Parameters:¶
- data : ArrayLike¶
Input data for OOD prediction.
- batch_size : int or None, default None¶
Number of instances to process per batch (only used by some detectors). When None, uses the global batch size from
get_batch_size().- ood_type : "feature" | "instance", default "instance"¶
Predict OOD at the
"feature"or"instance"level.
- Returns:¶
Predictions including
is_oodboolean array and OOD scores.- Return type:¶
-
score(data, batch_size=
None)¶ Compute out of distribution scores for a given dataset.
- Parameters:¶
- data : ArrayLike¶
Input data to score.
- batch_size : int or None, default None¶
Number of instances to process per batch (only used by some detectors). When None, uses the global batch size from
get_batch_size().
- Returns:¶
Instance-level (and optionally feature-level) OOD scores. Higher scores indicate samples more likely to be OOD.
- Return type:¶
- property device : torch.device¶
The device the model is on.
- property model : torch.nn.Module¶
The underlying autoencoder or VAE model.
- property use_gmm : bool¶
Whether GMM-based scoring is enabled.
Classes¶
Configuration for OODReconstruction detector training and threshold computation. |