"""pmrf.model
================
Core model base class and utilities for ParamRF.
This module defines :class:`Model`, a frozen, JAX-compatible, Equinox module
representing an RF network, along with helper utilities like :func:`wrap`.
"""
from functools import cached_property, partial
from copy import deepcopy
from typing import Callable, Sequence, TypeVar, Union, Iterator
import dataclasses
from dataclasses import fields, is_dataclass
from functools import update_wrapper
import jax.numpy as jnp
import numpy as np
import jsonpickle
from collections.abc import Sequence
from typing import Sequence, Callable, BinaryIO, TypeVar
import os
import skrf as skrf
import numpy as np
import h5py
import jax
import jax.numpy as jnp
from jax import flatten_util
from jaxtyping import PyTree
from jax import flatten_util
from jax.tree_util import SequenceKey, GetAttrKey, DictKey, SequenceKey, FlattenedIndexKey
import equinox as eqx
from numpyro.distributions import Distribution, Uniform as UniformDistribution
from pmrf.network_collection import NetworkCollection
from pmrf.functions.conversions import a2s, s2a
from pmrf.functions.math import FUNC_LOOKUP
from pmrf.parameters import Parameter, ParameterGroup, is_valid_param, asparam
from pmrf.distributions.parameter import JointParameterDistribution
from pmrf.constants import PRIMARY_PROPERTIES, IndexArray, ModelT
from pmrf.frequency import Frequency
from pmrf._util import field, classproperty, is_overridden, get_first_underlying_type
from pmrf._tree import nodes_by_type, value_at_path, partition, combine
ModelT = TypeVar('ModelT', bound='Model')
[docs]
class Model(eqx.Module):
"""
Overview
--------
This base class is used to represent any computable RF network, referred to in
**ParamRF** as a "Model". This class is abstract and should not be
instantiated directly. Derive from :class:`Model` and override one of
the primary property functions (e.g. :meth:`__call__`, :meth:`s`, :meth:`a`).
The model is an Equinox ``Module`` (immutable, dataclass-like) and is
treated as a JAX PyTree. Parameters are declared using standard dataclass
field syntax with types like :class:`pmrf.Parameter`.
Usage
-----
- Define new models by sub-classing the model and adding custom parameters and/or sub-models
- Construct models by passing parameters and/or submodels to the initializer (like a dataclass).
- Retrieve parameter information via methods such as :meth:`named_params`, :meth:`param_names`, :meth:`flat_params`, etc..
- Use `with_xxx` functions to modify fields, models and parameters within the model e.g. :meth:`with_params`, :meth:`with_fields`.
- Use "past tense" functions to modify the model in conjuction with another model or data e.g. :meth:`terminated`, :meth:`fitted`.
See also the :mod:`pmrf.fitting` and :mod:`pmrf.sampling` modules for details on model fitting and sampling.
Attributes
----------
name : str or None
An optional name for the model instance.
separator : str
The separator character used when flattening nested parameter names (default is '_').
metadata : dict
A dictionary for storing arbitrary metadata associated with the model (e.g., fit results).
Examples
--------
A ``PiCLC`` network ("foundational" model with fixed parameters and equations):
.. code-block:: python
import jax.numpy as jnp
import pmrf as prf
class PiCLC(prf.Model):
C1: prf.Parameter = 1.0e-12
L: prf.Parameter = 1.0e-9
C2: prf.Parameter = 1.0e-12
def a(self, freq: prf.Frequency) -> jnp.ndarray:
w = freq.w
Y1, Y2, Y3 = (1j * w * self.C1), (1j * w * self.C2), 1 / (1j * w * self.L)
return jnp.array([
[1 + Y2 / Y3, 1 / Y3],
[Y1 + Y2 + Y1*Y2/Y3, 1 + Y1 / Y3],
]).transpose(2, 0, 1)
An ``RLC`` network ("circuit" model with free parameters built using cascading)
.. code-block:: python
import pmrf as prf
from pmrf.models import Resistor, Capacitor, Inductor, Cascade
from pmrf.parameters import Uniform
class RLC(prf.Model):
res: Resistor = Resistor(Uniform(9.0, 11.0))
ind: Inductor = Inductor(Uniform(0.0, 10.0, scale=1e-9))
cap: Capacitor = Capacitor(Uniform(0.0, 10.0, scale=1e-12))
def __call__(self) -> prf.Model:
return self.res ** self.ind ** self.cap.terminated()
"""
# Public instance fields
name: str | None = field(default=None, kw_only=True, static=True)
separator: str = field(default='_', kw_only=True, static=True)
metadata: dict = field(default_factory=lambda: dict(), kw_only=True, static=True)
# Private instance fields
_z0: complex = field(default=50.0+0j, kw_only=True, static=True)
_param_groups: list[ParameterGroup] = field(default_factory=lambda: list(), kw_only=True, repr=False, static=True)
# ---- Internal initialization methods -------------------------------------------------
def __init_subclass__(cls, **kwargs):
"""Customize subclass construction.
- Applies default names to default ``Model``/``Parameter`` fields.
- Adds converters for ``Parameter``-typed fields (``asparam``).
- Protects against mutable-default pitfalls by using ``default_factory``.
- Dynamically injects convenience methods like ``s_mag``, ``s_mn_mag``,
based on :data:`PRIMARY_PROPERTIES` and :data:`FUNC_LOOKUP`.
"""
super().__init_subclass__(**kwargs)
# Pre-process fields by modifying any defaults as needed, as well as automatically applying converters
for field_name, field_types in cls.__annotations__.items():
# Setup
field_kwargs = {}
default = getattr(cls, field_name, None)
if default is None or isinstance(default, dataclasses.Field):
continue
# Replace default model and parameter names with the field name
if isinstance(default, Model) and default.name is None:
default = dataclasses.replace(default, name=field_name)
if isinstance(default, Parameter) and default.name is None:
default = dataclasses.replace(default, name=field_name)
# Auto apply asparam converter, to allow auto-conversion of Parameter-annotated structures
field_type = get_first_underlying_type(field_types)
if field_type is not None and issubclass(field_type, Parameter):
if default is not None and not isinstance(default, Parameter):
# Common mistake
if isinstance(default, tuple):
raise Exception(f"Expected a parameter for default '{field_name}' in class {cls} but found a tuple instead")
field_kwargs['default'] = default
field_kwargs['converter'] = lambda x, field_name=field_name: asparam(x, name=field_name, fixed=True)
# field_kwargs['default_factory'] = lambda default=default: default
# Apply default_factory to avoid Python default mutable trap
if any(isinstance(default, mutable_type) for mutable_type in {list, dict, tuple, Parameter, Model, jnp.ndarray}):
field_kwargs['default_factory'] = lambda default=default: deepcopy(default)
# Set final field
if len(field_kwargs) != 0:
setattr(cls, field_name, eqx.field(**field_kwargs))
# Implement dynamic functions
for prop in PRIMARY_PROPERTIES:
for suffix, lookup in FUNC_LOOKUP.items():
def make_dynamic_method(prop, func):
def dynamic_method(self, *args, **kwargs):
prop_fn = getattr(self, prop)
matrix = prop_fn(*args, **kwargs)
return func(matrix)
return dynamic_method
# First the regular (non-indexed) function e.g. s_mag
func_name = f"{prop}_{suffix}"
func = lookup[1]
m = make_dynamic_method(prop, func)
m._pmrf_auto = True
setattr(cls, func_name, m)
# Then the index function function e.g. s_mn_mag
func_name = f"{prop}_mn_{suffix}"
func = lookup[1]
m = make_dynamic_method(f"{prop}_mn", func)
m._pmrf_auto = True
setattr(cls, func_name, m)
# ---- Internal PyTree manipulation, introspection and helpers -------------------------------------------------
@property
def _param_value_spec(self) -> PyTree:
"""PyTree spec for **all** parameter values within the model (arrays)."""
def is_param_value(path, node):
if len(path) == 0 or not eqx.is_inexact_array(node):
return False
return hasattr(path[-1], 'name') and path[-1].name == 'value'
return jax.tree.map_with_path(is_param_value, self, is_leaf=lambda node: eqx.is_inexact_array(node))
@property
def _param_object_spec(self) -> PyTree:
"""PyTree spec for all :class:`Parameter` objects within the model."""
return jax.tree.map(is_valid_param, self, is_leaf=lambda node: is_valid_param(node))
@property
def _core_value_spec(self) -> PyTree:
"""PyTree spec for core (non-derived) parameter values."""
# A parameter is derived if any of its parent dataclasses have 'derived' set to True in its field metadata.
path_is_derived = {}
def is_leaf(path, node):
if is_dataclass(node):
for field in fields(node):
if field.metadata.get('derived', False):
field_path = path + (GetAttrKey(field.name),)
path_is_derived[field_path] = True
# Set base path as not being derived, and this path's derived as equal to the parents if not already set
if len(path) == 0:
path_is_derived[path] = False
else:
path_is_derived.setdefault(path, path_is_derived[path[0:-1]])
return isinstance(node, bool)
return jax.tree.map_with_path(lambda path, node: node and not path_is_derived[path], self._param_value_spec, is_leaf=is_leaf, is_leaf_takes_path=True)
@property
def _core_object_spec(self) -> PyTree:
"""PyTree spec for core (non-derived) :class:`Parameter` objects."""
return self._param_object_spec
# return jax.tree.map(lambda param, core_spec: is_valid_param(param) and core_spec.value, self, self.core_value_spec, is_leaf=lambda node: is_valid_param(node))
@property
def _free_value_spec(self) -> PyTree:
"""PyTree spec for **free** parameter values (arrays)."""
# A Pytree filter for all free Model parameter values.
def is_not_fixed(path, is_core):
if not is_core:
return False
# Here we check that the array is within a parameter and that the parameter is varying
param = value_at_path(self, path[0:-1])
if not isinstance(param, Parameter):
raise Exception(f"Found an array in a model not within a parameter: this is not allowed (at path {path})")
return not param.fixed
return jax.tree.map_with_path(is_not_fixed, self._core_value_spec)
@property
def _free_object_spec(self) -> PyTree:
"""PyTree spec for **free** :class:`Parameter` objects."""
# A Pytree filter for all free Model `Parameter` objects.
return jax.tree.map(lambda param, fit_spec: is_valid_param(param) and fit_spec.value, self, self._free_value_spec, is_leaf=lambda node: is_valid_param(node))
def _path_to_param_name(self, path) -> str | list[str]:
"""Convert a PyTree path to a fully-qualified parameter name."""
fields = []
for i, item in enumerate(path):
if isinstance(item, GetAttrKey):
fields.append(item.name)
elif isinstance(item, DictKey):
fields.pop() # don't include the dict variable name in the path
fields.append(item.key)
elif isinstance(item, SequenceKey) or isinstance(item, FlattenedIndexKey):
# This code adds the ability for models to be named and their names used for the parameter name instead of the list variable name and index
value = value_at_path(self, path[:i+1])
if hasattr(value, 'name') and not value.name is None:
fields.pop()
fields.append(value.name)
else:
fields.append(str(item.idx))
return self.separator.join(fields)
def _partition(self: ModelT, include_fixed=False, param_objects=False) -> tuple[ModelT, ModelT]:
"""Partition model into (parameters, static) trees.
This is useful for internal use, or for inspecting the model
and its parameters.
Parameters
----------
include_fixed : bool, default=False
Include fixed parameters in the parameter tree.
param_objects : bool, default=False
If ``True``, keep full :class:`Parameter` objects; otherwise filter to ``.value``.
Returns
-------
(ModelT, ModelT)
``(params_tree, static_tree)``
"""
if param_objects:
shared_spec = self._param_object_spec
if include_fixed:
filter_spec = self._core_object_spec
else:
filter_spec = self._free_object_spec
else:
shared_spec = self._param_value_spec
if include_fixed:
filter_spec = self._core_value_spec
else:
filter_spec = self._free_value_spec
return partition(self, filter_spec, shared_spec)
def _iter_params(
self,
*,
include_fixed: bool = False,
flatten: bool = False,
submodels: 'Model | Sequence[Model] | str | Sequence[str] | None' = None,
) -> Iterator[tuple[str, Parameter]]:
"""Iterate over (name, Parameter) pairs in internal order."""
spec = self._core_object_spec if include_fixed else self._free_object_spec
params_tree = eqx.filter(self, spec, is_leaf=is_valid_param)
path_and_params, _ = jax.tree.flatten_with_path(params_tree, is_leaf=is_valid_param)
params: list[str, Parameter] = [(self._path_to_param_name(path), param) for path, param in path_and_params]
# Submodel filtering
if submodels is not None:
if isinstance(submodels, (Model, str)):
submodels: list[Model] = [submodels]
if submodels and isinstance(submodels[0], str):
submodels: list[Model] = [getattr(self, name) for name in submodels]
allowed = {id(p) for sm in submodels for p in sm.params(include_fixed=include_fixed)}
params = [(k, v) for k, v in params if id(v) in allowed]
# Flatten multi-dimensional parameters if requested
if flatten:
flat_params: list[str, Parameter] = []
for name, param in params:
if param.ndim > 1:
flattened_params = param.flattened(separator=self.separator)
for i, subparam in enumerate(flattened_params):
flat_params.append((f"{name}{self.separator}{i}", subparam))
else:
flat_params.append((name, param))
params = flat_params
yield from params
# ---- Defaults / Primary ---------------------------------------------------
@classproperty
def DEFAULT_NAMED_PARAMS(cls) -> dict[str, Parameter]:
"""Default named parameters for the model.
Returns
-------
dict[str, Parameter]
Mapping from parameter name to :class:`Parameter`.
"""
instance = cls()
return instance.named_params()
@classproperty
def DEFAULT_PARAM_NAMES(cls) -> list[str]:
"""Default parameter names for the model.
Returns
-------
list[str]
"""
instance = cls()
return instance.param_names()
@classproperty
def DEFAULT_PARAMS(cls) -> list[Parameter]:
"""Default parameters for the model.
Returns
-------
list[str]
"""
instance = cls()
return instance.params()
@property
def primary_function(self) -> Callable[[Frequency], jnp.ndarray]:
"""The primary function (``s`` or ``a``) as a callable.
The primary function is the first overridden among
:data:`PRIMARY_PROPERTIES`, unless ``__call__`` is overridden,
in which case the primary function of the built model is returned.
Returns
-------
Callable[[Frequency], jnp.ndarray]
Raises
------
NotImplementedError
If no primary property is overridden.
"""
return getattr(self, self.primary_property)
@property
def primary_property(self) -> str:
"""The primary property (e.g. ``"s"``, ``"a"``) as a string.
The primary property is the first overridden among
:data:`PRIMARY_PROPERTIES`, unless ``__call__`` is overridden,
in which case the primary property of the built model is returned.
Returns
-------
str
Raises
------
NotImplementedError
If no primary property is overridden.
"""
prioritized = () # for future expansion
unprioritized = tuple(p for p in PRIMARY_PROPERTIES if p not in prioritized)
if is_overridden(type(self), Model, '__call__'):
return self().primary_property
for property in prioritized:
if is_overridden(type(self), Model, property):
return property
for property in unprioritized:
if is_overridden(type(self), Model, property):
return property
raise NotImplementedError(f"No primary properties in {PRIMARY_PROPERTIES} are overriden, which are the only ones supported currently")
[docs]
def copy(self: ModelT) -> ModelT:
"""Returns a deepcopy of self.
Returns
-------
Model
"""
return deepcopy(self)
# ---- Introspection properties --------------------------------------------------------
@cached_property
@eqx.filter_jit
def number_of_ports(self) -> int:
"""Number of ports.
Returns
-------
int
"""
freq = Frequency(1, 2, 2)
eval = jax.eval_shape(lambda: self.s(freq))
return eval.shape[1]
@property
def nports(self) -> int:
"""Alias of :attr:`number_of_ports`."""
return self.number_of_ports
@property
def port_tuples(self) -> list[tuple[int, int]]:
"""All (m, n) port index pairs.
Returns
-------
list[tuple[int, int]]
"""
return [(y, x) for x in range(self.nports) for y in range(self.nports)]
@property
def z0(self) -> jnp.ndarray:
"""Internal characteristic impedance.
Returns
-------
jnp.ndarray
"""
return jnp.array(self._z0)
@cached_property
def num_params(self) -> int:
"""Number of free parameters.
Returns
-------
int
"""
return len(self.params())
@cached_property
def num_flat_params(self) -> int:
"""Number of free, **flattened** parameters.
Returns
-------
int
"""
return len(self.flat_params())
# ---- Core API -------------------------------------------------------------
[docs]
def __call__(self) -> 'Model':
"""Build the model.
This function should be over-ridden by sub-classes.
It is useful in defining complex models that are built
using several sub-models (as opposed to equation-based models).
Returns
-------
Model
Raises
------
NotImplementedError
In the base class; override in derived classes to build
a compositional representation.
"""
raise NotImplementedError(f"Error: cannot use the __call__ method to build a model in the Model class directly")
[docs]
@eqx.filter_jit
def primary(self, freq: Frequency) -> jnp.ndarray:
"""Dispatch to the primary function for the given frequency."""
primary_function = self.primary_function
return primary_function(freq)
[docs]
@eqx.filter_jit
def a(self, freq: Frequency) -> jnp.ndarray:
"""Calculates the ABCD parameter matrix.
If only :meth:`s` is implemented, this is calculated using the conversion :func:`pmrf.functions.conversions.s2a`.
Parameters
----------
freq : Frequency
Frequency grid.
Returns
-------
jnp.ndarray
ABCD matrix with shape ``(nf, 2, 2)``.
Raises
------
NotImplementedError
If neither ``a`` nor ``s`` is implemented in a subclass.
"""
if is_overridden(type(self), Model, '__call__'):
return self().a(freq)
if not is_overridden(type(self), Model, 's'):
raise NotImplementedError(f"Error: model sub-classes currently *have* to implement the 's' or the 'a' function, but class {type(self)} has neither")
s = self.s(freq)
return s2a(s, self.z0)
[docs]
@eqx.filter_jit
def s(self, freq: Frequency) -> jnp.ndarray:
"""Scattering parameter matrix.
If only :meth:`a` is implemented, converts via :func:`pmrf.functions.conversions.a2s`.
Parameters
----------
freq : Frequency
Frequency grid.
Returns
-------
jnp.ndarray
S-parameter matrix with shape ``(nf, n, n)``.
Raises
------
NotImplementedError
If neither ``a`` nor ``s`` is implemented in a subclass.
"""
if is_overridden(type(self), Model, '__call__'):
return self().s(freq)
if not is_overridden(type(self), Model, 'a'):
raise NotImplementedError(f"Error: model sub-classes currently *have* to implement the 's' or the 'a' function, but class {type(self)} has neither")
a = self.a(freq)
return a2s(a, self.z0)
# ---- Structure utilities --------------------------------------------------
[docs]
def children(self) -> list['Model']:
"""Immediate submodels.
Returns
-------
list[Model]
"""
return [node for node in eqx.tree_flatten_one_level(self)[0] if isinstance(node, Model)]
[docs]
def submodels(self) -> list['Model']:
"""All nested submodels (depth-first), excluding ``self``.
Returns
-------
list[Model]
"""
return nodes_by_type(self, Model)[1:]
# ---- Container model building --------------------------------------------------
def __pow__(self, other: 'Model') -> 'Model':
"""Cascade composition operator ``**``."""
from pmrf.models.containers import Cascade
return Cascade([self, other])
[docs]
def connected(self, others: 'Model' | Sequence['Model'], ports: Sequence[int | Sequence[int]]) -> 'Model':
"""Return a version of the model with specified ports connected.
See :class:``pmrf.models.containers.Connected``.
Returns
-------
Model
"""
from pmrf.models.containers import Connected
if isinstance(others, Model):
others = [others]
models = [self, *others]
return Connected(models, ports)
[docs]
def flipped(self) -> 'Model':
"""Return a version of the model with ports flipped.
Returns
-------
Model
"""
from pmrf.models.containers import Flipped
if isinstance(self, Flipped):
return self.model
return Flipped(self)
[docs]
def terminated(self, load: 'Model' = None) -> 'Model':
"""Returns a new model that contains this 2-port model terminated in a 1-port load (default: SHORT).
Parameters
----------
load : Model, optional
Load network. Defaults to a SHORT.
Returns
-------
Model
"""
from pmrf.models.lumped import SHORT
from pmrf.models.containers import Cascade
load = load or SHORT
terminated_model = Cascade((self, load))
return terminated_model
# ---- Model fitting --------------------------------------------------
[docs]
def fitted(self: ModelT, measured: str | skrf.Network | NetworkCollection, **kwargs) -> ModelT:
"""Returns a new model fitted to measured data.
This is an alternative API to ParamRF's :mod:`fitting <pmrf.fitting>` module.
See the :func:`fit <pmrf.fitting.BaseFitter.fit>` method for more details.
Internally, a fitter is first created using :func:`pmrf.fitting.Fitter`, and then :func:`fitter.fit(...)` is called,
with the fitted model being returned. Fit results are stored in the model's metadata with key 'fit_results'.
Key-word arguments are split into 'init' and 'fit' key-word arguments appropriately.
Parameters
----------
measured : prf.NetworkCollection | skrf.Network | str
The measured data.
**kwargs
Additional arguments forwarded to :func:`pmrf.fitting.Fitter` and :func:`pmrf.fitting.BaseFitter.fit`.
Returns
-------
Model
The fitted model.
"""
from pmrf.fitting import fit
return fit(self, measured, **kwargs)
[docs]
def with_submodels_fitted(self: ModelT, measured: str | skrf.Network | NetworkCollection, **kwargs) -> ModelT:
"""Returns a new model with its submodels fitted to measured data.
This is an alternative API to ParamRF's :mod:`fitting <pmrf.fitting>` module.
See the :func:`fit_submodels <pmrf.fitting.BaseFitter.fit_submodels>` method for more details.
Internally, a fitter is first created using :func:`pmrf.fitting.Fitter`, and then :func:`fitter.fit_submodels(...)` is called,
with the fitted model being returned. Fit results are stored in the model's metadata with key 'fit_results'.
Key-word arguments are split into 'init' and 'fit' key-word arguments appropriately.
Parameters
----------
measured : prf.NetworkCollection | skrf.Network | str
The measured data.
**kwargs
Additional arguments forwarded to :func:`pmrf.fitting.Fitter` and :func:`pmrf.fitting.BaseFitter.fit_submodels`.
Returns
-------
Model
The fitted model.
"""
from pmrf.fitting import fit_submodels
return fit_submodels(self, measured, **kwargs)
# ---- Parameter inspection --------------------------------------------------
[docs]
def named_params(self, include_fixed=False, submodels: 'Model' | Sequence['Model'] | str | Sequence[str] | None = None) -> dict[str, Parameter]:
"""Named model parameters as a dict.
Keys are fully-qualified parameter names.
The order matches the internal flattened array order.
Parameters
----------
include_fixed : bool, default=False
Include fixed parameters.
submodels : Model | Sequence[Model] | str | Sequence[str] | None, optional
Restrict to parameters used by the given submodel(s). If strings are
provided, ``getattr(self, name)`` is used.
Returns
-------
dict[str, Parameter]
"""
return dict(self._iter_params(include_fixed=include_fixed, submodels=submodels))
[docs]
def named_param_values(self, scaled=False, **kwargs) -> dict[str, jnp.ndarray]:
"""Named model parameter values as a dict of jax arrays.
See :meth:`named_params`.
Parameters
----------
scaled : bool, default=False
Whether or not to scale the returned values by the parameter scales.
**kwargs
Additional key-word arguments as in :meth:`named_params`.
Returns
-------
dict[str, jnp.ndarray]
"""
if scaled:
return {n: jnp.array(p) for n, p in (self._iter_params(**kwargs))}
else:
return {n: p.value for n, p in (self._iter_params(**kwargs))}
[docs]
def param_names(self, *args, **kwargs) -> list[str]:
"""
Return model parameter names as a list.
See :meth:`named_params`.
"""
return list(self.named_params(*args, **kwargs).keys())
[docs]
def params(self, *args, **kwargs) -> list[Parameter]:
"""
Return model parameters as a list.
See :meth:`named_params`.
"""
return list(self.named_params(*args, **kwargs).values())
[docs]
def params_values(self, *args, **kwargs) -> list[jnp.ndarray]:
"""
Return model parameter values as a list of jax arrays.
See :meth:`named_param_values`.
"""
return list(self.named_param_values(*args, **kwargs).values())
[docs]
def named_flat_params(self, include_fixed=False, submodels: 'Model' | Sequence['Model'] | str | Sequence[str] | None = None) -> dict[str, Parameter]:
"""Named flattened model parameters as a dict.
Flat parameters are a de-vectorized version of
the internal parameters of the model. The returned
parameter objects therefore are not necessarily
equal to the internal model objects.
Keys are fully-qualified parameter names with de-vectorized suffixes added.
The order matches the internal flattened array order.
Parameters
----------
include_fixed : bool, default=False
Include fixed parameters.
submodels : Model | Sequence[Model] | str | Sequence[str] | None, optional
Restrict to parameters used by the given submodel(s). If strings are
provided, ``getattr(self, name)`` is used.
Returns
-------
dict[str, Parameter]
"""
return dict(self._iter_params(flatten=True, include_fixed=include_fixed, submodels=submodels))
[docs]
def named_flat_param_values(self, scaled=False, **kwargs) -> dict[str, jnp.ndarray]:
"""Named flattened model parameter values as a dict of jax arrays.
See :meth:`named_flat_params`.
Parameters
----------
scaled : bool, default=False
Whether or not to scale the returned values by the parameter scales.
**kwargs
Additional key-word arguments as in :meth:`named_params`.
Returns
-------
dict[str, jnp.ndarray]
"""
if scaled:
return {n: jnp.array(p) for n, p in (self._iter_params(flatten=True, **kwargs))}
else:
return {n: p.value for n, p in (self._iter_params(flatten=True, **kwargs))}
[docs]
def flat_param_names(self, *args, **kwargs) -> list[str]:
"""
Return flattened parameter names as a list.
See :meth:`named_flat_params`.
"""
return list(self.named_flat_params(*args, **kwargs).keys())
[docs]
def flat_params(self, *args, **kwargs) -> list[Parameter]:
"""
Return flattened parameters as a list.
See :meth:`named_flat_params`.
"""
return list(self.named_flat_params(*args, **kwargs).values())
[docs]
def flat_param_values(self, *args, **kwargs) -> jnp.ndarray:
"""
Return flattened model parameter values as a jax arrays.
See :meth:`named_flat_param_values`.
"""
return jnp.array(list(self.named_flat_param_values(*args, **kwargs).values())).reshape(-1)
[docs]
def param_groups(self, include_fixed=False) -> list[ParameterGroup]:
"""Return all parameter groups relevant to this model.
This function is useful for fitters that need to take into account
e.g. any constraints or correlated priors. Only groups that were
create with the same `flat` flag are returned.
Currently, only groups for the current model are considered
(i.e. groups in children models are ignored).
Parameters
----------
include_fixed : bool, default=False
Include groups involving fixed parameters.
Returns
-------
list[ParameterGroup]
"""
params = self.named_flat_params(include_fixed=include_fixed)
groups = [group for group in self._param_groups]
grouped_param_names = {name for group in groups for name in group.parameter_names}
for name, param in params.items():
if name not in grouped_param_names:
groups.append(ParameterGroup(param_names=[name], dist=param.distribution))
return groups
[docs]
def distribution(self) -> JointParameterDistribution:
"""Joint distribution over (flattened) parameters.
Returns
-------
JointParameterDistribution
"""
return JointParameterDistribution(self.param_groups(), self.flat_param_names())
# ---- Parameter manipulation --------------------------------------------------
[docs]
def with_params(
self: ModelT,
params: dict[str, Parameter] | dict[str, float] | jnp.ndarray | None = None,
check_missing: bool = False,
check_unknown: bool = False,
fix_others = False,
include_fixed = False,
**param_kwargs: dict[str, Parameter] | dict[str, float],
) -> ModelT:
"""Return a new model with parameters updated.
This is a multi-purpose function that updates parameters differently
based on the types pass.
Parameters
----------
params : dict[str, Parameter] | dict[str, float] | jnp.ndarray | None, optional
Parameter updates. If an array, **all** values must be provided
(matching ``flat_params`` order). You may also pass keyword args.
check_missing : bool, default=False
Require that all model parameters are specified.
check_unknown : bool, default=False
Error if unknown parameter keys are provided.
fix_others : bool, default=False
Fix any parameters not explicitly passed.
include_fixed : bool, default=False
Include fixed parameters when interpreting ``params`` mapping.
**param_kwargs : dict
Additional parameter updates by name.
Returns
-------
ModelT
Raises
------
Exception
If shape/order mismatches, unknown/missing names (when checked),
or if arrays are found outside of Parameters.
"""
# Prepare input
if isinstance(params, dict) or len(param_kwargs) != 0:
if include_fixed:
raise Exception('Not yet supported')
params = params if params is not None else {}
params.update(param_kwargs)
# Generate an ordered, input flat params array for verification
new_params = self.named_params(include_fixed=True)
# Validate the callers's input
unknown_params = set(params.keys() - new_params.keys())
if check_unknown and len(unknown_params) != 0:
raise Exception(f"Error: the following parameters were passed but are not in the model: {unknown_params}")
params = {k: v for k, v in params.items() if k not in unknown_params}
if check_missing or fix_others:
missing_params = set(new_params.keys() - params.keys())
if check_missing and len(missing_params) != 0:
raise Exception(f"Error: the following model parameters were missing: {missing_params}")
if fix_others:
for missing_param_name in missing_params:
new_params[missing_param_name] = dataclasses.replace(new_params[missing_param_name], fixed=True)
def is_convertible_to_float(x):
try:
float(x)
return True
except (ValueError, TypeError):
return False
# Convert to an array of parameters instead of floats
if all(is_convertible_to_float(v) for v in params.values()):
for name, value in params.items():
# TODO create specs for the full parameter objects such that we can get and use the built-in scales
new_params[name] = dataclasses.replace(new_params[name], value=jnp.array(value), scale=1.0)
else:
new_params.update(params)
new_flat_params = list(new_params.values())
# Get the current flat parameter object
params_tree, static = partition(self, self._core_object_spec, self._param_object_spec, is_leaf=is_valid_param)
flat_params, treedef = jax.tree.flatten(params_tree, is_leaf=is_valid_param)
# We allow the caller to pass None for name and then we update the name. Otherwise names should match
for i, param in enumerate(flat_params):
if new_flat_params[i].name == None:
new_flat_params[i] = dataclasses.replace(new_flat_params[i], name=param.name)
# Create the update tree and return
new_params_tree = jax.tree.unflatten(treedef, new_flat_params)
combined: Model = combine(new_params_tree, static, is_leaf=is_valid_param)
return combined
else:
params = jnp.array(params)
if params.shape[0] != self.num_flat_params:
raise Exception(f'Expected {self.num_flat_params} flat parameters but was passed {params.shape[0]}')
params_tree, static = self._partition(include_fixed=include_fixed)
params_out, unravel_fn = flatten_util.ravel_pytree(params_tree)
if jnp.isscalar(params_out) or params_out.shape[0] == 0:
raise Exception("Error: no free model parameters found to make feature function")
params_tree_recon = unravel_fn(params)
return combine(params_tree_recon, static)
[docs]
def with_fixed_params(self: ModelT, params: list[str] | Callable[[str], bool], check_unknown=True) -> ModelT:
"""Return a model with specified parameters fixed.
Parameters
----------
params : list[str]
Parameter names to fix.
check_unknown : bool, default=True
Error if any provided name does not exist.
Returns
-------
ModelT
"""
# For legacy callers
if isinstance(params, str):
params = [params]
if isinstance(params, Callable):
params = [p for p in self.param_names() if params(p)]
params = set(params)
current_params = self.named_params()
current_param_names = set(current_params.keys())
if check_unknown:
for param_name in params:
if param_name not in current_param_names:
raise Exception(f"Specified parameter '{param_name}' not found in model")
new_params = current_params.copy()
for name, param in current_params.items():
if name in params:
new_params[name] = param.as_fixed()
return self.with_params(new_params)
[docs]
def with_free_params(self: ModelT, params: list[str] | Callable[[str], bool], fix_others=False) -> ModelT:
"""Free the specified parameters.
Parameters
----------
params : Sequence[str] | Sequence[Parameter]
Parameters to set free.
fix_others : bool, default=True
Fix parameters not specified.
Returns
-------
ModelT
"""
if isinstance(params, Callable):
params = [p for p in self.param_names() if params(p)]
if isinstance(params[0], Parameter):
param_id_to_name: dict[Parameter, str] = {id(v): k for k, v in self.named_params().items()}
params = [param_id_to_name[id(param)] for param in params]
params = set(params)
current_params = self.named_params(include_fixed=True)
current_param_names = set(current_params.keys())
for param_name in params:
if param_name not in current_param_names:
raise Exception(f"Specified parameter '{param_name}'' not found in model")
new_params = current_params.copy()
for name, param in current_params.items():
if name in params:
new_params[name] = param.as_free()
elif fix_others:
new_params[name] = param.as_fixed()
return self.with_params(new_params)
[docs]
def with_free_params_only(self: ModelT, *args, **kwargs) -> ModelT:
"""Returns a model with only the specified parameters freed.
This is an alias for calling :meth:``with_free_params``
with `fix_others=True`.
See :meth:``with_free_params``.
"""
kwargs.setdefault('fix_others', True)
if kwargs['fix_others'] == False:
raise Exception("Cannot pass fix_others == False for `with_free_params_only`.")
return self.with_free_params(*args, **kwargs)
[docs]
def with_param_groups(self: ModelT, param_groups: ParameterGroup | list[ParameterGroup]) -> ModelT:
"""Return a model with parameter groups appended.
Parameter groups can be used to specify relationships between parameters in the model,
such as joint priors.
Parameters
----------
param_groups : ParameterGroup or list[ParameterGroup]
Group(s) to add.
Returns
-------
ModelT
"""
if not isinstance(param_groups, list):
param_groups = [param_groups]
param_groups_old = self._param_groups.copy() if self._param_groups is not None else []
param_groups_new = param_groups_old
param_groups_new.extend(param_groups)
param_groups_new = [param_group for param_group in param_groups_new]
return dataclasses.replace(self, _param_groups=param_groups_new)
# ---- Field and model manipulation --------------------------------------------------
[docs]
@classmethod
def with_defaults(cls, *args, **kwargs) -> type['Model']:
"""Return this model type with default initialization arguments.
This method is very useful in utilizing an existing model
with default values, without having to create a new
model type via inheritance.
Arguments are forwarded as if they were passed to `__init__`.
Returns
-------
type[Model]
"""
class DefaultsWrapper:
def __init__(self, p):
self.p = p # underlying partial
def __call__(self, *args, **kwargs):
return self.p(*args, **kwargs)
# chaining
def with_defaults(self, *args, **kwargs):
# merge new defaults after existing ones
new_args = self.p.args + args
new_kwargs = {**self.p.keywords, **kwargs} if self.p.keywords else kwargs
return DefaultsWrapper(partial(self.p.func, *new_args, **new_kwargs))
return DefaultsWrapper(partial(cls, *args, **kwargs))
[docs]
def with_models(self: ModelT, models: ModelT | Sequence[ModelT]) -> ModelT:
"""Combines this model with free parameters in other models.
This is useful to combine separate models obtained from fitting
the same initial model with different free parameters.
Parameters
----------
models : Model or Sequence[Model]
The other models to combine this model with.
Returns
-------
Model
"""
if not isinstance(models, Sequence):
models = [models]
combined = self
for other in models:
combined = combined.with_params(other.named_params())
combined = combined.with_param_groups(other.param_groups())
return combined
[docs]
def with_fields(self: ModelT, *args, **kwargs) -> ModelT:
"""
Return a copy of this model with dataclass-style field replacements.
Parameters are forwarded to :func:`dataclasses.replace`.
"""
return dataclasses.replace(self, *args, **kwargs)
[docs]
def with_submodel_fields(self: ModelT, submodel: str, *args, **kwargs) -> ModelT:
"""
Return a copy of this model with dataclass-style field replacements on a sub-model.
Parameters are forwarded to :func:`dataclasses.replace`.
Parameters
----------
submodels : str
The name of the submodel to replace the fields of.
"""
with_field_kwargs = {submodel: getattr(self, submodel).with_fields(*args, **kwargs)}
return self.with_fields(**with_field_kwargs)
[docs]
def with_free_submodels(self: ModelT, submodels: 'Model' | Sequence['Model'] | str | Sequence[str], include_fixed=False, fix_others=False) -> ModelT:
"""Free all parameters in the given submodels.
Submodels parameters are obtained using :meth:`param_names`.,
and subsequently freed using :meth:``with_free_params``.
Parameters
----------
submodels : Model | Sequence[Model] | str | Sequence[str]
Submodels whose parameters should be free.
include_fixed : bool, default=False
Also free parameters that are currently fixed in the submodels.
fix_others : bool, default=False
Fix all other submodels.
Returns
-------
ModelT
"""
model_param_names = self.param_names(include_fixed=include_fixed, submodels=submodels)
return self.with_free_params(model_param_names, fix_others=fix_others)
[docs]
def with_fixed_submodels(self: ModelT, submodels: 'Model' | Sequence['Model'] | str | Sequence[str]) -> ModelT:
"""Fixed all parameters in the given submodels.
Submodels parameters are obtained using :meth:`param_names`.,
and subsequently fixed using :meth:``with_fixed_params``.
Parameters
----------
submodels : Model | Sequence[Model] | str | Sequence[str]
Submodels whose parameters should be fixed.
Returns
-------
ModelT
"""
model_param_names = self.param_names(submodels=submodels)
return self.with_fixed_params(model_param_names)
# ---- File and conversion utilities --------------------------------------------------
[docs]
def to_skrf(self, frequency: Frequency | skrf.Frequency, sigma=0.0, **kwargs) -> skrf.Network:
"""Convert the model at frequencies to an :class:`skrf.Network`.
The active primary property (``self.primary_property``) is used.
Parameters
----------
frequency : pmrf.Frequency | skrf.Frequency
Frequency grid.
sigma : float, default=0.0
If nonzero, add complex Gaussian noise with stdev ``sigma`` to ``s``.
**kwargs
Forwarded to :class:`skrf.Network` constructor.
Returns
-------
skrf.Network
"""
if isinstance(frequency, Frequency):
model_freq = frequency
measured_freq = frequency.to_skrf()
else:
model_freq = Frequency.from_skrf(frequency)
measured_freq = frequency
fval, fname = self.primary(model_freq), self.primary_property
kwargs = kwargs or {}
kwargs.update({
fname: fval,
'frequency': measured_freq,
'name': kwargs.get('name', self.name),
'z0': self._z0,
})
ntwk = skrf.Network(**kwargs)
if sigma != 0.0:
ntwk.s += (np.random.normal(0, sigma, ntwk.s.shape) + 1j * np.random.normal(0, sigma, ntwk.s.shape))
return ntwk
[docs]
def write_hdf(self, group: h5py.Group):
"""Write model into an HDF5 group.
Parameters
----------
group : h5py.Group
Target group. Two subgroups are created: ``raw`` and ``params``.
"""
model_save = self.with_fields(metadata=dict())
params_tree, static_tree = model_save._partition(include_fixed=True, param_objects=True)
params = model_save.named_params()
model_raw_grp = group.create_group('raw')
model_raw_grp.create_dataset('combined', data=jsonpickle.encode(model_save))
model_raw_grp.create_dataset('params', data=jsonpickle.encode(params_tree))
model_raw_grp.create_dataset('static', data=jsonpickle.encode(static_tree))
params_grp = group.create_group('params')
for name, initial_param in params.items():
params_grp[name] = initial_param.to_json()
[docs]
@classmethod
def read_hdf(cls, group: h5py.Group) -> ModelT:
"""Read model from an HDF5 group.
Parameters
----------
group : h5py.Group
Group previously created via :meth:`write_hdf`.
Returns
-------
ModelT | None
The decoded model, or ``None`` if decoding fails.
"""
model_raw_grp = group['raw']
if 'combined' in model_raw_grp:
combined_json = model_raw_grp['combined'][()]
combined_json = combined_json.decode('utf-8') if isinstance(combined_json, bytes) else combined_json
combined_tree = jsonpickle.decode(combined_json)
return combined_tree
params_json = model_raw_grp['params'][()]
params_json = params_json.decode('utf-8') if isinstance(params_json, bytes) else params_json
static_json = model_raw_grp['static'][()]
static_json = static_json.decode('utf-8') if isinstance(static_json, bytes) else static_json
try:
params_tree = jsonpickle.decode(params_json)
static_tree = jsonpickle.decode(static_json)
return cls(eqx.combine(params_tree, static_tree))
except Exception as e:
import logging
logging.warning(f'Error while trying to load model: {e}')
return None
[docs]
def save(self, target: str | BinaryIO):
"""Serialize the model to disk in pickle format.
The recommended file extension for ParamRF models is `.prf`.
Parameters
----------
filepath : str
Destination file path.
"""
model_save = self.with_fields(metadata=dict())
data = jsonpickle.encode(model_save)
if isinstance(target, (str, os.PathLike)):
with open(target, "w", encoding="utf8") as f:
f.write(data)
else:
target.write(data)
[docs]
@classmethod
def load(cls: ModelT, source: str | BinaryIO) -> ModelT:
"""Load a model from a pickled file.
The recommended file extension for ParamRF models is `.prf`.
Parameters
----------
filepath : str
Source JSON path.
Returns
-------
ModelT
"""
if isinstance(source, (str, os.PathLike)):
with open(source, "r", encoding="utf8") as f:
data = f.read()
else:
data = source.read()
return jsonpickle.decode(data)
[docs]
def export_touchstone(self, filename: str, frequency: Frequency | skrf.Frequency, sigma: float = 0.0, **skrf_kwargs):
"""Export the model response to a Touchstone file via scikit-rf.
Parameters
----------
filename : str
frequency : Frequency | skrf.Frequency
sigma : float, default=0.0
Additive complex noise std for S-parameters.
**skrf_kwargs
Forwarded to :meth:`skrf.Network.write_touchstone`.
Returns
-------
Any
Return value of ``Network.write_touchstone``.
"""
if not isinstance(filename, str):
raise Exception('Filename must be a string')
ntwk = self.to_skrf(frequency, sigma=sigma)
retval = ntwk.write_touchstone(filename, **skrf_kwargs)
return retval
[docs]
def wrap(
func: Callable,
*args,
as_numpy: bool = False,
) -> Callable:
"""
Wrap a function/method taking ``(model, frequency, *args, **kwargs)`` into one that accepts flat arrays.
The wrapper accepts ``(theta, [f], *args, **kwargs)`` where ``theta`` is a flat
parameter array, and optionally ``f`` is a frequency array with a provided unit.
If a :class:`Frequency` object is passed in ``*args``, it is closed over and
the wrapped function only accepts ``theta``.
Supported call signatures:
* ``wrap(userfunc, model, "MHz")``
* ``wrap(userfunc, model, freq)``
* ``wrap(model.userfunc, "MHz")``
* ``wrap(model.userfunc, freq)``
Parameters
----------
func : callable
Function or bound method taking ``(model, frequency, ...)``.
*args : tuple
Variable length argument list. Expects either ``(model, Frequency)``,
``(model, unit_str)``, or just ``(Frequency,)`` / ``(unit_str,)``
if ``func`` is a bound method.
as_numpy : bool, optional
If True, JIT the wrapper and convert inputs/outputs to NumPy arrays.
Default is False.
Returns
-------
wrapper : callable
Wrapped function accepting ``(theta_array, [f_array], *args, **kwargs)``.
Raises
------
TypeError
If the frequency argument is neither a unit string nor a ``Frequency`` object.
"""
# Determine if func is a bound method
if hasattr(func, '__self__') and func.__self__ is not None:
model: Model = func.__self__
func = func.__func__
args = list(args)
else:
model: Model = args[0]
args = args[1:]
# Now args[0] should be Frequency or unit
frequency_or_unit = args[0] if args else "Hz"
# Validate
if not isinstance(frequency_or_unit, (str, Frequency)):
raise TypeError(f"Expected pmrf.Frequency or unit string as second argument, got type {type(frequency_or_unit)}")
use_variable_frequency = isinstance(frequency_or_unit, str)
# Define wrapped function
def wrapped_with_freq(theta: jnp.ndarray, f_array: jnp.ndarray, *f_args, **f_kwargs):
new_model = model.with_params(theta)
frequency = Frequency.from_f(f_array, unit=frequency_or_unit)
return func(new_model, frequency, *f_args, **f_kwargs)
def wrapped_fixed_freq(theta: jnp.ndarray, *f_args, **f_kwargs):
new_model = model.with_params(theta)
return func(new_model, frequency_or_unit, *f_args, **f_kwargs)
wrapped_fn = wrapped_with_freq if use_variable_frequency else wrapped_fixed_freq
if as_numpy:
raw_fn = jax.jit(wrapped_fn)
if use_variable_frequency:
wrapped_fn = lambda theta, f, *a, **kw: np.array(raw_fn(jnp.array(theta), jnp.array(f), *a, **kw))
else:
wrapped_fn = lambda theta, *a, **kw: np.array(raw_fn(jnp.array(theta), *a, **kw))
# Forward metadata
update_wrapper(wrapped_fn, func)
return wrapped_fn