dataeval.extractors.TorchExtractor

class dataeval.extractors.TorchExtractor(model, transforms=None, device=None, layer_name=None, use_output=True, flatten=True, batch_size=None, postprocess_fn=None)

Extracts embeddings from a PyTorch model, with optional intermediate layer hooking.

Encapsulates all PyTorch-specific logic for feature extraction:

  • Model management (torch.nn.Module)

  • Device handling

  • Transform pipeline

  • Layer hooking for intermediate layer extraction

Implements the FeatureExtractor protocol.

Parameters:
model : torch.nn.Module

PyTorch model for feature extraction.

transforms : Transform or Sequence[Transform] or None, default None

Preprocessing transforms to apply before encoding. When None, uses raw images.

device : DeviceLike or None, default None

Device for computation. When None, uses DataEval’s configured device.

layer_name : str or None, default None

Layer to extract embeddings from. When None, uses model output.

use_output : bool, default True

If True, captures layer output; if False, captures layer input. Only used when layer_name is specified.

flatten : bool, default True

If True, flattens outputs with more than 2 dimensions to (N, D) shape. If False, preserves the original output shape.

batch_size : int or None, default None

Forward-pass (compute) batch size: how many images go through the model at once. None runs a single forward pass over all inputs. When this extractor is wrapped by Embeddings, Embeddings loads images in its own (I/O) chunks and this extractor sub-batches each chunk by this value, so the smaller of the two bounds the forward pass.

postprocess_fn : PostprocessFn or None, default None

Batch-level decode applied to each minibatch’s full raw model output (passed as-is, including a tuple output), e.g. to turn a detection head’s raw output into a (n_detections, n_classes) score tensor. Must return a 2D tensor per batch (or a tuple whose element 0 is one). Mutually exclusive with layer_name. When set, flatten is bypassed (decoded output is used as scores as-is). When not set, a tuple model output is reduced to its element 0.

Example

Basic usage with a model:

>>> import torch.nn as nn
>>> from dataeval import Embeddings
>>> from dataeval.extractors import TorchExtractor
>>>
>>> model = nn.Sequential(nn.Flatten(), nn.Linear(784, 128))
>>> extractor = TorchExtractor(model, device="cpu")
>>> embeddings = Embeddings(dataset, extractor=extractor, batch_size=32)

Extracting from an intermediate layer:

>>> extractor = TorchExtractor(
...     model,
...     layer_name="0",  # Extract from Flatten layer
...     use_output=True,
... )
property batch_size : int | None

Return the default batch size for inference, if set.

property flatten : bool

Return whether outputs are flattened to 2D.

property layer_name : str | None

Return the layer name for intermediate extraction, if set.

property use_output : bool

Return whether output (True) or input (False) is captured from the layer.