Source code for pmrf.sampling.oneshot

from abc import ABC, abstractmethod
from typing import Any

import jax.numpy as jnp

from pmrf.sampling.base import BaseSampler
from pmrf.util import generate_key

[docs] class OneshotSampler(BaseSampler, ABC): r""" Base class for static, non-adaptive sampling strategies. This class is used for strategies that propose all sample points at once, independent of any model feedback (e.g., Uniform, Latin Hypercube, or Sobol sequences). It handles the transformation from the unit hypercube to the physical parameter space using the model's prior distributions. .. rubric:: Methods .. autosummary:: :nosignatures: run execute generate """
[docs] def execute( self, *, N: int = 100, batch_size: int = 1, key=None, **kwargs ) -> tuple[jnp.ndarray, Any]: r""" Generate and evaluate a batch of points in one execution. Parameters ---------- N : int, default=100 The total number of samples to generate and simulate. batch_size : int, default=1 The number of points to simulate in each iteration. **kwargs Additional arguments passed to the :meth:`generate` method. Returns ------- tuple[jax.numpy.ndarray, None] A tuple containing the evaluated physical parameters and a placeholder for backend results. """ if key is None: key = generate_key() d = self.model.num_flat_params U = self.generate(N, d, key=key, **kwargs) thetas = jnp.array([self.icdf(u) for u in U]) for i in range(0, thetas.shape[0], batch_size): batch_theta = thetas[i : i + batch_size] self.update(batch_theta) return self.sampled_params, None
[docs] @abstractmethod def generate(self, N: int, d: int, **kwargs) -> jnp.ndarray: r""" Propose points within the unit hypercube $[0, 1]^d$. Parameters ---------- N : int Number of points to propose. d : int Dimensionality of the parameter space. Returns ------- jax.numpy.ndarray An array of shape ``(N, d)`` containing the proposed coordinates. """ raise NotImplementedError