Source code for pmrf.fitting._bayesian

import numpy as np
import jax
import jax.numpy as jnp
import skrf
import numpyro.distributions as dist
import dataclasses

from pmrf.parameters import Parameter, ParameterGroup, Uniform
from pmrf.models.model import Model
from pmrf.constants import FeatureInputT
from pmrf.fitting._base import BaseFitter, FitResults


[docs] class BayesianResults(FitResults): pass
[docs] class BayesianFitter(BaseFitter): """ A base class for Bayesian fitting methods. This class extends `BaseFitter` by adding the concept of a likelihood function. """
[docs] def __init__( self, model: Model, measured: skrf.Network | dict[str, skrf.Network], frequency: skrf.Frequency | None = None, features: FeatureInputT | None = None, likelihood_kind: str | None = "gaussian", likelihood_params: dict[str, Parameter] = None, *args, **kwargs ) -> None: """Initializes the BayesianFitter. Args: model (Model): The parametric `pmrf` model to be fitted. measured (skrf.Network | list[skrf.Network]): The measured network data to fit the model against. frequency (skrf.Frequency | None, optional): The frequency axis to perform the fit on. Defaults to `None`. features (FeatureT | FeatureListT | None = None, optional): The features to extract for comparison. Note that note all features make sense for all likelihoods, but no error checking is done for this. Defaults to `None`, in which case real and imaginary feature for all model ports are used. likelihood_kind (str, optional): The kind of likelihood to use. Defaults to "gaussian" for a one-dimensional Gaussian likelihood requiring a single likelihood parameter 'sigma'. Can also be "multivariate_gaussian", in which case either multiple standard deviations 'sigma_0', 'sigma_1', ..., 'sigma_N' may be passed, where N is the number of features, or an arbitrary number of arbitrarily named likelihood parameters may be passed, along with a list of strings `feature_sigmas` of size N containing the names of the likelihood parameters to use for each feature. likelihood_params (dict[str, Parameter], optional): A dictionary of likelihood parameters to use for the likelihood function. """ feature_sigmas = kwargs.pop('feature_sigmas', None) super().__init__(model=model, measured=measured, frequency=frequency, features=features, *args, **kwargs) if likelihood_kind == 'multivariate_gaussian': if likelihood_params is None: raise Exception('Likelihood parameters must be provided for multivariate Gaussian likelihoods') if feature_sigmas is None: raise Exception('Currently on feature_sigmas is supported for multivariate Gaussian likelihoods') self.feature_sigmas = feature_sigmas self.likelihood_kind = likelihood_kind self.likelihood_params = likelihood_params elif likelihood_kind == 'gaussian': if likelihood_params is not None and len(likelihood_params) > 1: raise Exception("A gaussian likelihood only has a single likelihood parameter 'sigma'") self.likelihood_params = likelihood_params if likelihood_params is not None else {'sigma': Uniform(0.0, 50.0e-3, name='sigma')} self.likelihood_kind = likelihood_kind else: raise Exception(f"Unsupported likelihood kind: {likelihood_kind}")
@property def num_params(self) -> int: return self.num_model_params + self.num_likelihood_params @property def num_model_params(self) -> int: return self.initial_model.num_flat_params @property def num_likelihood_params(self) -> int: return len(self.likelihood_params) def _flat_param_names(self) -> list[str]: return self.initial_model.flat_param_names() + list(self.likelihood_params.keys()) def _make_prior_transform_fn(self, as_numpy=False): model_prior = self.initial_model.prior() num_model_params = len(self.initial_model.flat_params()) num_likelihood_params = len(self.likelihood_params) @jax.jit def prior_transform_fn(u): theta_model = model_prior.icdf(u[0:num_model_params]) theta_likelihood = jnp.array([param.prior.icdf(u[num_model_params:][i]) for i, param in enumerate(self.likelihood_params.values())]) return jnp.concat((theta_model, theta_likelihood)) if as_numpy: prior_transform_fn_jax = prior_transform_fn prior_transform_fn = lambda hypercube: np.array(prior_transform_fn_jax(hypercube)) self.logger.info('Compiling prior transform...') _prior = prior_transform_fn(jnp.array([0.5] * (num_model_params + num_likelihood_params))) return prior_transform_fn def _make_log_prior_fn(self, as_numpy=False): model_prior = self.initial_model.prior() num_model_params = self.initial_model.num_flat_params num_likelihood_params = len(self.likelihood_params) @jax.jit def logprior_fn(params: jax.Array) -> float: logprob_model = model_prior.log_prob(params[0:num_model_params]) logprob_likelihood = jnp.array([param.prior.log_prob(params[num_model_params:][i]) for i, param in enumerate(self.likelihood_params.values())]) return jnp.sum(logprob_model) + jnp.sum(logprob_likelihood) if as_numpy: logprior_fn_jax = logprior_fn logprior_fn = lambda x: float(logprior_fn_jax(jnp.array(x))) self.logger.info('Compiling log prior...') _prior = logprior_fn(jnp.array([0.5] * (num_model_params + num_likelihood_params))) return logprior_fn def _make_log_likelihood_fn(self, as_numpy=False): if self.likelihood_kind == 'gaussian': log_likelihood_fn = self._make_gaussian_log_likelihood_fn() elif self.likelihood_kind == 'multivariate_gaussian': log_likelihood_fn = self._make_multivariate_gaussian_log_likelihood_fn() else: raise Exception(f"Unsupported likelihood kind: {self.likelihood_kind}") x0 = jnp.array(list(self.initial_model.flat_params()) + [param.prior.mean for param in self.likelihood_params.values()]) if as_numpy: log_likelihood_fn_jax = log_likelihood_fn log_likelihood_fn = lambda x: float(log_likelihood_fn_jax(jnp.array(x))) x0 = np.array(x0) self.logger.info(f"Compiling likelihood function...") _log_likelihood = log_likelihood_fn(x0) return log_likelihood_fn def _make_gaussian_log_likelihood_fn(self): feature_fn_jax = self._make_feature_function() @jax.jit def loglikelihood_fn(flat_params) -> float: theta, sigma = flat_params[0:-1], flat_params[-1] y_pred = jnp.real(feature_fn_jax(theta)) y_meas = jnp.real(self.measured_features) logL = dist.Normal(loc=y_pred, scale=sigma).log_prob(y_meas).sum() return logL return loglikelihood_fn def _make_multivariate_gaussian_log_likelihood_fn(self): feature_fn_jax = self._make_feature_function() @jax.jit def loglikelihood_fn(flat_params) -> float: num_sigma = len(self.likelihood_params) theta, sigmas = flat_params[0:-num_sigma], flat_params[-num_sigma:] y_pred = jnp.real(feature_fn_jax(theta)) y_meas = jnp.real(self.measured_features) param_keys = list(self.likelihood_params.keys()) sigma_indices = jnp.array([param_keys.index(key) for key in self.feature_sigmas]) scales = sigmas[sigma_indices] logL = dist.Normal(loc=y_pred, scale=scales).log_prob(y_meas).sum() return logL return loglikelihood_fn