Source code for pmrf.fitting.base

import glob
from abc import ABC, abstractmethod
from dataclasses import dataclass
import importlib
import logging
from typing import Any, Sequence, Callable
import os
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import jax
import json
import skrf
import h5py
import jsonpickle
from skrf import Network

from pmrf.models import Model
from pmrf.frequency import Frequency
from pmrf.constants import FeatureT
from pmrf._util import LevelFilteredLogger, iter_submodules, load_class_from_string, RANK
from pmrf.frequency import Frequency, MULTIPLIER_DICT
from pmrf.constants import FeatureInputT
from pmrf import extract_features, wrap
from pmrf.network_collection import NetworkCollection

@dataclass
class FitSettings:
    """
    Configuration settings for the fitting process.

    Attributes
    ----------
    frequency : Frequency or None
        The frequency grid used for the fit.
    features : list of FeatureT or None
        The list of features extracted for the fit.
    fitter_kwargs : dict or None
        Keyword arguments passed to the specific fitter backend.
    solver_kwargs : dict or None
        Keyword arguments passed to the numerical solver.
    """
    frequency: Frequency | None = None
    features: list[FeatureT] | None = None
    fitter_kwargs: dict | None = None
    solver_kwargs: dict | None = None    

[docs] @dataclass class FitContext: """ Context object holding the state and data required for a fit execution. Attributes ---------- model : Model The parametric model being fitted. measured : skrf.Network or NetworkCollection The target measured data. frequency : Frequency The frequency object defining the domain. features : list of FeatureT The specific features to match against. measured_features : np.ndarray The values of the features extracted from the measured data. output_path : str or None Directory path for output files. output_root : str or None Root filename for outputs. sparam_kind : str or None The S-parameter representation kind (e.g., 'all', 'transmission'). logger : logging.Logger or None Logger instance for tracking progress. """ model: Model measured: skrf.Network | NetworkCollection frequency: Frequency features: list[FeatureT] measured_features: np.ndarray output_path: str | None = None output_root: str | None = None sparam_kind: str | None = None logger: logging.Logger | None = None
[docs] def model_param_names(self) -> list[str]: """ Get the names of the flat parameters of the model. Returns ------- list of str The list of parameter names. """ return self.model.flat_param_names()
[docs] def make_feature_function(self, as_numpy=False): """ Create a JIT-compiled function to extract features from model parameters. Parameters ---------- as_numpy : bool, default=False If True, the returned function handles NumPy arrays; otherwise JAX arrays. Returns ------- callable A function taking ``theta`` and returning feature values. """ general_feature_fn = wrap(extract_features, self.model, self.frequency, as_numpy=as_numpy) feature_fn = lambda theta: general_feature_fn(theta, self.features, sparam_kind=self.sparam_kind) return jax.jit(feature_fn)
[docs] def settings(self, solver_kwargs=None, fitter_kwargs=None) -> FitSettings: """ Create a FitSettings object from the current context. Parameters ---------- solver_kwargs : dict, optional Solver specific arguments. fitter_kwargs : dict, optional Fitter specific arguments. Returns ------- FitSettings The populated settings object. """ return FitSettings(frequency=self.frequency, features=self.features, fitter_kwargs=fitter_kwargs, solver_kwargs=solver_kwargs)
[docs] class BaseFitter(ABC): """ An abstract base class that provides the foundation for all fitting algorithms in `pmrf`. """
[docs] def __init__( self, model: Model, *, features: FeatureInputT | None = None, output_path: str | None = None, output_root: str | None = None, sparam_kind: str = 'all', ) -> None: """ Initializes the BaseFitter. Parameters ---------- model : Model The parametric `pmrf` Model to be fitted. features : FeatureInputT or None, optional Defines the features to be extracted from the model and network(s). Defaults to `None`, in which case real and imaginary features for all ports are used. Can be a single feature e.g. 's11', a list of features (e.g., `['s11', 's11_mag']`), or a dictionary with either of the above as value. In the dictionary case, keys must be network names in the collection passed by `measured` during fitting, which must also correspond to submodels which are attributes of the model. For example, {'name1', ('s11'), {'name2', ('s21')} can be passed. Note that if a collection of networks is passed, but a feature dictionary is not, it is assumed that those feature(s) should be extract for each networks/submodel. See `extract_features(..)` more details. output_path : str or None The path for fitters to write output data to. Defaults to `None`. output_root : str or None The root name to prepend (with an underscore) to output files in the output path. Defaults to `None`. sparam_kind : str or None The S-parameter data kind to use for port-expansion in feature extraction. Can either be 'transmission', 'reflection' or 'all'. See `extract_features` for more details. """ # Populate parameters self.model: Model = model self.features: FeatureInputT | None = features self.output_path = output_path self.output_root = output_root self.sparam_kind = sparam_kind if RANK == 0: self.logger = logging.getLogger("pmrf.fitting") else: self.logger = LevelFilteredLogger(null_level=logging.WARNING)
[docs] def fit( self, measured: str | Network | NetworkCollection, **kwargs ) -> 'FitResults': """ Fits the model to measured data. This method fits the full model using the original features specified. Arguments are forwarded to ``self.run_context(...)``. Parameters ---------- measured : skrf.Network or prf.NetworkCollection The measured network data to fit the model against. Can be a scikit-rf `Network` or a paramrf `NetworkCollection`. For the latter case the network names should be referenced during feature extraction by specifying features as a dictionary. If networks do not have the same frequency, a common frequency is used. **kwargs Additional arguments forwarded to the underlying algorithm via ``self.run_context``. Returns ------- FitResults The fit results object. """ if isinstance(measured, str): measured = skrf.Network(measured) ctx = self._create_context(measured) results = self._run_context(ctx, **kwargs) return results
[docs] def fit_submodels( self, measured: NetworkCollection, **kwargs ) -> 'FitResults': """ Fits the submodels. This method fits the model to the measured data by fitting its submodels in a sequential manner. Arguments are forwarded to ``self.run_context(...)``. Parameters ---------- measured : prf.NetworkCollection The measured network data to fit the model against. Must be a ParamRF `NetworkCollection`. Network names should be referenced during feature extraction by specifying features as a dictionary. If networks do not have the same frequency, a common frequency is used. **kwargs Additional arguments forwarded to the underlying algorithm via ``self.run_context``. Returns ------- FitResults The fit results object. `solver_results` contains a dictionary of the individual submodel results. """ all_results: dict[str, FitResults] = {} # Fit the components sequentially ctx_kwargs = kwargs.copy() ctx_kwargs['save_results'] = False ctx_kwargs['save_model'] = False ctx_kwargs.setdefault('figure_subfolder', '../../figures') for ntwk in measured: name = ntwk.name self.logger.info(f'Fitting {name}...') model = self.model.with_free_submodels([name], fix_others=True) comp_measured = measured.filter(lambda ntwk: ntwk.name == name) output_path = f'{self.output_path}/submodels/{ntwk.name}' ctx = self._create_context(comp_measured, model=model, output_path=output_path, output_root=name) all_results[name] = self._run_context(ctx, **ctx_kwargs) fitted_model = self.model.with_models([result.fitted_model for result in all_results.values()]) fit_results = FitResults( initial_model=self.model, fitted_model=fitted_model, solver_results=all_results, ) metadata = fitted_model.metadata metadata['fit_results'] = fit_results fitted_model = fitted_model.with_fields(metadata=metadata) if RANK == 0 and self.output_path is not None: self.logger.info(f'Saving combined model...') if kwargs.get('save_model', True): name = fitted_model.name or 'model' fitted_model.save(f'{self.output_path}/fitted_{name}.prf') self.logger.info(f'Saving combined results...') if kwargs.get('save_results', True): fit_results.save_hdf(f'{self.output_path}/results.hdf5') return fit_results
def _create_context(self, measured, *, model=None, features=None, output_path=None, output_root=None, sparam_kind=None) -> FitContext: """ Creates a FitContext from the provided measurement and optional overrides. Parameters ---------- measured : skrf.Network or NetworkCollection The measured data. model : Model, optional Model override. Defaults to self.model. features : FeatureInputT, optional Features override. Defaults to self.features. output_path : str, optional Output path override. Defaults to self.output_path. output_root : str, optional Output root override. Defaults to self.output_root. sparam_kind : str, optional S-parameter kind override. Defaults to self.sparam_kind. Returns ------- FitContext The initialized context object. """ model = model or self.model features = features or self.features sparam_kind = sparam_kind or self.sparam_kind output_path = output_path or self.output_path output_root = output_root or self.output_root # Make sure measured is loaded, and that all frequencies are the same if isinstance(measured, str): measured = skrf.Network(measured) measured = measured.copy() if isinstance(measured, NetworkCollection): measured.interpolate_self() frequency = measured.common_frequency() else: frequency = measured.frequency frequency = Frequency.from_skrf(frequency) # Set the default features and ensure it is not a scalar features = features if features is not None else [port_feature for m, n in model.port_tuples for port_feature in (f's{m+1}{n+1}_re', f's{m+1}{n+1}_im')] if not isinstance(features, Sequence) and not isinstance(features, dict): features = [features] if isinstance(measured, NetworkCollection) and not isinstance(features, dict): features = {ntwk.name: features for ntwk in measured} measured_features = extract_features(measured, None, features, sparam_kind=sparam_kind) return FitContext( measured=measured, model=model, frequency=frequency, features=features, measured_features=measured_features, logger=self.logger, output_path=output_path, output_root=output_root, sparam_kind=sparam_kind, ) def _run_context( self, context: FitContext, *, load_previous: bool = True, new_uniform_frac: float | None = 0.01, save_model: bool = True, save_results: bool = True, plot_s_db: bool = True, figure_subfolder: str | None = None, callback: Callable[['FitResults'], None] | None = None, **kwargs ) -> 'FitResults': """ Executes the fitting context. This is a low-level method and should seldom be used directly. This method runs the fitting algorithm implemented by the underlying sub-class. It contains several convenience parameters, allowing for e.g. automatic saving and plotting of results. Additional arguments are forwarded to the underlying algorithm via ``self.run_algorithm``. Parameters ---------- context: FitContext The fitting context. load_previous : bool, default=True Whether or not to try and load previous results from the output path. new_uniform_frac : float or None, optional, default=0.01 The fraction to update model distribution bounds uniformly around the fitted model values. save_model : bool, default=True Saves the model to the output path (if provided). save_results : bool, default=True Saves the results to hdf format in the output path (if provided). plot_s_db : bool, default=True Plots the S-parameters in db and save the results in the output path (if provided]). callback : Callable[[FitResults], None] or None, optional A callback to run after fitting but before saving and plotting. **kwargs Additional arguments forwarded to the underlying fitter via ``self.run_algorithm``. Returns ------- FitResults The fit results object. """ # Try load from previous results if load_previous and context.output_path is not None: try: filename = glob.glob(f"{context.output_path}/*.hdf5")[0] results = FitResults.load_hdf(filename) logging.info(f"Loaded previous results.") return results except: pass # Output fit parameters and features self.logger.info(f"Fitting for {context.model.num_flat_params} parameters") self.logger.info(f"Parameter names: {context.model.flat_param_names()}") self.logger.info(f'Features: {context.features}') results = self._run_algorithm(context, **kwargs) results.measured = context.measured results.initial_model = context.model results.settings = context.settings(solver_kwargs=kwargs) if new_uniform_frac is not None: results.fitted_model = results.fitted_model.with_uniform_distributions(new_uniform_frac) if callback: callback(results) output_path = context.output_path save_output = context.output_path is not None and (save_model or save_results or plot_s_db) and RANK == 0 if save_output: output_prefix = f'{output_path}/{context.output_root}_' if context.output_root is not None else f'{output_path}/' fitted_model = results.fitted_model name = results.measured[0].name or fitted_model.name or 'model' if save_model: Path(output_path).resolve().mkdir(parents=True, exist_ok=True) self.logger.info(f'Saving {name} model...') fitted_model.save(Path(f'{output_prefix}fitted_{name}.prf').resolve()) if save_results: Path(output_path).resolve().mkdir(parents=True, exist_ok=True) self.logger.info(f'Saving {name} results...') results.save_hdf(Path(f'{output_prefix}results.hdf5').resolve()) if plot_s_db: self.logger.info(f'Plotting {name} S-parameters in db...') figure_path = f'{output_path}/{figure_subfolder}' if figure_subfolder is not None else output_path figure_prefix = f'{figure_path}/{context.output_root}_' if context.output_root is not None else f'{figure_path}/' Path(figure_path).resolve().mkdir(parents=True, exist_ok=True) results.plot_s_db() plt.savefig(Path(f'{figure_prefix}s_db.png').resolve(), dpi=400) plt.close() model_metadata = results.fitted_model.metadata model_metadata['fit_results'] = results results.fitted_model = results.fitted_model.with_fields(metadata=model_metadata) return results @abstractmethod def _run_algorithm(self, context: FitContext, **kwargs) -> 'FitResults': """ Executes the fitting algorithm. This is a low-level method and should seldom be used directly. This method must be implemented by all concrete subclasses. It is the main entry point to start the optimization or sampling process. Parameters ---------- context : FitContext The fitting context. **kwargs Additional keyword arguments. Returns ------- FitResults The fit results object. """ raise NotImplementedError
[docs] @dataclass class FitResults: """ Container for the results of a model fitting process. Attributes ---------- measured : skrf.Network, NetworkCollection, or None The original measured data (target). initial_model : Model or None The model with the initial parameters. fitted_model : Model or None The model with the fitted parameters. solver_results : Any The raw result object returned by the optimization backend. settings : FitSettings or None The configuration used to execute the fit. """ measured: skrf.Network | NetworkCollection | None = None initial_model: Model | None = None fitted_model: Model | None = None solver_results: Model = None settings: FitSettings | None = None # -------------------------------------------------------------------------- # Plotting # --------------------------------------------------------------------------
[docs] def plot_s_db(self, use_initial_model=False): """ Plots the S-parameters (Magnitude in dB) of the Measured vs Fitted data. """ import matplotlib.pyplot as plt model = self.initial_model if use_initial_model else self.fitted_model if self.measured is None or model is None: print("Missing measured data or fitted model.") return if self.settings is None: print("Missing settings (frequency data) to generate fitted networks.") return # 1. Normalize input into a list of tuples: (Name, Measured_Network, Fitted_Network) data_to_plot = [] if isinstance(self.measured, NetworkCollection): for meas_nw in self.measured: try: sub_model = getattr(model, meas_nw.name) fit_nw = sub_model.to_skrf(self.settings.frequency) data_to_plot.append((meas_nw.name, meas_nw, fit_nw)) except AttributeError: print(f"Warning: Could not find sub-model attribute '{meas_nw.name}' in fitted_model.") else: fit_nw = model.to_skrf(self.settings.frequency) data_to_plot.append(("Main Model", self.measured, fit_nw)) if not data_to_plot: return # 2. Determine Grid Dimensions n_rows = len(data_to_plot) max_ports = max(d[1].number_of_ports for d in data_to_plot) n_cols = max_ports * max_ports fig, axes = plt.subplots( nrows=n_rows, ncols=n_cols, figsize=(4 * n_cols, 3.5 * n_rows), squeeze=False ) # 3. Plotting Loop for row_idx, (label, meas, fit) in enumerate(data_to_plot): n_ports = meas.number_of_ports plot_col_idx = 0 for i in range(n_ports): for j in range(n_ports): ax = axes[row_idx, plot_col_idx] if i < fit.number_of_ports and j < fit.number_of_ports: fit.plot_s_db(m=i, n=j, ax=ax, label="Model") meas.plot_s_db(m=i, n=j, ax=ax, label="Measured", linestyle='--', color='k') s_param_label = f"S{i+1}{j+1}" ax.set_title(f"{label} - {s_param_label}") ax.grid(True, which="major", linestyle="-", alpha=0.5) if plot_col_idx == 0: ax.legend(fontsize='small') else: if ax.get_legend(): ax.get_legend().remove() plot_col_idx += 1 for k in range(plot_col_idx, n_cols): axes[row_idx, k].axis('off') fig.tight_layout()
# -------------------------------------------------------------------------- # Public HDF5 IO Interface # --------------------------------------------------------------------------
[docs] def save_hdf(self, path: str, metadata: dict | None = None): """ Save the fit results to an HDF5 file. """ with h5py.File(path, 'w') as f: self._write_to_group(f, metadata)
[docs] @classmethod def load_hdf(cls, path: str) -> "FitResults": """ Load fit results from an HDF5 file. """ with h5py.File(path, 'r') as f: return cls._read_from_group(f)
# -------------------------------------------------------------------------- # Polymorphic Solver Results IO # --------------------------------------------------------------------------
[docs] def encode_solver_results(self, group: h5py.Group): """ Encode solver results into an HDF5 group. Base implementation uses recursion to handle FitResults, dicts, and lists. Subclasses can override this for specific formats (e.g. Scipy results). """ if self.solver_results is None: return # 1. Check for recursive structures (Dict, List, or nested FitResults) if isinstance(self.solver_results, (dict, list, tuple, set, FitResults)): self._encode_recursive(self.solver_results, group) # 2. Fallback for generic objects else: try: data = jsonpickle.encode(self.solver_results) dt = h5py.special_dtype(vlen=str) # Store as 1-element array to avoid scalar dataset issues in some viewers group.create_dataset('data', data=np.array([data], dtype=dt)) except Exception as e: logging.error(f"Failed to encode solver results: {e}")
[docs] @classmethod def decode_solver_results(cls, group: h5py.Group) -> Any: """ Decode solver results from an HDF5 group. """ # 1. Recursive Structure detected if '__type__' in group.attrs: return cls._decode_recursive(group) # 2. Legacy/Generic object detected elif 'data' in group: try: data = group['data'][()] if isinstance(data, np.ndarray): data = data[0] if isinstance(data, bytes): data = data.decode('utf-8') return jsonpickle.decode(data) except Exception as e: logging.error(f"Failed to decode solver results: {e}") return None return None
# -------------------------------------------------------------------------- # Recursive Encoding Helpers (Protected) # -------------------------------------------------------------------------- def _encode_recursive(self, obj: Any, group: h5py.Group): # Case A: Nested FitResults if isinstance(obj, FitResults): group.attrs['__type__'] = 'FitResults' obj._write_to_group(group) return # Case B: Dictionary if isinstance(obj, dict): group.attrs['__type__'] = 'dict' for k, v in obj.items(): sub_grp = group.create_group(str(k)) self._encode_recursive(v, sub_grp) return # Case C: Iterables if isinstance(obj, (list, tuple, set)): group.attrs['__type__'] = type(obj).__name__ for i, item in enumerate(obj): sub_grp = group.create_group(str(i)) self._encode_recursive(item, sub_grp) return # Case D: Leaf Node group.attrs['__type__'] = 'leaf' try: data = jsonpickle.encode(obj) dt = h5py.special_dtype(vlen=str) group.create_dataset('data', data=np.array([data], dtype=dt)) except Exception as e: logging.error(f"Failed to encode leaf: {e}") @classmethod def _decode_recursive(cls, group: h5py.Group) -> Any: type_tag = group.attrs.get('__type__') if type_tag == 'FitResults': # This handles polymorphism for nested results (e.g. FitResults vs SciPyResults) return FitResults._read_from_group(group) elif type_tag in ('list', 'tuple', 'set'): keys = sorted(group.keys(), key=lambda x: int(x) if x.isdigit() else x) items = [cls._decode_recursive(group[k]) for k in keys] if type_tag == 'tuple': return tuple(items) if type_tag == 'set': return set(items) return items elif type_tag == 'dict': res = {} for k in group.keys(): res[k] = cls._decode_recursive(group[k]) return res else: # Leaf or unknown if 'data' in group: data = group['data'][()] if isinstance(data, np.ndarray): data = data[0] if isinstance(data, bytes): data = data.decode('utf-8') return jsonpickle.decode(data) return None # -------------------------------------------------------------------------- # Internal IO Drivers # -------------------------------------------------------------------------- def _write_to_group(self, group: h5py.Group, metadata: dict | None = None): """Internal driver to save full object state.""" # 1. Metadata metadata_grp = group.create_group('metadata') internal_grp = metadata_grp.create_group('__pmrf__') internal_grp['fit_results_cls'] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" internal_grp['version'] = 4 if metadata: self._save_dict_to_group(metadata, metadata_grp) # 2. Models if self.initial_model: self.initial_model.write_hdf(group.create_group('initial_model')) if self.fitted_model: self.fitted_model.write_hdf(group.create_group('fitted_model')) # 3. Measured if self.measured: meas_grp = group.create_group('measured') if isinstance(self.measured, skrf.Network): self._write_network(meas_grp, self.measured) else: # NetworkCollection for ntwk in self.measured: self._write_network(meas_grp.create_group(ntwk.name), ntwk) # 4. Settings if self.settings: self._write_settings(group.create_group('settings')) # 5. Solver Results if self.solver_results is not None: self.encode_solver_results(group.create_group('solver_results')) @classmethod def _read_from_group(cls, group: h5py.Group) -> "FitResults": """Internal driver to load full object state.""" # 1. Determine Class Type target_cls = cls if 'metadata' in group: meta = group['metadata'].get('__pmrf__', {}) or group['metadata'].get('fitter', {}) if 'fit_results_cls' in meta: cls_path = meta['fit_results_cls'][()] cls_path = cls_path.decode('utf-8') if isinstance(cls_path, bytes) else cls_path target_cls = load_class_from_string(cls_path) # 2. Load Models fitted_model = None if 'fitted_model' in group: fitted_model = Model.read_hdf(group['fitted_model']) initial_model = None if 'initial_model' in group: initial_model = Model.read_hdf(group['initial_model']) # 3. Load Measured measured = None if 'measured' in group: measured_grp = group['measured'] if 's' in measured_grp and 'f' in measured_grp: measured = cls._read_network(measured_grp) else: params = cls._group_to_dict(measured_grp['params']) if 'params' in measured_grp else None measured = NetworkCollection(params=params) for label in measured_grp.keys(): if label == 'params': continue measured.add(cls._read_network(measured_grp[label])) # 4. Load Settings settings = None if 'settings' in group: settings = cls._read_settings(group['settings']) # 5. Load Solver Results solver_results = None if 'solver_results' in group: solver_results = target_cls.decode_solver_results(group['solver_results']) return target_cls( measured=measured, initial_model=initial_model, fitted_model=fitted_model, solver_results=solver_results, settings=settings ) # -------------------------------------------------------------------------- # Static Helper Methods # -------------------------------------------------------------------------- @staticmethod def _save_dict_to_group(d: dict, group: h5py.Group): for k, v in d.items(): if isinstance(v, dict): subgrp = group.create_group(k) FitResults._save_dict_to_group(v, subgrp) else: group[k] = json.dumps(v) @staticmethod def _group_to_dict(group: h5py.Group): result = {} for key, item in group.items(): if isinstance(item, h5py.Group): result[key] = FitResults._group_to_dict(item) else: result[key] = item[()] return result @staticmethod def _write_network(group: h5py.Group, ntwk: skrf.Network): group['name'] = ntwk.name or 'network' group.create_dataset('s', data=ntwk.s) group.create_dataset('f', data=ntwk.f) group.create_dataset('z0', data=ntwk.z0) if ntwk.params is not None: params_grp = group.create_group('params') for key, value in ntwk.params.items(): params_grp[key] = value @staticmethod def _read_network(group: h5py.Group) -> skrf.Network: name = group['name'][()] name = name.decode('utf-8') if isinstance(name, bytes) else name s = group['s'][()] f_data = group['f'][()] z0 = group['z0'][()] params = None if 'params' in group: params = FitResults._group_to_dict(group['params']) return skrf.Network(s=s, f=f_data, z0=z0, name=name, params=params) def _write_settings(self, group: h5py.Group): if self.settings.frequency is not None: frequency_grp = group.create_group('frequency') frequency_grp['f'] = self.settings.frequency.f frequency_grp['f_scaled'] = self.settings.frequency.f_scaled frequency_grp['unit'] = self.settings.frequency.unit if self.settings.features is not None: group.create_dataset('features', data=json.dumps(self.settings.features)) if self.settings.fitter_kwargs is not None: group.create_dataset('fitter_kwargs', data=jsonpickle.encode(self.settings.fitter_kwargs)) if self.settings.solver_kwargs is not None: group.create_dataset('solver_kwargs', data=jsonpickle.encode(self.settings.solver_kwargs)) @staticmethod def _read_settings(group: h5py.Group): # Assumes FitSettings is available in scope frequency = None if 'frequency' in group: freq_grp = group['frequency'] unit = freq_grp['unit'][()] unit = unit.decode('utf-8') if isinstance(unit, bytes) else unit if 'f_scaled' in freq_grp: f_scaled_arr = freq_grp['f_scaled'][()] frequency = Frequency.from_f(f=f_scaled_arr, unit=unit) else: # Basic Multiplier Dict fallback MULTIPLIER_DICT = {'hz': 1, 'khz': 1e3, 'mhz': 1e6, 'ghz': 1e9, 'thz': 1e12} f_arr = freq_grp['f'][()] div = MULTIPLIER_DICT.get(unit.lower(), 1) frequency = Frequency.from_f(f_arr / div, unit=unit) features = json.loads(group["features"][()]) if 'features' in group else None solver_kwargs = None if 'solver_kwargs' in group: data = group['solver_kwargs'][()] # Handle variable length string dataset vs scalar if isinstance(data, np.ndarray): data = data[0] solver_kwargs = jsonpickle.decode(data) fitter_kwargs = None if 'fitter_kwargs' in group: data = group['fitter_kwargs'][()] if isinstance(data, np.ndarray): data = data[0] fitter_kwargs = jsonpickle.decode(data) return FitSettings(frequency, features, fitter_kwargs, solver_kwargs)
def Fitter( model: Model, *, inference: str | None = None, backend: str | None = None, **kwargs ) -> 'BaseFitter': """ Fitter factory function. This allows the creator of a fitter by simply specifying the inference type or fitter backend and having all arguments forwarded. See the relevant fitter classes for detailed documentation. Parameters ---------- model : Model The parametric `pmrf` Model to be fitted. See the documentation for `BaseFitter`. inference : str, optional High-level inference mode. Can be either 'frequentist' or 'bayesian'. If provided and ``backend`` is ``None``, a suitable default backend is selected automatically. backend : str, optional Explicit fitter backend name. If provided, this takes precedence over ``inference`` and must be compatible with it. **kwargs Additional arguments forwarded to the fitter constructor. Returns ------- BaseFitter The concrete fitter instance. """ if inference is None and backend is None: inference = 'frequentist' if inference not in [None, 'frequentist', 'bayesian']: raise Exception('Unknown inference type') if backend is None: backend = 'scipy-minimize' if inference == 'frequentist' else 'polychord' if not is_inference_kind(backend, inference): raise Exception('Inference type incompatible with backend') cls = get_fitter_class(backend) return cls(model, **kwargs) def is_frequentist(solver) -> bool: """ Check if a solver is a Frequentist fitter. Parameters ---------- solver : str The name of the solver. Returns ------- bool True if the solver corresponds to a FrequentistFitter subclass. """ from pmrf.fitting.frequentist import FrequentistFitter cls = get_fitter_class(solver) return issubclass(cls, FrequentistFitter) def is_bayesian(solver) -> bool: """ Check if a solver is a Bayesian fitter. Parameters ---------- solver : str The name of the solver. Returns ------- bool True if the solver corresponds to a BayesianFitter subclass. """ from pmrf.fitting.bayesian import BayesianFitter cls = get_fitter_class(solver) return issubclass(cls, BayesianFitter) def is_inference_kind(solver, inference: str): """ Check if a solver matches a specific inference kind. Parameters ---------- solver : str The name of the solver. inference : str The inference kind ('frequentist' or 'bayesian'). Returns ------- bool True if the solver matches the inference kind. Raises ------ Exception If the inference kind is unknown. """ if inference == 'frequentist': return is_frequentist(solver) elif inference == 'bayesian': return is_bayesian(solver) else: raise Exception(f"Unknown inference type '{inference}'") def get_fitter_class(solver: str): """ Retrieve the Fitter class corresponding to a solver name. Parameters ---------- solver : str The name of the solver (e.g., 'scipy-minimize'). Returns ------- class The fitter class found in the backends. Raises ------ Exception If the solver class cannot be found or imported. """ solver = solver.replace('scipy', 'sciPy') solver = solver.replace('polychord', 'polyChord') class_names = [solver + 'Fitter'] class_names.append(''.join(part[0].upper() + part[1:] for part in solver.split('-')) + 'Fitter') try: for submodule_name, _ in iter_submodules('pmrf.fitting._backends'): fitter_submodel = importlib.import_module(submodule_name) for class_name in class_names: if hasattr(fitter_submodel, class_name): return getattr(fitter_submodel, class_name) except (ImportError, AttributeError) as e: raise Exception(f'Could not find solver named {solver} with error: {e}') def fit(model: Model, measured: str | skrf.Network | NetworkCollection, **kwargs) -> Model: """Fits a model to measured data. This is an alternative API to ParamRF's :mod:`fitting <pmrf.fitting>` module. See the :func:`fit <pmrf.fitting.BaseFitter.fit>` method for more details. Internally, a fitter is first created using :func:`pmrf.fitting.Fitter`, and then :func:`fitter.fit(...)` is called, with the fitted model being returned. Fit results are stored in the model's metadata with key 'fit_results'. Key-word arguments are split into 'init' and 'fit' key-word arguments appropriately. Parameters ---------- model: prf.Model The model to fit. measured : prf.NetworkCollection | skrf.Network | str The measured data. **kwargs Additional arguments forwarded to :func:`pmrf.fitting.Fitter` and :func:`pmrf.fitting.BaseFitter.fit`. Returns ------- Model The fitted model. """ from pmrf.fitting import Fitter, FITTER_INIT_PARAMS init_kwargs = {k: kwargs.pop(k) for k in FITTER_INIT_PARAMS if k in kwargs} return Fitter(model, **init_kwargs).fit(measured, **kwargs).fitted_model def fit_submodels(model: Model, measured: str | skrf.Network | NetworkCollection, **kwargs) -> Model: """Fits a model's submodels to measured data. This is an alternative API to ParamRF's :mod:`fitting <pmrf.fitting>` module. See the :func:`fit_submodels <pmrf.fitting.BaseFitter.fit_submodels>` method for more details. Internally, a fitter is first created using :func:`pmrf.fitting.Fitter`, and then :func:`fitter.fit_submodels(...)` is called, with the fitted model being returned. Fit results are stored in the model's metadata with key 'fit_results'. Key-word arguments are split into 'init' and 'fit' key-word arguments appropriately. Parameters ---------- model: prf.Model The model to fit. measured : prf.NetworkCollection | skrf.Network | str The measured data. **kwargs Additional arguments forwarded to :func:`pmrf.fitting.Fitter` and :func:`pmrf.fitting.BaseFitter.fit_submodels`. Returns ------- Model The fitted model. """ from pmrf.fitting import Fitter, FITTER_INIT_PARAMS init_kwargs = {k: kwargs.pop(k) for k in FITTER_INIT_PARAMS if k in kwargs} return Fitter(model, **init_kwargs).fit_submodels(measured, **kwargs).fitted_model