dataeval.utils.training.predict¶
-
dataeval.utils.training.predict(x, model, device=
None, batch_size=None, preprocess_fn=None)¶ Make batch predictions on a model.
- Parameters:¶
- x : np.ndarray or torch.Tensor¶
Batch of instances.
- model : torch.nn.Module¶
PyTorch model.
- device : DeviceLike or None, default None¶
The hardware device to use if specified, otherwise uses the DataEval default or torch default.
- batch_size : int or None, default None¶
Batch size used during prediction. If None, uses DataEval default (1e10).
- preprocess_fn : Callable or None, default None¶
Optional preprocessing function for each batch.
- Returns:¶
PyTorch tensor with model outputs, or tuple of tensors if model returns tuple (e.g., VAE models return (reconstruction, mu, logvar)).
- Return type:¶
torch.Tensor or tuple[torch.Tensor, …]