import jax
import jax.numpy as jnp
from typing import Any
import h5py
from pmrf.fitting._bayesian import BayesianFitter, BayesianResults
from pmrf.fitting.results._numpyro import NumPyroResults
from pmrf._util import explicit_kwargs
class NumPyroFitter(BayesianFitter):
def _make_numpyro_model(self):
import numpyro
import numpyro.distributions as dist
# Get the model parameters
params = self.initial_model.flat_params()
self._make_log_prior_fn()
param_names = list(params.keys())
param_priors = [param.prior for param in params.values()]
# Generate feature function and prepare the obs
self.logger.info("Compiling model and likelihood function...")
x0 = self.initial_model.flat_params()
feature_fn = self._make_feature_function()
feature_fn = jax.jit(feature_fn)
_y0 = feature_fn(x0)
# Define the numpyro model
obs_real, obs_imag = jnp.real(self.measured_features), jnp.imag(self.measured_features)
def numpyro_model():
x = jnp.stack([numpyro.sample(param_name, prior) for param_name, prior in zip(param_names, param_priors)])
y_pred = feature_fn(x)
y_real, y_imag = jnp.real(y_pred), jnp.imag(y_pred)
sigma = numpyro.sample('sigma', self.likelihood_params['sigma'].prior)
numpyro.sample('obs_real', dist.Normal(y_real, sigma), obs=obs_real)
numpyro.sample('obs_imag', dist.Normal(y_imag, sigma), obs=obs_imag)
return numpyro_model
[docs]
class NumPyroMCMCFitter(NumPyroFitter):
"""
NumPyro fitter using numpyro.infer.MCMC.
"""
[docs]
def run(self, kernel=None, **kwargs) -> NumPyroResults:
from numpyro.infer import MCMC, NUTS
if kernel is None:
kernel = NUTS
# Get the model parameters
param_names = self.initial_model.flat_param_names()
# Define the numpyro model
numpyro_model = self._make_numpyro_model()
# Run MCMC
self.logger.info(f'Fitting for {len(param_names)} model parameter(s)...')
self.logger.info(f'Parameter names: {param_names}')
rng = jax.random.PRNGKey(42)
kwargs.setdefault('num_warmup', 500)
kwargs.setdefault('num_samples', 1000)
mcmc = MCMC(kernel(numpyro_model), **kwargs)
mcmc.run(rng)
samples = mcmc.get_samples()
# Posterior means
x_mean = jnp.stack([samples[param_name].mean() for param_name in param_names])
fitted_model = self.initial_model.with_flat_params(x_mean)
# Return the results
settings = self._settings(kwargs, explicit_kwargs())
return NumPyroResults(
measured=self.measured,
initial_model=self.initial_model,
fitted_model=fitted_model,
solver_results=samples,
settings=settings,
)
class NumPyroNSFitter(NumPyroFitter):
def run(self, *, constructor_kwargs=None, terminated_kwargs=None, **kwargs) -> NumPyroResults:
from numpyro.contrib.nested_sampling import NestedSampler
# Get the model parameters
params = self.initial_model.flat_params()
param_names = list(params.keys())
# Define the numpyro model
numpyro_model = self._make_numpyro_model()
# Run MCMC
self.logger.info(f'Fitting for {len(param_names)} model parameter(s)...')
self.logger.info(f'Parameter names: {param_names}')
ns = NestedSampler(numpyro_model, constructor_kwargs=constructor_kwargs, termination_kwargs=terminated_kwargs)
ns.run(jax.random.PRNGKey(42))
samples = ns.get_samples(jax.random.PRNGKey(3), num_samples=1000)
# Posterior means
x_mean = jnp.stack([samples[param_name].mean() for param_name in param_names])
fitted_model = self.initial_model.with_flat_params(x_mean)
# Return the results
settings = self._settings(kwargs, explicit_kwargs())
return NumPyroResults(
measured=self.measured,
initial_model=self.initial_model,
fitted_model=fitted_model,
solver_results=samples,
settings=settings,
)