import jax
from jax import vmap
import skrf
import jax.numpy as jnp
from pmrf.frequency import Frequency
from pmrf.models.model import Model
from pmrf.models.lumped import Match
from pmrf._util import field
[docs]
class Measured(Model):
s_measured: jnp.ndarray
f_measured: jnp.ndarray
def __init__(self, ntwk: skrf.Network | str):
if isinstance(ntwk, str):
ntwk = skrf.Network(ntwk)
self.s_measured = jnp.array(ntwk.s)
self.f_measured = jnp.array(ntwk.f)
[docs]
def s(self, freq: Frequency) -> jnp.ndarray:
S, f = self.s_measured, self.f_measured
f_new = freq.f
n_ports = S.shape[1]
# Clip f_new to be within f bounds
f_min, f_max = f[0], f[-1]
f_new_clipped = jnp.clip(f_new, f_min, f_max)
# Split into real and imaginary parts
S_real = jnp.real(S)
S_imag = jnp.imag(S)
# Interpolate each real/imag component independently
def interp_component(S_comp):
return jnp.stack([
jnp.stack([
jnp.interp(f_new_clipped, f, S_comp[:, i, j])
for j in range(n_ports)
], axis=0)
for i in range(n_ports)
], axis=0) # shape: (n_ports, n_ports, n_freqs_new)
S_real_new = interp_component(S_real)
S_imag_new = interp_component(S_imag)
# Combine and transpose back to (n_freqs_new, n_ports, n_ports)
S_new = (S_real_new + 1j * S_imag_new).transpose(2, 0, 1)
return S_new
[docs]
class SModel(Model):
"""
**Overview**
A general model defined by a constant S-parameter matrix.
"""
s: jnp.array
[docs]
def s(self, freq: Frequency) -> jnp.ndarray:
"""Returns the S-parameter matrix.
Args:
freq (Frequency): Specifies the frequency to calculate the parameters at.
Returns:
jnp.ndarray: The resultant block-diagonal S-parameter matrix.
"""
return self.s
[docs]
class AModel(Model):
"""
**Overview**
A general model defined by a constant ABCD matrix.
"""
a: jnp.array
[docs]
def a(self, freq: Frequency) -> jnp.ndarray:
"""Returns the ABCD matrix.
Args:
freq (Frequency): Specifies the frequency to calculate the parameters at.
Returns:
jnp.ndarray: The resultant block-diagonal ABCD matrix.
"""
return self.a