AcquisitionSampler
- class AcquisitionSampler(model, **kwargs)[source]
Bases:
BaseSampler,ABCBase class for adaptive, acquisition-based sampling (Active Learning).
This sampler implements a feedback loop that:
Starts with an initial random experimental design (LHS).
Repeatedly queries an acquisition function to find the most “informative” next points to simulate.
Updates the internal state until a budget (\(N\)) or convergence is reached.
This class is marked as
expensive=True, meaning it will automatically save intermediate NumPy checkpoints for crash recovery during the loop.This class is marked as
expensive=True, meaning it will automatically save intermediate NumPy checkpoints for crash recovery during the loop.Methods
Execute the sampling process.
Execute the adaptive active learning loop.
Evaluate new parameters and update the internal sampling state.
Implemented by active learning backends (e.g., EqxLearnUncertaintySampler).
- Parameters:
model (Model)
- execute(N=None, *, initial_models=None, initial_factor=None, batch_size=1, max_iterations=None, key=None, **kwargs)[source]
Execute the adaptive active learning loop.
NB: This method should not be called directly. Call
run()instead.- Parameters:
N (int, optional) – The total budget of samples. The loop stops once
len(sampled_params) >= N.initial_models (list of Model or int, optional) – The starting state. If an integer, that many points are sampled via Latin Hypercube Sampling (LHS) to start the loop.
initial_factor (int, optional) – A multiplier for the parameter count to determine the number of initial points (e.g.,
initial_factor=5for a 2-param model starts with 10 points).batch_size (int, default=1) – The number of points to propose and simulate in each iteration.
max_iterations (int, optional) – The maximum number of adaptive loops to run, regardless of the total sample count.
key (Array, optional) – A JAX random PRNG key for stochastic acquisition functions.
**kwargs – Additional arguments passed to the
acquire()method.
- Returns:
The final array of sampled physical parameters.
- Return type:
tuple[jax.numpy.ndarray, None]
- Raises:
ValueError – If initialization arguments are contradictory or insufficient.
- abstractmethod acquire(N, d, *, key=None, **kwargs)[source]
Implemented by active learning backends (e.g., EqxLearnUncertaintySampler).
Should inspect self.sampled_params and self.sampled_features, train a surrogate/acquisition function, and return N proposed points in the unit hypercube (shape (N, d)).
Return None to signal algorithmic convergence to the master loop.
- cdf(theta)
Evaluate the cumulative distribution function (CDF) for the model parameters.
This computes the probability \(P(X \leq \theta)\) and lazily compiles the evaluation graph using
jax.jiton the first call.- Parameters:
theta (jax.numpy.ndarray) – The parameter values to evaluate.
- Returns:
The evaluated CDF probabilities, bounded between \(0\) and \(1\).
- Return type:
jax.numpy.ndarray
- icdf(u)
Evaluate the inverse cumulative distribution function (ICDF), or quantile function.
This computes \(\theta = F^{-1}(u)\) and lazily compiles the evaluation graph using
jax.jiton the first call.- Parameters:
u (jax.numpy.ndarray) – The probability values to evaluate, typically uniformly distributed between \(0\) and \(1\).
- Returns:
The corresponding parameter values \(\theta\).
- Return type:
jax.numpy.ndarray
- log_prior(theta)
Evaluate the log-prior probability \(\log P(\theta)\) of the model parameters.
The prior is calculated by summing the log-probabilities of the flat parameter vector. The JAX graph is lazily compiled on the first call.
- Parameters:
theta (jax.numpy.ndarray) – The parameter values to evaluate.
- Returns:
The scalar log-prior probability.
- Return type:
float or jax.numpy.ndarray
- model_features(theta)
Extract the RF features from the model for a given set of parameters.
This function maps the parameters into the model, simulates it over the defined frequency band, and extracts the target specifications. The entire extraction pipeline is vectorized over the batch dimension and lazily compiled via
jax.jit(jax.vmap(...)).- Parameters:
theta (jax.numpy.ndarray) – A 1D array of a single parameter set, or a 2D array representing a batch of parameters.
- Returns:
The extracted model features. Matches the batch dimension of
theta.- Return type:
jax.numpy.ndarray
- Raises:
RuntimeError – If
frequencyorfeatureswere not provided during initialization.
- static read_results(stream)
Reconstruct backend optimization or sampling results from a bytes stream.
- Parameters:
stream (BytesIO) – The open byte stream to read from.
- Returns:
The deserialized Python objects.
- Return type:
Any
- run(*, plot=None, output_path=None, load_previous=False, save_results=True, save_figures=True, **kwargs)
Execute the sampling process.
This method handles plotting setup, optional recovery from previous runs, delegates the core algorithm to the subclass’s
sample()method, and packages the final results.- Parameters:
plot (str or list[str], optional) – Specific features to plot live during the sampling loop and save as PNG images.
output_path (str, optional) – The directory where results, numpy crash-checkpoints, and figures should be saved.
load_previous (bool, default=False) – If
Trueandoutput_pathis provided, it will attempt to load and return an existingsample_results.hdf5file instead of re-running the sampling.save_results (bool, default=True) – Whether to save the final
SampleResultsobject to an HDF5 file.save_figures (bool, default=True) – Whether to save the generated plots to disk.
**kwargs – Additional arguments passed directly to the subclass’s
sample()method.
- Returns:
The comprehensive results object containing the evaluated parameters and their corresponding extracted features.
- Return type:
- update(theta)
Evaluate new parameters and update the internal sampling state.
This function is central to active learning and iterative sampling backends. It extracts the features for the provided parameters, appends the data to the internal tracking arrays, updates any configured live plots, and (if
expensive=True) automatically saves intermediate crash-recovery checkpoints to disk using native NumPy serialization.- Parameters:
theta (jax.numpy.ndarray) – A 1D array of a single parameter set, or a 2D array representing a batch of parameters to simulate.
- Return type:
None
- static write_results(stream, results)
Encode backend optimization or sampling results into a bytes stream.
- Parameters:
stream (BytesIO) – The open byte stream to write to.
results (Any) – The data structure containing the backend results.
Notes
The default implementation uses
jsonpicklefor robust, research-grade serialization of complex Python objects.