dataeval.extractors.ClassifierUncertaintyExtractor

class dataeval.extractors.ClassifierUncertaintyExtractor(model, preds_type='probs', batch_size=32, transforms=None, device=None)

Computes prediction entropy from a classification model for drift detection.

This class implements the FeatureExtractor protocol for use with drift detectors (e.g., DriftUnivariate). It computes prediction uncertainty (entropy) from a classification model.

Uncertainty-based drift detection monitors changes in model confidence rather than raw input features. This approach is particularly effective for detecting drift that affects model performance even when input statistics remain similar, such as out-of-domain samples or adversarial examples.

Parameters:
model : torch.nn.Module

Classification model to compute predictions and uncertainties. Should output class probabilities or logits.

preds_type : "probs" or "logits", default "probs"

Format of model outputs. “probs” expects normalized probabilities summing to 1. “logits” expects raw model outputs and applies softmax.

batch_size : int, default 32

Batch size for model inference. Larger batches improve GPU utilization but require more memory.

transforms : Transform or Sequence[Transform] or None, default None

Preprocessing transforms to apply before model inference. Should match preprocessing used during model training for consistent predictions.

device : DeviceLike or None, default None

Hardware device for computation. When None, uses DataEval’s configured device or PyTorch’s default.

model

The classification model used for predictions.

Type:

torch.nn.Module

preds_type

Format of model outputs.

Type:

{“probs”, “logits”}

batch_size

Batch size for inference.

Type:

int

device

Hardware device for computation.

Type:

torch.device

Example

Basic usage with DriftUnivariate

>>> import numpy as np
>>> import torch.nn as nn
>>> from dataeval.shift import DriftUnivariate
>>> from dataeval.extractors import ClassifierUncertaintyExtractor
>>>
>>> # Create dummy datasets
>>> train_dataset = np.random.randn(100, 16).astype(np.float32)
>>> test_dataset = np.random.randn(20, 16).astype(np.float32)
>>>
>>> # Create a simple model
>>> model = nn.Sequential(nn.Linear(16, 10), nn.Softmax(dim=-1))
>>>
>>> # Create uncertainty feature extractor
>>> uncertainty_extractor = ClassifierUncertaintyExtractor(model=model, preds_type="probs", batch_size=32)
>>>
>>> # Use with DriftUnivariate for uncertainty-based drift detection
>>> drift_detector = DriftUnivariate(train_dataset, method="ks", extractor=uncertainty_extractor)
>>>
>>> # Detect drift on new data
>>> result = drift_detector.predict(test_dataset)
>>> print(f"Drift detected: {result.drifted}")
Drift detected: False

With data preprocessing transforms

>>> import torch
>>>
>>> # Create new datasets and model for this example
>>> train_dataset = np.random.randn(100, 16).astype(np.float32)
>>> test_dataset = np.random.randn(20, 16).astype(np.float32)
>>> model = nn.Sequential(nn.Linear(16, 10), nn.Softmax(dim=-1))
>>>
>>> # Simple transform (no normalization needed for this dummy data)
>>> transforms = lambda x: x.float() if not x.is_floating_point() else x
>>>
>>> uncertainty_extractor = ClassifierUncertaintyExtractor(model=model, transforms=transforms, device="cpu")
>>>
>>> drift_detector = DriftUnivariate(train_dataset, method="ks", extractor=uncertainty_extractor)

Using different statistical methods

>>> # Create datasets and model for this example
>>> train_dataset = np.random.randn(100, 16).astype(np.float32)
>>> test_dataset = np.random.randn(20, 16).astype(np.float32)
>>> model = nn.Sequential(nn.Linear(16, 10), nn.Softmax(dim=-1))
>>> uncertainty_extractor = ClassifierUncertaintyExtractor(model=model, preds_type="probs", batch_size=32)
>>>
>>> # Use Cramér-von Mises test instead of Kolmogorov-Smirnov
>>> drift_detector = DriftUnivariate(
...     train_dataset,
...     method="cvm",  # More sensitive to overall distributional changes
...     extractor=uncertainty_extractor,
... )
>>>
>>> # Or use Mann-Whitney U test for robust median shift detection
>>> drift_detector = DriftUnivariate(
...     train_dataset,
...     method="mwu",  # Robust to outliers
...     extractor=uncertainty_extractor,
... )

Notes

The uncertainty extractor computes Shannon entropy: -sum(p * log(p)) where p are the predicted class probabilities. Higher entropy indicates greater model uncertainty.

This approach works best with well-calibrated models trained on representative data. Poorly calibrated models may produce misleading uncertainty estimates that don’t reliably indicate data quality issues.

Uncertainty-based drift detection is complementary to feature-based methods and can detect semantic drift (changes in data meaning) that may not be apparent in raw feature statistics.

See also

dataeval.shift.DriftUnivariate

Univariate drift detection with multiple statistical tests