from typing import Callable
from functools import partial
import jax
import jax.numpy as jnp
from jax.numpy import imag, pi, real, unwrap
from jax import lax
from jax.scipy.special import gammaln
from jax._src.numpy.ufuncs import _constant_like
from pmrf.constants import NumberLike, INF, LOG_OF_NEG
[docs]
def complex_2_magnitude(z: NumberLike):
"""
Return the magnitude of the complex argument.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
mag : ndarray or scalar
"""
return jnp.abs(z)
[docs]
def complex_2_db(z: NumberLike):
r"""
Return the magnitude in dB of a complex number (as :math:`20\log_{10}(|z|)`)..
The magnitude in dB is defined as :math:`20\log_{10}(|z|)`
where :math:`z` is a complex number.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
mag20dB : ndarray or scalar
"""
return magnitude_2_db(jnp.abs(z))
[docs]
def complex_2_db10(z: NumberLike):
r"""
Return the magnitude in dB of a complex number (as :math:`10\log_{10}(|z|)`).
The magnitude in dB is defined as :math:`10\log_{10}(|z|)`
where :math:`z` is a complex number.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
mag10dB : ndarray or scalar
"""
return mag_2_db10(jnp.abs(z))
[docs]
def complex_2_radian(z: NumberLike):
"""
Return the angle complex argument in radian.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
ang_rad : ndarray or scalar
The counterclockwise angle from the positive real axis on the complex
plane in the range ``(-pi, pi]``, with dtype as numpy.float64.
"""
return jnp.angle(z)
[docs]
def complex_2_degree(z: NumberLike):
"""
Returns the angle complex argument in degree.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
ang_deg : ndarray or scalar
"""
return jnp.angle(z, deg=True)
[docs]
def complex_2_quadrature(z: NumberLike):
r"""
Take a complex number and returns quadrature, which is (length, arc-length from real axis)
Arc-length is calculated as :math:`|z| \arg(z)`.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
mag : array like or scalar
magnitude (length)
arc_length : array like or scalar
arc-length from real axis: angle*magnitude
"""
return (jnp.abs(z), jnp.angle(z)*jnp.abs(z))
[docs]
def complex_2_reim(z: NumberLike):
"""
Return real and imaginary parts of a complex number.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
real : array like or scalar
real part of input
imag : array like or scalar
imaginary part of input
"""
return (jnp.real(z), jnp.imag(z))
[docs]
def complex_components(z: NumberLike):
"""
Break up a complex array into all possible scalar components.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
c_real : array like or scalar
real part
c_imag : array like or scalar
imaginary part
c_angle : array like or scalar
angle in degrees
c_mag : array like or scalar
magnitude
c_arc : array like or scalar
arclength from real axis, angle*magnitude
"""
return (*complex_2_reim(z), jnp.angle(z,deg=True), *complex_2_quadrature(z))
[docs]
def magnitude_2_db(z: NumberLike, zero_nan: bool = True):
"""
Convert linear magnitude to dB.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
zero_nan : bool, optional
Replace NaN with zero. The default is True.
Returns
-------
z : number or array_like
Magnitude in dB given by 20*log10(|z|)
"""
out = 20 * jnp.log10(z)
if zero_nan:
return jnp.nan_to_num(out, nan=LOG_OF_NEG, neginf=-jnp.inf)
return out
mag_2_db = magnitude_2_db
[docs]
def mag_2_db10(z: NumberLike, zero_nan:bool = True):
"""
Convert linear magnitude to dB.
Parameters
----------
z : array_like
A complex number or sequence of complex numbers
zero_nan : bool, optional
Replace NaN with zero. The default is True.
Returns
-------
z : array_like
Magnitude in dB given by 10*log10(|z|)
"""
out = 10 * jnp.log10(z)
if zero_nan:
return jnp.nan_to_num(out, nan=LOG_OF_NEG, neginf=-jnp.inf)
return out
[docs]
def db_2_magnitude(z: NumberLike):
"""
Convert dB to linear magnitude.
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
z : number or array_like
10**((z)/20) where z is a complex number
"""
return 10**((z)/20.)
db_2_mag = db_2_magnitude
[docs]
def db10_2_mag(z: NumberLike):
"""
Convert dB to linear magnitude.
Parameters
----------
z : array_like
A complex number or sequence of complex numbers
Returns
-------
z : array_like
10**((z)/10) where z is a complex number
"""
return 10**((z)/10.)
[docs]
def magdeg_2_reim(mag: NumberLike, deg: NumberLike):
"""
Convert linear magnitude and phase (in deg) arrays into a complex array.
Parameters
----------
mag : number or array_like
A complex number or sequence of real numbers
deg : number or array_like
A complex number or sequence of real numbers
Returns
-------
z : array_like
A complex number or sequence of complex numbers
"""
return mag*jnp.exp(1j*deg*pi/180.)
[docs]
def dbdeg_2_reim(db: NumberLike, deg: NumberLike):
"""
Converts dB magnitude and phase (in deg) arrays into a complex array.
Parameters
----------
db : number or array_like
A realnumber or sequence of real numbers
deg : number or array_like
A real number or sequence of real numbers
Returns
-------
z : array_like
A complex number or sequence of complex numbers
"""
return magdeg_2_reim(db_2_magnitude(db), deg)
[docs]
def db_2_np(db: NumberLike):
"""
Converts a value in decibel (dB) to neper (Np).
Parameters
----------
db : number or array_like
A real number or sequence of real numbers
Returns
-------
np : number or array_like
A real number of sequence of real numbers
"""
return (jnp.log(10)/20) * db
[docs]
def np_2_db(x: NumberLike):
"""
Converts a value in Nepers (Np) to decibel (dB).
Parameters
----------
np : number or array_like
A real number or sequence of real numbers
Returns
-------
db : number or array_like
A real number of sequence of real numbers
"""
return 20/jnp.log(10) * x
[docs]
def radian_2_degree(rad: NumberLike):
"""
Convert angles from radians to degrees.
Parameters
----------
rad : number or array_like
Angle in radian
Returns
-------
deg : number or array_like
Angle in degree
"""
return (rad)*180/pi
[docs]
def degree_2_radian(deg: NumberLike):
"""
Convert angles from degrees to radians.
Parameters
----------
deg : number or array_like
Angle in radian
Returns
-------
rad : number or array_like
Angle in degree
"""
return (deg)*pi/180.
[docs]
def feet_2_meter(feet: NumberLike = 1):
"""
Convert length in feet to meter.
1 foot is equal to 0.3048 meters.
Parameters
----------
feet : number or array-like, optional
length in feet. Default is 1.
Returns
-------
meter: number or array-like
length in meter
See Also
--------
meter_2_feet
"""
return 0.3048*feet
[docs]
def meter_2_feet(meter: NumberLike = 1):
"""
Convert length in meter to feet.
1 meter is equal to 0.3.28084 feet.
Parameters
----------
meter : number or array-like, optional
length in meter. Default is 1.
Returns
-------
feet : number or array-like
length in feet
See Also
--------
feet_2_meter
"""
return 3.28084*meter
[docs]
def db_per_100feet_2_db_per_100meter(db_per_100feet: NumberLike = 1):
"""
Convert attenuation values given in dB/100ft to dB/100m.
db_per_100meter = db_per_100feet * rf.meter_2_feet()
Parameters
----------
db_per_100feet : number or array-like, optional
Attenuation in dB/ 100 ft. Default is 1.
Returns
-------
db_per_100meter : number or array-like
Attenuation in dB/ 100 m
See Also
--------
meter_2_feet
feet_2_meter
np_2_db
db_2_np
"""
return db_per_100feet * 100 / feet_2_meter(100)
[docs]
def unwrap_rad(phi: NumberLike):
"""
Unwraps a phase given in radians.
Parameters
----------
phi : number or array_like
phase in radians
Returns
-------
phi : number of array_like
unwrapped phase in radians
"""
return unwrap(phi, axis=0)
[docs]
def sqrt_known_sign(z_squared: NumberLike, z_approx: NumberLike):
"""
Return the square root of a complex number, with sign chosen to match `z_approx`.
Parameters
----------
z_squared : number or array-like
the complex to be square-rooted
z_approx : number or array-like
the approximate value of z. sign of z is chosen to match that of
z_approx
Returns
-------
z : number, array-like (same type as z_squared)
square root of z_squared.
"""
z = jnp.sqrt(z_squared)
return jnp.where(
jnp.sign(jnp.angle(z)) == jnp.sign(jnp.angle(z_approx)),
z, z.conj())
[docs]
def find_correct_sign(z1: NumberLike, z2: NumberLike, z_approx: NumberLike):
r"""
Create new vector from z1, z2 choosing elements with sign matching z_approx.
This is used when you have to make a root choice on a complex number.
and you know the approximate value of the root.
.. math::
z1,z2 = \pm \sqrt(z^2)
Parameters
----------
z1 : array-like
root 1
z2 : array-like
root 2
z_approx : array-like
approximate answer of z
Returns
-------
z3 : np.array
array built from z1 and z2 by
z1 where sign(z1) == sign(z_approx), z2 else
"""
return jnp.where(
jnp.sign(jnp.angle(z1)) == jnp.sign(jnp.angle(z_approx)),z1, z2)
[docs]
def find_closest(z1: NumberLike, z2: NumberLike, z_approx: NumberLike):
"""
Return z1 or z2 depending on which is closer to z_approx.
Parameters
----------
z1 : array-like
root 1
z2 : array-like
root 2
z_approx : array-like
approximate answer of z
Returns
-------
z3 : np.array
array built from z1 and z2
"""
z1_dist = abs(z1-z_approx)
z2_dist = abs(z2-z_approx)
return jnp.where(z1_dist<z2_dist,z1, z2)
[docs]
def sqrt_phase_unwrap(z: NumberLike):
r"""
Take the square root of a complex number with unwrapped phase.
This idea came from Lihan Chen.
.. math::
\sqrt{|z|} \exp( \arg_{unwrap}(z) / 2 )
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
z : number of array_like
A complex number or sequence of complex numbers
"""
return jnp.sqrt(abs(z))*\
jnp.exp(0.5*1j*unwrap_rad(complex_2_radian(z)))
# mathematical functions
[docs]
def dirac_delta(x: NumberLike):
r"""
Calculate Dirac function.
Dirac function :math:`\delta(x)` defined as :math:`\delta(x)=1` if x=0,
0 otherwise.
Parameters
----------
x : number of array_like
A real number or sequence of real numbers
Returns
-------
delta : number of array_like
1 or 0
References
----------
https://en.wikipedia.org/wiki/Dirac_delta_function
"""
return (x==0)*1. + (x!=0)*0.
[docs]
def neuman(x: NumberLike):
r"""
Calculate Neumans number.
It is defined as:
.. math::
2 - \delta(x)
where :math:`\delta` is the Dirac function.
Parameters
----------
x : number or array_like
A real number or sequence of real numbers
Returns
-------
y : number or array_like
A real number or sequence of real numbers
See Also
--------
dirac_delta
"""
return 2. - dirac_delta(x)
[docs]
def null(A: jnp.ndarray, eps: float = 1e-15):
"""
Calculate the null space of matrix A.
Parameters
----------
A : array_like
eps : float
Returns
-------
null_space : array_like
References
----------
https://scipy-cookbook.readthedocs.io/items/RankNullspace.html
https://stackoverflow.com/questions/5889142/python-numpy-scipy-finding-the-null-space-of-a-matrix
"""
u, s, vh = jnp.linalg.svd(A)
null_space = jnp.compress(s <= eps, vh, axis=0)
return null_space.T
[docs]
def inf_to_num(x: NumberLike):
"""
Convert inf and -inf's to large numbers.
Parameters
------------
x : array-like or number
the input array or number
Returns
-------
x : Number of array_like
Input without with +/- inf replaced by large numbers
"""
x = jnp.nan_to_num(x, nan=jnp.nan, posinf=INF, neginf=-1*INF)
return x
[docs]
def cross_ratio(a: NumberLike, b: NumberLike, c: NumberLike, d:NumberLike):
r"""
Calculate the cross ratio of a quadruple of distinct points on the real line.
The cross ratio is defined as:
.. math::
r = \frac{ (a-b)(c-d) }{ (a-d)(c-b) }
Parameters
----------
a,b,c,d : array-like or number
Returns
-------
r : array-like or number
References
----------
https://en.wikipedia.org/wiki/Cross-ratio
"""
return ((a-b)*(c-d))/((a-d)*(c-b))
[docs]
def complexify(f: Callable, name: str = None):
"""
Make a function f(scalar) into f(complex).
If `f(x)` then it returns `f_c(z) = f(real(z)) + 1j*f(imag(z))`
If the real/imag arguments are not first, then you may specify the
name given to them as kwargs.
Parameters
----------
f : Callable
function of real variable
name : string, optional
name of the real/imag argument names if they are not first
Returns
-------
f_c : Callable
function of a complex variable
Examples
--------
>>> def f(x): return x
>>> f_c = rf.complexify(f)
>>> z = 0.2 -1j*0.3
>>> f_c(z)
"""
def f_c(z, *args, **kw):
if name is not None:
kw_re = {name: real(z)}
kw_im = {name: imag(z)}
kw_re.update(kw)
kw_im.update(kw)
return f(*args, **kw_re) + 1j*f(*args, **kw_im)
else:
return f(real(z), *args,**kw) + 1j*f(imag(z), *args, **kw)
return f_c
# old functions just for reference
[docs]
def complex2Scalar(z: NumberLike):
"""
Serialize a list/array of complex numbers
Parameters
----------
z : number or array_like
A complex number or sequence of complex numbers
Returns
-------
re_im : array_like
produce the following array for an input z:
z[0].real, z[0].imag, z[1].real, z[1].imag, etc.
See Also
--------
scalar2Complex
"""
z = jnp.array(z)
re_im = []
for k in z:
re_im.append(jnp.real(k))
re_im.append(jnp.imag(k))
return jnp.array(re_im).flatten()
[docs]
def scalar2Complex(s: NumberLike):
"""
Unserialize a list/array of real and imag numbers into a complex array.
Inverse of :func:`complex2Scalar`.
Parameters
----------
s : array_like
an array with real and imaginary parts ordered as:
z[0].real, z[0].imag, z[1].real, z[1].imag, etc.
Returns
-------
z : Number or array_like
A complex number or sequence of complex number
See Also
--------
complex2Scalar
"""
s = jnp.array(s)
z = []
for k in range(0,len(s),2):
z.append(s[k] + 1j*s[k+1])
return jnp.array(z).flatten()
[docs]
def multiply_by(x: jnp.ndarray, by: jnp.ndarray, axis=None) -> jnp.ndarray:
if by.shape == x.shape:
x *= by
else:
n = len(by)
if len(x.shape) == 1:
if axis == 0:
x = x.reshape(len(x), 1)
else:
x = x.reshape(1, len(x))
if axis and x.shape[axis] % n != 0:
raise ValueError(f"The length of the specified axis ({x.shape[axis]}) is not divisible by {n}.")
if axis == 0:
by = jnp.tile(by, (int(x.shape[0] / n), x.shape[1]))
elif axis == 1:
by = jnp.tile(by, (x.shape[0], int(x.shape[1] / n)))
x *= by
return x
[docs]
def sum_every(x: jnp.ndarray, n: int, axis=None) -> jnp.ndarray:
if len(x.shape) == 1 and len(x) % n == 0:
if axis == 0:
x = x.reshape(len(x), 1)
else:
x = x.reshape(1, len(x))
# Ensure the axis length is divisible by i
if x.shape[axis] % n != 0:
raise ValueError(f"The length of the specified axis ({x.shape[axis]}) is not divisible by {n}.")
# Reshape and sum
if axis == 0:
# Group rows
x = x.reshape(-1, n, x.shape[1]).sum(axis=1)
elif axis == 1:
# Group columns
x = x.reshape(x.shape[0], -1, n).sum(axis=2)
return x
[docs]
def multiply_every(x: jnp.ndarray, n: int, axis=None) -> jnp.ndarray:
if len(x.shape) == 1 and len(x) % n == 0:
if axis == 0:
x = x.reshape(len(x), 1)
else:
x = x.reshape(1, len(x))
# Ensure the axis length is divisible by i
if x.shape[axis] % n != 0:
raise ValueError(f"The length of the specified axis ({x.shape[axis]}) is not divisible by {n}.")
# Reshape and sum
if axis == 0:
# Group rows
x = x.reshape(-1, n, x.shape[1]).prod(axis=1)
elif axis == 1:
# Group columns
x = x.reshape(x.shape[0], -1, n).prod(axis=2)
return x
[docs]
def convolve_interleaved(x, axis=1) -> jnp.ndarray:
if axis == 1:
if len(x.shape) == 1:
y1, y2 = x[0::2], x[1::2]
x = jnp.convolve(y1, y2)
else:
raise Exception("Not yet implemented")
else:
raise Exception("Not yet implemented")
return x
[docs]
def rms(x):
return jnp.sqrt(jnp.mean(x**2))
[docs]
def mag_2_db(x):
return 20 * jnp.log10(jnp.abs(x))
[docs]
def db_2_mag(x):
return 10 ** (x / 20)
[docs]
def polar_2_rect(radii, angles, deg=False):
if deg:
angles = jnp.deg2rad(angles)
return radii * jnp.exp(1j*angles)
[docs]
def rect_2_polar(x, deg=False):
return abs(x), jnp.angle(x, deg=deg)
[docs]
def round_sig(x, sig=3):
if x == 0:
return 0
return round(x, sig - int(jnp.floor(jnp.log10(abs(x)))) - 1)
[docs]
def comb(N: jnp.ndarray, k: jnp.ndarray, exact: bool = False, repetition: bool = False):
r"""The number of combinations of N things taken k at a time.
This is often expressed as "N choose k".
Args:
N: The number of things.
k: The number of elements taken.
exact: If `True`, the result is computed exactly and returned as an integer type.
Currently, vectorization is not supported for exact=True.
repetition: If `repetition` is True, then the number of combinations with
repetition is computed.
Returns:
The number of combinations of N things taken k at a time.
Notes:
When exact=False, the result is approximately and efficiently computed using the following formula:
.. math::
\begin{equation}
\exp\left\{\ln{\Gamma(N+1)} - [\ln{\Gamma(k+1)} + \ln{\Gamma(N+1-k)}]\right\}
\end{equation}
Where we use the Gamma function.
"""
if repetition:
return comb(N + k - 1, k, exact=exact, repetition=False)
if exact:
max_divisor = lax.max(k, N - k)
min_divisor = lax.min(k, N - k)
N_factorial_over_max_factorial = jnp.prod(jnp.arange(N, max_divisor, -1))
return lax.div(N_factorial_over_max_factorial, jnp.prod(jnp.arange(1, min_divisor + 1)))
one = _constant_like(N, 1)
N_plus_1 = lax.add(N,one)
k_plus_1 = lax.add(k,one)
return lax.exp(lax.sub(gammaln(N_plus_1),lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1,k)))))
[docs]
def evaluate_power_basis(x, coeffs, lower_bound, upper_bound):
coeffs = jnp.asarray(coeffs)
x_norm = (x - lower_bound) / (upper_bound - lower_bound)
return jnp.polyval(coeffs[::-1], x_norm)
[docs]
def evaluate_bernstein_basis(x, coeffs, lower_bound, upper_bound):
coeffs = jnp.asarray(coeffs)
n = len(coeffs) - 1 # Degree of the polynomial
i = jnp.arange(n + 1)
binomial_coeffs = comb(n, i)
t = (x - lower_bound) / (upper_bound - lower_bound)
def _eval_single(t_scalar):
basis_values = jnp.power(t_scalar, i) * jnp.power(1 - t_scalar, n - i)
return jnp.dot(coeffs, binomial_coeffs * basis_values)
result = jax.vmap(_eval_single)(jnp.atleast_1d(t))
return result
# Aliases, lookups and partials
conv_inter = convolve_interleaved
dB20 = mag_2_db
l1_norm = partial(jnp.linalg.norm, ord=1)
l1_norm_ax0 = partial(jnp.linalg.norm, ord=1, axis=0)
l1_norm_ax1 = partial(jnp.linalg.norm, ord=1, axis=1)
l2_norm = partial(jnp.linalg.norm, ord=2)
l2_norm_ax0 = partial(jnp.linalg.norm, ord=2, axis=0)
l2_norm_ax1 = partial(jnp.linalg.norm, ord=2, axis=1)
conv_cost = lambda x: dB20(l2_norm_ax0(conv_inter(l2_norm_ax0(x))))
FUNC_LOOKUP: dict[str, tuple[str, Callable | None]] = {
're': ('Real Part', jnp.real),
'im': ('Imag Part', jnp.imag),
'mag': ('Magnitude', jnp.abs),
'db': ('Magnitude (dB)', complex_2_db),
'db10': ('Magnitude (dB)', complex_2_db10),
'rad': ('Phase (rad)', jnp.angle),
'deg': ('Phase (deg)', lambda x: jnp.angle(x, deg=True)),
'arcl': ('Arc Length',lambda x: jnp.angle(x) * jnp.abs(x)),
'rad_unwrap': ('Phase (rad)', lambda x: unwrap_rad(jnp.angle(x))),
'deg_unwrap': ('Phase (deg)', lambda x: radian_2_degree(unwrap_rad(jnp.angle(x)))),
'arcl_unwrap': ('Arc Length', lambda x: unwrap_rad(jnp.angle(x)) * jnp.abs(x)),
'vswr': ('VSWR', lambda x: (1 + jnp.abs(x)) / (1 - jnp.abs(x))),
# 'time': ('Time (real)', mf.ifft),
# 'time_db': ('Magnitude (dB)', lambda x: mf.complex_2_db(mf.ifft(x))),
# 'time_mag': ('Magnitude', lambda x: mf.complex_2_magnitude(mf.ifft(x))),
}