Source code for pmrf.rf_functions.connections.terminate

import jax
import jax.numpy as jnp
from pmrf.rf_functions.conversions import fix_z0_shape

[docs] def terminate_s_in_s( Smat_from: jnp.ndarray, z0_from: jnp.ndarray, Smat_into: jnp.ndarray, z0_into: jnp.ndarray, s_def = "power", ) -> tuple[jnp.ndarray, jnp.ndarray]: """ Terminates one S-parameter matrix in another S-parameter matrix. This supports terminating a 2N-port into an N-port, and automatically renormalizes the terminating network if the connecting port impedances do not match. The resultant terminated N-port S matrix is returned. Parameters ---------- Smat_from : jnp.ndarray The main S matrix to terminate from. Shape: (nfreqs, 2N, 2N) z0_from : jnp.ndarray The characteristic impedance of Smat_from. Smat_into : jnp.ndarray The S matrix to terminate into. Shape: (nfreqs, N, N) z0_into : jnp.ndarray The characteristic impedance of Smat_into. s_def : string The S-parameter definition. Defaults to "power". Returns ------- tuple[jnp.ndarray, jnp.ndarray] The resulting S-parameter matrix and characteristic impedance. """ N = Smat_into.shape[1] if Smat_from.shape[1] != 2 * N: raise ValueError( f"terminate_s_in_s requires a 2N-port terminating into an N-port. " f"Got {Smat_from.shape[1]} and {N}." ) nfreqs = Smat_from.shape[0] z0_from = fix_z0_shape(z0_from, nfreqs, 2 * N) z0_into = fix_z0_shape(z0_into, nfreqs, N) # Extract block matrices S11 = Smat_from[:, :N, :N] S12 = Smat_from[:, :N, N:] S21 = Smat_from[:, N:, :N] S22 = Smat_from[:, N:, N:] # Impedances of the connecting ports z0_out = z0_from[:, N:] # --- Renormalization Block --- def apply_renorm(operand): S_L, z_old, z_new = operand # Power-wave reflection coefficient (diagonal per frequency) if s_def == "power": g = (z_new - z_old) / (z_new + jnp.conj(z_old)) else: g = (z_new - z_old) / (z_new + z_old) # Convert to diagonal matrices: (nfreqs, N, N) G = jax.vmap(jnp.diag)(g) I = jnp.eye(N, dtype=S_L.dtype) I_minus_G = I - G # Calculate S_new = (I - G)^-1 @ (S_L - G) @ (I - G @ S_L)^-1 @ (I - G) # We use jnp.linalg.solve to avoid explicit inverses X = jnp.linalg.solve(I_minus_G, S_L - G) # X = (I - G)^-1 @ (S_L - G) Z = jnp.linalg.solve(I - G @ S_L, I_minus_G) # Z = (I - G @ S_L)^-1 @ (I - G) return X @ Z def skip_renorm(operand): S_L, _, _ = operand return S_L # Check if we need to renormalize (using allclose for floating point safety) needs_renorm = jnp.logical_not(jnp.allclose(z0_out, z0_into)) # jax.lax.cond strictly preserves the fast path during JIT execution S_L_matched = jax.lax.cond( needs_renorm, apply_renorm, skip_renorm, (Smat_into, z0_into, z0_out) ) # ----------------------------- I = jnp.eye(N, dtype=Smat_from.dtype) # Compute S_term = S11 + S12 @ S_L @ inv(I - S22 @ S_L) @ S21 diff = I - S22 @ S_L_matched X = jnp.linalg.solve(diff, S21) S_term = S11 + S12 @ S_L_matched @ X z0_term = z0_from[:, :N] return S_term, z0_term
[docs] def terminate_a_in_s( Amat: jnp.ndarray, Smat: jnp.ndarray, z0: jnp.ndarray, s_def = "power", ) -> jnp.ndarray: """ Terminates an ABCD matrix in an S-parameter matrix. Currently this only supports terminating a two-port in a one-port. The resultant terminated one-port S matrix is returned. Parameters ---------- Amat : jnp.ndarray The ABCD matrix to terminate from. Smat : jnp.ndarray: The S matrix to terminate into. z0 : jnp.ndarray: The characteristic impedance of the S-matrix being terminated. Returns ------- jnp.ndarray The resulting S-parameter matrix of the terminated system. """ if Smat.shape[1] != 1: raise Exception("terminate_a_in_s can only be called for one-port S matrices") z0 = fix_z0_shape(z0, Smat.shape[0], Smat.shape[1]) # Terminated last in s11 s11 = Smat[:,0,0] A, B, C, D = Amat[:,0,0], Amat[:,0,1], Amat[:,1,0], Amat[:,1,1] num = z0 * (1 + s11) * (A - z0*C) + (B - D*z0)*(1-s11) den = z0 * (1 + s11) * (A + z0*C) + (B + D*z0)*(1-s11) s11_out = num / den return s11_out.reshape(-1, 1, 1), z0[:,0]