AnestheticDistribution

class AnestheticDistribution(nested_samples, param_names=None, validate_args=None)[source]

Bases: SampledDistribution

Adapter for distributions represented by samples from the Anesthetic library.

Although you cannot sample from the nested samples as if they were an actual distribution, this class is still useful in storing sampling results directly within a model.

Parameters:

param_names (list[str] | None)

param_names()[source]

Retrieves parameter names associated with the samples of this distribution.

Returns:

The list of parameter names

Return type:

list[str]

samples(prior=False, weighted=False)[source]

Retrieve samples drawn from the distribution.

Parameters:

weighted (bool, optional, default=False) – If True, returns weighted (non-resampled) samples.

Returns:

The array of samples.

Return type:

jnp.ndarray

weights(prior=False)[source]

Retrieve weights assocaited with the samples, if any.

Returns:

The array of weights.

Return type:

jnp.ndarray

bounds(prior=False)[source]
Return type:

tuple[Array, Array]

marginalize(keep_params)[source]

Marginalizes out all parameters except those specified in keep_params.

Parameters:

keep_params (list[str]) – The list of parameter names to retain. All other parameters will be marginalized out.

Returns:

A new distribution object representing the marginalized distribution.

Return type:

AnestheticDistribution

arg_constraints = {}
property batch_shape: tuple[int, ...]

Returns the shape over which the distribution parameters are batched.

Returns:

batch shape of the distribution.

Return type:

tuple[int, …]

cdf(value)

The cumulative distribution function of this distribution.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – samples from this distribution.

Returns:

output of the cumulative distribution function evaluated at value.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

entropy()

Returns the entropy of the distribution.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

enumerate_support(expand=True)

Returns an array with shape len(support) x batch_shape containing all values in the support.

Parameters:

expand (bool)

Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

property event_dim: int

return: Number of dimensions of individual events. :rtype: int

property event_shape: tuple[int, ...]

Returns the shape of a single sample from the distribution without batching.

Returns:

event shape of the distribution.

Return type:

tuple[int, …]

expand(batch_shape)

Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.

Parameters:

batch_shape (tuple) – batch shape to expand to.

Returns:

an instance of ExpandedDistribution.

Return type:

ExpandedDistribution

expand_by(sample_shape)

Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:

sample_shape (tuple) – The size of the iid batch to be drawn from the distribution.

Returns:

An expanded version of this distribution.

Return type:

ExpandedDistribution

classmethod gather_pytree_aux_fields()
Return type:

tuple[str, …]

classmethod gather_pytree_data_fields()
Return type:

tuple[str, …]

get_args()

Get arguments of the distribution.

Return type:

dict[str, Any]

has_enumerate_support = False
property has_rsample: bool
icdf(u)

The inverse cumulative distribution function of this distribution.

Parameters:

q – quantile values, should belong to [0, 1].

Returns:

the samples whose cdf values equals to q.

classmethod infer_shapes(*args, **kwargs)

Infers batch_shape and event_shape given shapes of args to __init__().

Note

This assumes distribution shape depends only on the shapes of tensor inputs, not in the data contained in those inputs.

Parameters:
  • *args (Any) – Positional args replacing each input arg with a tuple representing the sizes of each tensor input.

  • **kwargs (Any) – Keywords mapping name of input arg to tuple representing the sizes of each tensor input.

Returns:

A pair (batch_shape, event_shape) of the shapes of a distribution that would be created with input args of the given shapes.

Return type:

tuple

property is_discrete: bool
log_prob(value)

Evaluates the log probability density for a batch of samples given by value.

Parameters:

value – A batch of samples from the distribution.

Returns:

an array with shape value.shape[:-self.event_shape]

Return type:

ArrayLike

mask(mask)

Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions Distribution.batch_shape .

Parameters:

mask (bool or jnp.ndarray) – A boolean or boolean valued array (True includes a site, False excludes a site).

Returns:

A masked copy of this distribution.

Return type:

MaskedDistribution

Example:

>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO

>>> def model(data, m):
...     f = numpyro.sample("latent_fairness", dist.Beta(1, 1))
...     with numpyro.plate("N", data.shape[0]):
...         # only take into account the values selected by the mask
...         masked_dist = dist.Bernoulli(f).mask(m)
...         numpyro.sample("obs", masked_dist, obs=data)


>>> def guide(data, m):
...     alpha_q = numpyro.param("alpha_q", 5., constraint=constraints.positive)
...     beta_q = numpyro.param("beta_q", 5., constraint=constraints.positive)
...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))


>>> data = jnp.concatenate([jnp.ones(5), jnp.zeros(5)])
>>> # select values equal to one
>>> masked_array = jnp.where(data == 1, True, False)
>>> optimizer = numpyro.optim.Adam(step_size=0.05)
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.PRNGKey(0), 300, data, masked_array)
>>> params = svi_result.params
>>> # inferred_mean is closer to 1
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
property mean: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Mean of the distribution.

property num_params: int
pytree_aux_fields = ('_batch_shape', '_event_shape')
pytree_data_fields = ()
reparametrized_params = []
rsample(key, sample_shape=())
Parameters:
Return type:

Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

sample(key, sample_shape=())

Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape. Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned sample will be filled with iid draws from the distribution instance.

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

ndarray

sample_with_intermediates(key, sample_shape=())

Same as sample except that any intermediate computations are returned (useful for TransformedDistribution).

Parameters:
  • key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.

  • sample_shape (tuple) – the sample shape for the distribution.

Returns:

an array of shape sample_shape + batch_shape + event_shape

Return type:

ndarray

static set_default_validate_args(value)
Parameters:

value (bool)

Return type:

None

shape(sample_shape=())

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:

sample_shape (tuple) – the size of the iid batch to be drawn from the distribution.

Returns:

shape of samples.

Return type:

tuple

property support: Constraint | None

The support of this distribution. Subclasses can override this as a class attribute or as a property.

to_event(reinterpreted_batch_ndims=None)

Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.

Parameters:

reinterpreted_batch_ndims (int | None) – Number of rightmost batch dims to interpret as event dims.

Returns:

An instance of Independent distribution.

Return type:

numpyro.distributions.distribution.Independent

tree_flatten()
Return type:

tuple[tuple[Any, …], tuple[Any, …]]

classmethod tree_unflatten(aux_data, params)
Parameters:
Return type:

Distribution

validate_args(strict=True)

Validate the arguments of the distribution.

Parameters:

strict (bool) – Require strict validation, raising an error if the function is called inside jitted code.

Return type:

None

property variance: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray

Variance of the distribution.