Source code for pmrf.distributions.maf

from numpyro.distributions import Distribution, constraints

_BACKEND_REGISTRY = {}
_backend_initialized = False

[docs] def get_adapter(backend): if not _backend_initialized: try: from pmrf.distributions._backends.margarine import MargarineMAFAdapter _BACKEND_REGISTRY['margarine_maf'] = MargarineMAFAdapter except ImportError: pass _backend_initialized = True if backend not in _BACKEND_REGISTRY: raise ValueError(f"Unsupported backend '{backend}'. Available: {list(_BACKEND_REGISTRY)}") return _BACKEND_REGISTRY[backend]
[docs] class MAFDistribution(Distribution): """ A generic NumPyro-compatible distribution wrapper for MAF-based models. The actual implementation is delegated to a backend adapter. """ support = constraints.real_vector has_rsample = False # No reparameterization support def __init__(self, maf_or_path, backend='margarine_maf', validate_args=None): self.backend = backend adapter_cls = get_adapter(backend) if isinstance(maf_or_path, str): maf_model = adapter_cls.load(maf_or_path) else: maf_model = maf_or_path self.adapter = adapter_cls(maf_model) event_shape = (self.adapter.event_dim,) super().__init__(batch_shape=(), event_shape=event_shape, validate_args=validate_args)
[docs] def save(self, path): self.adapter.save(path)
[docs] @classmethod def load(cls, path, backend='margarine_maf', validate_args=None): """ Load a MAFDistribution from a file using the specified backend. """ adapter_cls = get_adapter(backend) maf_model = adapter_cls.load(path) return cls(maf_model, backend=backend, validate_args=validate_args)
[docs] @classmethod def generate(cls, data, weights=None, backend='margarine_maf', validate_args=None, **kwargs): adapter_cls = get_adapter(backend) maf_model = adapter_cls.generate(data, weights=weights, **kwargs) return cls(maf_model, backend=backend, validate_args=validate_args)
[docs] def sample(self, key, sample_shape=()): return self.adapter.sample(key, sample_shape)
[docs] def log_prob(self, value): return self.adapter.log_prob(value)
[docs] def icdf(self, u): return self.adapter.icdf(u)
@property def min(self): return self.adapter.min @property def max(self): return self.adapter.max