Autoencoder Trainer

Autoencoders (AEs) are a type of neural network architecture that contain two parts: an encoder and decoder. While there are many uses of AEs, DAML uses them for dimensionality reduction on datasets with large images.

How does it work?

The encoder is trained to create dense embeddings for the images while the decoder is trained to reconstruct the new embedding into the original input image. This allows the dense embedding to become an efficient downsampling of the images, allowing for faster model inference and metric computation.

Tutorials

Check out this tutorial to begin using the AETrainer class

Autoencoder Trainer

How To Guides

There are currently no how to’s for AETrainer. If there are scenarios that you want us to explain, contact us!

DAML API

class daml.models.torch.AETrainer(model: Module, device: str | device = 'auto', batch_size: int = 8)
encode(dataset: Dataset) Tensor

Encode data through model if it has an encode attribute, otherwise passes data through model.forward

Parameters:

dataset (Dataset) – Dataset containing images to be encoded by the model

Returns:

Data encoded by the model

Return type:

torch.Tensor

eval(dataset: Dataset) float

Basic evaluation function for Autoencoder models for reconstruction tasks

Uses torch.optim.Adam and torch.nn.MSELoss as default hyperparameters

Parameters:

dataset (Dataset) – Torch Dataset containing images in the first return position

Returns:

Total reconstruction loss over all data

Return type:

float

Note

To replace this function with a custom function, do

AETrainer.eval = custom_function

train(dataset: Dataset, epochs: int = 25) List[float]

Basic training function for Autoencoder models for reconstruction tasks

Uses torch.optim.Adam and torch.nn.MSELoss as default hyperparameters

Parameters:
  • dataset (Dataset) – Torch Dataset containing images in the first return position

  • epochs (int, default 25) – Number of full training loops

Note

To replace this function with a custom function, do

AETrainer.train = custom_function