dataeval.detectors.drift.DriftMMD

class dataeval.detectors.drift.DriftMMD(data, p_val=0.05, update_strategy=None, sigma=None, n_permutations=100, device=None)

Drift detector using Maximum Mean Discrepancy (MMD) Drift Detection with permutation test.

Detects distributional differences by comparing kernel embeddings of reference and test datasets in a reproducing kernel Hilbert space (RKHS). Uses permutation testing to assess statistical significance of the observed MMD^2 statistic.

MMD is particularly effective for high-dimensional data like images as it can capture complex distributional differences that univariate tests might miss. The kernel-based approach enables detection of both marginal and dependency changes between features.

Parameters:
data : Embeddings or Array

Reference dataset used as baseline distribution for drift detection. Should represent the expected data distribution.

p_val : float, default 0.05

Significance threshold for statistical tests, between 0 and 1. For FDR correction, this represents the acceptable false discovery rate. Default 0.05 provides 95% confidence level for drift detection.

update_strategy : UpdateStrategy or None, default None

Strategy for updating reference data when new data arrives. When None, reference data remains fixed throughout detection.

sigma : Array or None, default None

Bandwidth parameter(s) for the Gaussian RBF kernel. Controls the kernel’s sensitivity to distance between data points. When None, automatically selects bandwidth using median heuristic. Can provide multiple values as array to average over different scales.

n_permutations : int, default 100

Number of random permutations used in the permutation test to estimate the null distribution of MMD² under no drift. Higher values provide more accurate p-value estimates but increase computation time. Default 100 balances statistical accuracy with computational efficiency.

device : DeviceLike or None, default None

Hardware device for computation. When None, automatically selects DataEval’s configured device, falling back to PyTorch’s default.

p_val

Significance threshold for statistical tests.

Type:

float

update_strategy

Reference data update strategy.

Type:

UpdateStrategy or None

n

Number of samples in the reference dataset.

Type:

int

sigma

Gaussian RBF kernel bandwidth parameter(s).

Type:

Array or None

n_permutations

Number of permutations for statistical testing.

Type:

int

device

Hardware device used for computations.

Type:

torch.device

Example

>>> from dataeval.data import Embeddings

Use Embeddings to encode images before testing for drift

>>> train_emb = Embeddings(train_images, model=encoder, batch_size=16)
>>> drift = DriftMMD(train_emb)

Test incoming images for drift

>>> drift.predict(test_images).drifted
True
predict(data)

Predict whether a batch of data has drifted from the reference data and then updates reference data using specified strategy.

Parameters:
data : Embeddings or Array

Batch of instances to predict drift on.

Returns:

Output class containing the drift prediction, p-value, threshold and MMD metric.

Return type:

DriftMMDOutput

score(data)

Compute the p-value resulting from a permutation test using the maximum mean discrepancy as a distance measure between the reference data and the data to be tested.

Parameters:
data : Embeddings or Array

Batch of instances to score.

Returns:

p-value obtained from the permutation test, MMD^2 between the reference and test set, and MMD^2 threshold above which drift is flagged

Return type:

tuple(float, float, float)

property x_ref : numpy.typing.NDArray[numpy.float32]

Reference data for drift detection.

Lazily encodes the reference dataset on first access. Data is flattened and converted to 32-bit floating point for consistent numerical processing across different input types.

Returns:

Reference data as flattened 32-bit floating point array. Shape is (n_samples, n_features_flattened).

Return type:

NDArray[np.float32]

Notes

Data is cached after first access to avoid repeated encoding overhead.