Source code for pmrf.sampling._base

from typing import Iterator, TypeVar, Generic
from abc import ABC, abstractmethod
import logging

import numpy as np
import skrf
import jax
from jax import flatten_util
import jax.numpy as jnp
import equinox as eqx

from pmrf import extract_features, wrap
from pmrf.frequency import Frequency
from pmrf.models.model import Model
from pmrf.constants import FeatureInputT

ModelT = TypeVar('ModelT', bound='Model')

[docs] class BaseSampler(ABC, Generic[ModelT]): """ Abstract base class for parameter sampling engines. This class provides the infrastructure to generate random variations of a `pmrf.Model` based on the prior distributions of its parameters. It supports generating batches of models, batches of features (e.g., S-parameters), or iterating over generated models. Subclasses must implement `_generate_hypercube_samples` to define the sampling strategy (e.g., Monte Carlo, Latin Hypercube). Attributes ---------- model : Model The base model defining the parameter structure and priors. N : int or None The number of samples for the current iteration context. logger : logging.Logger Logger instance. """
[docs] def __init__(self, model: ModelT): """ Initialize the sampler. Parameters ---------- model : Model The model to sample from. """ self.model: ModelT = model self.N = None self.logger = logging.getLogger(__name__)
def __iter__(self) -> Iterator[ModelT]: """ Iterate over generated models. Requires `range(N)` to have been called previously to set the sample count. Yields ------ Model A new model instance with sampled parameters. Raises ------ Exception If `range()` was not called before iteration. """ if self.N is None: raise Exception('Error: to use this class as an iterator, call e.g. enumerate (CircuitSampler.range(n))') N = self.N params_matrix = self._generate_params(N) params, static = self.model.named_params() _, ravel_fn = flatten_util.ravel_pytree(params) for i in N: yield eqx.combine(ravel_fn(params_matrix[i,:]), static) self.N = None def __len__(self) -> int: """ Return the number of samples in the current iteration context. Returns ------- int Number of samples. Raises ------ Exception If `range()` was not called before checking length. """ if self.N is None: raise Exception('Error: to use this class as an iterator, call e.g. enumerate (CircuitSampler.range(n))') return self.N
[docs] def range(self, N) -> 'BaseSampler[ModelT]': """ Configure the sampler for iteration over `N` samples. Allows the Sampler to be used as an iterable. Examples -------- >>> for i, system in enumerate(sampler.range(10)): ... print(system) Parameters ---------- N : int The number of samples to generate. Returns ------- BaseSampler Returns self to allow chaining into an iterator. """ self.N = N return self
[docs] def generate_models(self, N) -> list[ModelT]: """ Generate `N` random model instances using the sampler's engine. Parameters ---------- N : int The number of samples to generate. Returns ------- list[Model] A list of model instances with parameters sampled from the prior. """ params_matrix = self._generate_params(N) models = [] params, static = self.model._partition() _, ravel_fn = flatten_util.ravel_pytree(params) for i in range(N): models.append(eqx.combine(ravel_fn(params_matrix[i,:]), static)) return models
[docs] def generate_features(self, N, features: FeatureInputT, frequency: Frequency | skrf.Frequency, dont_jit=False, **kwargs) -> jnp.array: """ Generate feature vectors for `N` random samples. This method vectorizes the feature extraction process using JAX for efficiency. Parameters ---------- N : int The number of samples to generate. features : FeatureInputT The list or dictionary of features to extract (e.g., `['s11_db']`). frequency : Frequency or skrf.Frequency The frequency grid to evaluate the features on. dont_jit : bool, optional, default=False If True, disables JIT compilation for the vectorized function. **kwargs Additional arguments passed to the feature extractor. Returns ------- jnp.array The array of generated features with shape `(N, n_features)`. """ if isinstance(frequency, skrf.Frequency): frequency = Frequency.from_skrf(frequency) params_matrix = jnp.array(self._generate_params(N)) params = self.model.flat_param_values() feature_fn = wrap(extract_features, self.model, frequency, features) self.logger.info('Compiling feature function...') _features0 = feature_fn(params) self.logger.info('Generating features.') vectorized_fn = jax.vmap(feature_fn) if not dont_jit: vectorized_fn = jax.jit(vectorized_fn) return vectorized_fn(params_matrix)
def _generate_params(self, N): """ Internal method to generate parameter values in physical space. It first generates samples in the unit hypercube [0, 1]^D and then maps them to physical space using the model's prior inverse CDF (PPF). Parameters ---------- N : int Number of samples. Returns ------- np.ndarray Matrix of parameters with shape `(N, D)`. """ params = self.model.flat_param_values() D = len(params) U = self._generate_hypercube_samples(N, D) prior = self.model.distribution() X = [] for i in range(N): u = U[i, :] x = prior.icdf(u) X.append(x) return np.stack(X, axis=0) @abstractmethod def _generate_hypercube_samples(self, N, D) -> jnp.ndarray: """ Generate samples in the unit hypercube [0, 1]^D. Parameters ---------- N : int Number of samples. D : int Dimensionality (number of parameters). Returns ------- jnp.ndarray Samples with shape `(N, D)`. """ pass