dataeval.shift.UncertaintyFeatureExtractor¶
-
class dataeval.shift.UncertaintyFeatureExtractor(model, preds_type=
'probs', batch_size=32, transforms=None, device=None)¶ Feature extractor that converts data to model uncertainty scores.
This class implements the
FeatureExtractorprotocol for use with drift detectors (e.g.,DriftUnivariate). It computes prediction uncertainty (entropy) from a classification model.Uncertainty-based drift detection monitors changes in model confidence rather than raw input features. This approach is particularly effective for detecting drift that affects model performance even when input statistics remain similar, such as out-of-domain samples or adversarial examples.
- Parameters:¶
- model : torch.nn.Module¶
Classification model to compute predictions and uncertainties. Should output class probabilities or logits.
- preds_type : "probs" or "logits", default "probs"¶
Format of model outputs. “probs” expects normalized probabilities summing to 1. “logits” expects raw model outputs and applies softmax.
- batch_size : int, default 32¶
Batch size for model inference. Larger batches improve GPU utilization but require more memory.
- transforms : Transform or Sequence[Transform] or None, default None¶
Preprocessing transforms to apply before model inference. Should match preprocessing used during model training for consistent predictions.
- device : DeviceLike or None, default None¶
Hardware device for computation. When None, uses DataEval’s configured device or PyTorch’s default.
Example
Basic usage with DriftUnivariate
>>> import numpy as np >>> import torch.nn as nn >>> from dataeval.shift import DriftUnivariate, UncertaintyFeatureExtractor >>> >>> # Create dummy datasets >>> train_dataset = np.random.randn(100, 16).astype(np.float32) >>> test_dataset = np.random.randn(20, 16).astype(np.float32) >>> >>> # Create a simple model >>> model = nn.Sequential(nn.Linear(16, 10), nn.Softmax(dim=-1)) >>> >>> # Create uncertainty feature extractor >>> uncertainty_extractor = UncertaintyFeatureExtractor(model=model, preds_type="probs", batch_size=32) >>> >>> # Use with DriftUnivariate for uncertainty-based drift detection >>> drift_detector = DriftUnivariate(train_dataset, method="ks", feature_extractor=uncertainty_extractor) >>> >>> # Detect drift on new data >>> result = drift_detector.predict(test_dataset) >>> print(f"Drift detected: {result.drifted}") Drift detected: FalseWith data preprocessing transforms
>>> import torch >>> >>> # Create new datasets and model for this example >>> train_dataset = np.random.randn(100, 16).astype(np.float32) >>> test_dataset = np.random.randn(20, 16).astype(np.float32) >>> model = nn.Sequential(nn.Linear(16, 10), nn.Softmax(dim=-1)) >>> >>> # Simple transform (no normalization needed for this dummy data) >>> transforms = lambda x: x.float() if not x.is_floating_point() else x >>> >>> uncertainty_extractor = UncertaintyFeatureExtractor(model=model, transforms=transforms, device="cpu") >>> >>> drift_detector = DriftUnivariate(train_dataset, method="ks", feature_extractor=uncertainty_extractor)Using different statistical methods
>>> # Create datasets and model for this example >>> train_dataset = np.random.randn(100, 16).astype(np.float32) >>> test_dataset = np.random.randn(20, 16).astype(np.float32) >>> model = nn.Sequential(nn.Linear(16, 10), nn.Softmax(dim=-1)) >>> uncertainty_extractor = UncertaintyFeatureExtractor(model=model, preds_type="probs", batch_size=32) >>> >>> # Use Cramér-von Mises test instead of Kolmogorov-Smirnov >>> drift_detector = DriftUnivariate( ... train_dataset, ... method="cvm", # More sensitive to overall distributional changes ... feature_extractor=uncertainty_extractor, ... ) >>> >>> # Or use Mann-Whitney U test for robust median shift detection >>> drift_detector = DriftUnivariate( ... train_dataset, ... method="mwu", # Robust to outliers ... feature_extractor=uncertainty_extractor, ... )Notes
The uncertainty extractor computes Shannon entropy: -sum(p * log(p)) where p are the predicted class probabilities. Higher entropy indicates greater model uncertainty.
This approach works best with well-calibrated models trained on representative data. Poorly calibrated models may produce misleading uncertainty estimates that don’t reliably indicate data quality issues.
Uncertainty-based drift detection is complementary to feature-based methods and can detect semantic drift (changes in data meaning) that may not be apparent in raw feature statistics.
See also
DriftUnivariateUnivariate drift detection with multiple statistical tests