dataeval.shift.DriftDomainClassifier

class dataeval.shift.DriftDomainClassifier(n_folds=None, threshold=None, config=None)

Multivariate Domain Classifier based drift detector.

Detects drift by training a LightGBM classifier to distinguish between reference and test data. If the classifier can discriminate well (high AUROC), the distributions differ and drift is detected.

Uses a fit/predict lifecycle: construct with hyperparameters, call fit() with reference data, then call predict() with test data.

Supports two modes:

  • Non-chunked (default): Computes a single AUROC for the entire test set vs reference. Drift is flagged when AUROC exceeds the threshold (default 0.55).

  • Chunked: Splits data into chunks, computes AUROC per chunk, and uses threshold bounds to flag drift per chunk. Enable by passing chunking parameters to fit().

Parameters:
n_folds : int, default 5

Number of cross-validation (CV) folds.

threshold : float or tuple[float, float], default 0.55

For non-chunked mode: float threshold where AUROC > threshold means drift. For chunked mode: tuple (lower, upper) bounds on AUROC for identifying drift.

config : DriftDomainClassifier.Config or None, default None

Optional configuration object with default parameters. Parameters specified directly in __init__ will override config defaults.

Examples

Non-chunked mode:

>>> ref = np.random.randn(200, 4).astype(np.float32)
>>> test = np.random.randn(100, 4).astype(np.float32)
>>> detector = DriftDomainClassifier().fit(ref)
>>> result = detector.predict(test)
>>> print(f"Drift: {result.drifted}")
Drift: ...

Chunked mode:

>>> detector = DriftDomainClassifier(threshold=(0.45, 0.65)).fit(ref, chunk_size=100)
>>> result = detector.predict(test)

Using configuration:

>>> config = DriftDomainClassifier.Config(n_folds=10, threshold=(0.4, 0.6))
>>> detector = DriftDomainClassifier(config=config)
fit(x_ref, chunker=None, chunk_size=None, chunk_count=None, chunks=None, chunk_indices=None)

Fit the domain classifier on the reference data.

When chunking is enabled, the detector computes per-chunk baseline AUROC values from the reference data and derives threshold bounds. During prediction, the test data is split into chunks of the same size used here, so that per-chunk statistics are comparable to the baseline.

If chunk_count is provided, the effective chunk size is computed as len(x_ref) // chunk_count and locked in for prediction. Use chunk_size directly when you want explicit control over the chunk size used for both fitting and prediction.

Parameters:
x_ref : ArrayLike

Reference data with dim[n_samples, n_features].

chunker : ArrayChunker or None, default None

Explicit chunker instance for chunked mode.

chunk_size : int or None, default None

Create fixed-size chunks. The same size is used during prediction to keep statistics comparable.

chunk_count : int or None, default None

Split into this many equal chunks. Converted to a fixed chunk_size based on the reference data length.

chunks : list[ArrayLike] or None, default None

Pre-split reference data for chunked mode.

chunk_indices : list[list[int]] or None, default None

Index groupings for chunking reference data.

Return type:

Self

predict(x=None, chunks=None, chunk_indices=None)

Perform inference on the test data.

Parameters:
x : ArrayLike or None

Test (analysis) data with dim[n_samples, n_features]. Required for non-chunked mode and chunked mode unless pre-built chunks are provided.

chunks : list[ArrayLike] or None, default None

Pre-built test data chunks.

chunk_indices : list[list[int]] or None, default None

Index groupings for chunking test data.

Returns:

Non-chunked mode: details is a DriftDomainClassifierStats TypedDict. Chunked mode: details is a polars.DataFrame with per-chunk results.

Return type:

DriftOutput

property is_chunked : bool

Whether the detector is operating in chunked mode.

property x_ref : numpy.typing.NDArray[numpy.float32]

Reference data for drift detection.

Returns:

Reference data array.

Return type:

NDArray[np.float32]

Raises:

RuntimeError – If called before fit().

Classes

Config

Configuration for DriftDomainClassifier detector.