dataeval.protocols.UpdateStrategy¶
- class dataeval.protocols.UpdateStrategy¶
Protocol defining the interface for updating reference data in drift detectors.
Update strategies control how drift detectors maintain their reference dataset as new data arrives. Implementations must provide a __call__ method that updates the reference data based on new observations.
Examples
Creating a custom update strategy that keeps a moving average:
>>> import numpy as np >>> from numpy.typing import NDArray >>> from dataeval.utils.arrays import flatten_samples >>> >>> class MovingAverageUpdate: ... '''Update strategy that maintains a moving average of reference data.''' ... ... def __init__(self, n: int, alpha: float = 0.9) -> None: ... ''' ... Parameters ... ---------- ... n : int ... Maximum number of samples to maintain ... alpha : float, default 0.9 ... Exponential moving average weight (0 < alpha < 1) ... ''' ... self.n = n ... self.alpha = alpha ... ... def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32]) -> NDArray[np.float32]: ... ''' ... Update reference data with exponential moving average. ... ... Parameters ... ---------- ... x_ref : NDArray[np.float32] ... Current reference data of shape (n_ref, n_features) ... x_new : NDArray[np.float32] ... New observations of shape (n_new, n_features) ... ... Returns ... ------- ... NDArray[np.float32] ... Updated reference data of shape (n_updated, n_features) ... ''' ... x_new_flat = flatten_samples(x_new) ... # Compute moving average for overlapping samples ... n_overlap = min(len(x_ref), len(x_new_flat)) ... if n_overlap > 0: ... x_ref[:n_overlap] = self.alpha * x_ref[:n_overlap] + (1 - self.alpha) * x_new_flat[:n_overlap] ... # Append remaining new samples ... result = np.concatenate([x_ref, x_new_flat[n_overlap:]], axis=0) ... return result[-self.n :]Using a custom update strategy with a drift detector:
>>> from dataeval.shift import DriftUnivariate >>> import numpy as np >>> >>> # Create reference data >>> ref_data = np.random.normal(0, 1, (100, 10)) >>> >>> # Initialize drift detector with custom update strategy >>> update_strategy = MovingAverageUpdate(n=100, alpha=0.9) >>> detector = DriftUnivariate(ref_data, method="ks", update_strategy=update_strategy) >>> >>> # Detect drift on new data - reference will be updated automatically >>> new_data = np.random.normal(0.5, 1, (50, 10)) >>> result = detector.predict(new_data)Notes
Implementations should: - Accept current reference data and new observations - Return updated reference data with consistent shape - Handle edge cases (empty arrays, size mismatches) - Maintain internal state if needed (e.g., sample counts) - Ensure output size doesn’t exceed configured limits
See also
LastSeenUpdateBuilt-in strategy keeping the last n samples
ReservoirSamplingUpdateBuilt-in strategy using reservoir sampling