Source code for pmrf.sampling.results

from dataclasses import dataclass
import logging

import numpy as np
import jax.numpy as jnp
import h5py

from pmrf.models.model import Model
from pmrf.results import BaseResults
from pmrf.features import extract_features

[docs] @dataclass class SampleResults(BaseResults): """Container for the results of a model sampling process.""" sampled_models: list[Model] | None = None sampled_params: jnp.ndarray | None = None sampled_features: jnp.ndarray | None = None # -------------------------------------------------------------------------- # Subclass specific Plotting & I/O # -------------------------------------------------------------------------- def _get_feature_plot_data(self, feature: str, plot_samples=10, **kwargs) -> list[dict]: """ Extracts the requested feature from the initial model and a subset of the sampled posterior models to visualize variance. """ if not self.initial_model or not self.frequency: logging.warning("Missing initial model or frequency for plotting.") return [] plot_data = [] # 1. Plot the initial model (Base case) try: y_init = extract_features(self.initial_model, self.frequency, [feature]) y_init = np.real(y_init[:, 0]) # Cast to real to strip complex dtype plot_data.append({ 'y': y_init, 'label': 'Initial Model', 'linestyle': '-', 'color': 'k', 'linewidth': 2, 'zorder': 10 # Keep the initial model on top of the samples }) except Exception as e: logging.warning(f"Failed to extract '{feature}' from initial model: {e}") # 2. Plot the sampled models (Ghost traces) if self.sampled_models: step = max(1, len(self.sampled_models) // plot_samples) for i, sm in enumerate(self.sampled_models[::step]): try: y_sm = extract_features(sm, self.frequency, [feature]) y_sm = np.real(y_sm[:, 0]) plot_data.append({ 'y': y_sm, 'label': 'Sample' if i == 0 else None, # Only label the first one 'linestyle': '-', 'color': 'C0', 'alpha': 0.3, # Make samples semi-transparent 'linewidth': 1 }) except Exception: pass return plot_data def _write_data(self, f: h5py.File): if self.sampled_params is not None: f.create_dataset('sampled_params', data=np.asarray(self.sampled_params)) if self.sampled_features is not None: f.create_dataset('sampled_features', data=np.asarray(self.sampled_features)) if self.sampled_models: sm_grp = f.create_group('sampled_models') for i, model in enumerate(self.sampled_models): self._write_model(sm_grp.create_group(str(i)), model) @classmethod def _read_data(cls, f: h5py.File, kwargs: dict): if 'sampled_params' in f: kwargs['sampled_params'] = jnp.array(f['sampled_params'][()]) if 'sampled_features' in f: kwargs['sampled_features'] = jnp.array(f['sampled_features'][()]) if 'models' in f and 'sampled' in f['models']: sm_grp = f['models/sampled'] keys = sorted(sm_grp.keys(), key=lambda x: int(x)) kwargs['sampled_models'] = [cls._read_model(sm_grp[k]) for k in keys]