Source code for pmrf.sampling._base

from typing import Iterator, TypeVar
from abc import ABC, abstractmethod
import logging

import numpy as np
import skrf
import jax
from jax import flatten_util
import jax.numpy as jnp
import equinox as eqx

from pmrf import extract_features, wrap
from pmrf.frequency import Frequency
from pmrf.models.model import Model
from pmrf.constants import FeatureInputT

ModelT = TypeVar('ModelT', bound='Model')

[docs] class BaseSampler(ABC):
[docs] def __init__(self, model: ModelT): self.model: Model = model self.N = None self.logger = logging.getLogger(__name__)
def __iter__(self) -> Iterator[Model]: if self.N is None: raise Exception('Error: to use this class as an iterator, call e.g. enumerate (CircuitSampler.range(n))') N = self.N params_matrix = self._generate_params(N) params, static = self.model.named_params() _, ravel_fn = flatten_util.ravel_pytree(params) for i in N: yield eqx.combine(ravel_fn(params_matrix[i,:]), static) self.N = None def __len__(self) -> int: if self.N is None: raise Exception('Error: to use this class as an iterator, call e.g. enumerate (CircuitSampler.range(n))') return self.N
[docs] def range(self, N) -> ModelT: """Allows the CircuitSampler to be used as an iterable. To use, call e.g: for i, system in enumerate(sampler.range(10)). Args: n (int): The number of samples to generate Returns: Model: self """ self.N = N return self
[docs] def generate_models(self, N) -> list[ModelT]: """Generates N random models using the sampler's engine. Note that, if you want to generate samples one-by-one, you can use this class in iterator mode by passing N to the constructor or by using `Sampler.range(..)`. Args: n (int, optional): The number of samples to generate. Defaults to 10. Returns: _type_: Model | None """ params_matrix = self._generate_params(N) models = [] params, static = self.model.partition() _, ravel_fn = flatten_util.ravel_pytree(params) for i in range(N): models.append(eqx.combine(ravel_fn(params_matrix[i,:]), static)) return models
[docs] def generate_features(self, N, features: FeatureInputT, frequency: Frequency | skrf.Frequency, dont_jit=False, **kwargs) -> jnp.array: if isinstance(frequency, skrf.Frequency): frequency = Frequency.from_skrf(frequency) params_matrix = jnp.array(self._generate_params(N)) params = self.model.flat_params() feature_fn = wrap(extract_features, self.model, frequency, features) self.logger.info('Compiling feature function...') _features0 = feature_fn(params) self.logger.info('Generating features.') vectorized_fn = jax.vmap(feature_fn) if not dont_jit: vectorized_fn = jax.jit(vectorized_fn) return vectorized_fn(params_matrix)
def _generate_params(self, N): params = self.model.flat_params() D = len(params) U = self._generate_hypercube_samples(N, D) prior = self.model.prior() X = [] for i in range(N): u = U[i, :] x = prior.icdf(u) X.append(x) return np.stack(X, axis=0) @abstractmethod def _generate_hypercube_samples(self, N, D) -> jnp.ndarray: pass