dataeval.utils.training.train¶
-
dataeval.utils.training.train(model, x_train, y_train, loss_fn, optimizer, preprocess_fn, epochs, batch_size, device=
None, *, progress_callback=None)¶ Train PyTorch model.
- Parameters:¶
- model : torch.nn.Module¶
Model to train.
- x_train : NDArray¶
Training data.
- y_train : NDArray or None¶
Training labels. If None, assumes autoencoder-style training where x is the target.
- loss_fn : Callable or None¶
Loss function used for training. If None, uses MSELoss.
- optimizer : torch.optim.Optimizer or None¶
Optimizer used for training. If None, uses Adam with lr=0.001.
- preprocess_fn : Callable or None¶
Preprocessing function applied to each training batch.
- epochs : int¶
Number of training epochs.
- batch_size : int or None¶
Batch size used for training.
- device : DeviceLike or None, default None¶
The hardware device to use if specified, otherwise uses the DataEval default or torch default.
- progress_callback : ProgressCallback or None, default None¶
Optional progress callback function.