Complex square roots#
Negative input values#
When using numpy
as back-end, sympy
lambdifies a sqrt()
to a numpy.sqrt
:
x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr
np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)
def _lambdifygenerated(x):
return (sqrt(x))
As expected, if input values for the numpy.sqrt
are negative, numpy
raises a RuntimeWarning
and returns NaN
:
sample = np.linspace(-1, 1, 5)
np_sqrt(sample)
array([ nan, nan, 0. , 0.70710678, 1. ])
If we want numpy
to return imaginary numbers for negative input values, one can use complex
input data instead (e.g. numpy.complex64). Negative values are then treated as lying just above the real axis, so that their square root is a positive imaginary number:
complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)
array([0. +1.j , 0. +0.70710677j,
0. +0.j , 0.70710677+0.j ,
1. +0.j ], dtype=complex64)
A sympy.sqrt
lambdified to JAX exhibits the same behavior:
jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)
def _lambdifygenerated(x):
return (sqrt(x))
jax_sqrt(sample)
DeviceArray([ nan, nan, 0. , 0.70710677, 1. ], dtype=float32)
jax_sqrt(complex_sample)
DeviceArray([-4.3711388e-08+1.j , -3.0908620e-08+0.70710677j,
0.0000000e+00+0.j , 7.0710677e-01+0.j ,
1.0000000e+00+0.j ], dtype=complex64)
There is a problem with this approach though: once input data is complex, all square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior.
Take for instance the two square roots appearing in PhaseSpaceFactor
— does the \(\sqrt{s}\) also have to be evaluatable for negative \(s\)?
Complex square root#
Numpy also offers a special function that evaluates negative values even if the input values are real: numpy.emath.sqrt()
:
np.emath.sqrt(-1)
1j
Unfortunately, the jax.numpy
API does not interface to numpy.emath
. It is possible to decorate numpy.emath.sqrt()
be decorated with jax.jit()
, but that only works with static, hashable arguments:
jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
jax_csqrt_error(-1)
Show code cell output
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs)
975 app.initialize(argv)
--> 976 app.start()
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self)
711 try:
--> 712 self.io_loop.start()
713 except KeyboardInterrupt:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self)
198 asyncio.set_event_loop(self.asyncio_loop)
--> 199 self.asyncio_loop.run_forever()
200 finally:
File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self)
569 while True:
--> 570 self._run_once()
571 if self._stopping:
File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self)
1858 else:
-> 1859 handle._run()
1860 handle = None
File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/events.py:81, in Handle._run(self)
80 try:
---> 81 self._context.run(self._callback, *self._args)
82 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self)
509 try:
--> 510 await self.process_one()
511 except Exception:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait)
498 return None
--> 499 await dispatch(*args)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg)
405 if inspect.isawaitable(result):
--> 406 await result
407 except Exception:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent)
729 if inspect.isawaitable(reply_content):
--> 730 reply_content = await reply_content
732 # Flush output before sending the reply.
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id)
382 if with_cell_id:
--> 383 res = shell.run_cell(
384 code,
385 store_history=store_history,
386 silent=silent,
387 cell_id=cell_id,
388 )
389 else:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs)
527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
2880 try:
-> 2881 result = self._run_cell(
2882 raw_cell, store_history, silent, shell_futures, cell_id
2883 )
2884 finally:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
2935 try:
-> 2936 return runner(coro)
2937 except BaseException as e:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)
3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3136 interactivity=interactivity, compiler=compiler, result=result)
3138 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result)
3337 asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
3339 return True
[... skipping hidden 1 frame]
Input In [13], in <cell line: 2>()
1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
----> 2 jax_csqrt_error(-1)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
142 try:
--> 143 return fun(*args, **kwargs)
144 except Exception as e:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs)
425 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 426 out_flat = xla.xla_call(
427 flat_fun,
428 *args_flat,
429 device=device,
430 backend=backend,
431 name=flat_fun.__name__,
432 donated_invars=donated_invars)
433 out_pytree_def = out_tree()
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params)
1564 def bind(self, fun, *args, **params):
-> 1565 return call_bind(self, fun, *args, **params)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params)
1555 with maybe_new_sublevel(top_trace):
-> 1556 outs = primitive.process(top_trace, fun, tracers, params)
1557 return map(full_lower, apply_todos(env_trace_todo(), outs))
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params)
1567 def process(self, trace, fun, tracers, params):
-> 1568 return trace.process_call(self, fun, tracers, params)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params)
608 def process_call(self, primitive, f, tracers, params):
--> 609 return primitive.impl(f, *tracers, **params)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 578 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
579 *unsafe_map(arg_spec, args))
580 try:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
651 abstract_args, _ = unzip2(arg_specs)
--> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
653 if any(isinstance(c, core.Tracer) for c in consts):
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name)
1208 main.jaxpr_stack = () # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1210 del fun, main
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
1189 out_tracers = map(trace.full_raise, ans)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
169 # interrupt them before reraising the exception. Otherwise they will only
170 # get garbage-collected at some later time, running their cleanup tasks only
171 # after this exception is handled, which can corrupt the global state.
File <__array_function__ internals>:180, in sqrt(*args, **kwargs)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x)
200 """
201 Compute the square root of x.
202
(...)
245 -2j
246 """
--> 247 x = _fix_real_lt_zero(x)
248 return nx.sqrt(x)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x)
113 """Convert `x` to complex if it has real, negative components.
114
115 Otherwise, output is just the array version of the input (via asarray).
(...)
132
133 """
--> 134 x = asarray(x)
135 if any(isreal(x) & (x < 0)):
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw)
471 def __array__(self, *args, **kw):
--> 472 raise TracerArrayConversionError(self)
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TracerArrayConversionError Traceback (most recent call last)
Input In [13], in <cell line: 2>()
1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
----> 2 jax_csqrt_error(-1)
File <__array_function__ internals>:180, in sqrt(*args, **kwargs)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x)
198 @array_function_dispatch(_unary_dispatcher)
199 def sqrt(x):
200 """
201 Compute the square root of x.
202
(...)
245 -2j
246 """
--> 247 x = _fix_real_lt_zero(x)
248 return nx.sqrt(x)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x)
112 def _fix_real_lt_zero(x):
113 """Convert `x` to complex if it has real, negative components.
114
115 Otherwise, output is just the array version of the input (via asarray).
(...)
132
133 """
--> 134 x = asarray(x)
135 if any(isreal(x) & (x < 0)):
136 x = _tocomplex(x)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
jax_csqrt = jax.jit(np.emath.sqrt, backend="cpu", static_argnums=0)
jax_csqrt(-1)
DeviceArray(0.+1.j, dtype=complex64)
jax_csqrt(sample)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 jax_csqrt(sample)
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1. -0.5 0. 0.5 1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'
Conditional square root#
To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use Piecewise
:
def complex_sqrt(x: sp.Symbol) -> sp.Expr:
return sp.Piecewise(
(sp.sqrt(-x) * sp.I, x < 0),
(sp.sqrt(x), True),
)
complex_sqrt(x)
display(
complex_sqrt(-4),
complex_sqrt(+4),
)
Be careful though when lambdifying this expression: do not use the __dict__
of the numpy
module as backend, but use the module itself instead. When using __dict__
, lambdify()
will return an if-else
statement, which is inefficient and, worse, will result in problems with JAX:
Warning
Do not use the module __dict__
for the modules
argument of lambdify()
.
np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)
def _lambdifygenerated(x):
return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
np_complex_sqrt_no_select(-1)
1j
jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)
When instead using the numpy
module (or "numpy"
), lambdify()
correctly lambdifies to numpy.select()
to represent the cases.
np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)
def _lambdifygenerated(x):
return select(
[less(x, 0), True],
[1j * sqrt(-x), sqrt(x)],
default=nan,
)
Still, JAX does not handle this correctly. First, lambdifying JAX again results in this if-else
syntax:
jnp_complex_sqrt = sp.lambdify(x, complex_sqrt(x), jnp)
source = inspect.getsource(jnp_complex_sqrt)
print(source)
def _lambdifygenerated(x):
return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
But even if we lambdify to numpy
and decorate the result with a jax.jit()
decorator, the resulting function does not work properly:
jax_complex_sqrt_error = jax.jit(np_complex_sqrt)
source = inspect.getsource(jax_complex_sqrt_error)
def _lambdifygenerated(x):
return select(
[less(x, 0), True],
[1j * sqrt(-x), sqrt(x)],
default=nan,
)
jax_complex_sqrt_error(-1)
Show code cell output
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs)
975 app.initialize(argv)
--> 976 app.start()
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self)
711 try:
--> 712 self.io_loop.start()
713 except KeyboardInterrupt:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self)
198 asyncio.set_event_loop(self.asyncio_loop)
--> 199 self.asyncio_loop.run_forever()
200 finally:
File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self)
569 while True:
--> 570 self._run_once()
571 if self._stopping:
File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self)
1858 else:
-> 1859 handle._run()
1860 handle = None
File ~/miniconda3/envs/compwa-report/lib/python3.8/asyncio/events.py:81, in Handle._run(self)
80 try:
---> 81 self._context.run(self._callback, *self._args)
82 except (SystemExit, KeyboardInterrupt):
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self)
509 try:
--> 510 await self.process_one()
511 except Exception:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait)
498 return None
--> 499 await dispatch(*args)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg)
405 if inspect.isawaitable(result):
--> 406 await result
407 except Exception:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent)
729 if inspect.isawaitable(reply_content):
--> 730 reply_content = await reply_content
732 # Flush output before sending the reply.
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id)
382 if with_cell_id:
--> 383 res = shell.run_cell(
384 code,
385 store_history=store_history,
386 silent=silent,
387 cell_id=cell_id,
388 )
389 else:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs)
527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
2880 try:
-> 2881 result = self._run_cell(
2882 raw_cell, store_history, silent, shell_futures, cell_id
2883 )
2884 finally:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
2935 try:
-> 2936 return runner(coro)
2937 except BaseException as e:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)
3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3136 interactivity=interactivity, compiler=compiler, result=result)
3138 self.last_execution_succeeded = not has_raised
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result)
3337 asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
3339 return True
[... skipping hidden 1 frame]
Input In [26], in <cell line: 1>()
----> 1 jax_complex_sqrt_error(-1)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
142 try:
--> 143 return fun(*args, **kwargs)
144 except Exception as e:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs)
425 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 426 out_flat = xla.xla_call(
427 flat_fun,
428 *args_flat,
429 device=device,
430 backend=backend,
431 name=flat_fun.__name__,
432 donated_invars=donated_invars)
433 out_pytree_def = out_tree()
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params)
1564 def bind(self, fun, *args, **params):
-> 1565 return call_bind(self, fun, *args, **params)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params)
1555 with maybe_new_sublevel(top_trace):
-> 1556 outs = primitive.process(top_trace, fun, tracers, params)
1557 return map(full_lower, apply_todos(env_trace_todo(), outs))
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params)
1567 def process(self, trace, fun, tracers, params):
-> 1568 return trace.process_call(self, fun, tracers, params)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params)
608 def process_call(self, primitive, f, tracers, params):
--> 609 return primitive.impl(f, *tracers, **params)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 578 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
579 *unsafe_map(arg_spec, args))
580 try:
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
651 abstract_args, _ = unzip2(arg_specs)
--> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
653 if any(isinstance(c, core.Tracer) for c in consts):
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name)
1208 main.jaxpr_stack = () # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1210 del fun, main
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
1189 out_tracers = map(trace.full_raise, ans)
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
168 # Some transformations yield from inside context managers, so we have to
169 # interrupt them before reraising the exception. Otherwise they will only
170 # get garbage-collected at some later time, running their cleanup tasks only
171 # after this exception is handled, which can corrupt the global state.
File <lambdifygenerated-4>:2, in _lambdifygenerated(x)
1 def _lambdifygenerated(x):
----> 2 return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan))
File ~/miniconda3/envs/compwa-report/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw)
471 def __array__(self, *args, **kw):
--> 472 raise TracerArrayConversionError(self)
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TracerArrayConversionError Traceback (most recent call last)
Input In [26], in <cell line: 1>()
----> 1 jax_complex_sqrt_error(-1)
File <lambdifygenerated-4>:2, in _lambdifygenerated(x)
1 def _lambdifygenerated(x):
----> 2 return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan))
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
The very same function in created purely with jax.numpy
does work without problems, so it seems this is a SymPy problem:
@jax.jit
def jax_complex_sqrt(x):
return jnp.select(
[jnp.less(x, 0), True],
[1j * jnp.sqrt(-x), jnp.sqrt(x)],
default=jnp.nan,
)
jax_complex_sqrt(sample)
DeviceArray([0. +1.j , 0. +0.70710677j,
0. +0.j , 0.70710677+0.j ,
1. +0.j ], dtype=complex64)
A solution to this is presented in Handle for JAX.