from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any
import glob
import numpy as np
import jax.numpy as jnp
from pmrf.runner import BaseRunner
from pmrf.models.model import Model
from pmrf.util import LivePlotter, RANK
from pmrf.sampling.results import SampleResults
from pmrf.features import extract_features
[docs]
class BaseSampler(BaseRunner, ABC):
r"""
Base class for model sampling and active learning loops.
This runner is designed to explore the parameter space of a model, either
through pre-defined experimental designs or adaptive active learning strategies.
It manages state caching, crash-recovery checkpoints, and live plotting of
extracted features during the sampling loop.
.. rubric:: Methods
.. autosummary::
:nosignatures:
run
execute
update
Parameters
----------
model : :class:`~pmrf.models.model.Model`
The base ParamRF model to be sampled.
**kwargs
Additional arguments passed up to :class:`~pmrf.runner.BaseRunner`
(e.g., ``frequency``, ``features``).
"""
expensive: bool = False
def __init__(self, model: Model, **kwargs):
super().__init__(model, **kwargs)
self.sampled_params: jnp.ndarray | None = None
self.sampled_features: jnp.ndarray | None = None
# Context variables
self.plot_features = None
self.feature_plotters: list[LivePlotter] = []
[docs]
def run(
self,
*,
plot = None,
output_path: str | None = None,
load_previous: bool = False,
save_results: bool = True,
save_figures: bool = True,
**kwargs
) -> SampleResults:
r"""
Execute the sampling process.
This method handles plotting setup, optional recovery from previous runs,
delegates the core algorithm to the subclass's :meth:`sample` method,
and packages the final results.
Parameters
----------
plot : str or list[str], optional
Specific features to plot live during the sampling loop and save
as PNG images.
output_path : str, optional
The directory where results, numpy crash-checkpoints, and figures
should be saved.
load_previous : bool, default=False
If ``True`` and ``output_path`` is provided, it will attempt to
load and return an existing ``sample_results.hdf5`` file instead
of re-running the sampling.
save_results : bool, default=True
Whether to save the final ``SampleResults`` object to an HDF5 file.
save_figures : bool, default=True
Whether to save the generated plots to disk.
**kwargs
Additional arguments passed directly to the subclass's :meth:`sample` method.
Returns
-------
:class:`~pmrf.sampling.results.SampleResults`
The comprehensive results object containing the evaluated parameters
and their corresponding extracted features.
"""
# Updates feature variables
if plot is not None and isinstance(plot, str):
plot = [plot]
self.output_path = output_path
self.plot_features = plot
if self.plot_features is not None:
for i, feature_name in enumerate(plot):
if i >= len(self.feature_plotters):
self.feature_plotters.append(
LivePlotter(title=f"{feature_name}", xlabel=f"Frequency ({self.frequency.unit})", ylabel=f"{feature_name}")
)
# 1. Try load from previous results
if load_previous and output_path is not None:
try:
filename = glob.glob(f"{output_path}/*.hdf5")[0]
results = SampleResults.load_hdf(filename)
self.logger.info("Loaded previous sampling results.")
return results
except Exception:
pass
# 2. Execute Sampling (Delegates to Strategy/Backend)
self.logger.info(f"Sampling {self.model.num_flat_params} parameters...")
self.logger.debug(f"Parameter names: {self.model.flat_param_names()}")
self.logger.debug(f"Features: {self.features}")
samples, backend_results = self.execute(output_path=output_path, **kwargs)
# 3. Package Results
results = SampleResults()
results.initial_model = self.model
results.frequency = self.frequency
results.features = self.features
# We store the accumulated state from add_samples or the final return
results.sampled_params = self.sampled_params if self.sampled_params is not None else samples
results.sampled_features = self.sampled_features
results.backend_results = backend_results
results.backend_class = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
# Consolidate kwargs
full_run_kwargs = kwargs.copy()
full_run_kwargs.update({'load_previous': load_previous, 'save_results': save_results})
results.run_kwargs = full_run_kwargs
# 4. Save Output
if output_path is not None and RANK == 0:
if save_results or save_figures:
Path(output_path).resolve().mkdir(parents=True, exist_ok=True)
if save_results:
self.logger.info('Saving results...')
results.save_hdf(Path(f'{output_path}/sample_results.hdf5').resolve())
if save_figures and self.plot_features is not None:
self.logger.info('Saving figures...')
for feature, plotter in zip(self.plot_features, self.feature_plotters):
plotter.fig.savefig(f"{output_path}/{feature}.png", dpi=400)
return results
[docs]
def update(self, theta: jnp.ndarray) -> None:
r"""
Evaluate new parameters and update the internal sampling state.
This function is central to active learning and iterative sampling backends.
It extracts the features for the provided parameters, appends the data to
the internal tracking arrays, updates any configured live plots, and
(if ``expensive=True``) automatically saves intermediate crash-recovery
checkpoints to disk using native NumPy serialization.
Parameters
----------
theta : jax.numpy.ndarray
A 1D array of a single parameter set, or a 2D array representing
a batch of parameters to simulate.
"""
new_thetas = jnp.atleast_2d(theta)
N, D = new_thetas.shape
new_features = None
# Ensure we only try to extract features if frequency/features are defined
if N > 0 and self.frequency is not None and self.features is not None:
time_str = datetime.now().strftime("%H:%M:%S")
num_existing = len(self.sampled_params) if self.sampled_params is not None else 0
if self.expensive:
if N == 1:
self.logger.info(f"Simulating sample #{num_existing + 1} at {time_str}")
else:
self.logger.info(f"Simulating samples #{num_existing + 1}-{num_existing + N} at {time_str}")
new_features = super().model_features(new_thetas)
# Update State
if self.sampled_params is not None:
self.sampled_params = jnp.vstack((self.sampled_params, new_thetas))
if new_features is not None:
self.sampled_features = jnp.vstack((self.sampled_features, new_features))
else:
self.sampled_params = new_thetas
self.sampled_features = new_features
# I/O Checkpoint (numpy fallback for crash recovery during active learning)
if self.expensive and self.output_path is not None and RANK == 0:
np.save(f"{self.output_path}/sampled_params.npy", np.asarray(self.sampled_params))
if self.sampled_features is not None:
np.save(f"{self.output_path}/sampled_features.npy", np.asarray(self.sampled_features))
# Live Plotting
if self.plot_features is not None:
for plotter, plot_feature in zip(self.feature_plotters, self.plot_features):
for current_theta in new_thetas:
m = self.model.with_params(current_theta)
plot_features = extract_features(m, self.frequency, plot_feature)
param_dict = {k: float(v) for k, v in zip(self.model.flat_param_names(), current_theta)}
plotter.ax.set_title(f"{plot_feature}")
plotter.add_curve(str(param_dict), y_values=np.real(plot_features), x_values=self.frequency.f_scaled)
[docs]
@abstractmethod
def execute(
self,
**kwargs,
) -> tuple[jnp.ndarray, Any]:
r"""
Implemented by subclasses to perform the actual sampling algorithm.
Parameters
----------
**kwargs
Backend-specific algorithm parameters passed down from :meth:`run`.
Returns
-------
tuple[jax.numpy.ndarray, Any]
A tuple containing the generated array of parameters (thetas) and
any raw backend-specific results or metadata object.
"""
pass