"""
Adapters to hold discrete and continuous surrogate models.
"""
from typing import Callable
import jax.numpy as jnp
from pmrf.frequency import Frequency
from pmrf.parameters import Parameter
from pmrf.models.adapters.base import SingleProperty, SingleDiscreteProperty
[docs]
class ContinuousSurrogate(SingleProperty, transparent=True):
"""
A model that predicts its output at an arbitrary frequency using an arbitrary callable.
This is very useful for embedding machine learning architectures. For example, `callable` can be any Equinox module.
"""
# The input parameters of length P
theta: Parameter | list[Parameter] = None
# The underlying model. Must accept an array of shape (P,) and a frequency object, and return an array of shape nfreq x nports x nports.
func: Callable[[jnp.ndarray, Frequency], jnp.ndarray] = None
[docs]
def output(self, freq: Frequency) -> jnp.ndarray:
# Hack reshape for now
return self.func(self.flat_param_values(include_fixed=True), freq).reshape(-1, 1, 1)
[docs]
class DiscreteSurrogate(SingleDiscreteProperty, transparent=True):
"""
A model that predicts its output at a discrete set of frequency values using an arbitrary callable.
This is very useful for embedding machine learning architectures. For example, `callable` can be any Equinox module.
"""
# The input parameters of length P
theta: Parameter | list[Parameter] = None
# The underlying model. Must accept an array of shape (P,) and return an array of shape self.frequency.npoints x nports x nports.
func: Callable[[jnp.ndarray], jnp.ndarray] = None
[docs]
def output_discrete(self) -> jnp.ndarray:
return self.func(self.flat_param_values(include_fixed=True)).reshape(-1, 1, 1)