Source code for pmrf.functions.math

from typing import Callable
from functools import partial

import jax
import jax.numpy as jnp
from jax.numpy import imag, pi, real, unwrap
from jax import lax
from jax.scipy.special import gammaln
from jax._src.numpy.ufuncs import _constant_like

from pmrf.constants import NumberLike, INF, LOG_OF_NEG

[docs] def complex_2_magnitude(z: NumberLike): """ Return the magnitude of the complex argument. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- mag : ndarray or scalar The absolute value of the input. """ return jnp.abs(z)
[docs] def complex_2_db(z: NumberLike): r""" Return the magnitude in dB of a complex number (as :math:`20\log_{10}(|z|)`). The magnitude in dB is defined as :math:`20\log_{10}(|z|)` where :math:`z` is a complex number. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- mag20dB : ndarray or scalar The magnitude in decibels. """ return magnitude_2_db(jnp.abs(z))
[docs] def complex_2_db10(z: NumberLike): r""" Return the magnitude in dB of a complex number (as :math:`10\log_{10}(|z|)`). The magnitude in dB is defined as :math:`10\log_{10}(|z|)` where :math:`z` is a complex number. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- mag10dB : ndarray or scalar The magnitude in decibels (power factor 10). """ return mag_2_db10(jnp.abs(z))
[docs] def complex_2_radian(z: NumberLike): """ Return the angle complex argument in radian. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- ang_rad : ndarray or scalar The counterclockwise angle from the positive real axis on the complex plane in the range ``(-pi, pi]``, with dtype as numpy.float64. """ return jnp.angle(z)
[docs] def complex_2_degree(z: NumberLike): """ Returns the angle complex argument in degrees. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- ang_deg : ndarray or scalar The angle in degrees. """ return jnp.angle(z, deg=True)
[docs] def complex_2_quadrature(z: NumberLike): r""" Take a complex number and returns quadrature, which is (length, arc-length from real axis). Arc-length is calculated as :math:`|z| \arg(z)`. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- mag : array like or scalar Magnitude (length). arc_length : array like or scalar Arc-length from real axis: angle * magnitude. """ return (jnp.abs(z), jnp.angle(z)*jnp.abs(z))
[docs] def complex_2_reim(z: NumberLike): """ Return real and imaginary parts of a complex number. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- real : array like or scalar Real part of input. imag : array like or scalar Imaginary part of input. """ return (jnp.real(z), jnp.imag(z))
[docs] def complex_components(z: NumberLike): """ Break up a complex array into all possible scalar components. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- c_real : array like or scalar Real part. c_imag : array like or scalar Imaginary part. c_angle : array like or scalar Angle in degrees. c_mag : array like or scalar Magnitude. c_arc : array like or scalar Arclength from real axis, angle*magnitude. """ return (*complex_2_reim(z), jnp.angle(z,deg=True), *complex_2_quadrature(z))
[docs] def magnitude_2_db(z: NumberLike, zero_nan: bool = True): """ Convert linear magnitude to dB. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. zero_nan : bool, optional, default=True Replace NaN with zero. Returns ------- z : number or array_like Magnitude in dB given by :math:`20 \log_{10}(|z|)`. """ out = 20 * jnp.log10(z) if zero_nan: return jnp.nan_to_num(out, nan=LOG_OF_NEG, neginf=-jnp.inf) return out
mag_2_db = magnitude_2_db
[docs] def mag_2_db10(z: NumberLike, zero_nan:bool = True): """ Convert linear magnitude to dB (factor 10). Parameters ---------- z : array_like A complex number or sequence of complex numbers. zero_nan : bool, optional, default=True Replace NaN with zero. Returns ------- z : array_like Magnitude in dB given by :math:`10 \log_{10}(|z|)`. """ out = 10 * jnp.log10(z) if zero_nan: return jnp.nan_to_num(out, nan=LOG_OF_NEG, neginf=-jnp.inf) return out
[docs] def db_2_magnitude(z: NumberLike): """ Convert dB to linear magnitude. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- z : number or array_like 10**((z)/20) where z is a complex number. """ return 10**((z)/20.)
db_2_mag = db_2_magnitude
[docs] def db10_2_mag(z: NumberLike): """ Convert dB (factor 10) to linear magnitude. Parameters ---------- z : array_like A complex number or sequence of complex numbers. Returns ------- z : array_like 10**((z)/10) where z is a complex number. """ return 10**((z)/10.)
[docs] def magdeg_2_reim(mag: NumberLike, deg: NumberLike): """ Convert linear magnitude and phase (in deg) arrays into a complex array. Parameters ---------- mag : number or array_like Magnitude. deg : number or array_like Phase in degrees. Returns ------- z : array_like A complex number or sequence of complex numbers. """ return mag*jnp.exp(1j*deg*pi/180.)
[docs] def dbdeg_2_reim(db: NumberLike, deg: NumberLike): """ Convert dB magnitude and phase (in deg) arrays into a complex array. Parameters ---------- db : number or array_like Magnitude in dB. deg : number or array_like Phase in degrees. Returns ------- z : array_like A complex number or sequence of complex numbers. """ return magdeg_2_reim(db_2_magnitude(db), deg)
[docs] def db_2_np(db: NumberLike): """ Convert a value in decibel (dB) to neper (Np). Parameters ---------- db : number or array_like A real number or sequence of real numbers. Returns ------- np : number or array_like A real number of sequence of real numbers. """ return (jnp.log(10)/20) * db
[docs] def np_2_db(x: NumberLike): """ Convert a value in Nepers (Np) to decibel (dB). Parameters ---------- x : number or array_like A real number or sequence of real numbers. Returns ------- db : number or array_like A real number of sequence of real numbers. """ return 20/jnp.log(10) * x
[docs] def radian_2_degree(rad: NumberLike): """ Convert angles from radians to degrees. Parameters ---------- rad : number or array_like Angle in radian. Returns ------- deg : number or array_like Angle in degree. """ return (rad)*180/pi
[docs] def degree_2_radian(deg: NumberLike): """ Convert angles from degrees to radians. Parameters ---------- deg : number or array_like Angle in degrees. Returns ------- rad : number or array_like Angle in radians. """ return (deg)*pi/180.
[docs] def feet_2_meter(feet: NumberLike = 1): """ Convert length in feet to meter. 1 foot is equal to 0.3048 meters. Parameters ---------- feet : number or array-like, optional, default=1 Length in feet. Returns ------- meter: number or array-like Length in meter. See Also -------- meter_2_feet """ return 0.3048*feet
[docs] def meter_2_feet(meter: NumberLike = 1): """ Convert length in meter to feet. 1 meter is equal to 3.28084 feet. Parameters ---------- meter : number or array-like, optional, default=1 Length in meter. Returns ------- feet : number or array-like Length in feet. See Also -------- feet_2_meter """ return 3.28084*meter
[docs] def db_per_100feet_2_db_per_100meter(db_per_100feet: NumberLike = 1): """ Convert attenuation values given in dB/100ft to dB/100m. db_per_100meter = db_per_100feet * rf.meter_2_feet() Parameters ---------- db_per_100feet : number or array-like, optional, default=1 Attenuation in dB/100 ft. Returns ------- db_per_100meter : number or array-like Attenuation in dB/100 m. See Also -------- meter_2_feet feet_2_meter np_2_db db_2_np """ return db_per_100feet * 100 / feet_2_meter(100)
[docs] def unwrap_rad(phi: NumberLike): """ Unwrap a phase given in radians. Parameters ---------- phi : number or array_like Phase in radians. Returns ------- phi : number of array_like Unwrapped phase in radians. """ return unwrap(phi, axis=0)
[docs] def sqrt_known_sign(z_squared: NumberLike, z_approx: NumberLike): """ Return the square root of a complex number, with sign chosen to match `z_approx`. Parameters ---------- z_squared : number or array-like The complex to be square-rooted. z_approx : number or array-like The approximate value of z. The sign of z is chosen to match that of z_approx. Returns ------- z : number, array-like (same type as z_squared) Square root of z_squared. """ z = jnp.sqrt(z_squared) return jnp.where( jnp.sign(jnp.angle(z)) == jnp.sign(jnp.angle(z_approx)), z, z.conj())
[docs] def find_correct_sign(z1: NumberLike, z2: NumberLike, z_approx: NumberLike): r""" Create new vector from z1, z2 choosing elements with sign matching z_approx. This is used when you have to make a root choice on a complex number. and you know the approximate value of the root. .. math:: z1,z2 = \pm \sqrt(z^2) Parameters ---------- z1 : array-like Root 1. z2 : array-like Root 2. z_approx : array-like Approximate answer of z. Returns ------- z3 : np.array Array built from z1 and z2 by z1 where sign(z1) == sign(z_approx), z2 else. """ return jnp.where( jnp.sign(jnp.angle(z1)) == jnp.sign(jnp.angle(z_approx)),z1, z2)
[docs] def find_closest(z1: NumberLike, z2: NumberLike, z_approx: NumberLike): """ Return z1 or z2 depending on which is closer to z_approx. Parameters ---------- z1 : array-like Root 1. z2 : array-like Root 2. z_approx : array-like Approximate answer of z. Returns ------- z3 : np.array Array built from z1 and z2. """ z1_dist = abs(z1-z_approx) z2_dist = abs(z2-z_approx) return jnp.where(z1_dist<z2_dist,z1, z2)
[docs] def sqrt_phase_unwrap(z: NumberLike): r""" Take the square root of a complex number with unwrapped phase. This idea came from Lihan Chen. .. math:: \sqrt{|z|} \exp( \arg_{unwrap}(z) / 2 ) Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- z : number of array_like A complex number or sequence of complex numbers. """ return jnp.sqrt(abs(z))*\ jnp.exp(0.5*1j*unwrap_rad(complex_2_radian(z)))
# mathematical functions
[docs] def dirac_delta(x: NumberLike): r""" Calculate Dirac function. Dirac function :math:`\delta(x)` defined as :math:`\delta(x)=1` if x=0, 0 otherwise. Parameters ---------- x : number of array_like A real number or sequence of real numbers. Returns ------- delta : number of array_like 1 or 0. References ---------- https://en.wikipedia.org/wiki/Dirac_delta_function """ return (x==0)*1. + (x!=0)*0.
[docs] def neuman(x: NumberLike): r""" Calculate Neumans number. It is defined as: .. math:: 2 - \delta(x) where :math:`\delta` is the Dirac function. Parameters ---------- x : number or array_like A real number or sequence of real numbers. Returns ------- y : number or array_like A real number or sequence of real numbers. See Also -------- dirac_delta """ return 2. - dirac_delta(x)
[docs] def null(A: jnp.ndarray, eps: float = 1e-15): """ Calculate the null space of matrix A. Parameters ---------- A : array_like Input matrix. eps : float, optional, default=1e-15 Epsilon value for singular value thresholding. Returns ------- null_space : array_like The null space of A. References ---------- https://scipy-cookbook.readthedocs.io/items/RankNullspace.html https://stackoverflow.com/questions/5889142/python-numpy-scipy-finding-the-null-space-of-a-matrix """ u, s, vh = jnp.linalg.svd(A) null_space = jnp.compress(s <= eps, vh, axis=0) return null_space.T
[docs] def inf_to_num(x: NumberLike): """ Convert inf and -inf's to large numbers. Parameters ---------- x : array-like or number The input array or number. Returns ------- x : Number of array_like Input without with +/- inf replaced by large numbers. """ x = jnp.nan_to_num(x, nan=jnp.nan, posinf=INF, neginf=-1*INF) return x
[docs] def cross_ratio(a: NumberLike, b: NumberLike, c: NumberLike, d:NumberLike): r""" Calculate the cross ratio of a quadruple of distinct points on the real line. The cross ratio is defined as: .. math:: r = \frac{ (a-b)(c-d) }{ (a-d)(c-b) } Parameters ---------- a,b,c,d : array-like or number Input points. Returns ------- r : array-like or number The cross ratio. References ---------- https://en.wikipedia.org/wiki/Cross-ratio """ return ((a-b)*(c-d))/((a-d)*(c-b))
[docs] def complexify(f: Callable, name: str = None): """ Make a function f(scalar) into f(complex). If `f(x)` then it returns `f_c(z) = f(real(z)) + 1j*f(imag(z))` If the real/imag arguments are not first, then you may specify the name given to them as kwargs. Parameters ---------- f : Callable Function of real variable. name : string, optional Name of the real/imag argument names if they are not first. Returns ------- f_c : Callable Function of a complex variable. Examples ---------- >>> def f(x): return x >>> f_c = rf.complexify(f) >>> z = 0.2 -1j*0.3 >>> f_c(z) """ def f_c(z, *args, **kw): if name is not None: kw_re = {name: real(z)} kw_im = {name: imag(z)} kw_re.update(kw) kw_im.update(kw) return f(*args, **kw_re) + 1j*f(*args, **kw_im) else: return f(real(z), *args,**kw) + 1j*f(imag(z), *args, **kw) return f_c
# old functions just for reference
[docs] def complex2Scalar(z: NumberLike): """ Serialize a list/array of complex numbers. Parameters ---------- z : number or array_like A complex number or sequence of complex numbers. Returns ------- re_im : array_like Produce the following array for an input z: z[0].real, z[0].imag, z[1].real, z[1].imag, etc. See Also -------- scalar2Complex """ z = jnp.array(z) re_im = [] for k in z: re_im.append(jnp.real(k)) re_im.append(jnp.imag(k)) return jnp.array(re_im).flatten()
[docs] def scalar2Complex(s: NumberLike): """ Unserialize a list/array of real and imag numbers into a complex array. Inverse of :func:`complex2Scalar`. Parameters ---------- s : array_like An array with real and imaginary parts ordered as: z[0].real, z[0].imag, z[1].real, z[1].imag, etc. Returns ------- z : Number or array_like A complex number or sequence of complex number. See Also -------- complex2Scalar """ s = jnp.array(s) z = [] for k in range(0,len(s),2): z.append(s[k] + 1j*s[k+1]) return jnp.array(z).flatten()
[docs] def multiply_by(x: jnp.ndarray, by: jnp.ndarray, axis=None) -> jnp.ndarray: """ Broadcast multiply array `x` by array `by` along a specified axis. This function tiles `by` to match the dimensions of `x` implied by block multiplication. Parameters ---------- x : jnp.ndarray Input array. by : jnp.ndarray Array to multiply with. axis : int, optional Axis along which to apply the multiplication (0 for rows, 1 for columns). Returns ------- jnp.ndarray Result of the multiplication. Raises ------ ValueError If the shape of the specified axis is not divisible by the length of `by`. """ if by.shape == x.shape: x *= by else: n = len(by) if len(x.shape) == 1: if axis == 0: x = x.reshape(len(x), 1) else: x = x.reshape(1, len(x)) if axis and x.shape[axis] % n != 0: raise ValueError(f"The length of the specified axis ({x.shape[axis]}) is not divisible by {n}.") if axis == 0: by = jnp.tile(by, (int(x.shape[0] / n), x.shape[1])) elif axis == 1: by = jnp.tile(by, (x.shape[0], int(x.shape[1] / n))) x *= by return x
[docs] def sum_every(x: jnp.ndarray, n: int, axis=None) -> jnp.ndarray: """ Sum blocks of size `n` along a specified axis. Parameters ---------- x : jnp.ndarray Input array. n : int Block size to sum. axis : int, optional Axis along which to sum (0 for rows, 1 for columns). Returns ------- jnp.ndarray Resulting summed array. Raises ------ ValueError If the shape of the specified axis is not divisible by `n`. """ if len(x.shape) == 1 and len(x) % n == 0: if axis == 0: x = x.reshape(len(x), 1) else: x = x.reshape(1, len(x)) # Ensure the axis length is divisible by i if x.shape[axis] % n != 0: raise ValueError(f"The length of the specified axis ({x.shape[axis]}) is not divisible by {n}.") # Reshape and sum if axis == 0: # Group rows x = x.reshape(-1, n, x.shape[1]).sum(axis=1) elif axis == 1: # Group columns x = x.reshape(x.shape[0], -1, n).sum(axis=2) return x
[docs] def multiply_every(x: jnp.ndarray, n: int, axis=None) -> jnp.ndarray: """ Multiply blocks of size `n` along a specified axis. Parameters ---------- x : jnp.ndarray Input array. n : int Block size to multiply. axis : int, optional Axis along which to multiply (0 for rows, 1 for columns). Returns ------- jnp.ndarray Resulting product array. Raises ------ ValueError If the shape of the specified axis is not divisible by `n`. """ if len(x.shape) == 1 and len(x) % n == 0: if axis == 0: x = x.reshape(len(x), 1) else: x = x.reshape(1, len(x)) # Ensure the axis length is divisible by i if x.shape[axis] % n != 0: raise ValueError(f"The length of the specified axis ({x.shape[axis]}) is not divisible by {n}.") # Reshape and sum if axis == 0: # Group rows x = x.reshape(-1, n, x.shape[1]).prod(axis=1) elif axis == 1: # Group columns x = x.reshape(x.shape[0], -1, n).prod(axis=2) return x
[docs] def convolve_interleaved(x, axis=1) -> jnp.ndarray: """ Convolve interleaved signals (e.g. real/imag parts stored sequentially). Parameters ---------- x : jnp.ndarray Input array. axis : int, optional, default=1 Axis along which to convolve. Returns ------- jnp.ndarray The convolved result. Raises ------ Exception If the input is not 1D or axis is not 1 (not yet implemented). """ if axis == 1: if len(x.shape) == 1: y1, y2 = x[0::2], x[1::2] x = jnp.convolve(y1, y2) else: raise Exception("Not yet implemented") else: raise Exception("Not yet implemented") return x
[docs] def rms(x): """ Calculate the Root Mean Square (RMS) of the input. Parameters ---------- x : jnp.ndarray Input array. Returns ------- float or jnp.ndarray The RMS value. """ return jnp.sqrt(jnp.mean(x**2))
[docs] def mag_2_db(x): """ Convert magnitude to decibels. Parameters ---------- x : jnp.ndarray Linear magnitude. Returns ------- jnp.ndarray Value in dB. """ return 20 * jnp.log10(jnp.abs(x))
[docs] def db_2_mag(x): """ Convert decibels to magnitude. Parameters ---------- x : jnp.ndarray Value in dB. Returns ------- jnp.ndarray Linear magnitude. """ return 10 ** (x / 20)
[docs] def polar_2_rect(radii, angles, deg=False): """ Convert polar coordinates to rectangular (complex) coordinates. Parameters ---------- radii : jnp.ndarray Radius (magnitude). angles : jnp.ndarray Angle (phase). deg : bool, optional, default=False If True, angles are in degrees. Returns ------- jnp.ndarray Complex number in rectangular coordinates. """ if deg: angles = jnp.deg2rad(angles) return radii * jnp.exp(1j*angles)
[docs] def rect_2_polar(x, deg=False): """ Convert rectangular (complex) coordinates to polar coordinates. Parameters ---------- x : jnp.ndarray Complex number input. deg : bool, optional, default=False If True, return angle in degrees. Returns ------- radii : jnp.ndarray Magnitude. angles : jnp.ndarray Phase angle. """ return abs(x), jnp.angle(x, deg=deg)
[docs] def rsolve(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray: r""" Solves x @ A = B. Calls `numpy.linalg.solve` with transposed matrices. Equivalent to `B @ np.linalg.inv(A)` but avoids calculating the inverse explicitely. Input should have dimension of similar to (nfreqs, nports, nports). Parameters ---------- A : np.ndarray Matrix A. B : np.ndarray Matrix B. Returns ------- x : np.ndarray Solution matrix. """ return jnp.transpose(jnp.linalg.solve(jnp.transpose(A, (0, 2, 1)).conj(), jnp.transpose(B, (0, 2, 1)).conj()), (0, 2, 1)).conj()
[docs] def nudge_eig(mat: jnp.ndarray, cond: float | None = None, min_eig: float | None = None) -> jnp.ndarray: r""" Nudge eigenvalues with absolute value smaller than `max(cond * max(eigenvalue), min_eig)` to that value. Can be used to avoid singularities in solving matrix equations. Input should have dimension of similar to (nfreqs, nports, nports). Parameters ---------- mat : np.ndarray Matrices to nudge. cond : float, optional Minimum eigenvalue ratio compared to the maximum eigenvalue. Default value is `1e-9`. min_eig : float, optional Minimum eigenvalue. Default value is `1e-12`. Returns ------- res : np.ndarray Nudged matrices. """ # use current constants EIG_COND = 1e-9 EIG_MIN = 1e-12 if not cond: cond = EIG_COND if not min_eig: min_eig = EIG_MIN eigw, eigv = jnp.linalg.eig(mat) max_eig = jnp.amax(jnp.abs(eigw), axis=1) mask = jnp.logical_or(jnp.abs(eigw) < cond * max_eig[:, None], jnp.abs(eigw) < min_eig) has_problem = mask.any() def fix_branch(): nonlocal eigw # mask_cond = cond * jnp.repeat(max_eig[:, None], mat.shape[-1], axis=-1)[mask] # mask_min = min_eig * jnp.ones(mask_cond.shape) # eigw[mask] = jnp.maximum(mask_cond, mask_min) mask_array = cond * jnp.repeat(max_eig[:, None], mat.shape[-1], axis=-1) mask_min = min_eig * jnp.ones_like(mask_array) eigw = jnp.where(mask, jnp.maximum(mask_array, mask_min), eigw) # Now assemble the eigendecomposited matrices back e = jnp.zeros_like(mat) # other = jnp.einsum('ijj->ij', e) # e = jnp.einsum('ijj->ij', e).at[...].set(eigw) # e = e.at[jnp.diag_indices(e.shape[1], e.shape[2])].set(eigw) rows = jnp.arange(e.shape[1]) e = e.at[jnp.arange(e.shape[0])[:, None], rows, rows].set(eigw) return rsolve(eigv, eigv @ e) def no_fix_branch(): return mat return jax.lax.cond(has_problem, fix_branch, no_fix_branch)
[docs] def round_sig(x, sig=3): """ Round to a specific number of significant digits. Parameters ---------- x : float Number to round. sig : int, optional, default=3 Number of significant digits. Returns ------- float Rounded number. """ if x == 0: return 0 return round(x, sig - int(jnp.floor(jnp.log10(abs(x)))) - 1)
[docs] def comb(N: jnp.ndarray, k: jnp.ndarray, exact: bool = False, repetition: bool = False): r""" The number of combinations of N things taken k at a time. This is often expressed as "N choose k". When exact=False, the result is approximately and efficiently computed using the following formula: .. math:: \exp\left\{\ln{\Gamma(N+1)} - [\ln{\Gamma(k+1)} + \ln{\Gamma(N+1-k)}]\right\} using the Gamma function. Parameters ---------- N : np.ndarray The number of things. k : np.ndarray The number of elements taken. exact : bool, optional If `True`, the result is computed exactly and returned as an integer type. Currently, vectorization is not supported for exact=True. repetition : bool, optional If `repetition` is True, then the number of combinations with repetition is computed. Returns ------- comb : np.ndarray The number of combinations of N things taken k at a time. """ if repetition: return comb(N + k - 1, k, exact=exact, repetition=False) if exact: max_divisor = lax.max(k, N - k) min_divisor = lax.min(k, N - k) N_factorial_over_max_factorial = jnp.prod(jnp.arange(N, max_divisor, -1)) return lax.div(N_factorial_over_max_factorial, jnp.prod(jnp.arange(1, min_divisor + 1))) one = _constant_like(N, 1) N_plus_1 = lax.add(N,one) k_plus_1 = lax.add(k,one) return lax.exp(lax.sub(gammaln(N_plus_1),lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1,k)))))
[docs] def evaluate_power_basis(x, coeffs, lower_bound, upper_bound): """ Evaluate a polynomial in the power basis. The input `x` is normalized to the range `[0, 1]` based on bounds. Parameters ---------- x : jnp.ndarray Input values. coeffs : jnp.ndarray Polynomial coefficients. lower_bound : float Lower bound of the domain. upper_bound : float Upper bound of the domain. Returns ------- jnp.ndarray Evaluated polynomial values. """ coeffs = jnp.asarray(coeffs) x_norm = (x - lower_bound) / (upper_bound - lower_bound) return jnp.polyval(coeffs[::-1], x_norm)
[docs] def evaluate_bernstein_basis(x, coeffs, lower_bound, upper_bound): """ Evaluate a polynomial in the Bernstein basis. Parameters ---------- x : jnp.ndarray Input values. coeffs : jnp.ndarray Coefficients (control points) of the Bernstein polynomial. lower_bound : float Lower bound of the domain. upper_bound : float Upper bound of the domain. Returns ------- jnp.ndarray Evaluated values. """ coeffs = jnp.asarray(coeffs) n = len(coeffs) - 1 # Degree of the polynomial i = jnp.arange(n + 1) binomial_coeffs = comb(n, i) t = (x - lower_bound) / (upper_bound - lower_bound) def _eval_single(t_scalar): basis_values = jnp.power(t_scalar, i) * jnp.power(1 - t_scalar, n - i) return jnp.dot(coeffs, binomial_coeffs * basis_values) result = jax.vmap(_eval_single)(jnp.atleast_1d(t)) return result
# Aliases, lookups and partials conv_inter = convolve_interleaved dB20 = mag_2_db l1_norm = partial(jnp.linalg.norm, ord=1) l1_norm_ax0 = partial(jnp.linalg.norm, ord=1, axis=0) l1_norm_ax1 = partial(jnp.linalg.norm, ord=1, axis=1) l2_norm = partial(jnp.linalg.norm, ord=2) l2_norm_ax0 = partial(jnp.linalg.norm, ord=2, axis=0) l2_norm_ax1 = partial(jnp.linalg.norm, ord=2, axis=1) conv_cost = lambda x: dB20(l2_norm_ax0(conv_inter(l2_norm_ax0(x)))) FUNC_LOOKUP: dict[str, tuple[str, Callable | None]] = { 're': ('Real Part', jnp.real), 'im': ('Imag Part', jnp.imag), 'mag': ('Magnitude', jnp.abs), 'db': ('Magnitude (dB)', complex_2_db), 'db10': ('Magnitude (dB)', complex_2_db10), 'rad': ('Phase (rad)', jnp.angle), 'deg': ('Phase (deg)', lambda x: jnp.angle(x, deg=True)), 'arcl': ('Arc Length',lambda x: jnp.angle(x) * jnp.abs(x)), 'rad_unwrap': ('Phase (rad)', lambda x: unwrap_rad(jnp.angle(x))), 'deg_unwrap': ('Phase (deg)', lambda x: radian_2_degree(unwrap_rad(jnp.angle(x)))), 'arcl_unwrap': ('Arc Length', lambda x: unwrap_rad(jnp.angle(x)) * jnp.abs(x)), 'vswr': ('VSWR', lambda x: (1 + jnp.abs(x)) / (1 - jnp.abs(x))), # 'time': ('Time (real)', mf.ifft), # 'time_db': ('Magnitude (dB)', lambda x: mf.complex_2_db(mf.ifft(x))), # 'time_mag': ('Magnitude', lambda x: mf.complex_2_magnitude(mf.ifft(x))), }