dataeval.protocols.UpdateStrategy¶
- class dataeval.protocols.UpdateStrategy¶
Protocol defining the interface for updating reference data in drift detectors.
Update strategies control how drift detectors maintain their reference dataset as new data arrives. Implementations must provide a __call__ method that updates the reference data based on new observations.
Examples
Creating a custom update strategy that keeps a moving average:
>>> import numpy as np >>> from numpy.typing import NDArray >>> >>> class MovingAverageUpdate: ... '''Update strategy that maintains a moving average of reference data.''' ... ... def __init__(self, n: int, alpha: float = 0.9) -> None: ... ''' ... Parameters ... ---------- ... n : int ... Maximum number of samples to maintain ... alpha : float, default 0.9 ... Exponential moving average weight (0 < alpha < 1) ... ''' ... self.n = n ... self.alpha = alpha ... ... def __call__(self, reference_data: NDArray[np.float32], data: NDArray[np.float32]) -> NDArray[np.float32]: ... ''' ... Update reference data with exponential moving average. ... ... Parameters ... ---------- ... reference_data : NDArray[np.float32] ... Current reference data of shape (n_ref, n_features) ... data : NDArray[np.float32] ... New observations of shape (n_new, n_features) ... ... Returns ... ------- ... NDArray[np.float32] ... Updated reference data of shape (n_updated, n_features) ... ''' ... data_flat = np.atleast_2d(np.asarray(data, dtype=np.float32)) ... # Compute moving average for overlapping samples ... n_overlap = min(len(reference_data), len(data_flat)) ... if n_overlap > 0: ... reference_data[:n_overlap] = ( ... self.alpha * reference_data[:n_overlap] + (1 - self.alpha) * data_flat[:n_overlap] ... ) ... # Append remaining new samples ... result = np.concatenate([reference_data, data_flat[n_overlap:]], axis=0) ... return result[-self.n :]Using a custom update strategy with a drift detector:
>>> from dataeval.shift import DriftUnivariate >>> import numpy as np >>> >>> # Create reference data >>> ref_data = np.random.normal(0, 1, (100, 10)) >>> >>> # Initialize drift detector with custom update strategy >>> update_strategy = MovingAverageUpdate(n=100, alpha=0.9) >>> detector = DriftUnivariate(method="ks", update_strategy=update_strategy).fit(ref_data) >>> >>> # Detect drift on new data - reference will be updated automatically >>> new_data = np.random.normal(0.5, 1, (50, 10)) >>> result = detector.predict(new_data)Notes
Implementations should: - Accept current reference data and new observations - Return updated reference data with consistent shape - Handle edge cases (empty arrays, size mismatches) - Maintain internal state if needed (e.g., sample counts) - Ensure output size doesn’t exceed configured limits
See also
LastSeenUpdateBuilt-in strategy keeping the last n samples
ReservoirSamplingUpdateBuilt-in strategy using reservoir sampling