How to measure dataset sufficiency for image classification

This guide provides a beginner friendly how-to guide to anayze an image classification model’s hypothetical performance.

Estimated time to complete: 10 minutes

Relevant ML stages: Model Development

Relevant personas: ML Engineer

What you’ll do

  • Evaluate an image classification model’s performance with the MNIST dataset

  • Define a custom evaluation function with metrics of interest

  • Project the model’s performance over increasing sample sizes

What you’ll learn

  • Learn to evaluate a model’s limits for different metrics with the MNIST dataset

  • Learn to determine how many samples are required to reach specific performance thresholds

Problem Statement

For machine learning tasks, often we would like to evaluate the performance of a model on a small, preliminary dataset. In situations where data collection is expensive, we would like to extrapolate hypothetical performance out to a larger dataset.

DataEval has introduced a method projecting performance via sufficiency curves.

When to use

The Sufficiency class should be used when you would like to extrapolate hypothetical performance. For example, if you have a small dataset, and would like to know if it is worthwhile to collect more data.

What you will need

  1. A particular model architecture.

  2. Metric(s) that we would like to evaluate.

  3. A dataset of interest.

  4. A Python environment with the following packages installed:

    • tabulate

Setting up

Let’s import the required libraries needed to set up a minimal working example

import os
from collections.abc import Sequence
from typing import Any, cast

import dataeval_plots as dep
import numpy as np
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics
from IPython.display import display  # noqa: A004
from maite_datasets.image_classification import MNIST
from numpy.typing import NDArray
from tabulate import tabulate
from torch.utils.data import DataLoader, Subset
from torch.utils.data import Dataset as TorchDataset

from dataeval import config
from dataeval.performance import Sufficiency, SufficiencyConfig
from dataeval.protocols import Dataset, DatumMetadata
from dataeval.selection import Limit, Select

DatumType = tuple[NDArray[np.number[Any]], NDArray[np.number[Any]], DatumMetadata]

# Set seed for reproducibility
config.set_seed(0, all_generators=True)

# Set hardware based on system
device = "cuda" if torch.cuda.is_available() else "cpu"
config.set_device(device=device)

# Additional reproducibility and printing options
np.set_printoptions(formatter={"float": lambda x: f"{x:0.4f}"})
torch.set_float32_matmul_precision("high")
torch._dynamo.config.suppress_errors = True
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# Use plotly to render plots
dep.set_default_backend("plotly")

# Use the notebook renderer so JS is embedded
pio.renderers.default = "notebook"

Load data and create model

Before calculating the sufficiency of a dataset, the dataset must be loaded and the model architecture defined. We will walk through these in the following steps.

Loading MNIST data

Load the MNIST data and split it into training and test datasets. For this notebook, we will use subsets of the training (2500) and test (500) data.

# Configure the dataset transforms

transforms = [
    lambda x: x / 255.0,  # scale to [0, 1]
    lambda x: x.astype(np.float32),  # convert to float32
]

# Download the mnist dataset and apply the transforms and subset the data
train_ds = Select(MNIST(root="./data", image_set="train", transforms=transforms,download=True),selections=[Limit(2500)])  # fmt: skip # noqa: E501
test_ds = Select(MNIST(root="./data", image_set="test", transforms=transforms, download=True), selections=[Limit(500)])

Creating a PyTorch model

Next, we define the network architecture that will be trained and then evaluated throughout the sufficiency calculation.

# Define our network architecture
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(6400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Compile the model (cast sets the type to Net as compile returns an Unknown)
model: Net = cast(Net, torch.compile(Net().to(device)))

Strategy Protocols

Training and evaluation functions are heavily dependent on the hyperparameters defined by a user. These can include metrics, loss functions, optimizers, model architectures, input sizes, etc.

To allow the Sufficiency class to handle this situation, DataEval uses Protocols. Sufficiency requires two specific protocols called TrainingStrategy and EvaluationStrategy.
Below we will define the strategies that align with this notebook and combine them into a SufficiencyConfig that can be given to the Sufficiency class.

Training strategy

class MNISTTrainingStrategy:
    def train(self, model: nn.Module, dataset: Dataset[DatumType], indices: Sequence[int]):
        # Defined only for this testing scenario
        criterion = torch.nn.CrossEntropyLoss().to(device)
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        epochs = 10

        # Define the dataloader for training
        dataloader = DataLoader(Subset(cast(TorchDataset, dataset), indices), batch_size=8)

        for epoch in range(epochs):
            for batch in dataloader:
                # Load data/images to device
                X = torch.Tensor(batch[0]).to(device)
                # Load one-hot encoded targets/labels to device
                y = torch.argmax(torch.asarray(batch[1], dtype=torch.int).to(device), dim=1)
                # Zero out gradients
                optimizer.zero_grad()
                # Forward propagation
                outputs = model(X)
                # Compute loss
                loss = criterion(outputs, y)
                # Back prop
                loss.backward()
                # Update weights/parameters
                optimizer.step()

Evaluation strategy

class MNISTEvaluationStrategy:
    def evaluate(self, model: nn.Module, dataset: Dataset[DatumType]) -> dict[str, float]:
        # Metrics of interest
        metrics = {
            "Accuracy": torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device),
            "AUROC": torchmetrics.AUROC(task="multiclass", num_classes=10).to(device),
            "TPR at 0.5 Fixed FPR": torchmetrics.ROC(task="multiclass", average="macro", num_classes=10).to(device),
        }
        result = {}
        # Set model layers into evaluation mode
        model.eval()
        dataloader = DataLoader(cast(TorchDataset, dataset), batch_size=8)
        # Tell PyTorch to not track gradients, greatly speeds up processing
        with torch.no_grad():
            for batch in dataloader:
                # Load data/images to device
                X = torch.Tensor(batch[0]).to(device)
                # Load one-hot encoded targets/labels to device
                y = torch.argmax(torch.asarray(batch[1], dtype=torch.int).to(device), dim=1)
                preds = model(X)
                for metric in metrics.values():
                    metric.update(preds, y)
            # Compute ROC curve
            false_positive_rate, true_positive_rate, _ = metrics["TPR at 0.5 Fixed FPR"].compute()
            # determine interval to examine
            desired_rate = 0.5
            closest_desired_index = torch.argmin(torch.abs(false_positive_rate - desired_rate)).item()
            # return corresponding tpr value
            result["TPR at 0.5 Fixed FPR"] = true_positive_rate[closest_desired_index].cpu()
            result["Accuracy"] = metrics["Accuracy"].compute().cpu()
            result["AUROC"] = metrics["AUROC"].compute().cpu()
        return result

Sufficiency config

Do not forget to initialize your strategy classes!

mnist_config = SufficiencyConfig(
    training_strategy=MNISTTrainingStrategy(),
    evaluation_strategy=MNISTEvaluationStrategy(),
    runs=5,
    substeps=10,
)

Initialize sufficiency metric

Attach the custom training and evaluation functions to the Sufficiency metric and define the number of models to train in parallel (stability), as well as the number of steps along the learning curve to evaluate.

# Instantiate sufficiency metric
suff = Sufficiency(
    model=model,
    train_ds=train_ds,
    test_ds=test_ds,
    config=mnist_config,
)

Evaluate Sufficiency

Now we can evaluate the metric to train the models and produce the learning curve.

# Train & test model
output = suff.evaluate()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] WON'T CONVERT forward /tmp/ipykernel_1705/1189723111.py line 11 
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] due to: 
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] Traceback (most recent call last):
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     result = self._inner_convert(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         frame, cache_entry, hooks, frame_state, skip=skip + 1
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _compile(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         frame.f_code,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<14 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         skip=skip + 1,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     guarded_code = compile_inner(code, one_graph, hooks, transform)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _compile_inner(code, one_graph, hooks, transform)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return function(*args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     out_code = transform_code_object(code, transform)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     transformations(instructions, code_options)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return fn(*args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     tracer.run()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     super().run()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     while self.step():
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]           ~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.dispatch_table[inst.opcode](self, inst)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self._return(inst)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.output.compile_subgraph(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         self,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<2 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ),
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1101, in compile_subgraph
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.compile_and_call_fx_graph(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         tx, list(reversed(stack_values)), root, output_replacements
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn = self.call_user_compiler(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return self._call_user_compiler(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~~~~~~~~~~~~~~~^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         e.__traceback__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ) from None
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn = compiler_fn(gm, self.example_inputs())
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_gm = compiler_fn(gm, example_inputs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/__init__.py", line 2340, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return compile_fx(model_, inputs_, config_patches=self.config)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1863, in compile_fx
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return aot_autograd(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~~~~
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<6 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         cudagraphs=cudagraphs,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ~~~~~~~~~~~~~~~~~~~~~~
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )(model_, example_inputs_)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/backends/common.py", line 83, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 1155, in aot_module_simplified
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn = dispatch_and_compile()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn, _ = create_aot_dispatcher_function(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         functional_call,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<3 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         shape_env,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _create_aot_dispatcher_function(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn, fw_metadata = compiler_fn(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                                ~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         flat_fn,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<2 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         fw_metadata=fw_metadata,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 449, in aot_dispatch_autograd
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     fw_module, bw_module = aot_config.partition_fn(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                            ~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1779, in partition_fn
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     _recursive_joint_graph_passes(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 322, in _recursive_joint_graph_passes
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     joint_graph_passes(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 468, in joint_graph_passes
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ).apply_graph_pass(patterns.apply)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]       ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/fx/passes/graph_transform_observer.py", line 70, in apply_graph_pass
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return pass_fn(self.gm.graph)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py", line 1773, in apply
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     if is_match(m) and entry.extra_check(m):
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                        ~~~~~~~~~~~~~~~~~^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py", line 1352, in check_fn
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                                             ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 708, in should_pad_mm
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return should_pad_common(mat1, mat2) and should_pad_bench(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                                              ~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         match, mat1, mat2, torch.ops.aten.mm
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 386, in should_pad_bench
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _should_pad_bench(*args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 594, in _should_pad_bench
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ori_time = do_bench(orig_bench_fn)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/runtime/benchmarking.py", line 66, in wrapper
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return fn(self, *args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/runtime/benchmarking.py", line 202, in benchmark_gpu
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return self.triton_do_bench(_callable, **kwargs, return_mode="median")
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/testing.py", line 115, in do_bench
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     di = runtime.driver.active.get_device_interface()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/driver.py", line 23, in __getattr__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self._initialize_obj()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/driver.py", line 20, in _initialize_obj
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self._obj = self._init_fn()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                 ~~~~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/driver.py", line 9, in _create_driver
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return actives[0]()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/backends/nvidia/driver.py", line 450, in __init__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.utils = CudaUtils()  # TODO: make static
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                  ~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/backends/nvidia/driver.py", line 80, in __init__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/backends/nvidia/driver.py", line 57, in compile_module_from_src
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/build.py", line 32, in _build
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] RuntimeError: Failed to find C compiler. Please specify via CC environment variable.
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] 
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] 
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] Traceback (most recent call last):
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     result = self._inner_convert(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         frame, cache_entry, hooks, frame_state, skip=skip + 1
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _compile(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         frame.f_code,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<14 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         skip=skip + 1,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     guarded_code = compile_inner(code, one_graph, hooks, transform)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _compile_inner(code, one_graph, hooks, transform)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return function(*args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     out_code = transform_code_object(code, transform)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     transformations(instructions, code_options)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return fn(*args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     tracer.run()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     super().run()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     while self.step():
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]           ~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.dispatch_table[inst.opcode](self, inst)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self._return(inst)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.output.compile_subgraph(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         self,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<2 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ),
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1101, in compile_subgraph
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.compile_and_call_fx_graph(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         tx, list(reversed(stack_values)), root, output_replacements
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn = self.call_user_compiler(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return self._call_user_compiler(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~~~~~~~~~~~~~~~^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         e.__traceback__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ) from None
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn = compiler_fn(gm, self.example_inputs())
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_gm = compiler_fn(gm, example_inputs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/__init__.py", line 2340, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return compile_fx(model_, inputs_, config_patches=self.config)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1863, in compile_fx
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return aot_autograd(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~~~~
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<6 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         cudagraphs=cudagraphs,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ~~~~~~~~~~~~~~~~~~~~~~
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )(model_, example_inputs_)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_dynamo/backends/common.py", line 83, in __call__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 1155, in aot_module_simplified
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn = dispatch_and_compile()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn, _ = create_aot_dispatcher_function(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         functional_call,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<3 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         shape_env,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _create_aot_dispatcher_function(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     compiled_fn, fw_metadata = compiler_fn(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                                ~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         flat_fn,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ...<2 lines>...
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         fw_metadata=fw_metadata,
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 449, in aot_dispatch_autograd
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     fw_module, bw_module = aot_config.partition_fn(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                            ~~~~~~~~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1779, in partition_fn
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     _recursive_joint_graph_passes(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 322, in _recursive_joint_graph_passes
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     joint_graph_passes(gm)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 468, in joint_graph_passes
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ).apply_graph_pass(patterns.apply)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]       ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/fx/passes/graph_transform_observer.py", line 70, in apply_graph_pass
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return pass_fn(self.gm.graph)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py", line 1773, in apply
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     if is_match(m) and entry.extra_check(m):
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                        ~~~~~~~~~~~~~~~~~^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py", line 1352, in check_fn
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     if is_match(specific_pattern_match) and extra_check(specific_pattern_match):
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                                             ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 708, in should_pad_mm
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return should_pad_common(mat1, mat2) and should_pad_bench(
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                                              ~~~~~~~~~~~~~~~~^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         match, mat1, mat2, torch.ops.aten.mm
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     )
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 386, in should_pad_bench
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return _should_pad_bench(*args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/fx_passes/pad_mm.py", line 594, in _should_pad_bench
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ori_time = do_bench(orig_bench_fn)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/runtime/benchmarking.py", line 66, in wrapper
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return fn(self, *args, **kwargs)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/torch/_inductor/runtime/benchmarking.py", line 202, in benchmark_gpu
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return self.triton_do_bench(_callable, **kwargs, return_mode="median")
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/testing.py", line 115, in do_bench
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     di = runtime.driver.active.get_device_interface()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/driver.py", line 23, in __getattr__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self._initialize_obj()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     ~~~~~~~~~~~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/driver.py", line 20, in _initialize_obj
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self._obj = self._init_fn()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                 ~~~~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/driver.py", line 9, in _create_driver
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     return actives[0]()
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]            ~~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/backends/nvidia/driver.py", line 450, in __init__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     self.utils = CudaUtils()  # TODO: make static
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]                  ~~~~~~~~~^^
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/backends/nvidia/driver.py", line 80, in __init__
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/backends/nvidia/driver.py", line 57, in compile_module_from_src
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]   File "/builds/jatic/aria/dataeval/.nox/docs/lib/python3.13/site-packages/triton/runtime/build.py", line 32, in _build
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233]     raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] RuntimeError: Failed to find C compiler. Please specify via CC environment variable.
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] 
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
W0115 08:16:26.897000 1705 torch/_dynamo/convert_frame.py:1233] 
# Print out sufficiency output in a table format
formatted = {"Steps": output.steps, **output.averaged_measures}
print(tabulate(formatted, headers=list(formatted), tablefmt="pretty"))
+-------+----------------------+---------------------+--------------------+
| Steps | TPR at 0.5 Fixed FPR |      Accuracy       |       AUROC        |
+-------+----------------------+---------------------+--------------------+
|  25   |  0.785096275806427   | 0.22600000202655793 | 0.7452823281288147 |
|  41   |  0.904829478263855   | 0.5143999934196473  | 0.8521184682846069 |
|  69   |  0.9653781175613403  | 0.6363999962806701  | 0.9167528867721557 |
|  116  |  0.9803223729133606  |  0.706000006198883  | 0.9414556980133056 |
|  193  |  0.9876001358032227  | 0.7616000056266785  | 0.9595475792884827 |
|  322  |  0.9971008777618409  | 0.8248000144958496  |  0.97947758436203  |
|  538  |  0.9978660106658935  | 0.8824000000953675  | 0.9869886755943298 |
|  898  |         1.0          |         0.9         | 0.9915173768997192 |
| 1498  |  0.9988282799720765  | 0.8968000054359436  | 0.9905665993690491 |
| 2500  |         1.0          | 0.9351999878883361  | 0.9954016089439393 |
+-------+----------------------+---------------------+--------------------+
# Print out projected output values
projection = output.project([1000, 2500, 5000])
projected = {"Steps": projection.steps, **projection.averaged_measures}
print(tabulate(projected, list(projected), tablefmt="pretty"))
+-------+----------------------+--------------------+--------------------+
| Steps | TPR at 0.5 Fixed FPR |      Accuracy      |       AUROC        |
+-------+----------------------+--------------------+--------------------+
| 1000  |   0.99852820559649   | 0.8947735116932173 | 0.9898820089745924 |
| 2500  |  0.998855574889278   | 0.9146385101854015 | 0.9928518526713005 |
| 5000  |  0.9989164160917727  | 0.9226150775593616 | 0.993784323048578  |
+-------+----------------------+--------------------+--------------------+
for name, values in output.averaged_measures.items():
    print(abs(values[-1] - projection.averaged_measures[name][-2]))
0.0011444251107219916
0.020561477702934594
0.002549756272638759
# Plot the output using dataeval-plots library
for plot in dep.plot(output, backend="plotly"):
    display(plot)

Results

Using these learning curves, we can project performance under much larger datasets (with the same models).

Predicting sample requirements

We can also predict the amount of training samples required to achieve specific performance thresholds.

Let’s say we wanted to see how many samples are needed to hit 90%, 93%, and 99% accuracy, area under the receiver operating characteristic, and true positive rate at a fixed false positive rate of 0.5.

# Initialize the array of desired thresholds to apply to all metrics
desired_values = np.array([0.90, 0.93, 0.99])
metrics = ["Accuracy", "AUROC", "TPR at 0.5 Fixed FPR"]
evaluated_metrics = {}

for metric in metrics:
    evaluated_metrics[metric] = desired_values
# Evaluate the learning curve to infer the needed amount of training data
samples_needed = output.inv_project(evaluated_metrics)
# Print the amount of needed data needed to achieve the thresholds
for metric, samples in samples_needed.items():
    print(f"{metric}")
    for index, sample_size in enumerate(samples):
        print(
            f"To achieve {int(evaluated_metrics[metric][index] * 100)}% {metric},"
            f" {int(sample_size)} samples are needed."
        )
    print()
Accuracy
To achieve 90% Accuracy, 1203 samples are needed.
To achieve 93% Accuracy, 20040 samples are needed.
To achieve 99% Accuracy, -1 samples are needed.

AUROC
To achieve 90% AUROC, 61 samples are needed.
To achieve 93% AUROC, 87 samples are needed.
To achieve 99% AUROC, 1023 samples are needed.

TPR at 0.5 Fixed FPR
To achieve 90% TPR at 0.5 Fixed FPR, 39 samples are needed.
To achieve 93% TPR at 0.5 Fixed FPR, 48 samples are needed.
To achieve 99% TPR at 0.5 Fixed FPR, 163 samples are needed.

With a value of “-1” samples, the projection shows that given the current model, hitting an accuracy of 99% is improbable.