Custom lambdification#

Hide code cell content
import inspect

import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp
from black import FileMode, format_str

Overwrite printer methods#

As noted in TR-000, it’s hard to lambdify a sympy.sqrt to JAX. One possible way out is to define a custom class that derives from sympy.Expr and overwrite its printer methods.

from sympy.printing.printer import Printer

class ComplexSqrt(sp.Expr):
    def __new__(cls, x, *args, **kwargs):
        x = sp.sympify(x)
        expr = sp.Expr.__new__(cls, x, *args, **kwargs)
        if hasattr(x, "free_symbols") and not x.free_symbols:
            return expr.evaluate()
        return expr

    def evaluate(self):
        x = self.args[0]
        if not x.is_real:
            return sp.sqrt(x)
        return sp.Piecewise(
            (sp.I * sp.sqrt(-x), x < 0),
            (sp.sqrt(x), True),

    def _latex(self, printer: Printer, *args) -> str:
        x = printer._print(self.args[0])
        return Rf"\sqrt[\mathrm{{c}}]{{{x}}}"

    def _numpycode(self, printer: Printer, *args) -> str:
        x = printer._print(self.args[0])
        return f"scimath.sqrt({x})"

    def _pythoncode(self, printer: Printer, *args) -> str:
        printer.module_imports["cmath"].add("sqrt as csqrt")
        x = printer._print(self.args[0])
        return f"csqrt({x})"

As opposed to the derivation of a sympy.Expr, this class evaluates directly, because the evaluate key-word argument is not used processed by the __new__ method:

\[\displaystyle 2 i\]

The _latex() method ensures that ComplexSqrt renders nicely in notebooks:

x = sp.Symbol("x")
\[\displaystyle \sqrt[\mathrm{c}]{x}\]

Plot custom class#

In addition, one may modify this Lambdifier class, so that sympy.plot() also works on this custom class:

from sympy.plotting.experimental_lambdify import Lambdifier

Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"
%config InlineBackend.figure_formats = ['svg']

x = sp.Symbol("x")
expr = ComplexSqrt(x)
p1 = sp.plot(, (x, -1, 2), show=False, line_color="red")
p2 = sp.plot(, (x, -1, 2), show=False)


The important part, lambdifying to numpy or math works well as well now:

lambdified_py = sp.lambdify(x, ComplexSqrt(x), "math")
source = inspect.getsource(lambdified_py)
def _lambdifygenerated(x):
    return (csqrt(x))
numpy_lambdified = sp.lambdify(x, ComplexSqrt(x), "numpy")
source = inspect.getsource(numpy_lambdified)
def _lambdifygenerated(x):
    return (scimath.sqrt(x))
sample = np.linspace(-1, +1, 5)
array([0.        +1.j        , 0.        +0.70710678j,
       0.        +0.j        , 0.70710678+0.j        ,
       1.        +0.j        ])

Just as noted in Complex square root though, numpy.emath is not provided by the NumPy API of JAX. As discussed there, we can at most decorate the numpy.emath version with jax.jit() and work with static arguments only:

jax_lambdified = jax.jit(numpy_lambdified, backend="cpu", static_argnums=0)
DeviceArray(0.+1.j, dtype=complex64)

In that case, unhashable (non-static) input samples are still not accepted:

ValueError                                Traceback (most recent call last)
<ipython-input-12-94f797bd7204> in <module>
----> 1 jax_lambdified(sample)

ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1.  -0.5  0.   0.5  1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'

Handle for JAX#

As concluded in Conditional square root, the alternative to lambdify to numpy.emath is to lambdify to This has some caveats, though, like that you should not use __dict__. Worse, JAX is not immediately supported as backend. Fortunately, we now know how to overwrite lambdify methods.

An additional tool we need now is to define a new printer class for JAX, so that we can also define a special rendering method for ComplexSqrt in the case of JAX. Most of its printing methods should be the same as that of SymPy’s NumPyPrinter, the rest we can overwrite:


Alternative would be to add a method _jaxcode to the ComplexSqrt class above. See Printing.

from sympy.printing.numpy import NumPyPrinter

class JaxPrinter(NumPyPrinter):
    _module = "jax"

    def _print_ComplexSqrt(self, expr: sp.Expr) -> str:
        arg = expr.args[0]
        x = self._print(arg)
        return (
            f"select([less({x}, 0), True], [1j * sqrt(-{x}), sqrt({x})],"
            " default=nan,)"
numpy_expr = sp.lambdify(x, ComplexSqrt(x), modules=np, printer=JaxPrinter)
source = inspect.getsource(numpy_expr)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
jax_expr = sp.lambdify(x, ComplexSqrt(x), modules=jnp, printer=JaxPrinter)
source = inspect.getsource(jax_expr)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
DeviceArray([0.        +1.j        , 0.        +0.70710677j,
             0.        +0.j        , 0.70710677+0.j        ,
             1.        +0.j        ], dtype=complex64)

The lambdified function can of course also be decorated with jax.jit():

jit_expr = jax.jit(jax_expr)

Performance check#

rng = np.random.default_rng()
sample = rng.normal(size=1_000_000)
jax_sample = jnp.array(sample)
%timeit jit_expr(jax_sample)
1.91 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax_expr(jax_sample)
6.31 ms ± 42.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit numpy_expr(sample)
16.9 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)