Source code for pmrf.fitting.backends.polychord

import io
from typing import Any
import numpy as np
import h5py
import jax.numpy as jnp

from pmrf.fitting.bayesian import BayesianFitter
from pmrf.models.model import Model
from pmrf.util import time_string

[docs] class PolyChordFitter(BayesianFitter): """ Bayesian fitter using the PolyChord nested sampling algorithm. This backend uses ``pypolychord`` to perform slice sampling, calculating both the global evidence (logZ) and generating samples from the posterior distribution. """
[docs] def execute( self, target: jnp.ndarray, *, fitted_params='maximum-likelihood', nlive_factor=25, **kwargs ) -> tuple[Model, Any]: """ Execute the PolyChord nested sampling run. **NB:** This method should not be called directly. Call :meth:`run` instead. Parameters ---------- target : jax.numpy.ndarray The extracted target features to fit against. fitted_params : {'maximum-likelihood', 'mean'}, default='maximum-likelihood' How to select the final point estimates for the returned model's parameters from the posterior samples. nlive_factor : int, default=25 A multiplier to determine the number of live points. The total number of live points (``nlive``) defaults to ``nlive_factor * num_params``. **kwargs Additional keyword arguments passed directly to ``pypolychord.run``. Returns ------- tuple[Model, Any] The fitted model (with parameter groups updated to the full posterior) and the raw ``anesthetic.NestedSamples`` object. """ # Dynamic imports for heavy external dependencies import pypolychord from pmrf.parameters import ParameterGroup from pmrf.distributions import AnestheticDistribution # 1. Setup PolyChord Configuration kwargs.setdefault('nlive', nlive_factor * self.num_params) if self.output_path is not None: kwargs.setdefault('base_dir', f'{self.output_path}/chains') kwargs.setdefault('file_root', self.output_root or 'polychord') # 2. Map Parameter Names param_names = self.model.flat_param_names() + list(self.likelihood_params.keys()) dot_param_names = [name.replace('_', '.') for name in param_names] labeled_param_names = np.array([[n, dn] for n, dn in zip(param_names, dot_param_names)]) # 3. Define Wrappers for Lazily Compiled JAX Functions def log_likelihood_np(x): return float(self.log_likelihood(x, target)) def prior_np(u): return np.array(self.icdf(u)) theta0 = 0.5 * jnp.ones(len(param_names)) _ = log_likelihood_np(theta0) _ = prior_np(theta0) dumper = lambda _live, _dead, _logweights, logZ, _logZerr: \ self.logger.info(f'time: {time_string()} (logZ = {logZ:.2f})') # 4. Execute Nested Sampling self.logger.info(f'PolyChord started at {time_string()}') nested_samples = pypolychord.run( log_likelihood_np, len(param_names), dumper=dumper, prior=prior_np, paramnames=labeled_param_names, **kwargs ) self.logger.info(f'PolyChord finished at {time_string()}') # 5. Determine Best Parameters x0 = np.array(self.model.flat_param_values()) num_model_params = self.model.num_flat_params for i, param_name in enumerate(param_names[0:num_model_params]): if fitted_params == 'mean': x0[i] = nested_samples[param_name].mean() elif fitted_params == 'maximum-likelihood': idx = jnp.argmax(nested_samples.logL.values) x0[i] = nested_samples[param_name].values[idx] # 6. Update Model Priors with the full Posterior Distribution model_param_names = self.model.flat_param_names() param_group = ParameterGroup( model_param_names, AnestheticDistribution(nested_samples, model_param_names) ) fitted_model = self.model.with_params(x0).with_param_groups(param_group) return fitted_model, nested_samples
[docs] @staticmethod def write_results(stream: io.BytesIO, results: Any): """ Encode anesthetic NestedSamples into a CSV string for serialization. """ samples = results csv_str: str = samples.to_csv() stream.write(csv_str.encode('utf-8'))
[docs] @staticmethod def read_results(stream: io.BytesIO) -> Any: """ Reconstruct anesthetic NestedSamples from a serialized CSV string. """ from anesthetic import NestedSamples, read_csv csv_str = stream.read().decode('utf-8') samples = NestedSamples(read_csv(io.StringIO(csv_str))) return samples