Source code for pmrf.algorithms.anomaly

import jax.numpy as jnp
import os
import numpy as np
import jax.numpy as jnp
from jax import vmap
import matplotlib.pyplot as plt

[docs] def get_anomaly_mask(Y, threshold): """ Computes a boolean mask of rows that ARE anomalies. Returns True for rows that exceed the threshold. """ def check_row_is_anomaly(signal, th): d2 = jnp.diff(signal, n=2) curvature = jnp.abs(d2) th_is_vector = jnp.ndim(th) > 0 threshold_view = th[1:-1] if th_is_vector else th has_spike = jnp.any(curvature > threshold_view) return has_spike mask = vmap(check_row_is_anomaly, in_axes=(0, None))(Y, threshold) return mask
[docs] def has_sudden_changes(values, *, upwards=True, downwards=True, window=5, iqr_factor=1.5, relative_epsilon=0.05): """ Checks if an array has any sudden changes for all values relative to the previous `window` values. """ values = jnp.array(values) for idx in range(len(values) - 1, window - 1, -1): # First detect any spikes target_value = values[idx] history_start = idx - window history_end = idx history_window = values[history_start:history_end] # Calculate IQR on that history Q1, Q3 = jnp.percentile(jnp.array(history_window), 25), jnp.percentile(jnp.array(history_window), 75) actual_iqr = Q3 - Q1 min_iqr_floor = (jnp.abs(jnp.median(jnp.array(history_window))) * relative_epsilon) effective_iqr = jnp.maximum(actual_iqr, min_iqr_floor) if upwards: upper_threshold = Q3 + effective_iqr * iqr_factor if target_value > upper_threshold: return True if downwards: lower_threshold = Q1 - effective_iqr * iqr_factor if target_value < lower_threshold: return True return False