Source code for pmrf.functions.conversions

import jax.numpy as jnp
import jax
from jax import lax
from jax.scipy.special import gammaln
from jax._src.numpy.ufuncs import _constant_like

from pmrf.constants import NumberLike
from pmrf.functions import rsolve, nudge_eig

ZERO = 1e-4

[docs] def s2s(s: NumberLike, z0: NumberLike, s_def_new, s_def_old): if s_def_new == s_def_old: return s nfreqs, nports, nports = s.shape z0 = fix_z0_shape(z0, nfreqs, nports) all_real = jnp.isreal(z0).all() def real_branch(): return s def imag_branch(): # Calculate port voltages and currents using the old s_def. F, G = jnp.zeros_like(s), jnp.zeros_like(s) diag_idx = jnp.arange(F.shape[1]) if s_def_old == 'power': F = F.at[:, diag_idx, diag_idx].set(1.0 / (jnp.sqrt(z0.real))) G = G.at[:, diag_idx, diag_idx].set(z0) Id = jnp.eye(s.shape[1], dtype=complex) v = F @ (G.conjugate() + G @ s) i = F @ (Id - s) elif s_def_old == 'traveling': F = F.at[:, diag_idx, diag_idx].set(jnp.sqrt(z0)) G = G.at[:, diag_idx, diag_idx].set(1/(jnp.sqrt(z0))) Id = jnp.eye(s.shape[1], dtype=complex) v = F @ (Id + s) i = G @ (Id - s) else: raise ValueError(f'Unknown s_def: {s_def_old}') # Calculate a and b waves from the voltages and currents. F, G = jnp.zeros_like(s), jnp.zeros_like(s) if s_def_new == 'power': F = F.at[:, diag_idx, diag_idx].set(1/(2*jnp.sqrt(z0.real))) G = G.at[:, diag_idx, diag_idx].set(z0) a = F @ (v + G @ i) b = F @ (v - G.conjugate() @ i) elif s_def_new == 'traveling': F = F.at[:, diag_idx, diag_idx].set(1/(jnp.sqrt(z0))) G = G.at[:, diag_idx, diag_idx].set(z0) a = F @ (v + G @ i) b = F @ (v - G @ i) else: raise ValueError(f'Unknown s_def: {s_def_old}') # New S-parameter matrix from a and b waves. s_new = jnp.zeros_like(s) for n in range(nports): for m in range(nports): s_new = s_new.at[:, m, n].set(b[:, m, n] / a[:, n, n]) return s_new return jax.lax.cond(all_real, real_branch, imag_branch)
[docs] def a2s(a: jnp.ndarray, z0: NumberLike = 50) -> jnp.ndarray: # Taken from scikit-rf. See the copyright notice in pmrf._frequency.py nfreqs, nports, nports = a.shape if nports != 2: raise IndexError('abcd parameters are defined for 2-ports networks only') z0 = fix_z0_shape(z0, nfreqs, nports) z01 = z0[:,0] z02 = z0[:,1] A = a[:,0,0] B = a[:,0,1] C = a[:,1,0] D = a[:,1,1] denom = A*z02 + B + C*z01*z02 + D*z01 s = jnp.array([ [ (A*z02 + B - C*z01.conj()*z02 - D*z01.conj() ) / denom, (2*jnp.sqrt(z01.real * z02.real)) / denom, ], [ (2*(A*D - B*C)*jnp.sqrt(z01.real * z02.real)) / denom, (-A*z02.conj() + B - C*z01*z02.conj() + D*z01) / denom, ], ]).transpose() return s
[docs] def s2a(s: jnp.ndarray, z0: NumberLike = 50) -> jnp.ndarray: # Taken from scikit-rf. See the copyright notice in pmrf._frequency.py nfreqs, nports, nports = s.shape if nports != 2: raise IndexError('abcd parameters are defined for 2-ports networks only') z0 = fix_z0_shape(z0, nfreqs, nports) z01 = z0[:,0] z02 = z0[:,1] denom = (2*s[:,1,0]*jnp.sqrt(z01.real * z02.real)) a = jnp.array([ [ ((z01.conj() + s[:,0,0]*z01)*(1 - s[:,1,1]) + s[:,0,1]*s[:,1,0]*z01) / denom, ((1 - s[:,0,0])*(1 - s[:,1,1]) - s[:,0,1]*s[:,1,0]) / denom, ], [ ((z01.conj() + s[:,0,0]*z01)*(z02.conj() + s[:,1,1]*z02) - s[:,0,1]*s[:,1,0]*z01*z02) / denom, ((1 - s[:,0,0])*(z02.conj() + s[:,1,1]*z02) + s[:,0,1]*s[:,1,0]*z02) / denom, ], ]).transpose() return a
[docs] def y2s(y: jnp.ndarray, z0: NumberLike = 50, s_def = 'power') -> jnp.ndarray: nfreqs, nports, nports = y.shape z0 = fix_z0_shape(z0, nfreqs, nports) z0 = z0.astype(dtype=complex) z0 = jnp.where(z0.real == 0, z0 + ZERO, z0) y = jnp.array(y, dtype=complex) # The following is a vectorized version of a for loop for all frequencies. # Creating Identity matrices of shape (nports,nports) for each nfreqs Id = jnp.eye(nports, dtype=complex)[None, :, :] # (1, nports, nports) Id = jnp.broadcast_to(Id, (nfreqs, nports, nports)) if s_def == 'power': # Creating diagonal matrices of shape (nports,nports) for each nfreqs F, G = jnp.zeros_like(y), jnp.zeros_like(y) diag_idx = jnp.arange(F.shape[1]) F = F.at[:, diag_idx, diag_idx].set(1.0 / (2 * jnp.sqrt(z0.real))) G = G.at[:, diag_idx, diag_idx].set(z0) s = rsolve(F @ (Id + G @ y), F @ (Id - jnp.conjugate(G) @ y)) elif s_def == 'traveling': # Traveling-waves definition. Cf.Wikipedia "Impedance parameters" page. # Creating diagonal matrices of shape (nports, nports) for each nfreqs sqrtz0 = jnp.zeros_like(y) # (nfreqs, nports, nports) jnp.einsum('ijj->ij', sqrtz0)[...] = jnp.sqrt(z0) s = rsolve(Id + sqrtz0 @ y @ sqrtz0, Id - sqrtz0 @ y @ sqrtz0) else: raise ValueError(f'Unknown s_def: {s_def}') return s
[docs] def s2z(s: jnp.ndarray, z0: NumberLike = 50, s_def = 'power') -> jnp.ndarray: nfreqs, nports, nports = s.shape z0 = fix_z0_shape(z0, nfreqs, nports) z0 = z0.astype(dtype=complex) z0 = jnp.where(z0.real == 0, z0 + ZERO, z0) s = jnp.array(s, dtype=complex) # The following is a vectorized version of a for loop for all frequencies. # # Creating Identity matrices of shape (nports,nports) for each nfreqs Id = jnp.eye(nports, dtype=complex)[None, :, :] # (1, nports, nports) Id = jnp.broadcast_to(Id, (nfreqs, nports, nports)) if s_def == 'power': # Power-waves. Eq.(19) from [Kurokawa et al.] # Creating diagonal matrices of shape (nports,nports) for each nfreqs F, G = jnp.zeros_like(s), jnp.zeros_like(s) diag_idx = jnp.arange(F.shape[1]) F = F.at[:, diag_idx, diag_idx].set(1.0 / (2 * jnp.sqrt(z0.real))) G = G.at[:, diag_idx, diag_idx].set(z0) z = jnp.linalg.solve(nudge_eig((Id - s) @ F), (s @ G + jnp.conjugate(G)) @ F) # z = jnp.linalg.solve((Id - s) @ F, (s @ G + jnp.conjugate(G)) @ F) elif s_def == 'traveling': # Traveling-waves definition. Cf.Wikipedia "Impedance parameters" page. # Creating diagonal matrices of shape (nports, nports) for each nfreqs sqrtz0 = jnp.zeros_like(s) diag_idx = jnp.arange(s.shape[1]) sqrtz0 = sqrtz0.at[:, diag_idx, diag_idx].set(jnp.sqrt(z0)) z = sqrtz0 @ jnp.linalg.solve(nudge_eig(Id - s), (Id + s) @ sqrtz0) else: raise ValueError(f'Unknown s_def: {s_def}') return z
[docs] def z2s(z: NumberLike, z0:NumberLike = 50, s_def = 'power') -> jnp.ndarray: nfreqs, nports, nports = z.shape z0 = fix_z0_shape(z0, nfreqs, nports) z0 = z0.astype(dtype=complex) z0 = jnp.where(z0.real == 0, z0 + ZERO, z0) z = jnp.array(z, dtype=complex) if s_def == 'power': # Power-waves. Eq.(18) from [Kurokawa et al.3] # Creating diagonal matrices of shape (nports,nports) for each nfreqs F, G = jnp.zeros_like(z), jnp.zeros_like(z) diag_idx = jnp.arange(F.shape[1]) F = F.at[:, diag_idx, diag_idx].set(1.0 / (2 * jnp.sqrt(z0.real))) G = G.at[:, diag_idx, diag_idx].set(z0) s = rsolve(F @ (z + G), F @ (z - jnp.conjugate(G))) elif s_def == 'traveling': # Traveling-waves definition. Cf.Wikipedia "Impedance parameters" page. # Creating Identity matrices of shape (nports,nports) for each nfreqs Id, sqrty0 = jnp.zeros_like(z), jnp.zeros_like(z) # (nfreqs, nports, nports) diag_idx = jnp.arange(z.shape[1]) Id = Id.at[:, diag_idx, diag_idx].set(1.0) sqrty0 = sqrty0.at[:, diag_idx, diag_idx].set(jnp.sqrt(1.0/z0)) s = rsolve(sqrty0 @ z @ sqrty0 + Id, sqrty0 @ z @ sqrty0 - Id) else: raise ValueError(f'Unknown s_def: {s_def}') return s
[docs] def renormalize_s(s: jnp.ndarray, z_old: NumberLike, z_new: NumberLike, s_def_old='power', s_def_new='power') -> jnp.ndarray: return z2s(s2z(s, z0=z_old, s_def=s_def_old), z0=z_new, s_def=s_def_new)
# def renormalize_s_direct( # s: jnp.ndarray, # z_old: NumberLike, # z_new: NumberLike, # ) -> jnp.ndarray: # """ # Renormalize S-parameters from z_old to z_new impedances. # Parameters # ---------- # s : jnp.ndarray, shape (nfreqs, nports, nports) # S-parameter matrix. # z_old : scalar, (nports,), or (nfreqs, nports) # z_new : scalar, (nports,), or (nfreqs, nports) # Returns # ------- # jnp.ndarray of shape (nfreqs, nports, nports) # Renormalized S-parameters. # """ # nfreqs, nports, _ = s.shape # Z_A = fix_z0_shape(z_old, nfreqs, nports) # shape: (nfreqs, nports) # Z_B = fix_z0_shape(z_new, nfreqs, nports) # # Check if z_old and z_new are frequency-independent # freq_indep = ( # jnp.ndim(z_old) == 0 or jnp.shape(z_old) == (nports,) # ) and ( # jnp.ndim(z_new) == 0 or jnp.shape(z_new) == (nports,) # ) # if freq_indep: # z_a = Z_A[0] # shape: (nports,) # z_b = Z_B[0] # z_prod_sqrt_inv = 1.0 / jnp.sqrt(z_a * z_b) # D = 0.5 * jnp.diag(z_prod_sqrt_inv) # M_s = D @ jnp.linalg.inv(jnp.diag(z_a + z_b)) # M_c = D @ jnp.linalg.inv(jnp.diag(z_a - z_b)) # def renorm_fixed(s_f): # return (M_c + M_s @ s_f) @ jnp.linalg.inv(M_s + M_c @ s_f) # return jax.vmap(renorm_fixed)(s) # else: # def renorm_per_freq(s_f, z_a, z_b): # z_prod_sqrt_inv = 1.0 / jnp.sqrt(z_a * z_b) # D = 0.5 * jnp.diag(z_prod_sqrt_inv) # M_s = D @ jnp.linalg.inv(jnp.diag(z_a + z_b)) # M_c = D @ jnp.linalg.inv(jnp.diag(z_a - z_b)) # return (M_c + M_s @ s_f) @ jnp.linalg.inv(M_s + M_c @ s_f) # return jax.vmap(renorm_per_freq, in_axes=(0, 0, 0))(s, Z_A, Z_B)
[docs] def fix_z0_shape(z0: NumberLike, nfreqs: int, nports: int) -> jnp.ndarray: # Adapted from scikit-rf. See the copyright notice in pmrf._frequency.py if jnp.shape(z0) == (nfreqs, nports): return z0.copy() elif jnp.ndim(z0) == 0: return jnp.array(nfreqs * [nports * [z0]]) elif len(z0) == nports: return jnp.array(nfreqs * [z0]) elif len(z0) == nfreqs: return jnp.array(nports * [z0]).T else: raise IndexError('z0 is not an acceptable shape')