Source code for pmrf.fitting.fitters._polychord

from typing import Any
import jax
import jax.numpy as jnp
import io   
import h5py
import numpy as np

from pmrf.fitting._bayesian import BayesianFitter
from pmrf.fitting.results import AnestheticResults
from pmrf._util import time_string, explicit_kwargs
   
# For legacy imports
PolychordResults = AnestheticResults

[docs] class PolychordFitter(BayesianFitter): """ Polychord fitter using pypolychord.run. Polychord has its own license available at https://github.com/PolyChord/PolyChordLite. """
[docs] def run(self, best_param_method='maximum-likelihood', nlive_factor=None, **kwargs) -> AnestheticResults: # Dynamic imports import numpy as np import pypolychord if not 'nlive' in kwargs and nlive_factor is not None: kwargs['nlive'] = nlive_factor * (self.initial_model.num_flat_params + len(self.likelihood_params)) # Get the model parameters param_names = self._flat_param_names() dot_param_names = [name.replace('_', '.') for name in param_names] labeled_param_names = np.array([[name, f'\\theta_{{{name_replaced}}}'] for name, name_replaced in zip(param_names, dot_param_names)]) # Generate prior and likelihood functions x0 = np.array(self.initial_model.flat_params()) loglikelihood_fn = self._make_log_likelihood_fn(as_numpy=True) prior_fn = self._make_prior_transform_fn(as_numpy=True) dumper = lambda _live, _dead, _logweights, logZ, _logZerr: self.logger.info(f'time: {time_string()} (logZ = {logZ:.2f})') self.logger.info(f'Fitting for {len(param_names)} parameter(s)...') self.logger.info(f'Parameter names: {param_names}') self.logger.info(f'PolyChord started at {time_string()}') nested_samples = pypolychord.run( loglikelihood_fn, len(param_names), dumper=dumper, prior=prior_fn, paramnames=labeled_param_names, **kwargs ) self.logger.info(f'PolyChord finished at {time_string()}') for i, param_name in enumerate(param_names[0:-self.num_likelihood_params]): if best_param_method == 'mean': x0[i] = nested_samples[param_name].mean() elif best_param_method == 'maximum-likelihood': idx = jnp.argmax(nested_samples.logL.values) x0[i] = nested_samples[param_name].values[idx] else: self.logger.warning("Unknown best parameter method. Skipping") fitted_model = self.initial_model.with_flat_params(x0) settings = self._settings(kwargs, explicit_kwargs()) return AnestheticResults( measured=self.measured, initial_model=self.initial_model, fitted_model=fitted_model, solver_results=nested_samples, settings=settings, )