dataeval.shift.EmbeddingsFeatureExtractor¶
-
class dataeval.shift.EmbeddingsFeatureExtractor(model=
None, batch_size=None, transforms=None, layer_name=None, use_output=True, device=None, embeddings=None)¶ Extract embeddings from datasets for drift detection.
This class implements the
FeatureExtractorprotocol for use with drift detectors. It converts raw datasets into embeddings using a neural network model, with support for reusing pre-computed embeddings to avoid redundant computation.The extractor maintains state to cache reference embeddings and avoid recomputation when the same dataset is passed multiple times (e.g., for reference data initialization vs. actual drift detection).
- Parameters:¶
- model : torch.nn.Module or None, default None¶
Model to extract embeddings with. When None and embeddings is provided, uses the model from the embeddings object.
- batch_size : int or None, default None¶
Batch size for processing images through the model. Uses global batch_size if not provided.
- transforms : Transform or Sequence[Transform] or None, default None¶
Preprocessing transforms to apply before model inference.
- layer_name : str or None, default None¶
Network layer from which to extract embeddings. When None, uses model output.
- use_output : bool, default True¶
If True, captures output tensors from layer_name. If False, captures input tensors.
- device : DeviceLike or None, default None¶
Hardware device for computation. When None, uses DataEval’s configured device.
- embeddings : Embeddings or None, default None¶
Pre-computed Embeddings object to reuse. When provided, avoids recomputation for the same dataset. This is useful when you’ve already computed embeddings and want to use them for drift detection without redundant processing.
Example
Basic usage with a dataset:
>>> import numpy as np >>> import torch.nn as nn >>> from dataeval.shift import DriftUnivariate, EmbeddingsFeatureExtractor >>> >>> # Create dummy data >>> train_data = np.random.randn(100, 16).astype(np.float32) >>> test_data = np.random.randn(50, 16).astype(np.float32) >>> >>> # Create feature extractor >>> model = nn.Sequential(nn.Linear(16, 128), nn.ReLU(), nn.Linear(128, 64)) >>> embeddings_extractor = EmbeddingsFeatureExtractor(model=model, batch_size=32) >>> >>> # Use with drift detector on raw datasets >>> drift_detector = DriftUnivariate( ... data=train_data, ... method="ks", ... feature_extractor=embeddings_extractor, ... ) >>> result = drift_detector.predict(test_data) >>> print(f"Drift detected: {result.drifted}") Drift detected: FalseReusing pre-computed embeddings:
>>> from dataeval import Embeddings >>> >>> # Use ExampleDataset for structured data (1x4x4 images for simple example) >>> train_ds = ExampleDataset(100, image_shape=(1, 4, 4), n_classes=10, seed=42) >>> model_emb = nn.Sequential(nn.Flatten(), nn.Linear(16, 128), nn.ReLU(), nn.Linear(128, 64)) >>> >>> # Compute embeddings once >>> train_embeddings = Embeddings(train_ds, batch_size=32, model=model_emb).compute() >>> >>> # Reuse embeddings with drift detector >>> embeddings_extractor = EmbeddingsFeatureExtractor(embeddings=train_embeddings) >>> drift_detector = DriftUnivariate( ... data=train_ds, ... method="ks", ... feature_extractor=embeddings_extractor, ... )Notes
The extractor caches a reference to the dataset used during initialization to avoid redundant embedding computation when the same dataset is passed multiple times (common in reference data initialization).
See also
EmbeddingsUnderlying embeddings computation class
DriftUnivariateUnivariate drift detection with multiple statistical tests