Source code for pmrf.fitting._base

from abc import ABC, abstractmethod
from dataclasses import dataclass
from collections import Counter
import importlib
import logging
from typing import Any, Sequence
from io import BytesIO
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import json
import skrf
import h5py
import jsonpickle
import equinox as eqx
import skrf
try:
    from mpi4py import MPI
    rank = MPI.COMM_WORLD.Get_rank()
except ImportError:
    rank = 0

from pmrf.models.model import Model
from pmrf.frequency import Frequency
from pmrf.constants import FeatureT
from pmrf._util import LevelFilteredLogger, iter_submodules, load_class_from_string
from pmrf.models.model import Model
from pmrf.frequency import Frequency, MULTIPLIER_DICT
from pmrf.constants import FeatureInputT
from pmrf import extract_features, wrap

def Fitter(
    name: str,
    *args,
    **kwargs
) -> 'BaseFitter':
    """Fitter factory function.
    
    This allows the creator of a fitter by simply specifying the fitter type and having all arguments forwarded.
    See the relevant fitter classes for detailed documentation.

    Args:
        name (str): The name of the fitter to create, specified as either e.g. 'ScipyMinimize' or 'scipy-minimize'.

    Returns:
        BaseFitter: The concrete fitter instance.
    """
    cls = get_fitter_class(name)
    return cls(*args, **kwargs)
    
@dataclass
class FitSettings:
    frequency: Frequency | None = None
    features: list[FeatureT] | None = None
    fitter_kwargs: dict | None = None
    solver_kwargs: dict | None = None    

[docs] class BaseFitter(ABC): """ An abstract base class that provides the foundation for all fitting algorithms in `pmrf`. """
[docs] def __init__( self, model: Model, measured: str | skrf.Network | dict[str, skrf.Network], frequency: Frequency | None = None, features: FeatureInputT | None = None, ) -> None: """Initializes the BaseFitter. Args: model (Model): The parametric `pmrf` model to be fitted. measured (str | skrf.Network | dict[str, skrf.Network]): The measured network data to fit the model against. A dict can optionally be passed, in which case the keys of the networks must can be referenced during feature extraction by also specifying features as a dictionary. See the documentation for the `features` argument below. frequency (Frequency | None, optional): The frequency axis to perform the fit on. If `None`, the frequency from the first measured network is used. All networks will be interpolated onto this single frequency axis. Defaults to `None`. features (FeatureInputT | None, optional): Defines the features to be extracted from the network data and model for fitting. See `extract_features(..)` for a detail explanation. As an overview using string aliases, this can be a single feature e.g. 's11', a list of features (e.g., `['s11', 's11_mag']`), or a dictionary with either of the above as value. In the dictionary case, keys must be network names in the sequence passed by `measured`, which must also correspond to submodels which are attributes of the model. As an example, {'source_name1', ('s11'), {'source_name2', ('s21')} can be passed. Note that if a sequence of networks is passed, but a dictionary is not. it is assumed that those feature(s) should be extract for all measured networks/submodels. Defaults to `None`, in which case real and imaginary features for all ports are used. """ if isinstance(measured, str): measured = skrf.Network(str) # Set the default features and ensure it is not a scalar features = features if features is not None else [port_feature for m, n in model.port_tuples for port_feature in (f's{m+1}{n+1}_re', f's{m+1}{n+1}_im')] if not isinstance(features, Sequence) and not isinstance(features, dict): features = [features] if isinstance(measured, dict) and not isinstance(features, dict): features = {k: features for k in measured.keys()} # All frequencies must be the same across all measurements (at least currently..). We copy the input dict measured = measured.copy() if frequency is not None: measured_freq = frequency if isinstance(measured, dict): measured = {k: v.interpolate(frequency) for k, v in measured} else: measured = measured.interpolate(frequency) else: measured_freq = None if isinstance(measured, dict): for ntwk in measured.values(): if measured_freq is None: measured_freq = ntwk.frequency if ntwk.frequency != measured_freq and not len(ntwk.frequency) == 0: raise ValueError("Error: Currently `fit_frequency` must be passed for multi-measurement fits (i.e. all networks must be explicitly interpolated onto the same frequency for fitting)") else: measured_freq = measured.frequency # Initialize settings self.frequency = frequency or Frequency.from_skrf(measured_freq) self.features = features # Initialize model parameters from user and store in flat array self.initial_model: Model = model self.measured: skrf.Network | dict[str, skrf.Network] = measured self.measured_features = extract_features(measured, None, features) if rank == 0: self.logger = logging.getLogger("pmrf.fitting") else: self.logger = LevelFilteredLogger(null_level=logging.WARNING)
def _settings(self, solver_kwargs=None, fitter_kwargs=None) -> FitSettings: return FitSettings(frequency=self.frequency, features=self.features, fitter_kwargs=fitter_kwargs, solver_kwargs=solver_kwargs)
[docs] @abstractmethod def run(self, *args, **kwargs) -> 'FitResults': """Executes the fitting algorithm. This method must be implemented by all concrete subclasses. It is the main entry point to start the optimization or sampling process. Returns: FitResults: An object containing the results of the fit. """ pass
def _make_feature_function(self, as_numpy=False): general_feature_fn = wrap(extract_features, self.initial_model, self.frequency, as_numpy=as_numpy) feature_fn = lambda theta: general_feature_fn(theta, self.features) return jax.jit(feature_fn)
[docs] @dataclass class FitResults: measured: skrf.Network | dict[str, skrf.Network] | None = None initial_model: Model | None = None fitted_model: Model | None = None solver_results: Any = None settings: FitSettings | None = None
[docs] def encode_solver_results(self, group: h5py.Group): data = None if self.solver_results is not None: try: data = jsonpickle.encode(self.solver_results) except Exception as e: logging.error(f"Failed to encode solver results: {e}") group['data'] = data
[docs] @classmethod def decode_solver_results(cls, group: h5py.Group) -> Any: if 'data' in group: try: return jsonpickle.decode(group['data'][()]) except Exception as e: logging.error(f"Failed to decode solver results: {e}") return None
[docs] def save_hdf(self, path: str, metadata: dict | None = None): with h5py.File(path, 'w') as f: # Metadata metadata_grp = f.create_group('metadata') internal_metadata_grp = metadata_grp.create_group('__pmrf__') internal_metadata_grp['fit_results_cls'] = str(self.__class__.__module__ + "." + self.__class__.__qualname__) if self.solver_results is not None: internal_metadata_grp['solver_results_cls'] = self.solver_results.__module__ + "." + self.__class__.__qualname__ if not metadata is None: def save_dict_to_group(d: dict, group: h5py.Group): for k, v in d.items(): if isinstance(v, dict): subgrp = group.create_group(k) save_dict_to_group(v, subgrp) else: group[k] = json.dumps(v) save_dict_to_group(metadata, metadata_grp) # Models if self.fitted_model is not None: self.fitted_model.write_hdf(f.create_group('fitted_model')) if self.initial_model is not None: self.initial_model.write_hdf(f.create_group('initial_model')) # Measured data if self.measured is not None: measured_grp = f.create_group('measured') if isinstance(self.measured, skrf.Network): measured_grp['name'] = self.measured.name or 'ntwk' measured_grp.create_dataset('s', data=self.measured.s) measured_grp.create_dataset('f', data=self.measured.f) measured_grp.create_dataset('z0', data=self.measured.z0) else: for label, ntwk in self.measured.items(): measured_ntwk_grp = measured_grp.create_group(label) measured_ntwk_grp['name'] = ntwk.name or 'ntwk' measured_ntwk_grp.create_dataset('s', data=ntwk.s) measured_ntwk_grp.create_dataset('f', data=ntwk.f) measured_ntwk_grp.create_dataset('z0', data=ntwk.z0) # Solver results if self.solver_results is not None: solver_results_grp = f.create_group('solver_results') self.encode_solver_results(solver_results_grp) # Other input ## Setup input_grp = f.create_group('settings') ## Other settings if self.settings.frequency is not None: frequency_grp = input_grp.create_group('frequency') frequency_grp['f'] = self.settings.frequency.f frequency_grp['f_scaled'] = self.settings.frequency.f_scaled frequency_grp['unit'] = self.settings.frequency.unit if self.settings.features is not None: input_grp.create_dataset('features', data=json.dumps(self.settings.features)) if self.settings.fitter_kwargs is not None: input_grp.create_dataset('fitter_kwargs', data=jsonpickle.encode(self.settings.fitter_kwargs)) if self.settings.solver_kwargs is not None: input_grp.create_dataset('solver_kwargs', data=jsonpickle.encode(self.settings.solver_kwargs))
[docs] @classmethod def load_hdf(cls, path: str) -> "FitResults": with h5py.File(path, 'r') as f: # Metadata if 'metadata' in f: metadata_grp = f['metadata'] if 'fitter' in metadata_grp and 'version' in metadata_grp['fitter']: internal_metadata_grp = metadata_grp['fitter'] else: internal_metadata_grp = metadata_grp['__pmrf__'] version = internal_metadata_grp['version'][()] fit_results_cls_path = internal_metadata_grp['fit_results_cls'][()] fit_results_cls_path = fit_results_cls_path.decode('utf-8') if isinstance(fit_results_cls_path, bytes) else fit_results_cls_path try: cls = load_class_from_string(fit_results_cls_path) except ImportError: logging.warning(f"Could not import class from path '{fit_results_cls_path}'. Using FitResults instead.") # Model fit if version == 1: fitted_model = Model.read_hdf(f['model']) if 'model' in f else None elif version == 2: fitted_model = Model.read_hdf(f['fit_model']) if 'fit_model' in f else None elif version >= 3: fitted_model = Model.read_hdf(f['fitted_model']) if 'fitted_model' in f else None # Solver results solver_results = cls.decode_solver_results(f['solver_results']) if 'solver_results' in f else None # Settings settings_grp = f['input'] initial_model = Model.read_hdf(settings_grp['model']) if 'model' in settings_grp else None ## Measured networks measured = None measured_grp = None if version <= 3 and 'measured' in settings_grp: measured_grp = settings_grp['measured'] elif 'measured' in f: measured_grp = f['measured'] if measured_grp is not None: if 'name' in measured_grp: net_grp = measured_grp name = net_grp['name'][()] name = name.decode('utf-8') if isinstance(name, bytes) else name s = net_grp['s'][()] f_data = net_grp['f'][()] z0 = net_grp['z0'][()] measured = skrf.Network(s=s, f=f_data, z0=z0, name=name) else: measured = {} for label in measured_grp.keys(): net_grp = measured_grp[label] name = net_grp['name'][()] name = name.decode('utf-8') if isinstance(name, bytes) else name s = net_grp['s'][()] f_data = net_grp['f'][()] z0 = net_grp['z0'][()] network = skrf.Network(s=s, f=f_data, z0=z0, name=name) measured[str(label)] = network ## Frequency and features frequency = None features = None if 'frequency' in settings_grp: freq_grp = settings_grp['frequency'] unit = freq_grp['unit'][()] unit = unit.decode('utf-8') if isinstance(unit, bytes) else unit if 'f_scaled' in freq_grp: f_scaled_arr = freq_grp['f_scaled'][()] frequency = Frequency.from_f(f=f_scaled_arr, unit=unit) else: f_arr = freq_grp['f'][()] frequency = Frequency.from_f(f_arr / MULTIPLIER_DICT[unit.lower()], unit=unit) if 'features' in settings_grp: features = json.loads(settings_grp["features"][()]) ## Solver args, kwargs and fit args, kwargs solver_kwargs, fitter_kwargs = None, None if 'solver_kwargs' in settings_grp: solver_kwargs = jsonpickle.decode(settings_grp['solver_kwargs'][()]) if 'fitter_kwargs' in settings_grp: fitter_kwargs = jsonpickle.decode(settings_grp['fitter_kwargs'][()]) settings = FitSettings(frequency, features, fitter_kwargs, solver_kwargs) return cls( measured=measured, initial_model=initial_model, fitted_model=fitted_model, solver_results=solver_results, settings=settings, )
def is_frequentist(solver) -> bool: from pmrf.fitting._frequentist import FrequentistFitter cls = get_fitter_class(solver) return issubclass(cls, FrequentistFitter) def is_bayesian(solver) -> bool: from pmrf.fitting._bayesian import BayesianFitter cls = get_fitter_class(solver) return issubclass(cls, BayesianFitter) def is_inference_kind(solver, inference: str): if inference == 'frequentist': return is_frequentist(solver) elif inference == 'bayesian': return is_bayesian(solver) else: raise Exception(f"Unknown inference type '{inference}'") def get_fitter_class(solver: str): class_names = [solver + 'Fitter'] class_names.append(''.join(part[0].upper() + part[1:] for part in solver.split('-')) + 'Fitter') try: for submodule_name, _ in iter_submodules('pmrf.fitting.fitters'): fitter_submodel = importlib.import_module(submodule_name) for class_name in class_names: if hasattr(fitter_submodel, class_name): return getattr(fitter_submodel, class_name) except (ImportError, AttributeError): raise Exception(f'Could not find solver named {solver}')