Complex square roots#

Hide code cell content
import inspect

import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp
from black import FileMode, format_str
from IPython.display import display

Negative input values#

When using numpy as back-end, sympy lambdifies a sqrt() to a numpy.sqrt:

x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr
\[\displaystyle \sqrt{x}\]
np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)
def _lambdifygenerated(x):
    return (sqrt(x))

As expected, if input values for the numpy.sqrt are negative, numpy raises a RuntimeWarning and returns NaN:

sample = np.linspace(-1, 1, 5)
np_sqrt(sample)
array([       nan,        nan, 0.        , 0.70710678, 1.        ])

If we want numpy to return imaginary numbers for negative input values, one can use complex input data instead (e.g. numpy.complex64). Negative values are then treated as lying just above the real axis, so that their square root is a positive imaginary number:

complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)
array([0.        +1.j        , 0.        +0.70710677j,
       0.        +0.j        , 0.70710677+0.j        ,
       1.        +0.j        ], dtype=complex64)

A sympy.sqrt lambdified to JAX exhibits the same behavior:

jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)
def _lambdifygenerated(x):
    return (sqrt(x))
jax_sqrt(sample)
DeviceArray([       nan,        nan, 0.        , 0.70710677, 1.        ],            dtype=float32)
jax_sqrt(complex_sample)
DeviceArray([-4.3711388e-08+1.j        , -3.0908620e-08+0.70710677j,
              0.0000000e+00+0.j        ,  7.0710677e-01+0.j        ,
              1.0000000e+00+0.j        ], dtype=complex64)

There is a problem with this approach though: once input data is complex, all square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior.

Take for instance the two square roots appearing in PhaseSpaceFactor — does the \(\sqrt{s}\) also have to be evaluatable for negative \(s\)?

Complex square root#

Numpy also offers a special function that evaluates negative values even if the input values are real: numpy.emath.sqrt():

1j

Unfortunately, the jax.numpy API does not interface to numpy.emath. It is possible to decorate numpy.emath.sqrt() be decorated with jax.jit(), but that only works with static, hashable arguments:

jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
jax_csqrt_error(-1)
Hide code cell output
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs)
    975 app.initialize(argv)
--> 976 app.start()

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self)
    711 try:
--> 712     self.io_loop.start()
    713 except KeyboardInterrupt:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self)
    198     asyncio.set_event_loop(self.asyncio_loop)
--> 199     self.asyncio_loop.run_forever()
    200 finally:

File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self)
    569 while True:
--> 570     self._run_once()
    571     if self._stopping:

File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self)
   1858     else:
-> 1859         handle._run()
   1860 handle = None

File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/events.py:81, in Handle._run(self)
     80 try:
---> 81     self._context.run(self._callback, *self._args)
     82 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait)
    498         return None
--> 499 await dispatch(*args)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg)
    405     if inspect.isawaitable(result):
--> 406         await result
    407 except Exception:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent)
    729 if inspect.isawaitable(reply_content):
--> 730     reply_content = await reply_content
    732 # Flush output before sending the reply.

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id)
    382 if with_cell_id:
--> 383     res = shell.run_cell(
    384         code,
    385         store_history=store_history,
    386         silent=silent,
    387         cell_id=cell_id,
    388     )
    389 else:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2880 try:
-> 2881     result = self._run_cell(
   2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2935 try:
-> 2936     return runner(coro)
   2937 except BaseException as e:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)
   3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3136        interactivity=interactivity, compiler=compiler, result=result)
   3138 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result)
   3337     asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
   3339     return True

    [... skipping hidden 1 frame]

Input In [13], in <cell line: 2>()
      1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
----> 2 jax_csqrt_error(-1)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
    142 try:
--> 143   return fun(*args, **kwargs)
    144 except Exception as e:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs)
    425 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 426 out_flat = xla.xla_call(
    427     flat_fun,
    428     *args_flat,
    429     device=device,
    430     backend=backend,
    431     name=flat_fun.__name__,
    432     donated_invars=donated_invars)
    433 out_pytree_def = out_tree()

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params)
   1564 def bind(self, fun, *args, **params):
-> 1565   return call_bind(self, fun, *args, **params)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params)
   1555 with maybe_new_sublevel(top_trace):
-> 1556   outs = primitive.process(top_trace, fun, tracers, params)
   1557 return map(full_lower, apply_todos(env_trace_todo(), outs))

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params)
   1567 def process(self, trace, fun, tracers, params):
-> 1568   return trace.process_call(self, fun, tracers, params)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params)
    608 def process_call(self, primitive, f, tracers, params):
--> 609   return primitive.impl(f, *tracers, **params)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 578   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    579                                *unsafe_map(arg_spec, args))
    580   try:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args)
    261 else:
--> 262   ans = call(fun, *args)
    263   cache[key] = (ans, fun.stores)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    651 abstract_args, _ = unzip2(arg_specs)
--> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
    653 if any(isinstance(c, core.Tracer) for c in consts):

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1208 main.jaxpr_stack = ()  # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210 del fun, main

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
   1189 out_tracers = map(trace.full_raise, ans)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
    165 try:
--> 166   ans = self.f(*args, **dict(self.params, **kwargs))
    167 except:
    168   # Some transformations yield from inside context managers, so we have to
    169   # interrupt them before reraising the exception. Otherwise they will only
    170   # get garbage-collected at some later time, running their cleanup tasks only
    171   # after this exception is handled, which can corrupt the global state.

File <__array_function__ internals>:180, in sqrt(*args, **kwargs)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x)
    200 """
    201 Compute the square root of x.
    202 
   (...)
    245 -2j
    246 """
--> 247 x = _fix_real_lt_zero(x)
    248 return nx.sqrt(x)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x)
    113 """Convert `x` to complex if it has real, negative components.
    114 
    115 Otherwise, output is just the array version of the input (via asarray).
   (...)
    132 
    133 """
--> 134 x = asarray(x)
    135 if any(isreal(x) & (x < 0)):

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw)
    471 def __array__(self, *args, **kw):
--> 472   raise TracerArrayConversionError(self)

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
Input In [13], in <cell line: 2>()
      1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
----> 2 jax_csqrt_error(-1)

File <__array_function__ internals>:180, in sqrt(*args, **kwargs)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x)
    198 @array_function_dispatch(_unary_dispatcher)
    199 def sqrt(x):
    200     """
    201     Compute the square root of x.
    202 
   (...)
    245     -2j
    246     """
--> 247     x = _fix_real_lt_zero(x)
    248     return nx.sqrt(x)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x)
    112 def _fix_real_lt_zero(x):
    113     """Convert `x` to complex if it has real, negative components.
    114 
    115     Otherwise, output is just the array version of the input (via asarray).
   (...)
    132 
    133     """
--> 134     x = asarray(x)
    135     if any(isreal(x) & (x < 0)):
    136         x = _tocomplex(x)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
jax_csqrt = jax.jit(np.emath.sqrt, backend="cpu", static_argnums=0)
jax_csqrt(-1)
DeviceArray(0.+1.j, dtype=complex64)
jax_csqrt(sample)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 jax_csqrt(sample)

ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1.  -0.5  0.   0.5  1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'

Conditional square root#

To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use Piecewise:

def complex_sqrt(x: sp.Symbol) -> sp.Expr:
    return sp.Piecewise(
        (sp.sqrt(-x) * sp.I, x < 0),
        (sp.sqrt(x), True),
    )


complex_sqrt(x)
\[\begin{split}\displaystyle \begin{cases} i \sqrt{- x} & \text{for}\: x < 0 \\\sqrt{x} & \text{otherwise} \end{cases}\end{split}\]
display(
    complex_sqrt(-4),
    complex_sqrt(+4),
)
\[\displaystyle 2 i\]
\[\displaystyle 2\]

Be careful though when lambdifying this expression: do not use the __dict__ of the numpy module as backend, but use the module itself instead. When using __dict__, lambdify() will return an if-else statement, which is inefficient and, worse, will result in problems with JAX:

Warning

Do not use the module __dict__ for the modules argument of lambdify().

np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)
def _lambdifygenerated(x):
    return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
np_complex_sqrt_no_select(-1)
1j
jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)

When instead using the numpy module (or "numpy"), lambdify() correctly lambdifies to numpy.select() to represent the cases.

np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )

Still, JAX does not handle this correctly. First, lambdifying JAX again results in this if-else syntax:

jnp_complex_sqrt = sp.lambdify(x, complex_sqrt(x), jnp)
source = inspect.getsource(jnp_complex_sqrt)
print(source)
def _lambdifygenerated(x):
    return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))

But even if we lambdify to numpy and decorate the result with a jax.jit() decorator, the resulting function does not work properly:

jax_complex_sqrt_error = jax.jit(np_complex_sqrt)
source = inspect.getsource(jax_complex_sqrt_error)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )
jax_complex_sqrt_error(-1)
Hide code cell output
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs)
    975 app.initialize(argv)
--> 976 app.start()

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self)
    711 try:
--> 712     self.io_loop.start()
    713 except KeyboardInterrupt:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self)
    198     asyncio.set_event_loop(self.asyncio_loop)
--> 199     self.asyncio_loop.run_forever()
    200 finally:

File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self)
    569 while True:
--> 570     self._run_once()
    571     if self._stopping:

File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self)
   1858     else:
-> 1859         handle._run()
   1860 handle = None

File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/events.py:81, in Handle._run(self)
     80 try:
---> 81     self._context.run(self._callback, *self._args)
     82 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait)
    498         return None
--> 499 await dispatch(*args)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg)
    405     if inspect.isawaitable(result):
--> 406         await result
    407 except Exception:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent)
    729 if inspect.isawaitable(reply_content):
--> 730     reply_content = await reply_content
    732 # Flush output before sending the reply.

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id)
    382 if with_cell_id:
--> 383     res = shell.run_cell(
    384         code,
    385         store_history=store_history,
    386         silent=silent,
    387         cell_id=cell_id,
    388     )
    389 else:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2880 try:
-> 2881     result = self._run_cell(
   2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2935 try:
-> 2936     return runner(coro)
   2937 except BaseException as e:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)
   3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3136        interactivity=interactivity, compiler=compiler, result=result)
   3138 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result)
   3337     asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
   3339     return True

    [... skipping hidden 1 frame]

Input In [26], in <cell line: 1>()
----> 1 jax_complex_sqrt_error(-1)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
    142 try:
--> 143   return fun(*args, **kwargs)
    144 except Exception as e:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs)
    425 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 426 out_flat = xla.xla_call(
    427     flat_fun,
    428     *args_flat,
    429     device=device,
    430     backend=backend,
    431     name=flat_fun.__name__,
    432     donated_invars=donated_invars)
    433 out_pytree_def = out_tree()

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params)
   1564 def bind(self, fun, *args, **params):
-> 1565   return call_bind(self, fun, *args, **params)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params)
   1555 with maybe_new_sublevel(top_trace):
-> 1556   outs = primitive.process(top_trace, fun, tracers, params)
   1557 return map(full_lower, apply_todos(env_trace_todo(), outs))

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params)
   1567 def process(self, trace, fun, tracers, params):
-> 1568   return trace.process_call(self, fun, tracers, params)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params)
    608 def process_call(self, primitive, f, tracers, params):
--> 609   return primitive.impl(f, *tracers, **params)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 578   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    579                                *unsafe_map(arg_spec, args))
    580   try:

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args)
    261 else:
--> 262   ans = call(fun, *args)
    263   cache[key] = (ans, fun.stores)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    651 abstract_args, _ = unzip2(arg_specs)
--> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
    653 if any(isinstance(c, core.Tracer) for c in consts):

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1208 main.jaxpr_stack = ()  # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210 del fun, main

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
   1189 out_tracers = map(trace.full_raise, ans)

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
    165 try:
--> 166   ans = self.f(*args, **dict(self.params, **kwargs))
    167 except:
    168   # Some transformations yield from inside context managers, so we have to
    169   # interrupt them before reraising the exception. Otherwise they will only
    170   # get garbage-collected at some later time, running their cleanup tasks only
    171   # after this exception is handled, which can corrupt the global state.

File <lambdifygenerated-4>:2, in _lambdifygenerated(x)
      1 def _lambdifygenerated(x):
----> 2     return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan))

File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw)
    471 def __array__(self, *args, **kw):
--> 472   raise TracerArrayConversionError(self)

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
Input In [26], in <cell line: 1>()
----> 1 jax_complex_sqrt_error(-1)

File <lambdifygenerated-4>:2, in _lambdifygenerated(x)
      1 def _lambdifygenerated(x):
----> 2     return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan))

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The very same function in created purely with jax.numpy does work without problems, so it seems this is a SymPy problem:

@jax.jit
def jax_complex_sqrt(x):
    return jnp.select(
        [jnp.less(x, 0), True],
        [1j * jnp.sqrt(-x), jnp.sqrt(x)],
        default=jnp.nan,
    )
jax_complex_sqrt(sample)
DeviceArray([0.        +1.j        , 0.        +0.70710677j,
             0.        +0.j        , 0.70710677+0.j        ,
             1.        +0.j        ], dtype=complex64)

A solution to this is presented in Handle for JAX.