dataeval.shift.DriftReconstruction

class dataeval.shift.DriftReconstruction(model, device=None, model_type='auto', use_gmm=None, p_val=None, config=None)

Reconstruction-based drift detector using autoencoders.

Detects drift by comparing reconstruction errors: if the model (trained on reference data) produces higher reconstruction errors on test data, the test distribution has likely shifted.

Uses a fit/predict lifecycle: construct with model and hyperparameters, call fit() with reference data (trains the model), then call predict() with test data.

Supports two modes:

  • Non-chunked (default): Computes mean reconstruction error for the test set and uses a z-test against the reference baseline.

  • Chunked: Splits data into chunks, computes mean reconstruction error per chunk, and uses threshold bounds to flag drift.

Parameters:
model : torch.nn.Module

Autoencoder or VAE model.

device : DeviceLike or None, default None

Hardware device.

model_type : {"ae", "vae", "auto"} or None, default "auto"

Model type. "auto" auto-detects.

use_gmm : bool or None, default None

Whether to use GMM in latent space.

p_val : float, default 0.05

Significance threshold for non-chunked mode.

config : DriftReconstruction.Config or None, default None

Optional configuration object.

Examples

>>> from dataeval.utils.models import AE
>>> import torch
>>> model = AE(input_shape=(1, 28, 28))
>>> ref = torch.rand(100, 1, 28, 28).numpy()
>>> detector = DriftReconstruction(model).fit(ref)
>>> test = torch.rand(50, 1, 28, 28).numpy()
>>> result = detector.predict(test)
fit(x_ref, loss_fn=None, optimizer=None, epochs=None, batch_size=None, chunker=None, chunk_size=None, chunk_count=None, chunks=None, chunk_indices=None, threshold=None)

Fit the reconstruction drift detector.

Trains the autoencoder on reference data, then optionally sets up chunked baseline.

Parameters:
x_ref : ArrayLike

Reference data.

loss_fn : Callable or None, default None

Loss function for training.

optimizer : torch.optim.Optimizer or None, default None

Optimizer for training.

epochs : int or None, default None

Number of training epochs.

batch_size : int or None, default None

Batch size for training.

chunker : BaseChunker or None, default None

Explicit chunker instance for chunked mode.

chunk_size : int or None, default None

Create fixed-size chunks.

chunk_count : int or None, default None

Split into this many equal chunks.

chunks : list[ArrayLike] or None, default None

Pre-split reference data for chunked mode.

chunk_indices : list[list[int]] or None, default None

Index groupings for chunking reference data.

threshold : Threshold or None, default None

Threshold strategy for chunked mode.

Return type:

Self

predict(x=None, chunks=None, chunk_indices=None)

Predict whether test data has drifted from reference data.

Parameters:
x : ArrayLike or None

Test data.

chunks : list[ArrayLike] or None, default None

Pre-built test data chunks.

chunk_indices : list[list[int]] or None, default None

Index groupings for chunking test data.

Returns:

Non-chunked mode: details is a DriftReconstructionStats TypedDict. Chunked mode: details is a polars.DataFrame with per-chunk results.

Return type:

DriftOutput

property is_chunked : bool

Whether the detector is operating in chunked mode.

property x_ref : numpy.typing.NDArray[numpy.float32]

Reference data for drift detection.

Returns:

Reference data array.

Return type:

NDArray[np.float32]

Raises:

RuntimeError – If called before fit().

Classes

Config

Configuration for DriftReconstruction detector.