Source code for pmrf.parameters

import json
import dataclasses
from dataclasses import dataclass
from typing import Sequence

import jax.numpy as jnp
import equinox as eqx
import numpyro.distributions as dist
from numpyro.distributions.distribution import Distribution

from pmrf._util import field

MIN_PERCENTILE = 0.01
MAX_PERCENTILE = 0.99

[docs] class Parameter(eqx.Module): """ Overview -------- A container for a parameter, usually used within a `Model`. This class serves as the fundamental building block for defining the tunable or fixed parameters within a **paramrf** `Model` and for fitting. It is designed to be a flexible container that behaves like a standard numerical type (e.g., a `numpy.ndarray`) while holding additional metadata for model fitting and analysis. Usage -------- - Use in mathematical operations just like a JAX/numpy array - `Parameter` objects are JAX PyTrees, making them compatible with JAX's transformations (`jit`, `grad`, etc.). - Mark as `fixed` (which is honoured by fitting and sampling routines) - Associate priors, specified as numpyro distributions (uniform, normal etc.) Examples -------- .. code-block:: python import pmrf as prf import jax.numpy as jnp # A simple, single-valued parameter, initialized with a float p1 = prf.Parameter(value=1.0e-12, name='C1') # This parameter can be used in calculations directly (scaling is done during casting) impedance = 1 / (2j * jnp.pi * 1e9 * p1) print(f"Impedance: {impedance}") # A parameter that is fixed and will not be optimized during a fit p2 = prf.Parameter(value=50.0, fixed=True, name='Z0') # A parameter with a uniform distribution prior # Factory functions are a convenient way to create these p3 = prf.Uniform(min=0.9e-9, max=1.1e-9, name='L1') # The parameter's value is initialized to the mean of the distribution print(f"Initial value of L1: {p3.value}") """ # Underlying values/dists (unscaled). Multiply by scale above to get to true value (done automatically when converting to array) # None of these are marked static so we can update them if we want to value: jnp.ndarray = field(converter=lambda x: jnp.asarray(x, dtype=jnp.float64)) prior: Distribution | None = field(default=None) fixed: bool = field(default=False) scale: float = field(default=1.0) name: str | None = field(default=None, static=True) @property def ndim(self) -> int: """The number of free dimensions for this parameter.""" if self.fixed: return 0 # A Parameter object itself represents a single fittable entity, # even if its value is an array. The prior handles the dimensionality. if self.prior is not None: if len(self.prior.batch_shape) == 0: return 1 else: return self.prior.batch_shape[0] return 1 @property def min(self) -> jnp.array: r"""The unscaled minimum value of the parameter's distribution (MIN_PERCENTILE quantile). Returns ------- jnp.array The minimum value, or -np.inf if no distribution is set. """ if self.prior is not None: if isinstance(self.prior, dist.Uniform): return self.prior.low else: return self.prior.icdf(MIN_PERCENTILE) if jnp.isscalar(self.value): return -jnp.inf return jnp.array([-jnp.inf] * self.value.shape[0]) @property def max(self) -> jnp.array: r"""The unscaled maximum value of the parameter's distribution (MAX_PERCENTILE quantile). Returns ------- jnp.array The maximum value, or np.inf if no distribution is set. """ if self.prior is not None: if isinstance(self.prior, dist.Uniform): return self.prior.high else: return self.prior.icdf(MAX_PERCENTILE) if jnp.isscalar(self.value): return -jnp.inf return jnp.array([-jnp.inf] * self.value.shape[0])
[docs] def with_value(self, value: jnp.array) -> 'Parameter': r"""Return a copy with a new unscaled value. Parameters ---------- value : jnp.array The new unscaled value to set. Returns ------- Parameter A copy of this object with ``value`` replaced. """ return dataclasses.replace(self, value=value)
[docs] def with_prior(self, prior: Distribution) -> 'Parameter': r"""Return a copy with a new prior distribution. Parameters ---------- prior : numpyro.distributions.Distribution The prior to associate with this parameter. Returns ------- Parameter A copy of this object with ``prior`` replaced. """ return dataclasses.replace(self, prior=prior)
[docs] def flatten(self, separator='_') -> 'Parameter | list[Parameter]': r"""Flattens self, either returning a single Parameter if the internal parameter is scalar, or a list. If the internal prior cannot be separated, this will raise an Exception. Returns ------- 'Parameter' | list['Parameter'] The raveled parameters. """ if jnp.isscalar(self.value): return self else: if self.prior is not None: priors_split = _split_vectorized_distribution(self.prior) else: priors_split = [None] * len(self.value) return [Parameter(value=val, prior=p, fixed=self.fixed, scale=self.scale, name=f"{self.name}{separator}{i}") for i, (val, p) in enumerate(zip(self.value, priors_split))]
[docs] def as_fixed(self) -> 'Parameter': r"""Returns a fixed copy of self. Returns ------- Parameter The new, fixed parameter. """ return dataclasses.replace(self, fixed=True)
[docs] def as_free(self) -> 'Parameter': r"""Returns a copy of self with fixed set to False. Returns ------- Parameter The new, fixed parameter. """ return dataclasses.replace(self, fixed=False)
# Arithmetic and array conversions def __array__(self, dtype=None): r"""NumPy array interface. Parameters ---------- dtype : Any, optional Desired dtype. Returns ------- jnp.ndarray The scaled value as an array. """ return jnp.asarray(self.value * self.scale, dtype=dtype) def __jax_array__(self, dtype=None): r"""JAX array interface. Parameters ---------- dtype : Any, optional Desired dtype. Returns ------- jnp.ndarray The scaled value as an array. """ return jnp.asarray(self.value * self.scale, dtype=dtype) def __len__(self): r"""Length of the parameter value. Returns ------- int ``1`` for scalars, otherwise ``len(value)``. """ if len(self.value.shape) == 0: return 1 # e.g. for jax scalars return len(self.value) def __add__(self, other): r"""Elementwise addition.""" return jnp.add(jnp.array(self), jnp.array(other)) def __sub__(self, other): r"""Elementwise subtraction.""" return jnp.subtract(jnp.array(self), jnp.array(other)) def __mul__(self, other): r"""Elementwise multiplication.""" return jnp.multiply(jnp.array(self), jnp.array(other)) def __truediv__(self, other): r"""Elementwise true division.""" return jnp.divide(jnp.array(self), jnp.array(other)) def __radd__(self, other): r"""Reflected elementwise addition.""" return jnp.add(jnp.array(other), jnp.array(self)) def __rsub__(self, other): r"""Reflected elementwise subtraction.""" return jnp.subtract(jnp.array(other), jnp.array(self)) def __rmul__(self, other): r"""Reflected elementwise multiplication.""" return jnp.multiply(jnp.array(other), jnp.array(self)) def __rtruediv__(self, other): r"""Reflected elementwise true division.""" return jnp.divide(jnp.array(other), jnp.array(self))
[docs] def copy(self): r"""Return a shallow copy. Returns ------- Parameter A copy created via ``dataclasses.replace``. """ return dataclasses.replace(self)
# Serialization
[docs] def to_json(self) -> str: r"""Serialize the parameter to a JSON string. Returns ------- str A JSON-formatted string containing value, prior, fixed, scale and name. """ d = { "value": self.value.tolist(), "prior": _serialize_distribution(self.prior), "fixed": self.fixed, "scale": self.scale, "name": self.name } return json.dumps(d, indent=2)
[docs] @classmethod def from_json(cls, s: str) -> "Parameter": r"""Deserialize a parameter from a JSON string. Parameters ---------- s : str The JSON string produced by :meth:`to_json`. Returns ------- Parameter A reconstructed :class:`Parameter` instance. """ d = json.loads(s) return cls( value=jnp.asarray(d["value"]), prior=_deserialize_distribution(d["prior"]), fixed=d["fixed"], scale=d["scale"], name=d["name"] )
[docs] @dataclass class ParameterGroup: r""" A metadata class that groups a set of named flat parameters and defines any relationships between them. """ param_names: list[str] prior: dist.Distribution | None = field(default=None) def __init__(self, param_names: list[str] | dict[str, Parameter], prior: dist.Distribution | None = None): r"""Construct a :class:`ParameterGroup`. Parameters ---------- param_names : list[str] | dict[str, Parameter] The names of the flattened parameters (or a mapping to parameters). prior : numpyro.distributions.Distribution, optional An optional prior over the flattened parameters. """ self.param_names = param_names self.prior = prior @property def num_flat_params(self): r"""Number of flattened parameters. Returns ------- int The count of names in ``param_names``. """ return len(self.param_names) @property def min(self) -> jnp.array: r"""The unscaled minimum value of the parameter group's distribution (MIN_PERCENTILE quantile). Returns ------- jnp.array The minimum value, or -np.inf if no distribution is set. """ if self.prior is not None: if hasattr(self.prior, 'min'): return self.prior.min.reshape((self.num_flat_params)) elif hasattr(self.prior, 'low'): return self.prior.low.reshape((self.num_flat_params)) else: # TODO implement optimization to determine minima return self.prior.icdf(jnp.array([MIN_PERCENTILE] * self.num_flat_params)) return jnp.array([-jnp.inf] * self.num_flat_params) @property def max(self) -> jnp.array: r"""The unscaled maximum value of the parameter's distribution (MAX_PERCENTILE quantile). Returns ------- jnp.array The maximum value, or np.inf if no distribution is set. """ if self.prior is not None: if hasattr(self.prior, 'max'): return self.prior.max.reshape((self.num_flat_params)) elif hasattr(self.prior, 'high'): return self.prior.high.reshape((self.num_flat_params)) else: # TODO implement optimization to determine maximum return self.prior.icdf(jnp.array([MAX_PERCENTILE] * self.num_flat_params)) return jnp.array([jnp.inf] * self.num_flat_params)
def Uniform(low: float | Sequence[float], high: float | Sequence[float], n: int | None = None, value=None, **kwargs) -> 'Parameter': r"""Creates a `Parameter` with a uniform prior distribution. Parameters ---------- low : float | Sequence[float] The lower value of the distribution. Can be a sequence for a multi-valued Parameter. upper : float | Sequence[float] The upper value of the distribution. Can be a sequence for a multi-valued Parameter. n : int, optional The number of identical parameters to create in an array. Defaults to None. value : optional The initial value. If None, the midpoint of the distribution is used. Defaults to None. **kwargs Additional keyword arguments passed to the `Parameter` constructor. Returns ------- Parameter The created Parameter object. """ if n is not None and n != 1: low, high = jnp.array([low] * n), jnp.array([high] * n) if value is not None: value = jnp.array([value] * n) else: low, high = jnp.array(low), jnp.array(high) priors = dist.Uniform(low, high) values = (low + high) / 2.0 if value is None else value return Parameter(value=values, prior=priors, **kwargs) def Normal(mean: float | Sequence[float], std: float | Sequence[float], n: int | None = None, value=None, **kwargs) -> 'Parameter': r"""Creates a `Parameter` with a normal (Gaussian) prior distribution. Parameters ---------- mean : float | Sequence[float] The mean of the distribution. Can be a sequence for a multi-valued Parameter. std : float | Sequence[float] The standard deviation of the distribution. Can be a sequence for a multi-valued Parameter. n : int, optional The number of identical parameters to create in an array. Defaults to None. value : optional The initial value. If None, the mean of the distribution is used. Defaults to None. **kwargs Additional keyword arguments passed to the `Parameter` constructor. Returns ------- Parameter The created Parameter object. """ if n is not None and n != 1: mean, std = jnp.array([mean] * n), jnp.array([std] * n) if value is not None: value = jnp.array([value] * n) else: mean, std = jnp.array(mean), jnp.array(std) priors = dist.Normal(mean, std) values = jnp.array(mean) if value is None else value return Parameter(value=values, prior=priors, **kwargs) def PercentNormal(mean: float | Sequence[float], perc: float | Sequence[float], **kwargs) -> 'Parameter': r"""Creates a `Parameter` with a normal (Gaussian) distribution and a percentage standard deviation. Parameters ---------- mean : float | Sequence[float] The mean of the distribution. Can be a sequence for a multi-valued Parameter. perc : float | Sequence[float] The percentage width to use to initialize the standard deviation, assuming the percentage represents +-2*sigma (95% coverage). As an example, passing `5.0` results in `std = 0.025*mean`. Can be a sequence for a multi-valued Parameter. **kwargs Additional keyword arguments passed to the `Normal` factory function. Returns ------- Parameter The created Parameter object. """ if isinstance(perc, Sequence): std = [p * mean[i] / 200.0 for i, p in enumerate(perc)] else: std = perc * mean / 200.0 return Normal(mean=mean, std=std, **kwargs) def Fixed(value, n: int | None = None, **kwargs) -> 'Parameter': r"""Creates a `Parameter` that is marked as fixed. Parameters ---------- value The value of the parameter. n : int, optional The number of identical parameters to create in an array. Defaults to None. **kwargs Additional keyword arguments passed to the `Parameter` constructor. Returns ------- Parameter The created fixed Parameter object. """ if n is not None and n != 1: value = jnp.array([value] * n) else: value = jnp.array(value) return Parameter(value=value, fixed=True, **kwargs) def Free(value, n: int | None = None, **kwargs) -> 'Parameter': r"""Creates a `Parameter` that is marked as not fixed (i.e., free to vary). Parameters ---------- value The value of the parameter. n : int, optional The number of identical parameters to create in an array. Defaults to None. **kwargs Additional keyword arguments passed to the `Parameter` constructor. Returns ------- Parameter The created free Parameter object. """ if n is not None and n != 1: value = jnp.array([value] * n) else: value = jnp.array(value) return Parameter(value=value, **kwargs) def is_param(x) -> bool: r"""Checks if an object is an instance of a `Parameter`. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a Parameter, `False` otherwise. """ return isinstance(x, Parameter) def is_valid_param(x) -> bool: r"""Checks if an object is an instance of a `Parameter`, and if its value is not None. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a valid Parameter, `False` otherwise. """ return isinstance(x, Parameter) and x.value is not None def is_free_param(x) -> bool: r"""Checks if an object is a non-fixed `Parameter`. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a non-fixed Parameter, `False` otherwise. """ return isinstance(x, Parameter) and not x.fixed def is_fixed_param(x) -> bool: r"""Checks if an object is a fixed `Parameter`. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a fixed Parameter, `False` otherwise. """ return isinstance(x, Parameter) and x.fixed def asparam(x, **kwargs) -> Parameter: r"""Ensures an object is a `Parameter`. If the object is already a `Parameter`, it is returned unchanged. Otherwise, the object is converted into a new `Parameter`. Parameters ---------- x The object to convert. name : str, optional The name to assign to a newly created `Parameter`. Defaults to None. Returns ------- Parameter The object as a `Parameter`. """ if isinstance(x, Parameter): return x return Parameter(value=x, **kwargs) def _split_vectorized_distribution(dist): r"""Split a 1D batch of univariate numpyro distributions into a list. Parameters ---------- dist : numpyro.distributions.Distribution A distribution with ``event_shape == ()`` and 1D ``batch_shape``. Returns ------- list[numpyro.distributions.Distribution] A list of per-element distributions. Raises ------ ValueError If the distribution cannot be split (e.g., wrong shapes). """ if dist.event_shape != (): raise ValueError(f"Cannot split distribution with event_shape={dist.event_shape} (likely an Independent)") batch_size = dist.batch_shape[0] if len(dist.batch_shape) == 1 else None if batch_size is None: raise ValueError("Distribution is not a 1D batch of univariate distributions") # Get all init params used to construct the distribution dist_class = dist.__class__ param_names = dist.arg_constraints.keys() param_values = {name: getattr(dist, name) for name in param_names} # Split each parameter split_params = [ {name: param[i] if isinstance(param, jnp.ndarray) else param for name, param in param_values.items()} for i in range(batch_size) ] return [dist_class(**params) for params in split_params] def _serialize_distribution(d: Distribution | None) -> dict | None: r"""Serialize a numpyro distribution to a lightweight dictionary. Parameters ---------- d : numpyro.distributions.Distribution or None The distribution to serialize. Returns ------- dict or None A dictionary with ``class`` and ``params`` keys, or ``None``. """ if d is None: return None return { "class": d.__class__.__name__, "params": {k: v.tolist() if isinstance(v, jnp.ndarray) else v for k, v in d.__dict__.items() if not k.startswith("_")} } # Helper to deserialize a numpyro Distribution def _deserialize_distribution(dct: dict | None) -> Distribution | None: r"""Deserialize a numpyro distribution from a dictionary. Parameters ---------- dct : dict or None A dictionary produced by :func:`_serialize_distribution`. Returns ------- numpyro.distributions.Distribution or None The reconstructed distribution, or ``None``. """ if dct is None: return None cls = getattr(dist, dct["class"], None) if cls is None: raise ValueError(f"Unknown distribution class: {dct['class']}") return cls(**dct["params"])