Source code for pmrf.distributions.maf

from numpyro.distributions import Distribution, constraints

DEFAULT_BACKEND = 'flowjax'

def _get_backend_cls(backend_str: str):
    if backend_str == 'flowjax':
        from pmrf.distributions.flowjax import FlowJAXDistribution
    elif backend_str == 'margarine':
        from pmrf.distributions.margarine import MargarineMAFDistribution
    else:
        raise Exception(f'{backend_str} backend not supported by MAFDistribution')
    
def MAFDistribution(data, backend='flowjax') -> Distribution:
    if backend_str == 'flowjax':
        from pmrf.distributions.flowjax import FlowJAXDistribution
    elif backend_str == 'margarine':
        from pmrf.distributions.margarine import MargarineMAFDistribution
    else:
        raise Exception(f'{backend_str} backend not supported by MAFDistribution')    
        
[docs] class MAFDistribution(Distribution): """ A generic NumPyro-compatible distribution wrapper for MAF-based models. The actual implementation is delegated to a backend adapter. """ impl = None
[docs] @classmethod def generate(cls, *args, backend=DEFAULT_BACKEND, **kwargs) -> Distribution: backend_cls = _get_backend_cls(backend) return backend_cls.generate(*args, **kwargs)
def __getattr__(self, name): return getattr(self.backend, name)