Source code for pmrf.fitting.fitters._dypolychord

from typing import Any
import jax
import jax.numpy as jnp
import io   
import h5py
import numpy as np

from pmrf.fitting._bayesian import BayesianFitter, BayesianResults
from pmrf.fitting.results import AnestheticResults
from pmrf._util import time_string, explicit_kwargs
   
PolychordResults = AnestheticResults

[docs] class dyPolychordFitter(BayesianFitter): """ dyPolychord fitter using dyPolyChord.run_dypolychord. """
[docs] def run(self, best_param_method='maximum-likelihood', nlive_init_factor=None, nlive_factor=None, **kwargs) -> AnestheticResults: # Dynamic imports import numpy as np import dyPolyChord.pypolychord_utils import dyPolyChord from anesthetic import read_chains nlive_init_factor = nlive_init_factor if nlive_init_factor is not None else 1 nlive_factor = nlive_factor if nlive_factor is not None else 25 num_params = self.initial_model.num_flat_params + len(self.likelihood_params) param_names = self._flat_param_names() dot_param_names = [name.replace('_', '.') for name in param_names] labeled_param_names = np.array([[name, f'\\theta_{{{name_replaced}}}'] for name, name_replaced in zip(param_names, dot_param_names)]) settings_dict_in = {} dynamic_goal = 1.0 kwargs.setdefault('ninit', nlive_init_factor * num_params) kwargs.setdefault('nlive_const', nlive_factor * num_params) if kwargs['nlive_const'] <= kwargs['ninit']: raise Exception('Number of dynamic live points must be greater than number of init live points') # kwargs.setdefault('paramnames', labeled_param_names) _base_dir = kwargs.pop('base_dir', 'chains') # settings_dict_in.setdefault('base_dir', kwargs.pop('base_dir', 'chains')) settings_dict_in.setdefault('file_root', 'test') # Generate prior and likelihood functions x0 = np.array(self.initial_model.flat_params()) loglikelihood_fn = self._make_log_likelihood_fn(as_numpy=True) prior_fn = self._make_prior_transform_fn(as_numpy=True) dumper = lambda _live, _dead, _logweights, logZ, _logZerr: self.logger.info(f'time: {time_string()} (logZ = {logZ:.2f})') self.logger.info(f'Fitting for {len(param_names)} parameter(s)...') self.logger.info(f'Parameter names: {param_names}') self.logger.info(f'PolyChord started at {time_string()}') # Make a callable for running PolyChord combined_fn = dyPolyChord.pypolychord_utils.RunPyPolyChord( loglikelihood_fn, prior_fn, len(param_names), ) # Run dyPolyChord # TODO clean up this parameter passing and settings dict stuff dyPolyChord.run_dypolychord( combined_fn, settings_dict_in=settings_dict_in, dynamic_goal=dynamic_goal, **kwargs, ) chains_root, file_root = kwargs['chains'], kwargs['file_root'] nested_samples = read_chains(f'{chains_root}/{file_root}') self.logger.info(f'dyPolyChord finished at {time_string()}') for i, param_name in enumerate(param_names[0:-self.num_likelihood_params]): if best_param_method == 'mean': x0[i] = nested_samples[param_name].mean() elif best_param_method == 'maximum-likelihood': idx = jnp.argmax(nested_samples.logL.values) x0[i] = nested_samples[param_name].values[idx] else: self.logger.warning("Unknown best parameter method. Skipping") fitted_model = self.initial_model.with_flat_params(x0) settings = self._settings(kwargs, explicit_kwargs()) return AnestheticResults( measured=self.measured, initial_model=self.initial_model, fitted_model=fitted_model, solver_results=nested_samples, settings=settings, )