Source code for pmrf.fitting._backends.numpyro

import jax
import jax.numpy as jnp
from typing import Any
import h5py

from pmrf.fitting.bayesian import BayesianFitter, BayesianResults, BayesianContext

[docs] class NumPyroResults(BayesianResults): """ NumPyro: Results from ``numpyro.infer``. This class handles the storage and retrieval of samples generated by NumPyro's MCMC or Nested Sampling algorithms. """
[docs] def encode_solver_results(self, group: h5py.Group): """ Encode the NumPyro solver results into an HDF5 group. Parameters ---------- group : h5py.Group The HDF5 group to write the samples to. """ samples = self.solver_results group['samples'] = samples
[docs] @classmethod def decode_solver_results(cls, group: h5py.Group) -> Any: """ Decode NumPyro solver results from an HDF5 group. Parameters ---------- group : h5py.Group The HDF5 group to read the samples from. Returns ------- Any The decoded samples (typically a dictionary or structured array). """ return group['samples']
class NumPyroFitter(BayesianFitter): """ Base class for fitters utilizing the NumPyro probabilistic programming library. """ def make_numpyro_model(self, ctx: BayesianContext): """ Construct a NumPyro model function from the fitting context. This function defines the probabilistic model, including: 1. Sampling model parameters from their priors. 2. Computing the feature response. 3. Sampling the likelihood parameter (sigma). 4. Defining the observation likelihoods for real and imaginary parts. Parameters ---------- ctx : BayesianContext The Bayesian fitting context containing model, priors, and data. Returns ------- callable The NumPyro model function to be passed to an inference kernel. """ import numpyro import numpyro.distributions as dist # Get the model parameters params = ctx.model.named_flat_params() ctx.make_log_prior_fn() param_names = list(params.keys()) param_priors = [param.distribution for param in params.values()] # Generate feature function and prepare the obs self.logger.info("Compiling model and likelihood function...") x0 = ctx.model.flat_param_values() feature_fn = ctx.make_feature_function() feature_fn = jax.jit(feature_fn) _y0 = feature_fn(x0) # Define the numpyro model obs_real, obs_imag = jnp.real(ctx.measured_features), jnp.imag(ctx.measured_features) def numpyro_model(): x = jnp.stack([numpyro.sample(param_name, prior) for param_name, prior in zip(param_names, param_priors)]) y_pred = feature_fn(x) y_real, y_imag = jnp.real(y_pred), jnp.imag(y_pred) sigma = numpyro.sample('sigma', self.likelihood_params['sigma'].distribution) numpyro.sample('obs_real', dist.Normal(y_real, sigma), obs=obs_real) numpyro.sample('obs_imag', dist.Normal(y_imag, sigma), obs=obs_imag) return numpyro_model
[docs] class NumPyroMCMCFitter(NumPyroFitter): """ NumPyro MCMC: Markov Chain Monte Carlo (MCMC) sampling using ``numpyro.infer.MCMC``. Defaults to using the No-U-Turn Sampler (NUTS). """ def _run_algorithm(self, ctx: BayesianContext, *, kernel=None, **kwargs) -> NumPyroResults: """ Executes the MCMC sampling run. Parameters ---------- ctx : BayesianContext The Bayesian fitting context. kernel : numpyro.infer.mcmc.MCMCKernel, optional The MCMC kernel to use. Defaults to `numpyro.infer.NUTS`. **kwargs Additional keyword arguments passed to `numpyro.infer.MCMC` (e.g., `num_warmup`, `num_samples`, `num_chains`). Returns ------- NumPyroResults The results object containing the fitted model (using posterior means) and the MCMC samples. """ from numpyro.infer import MCMC, NUTS if kernel is None: kernel = NUTS # Get the model parameters param_names = ctx.model.flat_param_names() # Define the numpyro model numpyro_model = self.make_numpyro_model(ctx) # Run MCMC self.logger.info(f'Fitting for {len(param_names)} model parameter(s)...') self.logger.info(f'Parameter names: {param_names}') rng = jax.random.PRNGKey(42) kwargs.setdefault('num_warmup', 500) kwargs.setdefault('num_samples', 1000) mcmc = MCMC(kernel(numpyro_model), **kwargs) mcmc.run(rng) samples = mcmc.get_samples() # Posterior means x_mean = jnp.stack([samples[param_name].mean() for param_name in param_names]) fitted_model = ctx.model.with_params(x_mean) return NumPyroResults(fitted_model=fitted_model, solver_results=samples)
[docs] class NumPyroNSFitter(NumPyroFitter): """ NumPyro NS: Nested sampling using ``numpyro.contrib.nested_sampling.NestedSampler``. """ def _run_algorithm(self, ctx: BayesianContext, *, constructor_kwargs=None, terminated_kwargs=None) -> NumPyroResults: """ Executes the Nested Sampling run. Parameters ---------- ctx : BayesianContext The Bayesian fitting context. constructor_kwargs : dict, optional Arguments passed to the `NestedSampler` constructor. terminated_kwargs : dict, optional Arguments passed to the termination condition of the sampler. Returns ------- NumPyroResults The results object containing the fitted model and the posterior samples. """ from numpyro.contrib.nested_sampling import NestedSampler # Get the model parameters params = ctx.model.named_flat_params() param_names = list(params.keys()) # Define the numpyro model numpyro_model = self.make_numpyro_model(ctx) # Run MCMC self.logger.info(f'Fitting for {len(param_names)} model parameter(s)...') self.logger.info(f'Parameter names: {param_names}') ns = NestedSampler(numpyro_model, constructor_kwargs=constructor_kwargs, termination_kwargs=terminated_kwargs) ns.run(jax.random.PRNGKey(42)) samples = ns.get_samples(jax.random.PRNGKey(3), num_samples=1000) # Posterior means x_mean = jnp.stack([samples[param_name].mean() for param_name in param_names]) fitted_model = ctx.model.with_params(x_mean) # Return the results return NumPyroResults(fitted_model=fitted_model, solver_results=samples)