EqxLearnUncertaintySampler

class EqxLearnUncertaintySampler(model, surrogate, cv=None, fit_kwargs=None, **kwargs)[source]

Bases: FieldSampler

Adaptive sampler that targets regions of high surrogate model uncertainty.

This sampler uses an eqx-learn surrogate model to approximate the relationship between physical parameters and RF features. The adaptive loop selects new points where the surrogate’s predictive variance is highest, effectively minimizing the Maximum Mean Absolute Error (MAE) across the entire design space.

It includes built-in preprocessing to filter out simulation anomalies or NaNs before training the surrogate.

Parameters:
  • model (Model) – The base ParamRF model to sample.

  • surrogate (eqxlearn.BaseModel) – An eqx-learn model instance (e.g., a Gaussian Process) that supports predictive variance via return_var=True.

  • cv (callable, optional) – A cross-validation strategy for monitoring convergence. Defaults to a shuffled 5-fold KFold.

  • fit_kwargs (dict, optional) – Keyword arguments passed to the eqxlearn.fit routine.

  • **kwargs – Additional arguments for the FieldSampler base class.

run(anomaly_threshold=0.001, **kwargs)[source]

Execute the sampling process.

This method handles plotting setup, optional recovery from previous runs, delegates the core algorithm to the subclass’s sample() method, and packages the final results.

Parameters:
  • plot (str or list[str], optional) – Specific features to plot live during the sampling loop and save as PNG images.

  • output_path (str, optional) – The directory where results, numpy crash-checkpoints, and figures should be saved.

  • load_previous (bool, default=False) – If True and output_path is provided, it will attempt to load and return an existing sample_results.hdf5 file instead of re-running the sampling.

  • save_results (bool, default=True) – Whether to save the final SampleResults object to an HDF5 file.

  • save_figures (bool, default=True) – Whether to save the generated plots to disk.

  • **kwargs – Additional arguments passed directly to the subclass’s sample() method.

Returns:

The comprehensive results object containing the evaluated parameters and their corresponding extracted features.

Return type:

SampleResults

train_field(key=None)[source]

Train and return the scalar field (e.g., a surrogate model).

Parameters:

key (Array, optional) – JAX PRNG key for stochastic training.

Returns:

The trained field object/state.

Return type:

Any

evaluate_field(field, theta, key=None)[source]

Evaluate the scalar field at a given physical parameter point.

Parameters:
  • field (Any) – The trained field object returned by train_field().

  • theta (jax.numpy.ndarray) – A 1D array representing a point in physical parameter space.

  • key (Array, optional) – JAX PRNG key for stochastic evaluation.

Returns:

The scalar value of the field at the given point.

Return type:

float

calculate_convergence(key=None)[source]

Calculate metrics to track the algorithmic convergence.

If not overridden, the maximum value found in the grid field will be used to track stability.

Return type:

float

acquire(N, d, *, grid_sampler=None, rtol=0.01, atol=None, patience=5, window=5, num_grid_per_dim=1024, key=None, **kwargs)

Propose N new points by maximizing the learned field over a grid.

This method follows a fixed template:

  • Trains the field using current samples.

  • Generates a dense grid of points in the parameter space.

  • Evaluates the field across that grid.

  • Selects the top \(N\) points using a penalized greedy strategy to

    ensure spatial diversity.

  • Monitors convergence based on the field’s stability.

Parameters:
  • N (int) – Number of points to propose.

  • d (int) – Dimensionality of the parameter space.

  • grid_sampler (BaseSampler, optional) – An optional sampler used to generate the candidate grid. If None, Latin Hypercube Sampling is used.

  • rtol (float, default=0.01) – Relative tolerance for convergence checking.

  • atol (float, optional) – Absolute tolerance for convergence checking.

  • patience (int, default=5) – Number of iterations the metric must stay within tolerance to converge.

  • window (int, default=5) – The look-back window for calculating tolerances.

  • num_grid_per_dim (int, default=1024) – Multiplier for determining the grid density (\(K = num\_grid\_per\_dim \times d\)).

  • key (Array, optional) – JAX PRNG key.

Returns:

New proposed points in the unit hypercube, or None if converged.

Return type:

jax.numpy.ndarray or None

cdf(theta)

Evaluate the cumulative distribution function (CDF) for the model parameters.

This computes the probability \(P(X \leq \theta)\) and lazily compiles the evaluation graph using jax.jit on the first call.

Parameters:

theta (jax.numpy.ndarray) – The parameter values to evaluate.

Returns:

The evaluated CDF probabilities, bounded between \(0\) and \(1\).

Return type:

jax.numpy.ndarray

execute(N=None, *, initial_models=None, initial_factor=None, batch_size=1, max_iterations=None, key=None, **kwargs)

Execute the adaptive active learning loop.

NB: This method should not be called directly. Call run() instead.

Parameters:
  • N (int, optional) – The total budget of samples. The loop stops once len(sampled_params) >= N.

  • initial_models (list of Model or int, optional) – The starting state. If an integer, that many points are sampled via Latin Hypercube Sampling (LHS) to start the loop.

  • initial_factor (int, optional) – A multiplier for the parameter count to determine the number of initial points (e.g., initial_factor=5 for a 2-param model starts with 10 points).

  • batch_size (int, default=1) – The number of points to propose and simulate in each iteration.

  • max_iterations (int, optional) – The maximum number of adaptive loops to run, regardless of the total sample count.

  • key (Array, optional) – A JAX random PRNG key for stochastic acquisition functions.

  • **kwargs – Additional arguments passed to the acquire() method.

Returns:

The final array of sampled physical parameters.

Return type:

tuple[jax.numpy.ndarray, None]

Raises:

ValueError – If initialization arguments are contradictory or insufficient.

expensive: bool = True
icdf(u)

Evaluate the inverse cumulative distribution function (ICDF), or quantile function.

This computes \(\theta = F^{-1}(u)\) and lazily compiles the evaluation graph using jax.jit on the first call.

Parameters:

u (jax.numpy.ndarray) – The probability values to evaluate, typically uniformly distributed between \(0\) and \(1\).

Returns:

The corresponding parameter values \(\theta\).

Return type:

jax.numpy.ndarray

log_prior(theta)

Evaluate the log-prior probability \(\log P(\theta)\) of the model parameters.

The prior is calculated by summing the log-probabilities of the flat parameter vector. The JAX graph is lazily compiled on the first call.

Parameters:

theta (jax.numpy.ndarray) – The parameter values to evaluate.

Returns:

The scalar log-prior probability.

Return type:

float or jax.numpy.ndarray

model_features(theta)

Extract the RF features from the model for a given set of parameters.

This function maps the parameters into the model, simulates it over the defined frequency band, and extracts the target specifications. The entire extraction pipeline is vectorized over the batch dimension and lazily compiled via jax.jit(jax.vmap(...)).

Parameters:

theta (jax.numpy.ndarray) – A 1D array of a single parameter set, or a 2D array representing a batch of parameters.

Returns:

The extracted model features. Matches the batch dimension of theta.

Return type:

jax.numpy.ndarray

Raises:

RuntimeError – If frequency or features were not provided during initialization.

static read_results(stream)

Reconstruct backend optimization or sampling results from a bytes stream.

Parameters:

stream (BytesIO) – The open byte stream to read from.

Returns:

The deserialized Python objects.

Return type:

Any

update(theta)

Evaluate new parameters and update the internal sampling state.

This function is central to active learning and iterative sampling backends. It extracts the features for the provided parameters, appends the data to the internal tracking arrays, updates any configured live plots, and (if expensive=True) automatically saves intermediate crash-recovery checkpoints to disk using native NumPy serialization.

Parameters:

theta (jax.numpy.ndarray) – A 1D array of a single parameter set, or a 2D array representing a batch of parameters to simulate.

Return type:

None

static write_results(stream, results)

Encode backend optimization or sampling results into a bytes stream.

Parameters:
  • stream (BytesIO) – The open byte stream to write to.

  • results (Any) – The data structure containing the backend results.

Notes

The default implementation uses jsonpickle for robust, research-grade serialization of complex Python objects.

sampled_params: Array | None
sampled_features: Array | None
feature_plotters: list[LivePlotter]