dataeval.detectors.drift.DriftMMD

class dataeval.detectors.drift.DriftMMD(x_ref, p_val=0.05, x_ref_preprocessed=False, update_x_ref=None, preprocess_fn=None, sigma=None, configure_kernel_from_x_ref=True, n_permutations=100, device=None)

Maximum Mean Discrepancy (MMD) Drift Detection algorithm using a permutation test.

Parameters:
x_ref : ArrayLike

Data used as reference distribution.

p_val : float or None, default 0.05

P-value used for significance of the statistical test for each feature. If the FDR correction method is used, this corresponds to the acceptable q-value.

x_ref_preprocessed : bool, default False

Whether the given reference data x_ref has been preprocessed yet. If True, only the test data x will be preprocessed at prediction time. If False, the reference data will also be preprocessed.

update_x_ref : UpdateStrategy or None, default None

Reference data can optionally be updated using an UpdateStrategy class. Update using the last n instances seen by the detector with LastSeenUpdateStrategy or via reservoir sampling with ReservoirSamplingUpdateStrategy.

preprocess_fn : Callable or None, default None

Function to preprocess the data before computing the data drift metrics. Typically a dimensionality reduction technique.

sigma : ArrayLike or None, default None

Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. The kernel evaluation is then averaged over those bandwidths.

configure_kernel_from_x_ref : bool, default True

Whether to already configure the kernel bandwidth from the reference data.

n_permutations : int, default 100

Number of permutations used in the permutation test.

device : DeviceLike or None, default None

The hardware device to use if specified, otherwise uses the DataEval default or torch default.

Example

>>> from functools import partial
>>> from dataeval.detectors.drift import preprocess_drift

Use a preprocess function to encode images before testing for drift

>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
>>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)

Test incoming images for drift

>>> drift.predict(test_images).drifted
True
predict(x)

Predict whether a batch of data has drifted from the reference data and then updates reference data using specified strategy.

Parameters:
x : ArrayLike

Batch of instances.

Returns:

Output class containing the drift prediction, p-value, threshold and MMD metric.

Return type:

DriftMMDOutput

score(x)

Compute the p-value resulting from a permutation test using the maximum mean discrepancy as a distance measure between the reference data and the data to be tested.

Parameters:
x : ArrayLike

Batch of instances.

Returns:

p-value obtained from the permutation test, MMD^2 between the reference and test set, and MMD^2 threshold above which drift is flagged

Return type:

tuple(float, float, float)

property x_ref : dataeval.typing.ArrayLike

Retrieve the reference data, applying preprocessing if not already done.

Returns:

The reference dataset (x_ref), preprocessed if needed.

Return type:

ArrayLike