pmrf.distributions.maf

Functions

_get_backend_cls(backend_str)

Classes

MAFDistribution([batch_shape, event_shape, ...])

A generic NumPyro-compatible distribution wrapper for MAF-based models.

class pmrf.distributions.maf.MAFDistribution(batch_shape=(), event_shape=(), *, validate_args=None)[source]

Bases: Distribution

A generic NumPyro-compatible distribution wrapper for MAF-based models. The actual implementation is delegated to a backend adapter.

Parameters:
  • batch_shape (tuple[int, ...])

  • event_shape (tuple[int, ...])

  • validate_args (bool | None)

impl = None
classmethod generate(*args, backend='flowjax', **kwargs)[source]
Return type:

Distribution