How to detect uncertainty drift with a MAITE model

Problem statement

When a model is deployed, the model itself is a sensitive instrument for distribution shift: as operational data moves into regions the model was not trained on, its predictions grow less confident. The Shannon entropy of the predicted class probabilities summarizes that confidence — low entropy is a peaked, confident prediction; high entropy means the model is hedging. If the distribution of per-prediction entropy shifts upward, the model is increasingly operating in its uncertainty regions: a strong, label-free drift signal. See the uncertainty-based drift detection concept page for the theory and trade-offs.

To turn predictions into an uncertainty feature, DataEval needs to decode the model’s raw output into a (n_predictions, n_classes) array of class scores. The drift-with-uncertainty tutorial does this by wrapping a model in TorchExtractor with a hand-written postprocess_fn specific to that model’s output format.

This guide shows the alternative for models that conform to DataEval’s opinionated ONNX/LiteRT contract: a OnnxImageClassifier (a MAITE image_classification.Model) paired with ScoresExtractor decodes the predictions for you — no postprocess_fn required — and feeds straight into UncertaintyExtractor and a drift detector.

When to use

Use the OnnxImageClassifier + ScoresExtractor path when you want to:

  • Monitor a deployed ONNX or LiteRT model for distribution shift without operational ground-truth labels

  • Reuse a model that already ships with a model-metadata.json describing its input/output contract

  • Avoid writing and maintaining a model-specific postprocess_fn to decode raw predictions

The same pattern works for object detection with OnnxObjectDetector; ScoresExtractor flattens per-detection scores so the rest of the workflow is identical.

What you will need

  1. An ONNX (or LiteRT) model that conforms to the opinionated contract, plus its model-metadata.json. In production you bring your own; in the Build a model section below you export a small one so this guide is self-contained.

  2. A reference dataset (in-distribution) and an operational stream to monitor. This guide uses MNIST as in-distribution and CIFAR-10 as the shifted stream.

  3. A Python environment with the following packages installed:

    • dataeval[onnx]

    • maite-datasets

    • torch (only to export the demo model)

Getting started

Import the libraries needed for a minimal working example.

import json
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from maite_datasets.image_classification import CIFAR10, MNIST
from torch import nn

from dataeval.config import set_seed
from dataeval.data import Indices, Limit, Select, Shuffle
from dataeval.extractors import ScoresExtractor, UncertaintyExtractor
from dataeval.models import OnnxImageClassifier
from dataeval.shift import DriftWasserstein

set_seed(0, all_generators=True)

Build a conforming model for the demo

OnnxImageClassifier wraps two artifacts you would normally already have for a deployed model:

  • the model file (.onnx or .tflite for LiteRtImageClassifier), and

  • a model-metadata.json declaring the input/output contract — task, input channels and size, and the number of classes the model scores.

Because this guide is self-contained, you quickly train a small MNIST classifier and export it. This block is demo scaffolding — in your own workflow you would skip it and point OnnxImageClassifier at the model and metadata you already have.

The exported model emits a softmax scores output (probabilities), which is exactly what the opinionated classification contract expects.

class TinyCNN(nn.Module):
    """A small MNIST classifier emitting raw logits."""

    def __init__(self) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class WithSoftmax(nn.Module):
    """Wrap a logits model so its ONNX ``scores`` output is a probability distribution."""

    def __init__(self, base: nn.Module) -> None:
        super().__init__()
        self.base = base

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.softmax(self.base(x), dim=1)
# Train the demo classifier on a small slice of MNIST (a few seconds on CPU).
mnist_train = MNIST("./data", image_set="train", download=True)
train_ids = np.random.default_rng(0).permutation(len(mnist_train))[:4000]
x_train = torch.tensor(np.stack([np.asarray(mnist_train[i][0], dtype=np.float32) / 255.0 for i in train_ids]))
y_train = torch.tensor(np.stack([np.asarray(mnist_train[i][1]).argmax() for i in train_ids]))

model = TinyCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
model.train()
for _ in range(3):
    order = torch.randperm(len(x_train))
    for start in range(0, len(x_train), 128):
        batch = order[start : start + 128]
        optimizer.zero_grad()
        loss_fn(model(x_train[batch]), y_train[batch]).backward()
        optimizer.step()
model.eval()
TinyCNN(
  (net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1568, out_features=10, bias=True)
  )
)

Export the trained model to ONNX and write the matching model-metadata.json. The metadata declares a single-channel (GRAYSCALE) 28x28 input and 10 output classes — the contract OnnxImageClassifier reads to build model input and validate output.

Path("data").mkdir(exist_ok=True)
model_path = "data/mnist-demo.onnx"
metadata_path = "data/mnist-demo-metadata.json"

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    torch.onnx.export(
        WithSoftmax(model).eval(),
        (torch.zeros(1, 1, 28, 28),),
        model_path,
        input_names=["image"],
        output_names=["scores"],
        dynamic_axes={"image": {0: "batch"}, "scores": {0: "batch"}},
        opset_version=13,
        dynamo=False,
    )

_ = Path(metadata_path).write_text(
    json.dumps({
        "interface": {"name": "JATIC_ONNX", "version": "v1"},
        "io": {
            "batchSize": -1,
            "interface": "IMAGE_CLASSIFICATION",
            "input": {"channels": "GRAYSCALE", "height": 28, "width": 28},
            "output": {"nClasses": 10},
        },
    }),
    encoding="utf-8",
)

Load the datasets

You need three exchangeable in-distribution slices plus a shifted stream:

  • train and val — two disjoint, same-distribution slices of MNIST; the train-vs-val distance calibrates the baseline for “normal” variation between in-distribution samples.

  • test — a third disjoint MNIST slice held out from calibration; because it is exchangeable with train and val, the detector should report no drift here.

  • cifar — CIFAR-10 imagery, a genuinely shifted operational stream (natural color photos rather than handwritten digits); the detector should report drift here.

The opinionated input builder handles the format differences automatically: CIFAR-10’s RGB 32x32 images are converted to grayscale and resized to the model’s 28x28 input per the metadata contract — no manual preprocessing needed.

SAMPLES_PER_SPLIT = 200

mnist_test = MNIST("./data", image_set="test", download=True)

# Three disjoint, exchangeable MNIST slices: two references that calibrate the baseline
# (train, val) and a held-out set that should NOT drift (test).
perm = np.random.default_rng(1).permutation(len(mnist_test))
train_idx, val_idx, test_idx = np.array_split(perm[: 3 * SAMPLES_PER_SPLIT], 3)
trainset = Select(mnist_test, Indices(train_idx.tolist()))
valset = Select(mnist_test, Indices(val_idx.tolist()))
testset = Select(mnist_test, Indices(test_idx.tolist()))

# Shifted "operational" data: CIFAR-10 natural images
cifarset = Select(CIFAR10("./data", image_set="test", download=True), [Shuffle(0), Limit(SAMPLES_PER_SPLIT)])

print(f"train: {len(trainset)}  val: {len(valset)}  test: {len(testset)}  cifar: {len(cifarset)}")
train: 200  val: 200  test: 200  cifar: 200

Build the uncertainty feature extractor

This is the core of the guide. Three small pieces compose into a single feature extractor:

  1. OnnxImageClassifier loads the model and its metadata and runs inference, returning one (n_classes,) score array per image.

  2. ScoresExtractor adapts that MAITE Model into a FeatureExtractor, stacking per-image scores into an (n_images, n_classes) array.

  3. UncertaintyExtractor converts each row of class probabilities into a single normalized-entropy value.

Note what is absent: there is no postprocess_fn. The opinionated classifier already speaks the (instances, classes) contract ScoresExtractor expects, so the decoding the tutorial does by hand is handled for you. Because the model outputs a softmax distribution, use preds_type="probs".

classifier = OnnxImageClassifier(model_path, metadata_path)
scores = ScoresExtractor(classifier)
uncertainty = UncertaintyExtractor(scores, preds_type="probs", normalize=True)

print(uncertainty)
2026-07-02 17:18:06.747394233 [W:onnxruntime:Default, device_discovery.cc:325 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:92 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"
UncertaintyExtractor(scores=ScoresExtractor, preds_type='probs', normalize=True)

Compute uncertainty for each split

Calling the extractor on a dataset returns an (n_images, 1) array of normalized-entropy values — one per image.

h_train = uncertainty(trainset)
h_val = uncertainty(valset)
h_test = uncertainty(testset)
h_cifar = uncertainty(cifarset)

print(
    f"mean entropy -> train: {h_train.mean():.3f}  val: {h_val.mean():.3f}  "
    f"test: {h_test.mean():.3f}  cifar: {h_cifar.mean():.3f}"
)
mean entropy -> train: 0.183  val: 0.190  test: 0.191  cifar: 0.452

Visualize the uncertainty distributions

An empirical CDF (ECDF) makes the shift easy to read: if the CIFAR-10 curve sits to the right of the MNIST curves, the model is systematically more uncertain on the shifted imagery. The in-distribution splits (train, val, test) should overlap closely.

plt.figure()
for label, h in [("train", h_train), ("val", h_val), ("test", h_test), ("cifar", h_cifar)]:
    plt.ecdf(h.flatten(), label=label)
plt.xlabel("Normalized entropy")
plt.ylabel("Cumulative density")
plt.title("Prediction uncertainty by split")
plt.legend()
plt.show()
../_images/9cdce83228a6e7d0552014ef244890b3587d1a826edc756981472e44df29fb13.png

The MNIST splits cluster together while CIFAR-10 is visibly shifted toward higher entropy — exactly what you expect from a classifier encountering data outside its training distribution.

Detect drift with Wasserstein distance

Eyeballing ECDFs is useful, but you want an automated decision. DriftWasserstein measures the Wasserstein distance between the reference and incoming uncertainty distributions and flags drift when that distance grows beyond a calibrated baseline.

Unlike most detectors, DriftWasserstein takes two in-distribution references in fit(): a training set and a validation set. The train-vs-validation distance defines the normal amount of variation between two same-distribution samples, so the detector only alarms when incoming data is more different from training than validation was.

drift_detector = DriftWasserstein().fit(np.asarray(h_train), np.asarray(h_val))

result_test = drift_detector.predict(np.asarray(h_test))
print(f"MNIST held-out -> drift: {result_test.drifted}  (ratio: {result_test.distance:.2f})")

result_cifar = drift_detector.predict(np.asarray(h_cifar))
print(f"CIFAR-10       -> drift: {result_cifar.drifted}  (ratio: {result_cifar.distance:.2f})")
MNIST held-out -> drift: False  (ratio: 1.07)
CIFAR-10       -> drift: True  (ratio: 18.48)

The held-out MNIST slice stays within the calibrated baseline (no drift), while CIFAR-10 exceeds it (drift detected). The detector turned the visual gap in the ECDF into an automated, label-free alert — and you never had to write a decoder for the model’s output.

On your own

See also

Tutorials

How-to guides