dataeval.shift.OODDomainClassifier¶
-
class dataeval.shift.OODDomainClassifier(n_folds=
None, n_repeats=None, n_std=None, hyperparameters=None, config=None)¶ Domain Classifier based Out-of-Distribution detector.
Uses a LightGBM classifier’s ability to distinguish test samples from reference samples as an OOD signal. Samples that a classifier can easily identify as “not reference” are likely OOD.
During
fit(), establishes a null distribution of per-point class-1 prediction rates by running repeated k-fold CV on internal splits of the reference data. The threshold is set asmean + n_std * stdof this null distribution.During
predict()/score(), treats test data as class 1 and reference as class 0, runs repeated k-fold CV, and returns per-point class-1 rates. Points with rates exceeding the threshold are flagged OOD.- Parameters:¶
- n_folds : int, default 5¶
Number of cross-validation folds per repeat.
- n_repeats : int, default 5¶
Number of times to repeat the k-fold split.
- n_std : float, default 2.0¶
Number of standard deviations above the null mean for threshold.
- hyperparameters : dict or None, default None¶
LightGBM hyperparameters.
- config : OODDomainClassifier.Config or None, default None¶
Optional configuration object.
Examples
>>> ref = np.random.randn(200, 8).astype(np.float32) >>> test = np.random.randn(50, 8).astype(np.float32) + 3 >>> detector = OODDomainClassifier(n_folds=3, n_repeats=3) >>> detector.fit(ref) >>> predictions = detector.predict(test)-
fit(x_ref, threshold_perc=
None)¶ Fit the detector using reference (in-distribution) data.
Computes a null distribution of class-1 prediction rates by splitting the reference data internally (half as pseudo-class-0, half as pseudo-class-1) and running repeated k-fold CV. The OOD threshold is derived from this null distribution.
-
predict(x, batch_size=
int(10000000000.0), ood_type='instance')¶ Predict whether instances are out of distribution.
-
score(x, batch_size=
int(10000000000.0))¶ Compute out of distribution scores for a given dataset.
Classes¶
Configuration for OODDomainClassifier. |