Source code for pmrf.fitting.frequentist

from dataclasses import dataclass

import numpy as np
import jax
import jax.numpy as jnp
import equinox as eqx

from pmrf.functions import l2_norm_ax0, mag_2_db, conv_inter
from pmrf.models.model import Model
from pmrf.constants import ArrayFuncT
from pmrf.fitting.base import BaseFitter, FitResults, FitContext

L2_COST = [l2_norm_ax0, l2_norm_ax0, mag_2_db]
CONVOLUTIONAL_COST = [l2_norm_ax0, conv_inter, l2_norm_ax0, mag_2_db]

[docs] class FrequentistResults(FitResults): """ Results obtained from a Frequentist (classical) fitting process. This class is a marker subclass of `FitResults` and currently shares the same structure. """ pass
[docs] @dataclass class FrequentistContext(FitContext): """ Context object for Frequentist fitting, containing the cost function. Attributes ---------- cost_function : eqx.Module or None The sequence of functions defining the error metric. """ cost_function: eqx.Module | None = None
[docs] def make_cost_function(self, as_numpy=False): """ Create the cost function to be minimized. The cost function calculates the error between measured and model features, applies the defined `cost_function` transformation, and returns a scalar value. Parameters ---------- as_numpy : bool, optional, default=False If True, returns a function compatible with NumPy arrays; otherwise JAX arrays. Returns ------- callable The JIT-compiled cost function taking flat parameters and returning a scalar cost. """ x0_jax = self.model.flat_param_values() feature_fn_jax = self.make_feature_function() # Define the JAX cost function to be minimized @jax.jit def cost_fn(flat_params) -> jnp.ndarray: model_features = feature_fn_jax(flat_params) error = self.measured_features - model_features cost_val = self.cost_function(error) if jnp.isscalar(self.cost_function(error)): return cost_val else: return cost_val[0] if as_numpy: cost_fn_jax = cost_fn cost_fn = lambda x: float(cost_fn_jax(jnp.array(x))) x0_np = np.array(x0_jax) x0 = x0_np self.logger.info(f"Compiling cost function...") _cost_val = cost_fn(x0) return cost_fn
[docs] def bounds(self) -> tuple[jnp.ndarray, jnp.ndarray]: """ Retrieve the lower and upper bounds for all model parameters. Returns ------- tuple of jnp.ndarray A tuple containing (lower_bounds, upper_bounds). Raises ------ Exception If any parameter is not associated with a parameter group. """ param_groups = self.model.param_groups() param_names = self.model.flat_param_names() name_to_minimum = {name: None for name in param_names} name_to_maximum = {name: None for name in param_names} for param_group in param_groups: group_minimums, group_maximums = param_group.min, param_group.max group_param_names = param_group.parameter_names for i, name in enumerate(group_param_names): name_to_minimum[name] = group_minimums[i] name_to_maximum[name] = group_maximums[i] if any(value is None for value in name_to_minimum.values()) or any(value is None for value in name_to_maximum.values()): raise Exception('Parameter found that did not belong to a parameter groups') return jnp.array(list(name_to_minimum.values())), jnp.array(list(name_to_maximum.values()))
[docs] class FrequentistFitter(BaseFitter): """ A base class for frequentist (classical) optimization methods. This class extends `BaseFitter` by adding the concept of a cost function, which takes the difference between model features and measured features and computes a single scalar value representing the "cost" or "error". """
[docs] def __init__( self, model: Model, *, cost_kind: str | None = None, cost_function: ArrayFuncT | list[ArrayFuncT] | eqx.Module | None = None, **kwargs ) -> None: """ Initializes the FrequentistFitter. Parameters ---------- model : Model The parametric `pmrf` Model to be fitted. cost_kind : str, optional A cost 'kind' alias to initialize the features and cost function from. Can be one of 'convolutional', 'complex', 'magnitude' or None. cost_function : ArrayFuncT, list[ArrayFuncT] or eqx.Module, optional A function or sequence of functions defining the cost metric. If a list of functions is provided, they are composed sequentially. If `None`, then `cost_kind` defines the cost function. Defaults to `None`. **kwargs Additional arguments forwarded to :class:`BaseFitter`. """ self.cost_kind = cost_kind self.cost_function = cost_function super().__init__(model, **kwargs)
def _create_context(self, measured, *, cost_kind=None, cost_function=None, **kwargs) -> FrequentistContext: """ Create a FrequentistContext for the fitting process. Parameters ---------- measured : skrf.Network or NetworkCollection The measured data. cost_kind : str, optional Override for the cost kind alias. cost_function : callable or list of callables, optional Override for the cost function. **kwargs Additional context arguments (e.g., `features`). Returns ------- FrequentistContext The configured Frequentist context. Raises ------ Exception If an unknown `cost_kind` alias is provided. """ features = kwargs.pop('features', None) or self.features cost_kind = cost_kind or self.cost_kind or 'convolutional' cost_function = cost_function or self.cost_function default_features = None default_cost = None if cost_kind == 'convolutional': default_features = ['s', 's_mag'] default_cost = CONVOLUTIONAL_COST elif cost_kind == 'complex': default_features = ['s'] default_cost = L2_COST elif cost_kind == 'magnitude': default_features = ['s_mag'] default_cost = L2_COST elif cost_kind is not None: raise Exception("Unknown cost kind alias passed to frequentist fitter") if features is None: features = default_features if cost_function is None: cost_function = default_cost base_ctx = super()._create_context(measured, features=features, **kwargs) feature_spec = base_ctx.features if cost_function is not None and not isinstance(cost_function, list): cost_function = [cost_function] if cost_function is None: if len(feature_spec) > 1: cost_function = [l2_norm_ax0, l2_norm_ax0, mag_2_db] else: cost_function = [l2_norm_ax0, mag_2_db] cost_function = cost_function if isinstance(cost_function, eqx.Module) else eqx.nn.Sequential([eqx.nn.Lambda(fn) for fn in cost_function]) return FrequentistContext( model=base_ctx.model, measured=base_ctx.measured, frequency=base_ctx.frequency, features=base_ctx.features, measured_features=base_ctx.measured_features, output_path=base_ctx.output_path, output_root=base_ctx.output_root, sparam_kind=base_ctx.sparam_kind, cost_function=cost_function, logger=self.logger, )