NumPyroFitter

class NumPyroFitter(model, *, likelihood_kind=None, likelihood_params=None, feature_sigmas=None, **kwargs)[source]

Bases: BayesianFitter

Base class for fitters utilizing the NumPyro probabilistic programming library.

This base class handles the translation of the unified ParamRF model definitions (priors and compiled likelihoods) into a dynamic NumPyro trace graph. It also provides optimized binary serialization for the resulting MCMC/NS sample chains.

Methods

make_numpyro_model

Construct a dynamic NumPyro model function that mirrors the BayesianFitter logic.

Parameters:
make_numpyro_model(targets)[source]

Construct a dynamic NumPyro model function that mirrors the BayesianFitter logic.

This method traces the prior distributions for both the physical model parameters and the likelihood noise parameters using numpyro.sample. It then concatenates them and injects the lazily compiled ParamRF log-likelihood graph directly into the trace as a numpyro.factor node.

Parameters:

targets (jax.numpy.ndarray) – The extracted measurement data to condition the likelihood on.

Returns:

A dynamic parameter-less function representing the complete NumPyro probabilistic model, ready to be passed to a sampling kernel.

Return type:

callable

static write_results(stream, results)[source]

Save the NumPyro sample dictionary as compressed numpy arrays.

This overrides the default JSON serialization, providing significantly faster write times and smaller file sizes for large MCMC traces.

Parameters:
static read_results(stream)[source]

Reconstruct the NumPyro sample dictionary from the binary stream.

Parameters:

stream (BytesIO)

Return type:

Any

cdf(theta)

Evaluate the combined cumulative distribution function (CDF).

Parameters:

theta (jax.numpy.ndarray) – The parameter values. Note that this 1D array must contain the model parameters followed sequentially by the likelihood noise parameters.

Returns:

The combined CDF probabilities mapped between \(0\) and \(1\).

Return type:

jax.numpy.ndarray

abstractmethod execute(target, **kwargs)

Implemented by subclasses to run the specific optimization algorithm.

Parameters:
  • target (jax.numpy.ndarray) – The extracted target features to fit against.

  • **kwargs – Backend-specific algorithm parameters passed down from run().

Returns:

The fitted model and the raw results from the solver.

Return type:

tuple[Model, Any]

icdf(u)

Evaluate the combined inverse cumulative distribution function (ICDF).

Parameters:

u (jax.numpy.ndarray) – The probability values. Note that this 1D array corresponds to the probabilities for the model parameters followed by the likelihood noise parameters.

Returns:

The physical parameter values evaluated from the prior distributions.

Return type:

jax.numpy.ndarray

log_likelihood(theta, target)

Evaluate the log-likelihood of the target data.

This handles expanding 1D parameters into the 2D format expected by the vmapped feature extractor, and computes the probability density of the target data against the selected Gaussian or Multivariate Gaussian distribution.

Parameters:
  • theta (jax.numpy.ndarray) – The concatenated 1D array containing the model parameters followed by the likelihood noise standard deviations (\(\sigma\)).

  • target (jax.numpy.ndarray) – The extracted target features (measurement data) to evaluate against.

Returns:

The scalar log-likelihood probability.

Return type:

jax.numpy.ndarray

log_prior(theta)

Evaluate the total log-prior probability.

This lazily compiles the JAX graph to sum the log-prior probabilities of both the underlying model parameters and the added likelihood noise parameters.

Parameters:

theta (jax.numpy.ndarray) – The concatenated 1D array containing the model parameters followed by the likelihood parameters.

Returns:

The scalar log-prior probability.

Return type:

jax.numpy.ndarray

model_features(theta)

Extract the RF features from the model for a given set of parameters.

This function maps the parameters into the model, simulates it over the defined frequency band, and extracts the target specifications. The entire extraction pipeline is vectorized over the batch dimension and lazily compiled via jax.jit(jax.vmap(...)).

Parameters:

theta (jax.numpy.ndarray) – A 1D array of a single parameter set, or a 2D array representing a batch of parameters.

Returns:

The extracted model features. Matches the batch dimension of theta.

Return type:

jax.numpy.ndarray

Raises:

RuntimeError – If frequency or features were not provided during initialization.

property num_params: int

Total number of active parameters (model free parameters + likelihood noise parameters).

Type:

int

run(measured, **kwargs)

Execute the Bayesian fitting routine.

This method intercepts the standard run sequence to automatically resolve the target features, likelihood kind, and noise priors based on the shape and type of the provided measurement data before passing execution to the backend.

Parameters:
  • measured (str or skrf.Network or NetworkCollection) – The measurement data to condition the likelihood on.

  • **kwargs – Additional arguments forwarded to the specific backend solver.

Returns:

The fitted model and the raw results object.

Return type:

tuple[Model, FitResults]