Source code for pmrf.runner

"""
The base runner class for running an algorithm such as fitting or sampling.
"""
from typing import Any
import io
from abc import ABC, abstractmethod
import logging

import jax
import jax.numpy as jnp
import jsonpickle

from pmrf.models.model import Model
from pmrf.frequency import Frequency
from pmrf.constants import FeatureSpecT
from pmrf.features import extract_features
from pmrf.util import RANK, LevelFilteredLogger


[docs] class BaseRunner(ABC): r""" The unified base class for all ParamRF runners (fitters and samplers). This class manages the core model, the frequency bands, and the specifications for feature extraction. It also handles the lazy compilation of JAX-traced functions for evaluating probabilities and model features. .. rubric:: Main methods .. autosummary:: :nosignatures: cdf icdf log_prior model_features run write_results read_results Parameters ---------- model : Model The base ParamRF model containing the free parameters to be evaluated. frequency : Frequency, optional The frequency definitions over which the model should be evaluated. features : FeatureSpecT, optional The target features to extract from the model's simulated network. **feature_kwargs Additional keyword arguments passed directly to the underlying ``extract_features`` routine. Raises ------ ValueError If the initialized ``model`` contains no free parameters. """ def __init__( self, model: Model, *, frequency: Frequency | None = None, features: FeatureSpecT | None = None, **feature_kwargs, ): if model.num_flat_params == 0: raise ValueError("Model has no free parameters.") self.model = model self.frequency = frequency self.features = features self.feature_kwargs = feature_kwargs self.output_path = None self.output_root = None if RANK == 0: self.logger = logging.getLogger(f"pmrf.{self.__class__.__name__}") else: self.logger = LevelFilteredLogger(null_level=logging.WARNING) self._cdf_fn = None self._icdf_fn = None self._log_prior_fn = None self._feature_fn = None
[docs] def cdf(self, theta: jnp.ndarray) -> jnp.ndarray: r""" Evaluate the cumulative distribution function (CDF) for the model parameters. This computes the probability $P(X \leq \theta)$ and lazily compiles the evaluation graph using ``jax.jit`` on the first call. Parameters ---------- theta : jax.numpy.ndarray The parameter values to evaluate. Returns ------- jax.numpy.ndarray The evaluated CDF probabilities, bounded between $0$ and $1$. """ if self._cdf_fn is None: self._cdf_fn = jax.jit(self.model.distribution().cdf) return self._cdf_fn(jnp.array(theta))
[docs] def icdf(self, u: jnp.ndarray) -> jnp.ndarray: r""" Evaluate the inverse cumulative distribution function (ICDF), or quantile function. This computes $\theta = F^{-1}(u)$ and lazily compiles the evaluation graph using ``jax.jit`` on the first call. Parameters ---------- u : jax.numpy.ndarray The probability values to evaluate, typically uniformly distributed between $0$ and $1$. Returns ------- jax.numpy.ndarray The corresponding parameter values $\theta$. """ if self._icdf_fn is None: self._icdf_fn = jax.jit(self.model.distribution().icdf) return self._icdf_fn(jnp.array(u))
[docs] def log_prior(self, theta: jnp.ndarray) -> float | jnp.ndarray: r""" Evaluate the log-prior probability $\log P(\theta)$ of the model parameters. The prior is calculated by summing the log-probabilities of the flat parameter vector. The JAX graph is lazily compiled on the first call. Parameters ---------- theta : jax.numpy.ndarray The parameter values to evaluate. Returns ------- float or jax.numpy.ndarray The scalar log-prior probability. """ if self._log_prior_fn is None: self.logger.debug("Lazily compiling model log-prior graph...") model_dist = self.model.distribution() @jax.jit def log_prior_fn(t): # Sum the log_probs of the flat parameter vector return jnp.sum(model_dist.log_prob(t)) self._log_prior_fn = log_prior_fn return self._log_prior_fn(jnp.array(theta))
[docs] def model_features(self, theta: jnp.ndarray) -> jnp.ndarray: r""" Extract the RF features from the model for a given set of parameters. This function maps the parameters into the model, simulates it over the defined frequency band, and extracts the target specifications. The entire extraction pipeline is vectorized over the batch dimension and lazily compiled via ``jax.jit(jax.vmap(...))``. Parameters ---------- theta : jax.numpy.ndarray A 1D array of a single parameter set, or a 2D array representing a batch of parameters. Returns ------- jax.numpy.ndarray The extracted model features. Matches the batch dimension of ``theta``. Raises ------ RuntimeError If ``frequency`` or ``features`` were not provided during initialization. """ if self._feature_fn is None: if self.frequency is None or self.features is None: raise RuntimeError("Cannot compile feature function: frequency or features not set.") params_tree, static_tree = self.model.partition() from jax import flatten_util _, unravel_fn = flatten_util.ravel_pytree(params_tree) import equinox as eqx # ----------------------------------------- def _single_feature_fn(theta): params_tree_recon = unravel_fn(theta) m = eqx.combine(params_tree_recon, static_tree) single_features = extract_features(m, self.frequency, self.features, **self.feature_kwargs) # single_features = jnp.vstack([ # m.s(self.frequency)[:,0,0], # m.s(self.frequency)[:,0,1], # m.s(self.frequency)[:,1,0], # m.s(self.frequency)[:,1,1], # ]).T return single_features self._feature_fn = jax.jit(jax.vmap(_single_feature_fn)) thetas_2d = jnp.atleast_2d(jnp.array(theta)) features_2d = self._feature_fn(thetas_2d) if theta.ndim == 2: return features_2d else: return features_2d[0]
[docs] @abstractmethod def run(self, *args, **kwargs): r""" Execute the primary algorithm for the runner. Must be implemented by subclasses. """ pass
[docs] @staticmethod def write_results(stream: io.BytesIO, results: Any): r""" Encode backend optimization or sampling results into a bytes stream. Parameters ---------- stream : io.BytesIO The open byte stream to write to. results : Any The data structure containing the backend results. Notes ----- The default implementation uses ``jsonpickle`` for robust, research-grade serialization of complex Python objects. """ pickle_str = jsonpickle.encode(results) stream.write(pickle_str.encode('utf-8'))
[docs] @staticmethod def read_results(stream: io.BytesIO) -> Any: r""" Reconstruct backend optimization or sampling results from a bytes stream. Parameters ---------- stream : io.BytesIO The open byte stream to read from. Returns ------- Any The deserialized Python objects. """ pickle_str = stream.read().decode('utf-8') return jsonpickle.decode(pickle_str)