Source code for pmrf.parameters.parameter

import json
import dataclasses
from copy import deepcopy
from typing import Any

import jax.numpy as jnp
import equinox as eqx
import numpyro.distributions as dist
from numpyro.distributions.distribution import Distribution

from pmrf.distributions import NormalDistribution, UniformDistribution
from pmrf.distributions.stacked import StackedDistribution
from pmrf.field import field
from pmrf.constants import MIN_PERCENTILE, MAX_PERCENTILE

[docs] class Parameter(eqx.Module): """ A container for a parameter, usually used within a `Model`. This class serves as the fundamental building block for defining the tunable or fixed parameters within a **paramrf** `Model` and for fitting. It is designed to be a flexible container that behaves like a standard numerical type (e.g., a `numpy.ndarray`) while holding additional metadata for model fitting and analysis. Usage ----- * Use in mathematical operations just like a JAX/numpy array. * ``Parameter`` objects are JAX PyTrees, compatible with JAX transformations (jit, grad). * Mark as ``fixed`` (honoured by fitting and sampling routines). * Associate distributions, specified as numpyro distributions (uniform, normal, etc.). Attributes ---------- value : jnp.ndarray The underlying unscaled value. Automatically converted to a float64 array. distribution : numpyro.distributions.Distribution or None The prior distribution associated with this parameter. fixed : bool If True, the parameter is treated as a constant during optimization/sampling. scale : float A scaling factor. The effective value used in calculations is ``value * scale``. name : str or None An optional name for the parameter (marked as static). Examples -------- .. code-block:: python import pmrf as prf import jax.numpy as jnp # A simple, single-valued parameter with a scale, initialized with a float p1 = prf.Parameter(value=1.0, scale=1e-12, name='C1') # This parameter can be used in calculations directly (scaling is done during casting) impedance = 1 / (2j * jnp.pi * 1e9 * p1) print(f"Impedance: {impedance}") # A parameter that is fixed and will not be optimized during a fit p2 = prf.Parameter(value=50.0, fixed=True, name='R2') # A parameter with a uniform distribution # The provided factory functions in pmrf.parameters are a convenient way to create these from pmrf.parameters import Uniform p3 = Uniform(0.9, 1.1, scale=1e-9, name='L1') print(f"Initial value of L1: {p3.value}") # initialized to the mean # More complicated parameters can be initialized with any "numpyro" distribution from numpyro.distributions import LogNormal p4 = prf.Parameter(value=10.0, distribution=LogNormal(10.0, 1.0)) """ # Underlying values/dists (unscaled). Multiply by scale above to get to true value (done automatically when converting to array) value: jnp.ndarray = field(converter=lambda x: jnp.asarray(x, dtype=jnp.float64)) distribution: Distribution | None = field(default=None) # Static (metadata) fields fixed: bool = field(default=False, static=True) scale: float = field(default=1.0, static=True) name: str | None = field(default=None, static=True) flat_names: list[str] | None = field(default=None, converter=lambda x: list(x) if x is not None else x, static=True) @property def shape(self) -> tuple[int, ...]: """ The shape of this parameter. """ return self.value.shape @property def size(self) -> int: """ The number of dimensions for this parameter. """ return self.value.size @property def min(self) -> jnp.ndarray: r""" The unscaled minimum value of the parameter's distribution. This uses a percentile for all distributions other than uniform. Returns ------- jnp.ndarray The minimum value, or the current value if no distribution is set. """ if isinstance(self.distribution, dist.ImproperUniform): return jnp.full(self.shape, -jnp.inf) elif isinstance(self.distribution, dist.Delta): return self.distribution.v elif isinstance(self.distribution, dist.Uniform): return self.distribution.low elif self.distribution is not None: return self.distribution.icdf(MIN_PERCENTILE) return jnp.full(self.shape, -jnp.inf) @property def max(self) -> jnp.ndarray: r""" The unscaled maximum value of the parameter's distribution. This uses a percentile for all distributions other than uniform. Returns ------- jnp.ndarray The maximum value, or the current value if no distribution is set. """ if isinstance(self.distribution, dist.ImproperUniform): return jnp.full(self.shape, jnp.inf) elif isinstance(self.distribution, dist.Delta): return self.distribution.v elif isinstance(self.distribution, dist.Uniform): return self.distribution.high elif self.distribution is not None: return self.distribution.icdf(MAX_PERCENTILE) return jnp.full(self.shape, jnp.inf) @property def mean(self) -> jnp.ndarray: r""" The unscaled mean value of the parameter's distribution. Returns ------- jnp.ndarray The mean value, or the current value if no distribution is set. """ # ImproperUniform doesn't have a defined mean, so we fall back to the initialized value if isinstance(self.distribution, dist.ImproperUniform) or self.distribution is None: return self.value elif isinstance(self.distribution, dist.Delta): return self.distribution.v return self.distribution.mean
[docs] def with_value(self, value: jnp.ndarray) -> 'Parameter': r""" Return a copy of the parameter with a new unscaled value. Parameters ---------- value : jnp.ndarray The new unscaled value to set. Returns ------- Parameter A copy of this object with ``value`` replaced. """ return dataclasses.replace(self, value=value)
[docs] def shifted(self, amount: jnp.ndarray) -> 'Parameter': r""" Returns a copy of this parameter with its mean shifted by an amonut. Parameters ---------- amount : jnp.ndarray The amount to shift by. Can be positive or negative. Returns ------- Parameter A copy of this object with ``value`` replaced. """ new_value = self.value + amount new_distribution = deepcopy(self.distribution) if isinstance(new_distribution, UniformDistribution): new_distribution.low += amount new_distribution.high += amount elif isinstance(new_distribution, NormalDistribution): new_distribution.loc += amount else: raise Exception("Can only call 'shifted' on a normal or uniform parameter") return dataclasses.replace(self, value=new_value, distribution=new_distribution)
[docs] def with_distribution(self, distribution: Distribution) -> 'Parameter': r""" Return a copy of the parameter with a new distribution. Parameters ---------- distribution : numpyro.distributions.Distribution The distribution to associate with this parameter. Returns ------- Parameter A copy of this object with ``distribution`` replaced. Raises ------ Exception If ``dist`` is not a numpyro Distribution. """ if not isinstance(distribution, Distribution): raise Exception('Only numpyro distributions are supported as parameter distributions') return dataclasses.replace(self, distribution=distribution)
[docs] def flattened(self, separator='_') -> 'list[Parameter]': r""" Flatten self into a list of scalar Parameters. If the internal parameter is scalar, the list will contain self. Otherwise, the parameter is split (de-vectorized) if possible. Parameters ---------- separator : str, optional, default='_' Separator used for naming split parameters (e.g., name_0, name_1). Returns ------- list[Parameter] The list of individual parameters. Raises ------ ValueError If any internal distributions cannot be de-vectorized. """ # Handle scalar / 0-d array if self.value.ndim == 0 and self.flat_names is None: return [self] # Flatten the value flat_val = jnp.ravel(self.value) # Split distribution if present if self.distribution is not None: dists_split = _split_vectorized_distribution(self.distribution) else: dists_split = [None] * flat_val.size # Generate names flat_names = self.flat_names if flat_names is None: if self.name is not None: flat_names = [f"{self.name}{separator}{i}" for i in range(flat_val.size)] else: flat_names = [None] * flat_val.size return [ Parameter(value=val, distribution=p, fixed=self.fixed, scale=self.scale, name=n) for val, p, n in zip(flat_val, dists_split, flat_names) ]
[docs] def as_fixed(self) -> 'Parameter': r""" Return a copy of self with ``fixed=True``. Returns ------- Parameter The new, fixed parameter. """ return dataclasses.replace(self, fixed=True)
[docs] def as_free(self) -> 'Parameter': r""" Return a copy of self with ``fixed=False``. Returns ------- Parameter The new, free parameter. """ return dataclasses.replace(self, fixed=False)
# Arithmetic and array conversions def __array__(self, dtype=None): r""" NumPy array interface. Parameters ---------- dtype : Any, optional Desired dtype. Returns ------- jnp.ndarray The scaled value as an array (``value * scale``). """ return jnp.asarray(self.value * self.scale, dtype=dtype) def __jax_array__(self, dtype=None): r""" JAX array interface. Parameters ---------- dtype : Any, optional Desired dtype. Returns ------- jnp.ndarray The scaled value as an array (``value * scale``). """ return jnp.asarray(self.value * self.scale, dtype=dtype) def __len__(self): r""" Length of the parameter value. Returns ------- int ``1`` for scalars, otherwise ``len(value)``. """ if len(self.value.shape) == 0: return 1 # e.g. for jax scalars return len(self.value) def __repr__(self): # Build the representation dynamically # 'value' is always printed as it is the core of the Parameter args = [f"value={_format_val(self.value)}"] # Only add attributes if they deviate from the default if self.scale != 1.0: args.append(f"scale={self.scale}") if self.fixed is not False: args.append(f"fixed={self.fixed}") if self.distribution is not None: args.append(f"distribution={_format_distribution(self.distribution)}") if self.name is not None: args.append(f"name={repr(self.name)}") return f"Parameter({', '.join(args)})" def __add__(self, other): r"""Elementwise addition.""" return jnp.add(jnp.array(self), jnp.array(other)) def __sub__(self, other): r"""Elementwise subtraction.""" return jnp.subtract(jnp.array(self), jnp.array(other)) def __mul__(self, other): r"""Elementwise multiplication.""" return jnp.multiply(jnp.array(self), jnp.array(other)) def __truediv__(self, other): r"""Elementwise true division.""" return jnp.divide(jnp.array(self), jnp.array(other)) def __radd__(self, other): r"""Reflected elementwise addition.""" return jnp.add(jnp.array(other), jnp.array(self)) def __rsub__(self, other): r"""Reflected elementwise subtraction.""" return jnp.subtract(jnp.array(other), jnp.array(self)) def __rmul__(self, other): r"""Reflected elementwise multiplication.""" return jnp.multiply(jnp.array(other), jnp.array(self)) def __rtruediv__(self, other): r"""Reflected elementwise true division.""" return jnp.divide(jnp.array(other), jnp.array(self))
[docs] def copy(self): r""" Return a shallow copy. Returns ------- Parameter A copy created via ``dataclasses.replace``. """ return dataclasses.replace(self)
# Serialization
[docs] def to_json(self) -> str: r""" Serialize the parameter to a JSON string. Returns ------- str A JSON-formatted string containing value, distribution, fixed, scale, and name. """ d = { "value": self.value.tolist(), "distribution": _serialize_distribution(self.distribution), "fixed": self.fixed, "scale": self.scale, "name": self.name } return json.dumps(d, indent=2)
[docs] @classmethod def from_json(cls, s: str) -> "Parameter": r""" Deserialize a parameter from a JSON string. Parameters ---------- s : str The JSON string produced by :meth:`to_json`. Returns ------- Parameter A reconstructed :class:`Parameter` instance. """ d = json.loads(s) return cls( value=jnp.asarray(d["value"]), distribution=_deserialize_distribution(d["distribution"]), fixed=d["fixed"], scale=d["scale"], name=d["name"] )
[docs] def is_param(x) -> bool: r""" Check if an object is an instance of a `Parameter`. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a Parameter, `False` otherwise. """ return isinstance(x, Parameter)
[docs] def is_valid_param(x) -> bool: r""" Check if an object is an instance of a `Parameter` and if its value is not None. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a valid Parameter, `False` otherwise. """ return isinstance(x, Parameter) and x.value is not None
[docs] def is_free_param(x) -> bool: r""" Check if an object is a non-fixed `Parameter`. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a non-fixed Parameter, `False` otherwise. """ return isinstance(x, Parameter) and not x.fixed
[docs] def is_fixed_param(x) -> bool: r""" Check if an object is a fixed `Parameter`. Parameters ---------- x The object to check. Returns ------- bool `True` if the object is a fixed Parameter, `False` otherwise. """ return isinstance(x, Parameter) and x.fixed
[docs] def as_param(x: Any | list[Any] | dict[str, Any], **kwargs) -> Parameter: r""" Ensure an object is a `Parameter` or container over parameters. If the object is already a `Parameter`, it is returned unchanged. Otherwise, the underlying objects are converted into new `Parameter` objects. Parameters ---------- x The object to convert. **kwargs Additional keyword arguments passed to the `Parameter` constructor (e.g. `name`). Returns ------- Parameter The object wrapped as a `Parameter`. """ from pmrf.parameters.factories import Free, Fixed if isinstance(x, Parameter): return x elif isinstance(x, list): return [as_param(xi, **kwargs) for xi in x] elif isinstance(x, dict): return {k: as_param(xi, **kwargs) for k, xi in x.items()} else: is_fixed = kwargs.pop('fixed', False) if is_fixed: return Fixed(value=x, **kwargs) else: return Free(value=x, **kwargs)
def _split_vectorized_distribution(d: Distribution) -> list[Distribution]: """ Split an arbitrarily shaped batch of univariate numpyro distributions into a list of scalar distributions. Handles broadcasting of distribution parameters to the batch shape. Parameters ---------- d : numpyro.distributions.Distribution A distribution with ``event_shape == ()`` and arbitrary ``batch_shape``. Returns ------- list[numpyro.distributions.Distribution] A flat list of scalar distributions corresponding to the flattened batch. """ if d.event_shape != (): raise ValueError(f"Cannot split distribution with event_shape={d.event_shape} (likely an Independent or Multivariate dist)") batch_shape = d.batch_shape if not batch_shape: # Scalar distribution return [d] # Calculate total size to verify flatten length total_size = 1 for dim in batch_shape: total_size *= dim # Special handling for ImproperUniform if isinstance(d, dist.ImproperUniform): return [ dist.ImproperUniform(d.support, batch_shape=(), event_shape=d.event_shape) for _ in range(total_size) ] # Get all init params used to construct the distribution (e.g., 'loc', 'scale', 'low', 'high') # d.arg_constraints keys usually match the __init__ arguments param_names = d.arg_constraints.keys() # Extract current values of parameters param_values = {name: getattr(d, name) for name in param_names} # Broadcast all parameters to the distribution's batch_shape and flatten them flat_params = {} for name, val in param_values.items(): val = jnp.asarray(val) # Broadcast the parameter value to the distribution's batch shape. # This handles cases where e.g. Normal(0, [1, 2]) has scalar loc and vector scale. try: val_broadcast = jnp.broadcast_to(val, batch_shape) except ValueError: # Fallback or error if shapes are strictly incompatible, though numpyro usually prevents this earlier raise ValueError(f"Parameter '{name}' with shape {val.shape} cannot be broadcast to batch_shape {batch_shape}") flat_params[name] = jnp.ravel(val_broadcast) # Reconstruct individual scalar distributions split_dists = [] dist_class = d.__class__ for i in range(total_size): # Extract the i-th scalar value for each parameter args = {name: vals[i] for name, vals in flat_params.items()} split_dists.append(dist_class(**args)) return split_dists def _stack_vectorized_distributions(dists: list[Distribution | None]) -> Distribution | None: """ Combine a list of scalar numpyro distributions into a single batched distribution. Parameters ---------- dists : list[numpyro.distributions.Distribution | None] A list of distributions to stack. Returns ------- numpyro.distributions.Distribution | None The vectorized distribution, or None if no distributions were provided. """ if all(d is None for d in dists): return None if any(d is None for d in dists): raise ValueError("Cannot stack a mix of parameters where some have distributions and others do not.") dist_classes = set(type(d) for d in dists) # Fast path: If they are all the exact same family, use native NumPyro batching if len(dist_classes) == 1: dist_cls = dists[0].__class__ param_names = dists[0].arg_constraints.keys() stacked_kwargs = {name: jnp.stack([getattr(d, name) for d in dists]) for name in param_names} return dist_cls(**stacked_kwargs) # Flexible path: Use our custom meta-distribution for mixed types! return StackedDistribution(dists) def _serialize_distribution(d: Distribution | None) -> dict | None: r""" Serialize a numpyro distribution to a lightweight dictionary. Parameters ---------- d : numpyro.distributions.Distribution or None The distribution to serialize. Returns ------- dict or None A dictionary with ``class`` and ``params`` keys, or ``None``. """ if d is None: return None params = {} for k, v in d.__dict__.items(): if k.startswith("_"): continue if isinstance(v, jnp.ndarray): params[k] = v.tolist() elif isinstance(v, dist.constraints.Constraint): # Just store the name of the constraint (e.g., 'real', 'positive') params[k] = {"__constraint__": type(v).__name__} else: params[k] = v return {"class": d.__class__.__name__, "params": params} # Helper to deserialize a numpyro Distribution def _deserialize_distribution(dct: dict | None) -> Distribution | None: r""" Deserialize a numpyro distribution from a dictionary. Parameters ---------- dct : dict or None A dictionary produced by :func:`_serialize_distribution`. Returns ------- numpyro.distributions.Distribution or None The reconstructed distribution, or ``None``. Raises ------ ValueError If the distribution class is unknown. """ if dct is None: return None cls = getattr(dist, dct["class"], None) if cls is None: raise ValueError(f"Unknown distribution class: {dct['class']}") params = dct["params"] for k, v in params.items(): if isinstance(v, dict) and "__constraint__" in v: # Map the string name back to the numpyro constraint object constraint_name = v["__constraint__"].lower() # e.g. '_Real' -> 'real' constraint_name = constraint_name.strip('_') params[k] = getattr(dist.constraints, constraint_name) return cls(**params) def _format_val(val, sig_figs: int = 4) -> str: """Safely extract and format a scalar or array value for clean printing.""" if hasattr(val, "ndim") and val.ndim == 0: return f"{float(val):.{sig_figs}g}" elif hasattr(val, "size") and val.size == 1: return f"{float(val.item()):.{sig_figs}g}" else: # Fallback for multi-dimensional arrays return f"f64{list(val.shape)}" if hasattr(val, "shape") else repr(val) def _format_distribution(d: Distribution) -> str: """Format a numpyro distribution dynamically using its arg_constraints.""" class_name = d.__class__.__name__ if hasattr(d, "arg_constraints"): args = [] # Pull out loc, scale, low, high, etc., dynamically for param_name in d.arg_constraints.keys(): if hasattr(d, param_name): val = getattr(d, param_name) args.append(f"{param_name}={_format_val(val)}") return f"{class_name}({', '.join(args)})" return repr(d)