dataeval.core.mutual_info_classwise

dataeval.core.mutual_info_classwise(class_labels, factor_data, discrete_features=None, num_neighbors=5)

Compute normalized mutual information (NMI) between factors.

Factors include class label, metadata, and label/image properties.

Parameters:
class_labels : Array1D[int]

Target class labels as integer indices. Can be a 1D list, or array-like object.

factor_data : Array2D[int | float]

Factor values after binning or digitization. Can be a 1D list, or array-like object.

discrete_features : Array1D[bool] | None = None

Boolean array or iterable defining whether or not the feature set is discretized. Can be a 1D list, or array-like object.

num_neighbors : int = 5

Number of points to consider as neighbors.

Returns:

(num_classes) x (num_factors+1) estimate of normalized mutual information between num_factors metadata factors and class label. Symmetry is enforced.

Return type:

NDArray[np.float64]

See also

sklearn.feature_selection.mutual_info_classif, sklearn.feature_selection.mutual_info_regression, sklearn.metrics.mutual_info_score

Notes

We use mutual_info_classif from sklearn since class label is categorical. mutual_info_classif outputs are consistent up to O(1e-4) and depend on a random seed. MI is computed differently for categorical and continuous variables. In all cases, we return either a normalization or transformation of MI onto the interval [0, 1].

Example

Return classwise balance (normalized mutual information) of factors with individual class_labels

>>> rng = np.random.default_rng(175)
>>> class_labels = rng.choice([0, 1, 2], size=100)
>>> factor_data = np.column_stack([
...     rng.choice([25, 35, 45, 55], size=100),  # age
...     rng.choice([50000, 65000, 80000], size=100),  # income
...     rng.choice([0, 1], size=100),  # gender
... ])
>>> mutual_info_classwise(class_labels=class_labels, factor_data=factor_data)
array([[1.000e+00, 2.077e-02, 2.296e-03, 7.317e-04],
       [1.000e+00, 4.893e-02, 2.451e-02, 4.362e-03],
       [1.000e+00, 1.868e-02, 3.820e-02, 1.006e-03]])