Custom lambdification#
Overwrite printer methods#
As noted in TR-000, it’s hard to lambdify a sympy.sqrt
to JAX. One possible way out is to define a custom class that derives from sympy.Expr
and overwrite its printer methods.
from sympy.printing.printer import Printer
class ComplexSqrt(sp.Expr):
def __new__(cls, x, *args, **kwargs):
x = sp.sympify(x)
expr = sp.Expr.__new__(cls, x, *args, **kwargs)
if hasattr(x, "free_symbols") and not x.free_symbols:
return expr.evaluate()
return expr
def evaluate(self):
x = self.args[0]
if not x.is_real:
return sp.sqrt(x)
return sp.Piecewise(
(sp.I * sp.sqrt(-x), x < 0),
(sp.sqrt(x), True),
)
def _latex(self, printer: Printer, *args) -> str:
x = printer._print(self.args[0])
return Rf"\sqrt[\mathrm{{c}}]{{{x}}}"
def _numpycode(self, printer: Printer, *args) -> str:
printer.module_imports["numpy.lib"].add("scimath")
x = printer._print(self.args[0])
return f"scimath.sqrt({x})"
def _pythoncode(self, printer: Printer, *args) -> str:
printer.module_imports["cmath"].add("sqrt as csqrt")
x = printer._print(self.args[0])
return f"csqrt({x})"
As opposed to the derivation of a sympy.Expr, this class evaluates directly, because the evaluate
key-word argument is not used processed by the __new__
method:
ComplexSqrt(-4)
The _latex()
method ensures that ComplexSqrt
renders nicely in notebooks:
x = sp.Symbol("x")
ComplexSqrt(x)
Plot custom class#
In addition, one may modify this Lambdifier
class, so that sympy.plot()
also works on this custom class:
from sympy.plotting.experimental_lambdify import Lambdifier
Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"
%config InlineBackend.figure_formats = ['svg']
x = sp.Symbol("x")
expr = ComplexSqrt(x)
p1 = sp.plot(sp.re(expr), (x, -1, 2), show=False, line_color="red")
p2 = sp.plot(sp.im(expr), (x, -1, 2), show=False)
p1.append(p2[0])
p1.show()
Lambdifying#
The important part, lambdifying to numpy
or math
works well as well now:
lambdified_py = sp.lambdify(x, ComplexSqrt(x), "math")
source = inspect.getsource(lambdified_py)
print(source)
def _lambdifygenerated(x):
return (csqrt(x))
numpy_lambdified = sp.lambdify(x, ComplexSqrt(x), "numpy")
source = inspect.getsource(numpy_lambdified)
print(source)
def _lambdifygenerated(x):
return (scimath.sqrt(x))
sample = np.linspace(-1, +1, 5)
numpy_lambdified(sample)
array([0. +1.j , 0. +0.70710678j,
0. +0.j , 0.70710678+0.j ,
1. +0.j ])
Just as noted in Complex square root though, numpy.emath
is not provided by the NumPy API of JAX. As discussed there, we can at most decorate the numpy.emath
version with jax.jit()
and work with static arguments only:
jax_lambdified = jax.jit(numpy_lambdified, backend="cpu", static_argnums=0)
jax_lambdified(-1)
DeviceArray(0.+1.j, dtype=complex64)
In that case, unhashable (non-static) input samples are still not accepted:
jax_lambdified(sample)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-12-94f797bd7204> in <module>
----> 1 jax_lambdified(sample)
ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1. -0.5 0. 0.5 1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'
Handle for JAX#
As concluded in Conditional square root, the alternative to lambdify to numpy.emath
is to lambdify to numpy.select()
. This has some caveats, though, like that you should not use __dict__
. Worse, JAX is not immediately supported as backend. Fortunately, we now know how to overwrite lambdify methods.
An additional tool we need now is to define a new printer class for JAX, so that we can also define a special rendering method for ComplexSqrt
in the case of JAX. Most of its printing methods should be the same as that of SymPy’s NumPyPrinter
, the rest we can overwrite:
Note
Alternative would be to add a method _jaxcode
to the ComplexSqrt
class above. See Printing.
from sympy.printing.numpy import NumPyPrinter
class JaxPrinter(NumPyPrinter):
_module = "jax"
def _print_ComplexSqrt(self, expr: sp.Expr) -> str:
arg = expr.args[0]
x = self._print(arg)
return (
f"select([less({x}, 0), True], [1j * sqrt(-{x}), sqrt({x})], default=nan,)"
)
numpy_expr = sp.lambdify(x, ComplexSqrt(x), modules=np, printer=JaxPrinter)
source = inspect.getsource(numpy_expr)
def _lambdifygenerated(x):
return select(
[less(x, 0), True],
[1j * sqrt(-x), sqrt(x)],
default=nan,
)
jax_expr = sp.lambdify(x, ComplexSqrt(x), modules=jnp, printer=JaxPrinter)
source = inspect.getsource(jax_expr)
def _lambdifygenerated(x):
return select(
[less(x, 0), True],
[1j * sqrt(-x), sqrt(x)],
default=nan,
)
jax_expr(sample)
DeviceArray([0. +1.j , 0. +0.70710677j,
0. +0.j , 0.70710677+0.j ,
1. +0.j ], dtype=complex64)
The lambdified function can of course also be decorated with jax.jit()
:
jit_expr = jax.jit(jax_expr)
Performance check#
rng = np.random.default_rng()
sample = rng.normal(size=1_000_000)
jax_sample = jnp.array(sample)
%timeit jit_expr(jax_sample)
1.91 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax_expr(jax_sample)
6.31 ms ± 42.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit numpy_expr(sample)
16.9 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)