Parameter (pmrf.Parameter)

class Parameter(value, distribution=None, fixed=False, scale=1.0, name=None, flat_names=None)[source]

Bases: Module

A container for a parameter, usually used within a Model.

This class serves as the fundamental building block for defining the tunable or fixed parameters within a paramrf Model and for fitting. It is designed to be a flexible container that behaves like a standard numerical type (e.g., a numpy.ndarray) while holding additional metadata for model fitting and analysis.

Usage

  • Use in mathematical operations just like a JAX/numpy array.

  • Parameter objects are JAX PyTrees, compatible with JAX transformations (jit, grad).

  • Mark as fixed (honoured by fitting and sampling routines).

  • Associate distributions, specified as numpyro distributions (uniform, normal, etc.).

ivar value:

The underlying unscaled value. Automatically converted to a float64 array.

vartype value:

jnp.ndarray

ivar distribution:

The prior distribution associated with this parameter.

vartype distribution:

numpyro.distributions.Distribution or None

ivar fixed:

If True, the parameter is treated as a constant during optimization/sampling.

vartype fixed:

bool

ivar scale:

A scaling factor. The effective value used in calculations is value * scale.

vartype scale:

float

ivar name:

An optional name for the parameter (marked as static).

vartype name:

str or None

Examples

import pmrf as prf
import jax.numpy as jnp

# A simple, single-valued parameter with a scale, initialized with a float
p1 = prf.Parameter(value=1.0, scale=1e-12, name='C1')

# This parameter can be used in calculations directly (scaling is done during casting)
impedance = 1 / (2j * jnp.pi * 1e9 * p1)
print(f"Impedance: {impedance}")

# A parameter that is fixed and will not be optimized during a fit
p2 = prf.Parameter(value=50.0, fixed=True, name='R2')

# A parameter with a uniform distribution
# The provided factory functions in pmrf.parameters are a convenient way to create these
from pmrf.parameters import Uniform
p3 = Uniform(0.9, 1.1, scale=1e-9, name='L1')
print(f"Initial value of L1: {p3.value}") # initialized to the mean

# More complicated parameters can be initialized with any "numpyro" distribution
from numpyro.distributions import LogNormal
p4 = prf.Parameter(value=10.0, distribution=LogNormal(10.0, 1.0))
value: Array
distribution: Distribution | None = None
fixed: bool = False
scale: float = 1.0
name: str | None = None
flat_names: list[str] | None = None
property shape: tuple[int, ...]

The shape of this parameter.

property size: int

The number of dimensions for this parameter.

property min: Array

The unscaled minimum value of the parameter’s distribution.

This uses a percentile for all distributions other than uniform.

Returns:

The minimum value, or the current value if no distribution is set.

Return type:

jnp.ndarray

property max: Array

The unscaled maximum value of the parameter’s distribution.

This uses a percentile for all distributions other than uniform.

Returns:

The maximum value, or the current value if no distribution is set.

Return type:

jnp.ndarray

property mean: Array

The unscaled mean value of the parameter’s distribution.

Returns:

The mean value, or the current value if no distribution is set.

Return type:

jnp.ndarray

with_value(value)[source]

Return a copy of the parameter with a new unscaled value.

Parameters:

value (jnp.ndarray) – The new unscaled value to set.

Returns:

A copy of this object with value replaced.

Return type:

Parameter

shifted(amount)[source]

Returns a copy of this parameter with its mean shifted by an amonut.

Parameters:

amount (jnp.ndarray) – The amount to shift by. Can be positive or negative.

Returns:

A copy of this object with value replaced.

Return type:

Parameter

with_distribution(distribution)[source]

Return a copy of the parameter with a new distribution.

Parameters:

distribution (numpyro.distributions.Distribution) – The distribution to associate with this parameter.

Returns:

A copy of this object with distribution replaced.

Return type:

Parameter

Raises:

Exception – If dist is not a numpyro Distribution.

flattened(separator='_')[source]

Flatten self into a list of scalar Parameters.

If the internal parameter is scalar, the list will contain self. Otherwise, the parameter is split (de-vectorized) if possible.

Parameters:

separator (str, optional, default='_') – Separator used for naming split parameters (e.g., name_0, name_1).

Returns:

The list of individual parameters.

Return type:

list[Parameter]

Raises:

ValueError – If any internal distributions cannot be de-vectorized.

as_fixed()[source]

Return a copy of self with fixed=True.

Returns:

The new, fixed parameter.

Return type:

Parameter

as_free()[source]

Return a copy of self with fixed=False.

Returns:

The new, free parameter.

Return type:

Parameter

copy()[source]

Return a shallow copy.

Returns:

A copy created via dataclasses.replace.

Return type:

Parameter

to_json()[source]

Serialize the parameter to a JSON string.

Returns:

A JSON-formatted string containing value, distribution, fixed, scale, and name.

Return type:

str

classmethod from_json(s)[source]

Deserialize a parameter from a JSON string.

Parameters:

s (str) – The JSON string produced by to_json().

Returns:

A reconstructed Parameter instance.

Return type:

Parameter

Parameters: