Source code for pmrf.sampling.algorithms.uniform

import jax
import jax.numpy as jnp

from pmrf.sampling.oneshot import OneshotSampler

[docs] class UniformSampler(OneshotSampler):
[docs] def generate(self, N, D, key=None, **kwargs) -> jnp.ndarray: """ Generate samples using uniform random sampling. Parameters ---------- N : int Number of samples. D : int Dimensionality (number of parameters). Returns ------- jnp.ndarray Samples in the unit hypercube `[0, 1)^D`. """ return jax.random.uniform(key, shape=(N, D))