Source code for pmrf.fitting._frequentist

import numpy as np
import skrf
import jax
import jax.numpy as jnp
import equinox as eqx

from pmrf.functions import l2_norm_ax0, mag_2_db
from pmrf.models.model import Model
from pmrf.parameters import Parameter, ParameterGroup
from pmrf.constants import FeatureInputT, ArrayFuncT

from pmrf.fitting._base import BaseFitter, FitResults

[docs] class FrequentistResults(FitResults): pass
[docs] class FrequentistFitter(BaseFitter): """ A base class for frequentist (classical) optimization methods. This class extends `BaseFitter` by adding the concept of a `cost_fn`, a function that takes the difference between model features and measured features and computes a single scalar value representing the "cost" or "error". """
[docs] def __init__( self, model: Model, measured: skrf.Network | dict[str, skrf.Network], frequency: skrf.Frequency | None = None, features: FeatureInputT | None = None, cost: ArrayFuncT | list[ArrayFuncT] | eqx.Module = None, *args, **kwargs ) -> None: """Initializes the FrequentistFitter. Args: model (Model): The parametric `pmrf` model to be fitted. measured (skrf.Network | list[skrf.Network]): The measured network data to fit the model against. frequency (skrf.Frequency | None, optional): The frequency axis to perform the fit on. Defaults to `None`. features (FeatureT | FeatureListT | None, optional), The features to extract for comparison. Defaults to `None`. cost (ArrayFuncT | list[ArrayFuncT] | eqx.Module, optional): A function or sequence of functions defining the cost metric. If a list of functions is provided, they are composed sequentially. If `None`, a default cost function (typically L2 norm on the dB magnitude difference) is used. Defaults to `None`. """ super().__init__(model=model, measured=measured, frequency=frequency, features=features, *args, **kwargs) features = self.features if cost is not None and not isinstance(cost, list): cost = [cost] if cost is None: if len(features) > 1: cost = [l2_norm_ax0, l2_norm_ax0, mag_2_db] else: cost = [l2_norm_ax0, mag_2_db] self.cost_metric_fn = cost if isinstance(cost, eqx.Module) else eqx.nn.Sequential([eqx.nn.Lambda(fn) for fn in cost])
def _make_cost_function(self, as_numpy=False): x0_jax = self.initial_model.flat_params() feature_fn_jax = self._make_feature_function() # Define the JAX cost function to be minimized @jax.jit def cost_fn(flat_params) -> jnp.ndarray: model_features = feature_fn_jax(flat_params) error = self.measured_features - model_features cost_val = self.cost_metric_fn(error) if jnp.isscalar(self.cost_metric_fn(error)): return cost_val else: return cost_val[0] if as_numpy: cost_fn_jax = cost_fn cost_fn = lambda x: float(cost_fn_jax(jnp.array(x))) x0_np = np.array(x0_jax) x0 = x0_np self.logger.info(f"Compiling cost function...") _c0 = cost_fn(x0) return cost_fn def _bounds(self) -> tuple[jnp.ndarray, jnp.ndarray]: param_groups = self.initial_model.param_groups() param_names = self.initial_model.flat_param_names() name_to_minimum = {name: None for name in param_names} name_to_maximum = {name: None for name in param_names} for param_group in param_groups: group_minimums, group_maximums = param_group.min, param_group.max group_param_names = param_group.param_names for i, name in enumerate(group_param_names): name_to_minimum[name] = group_minimums[i] name_to_maximum[name] = group_maximums[i] if any(value is None for value in name_to_minimum.values()) or any(value is None for value in name_to_maximum.values()): raise Exception('Parameter found that did not belong to a parameter groups') return jnp.array(list(name_to_minimum.values())), jnp.array(list(name_to_maximum.values()))