dataeval.utils.torch.trainer.AETrainer ====================================== .. py:class:: dataeval.utils.torch.trainer.AETrainer(model, device = 'auto', batch_size = 8) A class to train and evaluate an autoencoder` model. :param model: The model to be trained. :type model: nn.Module :param device: The hardware device to use for training. If "auto", the device will be set to "cuda" if available, otherwise "cpu". :type device: str or torch.device, default "auto" :param batch_size: The number of images to process in a batch. :type batch_size: int, default 8 .. py:method:: encode(dataset) Create image :term:`embeddings` for the dataset using the model's encoder. If the model has an `encode` method, it will be used; otherwise, `model.forward` will be used. :param dataset: The dataset to encode. Torch Dataset containing images in the first return position. :type dataset: Dataset :returns: Data encoded by the model :rtype: torch.Tensor .. note:: This function should be run after the model has been trained and evaluated. .. py:method:: eval(dataset) Basic image reconstruction evaluation function for :term:`autoencoder` models Uses `torch.nn.MSELoss` as default loss function. :param dataset: The dataset to evaluate on. Torch Dataset containing images in the first return position. :type dataset: Dataset :returns: Total reconstruction loss over the entire dataset :rtype: float .. note:: To replace this function with a custom function, do: AETrainer.eval = custom_function .. py:method:: train(dataset, epochs = 25) Basic image reconstruction training function for :term:`Autoencoder` models Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters :param dataset: The dataset to train on. Torch Dataset containing images in the first return position. :type dataset: Dataset :param epochs: Number of full training loops :type epochs: int, default 25 :returns: A list of average loss values for each :term:`epoch`. :rtype: List[float] .. note:: To replace this function with a custom function, do: AETrainer.train = custom_function