from abc import abstractmethod
import numpy as np
import jax
import jax.numpy as jnp
import skrf
import numpyro.distributions as dist
import dataclasses
from pmrf.constants import FeatureInputT
from pmrf.models import Model
from pmrf.parameters import Parameter, ParameterGroup, Uniform
from pmrf.distributions.trainable import TrainableDistributionT
from pmrf.fitting._base import BaseFitter, FitResults
[docs]
class BayesianResults(FitResults):
[docs]
@abstractmethod
def prior_samples(self) -> jnp.ndarray:
pass
[docs]
@abstractmethod
def posterior_samples(self) -> jnp.ndarray:
pass
[docs]
def update_priors(self, train_dist: TrainableDistributionT | None = None, **train_kwargs):
if train_dist is None:
from pmrf.distributions import MargarineMAFDistribution
train_dist = MargarineMAFDistribution
param_names: list[str] = self.fitted_model.flat_param_names()
samples: jnp.ndarray = self.posterior_samples()[:,0:self.fitted_model.num_flat_params]
dist = train_dist.from_samples(samples, **train_kwargs)
param_group = ParameterGroup(param_names, dist)
self.fitted_model = self.fitted_model.with_param_groups(param_group)
[docs]
class BayesianFitter(BaseFitter):
"""
A base class for Bayesian fitting methods.
This class extends `BaseFitter` by adding the concept of a likelihood function.
"""
[docs]
def __init__(
self,
model: Model,
measured: skrf.Network | dict[str, skrf.Network],
*args,
frequency: skrf.Frequency | None = None,
features: FeatureInputT | None = None,
likelihood_kind: str | None = "gaussian",
likelihood_params: dict[str, Parameter] = None,
**kwargs
) -> None:
"""Initializes the BayesianFitter.
Args:
model (Model):
The parametric `pmrf` model to be fitted.
measured (skrf.Network | list[skrf.Network]):
The measured network data to fit the model against.
frequency (skrf.Frequency | None, optional):
The frequency axis to perform the fit on. Defaults to `None`.
features (FeatureT | FeatureListT | None = None, optional):
The features to extract for comparison.
Note that note all features make sense for all likelihoods, but no error checking is done for this.
Defaults to `None`, in which case real and imaginary feature for all model ports are used.
likelihood_kind (str, optional):
The kind of likelihood to use. Defaults to "gaussian" for a one-dimensional Gaussian likelihood
requiring a single likelihood parameter 'sigma'. Can also be "multivariate_gaussian", in which case either
multiple standard deviations 'sigma_0', 'sigma_1', ..., 'sigma_N' may be passed, where N is the number of features,
or an arbitrary number of arbitrarily named likelihood parameters may be passed, along with a list of strings `feature_sigmas`
of size N containing the names of the likelihood parameters to use for each feature.
likelihood_params (dict[str, Parameter], optional):
A dictionary of likelihood parameters to use for the likelihood function.
train_posteriors (bool, optional):
Whether or not model posteriors should be trained and set for the fitted model.
"""
feature_sigmas: list[str] = kwargs.pop('feature_sigmas', None)
super().__init__(model=model, measured=measured, frequency=frequency, features=features, *args, **kwargs)
if likelihood_kind == 'multivariate_gaussian':
if feature_sigmas is None:
raise Exception('feature_sigmas must be passed for multivariate Gaussian likelihoods')
self.feature_sigmas = feature_sigmas
self.likelihood_kind = likelihood_kind
self.likelihood_params = likelihood_params if likelihood_params is not None else {sigma_name: Uniform(0.0, 50.0e-3, name=sigma_name) for sigma_name in feature_sigmas}
elif likelihood_kind == 'gaussian':
if likelihood_params is not None and len(likelihood_params) > 1:
raise Exception("A gaussian likelihood only has a single likelihood parameter 'sigma'")
self.likelihood_params = likelihood_params if likelihood_params is not None else {'sigma': Uniform(0.0, 50.0e-3, name='sigma')}
self.likelihood_kind = likelihood_kind
else:
raise Exception(f"Unsupported likelihood kind: {likelihood_kind}")
@property
def num_params(self) -> int:
return self.num_model_params + self.num_likelihood_params
@property
def num_model_params(self) -> int:
return self.initial_model.num_flat_params
@property
def num_likelihood_params(self) -> int:
return len(self.likelihood_params)
def _model_param_names(self) -> list[str]:
return self.initial_model.flat_param_names()
def _likelihood_param_names(self) -> list[str]:
return list(self.likelihood_params.keys())
def _flat_param_names(self) -> list[str]:
return self._model_param_names() + self._likelihood_param_names()
def _make_prior_transform_fn(self, as_numpy=False):
model_prior = self.initial_model.distribution()
num_model_params = len(self.initial_model.flat_params())
num_likelihood_params = len(self.likelihood_params)
@jax.jit
def prior_transform_fn(u):
theta_model = model_prior.icdf(u[0:num_model_params])
theta_likelihood = jnp.array([param.distribution.icdf(u[num_model_params:][i]) for i, param in enumerate(self.likelihood_params.values())])
return jnp.concat((theta_model, theta_likelihood))
if as_numpy:
prior_transform_fn_jax = prior_transform_fn
prior_transform_fn = lambda hypercube: np.array(prior_transform_fn_jax(hypercube))
self.logger.info('Compiling prior transform...')
_prior = prior_transform_fn(jnp.array([0.5] * (num_model_params + num_likelihood_params)))
return prior_transform_fn
def _make_log_prior_fn(self, as_numpy=False):
model_prior = self.initial_model.distribution()
num_model_params = self.initial_model.num_flat_params
num_likelihood_params = len(self.likelihood_params)
@jax.jit
def logprior_fn(params: jax.Array) -> float:
logprob_model = model_prior.log_prob(params[0:num_model_params])
logprob_likelihood = jnp.array([param.distribution.log_prob(params[num_model_params:][i]) for i, param in enumerate(self.likelihood_params.values())])
return jnp.sum(logprob_model) + jnp.sum(logprob_likelihood)
if as_numpy:
logprior_fn_jax = logprior_fn
logprior_fn = lambda x: float(logprior_fn_jax(jnp.array(x)))
self.logger.info('Compiling log prior...')
_prior = logprior_fn(jnp.array([0.5] * (num_model_params + num_likelihood_params)))
return logprior_fn
def _make_log_likelihood_fn(self, as_numpy=False):
if self.likelihood_kind == 'gaussian':
log_likelihood_fn = self._make_gaussian_log_likelihood_fn()
elif self.likelihood_kind == 'multivariate_gaussian':
log_likelihood_fn = self._make_multivariate_gaussian_log_likelihood_fn()
else:
raise Exception(f"Unsupported likelihood kind: {self.likelihood_kind}")
x0 = jnp.array(list(self.initial_model.flat_params()) + [param.distribution.mean for param in self.likelihood_params.values()])
if as_numpy:
log_likelihood_fn_jax = log_likelihood_fn
log_likelihood_fn = lambda x: float(log_likelihood_fn_jax(jnp.array(x)))
x0 = np.array(x0)
self.logger.info(f"Compiling likelihood function...")
_log_likelihood = log_likelihood_fn(x0)
return log_likelihood_fn
def _make_gaussian_log_likelihood_fn(self):
feature_fn_jax = self._make_feature_function()
@jax.jit
def loglikelihood_fn(flat_params) -> float:
theta, sigma = flat_params[0:-1], flat_params[-1]
y_pred = jnp.real(feature_fn_jax(theta))
y_meas = jnp.real(self.measured_features)
logL = dist.Normal(loc=y_pred, scale=sigma).log_prob(y_meas).sum()
return logL
return loglikelihood_fn
def _make_multivariate_gaussian_log_likelihood_fn(self):
feature_fn_jax = self._make_feature_function()
@jax.jit
def loglikelihood_fn(flat_params) -> float:
num_sigma = len(self.likelihood_params)
theta, sigmas = flat_params[0:-num_sigma], flat_params[-num_sigma:]
y_pred = jnp.real(feature_fn_jax(theta))
y_meas = jnp.real(self.measured_features)
param_keys = list(self.likelihood_params.keys())
sigma_indices = jnp.array([param_keys.index(key) for key in self.feature_sigmas])
scales = sigmas[sigma_indices]
logL = dist.Normal(loc=y_pred, scale=scales).log_prob(y_meas).sum()
return logL
return loglikelihood_fn