{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Out-of-Distribution (OOD) Detection Tutorial\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## _Problem Statement_\n", "\n", "For most computer vision tasks like **image classification** and **object detection**, out-of-distribution (OOD) detection can provide insight into operational drift, or training problems. A way to identify these is through autoencoding reconstruction error.\n", "\n", "To help with this, DataEval has an OOD detector that allows a user to identify these images.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### _When to use_\n", "\n", "The `OOD_AE` class and similar should be used when you would like to find individual images in a dataset which are the most different from the others in the provided set.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### _What you will need_\n", "\n", "1. A training image dataset with the approximate percentage of known OOD images.\n", "2. A test image dataset to evaluate for OOD images.\n", "3. A python environment with the following packages installed:\n", " - `dataeval[torch]` or `dataeval[all]`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### _Setting up_\n", "\n", "Let's import the required libraries needed to set up a minimal working example\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove_cell" ] }, "outputs": [], "source": [ "try:\n", " import google.colab # noqa: F401\n", "\n", " # specify the version of DataEval (==X.XX.X) for versions other than the latest\n", " %pip install -q dataeval[torch]\n", "except Exception:\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from dataeval.detectors.ood.ae_torch import OOD_AE\n", "from dataeval.utils.torch.datasets import MNIST\n", "from dataeval.utils.torch.models import AE\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data\n", "\n", "We will use the MNIST dataset for this tutorial\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load in the training MNIST dataset and use the first 2000\n", "train_ds = MNIST(root=\"./data/\", train=True, download=True, size=2000, unit_interval=True, channels=\"channels_first\")\n", "\n", "# Split out the images and labels\n", "images, labels = train_ds.data, train_ds.targets\n", "input_shape = images[0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize the model\n", "\n", "Now, lets look at how to use DataEval's OOD detection methods. \n", "We will focus on a simple autoencoder network from our Alibi Detect provider\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "OOD_detectors = [\n", " OOD_AE(AE(input_shape=input_shape)),\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model\n", "\n", "Next we will train a model on the dataset.\n", "For better results, the epochs can be increased.\n", "We set the threshold to detect the most extreme 1% of training data as out-of-distribution.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for detector in OOD_detectors:\n", " print(f\"Training {detector.__class__.__name__}...\")\n", " detector.fit(images, threshold_perc=99, epochs=23)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test for OOD\n", "\n", "We have trained our detector on a dataset of digits. \n", "What happens when we give it corrupted images of digits (which we expect to be \"OOD\")?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "corruption = MNIST(\n", " root=\"./data\",\n", " train=True,\n", " download=False,\n", " size=2000,\n", " unit_interval=True,\n", " channels=\"channels_first\",\n", " corruption=\"translate\",\n", ")\n", "corrupted_images = corruption.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we evaluate the two datasets using the trained model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "results = {type(detector).__name__: np.mean(detector.predict(images).is_ood) for detector in OOD_detectors}\n", "print(results)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "corrupted = {type(detector).__name__: np.mean(detector.predict(corrupted_images).is_ood) for detector in OOD_detectors}\n", "print(corrupted)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove_cell" ] }, "outputs": [], "source": [ "### TEST ASSERTION CELL ###\n", "assert results[\"OOD_AE\"] < 0.05\n", "assert corrupted[\"OOD_AE\"] > 0.85" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Results\n", "\n", "We can see that the Autoencoder based OOD detector was able to identify most of the translated images as outliers.\n", "\n", "Depending on your needs, certain outlier detectors will work better than others under specific conditions.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv-3.11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }