"""
An expansion of a set of basis functions.
"""
import jax.numpy as jnp
from pmrf.models.adapters.base import SingleDiscreteProperty
from pmrf.parameters import Parameter
[docs]
class VectorExpansion(SingleDiscreteProperty):
"""
A model where the output is a linear expansion of vector basis functions.
The S-parameters are returned as offset + coefficients @ basis, where the coefficients are the model parameters.
"""
# The coefficients parameters (coefficients)
coefficients_real: Parameter = None
coefficients_imag: Parameter = None
# The basis functions themselves and an optional offset
basis: jnp.ndarray = None
offset: jnp.ndarray = None
[docs]
def output_discrete(self) -> jnp.ndarray:
# The model output which multiplies the current coefficients onto the basis vectors
coeff = self.coefficients_complex
X = jnp.einsum('imn,ikmn->kmn', coeff, self.basis)
if self.offset is not None:
offset = self.offset.reshape(X.shape)
X += offset
return X
[docs]
def inverse(self, sample: jnp.ndarray) -> jnp.ndarray:
if len(sample.shape) == 1:
sample = sample[..., None, None]
# The inverse model, which projects a sample onto the coefficients
if self.offset is not None:
sample = sample - self.offset
basis_Tconj = self.basis.transpose(1, 0, 2, 3).conj()
# This projects the sample onto the basis vector for each port (m, n)
coefficients = jnp.einsum('ifmn,fbmn->ibmn', sample, basis_Tconj).reshape(basis_Tconj.shape[1:])
if self.coefficients_imag is not None:
coefficients = jnp.concat([coefficients.real, coefficients.imag])
else:
coefficients = coefficients.real
return coefficients
@property
def num_basis(self) -> int:
return len(self.basis)
@property
def basis_separate(self) -> jnp.ndarray:
return jnp.concatenate([self.basis.real, self.basis.imag], axis=0)
@property
def coefficients_complex(self) -> jnp.ndarray:
coefficients = self.coefficients_real
if self.coefficients_imag is not None:
coefficients += 1j * self.coefficients_imag
return coefficients