import json
import dataclasses
from copy import deepcopy
from typing import Any
import jax.numpy as jnp
import equinox as eqx
import numpyro.distributions as dist
from numpyro.distributions.distribution import Distribution
from pmrf.distributions import NormalDistribution, UniformDistribution
from pmrf.distributions.stacked import StackedDistribution
from pmrf.field import field
from pmrf.constants import MIN_PERCENTILE, MAX_PERCENTILE
[docs]
class Parameter(eqx.Module):
"""
A container for a parameter, usually used within a `Model`.
This class serves as the fundamental building block for defining the
tunable or fixed parameters within a **paramrf** `Model` and for fitting.
It is designed to be a flexible container that behaves like a standard numerical type
(e.g., a `numpy.ndarray`) while holding additional metadata for model
fitting and analysis.
Usage
-----
* Use in mathematical operations just like a JAX/numpy array.
* ``Parameter`` objects are JAX PyTrees, compatible with JAX transformations (jit, grad).
* Mark as ``fixed`` (honoured by fitting and sampling routines).
* Associate distributions, specified as numpyro distributions (uniform, normal, etc.).
Attributes
----------
value : jnp.ndarray
The underlying unscaled value. Automatically converted to a float64 array.
distribution : numpyro.distributions.Distribution or None
The prior distribution associated with this parameter.
fixed : bool
If True, the parameter is treated as a constant during optimization/sampling.
scale : float
A scaling factor. The effective value used in calculations is ``value * scale``.
name : str or None
An optional name for the parameter (marked as static).
Examples
--------
.. code-block:: python
import pmrf as prf
import jax.numpy as jnp
# A simple, single-valued parameter with a scale, initialized with a float
p1 = prf.Parameter(value=1.0, scale=1e-12, name='C1')
# This parameter can be used in calculations directly (scaling is done during casting)
impedance = 1 / (2j * jnp.pi * 1e9 * p1)
print(f"Impedance: {impedance}")
# A parameter that is fixed and will not be optimized during a fit
p2 = prf.Parameter(value=50.0, fixed=True, name='R2')
# A parameter with a uniform distribution
# The provided factory functions in pmrf.parameters are a convenient way to create these
from pmrf.parameters import Uniform
p3 = Uniform(0.9, 1.1, scale=1e-9, name='L1')
print(f"Initial value of L1: {p3.value}") # initialized to the mean
# More complicated parameters can be initialized with any "numpyro" distribution
from numpyro.distributions import LogNormal
p4 = prf.Parameter(value=10.0, distribution=LogNormal(10.0, 1.0))
"""
# Underlying values/dists (unscaled). Multiply by scale above to get to true value (done automatically when converting to array)
value: jnp.ndarray = field(converter=lambda x: jnp.asarray(x, dtype=jnp.float64))
distribution: Distribution | None = field(default=None)
# Static (metadata) fields
fixed: bool = field(default=False, static=True)
scale: float = field(default=1.0, static=True)
name: str | None = field(default=None, static=True)
flat_names: list[str] | None = field(default=None, converter=lambda x: list(x) if x is not None else x, static=True)
@property
def shape(self) -> tuple[int, ...]:
"""
The shape of this parameter.
"""
return self.value.shape
@property
def size(self) -> int:
"""
The number of dimensions for this parameter.
"""
return self.value.size
@property
def min(self) -> jnp.ndarray:
r"""
The unscaled minimum value of the parameter's distribution.
This uses a percentile for all distributions other than uniform.
Returns
-------
jnp.ndarray
The minimum value, or the current value if no distribution is set.
"""
if isinstance(self.distribution, dist.ImproperUniform):
return jnp.full(self.shape, -jnp.inf)
elif isinstance(self.distribution, dist.Delta):
return self.distribution.v
elif isinstance(self.distribution, dist.Uniform):
return self.distribution.low
elif self.distribution is not None:
return self.distribution.icdf(MIN_PERCENTILE)
return jnp.full(self.shape, -jnp.inf)
@property
def max(self) -> jnp.ndarray:
r"""
The unscaled maximum value of the parameter's distribution.
This uses a percentile for all distributions other than uniform.
Returns
-------
jnp.ndarray
The maximum value, or the current value if no distribution is set.
"""
if isinstance(self.distribution, dist.ImproperUniform):
return jnp.full(self.shape, jnp.inf)
elif isinstance(self.distribution, dist.Delta):
return self.distribution.v
elif isinstance(self.distribution, dist.Uniform):
return self.distribution.high
elif self.distribution is not None:
return self.distribution.icdf(MAX_PERCENTILE)
return jnp.full(self.shape, jnp.inf)
@property
def mean(self) -> jnp.ndarray:
r"""
The unscaled mean value of the parameter's distribution.
Returns
-------
jnp.ndarray
The mean value, or the current value if no distribution is set.
"""
# ImproperUniform doesn't have a defined mean, so we fall back to the initialized value
if isinstance(self.distribution, dist.ImproperUniform) or self.distribution is None:
return self.value
elif isinstance(self.distribution, dist.Delta):
return self.distribution.v
return self.distribution.mean
[docs]
def with_value(self, value: jnp.ndarray) -> 'Parameter':
r"""
Return a copy of the parameter with a new unscaled value.
Parameters
----------
value : jnp.ndarray
The new unscaled value to set.
Returns
-------
Parameter
A copy of this object with ``value`` replaced.
"""
return dataclasses.replace(self, value=value)
[docs]
def shifted(self, amount: jnp.ndarray) -> 'Parameter':
r"""
Returns a copy of this parameter with its mean shifted by an amonut.
Parameters
----------
amount : jnp.ndarray
The amount to shift by. Can be positive or negative.
Returns
-------
Parameter
A copy of this object with ``value`` replaced.
"""
new_value = self.value + amount
new_distribution = deepcopy(self.distribution)
if isinstance(new_distribution, UniformDistribution):
new_distribution.low += amount
new_distribution.high += amount
elif isinstance(new_distribution, NormalDistribution):
new_distribution.loc += amount
else:
raise Exception("Can only call 'shifted' on a normal or uniform parameter")
return dataclasses.replace(self, value=new_value, distribution=new_distribution)
[docs]
def with_distribution(self, distribution: Distribution) -> 'Parameter':
r"""
Return a copy of the parameter with a new distribution.
Parameters
----------
distribution : numpyro.distributions.Distribution
The distribution to associate with this parameter.
Returns
-------
Parameter
A copy of this object with ``distribution`` replaced.
Raises
------
Exception
If ``dist`` is not a numpyro Distribution.
"""
if not isinstance(distribution, Distribution):
raise Exception('Only numpyro distributions are supported as parameter distributions')
return dataclasses.replace(self, distribution=distribution)
[docs]
def flattened(self, separator='_') -> 'list[Parameter]':
r"""
Flatten self into a list of scalar Parameters.
If the internal parameter is scalar, the list will contain self.
Otherwise, the parameter is split (de-vectorized) if possible.
Parameters
----------
separator : str, optional, default='_'
Separator used for naming split parameters (e.g., name_0, name_1).
Returns
-------
list[Parameter]
The list of individual parameters.
Raises
------
ValueError
If any internal distributions cannot be de-vectorized.
"""
# Handle scalar / 0-d array
if self.value.ndim == 0 and self.flat_names is None:
return [self]
# Flatten the value
flat_val = jnp.ravel(self.value)
# Split distribution if present
if self.distribution is not None:
dists_split = _split_vectorized_distribution(self.distribution)
else:
dists_split = [None] * flat_val.size
# Generate names
flat_names = self.flat_names
if flat_names is None:
if self.name is not None:
flat_names = [f"{self.name}{separator}{i}" for i in range(flat_val.size)]
else:
flat_names = [None] * flat_val.size
return [
Parameter(value=val, distribution=p, fixed=self.fixed, scale=self.scale, name=n)
for val, p, n in zip(flat_val, dists_split, flat_names)
]
[docs]
def as_fixed(self) -> 'Parameter':
r"""
Return a copy of self with ``fixed=True``.
Returns
-------
Parameter
The new, fixed parameter.
"""
return dataclasses.replace(self, fixed=True)
[docs]
def as_free(self) -> 'Parameter':
r"""
Return a copy of self with ``fixed=False``.
Returns
-------
Parameter
The new, free parameter.
"""
return dataclasses.replace(self, fixed=False)
# Arithmetic and array conversions
def __array__(self, dtype=None):
r"""
NumPy array interface.
Parameters
----------
dtype : Any, optional
Desired dtype.
Returns
-------
jnp.ndarray
The scaled value as an array (``value * scale``).
"""
return jnp.asarray(self.value * self.scale, dtype=dtype)
def __jax_array__(self, dtype=None):
r"""
JAX array interface.
Parameters
----------
dtype : Any, optional
Desired dtype.
Returns
-------
jnp.ndarray
The scaled value as an array (``value * scale``).
"""
return jnp.asarray(self.value * self.scale, dtype=dtype)
def __len__(self):
r"""
Length of the parameter value.
Returns
-------
int
``1`` for scalars, otherwise ``len(value)``.
"""
if len(self.value.shape) == 0:
return 1 # e.g. for jax scalars
return len(self.value)
def __repr__(self):
# Build the representation dynamically
# 'value' is always printed as it is the core of the Parameter
args = [f"value={_format_val(self.value)}"]
# Only add attributes if they deviate from the default
if self.scale != 1.0:
args.append(f"scale={self.scale}")
if self.fixed is not False:
args.append(f"fixed={self.fixed}")
if self.distribution is not None:
args.append(f"distribution={_format_distribution(self.distribution)}")
if self.name is not None:
args.append(f"name={repr(self.name)}")
return f"Parameter({', '.join(args)})"
def __add__(self, other):
r"""Elementwise addition."""
return jnp.add(jnp.array(self), jnp.array(other))
def __sub__(self, other):
r"""Elementwise subtraction."""
return jnp.subtract(jnp.array(self), jnp.array(other))
def __mul__(self, other):
r"""Elementwise multiplication."""
return jnp.multiply(jnp.array(self), jnp.array(other))
def __truediv__(self, other):
r"""Elementwise true division."""
return jnp.divide(jnp.array(self), jnp.array(other))
def __radd__(self, other):
r"""Reflected elementwise addition."""
return jnp.add(jnp.array(other), jnp.array(self))
def __rsub__(self, other):
r"""Reflected elementwise subtraction."""
return jnp.subtract(jnp.array(other), jnp.array(self))
def __rmul__(self, other):
r"""Reflected elementwise multiplication."""
return jnp.multiply(jnp.array(other), jnp.array(self))
def __rtruediv__(self, other):
r"""Reflected elementwise true division."""
return jnp.divide(jnp.array(other), jnp.array(self))
[docs]
def copy(self):
r"""
Return a shallow copy.
Returns
-------
Parameter
A copy created via ``dataclasses.replace``.
"""
return dataclasses.replace(self)
# Serialization
[docs]
def to_json(self) -> str:
r"""
Serialize the parameter to a JSON string.
Returns
-------
str
A JSON-formatted string containing value, distribution, fixed, scale, and name.
"""
d = {
"value": self.value.tolist(),
"distribution": _serialize_distribution(self.distribution),
"fixed": self.fixed,
"scale": self.scale,
"name": self.name
}
return json.dumps(d, indent=2)
[docs]
@classmethod
def from_json(cls, s: str) -> "Parameter":
r"""
Deserialize a parameter from a JSON string.
Parameters
----------
s : str
The JSON string produced by :meth:`to_json`.
Returns
-------
Parameter
A reconstructed :class:`Parameter` instance.
"""
d = json.loads(s)
return cls(
value=jnp.asarray(d["value"]),
distribution=_deserialize_distribution(d["distribution"]),
fixed=d["fixed"],
scale=d["scale"],
name=d["name"]
)
[docs]
def is_param(x) -> bool:
r"""
Check if an object is an instance of a `Parameter`.
Parameters
----------
x
The object to check.
Returns
-------
bool
`True` if the object is a Parameter, `False` otherwise.
"""
return isinstance(x, Parameter)
[docs]
def is_valid_param(x) -> bool:
r"""
Check if an object is an instance of a `Parameter` and if its value is not None.
Parameters
----------
x
The object to check.
Returns
-------
bool
`True` if the object is a valid Parameter, `False` otherwise.
"""
return isinstance(x, Parameter) and x.value is not None
[docs]
def is_free_param(x) -> bool:
r"""
Check if an object is a non-fixed `Parameter`.
Parameters
----------
x
The object to check.
Returns
-------
bool
`True` if the object is a non-fixed Parameter, `False` otherwise.
"""
return isinstance(x, Parameter) and not x.fixed
[docs]
def is_fixed_param(x) -> bool:
r"""
Check if an object is a fixed `Parameter`.
Parameters
----------
x
The object to check.
Returns
-------
bool
`True` if the object is a fixed Parameter, `False` otherwise.
"""
return isinstance(x, Parameter) and x.fixed
[docs]
def as_param(x: Any | list[Any] | dict[str, Any], **kwargs) -> Parameter:
r"""
Ensure an object is a `Parameter` or container over parameters.
If the object is already a `Parameter`, it is returned unchanged.
Otherwise, the underlying objects are converted into new `Parameter` objects.
Parameters
----------
x
The object to convert.
**kwargs
Additional keyword arguments passed to the `Parameter` constructor (e.g. `name`).
Returns
-------
Parameter
The object wrapped as a `Parameter`.
"""
from pmrf.parameters.factories import Free, Fixed
if isinstance(x, Parameter):
return x
elif isinstance(x, list):
return [as_param(xi, **kwargs) for xi in x]
elif isinstance(x, dict):
return {k: as_param(xi, **kwargs) for k, xi in x.items()}
else:
is_fixed = kwargs.pop('fixed', False)
if is_fixed:
return Fixed(value=x, **kwargs)
else:
return Free(value=x, **kwargs)
def _split_vectorized_distribution(d: Distribution) -> list[Distribution]:
"""
Split an arbitrarily shaped batch of univariate numpyro distributions into a list of scalar distributions.
Handles broadcasting of distribution parameters to the batch shape.
Parameters
----------
d : numpyro.distributions.Distribution
A distribution with ``event_shape == ()`` and arbitrary ``batch_shape``.
Returns
-------
list[numpyro.distributions.Distribution]
A flat list of scalar distributions corresponding to the flattened batch.
"""
if d.event_shape != ():
raise ValueError(f"Cannot split distribution with event_shape={d.event_shape} (likely an Independent or Multivariate dist)")
batch_shape = d.batch_shape
if not batch_shape: # Scalar distribution
return [d]
# Calculate total size to verify flatten length
total_size = 1
for dim in batch_shape:
total_size *= dim
# Special handling for ImproperUniform
if isinstance(d, dist.ImproperUniform):
return [
dist.ImproperUniform(d.support, batch_shape=(), event_shape=d.event_shape)
for _ in range(total_size)
]
# Get all init params used to construct the distribution (e.g., 'loc', 'scale', 'low', 'high')
# d.arg_constraints keys usually match the __init__ arguments
param_names = d.arg_constraints.keys()
# Extract current values of parameters
param_values = {name: getattr(d, name) for name in param_names}
# Broadcast all parameters to the distribution's batch_shape and flatten them
flat_params = {}
for name, val in param_values.items():
val = jnp.asarray(val)
# Broadcast the parameter value to the distribution's batch shape.
# This handles cases where e.g. Normal(0, [1, 2]) has scalar loc and vector scale.
try:
val_broadcast = jnp.broadcast_to(val, batch_shape)
except ValueError:
# Fallback or error if shapes are strictly incompatible, though numpyro usually prevents this earlier
raise ValueError(f"Parameter '{name}' with shape {val.shape} cannot be broadcast to batch_shape {batch_shape}")
flat_params[name] = jnp.ravel(val_broadcast)
# Reconstruct individual scalar distributions
split_dists = []
dist_class = d.__class__
for i in range(total_size):
# Extract the i-th scalar value for each parameter
args = {name: vals[i] for name, vals in flat_params.items()}
split_dists.append(dist_class(**args))
return split_dists
def _stack_vectorized_distributions(dists: list[Distribution | None]) -> Distribution | None:
"""
Combine a list of scalar numpyro distributions into a single batched distribution.
Parameters
----------
dists : list[numpyro.distributions.Distribution | None]
A list of distributions to stack.
Returns
-------
numpyro.distributions.Distribution | None
The vectorized distribution, or None if no distributions were provided.
"""
if all(d is None for d in dists):
return None
if any(d is None for d in dists):
raise ValueError("Cannot stack a mix of parameters where some have distributions and others do not.")
dist_classes = set(type(d) for d in dists)
# Fast path: If they are all the exact same family, use native NumPyro batching
if len(dist_classes) == 1:
dist_cls = dists[0].__class__
param_names = dists[0].arg_constraints.keys()
stacked_kwargs = {name: jnp.stack([getattr(d, name) for d in dists]) for name in param_names}
return dist_cls(**stacked_kwargs)
# Flexible path: Use our custom meta-distribution for mixed types!
return StackedDistribution(dists)
def _serialize_distribution(d: Distribution | None) -> dict | None:
r"""
Serialize a numpyro distribution to a lightweight dictionary.
Parameters
----------
d : numpyro.distributions.Distribution or None
The distribution to serialize.
Returns
-------
dict or None
A dictionary with ``class`` and ``params`` keys, or ``None``.
"""
if d is None:
return None
params = {}
for k, v in d.__dict__.items():
if k.startswith("_"):
continue
if isinstance(v, jnp.ndarray):
params[k] = v.tolist()
elif isinstance(v, dist.constraints.Constraint):
# Just store the name of the constraint (e.g., 'real', 'positive')
params[k] = {"__constraint__": type(v).__name__}
else:
params[k] = v
return {"class": d.__class__.__name__, "params": params}
# Helper to deserialize a numpyro Distribution
def _deserialize_distribution(dct: dict | None) -> Distribution | None:
r"""
Deserialize a numpyro distribution from a dictionary.
Parameters
----------
dct : dict or None
A dictionary produced by :func:`_serialize_distribution`.
Returns
-------
numpyro.distributions.Distribution or None
The reconstructed distribution, or ``None``.
Raises
------
ValueError
If the distribution class is unknown.
"""
if dct is None:
return None
cls = getattr(dist, dct["class"], None)
if cls is None:
raise ValueError(f"Unknown distribution class: {dct['class']}")
params = dct["params"]
for k, v in params.items():
if isinstance(v, dict) and "__constraint__" in v:
# Map the string name back to the numpyro constraint object
constraint_name = v["__constraint__"].lower() # e.g. '_Real' -> 'real'
constraint_name = constraint_name.strip('_')
params[k] = getattr(dist.constraints, constraint_name)
return cls(**params)
def _format_val(val, sig_figs: int = 4) -> str:
"""Safely extract and format a scalar or array value for clean printing."""
if hasattr(val, "ndim") and val.ndim == 0:
return f"{float(val):.{sig_figs}g}"
elif hasattr(val, "size") and val.size == 1:
return f"{float(val.item()):.{sig_figs}g}"
else:
# Fallback for multi-dimensional arrays
return f"f64{list(val.shape)}" if hasattr(val, "shape") else repr(val)
def _format_distribution(d: Distribution) -> str:
"""Format a numpyro distribution dynamically using its arg_constraints."""
class_name = d.__class__.__name__
if hasattr(d, "arg_constraints"):
args = []
# Pull out loc, scale, low, high, etc., dynamically
for param_name in d.arg_constraints.keys():
if hasattr(d, param_name):
val = getattr(d, param_name)
args.append(f"{param_name}={_format_val(val)}")
return f"{class_name}({', '.join(args)})"
return repr(d)