dataeval.bias.Balance¶
-
class dataeval.bias.Balance(num_neighbors=
5, class_imbalance_threshold=0.3, factor_correlation_threshold=0.5)¶ Calculates mutual information (MI) between factors (class label, metadata, label/image properties).
Identifies imbalanced classes and highly correlated metadata factors based on mutual information thresholds.
- Parameters:¶
- num_neighbors : int, default 5¶
Number of points to consider as neighbors
- class_imbalance_threshold : float, default 0.3¶
Threshold for identifying imbalanced classes. Classes with MI above this threshold with any metadata factor are considered imbalanced.
- factor_correlation_threshold : float, default 0.5¶
Threshold for identifying highly correlated metadata factors. Factor pairs with MI above this threshold are considered highly correlated.
- factor_correlation_threshold¶
Threshold for identifying highly correlated metadata factors
- Type:¶
float
Notes
We use mutual_info_classif from sklearn since class label is categorical. mutual_info_classif outputs are consistent up to O(1e-4) and depend on a random seed. MI is computed differently for categorical and continuous variables.
Examples
Initialize the Balance class:
>>> balance = Balance()Specifying custom thresholds:
>>> balance = Balance(class_imbalance_threshold=0.2, factor_correlation_threshold=0.6)See also
sklearn.feature_selection.mutual_info_classif,sklearn.feature_selection.mutual_info_regression,sklearn.metrics.mutual_info_score- evaluate(data)¶
Compute mutual information between factors and identify imbalanced classes.
- Parameters:¶
- data : AnnotatedDataset[Any] or Metadata¶
Either an annotated dataset (which will be converted to Metadata) or preprocessed Metadata directly.
- Returns:¶
Three DataFrames containing MI scores and threshold flags: - balance: Global class-to-factor mutual information - factors: Inter-factor mutual information - classwise: Per-class-to-factor mutual information
- Return type:¶
Example
Return balance (mutual information) of factors with class_labels
>>> metadata = generate_random_metadata( ... labels=["doctor", "artist", "teacher"], ... factors={"age": [25, 30, 35, 45], "income": [50000, 65000, 80000], "gender": ["M", "F"]}, ... length=100, ... random_seed=175, ... )>>> balance = Balance() >>> result = balance.evaluate(metadata) >>> result.balance shape: (4, 2) ┌─────────────┬──────────┐ │ factor_name ┆ mi_value │ │ --- ┆ --- │ │ cat ┆ f64 │ ╞═════════════╪══════════╡ │ class_label ┆ 0.888187 │ │ age ┆ 0.251485 │ │ gender ┆ 0.00399 │ │ income ┆ 0.362771 │ └─────────────┴──────────┘>>> result.factors shape: (6, 4) ┌─────────┬─────────┬──────────┬───────────────┐ │ factor1 ┆ factor2 ┆ mi_value ┆ is_correlated │ │ --- ┆ --- ┆ --- ┆ --- │ │ cat ┆ cat ┆ f64 ┆ bool │ ╞═════════╪═════════╪══════════╪═══════════════╡ │ age ┆ gender ┆ 0.046483 ┆ false │ │ age ┆ income ┆ 0.078066 ┆ false │ │ gender ┆ age ┆ 0.046483 ┆ false │ │ gender ┆ income ┆ 0.047947 ┆ false │ │ income ┆ age ┆ 0.078066 ┆ false │ │ income ┆ gender ┆ 0.047947 ┆ false │ └─────────┴─────────┴──────────┴───────────────┘>>> result.classwise shape: (9, 4) ┌────────────┬─────────────┬──────────┬───────────────┐ │ class_name ┆ factor_name ┆ mi_value ┆ is_imbalanced │ │ --- ┆ --- ┆ --- ┆ --- │ │ cat ┆ cat ┆ f64 ┆ bool │ ╞════════════╪═════════════╪══════════╪═══════════════╡ │ artist ┆ age ┆ 0.301469 ┆ true │ │ artist ┆ gender ┆ 0.04493 ┆ false │ │ artist ┆ income ┆ 0.250237 ┆ false │ │ doctor ┆ age ┆ 0.164287 ┆ false │ │ doctor ┆ gender ┆ 0.095962 ┆ false │ │ doctor ┆ income ┆ 0.46587 ┆ true │ │ teacher ┆ age ┆ 0.137221 ┆ false │ │ teacher ┆ gender ┆ 0.018392 ┆ false │ │ teacher ┆ income ┆ 0.160404 ┆ false │ └────────────┴─────────────┴──────────┴───────────────┘