JointDistribution

class JointDistribution(distributions, distribution_names, param_names)[source]

Bases: Distribution

Parameters:
support = RealVector(Real(), 1)
Parameters:

value (NumLikeT)

Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

reparametrized_params: list[str] = []
has_rsample = True
property num_params: int
sample(key, sample_shape=())[source]

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

ndarray

log_prob(value)[source]

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

icdf(u)[source]

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

cdf(x)[source]

The cumulative distribution function of this distribution.

Parameters:

value – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

property min: Array

Dynamically infer lower bounds based on the numpyro constraints of each distribution.

MIN_PERCENTILE is used if the distribution is not strictly bounded.

property max: Array

Dynamically infer upper bounds based on the numpyro constraints of each distribution.

MAX_PERCENTILE is used if the distribution is not strictly bounded.

property bounds: tuple[Array, Array]

Dynamically infer bounds based on the numpyro constraints of each distribution.

MIN_PERCENTILE and MAX_PERCENTILE are used for “min” and “max” if the distribution is not bounded.

arg_constraints: dict[str, Any] = {}
property batch_shape: tuple[int, ...]

Returns the shape over which the distribution parameters are batched.

Returns:

batch shape of the distribution.

Return type:

tuple[int, …]

entropy()

Returns the entropy of the distribution.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

enumerate_support(expand=True)

Returns an array with shape len(support) x batch_shape containing all values in the support.

Parameters:

expand (bool)

Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

property event_dim: int

return: Number of dimensions of individual events. :rtype: int

property event_shape: tuple[int, ...]

Returns the shape of a single sample from the distribution without batching.

Returns:

event shape of the distribution.

Return type:

tuple[int, …]

expand(batch_shape)

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:

batch_shape (tuple) – batch shape to expand to.

Returns:

an instance of ExpandedDistribution.

Return type:

ExpandedDistribution

expand_by(sample_shape)

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:

sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.

Returns:

An expanded version of this distribution.

Return type:

ExpandedDistribution

classmethod gather_pytree_aux_fields()
Return type:

tuple[str, …]

classmethod gather_pytree_data_fields()
Return type:

tuple[str, …]

get_args()

Get arguments of the distribution.

Return type:

dict[str, Any]

has_enumerate_support: bool = False
classmethod infer_shapes(*args, **kwargs)

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args (Any) – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs (Any) – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

property is_discrete: bool
mask(mask)

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:

mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).

Returns:

A masked copy of this distribution.

Return type:

MaskedDistribution

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO

>>> def model(data, m):
...     f = numpyro.sample("latent_fairness", dist.Beta(1, 1))
...     with numpyro.plate("N", data.shape[0]):
...         # only take into account the values selected by the mask
...         masked_dist = dist.Bernoulli(f).mask(m)
...         numpyro.sample("obs", masked_dist, obs=data)


>>> def guide(data, m):
...     alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))


>>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)])
>>> # select values equal to one
>>> masked_array = jnp.where(data == 1, True, False)
>>> optimizer = numpyro.optim.Adam(step_size=0.05)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array)
>>> params = svi_result.params
>>> # inferred_mean is closer to 1
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

pytree_aux_fields: tuple[str, ...] = ('_batch_shape', '_event_shape')
pytree_data_fields: tuple[str, ...] = ()
rsample(key, sample_shape=())
Parameters:
Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

sample_with_intermediates(key, sample_shape=())

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

ndarray

static set_default_validate_args(value)
Parameters:

value (bool)

Return type:

None

shape(sample_shape=())

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:

sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.

Returns:

shape of samples.

Return type:

tuple

to_event(reinterpreted_batch_ndims=None)

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:

reinterpreted_batch_ndims (int | None) – Number of rightmost batch dims to interpret as event dims.

Returns:

An instance of Independent distribution.

Return type:

numpyro.distributions.distribution.Independent

tree_flatten()
Return type:

tuple[tuple[Any, …], tuple[Any, …]]

classmethod tree_unflatten(aux_data, params)
Parameters:
Return type:

Distribution

validate_args(strict=True)

Validate the arguments of the distribution.

Parameters:

strict (bool) – Require strict validation, raising an error if the function is called inside jitted code.

Return type:

None

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.