Source code for pmrf.fitting.fitters._numpyro

import jax
import jax.numpy as jnp
from typing import Any
import h5py

from pmrf.fitting._bayesian import BayesianFitter, BayesianResults
from pmrf.fitting.results._numpyro import NumPyroResults
from pmrf._util import explicit_kwargs
        
class NumPyroFitter(BayesianFitter):
    def _make_numpyro_model(self):
        import numpyro
        import numpyro.distributions as dist
        
        # Get the model parameters
        params = self.initial_model.flat_params()
        self._make_log_prior_fn()
        
        param_names = list(params.keys())
        param_priors = [param.prior for param in params.values()]
        
        # Generate feature function and prepare the obs
        self.logger.info("Compiling model and likelihood function...")
        x0 = self.initial_model.flat_params()
        feature_fn = self._make_feature_function()
        feature_fn = jax.jit(feature_fn)
        _y0 = feature_fn(x0)
        
        # Define the numpyro model
        obs_real, obs_imag = jnp.real(self.measured_features), jnp.imag(self.measured_features)
        def numpyro_model():
            x = jnp.stack([numpyro.sample(param_name, prior) for param_name, prior in zip(param_names, param_priors)])

            y_pred = feature_fn(x)
            y_real, y_imag = jnp.real(y_pred), jnp.imag(y_pred)
            sigma = numpyro.sample('sigma', self.likelihood_params['sigma'].prior)

            numpyro.sample('obs_real', dist.Normal(y_real, sigma), obs=obs_real)
            numpyro.sample('obs_imag', dist.Normal(y_imag, sigma), obs=obs_imag)       
            
        return numpyro_model

[docs] class NumPyroMCMCFitter(NumPyroFitter): """ NumPyro fitter using numpyro.infer.MCMC. """
[docs] def run(self, kernel=None, **kwargs) -> NumPyroResults: from numpyro.infer import MCMC, NUTS if kernel is None: kernel = NUTS # Get the model parameters param_names = self.initial_model.flat_param_names() # Define the numpyro model numpyro_model = self._make_numpyro_model() # Run MCMC self.logger.info(f'Fitting for {len(param_names)} model parameter(s)...') self.logger.info(f'Parameter names: {param_names}') rng = jax.random.PRNGKey(42) kwargs.setdefault('num_warmup', 500) kwargs.setdefault('num_samples', 1000) mcmc = MCMC(kernel(numpyro_model), **kwargs) mcmc.run(rng) samples = mcmc.get_samples() # Posterior means x_mean = jnp.stack([samples[param_name].mean() for param_name in param_names]) fitted_model = self.initial_model.with_flat_params(x_mean) # Return the results settings = self._settings(kwargs, explicit_kwargs()) return NumPyroResults( measured=self.measured, initial_model=self.initial_model, fitted_model=fitted_model, solver_results=samples, settings=settings, )
class NumPyroNSFitter(NumPyroFitter): def run(self, *, constructor_kwargs=None, terminated_kwargs=None, **kwargs) -> NumPyroResults: from numpyro.contrib.nested_sampling import NestedSampler # Get the model parameters params = self.initial_model.flat_params() param_names = list(params.keys()) # Define the numpyro model numpyro_model = self._make_numpyro_model() # Run MCMC self.logger.info(f'Fitting for {len(param_names)} model parameter(s)...') self.logger.info(f'Parameter names: {param_names}') ns = NestedSampler(numpyro_model, constructor_kwargs=constructor_kwargs, termination_kwargs=terminated_kwargs) ns.run(jax.random.PRNGKey(42)) samples = ns.get_samples(jax.random.PRNGKey(3), num_samples=1000) # Posterior means x_mean = jnp.stack([samples[param_name].mean() for param_name in param_names]) fitted_model = self.initial_model.with_flat_params(x_mean) # Return the results settings = self._settings(kwargs, explicit_kwargs()) return NumPyroResults( measured=self.measured, initial_model=self.initial_model, fitted_model=fitted_model, solver_results=samples, settings=settings, )