dataeval.utils.torch.trainer.AETrainer¶
-
class dataeval.utils.torch.trainer.AETrainer(model, device=
None, batch_size=8)¶ A class to train and evaluate an autoencoder<Autoencoder>` model.
- Parameters:¶
- encode(dataset)¶
Create image embeddings for the dataset using the model’s encoder.
If the model has an encode method, it will be used; otherwise, model.forward will be used.
Note
This function should be run after the model has been trained and evaluated.
- eval(dataset)¶
Basic image reconstruction evaluation function for autoencoder models
Uses torch.nn.MSELoss as default loss function.
Note
- To replace this function with a custom function, do:
AETrainer.eval = custom_function
-
train(dataset, epochs=
25)¶ Basic image reconstruction training function for Autoencoder models
Uses torch.optim.Adam and torch.nn.MSELoss as default hyperparameters
Note
- To replace this function with a custom function, do:
AETrainer.train = custom_function