import jax.numpy as jnp
from pmrf.constants import NumberLike
import jax
from jax import lax
from jax.scipy.special import gammaln
from jax._src.numpy.ufuncs import _constant_like
[docs]
def a2s(a: jnp.ndarray, z0: NumberLike = 50) -> jnp.ndarray:
# Taken from scikit-rf. See the copyright notice in pmrf._frequency.py
nfreqs, nports, nports = a.shape
if nports != 2:
raise IndexError('abcd parameters are defined for 2-ports networks only')
z0 = fix_z0_shape(z0, nfreqs, nports)
z01 = z0[:,0]
z02 = z0[:,1]
A = a[:,0,0]
B = a[:,0,1]
C = a[:,1,0]
D = a[:,1,1]
denom = A*z02 + B + C*z01*z02 + D*z01
s = jnp.array([
[
(A*z02 + B - C*z01.conj()*z02 - D*z01.conj() ) / denom,
(2*jnp.sqrt(z01.real * z02.real)) / denom,
],
[
(2*(A*D - B*C)*jnp.sqrt(z01.real * z02.real)) / denom,
(-A*z02.conj() + B - C*z01*z02.conj() + D*z01) / denom,
],
]).transpose()
return s
[docs]
def s2a(s: jnp.ndarray, z0: NumberLike = 50) -> jnp.ndarray:
# Taken from scikit-rf. See the copyright notice in pmrf._frequency.py
nfreqs, nports, nports = s.shape
if nports != 2:
raise IndexError('abcd parameters are defined for 2-ports networks only')
z0 = fix_z0_shape(z0, nfreqs, nports)
z01 = z0[:,0]
z02 = z0[:,1]
denom = (2*s[:,1,0]*jnp.sqrt(z01.real * z02.real))
a = jnp.array([
[
((z01.conj() + s[:,0,0]*z01)*(1 - s[:,1,1]) + s[:,0,1]*s[:,1,0]*z01) / denom,
((1 - s[:,0,0])*(1 - s[:,1,1]) - s[:,0,1]*s[:,1,0]) / denom,
],
[
((z01.conj() + s[:,0,0]*z01)*(z02.conj() + s[:,1,1]*z02) - s[:,0,1]*s[:,1,0]*z01*z02) / denom,
((1 - s[:,0,0])*(z02.conj() + s[:,1,1]*z02) + s[:,0,1]*s[:,1,0]*z02) / denom,
],
]).transpose()
return a
[docs]
def renormalize_s(
s: jnp.ndarray,
z_old: NumberLike,
z_new: NumberLike,
) -> jnp.ndarray:
"""
Renormalize S-parameters from z_old to z_new impedances.
Parameters
----------
s : jnp.ndarray, shape (nfreqs, nports, nports)
S-parameter matrix.
z_old : scalar, (nports,), or (nfreqs, nports)
z_new : scalar, (nports,), or (nfreqs, nports)
Returns
-------
jnp.ndarray of shape (nfreqs, nports, nports)
Renormalized S-parameters.
"""
nfreqs, nports, _ = s.shape
Z_A = fix_z0_shape(z_old, nfreqs, nports) # shape: (nfreqs, nports)
Z_B = fix_z0_shape(z_new, nfreqs, nports)
# Check if z_old and z_new are frequency-independent
freq_indep = (
jnp.ndim(z_old) == 0 or jnp.shape(z_old) == (nports,)
) and (
jnp.ndim(z_new) == 0 or jnp.shape(z_new) == (nports,)
)
if freq_indep:
z_a = Z_A[0] # shape: (nports,)
z_b = Z_B[0]
z_prod_sqrt_inv = 1.0 / jnp.sqrt(z_a * z_b)
D = 0.5 * jnp.diag(z_prod_sqrt_inv)
M_s = D @ jnp.linalg.inv(jnp.diag(z_a + z_b))
M_c = D @ jnp.linalg.inv(jnp.diag(z_a - z_b))
def renorm_fixed(s_f):
return (M_c + M_s @ s_f) @ jnp.linalg.inv(M_s + M_c @ s_f)
return jax.vmap(renorm_fixed)(s)
else:
def renorm_per_freq(s_f, z_a, z_b):
z_prod_sqrt_inv = 1.0 / jnp.sqrt(z_a * z_b)
D = 0.5 * jnp.diag(z_prod_sqrt_inv)
M_s = D @ jnp.linalg.inv(jnp.diag(z_a + z_b))
M_c = D @ jnp.linalg.inv(jnp.diag(z_a - z_b))
return (M_c + M_s @ s_f) @ jnp.linalg.inv(M_s + M_c @ s_f)
return jax.vmap(renorm_per_freq, in_axes=(0, 0, 0))(s, Z_A, Z_B)
[docs]
def fix_z0_shape(z0: NumberLike, nfreqs: int, nports: int) -> jnp.ndarray:
# Taken from scikit-rf. See the copyright notice in pmrf._frequency.py
if jnp.shape(z0) == (nfreqs, nports):
return z0.copy()
elif jnp.ndim(z0) == 0:
return jnp.array(nfreqs * [nports * [z0]])
elif len(z0) == nports:
return jnp.array(nfreqs * [z0])
elif len(z0) == nfreqs:
return jnp.array(nports * [z0]).T
else:
raise IndexError('z0 is not an acceptable shape')