dataeval.shift.DriftWasserstein¶
-
class dataeval.shift.DriftWasserstein(ratio_threshold=
None, n_features=None, extractor=None, update_strategy=None, config=None)¶ Drift detector using Wasserstein distance with a validation baseline.
Detects distributional changes by comparing the Wasserstein distance between training and operational data against a baseline distance computed from training and validation data. Drift is flagged when the ratio of operational distance to baseline distance exceeds a threshold.
Unlike hypothesis-test-based detectors, this detector requires two reference datasets to be provided at fit time: a training set and an in-distribution validation set. The train/validation Wasserstein distance serves as a calibrated baseline for what “no drift” looks like. At predict time, the train/operational Wasserstein distance is divided by this baseline; ratios substantially above 1.0 indicate drift.
For multivariate data (e.g. model embeddings), the test is applied independently to each feature. A feature is considered drifted when its individual distance ratio exceeds the threshold. Overall drift is declared when any feature drifts.
Uses a fit/predict lifecycle: construct with hyperparameters, call
fit()with training and validation data, then callpredict()with test data. Usechunked()to create a chunked wrapper for time-series monitoring.- Parameters:¶
- ratio_threshold : float, default 1.4¶
Distance ratio above which drift is declared, per feature. A value of 1.4 means operational distances more than 40% larger than the train/validation baseline are flagged as drift.
- n_features : int or None, default None¶
Number of features to analyse. When None, automatically inferred from the flattened shape of the first data sample.
- extractor : FeatureExtractor or None, default None¶
Optional feature extraction function to convert input data to arrays. When provided, enables drift detection on non-array inputs such as datasets, metadata, or raw model outputs. The extractor is applied to training, validation, and test data before drift detection. When None, data must already be Array-like.
- update_strategy : UpdateStrategy or None, default None¶
Strategy for updating the training reference data when new data arrives. When None, reference data remains fixed throughout detection.
- config : DriftWasserstein.Config or None, default None¶
Optional configuration object with default parameters. Parameters specified directly in __init__ will override config defaults.
Example
Basic drift detection with Wasserstein distance
>>> rng = np.random.default_rng(42) >>> train_emb = rng.standard_normal((200, 64)).astype(np.float32) >>> val_emb = rng.standard_normal((100, 64)).astype(np.float32) >>> drift_detector = DriftWasserstein().fit(train_emb, val_emb) >>> test_emb = np.zeros((50, 64), dtype=np.float32) >>> result = drift_detector.predict(test_emb) >>> print(f"Drift detected: {result.drifted}") Drift detected: TrueChunked drift detection
>>> chunked = DriftWasserstein().chunked(chunk_size=50) >>> chunked.fit(train_emb, val_emb) ChunkedDrift(DriftWasserstein(...), chunker=SizeChunker(...), fitted=True) >>> result = chunked.predict(test_emb) >>> print(f"Drift detected: {result.drifted}, chunks: {len(result.details)}") Drift detected: True, chunks: 1Using configuration:
>>> config = DriftWasserstein.Config(ratio_threshold=1.2) >>> drift = DriftWasserstein(config=config).fit(train_emb, val_emb)-
chunked(chunker=
None, chunk_size=None, chunk_count=None, threshold=None)¶ Create a chunked wrapper around this drift detector.
Returns a
ChunkedDriftthat splits data into chunks during fit and predict, computing per-chunk metrics and comparing against baseline thresholds.- Parameters:¶
- chunker : BaseChunker or None, default None¶
Explicit chunker instance.
- chunk_size : int or None, default None¶
Create fixed-size chunks of this many samples.
- chunk_count : int or None, default None¶
Split into this many equal chunks.
- threshold : Threshold or None, default None¶
Threshold strategy for determining drift bounds from baseline. When None, uses the detector’s default threshold.
- Returns:¶
A chunked drift wrapper around this detector.
- Return type:¶
ChunkedDrift[TDetails]
-
fit(reference_data, validation_data=
None)¶ Fit detector with training and validation reference data.
Encodes both datasets, stores the training set as the reference for subsequent drift comparisons, and computes the per-feature Wasserstein distance between training and validation data as the drift baseline.
- Parameters:¶
- reference_data : Any¶
Training dataset used as the primary reference for drift detection. Can be Array-like or any type supported by the configured extractor.
- validation_data : Any, default None¶
Validation dataset drawn from the same distribution as
reference_data. Used to calibrate the baseline distance. Must be compatible withreference_data(same feature dimensionality after encoding). Required despite theNonedefault.
- Return type:¶
Self
- Raises:¶
ValueError – If
validation_datais not provided.
- predict(data)¶
Predict drift and optionally update reference data.
Computes per-feature Wasserstein distances between the training reference and
data, divides by the train/validation baseline distances, and flags drift when any feature ratio exceedsratio_threshold.
- score(data)¶
Compute per-feature distance ratios and raw Wasserstein distances.
Encodes
data, computes the per-feature Wasserstein distance between the training reference anddata, then divides by the baseline train/validation distances.
- property baseline_distances : numpy.typing.NDArray[numpy.float32]¶
Per-feature Wasserstein distances between training and validation data.
Computed once during
fit()and reused across all subsequentpredict()calls.- Returns:¶
Baseline distances, shape
(n_features,).- Return type:¶
NDArray[np.float32]
- Raises:¶
NotFittedError – If called before
fit().
- property n_features : int¶
Number of features in the reference data.
Lazily computes the number of features from the encoded reference array if not provided during initialization.
- property reference_data : numpy.typing.NDArray[numpy.float32]¶
Reference data, lazily encoded on first access.
Overrides
BaseDrift.reference_datavia MRO when this mixin appears beforeBaseDriftin the inheritance list.
Classes¶
Configuration for DriftWasserstein detector. |
|
Per-feature statistics from Wasserstein drift detection. |