Source code for pmrf.rf_functions.connections.cascade

from typing import Sequence

import jax
import jax.numpy as jnp
import equinox as eqx

from pmrf.constants import NumberLike
from pmrf.rf_functions.conversions import fix_z0_shape
from pmrf.math_functions import nudge_diag

[docs] @eqx.filter_jit def cascade_s( Smats: NumberLike, z0s: NumberLike, method = 'redheffer', eps=1e-12, ) -> jnp.ndarray: """ Cascades multiple S-parameter networks together. This is the most efficient function to use for cascading S-parameter networks together. Currently only the 2N port case is supported, in which case Redheffer's star product is used. Also, currently, all matrices must have the same S-parameter definition and characteristic impedance. Parameters ---------- Smats : NumberLike A sequence of S-parameter matrices for the component networks. The array must have shape (N x F x m x m) representing N networks to be cascaded sequential at F frequencies, each with ports (m, m). z0s : NumberLike A sequence of characteristic impedances (z0) for the component networks. Must be a float, or N arrays each broadcastable to (N x F x m). Returns ------- jnp.ndarray The resulting S-parameter matrix of the cascaded system. """ if method != 'redheffer': raise ValueError("Currently only method = 'redheffer' is supported for cascade_s") Smats = jnp.array(Smats) z0s = jnp.array(z0s) # Check port shapes and impedances nnetworks = Smats.shape[0] if not all(s.shape == Smats[0].shape for s in Smats): raise ValueError("Currently, all networks must have the same number of ports") nfreq, nports, nports = Smats[0].shape z0s = fix_z0_shape(z0s, nfreq, nports, nnetworks) z0s = eqx.error_if( z0s, jnp.logical_not(jnp.all(z0s == z0s[0])), "Currently, all characteristic impedances must be equal in cascade_s" ) # Connect all S-matrices using jax.lax.scan def scan_fn(carry, x): S_acc, z0_acc = carry S_i, z0_i = x S_next, z0_next = _cascade_two_s_redheffer(S_acc, z0_acc, S_i, z0_i, eps=eps) return (S_next, z0_next), None # jax.lax.scan loops over the arrays dynamically without unrolling (S_cas, z0_cas), _ = jax.lax.scan( scan_fn, init=(Smats[0], z0s[0]), xs=(Smats[1:], z0s[1:]) ) return S_cas, z0_cas
def _cascade_two_s_redheffer( Smat_A: jnp.ndarray, z0_A: jnp.ndarray, Smat_B: jnp.ndarray, z0_B: jnp.ndarray, eps=1e-12, ): nports = Smat_A.shape[1] assert nports % 2 == 0 N = nports // 2 z0_cas = jnp.concatenate((z0_A[:, :N], z0_B[:, N:]), axis=1) # Partition blocks A11 = Smat_A[:, :N, :N] A12 = Smat_A[:, :N, N:] A21 = Smat_A[:, N:, :N] A22 = Smat_A[:, N:, N:] B11 = Smat_B[:, :N, :N] B12 = Smat_B[:, :N, N:] B21 = Smat_B[:, N:, :N] B22 = Smat_B[:, N:, N:] I = jnp.broadcast_to(jnp.eye(N), (Smat_A.shape[0], N, N)) M = nudge_diag(I - B11 @ A22, eps=eps) N = nudge_diag(I - A22 @ B11, eps=eps) X = jnp.linalg.solve(M, I) Y = jnp.linalg.solve(N, I) S11 = A11 + A12 @ X @ B11 @ A21 S12 = A12 @ X @ B12 S21 = B21 @ Y @ A21 S22 = B22 + B21 @ Y @ A22 @ B12 top = jnp.concatenate((S11, S12), axis=2) bottom = jnp.concatenate((S21, S22), axis=2) S_cas = jnp.concatenate((top, bottom), axis=1) return S_cas, z0_cas
[docs] def cascade_a( Amats: Sequence[jnp.ndarray], ) -> jnp.ndarray: """ Cascaded multiple ABCD-parameter networks together in a loop. This simply multiplies the matrices together. Parameters ---------- Smats : Sequence[jnp.ndarray] A sequence of S-parameter matrices for the component networks. Each element has shape `(Nf, n_ports, n_ports)`. z0s : Sequence[jnp.ndarray] A sequence of characteristic impedances (z0) for the component networks. s_def : str, optional, default='power' The S-parameter definition to use ('power', 'traveling', or 'pseudo'). Returns ------- jnp.ndarray The resulting S-parameter matrix of the cascaded system. """ # Ensure we are working with a stacked array of shape (N, Nf, 2, 2) Amats_array = jnp.asarray(Amats) # Fast path for a single network if Amats_array.shape[0] == 1: return Amats_array[0] def scan_fn(carry, x): # The '@' operator natively broadcasts across the frequency dimension (Nf) return carry @ x, None # Scan sequentially multiplies the matrices across the 'N' dimension final_a, _ = jax.lax.scan(scan_fn, init=Amats_array[0], xs=Amats_array[1:]) return final_a
def cascade_t( Tmats: Sequence[jnp.ndarray], ) -> jnp.ndarray: """ Cascade multiple wave-transfer (T) matrices. Parameters ---------- Tmats : Sequence[jnp.ndarray] Sequence of T matrices with shape (Nf, 2, 2) Returns ------- jnp.ndarray Cascaded T matrix with shape (Nf, 2, 2) """ Tmats_array = jnp.asarray(Tmats) if Tmats_array.shape[0] == 1: return Tmats_array[0] def scan_fn(carry, x): return carry @ x, None final_t, _ = jax.lax.scan(scan_fn, init=Tmats_array[0], xs=Tmats_array[1:]) return final_t