Source code for pmrf.distributions.parameter

from numpyro.distributions import Distribution, constraints
import jax.numpy as jnp
import jax

from pmrf.parameters import ParameterGroup

[docs] class JointParameterDistribution(Distribution): support = constraints.real_vector reparametrized_params = [] has_rsample = True def __init__(self, param_groups: list[ParameterGroup], param_names: list[str], kind='prior'): self.param_groups = param_groups self.param_names = param_names self.kind = kind self.name_to_index = {name: i for i, name in enumerate(param_names)} self._validate_groups() event_shape = (len(param_names),) batch_shape = () super().__init__(batch_shape=batch_shape, event_shape=event_shape) def _validate_groups(self): for group in self.param_groups: if not hasattr(group, self.kind): raise ValueError(f"Parameter group {group} has no {self.kind} defined.") dist: Distribution = getattr(group, self.kind) if not hasattr(dist, "sample") or not hasattr(dist, "log_prob"): raise ValueError(f"Parameter group prior must support sample and log_prob.")
[docs] def sample(self, key, sample_shape=()): values = [None] * len(self.param_names) for group in self.param_groups: group_names = group.param_names subkey, key = jax.random.split(key) dist: Distribution = getattr(group, self.kind) samples = dist.sample(subkey, sample_shape) for i, name in enumerate(group_names): idx = self.name_to_index[name] values[idx] = samples[..., i] return jnp.stack(values, axis=-1)
[docs] def log_prob(self, value): total_logp = 0.0 for group in self.param_groups: group_names = group.param_names idxs = [self.name_to_index[name] for name in group_names] group_vals = value[..., idxs] dist: Distribution = getattr(group, self.kind) logp = dist.log_prob(group_vals) total_logp += logp return total_logp
[docs] def icdf(self, u): values = [None] * len(self.param_names) for group in self.param_groups: group_names = group.param_names idxs = [self.name_to_index[name] for name in group_names] group_u = u[..., idxs] dist: Distribution = getattr(group, self.kind) group_x = dist.icdf(group_u) for i, name in enumerate(group_names): values[self.name_to_index[name]] = group_x[..., i] return jnp.stack(values, axis=-1).reshape(u.shape)
[docs] def cdf(self, x): values = [None] * len(self.param_names) for group in self.param_groups: group_names = group.param_names idxs = [self.name_to_index[name] for name in group_names] group_x = x[..., idxs] dist: Distribution = getattr(group, self.kind) group_u = dist.cdf(group_x) for i, name in enumerate(group_names): values[self.name_to_index[name]] = group_u[..., i] return jnp.stack(values, axis=-1)
[docs] def support(self): return constraints.real_vector # Override if more specific support is known