Source code for pmrf.models.components.lines.nonuniform

"""
Non-uniform transmission lines.
"""
from typing import Callable, Any, Dict

import jax
import jax.numpy as jnp
import equinox as eqx

from pmrf.frequency import Frequency
from pmrf.rf_functions.connections import cascade_s
from pmrf.parameters import Parameter, as_param
from pmrf.models.model import Model
from pmrf.models.components.lines.uniform import RLGCLine

[docs] class ProfiledLine(Model, transparent=True): r""" A non-uniform transmission line defined by an arbitrary profile. This model wraps any RLGC transmission line model and allows its parameters to vary as a function of length. For example, the line's characteristic impedance can be varied exponentially for impedance matching purposes. Any line parameters can be uniform across length or follow a spatial profile defined by user-provided function (e.g., splines). Both the uniform parameters and the coefficients of the profile functions are registered as `Parameter` objects, making the profile parameters compatible with fitting and sampling. Supported evaluation methods: - 'stepped': A discrete cascaded approximation (default). - 'riccati': A continuous ODE solver using the Matrix Riccati differential equations. Example -------- .. code-block:: python import pmrf as prf from pmrf.models import PhysicalLine, ProfiledLine def linear_taper(t, start_val, end_val): return start_val + (end_val - start_val) * t tapered_line = ProfiledLine( PhysicalLine, linear_taper, length=0.1, zn={'start_val': 50.0, 'end_val': 100.0}, epr=2.2, method='riccati', options={'rtol': 1e-6, 'atol': 1e-6}, ) freq = prf.Frequency(start=1, stop=10, npoints=101, unit='ghz') s_taper = tapered_line.s(freq) """ # Config line_fn: Callable[[Any], RLGCLine] = eqx.field(static=True) floating: bool = False profile_fns: Dict[str, Callable] = eqx.field(static=True) method: str = eqx.field(static=True) options: dict = eqx.field(static=True) # Parameters and sub-models length: Parameter profile_params: Dict[str, Dict[str, Parameter]] uniform_params: Dict[str, Parameter] def __init__( self, line_fn: Callable, profile_fn: Callable | None = None, *, length: Any = 1, floating: bool = False, method: str = 'stepped', options: dict | None = None, name: str | None = None, z0: complex = 50.0, **line_params ): super().__init__(name=name, z0=z0) # Defaults options = options or dict() if method == 'stepped': options.setdefault('N', 50) elif method == 'riccati': options.setdefault('rtol', 1e-5) options.setdefault('atol', 1e-5) options.setdefault('max_steps', 1000) # Add a small offset to avoid evaluating the jacobian at exactly 0.0 # if the underlying line model has a singularity at length=0 options.setdefault('dz_eval', 0.0) # Members self.line_fn = line_fn self.floating = floating self.length = as_param(length) self.method = method self.options = options # 1. Instantiate a dummy model to extract all default parameters. safe_kwargs = { k: v for k, v in line_params.items() if not isinstance(v, (dict, tuple)) } safe_kwargs['length'] = self.length dummy_model = line_fn(**safe_kwargs) base_params = dummy_model.named_params(include_fixed=True) # Extract the parameter objects, excluding 'length' merged_params = { k: v for k, v in base_params.items() if k != 'length' } # 2. Overlay the user's raw inputs merged_params.update(line_params) parsed_profile_fns = {} parsed_profile_params = {} parsed_uniform_params = {} # 3. Parse the unified kwargs into Equinox-friendly PyTrees for k, v in merged_params.items(): if isinstance(v, tuple) and len(v) == 2 and callable(v[0]) and isinstance(v[1], dict): parsed_profile_fns[k] = v[0] parsed_profile_params[k] = {arg: as_param(val) for arg, val in v[1].items()} elif isinstance(v, dict): if profile_fn is None: raise ValueError(f"Parameter '{k}' was provided as a dict, but no `profile_fn` was specified.") parsed_profile_fns[k] = profile_fn parsed_profile_params[k] = {arg: as_param(val) for arg, val in v.items()} else: parsed_uniform_params[k] = as_param(v) self.profile_fns = parsed_profile_fns self.profile_params = parsed_profile_params self.uniform_params = parsed_uniform_params
[docs] @eqx.filter_jit def section(self, dz: float, **profiled_params) -> RLGCLine: # Create a fresh dict to avoid mutating self.uniform_params current_params = dict(self.uniform_params) current_params['length'] = dz current_params['floating'] = self.floating current_params.update(profiled_params) return self.line_fn(**current_params)
def _profiled_params(self, t: float) -> dict: """Helper to evaluate all profile functions at normalized position t.""" evaluated_params = {} for p_name, func in self.profile_fns.items(): f_kwargs = self.profile_params.get(p_name, {}) evaluated_params[p_name] = func(t, **f_kwargs) return evaluated_params
[docs] @eqx.filter_jit def s(self, freq: Frequency) -> jnp.ndarray: if self.method == 'stepped': return self._s_stepped(freq) elif self.method == 'riccati': raise NotImplementedError("Riccati method for ProfileLine not yet functional") return self._s_riccati(freq) else: raise ValueError(f"Unknown evaluation method: '{self.method}'")
@eqx.filter_jit def _s_stepped(self, freq: Frequency) -> jnp.ndarray: N = self.options['N'] dz = self.length / N t_centers = (jnp.arange(N) + 0.5) / N def s_section(t): profiled_params = self._profiled_params(t) sec = self.section(dz, **profiled_params) return sec.s(freq), jnp.asarray(sec.z0, dtype=complex) batch_s, batch_z0 = jax.vmap(s_section)(t_centers) S_cas, z0_cas = cascade_s(batch_s, batch_z0) return S_cas @eqx.filter_jit def _s_riccati(self, freq: Frequency) -> jnp.ndarray: import diffrax def s_derivative(t, S_state_tuple, args): # 1. Reconstruct complex S-matrix from the real/imag tuple state # This explicitly resolves the Diffrax complex dtype warning S_real, S_imag = S_state_tuple S_state = S_real + 1j * S_imag # S_state shape: (..., 2N, 2N) N_total = S_state.shape[-1] N_ports = N_total // 2 # Unpack aggregated S-matrix blocks S11 = S_state[..., :N_ports, :N_ports] S12 = S_state[..., :N_ports, N_ports:] S21 = S_state[..., N_ports:, :N_ports] S22 = S_state[..., N_ports:, N_ports:] profiled_params = self._profiled_params(t) # Wrap the line construction to isolate `length` def local_s_matrix(dz): sec = self.section(dz, **profiled_params) return sec.s(freq) # 2. Extract Generator Matrix Q(z) = dS/dz using Finite Difference # This bypasses the jax.jacfwd crash caused by complex matrix inversions/conjugations # inside the renormalize_s power-wave formulation. dz_delta = 1e-8 S_plus = local_s_matrix(dz_delta) S_zero = local_s_matrix(0.0) Q = (S_plus - S_zero) / dz_delta Q11 = Q[..., :N_ports, :N_ports] Q12 = Q[..., :N_ports, N_ports:] Q21 = Q[..., N_ports:, :N_ports] Q22 = Q[..., N_ports:, N_ports:] # Differential Redheffer Star Product dS11 = S12 @ Q11 @ S21 dS12 = S12 @ (Q12 + Q11 @ S22) dS21 = (Q21 + S22 @ Q11) @ S21 dS22 = Q22 + Q21 @ S22 + S22 @ Q12 + S22 @ Q11 @ S22 # Reconstruct block matrix top = jnp.concatenate([dS11, dS12], axis=-1) bot = jnp.concatenate([dS21, dS22], axis=-1) dS_dt = jnp.concatenate([top, bot], axis=-2) dS_dt_scaled = self.length * dS_dt # Return as a tuple of real/imag to maintain Diffrax stability return jnp.real(dS_dt_scaled), jnp.imag(dS_dt_scaled) # Establish Initial Condition S(0) = [0, I; I, 0] profiled_params_init = self._profiled_params(0.0) sec_dummy = self.section(0.0, **profiled_params_init) S_dummy = sec_dummy.s(freq) batch_shape = S_dummy.shape[:-2] N_ports = S_dummy.shape[-1] // 2 I_mat = jnp.broadcast_to(jnp.eye(N_ports, dtype=complex), batch_shape + (N_ports, N_ports)) Z_mat = jnp.zeros_like(I_mat) top_init = jnp.concatenate([Z_mat, I_mat], axis=-1) bot_init = jnp.concatenate([I_mat, Z_mat], axis=-1) S_initial_complex = jnp.concatenate([top_init, bot_init], axis=-2) # Package initial state as a real/imaginary tuple y0 = (jnp.real(S_initial_complex), jnp.imag(S_initial_complex)) # Setup and solve ODE term = diffrax.ODETerm(s_derivative) solver = diffrax.Dopri5() stepsize_controller = diffrax.PIDController( rtol=self.options['rtol'], atol=self.options['atol'] ) solution = diffrax.diffeqsolve( term, solver, t0=0.0, t1=1.0, dt0=0.1, y0=y0, stepsize_controller=stepsize_controller, max_steps=self.options['max_steps'] ) # Unpack the final tuple state back into a standard complex S-matrix array S_final_real, S_final_imag = solution.ys[0][-1], solution.ys[1][-1] return S_final_real + 1j * S_final_imag