dataeval.utils.training.train

dataeval.utils.training.train(model, x_train, y_train, loss_fn, optimizer, preprocess_fn, epochs, batch_size, device=None, *, progress_callback=None)

Train PyTorch model.

Parameters:
model : torch.nn.Module

Model to train.

x_train : NDArray

Training data.

y_train : NDArray or None

Training labels. If None, assumes autoencoder-style training where x is the target.

loss_fn : Callable or None

Loss function used for training. If None, uses MSELoss.

optimizer : torch.optim.Optimizer or None

Optimizer used for training. If None, uses Adam with lr=0.001.

preprocess_fn : Callable or None

Preprocessing function applied to each training batch.

epochs : int

Number of training epochs.

batch_size : int or None

Batch size used for training.

device : DeviceLike or None, default None

The hardware device to use if specified, otherwise uses the DataEval default or torch default.

progress_callback : ProgressCallback or None, default None

Optional progress callback function.

Return type:

None