Class Parity Label Analysis Tutorial

Problem Statement

For machine learning tasks, a discrepancy in label frequencies between train and test datasets can result in poor model performance.

To help with this, DataEval has a tool that compares the label distributions of two datasets.

When to use

The Parity class and similar should be used when you would like to determine if two datasets have statistically independent labels.

What you will need

  1. A labeled training image dataset

  2. A labeled test image dataset to evaluate the label distribution of

Setting up

Let’s import the required libraries needed to set up a minimal working example

import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms.v2 as v2

from dataeval.metrics import Parity

Load the data

We will use the MNIST dataset from torchvision for this tutorial on class label statistics

to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=to_tensor)
test_ds = datasets.MNIST("./data", train=False, download=True, transform=to_tensor)

# Take a subset of 2000 training images and 500 test images
train_labels = np.int64(train_ds.targets[:2000])
test_labels = np.int64(test_ds.targets[:500])

Initialize the metric

Now, let’s look at how to use DataEval’s label statistics analyzer. Start by initializing a Parity object and specifying the two datasets to be compared, as well as the number of unique classes (for MNIST, there are 10 unique classes)

lsi = Parity(train_labels, test_labels)

Evaluate label statistical independence

Compute the chi-squared value of hypothesis that test_ds has the same class distribution as train_ds by using get_chisquared(). It also returns the p-value of the test.

chisquared, p = lsi.evaluate()
print(f"The chi-squared value for the two label distributions is {chisquared}, with p-value {p}")
The chi-squared value for the two label distributions is 6.784849368367514, with p-value 0.6595083107993285