Complex square roots#
Negative input values#
When using numpy
as back-end, sympy
lambdifies a sqrt()
to a numpy.sqrt
:
x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr
np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)
def _lambdifygenerated(x):
return (sqrt(x))
As expected, if input values for the numpy.sqrt
are negative, numpy
raises a RuntimeWarning
and returns NaN
:
sample = np.linspace(-1, 1, 5)
np_sqrt(sample)
array([ nan, nan, 0. , 0.70710678, 1. ])
If we want numpy
to return imaginary numbers for negative input values, one can use complex
input data instead (e.g. numpy.complex64). Negative values are then treated as lying just above the real axis, so that their square root is a positive imaginary number:
complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)
array([0. +1.j , 0. +0.70710677j,
0. +0.j , 0.70710677+0.j ,
1. +0.j ], dtype=complex64)
A sympy.sqrt
lambdified to JAX exhibits the same behavior:
jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)
def _lambdifygenerated(x):
return (sqrt(x))
jax_sqrt(sample)
DeviceArray([ nan, nan, 0. , 0.70710677, 1. ], dtype=float32)
jax_sqrt(complex_sample)
DeviceArray([-4.3711388e-08+1.j , -3.0908620e-08+0.70710677j,
0.0000000e+00+0.j , 7.0710677e-01+0.j ,
1.0000000e+00+0.j ], dtype=complex64)
There is a problem with this approach though: once input data is complex, all square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior.
Take for instance the two square roots appearing in PhaseSpaceFactor
— does the \(\sqrt{s}\) also have to be evaluatable for negative \(s\)?
Complex square root#
Numpy also offers a special function that evaluates negative values even if the input values are real: numpy.emath.sqrt()
:
np.emath.sqrt(-1)
1j
Unfortunately, the jax.numpy
API does not interface to numpy.emath
. It is possible to decorate numpy.emath.sqrt()
be decorated with jax.jit()
, but that only works with static, hashable arguments:
jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
jax_csqrt_error(-1)
jax_csqrt = jax.jit(np.emath.sqrt, backend="cpu", static_argnums=0)
jax_csqrt(-1)
DeviceArray(0.+1.j, dtype=complex64)
jax_csqrt(sample)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 jax_csqrt(sample)
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1. -0.5 0. 0.5 1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'
Conditional square root#
To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use Piecewise
:
def complex_sqrt(x: sp.Symbol) -> sp.Expr:
return sp.Piecewise(
(sp.sqrt(-x) * sp.I, x < 0),
(sp.sqrt(x), True),
)
complex_sqrt(x)
display(
complex_sqrt(-4),
complex_sqrt(+4),
)
Be careful though when lambdifying this expression: do not use the __dict__
of the numpy
module as backend, but use the module itself instead. When using __dict__
, lambdify()
will return an if-else
statement, which is inefficient and, worse, will result in problems with JAX:
Warning
Do not use the module __dict__
for the modules
argument of lambdify()
.
np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)
def _lambdifygenerated(x):
return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
np_complex_sqrt_no_select(-1)
1j
jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)
When instead using the numpy
module (or "numpy"
), lambdify()
correctly lambdifies to numpy.select()
to represent the cases.
np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)
def _lambdifygenerated(x):
return select(
[less(x, 0), True],
[1j * sqrt(-x), sqrt(x)],
default=nan,
)
Still, JAX does not handle this correctly. First, lambdifying JAX again results in this if-else
syntax:
jnp_complex_sqrt = sp.lambdify(x, complex_sqrt(x), jnp)
source = inspect.getsource(jnp_complex_sqrt)
print(source)
def _lambdifygenerated(x):
return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
But even if we lambdify to numpy
and decorate the result with a jax.jit()
decorator, the resulting function does not work properly:
jax_complex_sqrt_error = jax.jit(np_complex_sqrt)
source = inspect.getsource(jax_complex_sqrt_error)
def _lambdifygenerated(x):
return select(
[less(x, 0), True],
[1j * sqrt(-x), sqrt(x)],
default=nan,
)
jax_complex_sqrt_error(-1)
The very same function in created purely with jax.numpy
does work without problems, so it seems this is a SymPy problem:
@jax.jit
def jax_complex_sqrt(x):
return jnp.select(
[jnp.less(x, 0), True],
[1j * jnp.sqrt(-x), jnp.sqrt(x)],
default=jnp.nan,
)
jax_complex_sqrt(sample)
DeviceArray([0. +1.j , 0. +0.70710677j,
0. +0.j , 0.70710677+0.j ,
1. +0.j ], dtype=complex64)
A solution to this is presented in Handle for JAX.