Source code for pmrf.sampling.backends.eqxlearn

from functools import partial
from typing import Any

import jax.numpy as jnp
import jax.random as jr

from pmrf.models.model import Model
from pmrf.sampling.algorithms import FieldSampler
from pmrf.sampling.base import SampleResults
from pmrf.algorithms.anomaly import get_anomaly_mask

[docs] class EqxLearnUncertaintySampler(FieldSampler): r""" Adaptive sampler that targets regions of high surrogate model uncertainty. This sampler uses an ``eqx-learn`` surrogate model to approximate the relationship between physical parameters and RF features. The adaptive loop selects new points where the surrogate's predictive variance is highest, effectively minimizing the Maximum Mean Absolute Error (MAE) across the entire design space. It includes built-in preprocessing to filter out simulation anomalies or NaNs before training the surrogate. Parameters ---------- model : :class:`~pmrf.models.model.Model` The base ParamRF model to sample. surrogate : :class:`eqxlearn.BaseModel` An ``eqx-learn`` model instance (e.g., a Gaussian Process) that supports predictive variance via ``return_var=True``. cv : callable, optional A cross-validation strategy for monitoring convergence. Defaults to a shuffled 5-fold KFold. fit_kwargs : dict, optional Keyword arguments passed to the ``eqxlearn.fit`` routine. **kwargs Additional arguments for the :class:`~pmrf.sampling.algorithms.FieldSampler` base class. """ def __init__( self, model: Model, surrogate: Any, cv=None, fit_kwargs: dict = None, **kwargs ): from eqxlearn.model_selection import KFold super().__init__(model=model, **kwargs) self.surrogate = surrogate self.cv = cv or partial(KFold, n_splits=5, shuffle=True) self.fit_kwargs = fit_kwargs or {} self.anomaly_threshold = None
[docs] def run(self, anomaly_threshold=1e-3, **kwargs) -> SampleResults: self.anomaly_threshold = anomaly_threshold return super().run(**kwargs)
def _preprocess(self) -> tuple[jnp.ndarray, jnp.ndarray]: params, features = self.sampled_params, self.sampled_features anomaly_threshold = self.anomaly_threshold num_features = features.shape[-1] for fi in range(num_features): features_i = features[...,fi] # Filter out NaNs mask = jnp.any(jnp.isnan(features)) mask = mask | jnp.any(jnp.isnan(params)) # Filter out anomalies if anomaly_threshold is not None: mask = mask | get_anomaly_mask(features_i, anomaly_threshold) # Apply the masks if len(features[mask]) > 0: self.logger.warning(f"Not using {len(features[mask])} bad features") features = features[~mask] if params is not None: params = params[~mask] # Flatten features features = features.reshape(features.shape[0], -1) return params, features
[docs] def train_field(self, key=None) -> Any: from eqxlearn import fit X, y = self._preprocess() self.logger.info("Training surrogate model...") fitted_eqx_model, losses = fit(self.surrogate, X=X, y=y, key=key, **self.fit_kwargs) self.logger.info(f"Final loss: {losses[-1]:.2f}") return fitted_eqx_model
[docs] def evaluate_field(self, field: Any, theta: jnp.ndarray, key=None) -> float: _y_mean, y_var = field(theta, return_var=True) rayleigh_factor = jnp.sqrt(jnp.pi) / 2.0 total_std = jnp.sqrt(y_var.real + y_var.imag) expected_mae = jnp.mean(rayleigh_factor * total_std) return 20 * jnp.log10(expected_mae)
[docs] def calculate_convergence(self, key=None) -> float: from eqxlearn.model_selection import cross_validate from eqxlearn.metrics import mean_absolute_error X, y = self._preprocess() key, split_key = jr.split(key) cv_instance = self.cv(key=split_key) def mae_db(model, X, y): return jnp.max(20 * jnp.log10(mean_absolute_error(y, model.predict(X), axis=1))) self.logger.info("Validating surrogate model...") key, cv_key = jr.split(key) results = cross_validate( self.surrogate, X, y, cv=cv_instance, scoring=mae_db, return_loss=True, key=cv_key, **self.fit_kwargs ) score = jnp.mean(results['test_score']) losses = results['loss'] loss_str = [round(float(loss), 2) for loss in losses] self.logger.info(f"Average error = {score:.2f} dB. Training losses = {loss_str}") return score