Source code for pmrf.fitting.bayesian

import dataclasses
from dataclasses import dataclass
from functools import partial
from abc import abstractmethod

import equinox as eqx
import numpy as np
import jax
import jax.numpy as jnp
import skrf
import numpyro.distributions as dist
import matplotlib.pyplot as plt

from pmrf.network_collection import NetworkCollection
from pmrf._util import RANK, wait_for_all_ranks
from pmrf.models import Model
from pmrf.parameters import Parameter, ParameterGroup, Uniform
from pmrf.distributions.trainable import TrainableDistributionT
from pmrf.fitting.base import BaseFitter, FitResults, FitContext

DefaultSigmaPrior = partial(Uniform, 0.0, 20e-3)

[docs] class BayesianResults(FitResults): """ Abstract base class for results obtained from a Bayesian fitting process. This class extends `FitResults` to include methods specific to posterior sampling and distribution training. """
[docs] @abstractmethod def prior_samples(self, equal_weights=False) -> jnp.ndarray: """ Retrieve samples drawn from the prior distribution. Parameters ---------- equal_weights : bool, optional, default=False If True, returns unweighted (resampled) samples. Returns ------- jnp.ndarray The array of prior samples. """ pass
[docs] @abstractmethod def posterior_samples(self, equal_weights=False) -> jnp.ndarray: """ Retrieve samples drawn from the posterior distribution. Parameters ---------- equal_weights : bool, optional, default=False If True, returns unweighted (resampled) samples. Returns ------- jnp.ndarray The array of posterior samples. """ pass
[docs] @abstractmethod def weights(self) -> jnp.ndarray: """ Retrieve the weights associated with the posterior samples. Returns ------- jnp.ndarray Array of sample weights. """ pass
[docs] def fit_posterior(self, train_dist: TrainableDistributionT | None = None, equal_weights=False, drift_sigma=0.0, boost_method=None, boost_samples=10000, **train_kwargs): """ Fit a trainable distribution to the posterior samples. Parameters ---------- train_dist : TrainableDistributionT or None, optional The distribution class to train. If None, defaults to `MargarineMAFDistribution`. equal_weights : bool, optional, default=False If True, uses equal weights for training; otherwise uses sample weights. drift_sigma : float, optional, default=0.0 Standard deviation for drift augmentation to broaden the posterior support. boost_method : str or None, optional Method to boost sample count ('kde' or None). boost_samples : int, optional, default=10000 Number of samples to generate if boosting is enabled. **train_kwargs Additional keyword arguments passed to the distribution's training method. """ param_names: list[str] = self.fitted_model.flat_param_names() training_data: jnp.ndarray = self.posterior_samples(equal_weights=equal_weights)[:,0:self.fitted_model.num_flat_params] if drift_sigma != 0.0: if boost_method == 'kde': from margarine.kde import KDE kde = KDE(training_data) kde.generate_kde() training_data = kde.sample(boost_samples) elif boost_method != None: raise Exception('Unknown posterior training data boost method') scale = np.abs(np.mean(training_data, axis=0)) * drift_sigma training_data += np.random.normal(loc=0.0, scale=scale, size=training_data.shape) if train_dist is None: from pmrf.distributions import MargarineMAFDistribution train_dist = MargarineMAFDistribution if equal_weights: dist = train_dist.from_samples(training_data, **train_kwargs) else: weights = self.weights() dist = train_dist.from_weighted_samples(training_data, weights, **train_kwargs) param_group = ParameterGroup(param_names, dist) self.fitted_model = self.fitted_model.with_param_groups(param_group)
[docs] @dataclass class BayesianContext(FitContext): """ Context object for Bayesian fitting, containing likelihood configurations. Attributes ---------- likelihood_kind : str or None The type of likelihood function (e.g., 'gaussian', 'multivariate_gaussian'). likelihood_params : dict[str, Parameter] or None Parameters governing the likelihood function (e.g., noise standard deviation). feature_sigmas : list[str] or None Mapping of feature names to likelihood parameter names for multivariate cases. """ likelihood_kind: str = None likelihood_params: dict[str, Parameter] = None feature_sigmas: list[str] | None = None @property def num_params(self) -> int: """int: Total number of parameters (model + likelihood).""" return self.num_model_params + self.num_likelihood_params @property def num_model_params(self) -> int: """int: Number of flat parameters in the model.""" return self.model.num_flat_params @property def num_likelihood_params(self) -> int: """int: Number of parameters in the likelihood function.""" return len(self.likelihood_params)
[docs] def likelihood_param_names(self) -> list[str]: """ Get names of the likelihood parameters. Returns ------- list of str The names of the likelihood parameters. """ return list(self.likelihood_params.keys())
[docs] def flat_param_names(self) -> list[str]: """ Get names of all parameters (model and likelihood). Returns ------- list of str Combined list of parameter names. """ return self.model_param_names() + self.likelihood_param_names()
[docs] def make_prior_transform_fn(self, as_numpy=False): """ Create the prior transform function (unit hypercube to parameter space). Parameters ---------- as_numpy : bool, optional, default=False If True, returns a function compatible with NumPy arrays; otherwise JAX arrays. Returns ------- callable Function transforming a unit hypercube vector `u` to parameter vector `theta`. """ model_prior = self.model.distribution() num_model_params = self.model.num_flat_params num_likelihood_params = len(self.likelihood_params) @jax.jit def prior_transform_fn(u): theta_model = model_prior.icdf(u[0:num_model_params]) theta_likelihood = jnp.array([param.distribution.icdf(u[num_model_params:][i]) for i, param in enumerate(self.likelihood_params.values())]) return jnp.concat((theta_model, theta_likelihood)) if as_numpy: prior_transform_fn_jax = prior_transform_fn prior_transform_fn = lambda hypercube: np.array(prior_transform_fn_jax(hypercube)) self.logger.info('Compiling prior transform...') _prior = prior_transform_fn(jnp.array([0.5] * (num_model_params + num_likelihood_params))) return prior_transform_fn
[docs] def make_log_prior_fn(self, as_numpy=False): """ Create the log-prior probability density function. Parameters ---------- as_numpy : bool, optional, default=False If True, returns a function compatible with NumPy arrays. Returns ------- callable Function returning the log-probability of the given parameters. """ model_prior = self.model.distribution() num_model_params = self.model.num_flat_params num_likelihood_params = len(self.likelihood_params) @jax.jit def logprior_fn(params: jax.Array) -> float: logprob_model = model_prior.log_prob(params[0:num_model_params]) logprob_likelihood = jnp.array([param.distribution.log_prob(params[num_model_params:][i]) for i, param in enumerate(self.likelihood_params.values())]) return jnp.sum(logprob_model) + jnp.sum(logprob_likelihood) if as_numpy: logprior_fn_jax = logprior_fn logprior_fn = lambda x: float(logprior_fn_jax(jnp.array(x))) self.logger.info('Compiling log prior...') _prior = logprior_fn(jnp.array([0.5] * (num_model_params + num_likelihood_params))) return logprior_fn
[docs] def make_log_likelihood_fn(self, as_numpy=False): """ Create the log-likelihood function based on the context configuration. Parameters ---------- as_numpy : bool, optional, default=False If True, returns a function compatible with NumPy arrays. Returns ------- callable Function returning the log-likelihood of the parameters given the data. Raises ------ Exception If the `likelihood_kind` is unsupported. """ if self.likelihood_kind == 'gaussian': log_likelihood_fn = self.make_gaussian_log_likelihood_fn() elif self.likelihood_kind == 'multivariate_gaussian': log_likelihood_fn = self.make_multivariate_gaussian_log_likelihood_fn() else: raise Exception(f"Unsupported likelihood kind: {self.likelihood_kind}") x0 = jnp.array(list(self.model.flat_param_values()) + [param.distribution.mean for param in self.likelihood_params.values()]) if as_numpy: log_likelihood_fn_jax = log_likelihood_fn log_likelihood_fn = lambda x: float(log_likelihood_fn_jax(jnp.array(x))) x0 = np.array(x0) self.logger.info(f"Compiling likelihood function...") _log_likelihood = log_likelihood_fn(x0) return log_likelihood_fn
[docs] def make_gaussian_log_likelihood_fn(self): """ Create a simple Gaussian log-likelihood function (single sigma). Returns ------- callable JIT-compiled log-likelihood function. """ feature_fn_jax = self.make_feature_function() @jax.jit def loglikelihood_fn(flat_params) -> float: theta, sigma = flat_params[0:-1], flat_params[-1] y_pred = jnp.real(feature_fn_jax(theta)) y_meas = jnp.real(self.measured_features) logL = dist.Normal(loc=y_pred, scale=sigma).log_prob(y_meas).sum() return logL return loglikelihood_fn
[docs] def make_multivariate_gaussian_log_likelihood_fn(self): """ Create a multivariate Gaussian log-likelihood function (multiple sigmas). Returns ------- callable JIT-compiled log-likelihood function. """ feature_fn_jax = self.make_feature_function() @jax.jit def loglikelihood_fn(flat_params) -> float: num_sigma = len(self.likelihood_params) theta, sigmas = flat_params[0:-num_sigma], flat_params[-num_sigma:] y_pred = jnp.real(feature_fn_jax(theta)) y_meas = jnp.real(self.measured_features) param_keys = list(self.likelihood_params.keys()) sigma_indices = jnp.array([param_keys.index(key) for key in self.feature_sigmas]) scales = sigmas[sigma_indices] logL = dist.Normal(loc=y_pred, scale=scales).log_prob(y_meas).sum() return logL return loglikelihood_fn
[docs] class BayesianFitter(BaseFitter): """ A base class for Bayesian fitting methods. This class extends `BaseFitter` by adding the concept of a likelihood function. """
[docs] def __init__( self, model: Model, *, likelihood_kind: str | None = None, likelihood_params: dict[str, Parameter] = None, feature_sigmas: list[str] | None = None, **kwargs ) -> None: """ Initializes the BayesianFitter. Parameters ---------- model : Model The parametric `pmrf` model to be fitted. likelihood_kind : str, optional The kind of likelihood to use. Can be either 'gaussian' or 'multivariate_gaussian'. Defaults internally to 'gaussian' for one-port fits, and 'multivariate_gaussian' for greater port fits. For 'gaussian', a single likelihood parameter, 'sigma', is needed. For 'multivariate_gaussian', either multiple standard deviations 'sigma_0', 'sigma_1', ..., 'sigma_N' may be passed, where N is the number of features, or an arbitrary number of arbitrarily named likelihood parameters may be passed, along with a list of strings `feature_sigmas` of size N containing the names of the likelihood parameters to use for each feature. likelihood_params : dict[str, Parameter], optional A dictionary of likelihood parameters to use for the likelihood function. feature_sigmas : list[str], optional A list of sigma names for each feature. Only used when `likelihood_kind` is 'multivariate_gaussian'. **kwargs Additional arguments forwarded to :class:`BaseFitter`. """ self.likelihood_kind = likelihood_kind self.likelihood_params = likelihood_params self.feature_sigmas = feature_sigmas super().__init__(model, **kwargs)
def _create_context(self, measured, *, likelihood_kind=None, likelihood_params=None, feature_sigmas=None, **kwargs) -> BayesianContext: """ Create a BayesianContext for the fitting process. Parameters ---------- measured : skrf.Network or NetworkCollection The measured data. likelihood_kind : str, optional Override for the likelihood kind. likelihood_params : dict, optional Override for the likelihood parameters. feature_sigmas : list, optional Override for the feature sigmas. **kwargs Additional context arguments (e.g., `features`, `sparam_kind`). Returns ------- BayesianContext The configured Bayesian context. Raises ------ Exception If feature sigmas are missing for multivariate likelihoods or if multiple params are passed for a simple gaussian. """ features = kwargs.pop('features', None) or self.features sparam_kind = kwargs.pop('sparam_kind', None) or self.sparam_kind likelihood_kind = likelihood_kind or self.likelihood_kind likelihood_params = likelihood_params or self.likelihood_params feature_sigmas = feature_sigmas or self.feature_sigmas if isinstance(measured, str): measured = skrf.Network(measured) is_two_port = all([ntwk.nports == 2 for ntwk in measured]) if isinstance(measured, NetworkCollection) else measured.nports == 2 is_one_port = all([ntwk.nports == 1 for ntwk in measured]) if isinstance(measured, NetworkCollection) else measured.nports == 1 if is_one_port and sparam_kind == 'all': sparam_kind = 'reflection' likelihood_kind = likelihood_kind if likelihood_kind is not None else 'multivariate_gaussian' if sparam_kind == 'all' and not is_one_port else 'gaussian' default_likelihood_params = default_features = default_feature_sigmas = None if likelihood_kind == 'gaussian': default_likelihood_params = {'sigma': DefaultSigmaPrior()} if is_two_port: if sparam_kind == 'all': default_features = ['s11_re', 's11_im', 's12_re', 's12_im', 's21_re', 's21_im', 's22_re', 's22_im'] elif sparam_kind == 'reflection': default_features = ['s11_re', 's11_im', 's22_re', 's22_im'] elif sparam_kind == 'transmission': default_features = ['s12_re', 's12_im', 's21_re', 's21_im'] else: default_features = ['s_re', 's_im'] elif likelihood_kind == 'multivariate_gaussian': if is_two_port: default_likelihood_params = {sigma_name: DefaultSigmaPrior() for sigma_name in ['sigma_gamma', 'sigma_tau']} if sparam_kind == 'all': default_features = ['s11_re', 's11_im', 's12_re', 's12_im', 's21_re', 's21_im', 's22_re', 's22_im'] default_feature_sigmas = ['sigma_gamma', 'sigma_gamma', 'sigma_tau', 'sigma_tau', 'sigma_tau', 'sigma_tau', 'sigma_gamma', 'sigma_gamma'] else: raise Exception('No need to use multivariate gaussian when fitting only transmission or reflection coefficients') else: pass likelihood_params = likelihood_params if likelihood_params is not None else default_likelihood_params features = features if features is not None else default_features feature_sigmas = feature_sigmas if feature_sigmas is not None else default_feature_sigmas base_ctx = super()._create_context(measured, features=features, sparam_kind=sparam_kind, **kwargs) if likelihood_kind == 'multivariate_gaussian': if feature_sigmas is None: raise Exception('feature_sigmas must be passed for multivariate Gaussian likelihoods') likelihood_params = likelihood_params if likelihood_params is not None else {sigma_name: DefaultSigmaPrior(name=sigma_name) for sigma_name in feature_sigmas} elif likelihood_kind == 'gaussian': if likelihood_params is not None and len(likelihood_params) > 1: raise Exception("A gaussian likelihood only has a single likelihood parameter 'sigma'") likelihood_params = likelihood_params if likelihood_params is not None else {'sigma': DefaultSigmaPrior(name='sigma')} else: raise Exception(f"Unsupported likelihood kind: {likelihood_kind}") return BayesianContext( model=base_ctx.model, measured=base_ctx.measured, frequency=base_ctx.frequency, features=base_ctx.features, measured_features=base_ctx.measured_features, output_path=base_ctx.output_path, output_root=base_ctx.output_root, sparam_kind=base_ctx.sparam_kind, likelihood_kind=likelihood_kind, likelihood_params=likelihood_params, feature_sigmas=feature_sigmas, logger=self.logger, ) def _run_context(self, ctx: BayesianContext, plot_params=False, fit_posterior=False, fit_posterior_dist=None, fit_posterior_kwargs=None, *args, **kwargs) -> BayesianResults: """ Execute the Bayesian fitting process within a context Parameters ---------- ctx : BayesianContext The execution context. plot_params : bool, optional, default=False If True, saves a plot of the parameter distributions. fit_posterior : bool, optional, default=False If True, fits a trainable distribution to the posterior samples after the run. fit_posterior_dist : TrainableDistributionT, optional The specific distribution class to fit to the posterior. fit_posterior_kwargs : dict, optional Arguments passed to the posterior fitting method. *args, **kwargs Additional arguments forwarded to the parent `run` method. Returns ------- BayesianResults The results of the fit. """ user_callback = kwargs.get('callback', None) fit_posterior_kwargs = fit_posterior_kwargs or {} def callback(results: BayesianResults): nonlocal user_callback nonlocal fit_posterior_dist if RANK == 0: from pmrf.distributions import MargarineMAFDistribution fit_posterior_dist = fit_posterior_dist or MargarineMAFDistribution results.fit_posterior(fit_posterior_dist, **fit_posterior_kwargs) wait_for_all_ranks() if user_callback: kwargs['callback'](results) if fit_posterior: user_callback = kwargs.pop('callback', None) kwargs['callback'] = callback results: BayesianResults = super()._run_context(ctx, *args, **kwargs) if plot_params: results.plot_params() plt.savefig(f'{ctx.output_path}/params.png') return results