dataeval.shift.OODReconstruction

class dataeval.shift.OODReconstruction(model, device=None, model_type='auto', use_gmm=None, threshold_perc=None, config=None)

Autoencoder (AE) or Variational Autoencoder (VAE) based out-of-distribution detector.

Supports standard autoencoders and variational autoencoders with optional Gaussian Mixture Model (GMM) in the latent space for enhanced detection. Model type can be auto-detected from the model structure or explicitly specified.

Input data must be on the unit interval [0, 1].

Parameters:
model : torch.nn.Module

An autoencoder or VAE model to use for encoding and reconstruction of images for detection of out-of-distribution samples. Model type will be auto-detected if not specified.

device : DeviceLike or None, default None

The hardware device to use if specified, otherwise uses the DataEval default or torch default.

model_type : {"ae", "vae", "auto"} or None, default "auto"

Type of model: “ae” for standard autoencoder, “vae” for variational autoencoder, or “auto” to auto-detect based on model structure. If None, defaults to “auto”.

use_gmm : bool or None, default None

Whether to use Gaussian Mixture Model in the latent space for enhanced OOD detection. If None, will be auto-detected based on whether model has gmm_density_net attribute. When True, the model must output (reconstruction, z, gamma) where z is the latent representation and gamma is the mixture assignment probabilities.

threshold_perc : float or None, default None

Percentage of reference data considered normal (0-100). If None, uses config.threshold_perc (default 95.0).

config : OODReconstruction.Config or None, default None

Optional configuration object with default training parameters. Parameters specified in fit() will override these defaults.

Example

Auto-detection (recommended):

>>> from dataeval.utils.models import AE, VAE, GMMDensityNet
>>> from dataeval.shift import OODReconstruction
>>>
>>> train_data = torch.rand(10, 1, 28, 28)
>>>
>>> # Auto-detect AE
>>> ae = AE(input_shape=(1, 28, 28))
>>> ood = OODReconstruction(ae)  # Automatically detects as "ae", use_gmm=False
>>>
>>> # Auto-detect VAE
>>> vae = VAE(input_shape=(1, 28, 28))
>>> ood = OODReconstruction(vae)  # Automatically detects as "vae", use_gmm=False
>>>
>>> # Auto-detect GMM
>>> ae_gmm = AE(input_shape=(1, 28, 28), gmm_density_net=GMMDensityNet(latent_dim=256, n_gmm=3))
>>> ood = OODReconstruction(ae_gmm)  # Automatically detects as "ae", use_gmm=True

Using configuration:

>>> config = OODReconstruction.Config(epochs=10, batch_size=128, threshold_perc=99.0)
>>> ood = OODReconstruction(vae, config=config)
>>> ood.fit(train_data)  # Uses config defaults
OODReconstruction(loss_fn=None, optimizer=None, epochs=10, batch_size=128, threshold_perc=99.0, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)

Explicit specification:

>>> config = OODReconstruction.Config(epochs=20)
>>> ood = OODReconstruction(vae, model_type="vae", use_gmm=False, threshold_perc=95, config=config)
>>> ood.fit(train_data)
OODReconstruction(loss_fn=None, optimizer=None, epochs=20, batch_size=64, threshold_perc=95, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)
fit(reference_data)

Train the model and infer the threshold value.

Training parameters (loss_fn, optimizer, epochs, batch_size) are taken from Config.

Parameters:
reference_data : ArrayLike

Training data.

Returns:

The fitted detector (for method chaining).

Return type:

Self

Examples

>>> from dataeval.shift import OODReconstruction
>>> from dataeval.utils.models import AE, VAE
>>> input_shape = (1, 28, 28)
>>> train_data = torch.rand(20, *input_shape)
>>> config = OODReconstruction.Config(epochs=10, threshold_perc=95)
>>> ood = OODReconstruction(AE(input_shape), config=config)
>>> ood.fit(train_data)
OODReconstruction(loss_fn=None, optimizer=None, epochs=10, batch_size=64, threshold_perc=95, gmm_weight=0.5, gmm_score_mode='standardized', fitted=False)
predict(data, batch_size=None, ood_type='instance')

Predict whether instances are out of distribution.

Parameters:
data : ArrayLike

Input data for OOD prediction.

batch_size : int or None, default None

Number of instances to process per batch (only used by some detectors). When None, uses the global batch size from get_batch_size().

ood_type : "feature" | "instance", default "instance"

Predict OOD at the "feature" or "instance" level.

Returns:

Predictions including is_ood boolean array and OOD scores.

Return type:

OODOutput

score(data, batch_size=None)

Compute out of distribution scores for a given dataset.

Parameters:
data : ArrayLike

Input data to score.

batch_size : int or None, default None

Number of instances to process per batch (only used by some detectors). When None, uses the global batch size from get_batch_size().

Returns:

Instance-level (and optionally feature-level) OOD scores. Higher scores indicate samples more likely to be OOD.

Return type:

OODScoreOutput

property device : torch.device

The device the model is on.

property model : torch.nn.Module

The underlying autoencoder or VAE model.

property model_type : str

"ae" or "vae".

Type:

Model type

property use_gmm : bool

Whether GMM-based scoring is enabled.

Classes

Config

Configuration for OODReconstruction detector training and threshold computation.