dataeval.protocols.UpdateStrategy

class dataeval.protocols.UpdateStrategy

Protocol defining the interface for updating reference data in drift detectors.

Update strategies control how drift detectors maintain their reference dataset as new data arrives. Implementations must provide a __call__ method that updates the reference data based on new observations.

See also

LastSeenUpdate

Built-in strategy keeping the last n samples

ReservoirSamplingUpdate

Built-in strategy using reservoir sampling

Notes

Implementations should: - Accept current reference data and new observations - Return updated reference data with consistent shape - Handle edge cases (empty arrays, size mismatches) - Maintain internal state if needed (e.g., sample counts) - Ensure output size doesn’t exceed configured limits

Examples

Creating a custom update strategy that keeps a moving average:

>>> import numpy as np
>>> from numpy.typing import NDArray
>>>
>>> class MovingAverageUpdate:
...     '''Update strategy that maintains a moving average of reference data.'''
...
...     def __init__(self, n: int, alpha: float = 0.9) -> None:
...         '''
...         Parameters
...         ----------
...         n : int
...             Maximum number of samples to maintain
...         alpha : float, default 0.9
...             Exponential moving average weight (0 < alpha < 1)
...         '''
...         self.n = n
...         self.alpha = alpha
...
...     def __call__(self, reference_data: NDArray[np.float32], data: NDArray[np.float32]) -> NDArray[np.float32]:
...         '''
...         Update reference data with exponential moving average.
...
...         Parameters
...         ----------
...         reference_data : NDArray[np.float32]
...             Current reference data of shape (n_ref, n_features)
...         data : NDArray[np.float32]
...             New observations of shape (n_new, n_features)
...
...         Returns
...         -------
...         NDArray[np.float32]
...             Updated reference data of shape (n_updated, n_features)
...         '''
...         data_flat = np.atleast_2d(np.asarray(data, dtype=np.float32))
...         # Compute moving average for overlapping samples
...         n_overlap = min(len(reference_data), len(data_flat))
...         if n_overlap > 0:
...             reference_data[:n_overlap] = (
...                 self.alpha * reference_data[:n_overlap] + (1 - self.alpha) * data_flat[:n_overlap]
...             )
...         # Append remaining new samples
...         result = np.concatenate([reference_data, data_flat[n_overlap:]], axis=0)
...         return result[-self.n :]

Using a custom update strategy with a drift detector:

>>> from dataeval.shift import DriftUnivariate
>>> import numpy as np
>>>
>>> # Create reference data
>>> ref_data = np.random.normal(0, 1, (100, 10))
>>>
>>> # Initialize drift detector with custom update strategy
>>> update_strategy = MovingAverageUpdate(n=100, alpha=0.9)
>>> detector = DriftUnivariate(method="ks", update_strategy=update_strategy).fit(ref_data)
>>>
>>> # Detect drift on new data - reference will be updated automatically
>>> new_data = np.random.normal(0.5, 1, (50, 10))
>>> result = detector.predict(new_data)