dataeval.shift.OODReconstruction

class dataeval.shift.OODReconstruction(model, device=None, model_type='auto', use_gmm=None, config=None)

Autoencoder (AE) or Variational Autoencoder (VAE) based out-of-distribution detector.

Supports standard autoencoders and variational autoencoders with optional Gaussian Mixture Model (GMM) in the latent space for enhanced detection. Model type can be auto-detected from the model structure or explicitly specified.

Input data must be on the unit interval [0, 1].

Parameters:
model : torch.nn.Module

An autoencoder or VAE model to use for encoding and reconstruction of images for detection of out-of-distribution samples. Model type will be auto-detected if not specified.

device : DeviceLike or None, default None

The hardware device to use if specified, otherwise uses the DataEval default or torch default.

model_type : {"ae", "vae", "auto"} or None, default "auto"

Type of model: “ae” for standard autoencoder, “vae” for variational autoencoder, or “auto” to auto-detect based on model structure. If None, defaults to “auto”.

use_gmm : bool or None, default None

Whether to use Gaussian Mixture Model in the latent space for enhanced OOD detection. If None, will be auto-detected based on whether model has gmm_density_net attribute. When True, the model must output (reconstruction, z, gamma) where z is the latent representation and gamma is the mixture assignment probabilities.

config : OODReconstruction.Config or None, default None

Optional configuration object with default training parameters. Parameters specified in fit() will override these defaults.

Example

Auto-detection (recommended):

>>> from dataeval.utils.models import AE, VAE, GMMDensityNet
>>> from dataeval.shift import OODReconstruction
>>>
>>> train_data = torch.rand(10, 1, 28, 28)
>>>
>>> # Auto-detect AE
>>> ae = AE(input_shape=(1, 28, 28))
>>> ood = OODReconstruction(ae)  # Automatically detects as "ae", use_gmm=False
>>>
>>> # Auto-detect VAE
>>> vae = VAE(input_shape=(1, 28, 28))
>>> ood = OODReconstruction(vae)  # Automatically detects as "vae", use_gmm=False
>>>
>>> # Auto-detect GMM
>>> ae_gmm = AE(input_shape=(1, 28, 28), gmm_density_net=GMMDensityNet(latent_dim=256, n_gmm=3))
>>> ood = OODReconstruction(ae_gmm)  # Automatically detects as "ae", use_gmm=True

Using configuration:

>>> config = OODReconstruction.Config(epochs=10, batch_size=128, threshold_perc=99.0)
>>> ood = OODReconstruction(vae, config=config)
>>> ood.fit(train_data)  # Uses config defaults

Explicit specification:

>>> ood = OODReconstruction(vae, model_type="vae", use_gmm=False)
>>> ood.fit(train_data, threshold_perc=95, epochs=20)
fit(x_ref, threshold_perc=None, loss_fn=None, optimizer=None, epochs=None, batch_size=None)

Train the model and infer the threshold value.

Parameters specified here override defaults from the config object.

Parameters:
x_ref : ArrayLike

Training data.

threshold_perc : float or None, default None

Percentage of reference data that is normal (0-100). If None, uses config.threshold_perc.

loss_fn : ReconstructionLossFn | VAELossFn | Callable | None, default None

Loss function used for training. Can be: - A ReconstructionLossFn for AE (e.g., torch.nn.MSELoss()) - A VAELossFn for VAE (e.g., ELBOLoss(beta=1.0)) - Any callable with appropriate signature If None, uses config.loss_fn, or auto-selects based on model type: - For AE/GMM-AE: uses MSELoss() - For VAE: uses ELBOLoss()

optimizer : torch.optim.Optimizer | None, default None

Optimizer used for training. If None, uses config.optimizer or Adam with lr=0.001.

epochs : int or None, default None

Number of training epochs. If None, uses config.epochs (default 20).

batch_size : int or None, default None

Batch size used for training. If None, uses config.batch_size (default 64).

Examples

Using config defaults (recommended):

>>> from dataeval.shift import OODReconstruction
>>> from dataeval.utils.models import AE, VAE
>>> from dataeval.utils.losses import ELBOLoss
>>> input_shape = (1, 28, 28)
>>> train_data = torch.rand(20, *input_shape)
>>> config = OODReconstruction.Config(epochs=10, threshold_perc=95)
>>> ood = OODReconstruction(AE(input_shape), config=config)
>>> ood.fit(train_data)  # Uses config defaults

Overriding specific parameters:

>>> ood = OODReconstruction(VAE(input_shape))
>>> ood.fit(train_data, epochs=20, batch_size=10)  # Override defaults

Using custom loss:

>>> config = OODReconstruction.Config(loss_fn=ELBOLoss(beta=2.0))
>>> ood = OODReconstruction(VAE(input_shape), config=config)
>>> ood.fit(train_data)
predict(x, batch_size=int(10000000000.0), ood_type='instance')

Predict whether instances are out of distribution.

Parameters:
x : ArrayLike

Input data for OOD prediction.

batch_size : int, default 1e10

Number of instances to process per batch (only used by some detectors).

ood_type : "feature" | "instance", default "instance"

Predict OOD at the "feature" or "instance" level.

Returns:

Predictions including is_ood boolean array and OOD scores.

Return type:

OODOutput

score(x, batch_size=int(10000000000.0))

Compute out of distribution scores for a given dataset.

Parameters:
x : ArrayLike

Input data to score.

batch_size : int, default 1e10

Number of instances to process per batch (only used by some detectors).

Returns:

Instance-level (and optionally feature-level) OOD scores. Higher scores indicate samples more likely to be OOD.

Return type:

OODScoreOutput

property device : torch.device

The device the model is on.

property model : torch.nn.Module

The underlying autoencoder or VAE model.

property model_type : str

"ae" or "vae".

Type:

Model type

property use_gmm : bool

Whether GMM-based scoring is enabled.

Classes

Config

Configuration for OODReconstruction detector training and threshold computation.