dataeval.shift.DriftMMD

class dataeval.shift.DriftMMD(p_val=None, update_strategy=None, sigma=None, n_permutations=None, permutation_batch_size=None, device=None, extractor=None, config=None)

Drift detector using Maximum Mean Discrepancy (MMD) Drift Detection with permutation test.

Detects distributional differences by comparing kernel embeddings of reference and test datasets in a reproducing kernel Hilbert space (RKHS). Uses permutation testing to assess statistical significance of the observed MMD^2 statistic.

Uses a fit/predict lifecycle: construct with hyperparameters, call fit() with reference data, then call predict() with test data. Use chunked() to create a chunked wrapper for time-series monitoring.

MMD is particularly effective for high-dimensional data like images as it can capture complex distributional differences that univariate tests might miss. The kernel-based approach enables detection of both marginal and dependency changes between features.

Parameters:
p_val : float, default 0.05

Significance threshold for statistical tests, between 0 and 1. For FDR correction, this represents the acceptable false discovery rate. Default 0.05 provides 95% confidence level for drift detection.

update_strategy : UpdateStrategy or None, default None

Strategy for updating reference data when new data arrives. When None, reference data remains fixed throughout detection.

sigma : Array or None, default None

Bandwidth parameter(s) for the Gaussian RBF kernel. Controls the kernel’s sensitivity to distance between data points. When None, automatically selects bandwidth using median heuristic. Can provide multiple values as array to average over different scales.

n_permutations : int, default 100

Number of random permutations used in the permutation test to estimate the null distribution of MMD² under no drift. Higher values provide more accurate p-value estimates but increase computation time. Default 100 balances statistical accuracy with computational efficiency.

permutation_batch_size : int or "auto", default "auto"

Batch size for computing permutations to reduce memory usage. When “auto” (default), automatically detects appropriate batch size based on available GPU memory (on CUDA devices) or computes all permutations at once (on CPU). Set to an integer to manually specify batch size. Useful when working with large kernel matrices or many permutations to avoid GPU out-of-memory errors.

device : DeviceLike or None, default None

Hardware device for computation. When None, automatically selects DataEval’s configured device, falling back to PyTorch’s default.

extractor : FeatureExtractor or None, default None

Feature extractor for transforming input data before drift detection. When provided, raw data is passed through the extractor before flattening and comparison. When None, data is used as-is.

config : DriftMMD.Config or None, default None

Optional configuration object with default parameters. Parameters specified directly in __init__ will override config defaults.

p_val

Significance threshold for statistical tests.

Type:

float

update_strategy

Reference data update strategy.

Type:

UpdateStrategy or None

n

Number of samples in the reference dataset.

Type:

int

sigma

Gaussian RBF kernel bandwidth parameter(s).

Type:

Array or None

n_permutations

Number of permutations for statistical testing.

Type:

int

permutation_batch_size

Batch size for computing permutations, or “auto” for automatic detection.

Type:

int or “auto”

device

Hardware device used for computations.

Type:

torch.device

See also

DriftMMD.Stats

Per-prediction statistics returned in DriftOutput.details.

Example

Initialize with image embeddings

>>> train_emb = np.ones((100, 128), dtype=np.float32)
>>> drift = DriftMMD().fit(train_emb)

Test incoming images for drift

>>> test_emb = np.zeros((20, 128), dtype=np.float32)
>>> result = drift.predict(test_emb)
>>> print(f"Drift detected: {result.drifted}")
Drift detected: True
>>> print(f"Mean MMD statistic: {result.distance:.2f}")
Mean MMD statistic: 1.26

Chunked drift detection with z-score thresholds

>>> chunked = DriftMMD().chunked(chunk_size=20)
>>> chunked.fit(train_emb)
ChunkedDrift(DriftMMD(p_val=0.05, sigma=None, n_permutations=100, permutation_batch_size='auto', device=None, update_strategy=None, extractor=None), chunker=SizeChunker(chunk_size=20, incomplete='keep'), fitted=True)
>>> result = chunked.predict(test_emb)

Using configuration:

>>> config = DriftMMD.Config(p_val=0.01, n_permutations=200)
>>> drift = DriftMMD(config=config).fit(train_emb)
chunked(chunker=None, chunk_size=None, chunk_count=None, threshold=None)

Create a chunked wrapper around this drift detector.

Returns a ChunkedDrift that splits data into chunks during fit and predict, computing per-chunk metrics and comparing against baseline thresholds.

Parameters:
chunker : BaseChunker or None, default None

Explicit chunker instance.

chunk_size : int or None, default None

Create fixed-size chunks of this many samples.

chunk_count : int or None, default None

Split into this many equal chunks.

threshold : Threshold or None, default None

Threshold strategy for determining drift bounds from baseline. When None, uses the detector’s default threshold.

Returns:

A chunked drift wrapper around this detector.

Return type:

ChunkedDrift[TDetails]

fit(reference_data)

Fit detector with reference data.

Stores reference data, initializes the kernel, and precomputes the reference kernel matrix.

Parameters:
reference_data : Array

Reference dataset used as baseline distribution for drift detection.

Return type:

Self

predict(data)

Predict whether a batch of data has drifted from the reference data.

Uses permutation test to assess statistical significance.

Parameters:
data : Any

Batch of instances to predict drift on.

Returns:

Drift prediction with MMD statistics.

Return type:

DriftOutput[DriftMMD.Stats]

score(data)

Compute the p-value resulting from a permutation test using the maximum mean discrepancy.

The maximum mean discrepancy is used as a distance measure between the reference data and the data to be tested.

Parameters:
data : Array

Batch of instances to score.

Returns:

p-value obtained from the permutation test, MMD^2 between the reference and test set, and MMD^2 threshold above which drift is flagged

Return type:

tuple(float, float, float)

property reference_data : numpy.typing.NDArray[numpy.float32]

Reference data, lazily encoded on first access.

Overrides BaseDrift.reference_data via MRO when this mixin appears before BaseDrift in the inheritance list.

Classes

Config

Configuration for DriftMMD detector.

Stats

Statistics from MMD permutation test.