Source code for pmrf.distributions.joint

"""
A distribution composed of multiple independent distributions.
"""

from numpyro.distributions import Distribution, constraints
import numpyro.distributions as dist
import jax.numpy as jnp
import jax

from pmrf.constants import MIN_PERCENTILE, MAX_PERCENTILE

[docs] class JointDistribution(Distribution): support = constraints.real_vector reparametrized_params = [] has_rsample = True def __init__( self, distributions: list[Distribution], distribution_names: list[list[str]], param_names: list[str] ): """ A joint distribution composed of multiple independent component distributions. Parameters ---------- distributions : list[Distribution] List of numpyro distributions. distribution_names : list[list[str]] List of lists, where each sub-list contains the names of the parameters associated with the corresponding distribution. param_names : list[str] The global ordered list of all parameter names. """ self.distributions = distributions self.distribution_names = distribution_names self.param_names = param_names self.name_to_index = {name: i for i, name in enumerate(param_names)} self._validate_inputs() event_shape = (len(param_names),) batch_shape = () super().__init__(batch_shape=batch_shape, event_shape=event_shape) def _validate_inputs(self): if len(self.distributions) != len(self.distribution_names): raise ValueError("The number of distributions must match the number of name lists in distribution_names.") all_provided_names = [name for names in self.distribution_names for name in names] if set(all_provided_names) != set(self.param_names): raise ValueError("Mismatch between distribution_names and the global param_names.") for dist in self.distributions: if not hasattr(dist, "sample") or not hasattr(dist, "log_prob"): raise ValueError("All distributions must support 'sample' and 'log_prob'.") @property def num_params(self) -> int: return len(self.param_names)
[docs] def sample(self, key, sample_shape=()): values = [None] * len(self.param_names) for dist, group_names in zip(self.distributions, self.distribution_names): subkey, key = jax.random.split(key) samples = dist.sample(subkey, sample_shape) # Expand dimension if the distribution is univariate to allow [..., i] indexing if not dist.event_shape: samples = samples[..., None] for i, name in enumerate(group_names): idx = self.name_to_index[name] values[idx] = samples[..., i] return jnp.stack(values, axis=-1)
[docs] def log_prob(self, value): total_logp = 0.0 for dist, group_names in zip(self.distributions, self.distribution_names): idxs = [self.name_to_index[name] for name in group_names] group_vals = value[..., idxs] # Squeeze dimension if the distribution expects scalars if not dist.event_shape and len(group_names) == 1: group_vals = group_vals[..., 0] logp = dist.log_prob(group_vals) total_logp += logp return total_logp
[docs] def icdf(self, u): values = [None] * len(self.param_names) for dist, group_names in zip(self.distributions, self.distribution_names): idxs = [self.name_to_index[name] for name in group_names] group_u = u[..., idxs] if not dist.event_shape and len(group_names) == 1: group_u = group_u[..., 0] group_x = dist.icdf(group_u) if not dist.event_shape: group_x = group_x[..., None] for i, name in enumerate(group_names): values[self.name_to_index[name]] = group_x[..., i] return jnp.stack(values, axis=-1)
[docs] def cdf(self, x): values = [None] * len(self.param_names) for dist, group_names in zip(self.distributions, self.distribution_names): idxs = [self.name_to_index[name] for name in group_names] group_x = x[..., idxs] if not dist.event_shape and len(group_names) == 1: group_x = group_x[..., 0] group_u = dist.cdf(group_x) if not dist.event_shape: group_u = group_u[..., None] for i, name in enumerate(group_names): values[self.name_to_index[name]] = group_u[..., i] return jnp.stack(values, axis=-1)
@property def min(self) -> jnp.ndarray: """ Dynamically infer lower bounds based on the numpyro constraints of each distribution. MIN_PERCENTILE is used if the distribution is not strictly bounded. """ lower_bounds = [None] * len(self.param_names) for d, group_names in zip(self.distributions, self.distribution_names): # 1. Safely extract the lower bound for this distribution group if isinstance(d, dist.ImproperUniform): lb_val = -jnp.inf elif isinstance(d, dist.Delta): lb_val = d.v elif isinstance(d, dist.Uniform): lb_val = d.low else: # Best-effort extraction from Numpyro constraints lb_val = getattr(d.support, "lower_bound", d.icdf(MIN_PERCENTILE)) # 2. Map the extracted bounds back to the global parameter order for i, name in enumerate(group_names): idx = self.name_to_index[name] lb = lb_val # Handle cases where the bound is a vector (e.g., multivariate constraints) if isinstance(lb, (jnp.ndarray, float, int)) and jnp.ndim(lb) > 0: lb = lb[i] lower_bounds[idx] = lb return jnp.array(lower_bounds) @property def max(self) -> jnp.ndarray: """ Dynamically infer upper bounds based on the numpyro constraints of each distribution. MAX_PERCENTILE is used if the distribution is not strictly bounded. """ upper_bounds = [None] * len(self.param_names) for d, group_names in zip(self.distributions, self.distribution_names): # 1. Safely extract the upper bound for this distribution group if isinstance(d, dist.ImproperUniform): ub_val = jnp.inf elif isinstance(d, dist.Delta): ub_val = d.v elif isinstance(d, dist.Uniform): ub_val = d.high else: # Best-effort extraction from Numpyro constraints ub_val = getattr(d.support, "upper_bound", d.icdf(MAX_PERCENTILE)) # 2. Map the extracted bounds back to the global parameter order for i, name in enumerate(group_names): idx = self.name_to_index[name] ub = ub_val # Handle cases where the bound is a vector (e.g., multivariate constraints) if isinstance(ub, (jnp.ndarray, float, int)) and jnp.ndim(ub) > 0: ub = ub[i] upper_bounds[idx] = ub return jnp.array(upper_bounds) @property def bounds(self) -> tuple[jnp.ndarray, jnp.ndarray]: """ Dynamically infer bounds based on the numpyro constraints of each distribution. MIN_PERCENTILE and MAX_PERCENTILE are used for "min" and "max" if the distribution is not bounded. """ return self.min, self.max