dataeval.detectors.drift.DriftMMD

class dataeval.detectors.drift.DriftMMD(data, p_val=0.05, update_strategy=None, sigma=None, n_permutations=100, device=None)

Maximum Mean Discrepancy (MMD) Drift Detection algorithm using a permutation test.

Parameters:
data : Embeddings or Array

Data used as reference distribution.

p_val : float or None, default 0.05

P-value used for significance of the statistical test for each feature. If the FDR correction method is used, this corresponds to the acceptable q-value.

update_strategy : UpdateStrategy or None, default None

Reference data can optionally be updated using an UpdateStrategy class. Update using the last n instances seen by the detector with LastSeenUpdateStrategy or via reservoir sampling with ReservoirSamplingUpdateStrategy.

sigma : Array or None, default None

Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. The kernel evaluation is then averaged over those bandwidths.

n_permutations : int, default 100

Number of permutations used in the permutation test.

device : DeviceLike or None, default None

The hardware device to use if specified, otherwise uses the DataEval default or torch default.

Example

>>> from dataeval.utils.data import Embeddings

Use Embeddings to encode images before testing for drift

>>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
>>> drift = DriftMMD(train_emb)

Test incoming images for drift

>>> drift.predict(test_images).drifted
True
predict(data)

Predict whether a batch of data has drifted from the reference data and then updates reference data using specified strategy.

Parameters:
data : Embeddings or Array

Batch of instances to predict drift on.

Returns:

Output class containing the drift prediction, p-value, threshold and MMD metric.

Return type:

DriftMMDOutput

score(data)

Compute the p-value resulting from a permutation test using the maximum mean discrepancy as a distance measure between the reference data and the data to be tested.

Parameters:
data : Embeddings or Array

Batch of instances to score.

Returns:

p-value obtained from the permutation test, MMD^2 between the reference and test set, and MMD^2 threshold above which drift is flagged

Return type:

tuple(float, float, float)

property x_ref : numpy.typing.NDArray[numpy.float32]

Retrieve the reference data of the drift detector.

Returns:

The reference data as a 32-bit floating point numpy array.

Return type:

NDArray[np.float32]