Source code for pmrf.fitting.backends.blackjax

import io
import time
from typing import Any

import numpy as np
import jax
import jax.numpy as jnp
from jax.extend.backend import get_backend
from tqdm.auto import tqdm

from pmrf.fitting.bayesian import BayesianFitter
from pmrf.models.model import Model
from pmrf.parameters import ParameterGroup

[docs] class BlackJAXNSFitter(BayesianFitter): """ Bayesian fitter using the BlackJAX nested slice sampling algorithm. This backend leverages ``blackjax.nss`` for fully JAX-compiled nested sampling. It provides both the global evidence (logZ) and posterior samples, natively utilizing hardware acceleration (GPU/TPU) when available. """
[docs] def execute( self, target: jnp.ndarray, *, fitted_params='maximum-likelihood', nlive_factor=None, num_delete=None, num_inner_steps=None, logZ_convergence: float = -3.0, seed: int = 0, ) -> tuple[Model, Any]: """ Execute the nested sampling run using BlackJAX. **NB:** This method should not be called directly. Call :meth:`run` instead. Parameters ---------- target : jax.numpy.ndarray The extracted target features to fit against. fitted_params : {'maximum-likelihood', 'mean'}, default='maximum-likelihood' How to select the final point estimates for the returned model's parameters from the posterior samples. nlive_factor : int, optional A multiplier to determine the total number of live points. The total number of live points (``n_live``) evaluates to ``nlive_factor * num_params``. Defaults to 25. num_delete : int, optional The number of dead points to replace per sampling iteration. If not provided, it defaults to 10% of ``n_live`` on CPU and 50% on GPU/TPU. num_inner_steps : int, optional The number of MCMC steps used to generate a new live point. Defaults to ``3 * num_params``. logZ_convergence : float, default=-3.0 The convergence threshold for the log-evidence. The sampling loop stops when the estimated remaining evidence is less than this threshold. seed : int, default=0 The PRNG key seed for JAX's random number generator. **kwargs Additional keyword arguments. Returns ------- tuple[:class:`~pmrf.models.model.Model`, Any] The fitted model (with parameter groups updated to the full posterior) and the raw ``anesthetic.NestedSamples`` object. """ import blackjax from anesthetic import NestedSamples from pmrf.distributions import AnestheticDistribution start_time = time.time() rng_key = jax.random.PRNGKey(seed) # 1. Parameter Names Mapping param_names = self.model.flat_param_names() + list(self.likelihood_params.keys()) dot_param_names = [name.replace('_', '.') for name in param_names] labeled_param_names = { name: f'\\theta_{{{name_replaced}}}' for name, name_replaced in zip(param_names, dot_param_names) } d = len(param_names) x0 = np.array(self.model.flat_param_values()) # 2. Trigger Lazy Compilation & Extract Pure JAX Functions # We pass dummy arrays to force BaseRunner to compile the JAX graphs theta0 = self.model.flat_param_values() _ = self.log_prior(theta0) _ = self.log_likelihood(theta0, target) _ = self.icdf(theta0) # Grab the internal compiled JAX closures logprior_fn = self._log_prior_fn prior_fn = self._icdf_fn # Wrap the likelihood so it only expects theta (closing over target_features) def loglikelihood_fn(theta): return self._log_likelihood_fn(theta, target) # 3. Setup BlackJAX Configuration nlive_factor = nlive_factor if nlive_factor is not None else 25 n_live = nlive_factor * d if num_delete is None: if get_backend().platform == 'cpu': self.logger.info('Running BlackJAX on the CPU') num_delete = int(0.1 * n_live) else: self.logger.info('Running BlackJAX on a GPU/TPU') num_delete = int(0.5 * n_live) if num_inner_steps is None: num_inner_steps = 3 * d nested_sampler = blackjax.nss( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, num_delete=num_delete, num_inner_steps=num_inner_steps, ) # 4. Initialize Particles rng_key, prior_key = jax.random.split(rng_key) u = jax.random.uniform(prior_key, shape=(n_live, d)) initial_particles = jax.vmap(prior_fn)(u) init_fn = jax.jit(nested_sampler.init) step_fn = jax.jit(nested_sampler.step) state = init_fn(initial_particles) dead_points_list = [] # 5. Execute Sampling Loop self.logger.info(f"Starting nested sampling with {n_live} live points and {num_delete} delete points...") with tqdm(desc="Sampling", unit=" dead points") as pbar: while not state.logZ_live - state.logZ < logZ_convergence: rng_key, step_key = jax.random.split(rng_key) state, dead_info = step_fn(step_key, state) dead_points_list.append(dead_info) pbar.update(num_delete) pbar.set_postfix(logZ=f"{state.logZ:.2f}") # 6. Finalize the run and package the results self.logger.info("Finalizing results...") final_dead_points = blackjax.ns.utils.finalise(state, dead_points_list) # Use anesthetic to easily calculate logZ and its error nested_samples = NestedSamples( data=final_dead_points.particles, columns=param_names, logL=final_dead_points.loglikelihood, logL_birth=final_dead_points.loglikelihood_birth, labels=labeled_param_names, ) end_time = time.time() total_time = end_time - start_time self.logger.info(f"Sampling finished in {total_time:.2f} seconds.") self.logger.info(f"Final logZ = {nested_samples.logZ()}") # 7. Determine Best Parameters num_model_params = self.model.num_flat_params for i, param_name in enumerate(param_names[0:num_model_params]): if fitted_params == 'mean': x0[i] = nested_samples[param_name].mean() elif fitted_params == 'maximum-likelihood': idx = jnp.argmax(nested_samples.logL.values) x0[i] = nested_samples[param_name].values[idx] else: self.logger.warning("Unknown best parameter method. Skipping") # 8. Update Model Priors with Posterior model_param_names = self.model.flat_param_names() param_group = ParameterGroup( model_param_names, AnestheticDistribution(nested_samples, model_param_names) ) fitted_model = self.model.with_params(x0).with_param_groups(param_group) return fitted_model, nested_samples
# -------------------------------------------------------------------------- # HDF5 Serialization Hooks (Identical to PolyChord since both return NestedSamples) # --------------------------------------------------------------------------
[docs] @staticmethod def write_results(stream: io.BytesIO, results: Any): """ Encode anesthetic NestedSamples into a CSV stream for HDF5 serialization. """ csv_str: str = results.to_csv() stream.write(csv_str.encode('utf-8'))
[docs] @staticmethod def read_results(stream: io.BytesIO) -> Any: """ Reconstruct anesthetic NestedSamples from a CSV stream. """ from anesthetic import NestedSamples, read_csv csv_str = stream.read().decode('utf-8') return NestedSamples(read_csv(io.StringIO(csv_str)))