Source code for pmrf.models.model

"""pmrf.model
================

Core model base class and utilities for ParamRF.

This module defines :class:`Model`, a frozen, JAX-compatible, Equinox module
representing an RF network, along with helper utilities like :func:`wrap`.

"""

from functools import cached_property
from copy import deepcopy
from typing import Callable, Sequence, TypeVar
import dataclasses
from dataclasses import fields, is_dataclass
from functools import update_wrapper
import jax.numpy as jnp
import numpy as np
import jsonpickle
from collections.abc import Sequence
from typing import Sequence, Callable

import skrf as skrf
import numpy as np
import h5py
import jax
import jax.numpy as jnp
from jax import flatten_util
from jaxtyping import PyTree
from jax import flatten_util
from jax.tree_util import SequenceKey, GetAttrKey, DictKey, SequenceKey, FlattenedIndexKey
import equinox as eqx
from numpyro.distributions import Distribution

from pmrf.functions.conversions import a2s, s2a
from pmrf.functions.math import FUNC_LOOKUP
from pmrf.parameters import Parameter, ParameterGroup, is_valid_param, asparam
from pmrf.distributions.parameter import JointParameterDistribution
from pmrf.constants import PRIMARY_PROPERTIES, IndexArray, ModelT
from pmrf.frequency import Frequency
from pmrf._util import field, classproperty, is_overridden, get_first_underlying_type
from pmrf._tree import nodes_by_type, value_at_path, partition, combine

ModelT = TypeVar('ModelT', bound='Model')

[docs] class Model(eqx.Module): """ Overview -------- Base class representing a computable RF network, referred to in **ParamRF** as a "Model". This class is abstract and should not be instantiated directly. Derive from :class:`Model` and override one of the primary property functions (e.g. :meth:`__call__`, :meth:`s`, :meth:`a`). The model is an Equinox ``Module`` (immutable, dataclass-like) and is treated as a JAX PyTree. Parameters are declared using standard dataclass field syntax with types like :class:`pmrf.Parameter`. Usage ------------- - Define sub-classes with custom parameters and sub-models - Construct models by passing parameters and/or submodels to the initializer (like a dataclass). - Use `with_xxx` functions to update the model e.g. :meth:`with_params`, :meth:`with_replaced`. - Retrieve parameter information via methods such as :meth:`named_params`, :meth:`param_names`, :meth:`flat_params`, etc.. See also the :mod:`pmrf.fitting` and :mod:`pmrf.sampling` modules for details on model fitting and sampling. Examples ------- A ``PiCLC`` network ("foundational" model with fixed parameters and equations): .. code-block:: python import jax.numpy as jnp import pmrf as prf class PiCLC(prf.Model): C1: prf.Parameter = 1.0e-12 L: prf.Parameter = 1.0e-9 C2: prf.Parameter = 1.0e-12 def a(self, freq: prf.Frequency) -> jnp.ndarray: w = freq.w Y1, Y2, Y3 = (1j * w * self.C1), (1j * w * self.C2), 1 / (1j * w * self.L) return jnp.array([ [1 + Y2 / Y3, 1 / Y3], [Y1 + Y2 + Y1*Y2/Y3, 1 + Y1 / Y3], ]).transpose(2, 0, 1) An ``RLC`` network ("circuit" model with free parameters built using cascading) .. code-block:: python import pmrf as prf from pmrf.models import Resistor, Capacitor, Inductor, Cascade from pmrf.parameters import Uniform class RLC(prf.Model): res: Resistor = Resistor(Uniform(9.0, 11.0)) ind: Inductor = Inductor(Uniform(0.0, 10.0, scale=1e-9)) cap: Capacitor = Capacitor(Uniform(0.0, 10.0, scale=1e-12)) def __call__(self) -> prf.Model: return self.res ** self.ind ** self.cap.terminated() """ # Instance fields name: str | None = field(default=None, kw_only=True, static=True) separator: str = field(default='_', kw_only=True, static=True) _z0: complex = field(default=50.0+0j, kw_only=True, static=True) _param_groups: list[ParameterGroup] = field(default_factory=lambda: list(), kw_only=True, repr=False, static=True) def __init_subclass__(cls, **kwargs): """Customize subclass construction. - Applies default names to default ``Model``/``Parameter`` fields. - Adds converters for ``Parameter``-typed fields (``asparam``). - Protects against mutable-default pitfalls by using ``default_factory``. - Dynamically injects convenience methods like ``s_mag``, ``s_mn_mag``, based on :data:`PRIMARY_PROPERTIES` and :data:`FUNC_LOOKUP`. """ super().__init_subclass__(**kwargs) # Pre-process fields by modifying any defaults as needed, as well as automatically applying converters for field_name, field_types in cls.__annotations__.items(): # Setup field_kwargs = {} default = getattr(cls, field_name, None) if default is None or isinstance(default, dataclasses.Field): continue # Replace default model and parameter names with the field name if isinstance(default, Model) and default.name is None: default = dataclasses.replace(default, name=field_name) if isinstance(default, Parameter) and default.name is None: default = dataclasses.replace(default, name=field_name) # Auto apply asparam converter, to allow auto-conversion of Parameter-annotated structures field_type = get_first_underlying_type(field_types) if field_type is not None and issubclass(field_type, Parameter): if default is not None and not isinstance(default, Parameter): # Common mistake if isinstance(default, tuple): raise Exception(f"Expected a parameter for default '{field_name}' in class {cls} but found a tuple instead") field_kwargs['default'] = default field_kwargs['converter'] = lambda x, field_name=field_name: asparam(x, name=field_name, fixed=True) # field_kwargs['default_factory'] = lambda default=default: default # Apply default_factory to avoid Python default mutable trap if any(isinstance(default, mutable_type) for mutable_type in {list, dict, tuple, Parameter, Model, jnp.ndarray}): field_kwargs['default_factory'] = lambda default=default: deepcopy(default) # Set final field if len(field_kwargs) != 0: setattr(cls, field_name, eqx.field(**field_kwargs)) # Implement dynamic functions for prop in PRIMARY_PROPERTIES: for suffix, lookup in FUNC_LOOKUP.items(): def make_dynamic_method(prop, func): def dynamic_method(self, *args, **kwargs): prop_fn = getattr(self, prop) matrix = prop_fn(*args, **kwargs) return func(matrix) return dynamic_method # First the regular (non-indexed) function e.g. s_mag func_name = f"{prop}_{suffix}" func = lookup[1] m = make_dynamic_method(prop, func) m._pmrf_auto = True setattr(cls, func_name, m) # Then the index function function e.g. s_mn_mag func_name = f"{prop}_mn_{suffix}" func = lookup[1] m = make_dynamic_method(f"{prop}_mn", func) m._pmrf_auto = True setattr(cls, func_name, m) # ---- Private PyTree specs ------------------------------------------------- @property def _param_value_spec(self) -> PyTree: """PyTree spec for **all** parameter values within the model (arrays).""" def is_param_value(path, node): if len(path) == 0 or not eqx.is_inexact_array(node): return False return hasattr(path[-1], 'name') and path[-1].name == 'value' return jax.tree.map_with_path(is_param_value, self, is_leaf=lambda node: eqx.is_inexact_array(node)) @property def _param_object_spec(self) -> PyTree: """PyTree spec for all :class:`Parameter` objects within the model.""" return jax.tree.map(is_valid_param, self, is_leaf=lambda node: is_valid_param(node)) @property def _core_value_spec(self) -> PyTree: """PyTree spec for core (non-derived) parameter values.""" # A parameter is derived if any of its parent dataclasses have 'derived' set to True in its field metadata. path_is_derived = {} def is_leaf(path, node): if is_dataclass(node): for field in fields(node): if field.metadata.get('derived', False): field_path = path + (GetAttrKey(field.name),) path_is_derived[field_path] = True # Set base path as not being derived, and this path's derived as equal to the parents if not already set if len(path) == 0: path_is_derived[path] = False else: path_is_derived.setdefault(path, path_is_derived[path[0:-1]]) return isinstance(node, bool) return jax.tree.map_with_path(lambda path, node: node and not path_is_derived[path], self._param_value_spec, is_leaf=is_leaf, is_leaf_takes_path=True) @property def _core_object_spec(self) -> PyTree: """PyTree spec for core (non-derived) :class:`Parameter` objects.""" return self._param_object_spec # return jax.tree.map(lambda param, core_spec: is_valid_param(param) and core_spec.value, self, self.core_value_spec, is_leaf=lambda node: is_valid_param(node)) @property def _free_value_spec(self) -> PyTree: """PyTree spec for **free** parameter values (arrays).""" # A Pytree filter for all free Model parameter values. def is_not_fixed(path, is_core): if not is_core: return False # Here we check that the array is within a parameter and that the parameter is varying param = value_at_path(self, path[0:-1]) if not isinstance(param, Parameter): raise Exception(f"Found an array in a model not within a parameter: this is not allowed (at path {path})") return not param.fixed return jax.tree.map_with_path(is_not_fixed, self._core_value_spec) @property def _free_object_spec(self) -> PyTree: """PyTree spec for **free** :class:`Parameter` objects.""" # A Pytree filter for all free Model `Parameter` objects. return jax.tree.map(lambda param, fit_spec: is_valid_param(param) and fit_spec.value, self, self._free_value_spec, is_leaf=lambda node: is_valid_param(node)) # ---- Helpers -------------------------------------------------------------- def _path_to_param_name(self, path) -> str | list[str]: """Convert a PyTree path to a fully-qualified parameter name.""" fields = [] for key in path: if isinstance(key, GetAttrKey) or isinstance(key, DictKey): fields.append(key.name) elif isinstance(key, SequenceKey) or isinstance(key, FlattenedIndexKey): fields.append(str(key.idx)) return self.separator.join(fields) def __pow__(self, other: ModelT) -> ModelT: """Cascade composition operator ``**``.""" from pmrf.models.containers import Cascade return Cascade([self, other]) # ---- Defaults / Primary --------------------------------------------------- @classproperty def DEFAULT_PARAMS(cls) -> dict[str, Parameter]: """Default parameters for the model. Returns ------- dict[str, Parameter] Mapping from parameter name to :class:`Parameter`. """ instance = cls() return instance.named_params() @classproperty def DEFAULT_PARAM_NAMES(cls) -> list[str]: """Default parameter names for the model. Returns ------- list[str] """ instance = cls() return instance.param_names() @property def primary_function(self) -> Callable[[Frequency], jnp.ndarray]: """Primary function (``s`` or ``a``) as a callable. Returns ------- Callable[[Frequency], jnp.ndarray] """ return getattr(self, self.primary_property) @property def primary_property(self) -> str: """Name of the primary property (e.g. ``"s"``, ``"a"``). The primary property is the first overridden among :data:`PRIMARY_PROPERTIES`. If ``__call__`` is overridden, delegation occurs to the returned model. Returns ------- str Raises ------ NotImplementedError If no primary property is overridden. """ prioritized = () # for future expansion unprioritized = tuple(p for p in PRIMARY_PROPERTIES if p not in prioritized) if is_overridden(type(self), Model, '__call__'): return self().primary_property for property in prioritized: if is_overridden(type(self), Model, property): return property for property in unprioritized: if is_overridden(type(self), Model, property): return property raise NotImplementedError(f"No primary properties in {PRIMARY_PROPERTIES} are overriden, which are the only ones supported currently") # ---- Introspection -------------------------------------------------------- @cached_property @eqx.filter_jit def number_of_ports(self) -> int: """Number of ports. Returns ------- int """ freq = Frequency(1, 2, 2) eval = jax.eval_shape(lambda: self.s(freq)) return eval.shape[1] @cached_property def num_flat_params(self) -> int: """Number of free, **flattened** parameters. Returns ------- int """ return len(self.flat_params()) @cached_property def num_params(self) -> int: """Number of free parameters. Returns ------- int """ return len(self.named_params()) @property def nports(self) -> int: """Alias of :attr:`number_of_ports`.""" return self.number_of_ports @property def port_tuples(self) -> list[tuple[int, int]]: """All (m, n) port index pairs. Returns ------- list[tuple[int, int]] """ return [(y, x) for x in range(self.nports) for y in range(self.nports)] @property def z0(self) -> jnp.ndarray: """Internal characteristic impedance. Returns ------- jnp.ndarray """ return self._z0 # ---- Core API -------------------------------------------------------------
[docs] def __call__(self) -> 'Model': """Build a compositional circuit representing this model. Returns ------- Model Raises ------ NotImplementedError In the base class; override in derived classes to build a compositional representation. """ raise NotImplementedError(f"Error: cannot use the __call__ method to build a model in the Model class directly")
[docs] @eqx.filter_jit def primary(self, freq: Frequency) -> jnp.ndarray: """Dispatch to the primary function for the given frequency.""" primary_function = self.primary_function return primary_function(freq)
[docs] @eqx.filter_jit def a(self, freq: Frequency) -> jnp.ndarray: """ABCD parameter matrix. If only :meth:`s` is implemented, converts via :func:`pmrf.functions.conversions.s2a`. Parameters ---------- freq : Frequency Frequency grid. Returns ------- jnp.ndarray ABCD matrix with shape ``(nf, 2, 2)``. Raises ------ NotImplementedError If neither ``a`` nor ``s`` is implemented in a subclass. """ if is_overridden(type(self), Model, '__call__'): return self().a(freq) if not is_overridden(type(self), Model, 's'): raise NotImplementedError(f"Error: model sub-classes currently *have* to implement the 's' or the 'a' function, but class {type(self)} has neither") s = self.s(freq) return s2a(s, self.z0)
[docs] @eqx.filter_jit def s(self, freq: Frequency) -> jnp.ndarray: """Scattering parameter matrix. If only :meth:`a` is implemented, converts via :func:`pmrf.functions.conversions.a2s`. Parameters ---------- freq : Frequency Frequency grid. Returns ------- jnp.ndarray S-parameter matrix with shape ``(nf, n, n)``. Raises ------ NotImplementedError If neither ``a`` nor ``s`` is implemented in a subclass. """ if is_overridden(type(self), Model, '__call__'): return self().s(freq) if not is_overridden(type(self), Model, 'a'): raise NotImplementedError(f"Error: model sub-classes currently *have* to implement the 's' or the 'a' function, but class {type(self)} has neither") a = self.a(freq) return a2s(a, self.z0)
# ---- Structure utilities --------------------------------------------------
[docs] def children(self) -> list['Model']: """Immediate submodels. Returns ------- list[Model] """ return [node for node in eqx.tree_flatten_one_level(self)[0] if isinstance(node, Model)]
[docs] def named_children(self) -> dict[str, 'Model']: """Immediate submodels keyed by their ``name`` attribute. Returns ------- dict[str, Model] """ return {child.name: child for child in self.children()}
[docs] def submodels(self) -> list['Model']: """All nested submodels (depth-first), excluding ``self``. Returns ------- list[Model] """ return nodes_by_type(self, Model)[1:]
[docs] def flipped(self) -> 'Model': """Return a version of the model with ports flipped. Returns ------- Model """ from pmrf.models.containers import Flipped return Flipped(self)
[docs] def terminated(self, load: 'Model' = None) -> 'Model': """Terminate a 2-port model in a 1-port load (default: short). Parameters ---------- load : Model, optional Load network. Defaults to a short. Returns ------- Model """ from pmrf.models.lumped import Short from pmrf.models.containers import Cascade load = load or Short() terminated_model = Cascade((self, load)) return terminated_model
[docs] def copy(self: ModelT) -> ModelT: return deepcopy(self)
[docs] def partition(self: ModelT, include_fixed=False, param_objects=False) -> tuple[ModelT, ModelT]: """Partition model into (parameters, static) trees. This is useful for internal use, or for inspecting the model and its parameters. Parameters ---------- include_fixed : bool, default=False Include fixed parameters in the parameter tree. param_objects : bool, default=False If ``True``, keep full :class:`Parameter` objects; otherwise filter to ``.value``. Returns ------- (ModelT, ModelT) ``(params_tree, static_tree)`` """ if param_objects: shared_spec = self._param_object_spec if include_fixed: filter_spec = self._core_object_spec else: filter_spec = self._free_object_spec else: shared_spec = self._param_value_spec if include_fixed: filter_spec = self._core_value_spec else: filter_spec = self._free_value_spec return partition(self, filter_spec, shared_spec)
[docs] def named_params(self, include_fixed=False, flat=False, flat_params=False, submodels: 'Model' | Sequence['Model'] | str | Sequence[str] | None = None) -> dict[str, Parameter]: """Return model parameters as a dict (or flattened array). Keys are fully-qualified parameter names. The order matches the internal flattened array order. Parameters ---------- include_fixed : bool, default=False Include fixed parameters. flat : bool, default=False If ``True``, return a 1D ``jnp.ndarray`` of parameter **values**. flat_params : bool, default=False If ``True``, return a dict of **manufactured** parameter objects (flattened) instead of an array. submodels : Model | Sequence[Model] | str | Sequence[str] | None, optional Restrict to parameters used by the given submodel(s). If strings are provided, ``getattr(self, name)`` is used. Returns ------- dict[str, Parameter] or jnp.ndarray """ spec = self._core_object_spec if include_fixed else self._free_object_spec params_tree = eqx.filter(self, spec, is_leaf=is_valid_param) path_and_params = jax.tree.flatten_with_path(params_tree, is_leaf=is_valid_param) _params: dict[str, Parameter] = {self._path_to_param_name(path): param for path, param in path_and_params[0]} if flat | flat_params: flat_params_dict = {} for name, param in _params.items(): if param.ndim > 1: param_flattened = param.flatten(separator=self.separator) for i, subparam in enumerate(param_flattened): flat_params_dict[f'{name}{self.separator}{i}'] = subparam else: flat_params_dict[name] = param _params = flat_params_dict if submodels is not None: if isinstance(submodels, Model) or isinstance(submodels, str): submodels = [submodels] if len(submodels) != 0 and isinstance(submodels[0], str): submodels = [getattr(self, name) for name in submodels] submodel_param_values = [param for submodel in submodels for param in submodel.params(include_fixed=include_fixed, flat=flat).values()] _params = {k: v for k, v in _params.items() if any(v is p for p in submodel_param_values)} if flat and not flat_params: return jnp.array([p.value for p in _params.values()]) return _params
[docs] def flat_params(self, *args, **kwargs) -> jnp.ndarray: """Shorthand for ``named_params(flat=True, ...)``.""" return self.named_params(*args, flat=True, **kwargs)
[docs] def flat_param_names(self) -> list[str]: """Flat parameter names (matching :meth:`flat_params` order).""" return list(self.named_params(flat_params=True).keys())
[docs] def param_names(self, *args, **kwargs): """Parameter names for current (possibly filtered) selection.""" params = self.named_params(*args, **kwargs) return list(params.keys())
[docs] def param_groups(self, include_fixed=False) -> list[ParameterGroup]: """Return all parameter groups relevant to this model. This function is useful for fitters that need to take into account e.g. any constraints or correlated priors. Only groups that were create with the same `flat` flag are returned. Currently, only groups for the current model are considered (i.e. groups in children models are ignored). Parameters ---------- include_fixed : bool, default=False Include groups involving fixed parameters. Returns ------- list[ParameterGroup] """ params = self.named_params(flat=True, flat_params=True, include_fixed=include_fixed) groups = [group for group in self._param_groups] grouped_param_names = {name for group in groups for name in group.param_names} for name, param in params.items(): if name not in grouped_param_names: groups.append(ParameterGroup(param_names=[name], prior=param.prior)) return groups
[docs] def prior(self) -> Distribution: """Joint prior distribution over (flattened) parameters.""" return JointParameterDistribution(self.param_groups(), self.flat_param_names())
[docs] def with_params( self: ModelT, params: dict[str, Parameter] | dict[str, float] | jnp.ndarray | None = None, check_missing: bool = False, check_unknown: bool = False, fix_others = False, include_fixed = False, **param_kwargs: dict[str, Parameter] | dict[str, float], ) -> ModelT: """Return a new model with core parameters updated. Parameters ---------- params : dict[str, Parameter] | dict[str, float] | jnp.ndarray | None, optional Parameter updates. If an array, **all** values must be provided (matching ``flat_params`` order). You may also pass keyword args. check_missing : bool, default=False Require that all model parameters are specified. check_unknown : bool, default=False Error if unknown parameter keys are provided. fix_others : bool, default=False Fix any parameters not explicitly passed. include_fixed : bool, default=False Include fixed parameters when interpreting ``params`` mapping. **param_kwargs : dict Additional parameter updates by name. Returns ------- ModelT Raises ------ Exception If shape/order mismatches, unknown/missing names (when checked), or if arrays are found outside of Parameters. """ # Prepare input if isinstance(params, dict) or len(param_kwargs) != 0: if include_fixed: raise Exception('Not yet supported') params = params if params is not None else {} params.update(param_kwargs) # Generate an ordered, input flat params array for verification new_params = self.named_params(include_fixed=True) # Validate the callers's input unknown_params = set(params.keys() - new_params.keys()) if check_unknown and len(unknown_params) != 0: raise Exception(f"Error: the following parameters were passed but are not in the model: {unknown_params}") params = {k: v for k, v in params.items() if k not in unknown_params} if check_missing or fix_others: missing_params = set(new_params.keys() - params.keys()) if check_missing and len(missing_params) != 0: raise Exception(f"Error: the following model parameters were missing: {missing_params}") if fix_others: for missing_param_name in missing_params: new_params[missing_param_name] = dataclasses.replace(new_params[missing_param_name], fixed=True) def is_convertible_to_float(x): try: float(x) return True except (ValueError, TypeError): return False # Convert to an array of parameters instead of floats if all(is_convertible_to_float(v) for v in params.values()): for name, value in params.items(): # TODO create specs for the full parameter objects such that we can get and use the built-in scales new_params[name] = dataclasses.replace(new_params[name], value=jnp.array(value), scale=1.0) else: new_params.update(params) new_flat_params = list(new_params.values()) # Get the current flat parameter object params_tree, static = partition(self, self._core_object_spec, self._param_object_spec, is_leaf=is_valid_param) flat_params, treedef = jax.tree.flatten(params_tree, is_leaf=is_valid_param) # We allow the caller to pass None for name and then we update the name. Otherwise names should match for i, param in enumerate(flat_params): if new_flat_params[i].name == None: new_flat_params[i] = dataclasses.replace(new_flat_params[i], name=param.name) # Create the update tree and return new_params_tree = jax.tree.unflatten(treedef, new_flat_params) combined: Model = combine(new_params_tree, static, is_leaf=is_valid_param) return combined else: params = jnp.array(params) if params.shape[0] != self.num_flat_params: raise Exception(f'Expected {self.num_flat_params} flat parameters but was passed {params.shape[0]}') params_tree, static = self.partition(include_fixed=include_fixed) params_out, unravel_fn = flatten_util.ravel_pytree(params_tree) if jnp.isscalar(params_out) or params_out.shape[0] == 0: raise Exception("Error: no free model parameters found to make feature function") params_tree_recon = unravel_fn(params) return combine(params_tree_recon, static)
[docs] def with_flat_params(self, *args, **kwargs): """Alias for :meth:`with_params` when passing a flat array.""" return self.with_params(*args, **kwargs)
[docs] def with_replaced(self: ModelT, *args, **kwargs) -> ModelT: """Return a copy with dataclass-style field replacements.""" return dataclasses.replace(self, *args, **kwargs)
[docs] def with_param_groups(self: ModelT, param_groups: ParameterGroup | list[ParameterGroup]) -> ModelT: """Return a a model with parameter groups appended. Parmaeter groups can be used to specify relationships between parameters in the model, such as joint priors. Parameters ---------- param_groups : ParameterGroup or list[ParameterGroup] Group(s) to add. Returns ------- ModelT """ if not isinstance(param_groups, list): param_groups = [param_groups] param_groups_old = self._param_groups.copy() if self._param_groups is not None else [] param_groups_new = param_groups_old param_groups_new.extend(param_groups) param_groups_new = [param_group for param_group in param_groups_new] return dataclasses.replace(self, _param_groups=param_groups_new)
[docs] def with_fixed_params(self: ModelT, *params, check_unknown=True) -> ModelT: """Return a model with specified parameters fixed. Parameters ---------- *params : str Parameter names to fix. check_unknown : bool, default=True Error if any provided name does not exist. Returns ------- ModelT """ if isinstance(params, str): params = [params] params = set(params) current_params = self.named_params() current_param_names = set(current_params.keys()) if check_unknown: for param_name in params: if param_name not in current_param_names: raise Exception(f"Specified parameter '{param_name}'' not found in model") new_params = current_params.copy() for name, param in current_params.items(): if name in params: new_params[name] = param.as_fixed() return self.with_params(new_params)
[docs] def with_free_params(self: ModelT, *params, fix_others=True, check_unknown=True) -> ModelT: """Free the specified parameters. Parameters ---------- *params : str Parameter names to set free. fix_others : bool, default=True Fix parameters not specified. check_unknown : bool, default=True Error if any provided name does not exist. Returns ------- ModelT """ if isinstance(params, str): params = [params] params = set(params) current_params = self.named_params(include_fixed=True) current_param_names = set(current_params.keys()) if check_unknown: for param_name in params: if param_name not in current_param_names: raise Exception(f"Specified parameter '{param_name}'' not found in model") new_params = current_params.copy() for name, param in current_params.items(): if name in params: new_params[name] = param.as_free() elif fix_others: new_params[name] = param.as_fixed() return self.with_params(new_params)
[docs] def with_free_submodels(self: ModelT, free_submodels: 'Model' | Sequence['Model'] | str | Sequence[str]) -> ModelT: """Fix all parameters **except** those appearing in the given submodels. The submodels can be any models that reference the parameters in this model. Specifically, `free_submodels` can consist of direct children of this model, submodels of those direct children, and any models that are built using the parameters of this model. Then, only the parameters that are found in both the submodels and this model are set to be free, and all others are fixed. Parameters ---------- free_submodels : Model | Sequence[Model] | str | Sequence[str] Submodels whose parameters should remain free. If strings are provided, they are resolved via ``getattr(self, name)``. Returns ------- ModelT """ # TODO this function ends up error in `delias` if `free_submodels` is empty if isinstance(free_submodels, Model) or isinstance(free_submodels, str): free_submodels: list[Model] = [free_submodels] if len(free_submodels) != 0 and isinstance(free_submodels[0], str): free_submodels: list[Model] = [getattr(self, name) for name in free_submodels] allowed_free_param_values = [param for source in free_submodels for param in source.named_params().values()] allowed_free_params = {k: v for k, v in self.named_params().items() if any(v is p for p in allowed_free_param_values)} return self.with_params(allowed_free_params, fix_others=True)
[docs] def to_skrf(self, frequency: Frequency | skrf.Frequency, sigma=0.0, **kwargs) -> skrf.Network: """Convert the model at frequencies to an :class:`skrf.Network`. The active primary property (``self.primary_property``) is used. Parameters ---------- frequency : pmrf.Frequency | skrf.Frequency Frequency grid. sigma : float, default=0.0 If nonzero, add complex Gaussian noise with stdev ``sigma`` to ``s``. **kwargs Forwarded to :class:`skrf.Network` constructor. Returns ------- skrf.Network """ if isinstance(frequency, Frequency): model_freq = frequency measured_freq = frequency.to_skrf() else: model_freq = Frequency.from_skrf(frequency) measured_freq = frequency fval, fname = self.primary(model_freq), self.primary_property kwargs = kwargs or {} kwargs.update({ fname: fval, 'frequency': measured_freq, 'name': kwargs.get('name', self.name), 'z0': self._z0, }) ntwk = skrf.Network(**kwargs) if sigma != 0.0: ntwk.s += (np.random.normal(0, sigma, ntwk.s.shape) + 1j * np.random.normal(0, sigma, ntwk.s.shape)) return ntwk
[docs] def write_hdf(self, group: h5py.Group): """Write model into an HDF5 group. Parameters ---------- group : h5py.Group Target group. Two subgroups are created: ``raw`` and ``params``. """ params_tree, static_tree = self.partition(include_fixed=True, param_objects=True) params = self.named_params() model_raw_grp = group.create_group('raw') model_raw_grp.create_dataset('params', data=jsonpickle.encode(params_tree)) model_raw_grp.create_dataset('static', data=jsonpickle.encode(static_tree)) params_grp = group.create_group('params') for name, initial_param in params.items(): params_grp[name] = initial_param.to_json()
[docs] @classmethod def read_hdf(cls, group: h5py.Group) -> ModelT: """Read model from an HDF5 group. Parameters ---------- group : h5py.Group Group previously created via :meth:`write_hdf`. Returns ------- ModelT | None The decoded model, or ``None`` if decoding fails. """ model_raw_grp = group['raw'] params_json = model_raw_grp['params'][()] params_json = params_json.decode('utf-8') if isinstance(params_json, bytes) else params_json static_json = model_raw_grp['static'][()] static_json = static_json.decode('utf-8') if isinstance(static_json, bytes) else static_json try: params_tree = jsonpickle.decode(params_json) static_tree = jsonpickle.decode(static_json) # NB the following hack actually also BREAKS some model loading... we need to investigate further # The following fixes some quirks when e.g. the original model contains lambdas. # Not sure 100% why but some fields seem to be in a "bad" state when jsonpickle cant deserialize them # params_tree = dataclasses.replace(params_tree) # static_tree = dataclasses.replace(static_tree) return eqx.combine(params_tree, static_tree) except: return None
[docs] def save(self, filepath: str): """Serialize the model to disk in pickle format. Parameters ---------- filepath : str Destination file path. """ json = jsonpickle.encode(self) with open(filepath, 'w') as f: f.write(json)
[docs] @classmethod def load(cls: ModelT, filepath: str) -> ModelT: """Load a model from a pickled file. Parameters ---------- filepath : str Source JSON path. Returns ------- ModelT """ with open(filepath, 'r') as f: return jsonpickle.decode(f.read())
[docs] def export_touchstone(self, frequency: Frequency | skrf.Frequency, filename: str, sigma: float = 0.0, **skrf_kwargs): """Export the model response to a Touchstone file via scikit-rf. Parameters ---------- frequency : Frequency | skrf.Frequency filename : str sigma : float, default=0.0 Additive complex noise std for S-parameters. **skrf_kwargs Forwarded to :meth:`skrf.Network.write_touchstone`. Returns ------- Any Return value of ``Network.write_touchstone``. """ ntwk = self.to_skrf(frequency, sigma=sigma) retval = ntwk.write_touchstone(filename, **skrf_kwargs) return retval
[docs] def wrap( func: Callable, *args, as_numpy: bool = False, ) -> Callable: """Wrap a function/method ``(model, frequency, *args, **kwargs)`` to accept flat arrays. The wrapper accepts ``(theta, [f], *args, **kwargs)`` where ``theta`` is a flat parameter array, and optionally ``f`` is a frequency array with a provided unit. Supported calls --------------- - ``wrap(userfunc, model, "MHz")`` - ``wrap(userfunc, model, freq)`` - ``wrap(model.userfunc, "MHz")`` - ``wrap(model.userfunc, freq)`` If a :class:`Frequency` object is passed, it is closed over and the wrapped function only accepts ``theta``. Parameters ---------- func : Callable Function or bound method taking ``(model, frequency, ...)``. *args : tuple Either ``(model, Frequency|unit)`` or just ``(Frequency|unit)`` if ``func`` is bound. as_numpy : bool, default=False If ``True``, JIT the wrapper and convert inputs/outputs to NumPy arrays. Returns ------- Callable Wrapped function accepting ``(theta_array, [f_array], *args, **kwargs)``. Raises ------ TypeError If the frequency argument is neither a unit string nor ``Frequency``. """ # Determine if func is a bound method if hasattr(func, '__self__') and func.__self__ is not None: model: Model = func.__self__ func = func.__func__ args = list(args) else: model: Model = args[0] args = args[1:] # Now args[0] should be Frequency or unit frequency_or_unit = args[0] if args else "Hz" # Validate if not isinstance(frequency_or_unit, (str, Frequency)): raise TypeError("Expected Frequency or unit string as second argument") use_variable_frequency = isinstance(frequency_or_unit, str) # Define wrapped function def wrapped_with_freq(theta: jnp.ndarray, f_array: jnp.ndarray, *f_args, **f_kwargs): new_model = model.with_flat_params(theta) frequency = Frequency.from_f(f_array, unit=frequency_or_unit) return func(new_model, frequency, *f_args, **f_kwargs) def wrapped_fixed_freq(theta: jnp.ndarray, *f_args, **f_kwargs): new_model = model.with_flat_params(theta) return func(new_model, frequency_or_unit, *f_args, **f_kwargs) wrapped_fn = wrapped_with_freq if use_variable_frequency else wrapped_fixed_freq if as_numpy: raw_fn = jax.jit(wrapped_fn) if use_variable_frequency: wrapped_fn = lambda theta, f, *a, **kw: np.array(raw_fn(jnp.array(theta), jnp.array(f), *a, **kw)) else: wrapped_fn = lambda theta, *a, **kw: np.array(raw_fn(jnp.array(theta), *a, **kw)) # Forward metadata update_wrapper(wrapped_fn, func) return wrapped_fn