dataeval.extractors.TorchExtractor¶
-
class dataeval.extractors.TorchExtractor(model, transforms=
None, device=None, layer_name=None, use_output=True, flatten=True, batch_size=None, postprocess_fn=None)¶ Extracts embeddings from a PyTorch model, with optional intermediate layer hooking.
Encapsulates all PyTorch-specific logic for feature extraction:
Model management (torch.nn.Module)
Device handling
Transform pipeline
Layer hooking for intermediate layer extraction
Implements the
FeatureExtractorprotocol.- Parameters:¶
- model : torch.nn.Module¶
PyTorch model for feature extraction.
- transforms : Transform or Sequence[Transform] or None, default None¶
Preprocessing transforms to apply before encoding. When None, uses raw images.
- device : DeviceLike or None, default None¶
Device for computation. When None, uses DataEval’s configured device.
- layer_name : str or None, default None¶
Layer to extract embeddings from. When None, uses model output.
- use_output : bool, default True¶
If True, captures layer output; if False, captures layer input. Only used when layer_name is specified.
- flatten : bool, default True¶
If True, flattens outputs with more than 2 dimensions to (N, D) shape. If False, preserves the original output shape.
- batch_size : int or None, default None¶
Forward-pass (compute) batch size: how many images go through the model at once.
Noneruns a single forward pass over all inputs. When this extractor is wrapped byEmbeddings,Embeddingsloads images in its own (I/O) chunks and this extractor sub-batches each chunk by this value, so the smaller of the two bounds the forward pass.- postprocess_fn : PostprocessFn or None, default None¶
Batch-level decode applied to each minibatch’s full raw model output (passed as-is, including a tuple output), e.g. to turn a detection head’s raw output into a
(n_detections, n_classes)score tensor. Must return a 2D tensor per batch (or a tuple whose element 0 is one). Mutually exclusive withlayer_name. When set,flattenis bypassed (decoded output is used as scores as-is). When not set, a tuple model output is reduced to its element 0.
Example
Basic usage with a model:
>>> import torch.nn as nn >>> from dataeval import Embeddings >>> from dataeval.extractors import TorchExtractor >>> >>> model = nn.Sequential(nn.Flatten(), nn.Linear(784, 128)) >>> extractor = TorchExtractor(model, device="cpu") >>> embeddings = Embeddings(dataset, extractor=extractor, batch_size=32)Extracting from an intermediate layer:
>>> extractor = TorchExtractor( ... model, ... layer_name="0", # Extract from Flatten layer ... use_output=True, ... )