dataeval.utils.training.predict

dataeval.utils.training.predict(x, model, device=None, batch_size=None, preprocess_fn=None)

Make batch predictions on a model.

Parameters:
x : np.ndarray or torch.Tensor

Batch of instances.

model : torch.nn.Module

PyTorch model.

device : DeviceLike or None, default None

The hardware device to use if specified, otherwise uses the DataEval default or torch default.

batch_size : int or None, default None

Batch size used during prediction. If None, uses DataEval default (1e10).

preprocess_fn : Callable or None, default None

Optional preprocessing function for each batch.

Returns:

PyTorch tensor with model outputs, or tuple of tensors if model returns tuple (e.g., VAE models return (reconstruction, mu, logvar)).

Return type:

torch.Tensor or tuple[torch.Tensor, …]