Speed up lambdifying#

Hide code cell content
%config InlineBackend.figure_formats = ['svg']

from __future__ import annotations

import inspect
import logging
import timeit
import warnings
from collections.abc import Generator, Sequence
from typing import Callable

import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import qrules
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from tensorwaves.data import generate_phsp
from tensorwaves.data.phasespace import TFUniformRealNumberGenerator
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel

LOGGER = logging.getLogger()

Create dummy expression#

First, let’s create an amplitude model with ampform. We’ll use this model as complicated sympy.Expr in the rest of this notebooks.

result = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)

model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.generate()
free_symbols = sorted(model.expression.free_symbols, key=lambda s: s.name)
free_symbols
[C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}],
 Gamma_f(0)(980),
 d_f(0)(980),
 m_1,
 m_12,
 m_2,
 m_f(0)(980),
 phi_1+2,
 phi_1,1+2,
 theta_1+2,
 theta_1,1+2]

Helicity model components#

A HelicityModel has the benefit that it comes with components (intensities and amplitudes) that together form its expression. Let’s separate these components into amplitude and intensity.

amplitudes = {
    name: expr for name, expr in model.components.items() if name.startswith("A")
}
sorted(amplitudes)
['A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]']
intensities = {
    name: expr for name, expr in model.components.items() if name.startswith("I")
}
assert len(amplitudes) + len(intensities) == len(model.components)

Component structure#

Note that each intensity consists of a subset of these amplitudes. This means that intensities have a larger expression tree than amplitudes.

amplitude_to_symbol = {
    expr: sp.Symbol(f"A{i}") for i, expr in enumerate(amplitudes.values(), 1)
}
intensity_to_symbol = {
    expr: sp.Symbol(f"I{i}") for i, expr in enumerate(intensities.values(), 1)
}
intensity_expr = model.expression.subs(intensity_to_symbol, simultaneous=True)
intensity_expr
\[\displaystyle I_{1} + I_{2} + I_{3} + I_{4}\]
dot = sp.dotprint(intensity_expr)
graphviz.Source(dot)

amplitude_expr = model.expression.subs(amplitude_to_symbol, simultaneous=True)
amplitude_expr
\[\displaystyle \left|{A_{1} + A_{2}}\right|^{2} + \left|{A_{3} + A_{4}}\right|^{2} + \left|{A_{5} + A_{6}}\right|^{2} + \left|{A_{7} + A_{8}}\right|^{2}\]
dot = sp.dotprint(amplitude_expr)
graphviz.Source(dot)

Performance check#

Lambdifying the whole HelicityModel.expression is slowest. The lambdify() function first prints the expression as a str (!) with (in this case) numpy syntax and then uses eval() to convert that back to actual numpy objects:

Hide code cell content
runtime = {}
start = timeit.default_timer()
%%time

np_complete_model = sp.lambdify(free_symbols, model.expression.doit(), "numpy")
CPU times: user 1.46 s, sys: 703 µs, total: 1.46 s
Wall time: 1.46 s
Hide code cell content
stop = timeit.default_timer()
runtime["complete model"] = stop - start

Printing to str and converting back with eval() becomes exponentially slow the larger the expression tree. This means that it’s more efficient to lambdify sub-trees of the expression tree separately. Lambdifying the four intensities of this model separately, the effect is not noticeable:

%%time

for expr, symbol in intensity_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start
CPU times: user 1.56 s, sys: 4.94 ms, total: 1.56 s
Wall time: 1.56 s

…but each of the eight amplitudes separately does result in a significant speed-up:

%%time

np_amplitudes = {}
for expr, symbol in amplitude_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    np_expr = sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start
    np_amplitudes[symbol] = np_expr
CPU times: user 547 ms, sys: 3.85 ms, total: 550 ms
Wall time: 547 ms

Recombining components#

Recall what amplitude module expressed in its amplitude components looks like:

amplitude_expr
\[\displaystyle \left|{A_{1} + A_{2}}\right|^{2} + \left|{A_{3} + A_{4}}\right|^{2} + \left|{A_{5} + A_{6}}\right|^{2} + \left|{A_{7} + A_{8}}\right|^{2}\]

We have to lambdify that top expression as well:

sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)
np_amplitude_expr = sp.lambdify(sorted_amplitude_symbols, amplitude_expr, "numpy")
source = inspect.getsource(np_amplitude_expr)
print(source)
def _lambdifygenerated(A1, A2, A3, A4, A5, A6, A7, A8):
    return (abs(A1 + A2)**2 + abs(A3 + A4)**2 + abs(A5 + A6)**2 + abs(A7 + A8)**2)

We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments.

def componentwise_lambdified(*args):
    """Lambdified amplitude model, recombined from its amplitude components.

    .. warning:: Order of the ``args`` has to be the same as that
        of the ``args`` of the lambdified amplitude components.
    """
    amplitude_values = []
    for amp_symbol in sorted_amplitude_symbols:
        np_amplitude = np_amplitudes[amp_symbol]
        values = np_amplitude(*args)
        amplitude_values.append(values)
    return np_amplitude_expr(*amplitude_values)

Test with data#

Okay, so does all this work? Let’s first generate a phase space sample with good-old tensorwaves. We can then use this sample as input to the component-wise lambdified function.

sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_sample = generate_phsp(10_000, model.adapter.reaction_info, random_generator=rng)
phsp_set = data_converter.transform(phsp_sample)
Hide code cell source
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(phsp_set["m_12"], bins=50, alpha=0.5, density=True)
ax.hist(
    phsp_set["m_12"],
    bins=50,
    alpha=0.5,
    density=True,
    weights=np.array(intensity(phsp_set)),
)
plt.show()

The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:

kinematic_variable_names = set(phsp_set)
parameter_names = {symbol.name for symbol in model.parameter_defaults}
free_symbol_names = {symbol.name for symbol in free_symbols}
assert free_symbol_names <= kinematic_variable_names ^ parameter_names

That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:

merged_par_var_values = {
    symbol.name: value for symbol, value in model.parameter_defaults.items()
}
merged_par_var_values.update(phsp_set)
args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]

Finally, here’s the result of plugging that back into the component-wise lambdified expression:

componentwise_result = componentwise_lambdified(*args_values)
componentwise_result
array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
       0.00030117])

And it’s indeed the same as that the intensity computed by tensorwaves (direct lambdify):

tensorwaves_result = np.array(intensity(phsp_set))
mean_difference = (componentwise_result - tensorwaves_result).mean()
mean_difference
-7.307471250984975e-11

Arbitrary expressions#

The problem with Test with data is that it requires a HelicityModel. In tensorwaves, we want to work with general sympy.Exprs though (see SympyModel), where we don’t have sub-ampform.helicity.HelicityModel.components available.

Instead, we have to split up the lambdifying in a more general way that can handle arbitrary sympy.core.expr.Exprs. For that we need:

  1. A general method of traversing through a SymPy expression tree. This can be done with Advanced Expression Manipulation.

  2. A fast method to estimate the complexity of a model, so that we can decide whether a node in the expression tree is small enough to be lambdified without much runtime. The best measure for complexity is count_ops() (“count operations”), see notes under Simplify.

Expression complexity#

Let’s tackle 2. first and use the HelicityModel.expression and its components that we lambdified earlier on. Here’s an overview of the number of operations versus the time it took to lambdify each component:

Hide code cell source
df = pd.DataFrame(runtime.values(), index=runtime, columns=["runtime (s)"])
operations = [sp.count_ops(model.expression)]
operations.extend(sp.count_ops(expr) for expr in intensity_to_symbol)
operations.extend(sp.count_ops(expr) for expr in amplitude_to_symbol)
df.insert(0, "operations", operations)
df
operations runtime (s)
complete model 823 0.980456
I1 209 0.279897
I2 203 0.235227
I3 207 0.215937
I4 201 0.233635
A1 103 0.045300
A2 103 0.040710
A3 100 0.039767
A4 100 0.035684
A5 102 0.036551
A6 102 0.036198
A7 99 0.042208
A8 99 0.040205

From this we can already roughly see that the lambdify runtime scales roughly with the number of SymPy operations.

To better visualize this, we can lambdify the expressions in BlattWeisskopfSquared for each angular momentums and compute their runtime a number of times with timeit. Note that the BlattWeisskopfSquared becomes increasingly complex the higher the angular momentum.

Hide code cell source
from ampform.dynamics import BlattWeisskopfSquared

angular_momentum, z = sp.symbols("L z")
BlattWeisskopfSquared(angular_momentum, z).doit()
\[\begin{split}\displaystyle \begin{cases} 1 & \text{for}\: L = 0 \\\frac{2 z}{z + 1} & \text{for}\: L = 1 \\\frac{13 z^{2}}{9 z + \left(z - 3\right)^{2}} & \text{for}\: L = 2 \\\frac{277 z^{3}}{z \left(z - 15\right)^{2} + \left(2 z - 5\right) \left(18 z - 45\right)} & \text{for}\: L = 3 \\\frac{12746 z^{4}}{25 z \left(2 z - 21\right)^{2} + \left(z^{2} - 45 z + 105\right)^{2}} & \text{for}\: L = 4 \\\frac{998881 z^{5}}{z^{5} + 15 z^{4} + 315 z^{3} + 6300 z^{2} + 99225 z + 893025} & \text{for}\: L = 5 \\\frac{118394977 z^{6}}{z^{6} + 21 z^{5} + 630 z^{4} + 18900 z^{3} + 496125 z^{2} + 9823275 z + 108056025} & \text{for}\: L = 6 \\\frac{19727003738 z^{7}}{z^{7} + 28 z^{6} + 1134 z^{5} + 47250 z^{4} + 1819125 z^{3} + 58939650 z^{2} + 1404728325 z + 18261468225} & \text{for}\: L = 7 \\\frac{4392846440677 z^{8}}{z^{8} + 36 z^{7} + 1890 z^{6} + 103950 z^{5} + 5457375 z^{4} + 255405150 z^{3} + 9833098275 z^{2} + 273922023375 z + 4108830350625} & \text{for}\: L = 8 \end{cases}\end{split}\]
Hide code cell content
operations = []
runtime = []
for angular_momentum in range(9):
    ff2 = BlattWeisskopfSquared(angular_momentum, z)
    operations.append(sp.count_ops(ff2.doit()))
    n_iterations = 10
    t = timeit.timeit(
        setup=f"""
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared
z = sp.Symbol("z")
ff2 = BlattWeisskopfSquared({angular_momentum}, z)
    """,
        stmt='sp.lambdify(z, ff2.doit(), "numpy")',
        number=n_iterations,
    )
    runtime.append(t / n_iterations * 1_000)
Hide code cell source
df = pd.DataFrame(
    {
        "operations": operations,
        "runtime (ms)": runtime,
    },
)
df
operations runtime (ms)
0 0 0.81877
1 3 1.24712
2 7 1.64094
3 12 2.52622
4 14 2.29422
5 16 1.88900
6 19 2.24741
7 22 2.72068
8 25 3.01171
Hide code cell source
fig, ax = plt.subplots(figsize=(8, 4))
plt.scatter(x=df["operations"], y=df["runtime (ms)"])
ax.set_ylim(bottom=0)
ax.set_xlabel("operations")
ax.set_ylabel("runtime (ms)")
plt.show()

Identifying nodes#

Now imagine that we don’t know anything about the expression that we created before other than that it is a sympy.Expr.

Approach 1: Generator#

A first attempt is to use a generator to recursively identify components in the expression that lie within a certain ‘complexity’ (as computed by count_ops()).

def recurse_tree(
    expression: sp.Expr, *, min_complexity: int = 0, max_complexity: int
) -> Generator[sp.Expr, None, None]:
    for arg in expression.args:
        complexity = sp.count_ops(arg)
        if complexity < max_complexity and complexity > min_complexity:
            yield arg
        else:
            yield from recurse_tree(
                arg,
                min_complexity=min_complexity,
                max_complexity=max_complexity,
            )

We can then use this generator function to create a mapping of these sub-expressions within the expression tree to Symbols. That mapping can then be used in xreplace() to replace the sub-expressions with those symbols.

%%time

expression = model.expression.doit()
sub_expressions = {}
for i, expr in enumerate(recurse_tree(expression, max_complexity=100)):
    symbol = sp.Symbol(f"f{i}")
    complexity = sp.count_ops(expr)
    sub_expressions[expr] = symbol
expression.xreplace(sub_expressions)
CPU times: user 314 ms, sys: 135 µs, total: 314 ms
Wall time: 313 ms
\[\displaystyle \left|{f_{0} + f_{1}}\right|^{2} + \left|{f_{2} + f_{3}}\right|^{2} + \left|{f_{4} + f_{5}}\right|^{2} + \left|{f_{6} + f_{7}}\right|^{2}\]

Approach 2: Direct substitution#

There is one problem though: xreplace() is not accurate for larger expressions. It would therefore be better to directly substitute the sub-expression with a symbol while we loop over the nodes in the expression tree. The following function can do that:

def split_expression(
    expression: sp.Expr,
    max_complexity: int,
    min_complexity: int = 0,
) -> tuple[sp.Expr, dict[sp.Symbol, sp.Expr]]:
    i = 0
    symbol_mapping = {}

    def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
        nonlocal i
        for arg in sub_expression.args:
            complexity = sp.count_ops(arg)
            if complexity < max_complexity and complexity > min_complexity:
                symbol = sp.Symbol(f"f{i}")
                i += 1
                symbol_mapping[symbol] = arg
                sub_expression = sub_expression.xreplace({arg: symbol})
            else:
                new_arg = recursive_split(arg)
                sub_expression = sub_expression.xreplace({arg: new_arg})
        return sub_expression

    top_expression = recursive_split(expression)
    return top_expression, symbol_mapping

And indeed, this is much faster than Approach 1: Generator (it’s even possible to parallelize this for loop):

%time

top_expression, sub_expressions = split_expression(expression, max_complexity=100)
CPU times: user 7 µs, sys: 1 µs, total: 8 µs
Wall time: 15.5 µs
top_expression
\[\displaystyle \left|{f_{0} + f_{1}}\right|^{2} + \left|{f_{2} + f_{3}}\right|^{2} + \left|{f_{4} + f_{5}}\right|^{2} + \left|{f_{6} + f_{7}}\right|^{2}\]
sub_expressions[sp.Symbol("f0")]
\[\displaystyle \frac{C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}] \Gamma_{f(0)(980)} m_{f(0)(980)} \left(\frac{\cos{\left(\theta_{1+2} \right)}}{2} + \frac{1}{2}\right) e^{i \phi_{1+2}}}{- \frac{i \Gamma_{f(0)(980)} m_{f(0)(980)} \sqrt{\frac{\left(m_{12}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{12}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}{m_{12}^{2}}} \sqrt{m_{f(0)(980)}^{2}}}{\sqrt{\frac{\left(m_{f(0)(980)}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{f(0)(980)}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}{m_{f(0)(980)}^{2}}} \left|{m_{12}}\right|} - m_{12}^{2} + m_{f(0)(980)}^{2}}\]

Lambdify and combine#

Now that we have the machinery to split up arbitrary expressions by complexity, we need to lambdify the top expression as well as each of the sub-expressions and recombine them. The following function can do that and return a recombined Callable.

def optimized_lambdify(
    args: Sequence[sp.Symbol],
    expr: sp.Expr,
    modules: str | None = None,
    min_complexity: int = 0,
    max_complexity: int = 100,
) -> Callable:
    top_expression, definitions = split_expression(
        expression,
        min_complexity=min_complexity,
        max_complexity=max_complexity,
    )
    top_symbols = sorted(definitions, key=lambda s: s.name)
    top_lambdified = sp.lambdify(top_symbols, top_expression, modules)
    sub_lambdified = [
        sp.lambdify(args, definitions[symbol], modules) for symbol in top_symbols
    ]

    def recombined_function(*args):
        new_args = [sub_expr(*args) for sub_expr in sub_lambdified]
        return top_lambdified(*new_args)

    return recombined_function

We can use the same input values as in Test with data to check that the resulting lambdified expression results in the same output.

%time

treewise_lambdified = optimized_lambdify(free_symbols, expression, "numpy")
CPU times: user 8 µs, sys: 1 µs, total: 9 µs
Wall time: 17.4 µs
treewise_result = treewise_lambdified(*args_values)
treewise_result
array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
       0.00030117])

And it’s indeed the same as that the intensity computed by tensorwaves (direct lambdify):

mean_difference = (treewise_result - tensorwaves_result).mean()
mean_difference
-7.307471274905997e-11

Comparison#

Now have a look at a slightly more complicated model:

Hide code cell source
result = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [+1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
complex_model = model_builder.generate()
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)

This makes it clear that the functions defined in Arbitrary expressions results in a huge speed-up!

new_expression = complex_model.expression.doit()
new_free_symbols = sorted(new_expression.free_symbols, key=lambda s: s.name)
%%time

np_expr = sp.lambdify(new_free_symbols, new_expression)
CPU times: user 4.57 s, sys: 3.16 ms, total: 4.57 s
Wall time: 4.57 s
%%time

np_expr = optimized_lambdify(new_free_symbols, new_expression)
CPU times: user 261 ms, sys: 87 µs, total: 262 ms
Wall time: 260 ms