Source code for pmrf.sampling.algorithms.uniform
import jax
import jax.numpy as jnp
from pmrf.sampling.oneshot import OneshotSampler
[docs]
class UniformSampler(OneshotSampler):
[docs]
def generate(self, N, D, key=None, **kwargs) -> jnp.ndarray:
"""
Generate samples using uniform random sampling.
Parameters
----------
N : int
Number of samples.
D : int
Dimensionality (number of parameters).
Returns
-------
jnp.ndarray
Samples in the unit hypercube `[0, 1)^D`.
"""
return jax.random.uniform(key, shape=(N, D))