Source code for pmrf.sampling.algorithms.field

from abc import ABC, abstractmethod
from typing import Any, Sequence

import jax
import jax.numpy as jnp
import jax.random as jr

from pmrf.sampling.acqusition import AcquisitionSampler
from pmrf.models.model import Model
from pmrf.sampling.base import BaseSampler, SampleResults
from pmrf.util import lhs_sample, LivePlotter, RANK
from pmrf.algorithms import has_converged

[docs] class FieldSampler(AcquisitionSampler, ABC): r""" Base class for sampling new points at the maxima of a learned scalar field. This sampler implements an active learning strategy where a "field" (such as an uncertainty map or an acquisition function) is trained on existing samples. The next points to simulate are selected by finding the maxima of this field across the parameter space. Subclasses must implement :meth:`train_field` and :meth:`evaluate_field` to define the specific behavior of the adaptive loop. .. rubric:: Methods .. autosummary:: :nosignatures: run execute update train_field evaluate_field calculate_convergence acquire """ def __init__(self, model: Model, **kwargs): super().__init__(model, **kwargs) self.converge_plotters: list[LivePlotter] = [] self.converge_values: list[tuple[float, ...]] = []
[docs] def run( self, output_path: str | None = None, save_figures: bool = True, **kwargs ) -> SampleResults: result = super().run(output_path=output_path, save_figures=save_figures, **kwargs) if output_path is not None and RANK == 0: if save_figures: for i, plotter in enumerate(self.converge_plotters): plotter.fig.savefig(f"{output_path}/convergence_{i}.png", dpi=400) return result
[docs] @abstractmethod def train_field(self, key=None) -> Any: r""" Train and return the scalar field (e.g., a surrogate model). Parameters ---------- key : jax.Array, optional JAX PRNG key for stochastic training. Returns ------- Any The trained field object/state. """ pass
[docs] @abstractmethod def evaluate_field(self, field: Any, theta: jnp.ndarray, key=None) -> float: r""" Evaluate the scalar field at a given physical parameter point. Parameters ---------- field : Any The trained field object returned by :meth:`train_field`. theta : jax.numpy.ndarray A 1D array representing a point in physical parameter space. key : jax.Array, optional JAX PRNG key for stochastic evaluation. Returns ------- float The scalar value of the field at the given point. """ pass
[docs] def calculate_convergence(self, key=None) -> float | Sequence[float] | None: r""" Calculate metrics to track the algorithmic convergence. If not overridden, the maximum value found in the grid field will be used to track stability. """ return None
# -------------------------------------------------------------------------- # The Template Method # --------------------------------------------------------------------------
[docs] def acquire( self, N: int, d: int, *, grid_sampler: BaseSampler | None = None, rtol=0.01, atol=None, patience=5, window=5, num_grid_per_dim=1024, key=None, **kwargs ) -> jnp.ndarray | None: r""" Propose N new points by maximizing the learned field over a grid. This method follows a fixed template: * Trains the field using current samples. * Generates a dense grid of points in the parameter space. * Evaluates the field across that grid. * Selects the top $N$ points using a penalized greedy strategy to ensure spatial diversity. * Monitors convergence based on the field's stability. Parameters ---------- N : int Number of points to propose. d : int Dimensionality of the parameter space. grid_sampler : :class:`~pmrf.sampling.base.BaseSampler`, optional An optional sampler used to generate the candidate grid. If ``None``, Latin Hypercube Sampling is used. rtol : float, default=0.01 Relative tolerance for convergence checking. atol : float, optional Absolute tolerance for convergence checking. patience : int, default=5 Number of iterations the metric must stay within tolerance to converge. window : int, default=5 The look-back window for calculating tolerances. num_grid_per_dim : int, default=1024 Multiplier for determining the grid density ($K = num\_grid\_per\_dim \times d$). key : jax.Array, optional JAX PRNG key. Returns ------- jax.numpy.ndarray or None New proposed points in the unit hypercube, or ``None`` if converged. """ # 1. Train the field (using subclass implementation) key, field_key = jr.split(key) field = self.train_field(key=field_key) # 2. Generate Grid Points K = num_grid_per_dim * d key, grid_key = jr.split(key) if grid_sampler is not None: grid_models, _ = grid_sampler.run(N=K, key=grid_key) theta_grid = jnp.array([m.flat_param_values() for m in grid_models]) else: U_grid = lhs_sample(K, d, key=grid_key) theta_grid = jax.vmap(self.icdf)(U_grid) # 3. Evaluate Field on the Grid (using subclass implementation) key, eval_key = jr.split(key) eval_keys = jr.split(eval_key, len(theta_grid)) def eval_theta_fn(theta, k) -> float: return self.evaluate_field(field, theta, key=k) grid_field = jax.vmap(eval_theta_fn)(theta_grid, eval_keys) # 4. Select N Diverse Points selected_field_thetas, selected_field_values = self._select_field_thetas(N, theta_grid, grid_field) max_field_value = jnp.max(grid_field) if N == 1: self.logger.info(f"Field maximum = {float(max_field_value):.2f}") else: value_str = ", ".join([f"{v:.2f}" for v in selected_field_values]) self.logger.info(f"Field maxima = [{value_str}]") # 5. Convergence Tracking (using subclass implementation) key, validate_key = jr.split(key) conv_metric = self.calculate_convergence(key=validate_key) if conv_metric is not None: new_converge_values = tuple(conv_metric) if isinstance(conv_metric, Sequence) else (conv_metric,) else: new_converge_values = (float(max_field_value),) self.converge_values.append(new_converge_values) # Check for convergence converged = True has_converged_vals = list(map(list, zip(*self.converge_values))) for values in has_converged_vals: if not has_converged(values, rtol=rtol, atol=atol, patience=patience, window=window): converged = False break for i, new_value in enumerate(new_converge_values): if len(self.converge_plotters) <= i: self.converge_plotters.append(LivePlotter(f"Convergence {i}", "Iteration", "Value")) self.converge_plotters[i].add_point(f"Convergence {i}", new_value) # 6. Return mapped hypercube points if converged: return None return jnp.array([self.cdf(theta) for theta in selected_field_thetas])
def _select_field_thetas(self, N: int, thetas: jnp.ndarray, values: jnp.ndarray) -> jnp.ndarray: """Selects N points iteratively using a penalized greedy strategy to ensure diversity.""" K = len(values) if N >= K: return thetas, values if N <= 0: return jnp.array([]), jnp.array([]) if N == 1: best_idx = jnp.argmax(values) return thetas[best_idx][None, :], values[best_idx].reshape(1) t_min, t_max = jnp.min(thetas, axis=0), jnp.max(thetas, axis=0) t_range = jnp.where((t_max - t_min) == 0, 1.0, t_max - t_min) norm_thetas = (thetas - t_min) / t_range v_min, v_max = jnp.min(values), jnp.max(values) if v_max == v_min: return thetas[:N], values[:N] scores = (values - v_min) / (v_max - v_min) L = 0.1 * jnp.sqrt(thetas.shape[1]) selected_indices = [] for _ in range(N): best_idx = jnp.argmax(scores) selected_indices.append(best_idx) if len(selected_indices) < N: dist_sq = jnp.sum((norm_thetas - norm_thetas[best_idx])**2, axis=1) penalty = 1.0 - jnp.exp(-dist_sq / (2 * L**2)) scores = scores * penalty scores = scores.at[best_idx].set(-jnp.inf) return thetas[jnp.array(selected_indices)], values[jnp.array(selected_indices)]