Class Parity Label Analysis Tutorial#
Problem Statement#
For machine learning tasks, a discrepancy in label frequencies between train and test datasets can result in poor model performance.
To help with this, DataEval has a tool that compares the label distributions of two datasets.
When to use#
The Parity class and similar should be used when you would like to determine if two datasets have statistically independent labels.
What you will need#
A labeled training image dataset
A labeled test image dataset to evaluate the label distribution of
A python environment with the following packages installed:
dataeval[torch]ordataeval[all]torchvision
Setting up#
Let’s import the required libraries needed to set up a minimal working example
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2
from dataeval.metrics.bias import label_parity
Load the data#
We will use the MNIST dataset from torchvision for this tutorial on class label statistics
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=to_tensor)
test_ds = datasets.MNIST("./data", train=False, download=True, transform=to_tensor)
# Take a subset of 2000 training images and 500 test images
train_labels = np.int64(train_ds.targets[:2000])
test_labels = np.int64(test_ds.targets[:500])
Evaluate label statistical independence#
Now, let’s look at how to use DataEval’s label statistics analyzer.
Start by initializing a Parity object. Compute the chi-squared value of hypothesis that test_ds has the same class distribution as train_ds by specifying the two datasets to be compared, as well as the number of unique classes (for MNIST, there are 10 unique classes). It also returns the p-value of the test.
results = label_parity(train_labels, test_labels)
print(f"The chi-squared value for the two label distributions is {results.score}, with p-value {results.p_value}")
The chi-squared value for the two label distributions is 6.784849368367514, with p-value 0.6595083107993285