Speed up lambdifying#
Show code cell content
%config InlineBackend.figure_formats = ['svg']
from __future__ import annotations
import inspect
import logging
import timeit
import warnings
from collections.abc import Generator, Sequence
from typing import Callable
import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import qrules
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from tensorwaves.data import generate_phsp
from tensorwaves.data.phasespace import TFUniformRealNumberGenerator
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel
LOGGER = logging.getLogger()
Create dummy expression#
First, let’s create an amplitude model with ampform
. We’ll use this model as complicated sympy.Expr
in the rest of this notebooks.
result = qrules.generate_transitions(
initial_state=("J/psi(1S)", [-1, +1]),
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)(980)"],
allowed_interaction_types=["strong", "EM"],
formalism_type="canonical-helicity",
)
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.generate()
free_symbols = sorted(model.expression.free_symbols, key=lambda s: s.name)
free_symbols
[C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}],
Gamma_f(0)(980),
d_f(0)(980),
m_1,
m_12,
m_2,
m_f(0)(980),
phi_1+2,
phi_1,1+2,
theta_1+2,
theta_1,1+2]
Helicity model components#
A HelicityModel
has the benefit that it comes with components
(intensities and amplitudes) that together form its expression
. Let’s separate these components into amplitude and intensity.
amplitudes = {
name: expr for name, expr in model.components.items() if name.startswith("A")
}
sorted(amplitudes)
['A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]']
intensities = {
name: expr for name, expr in model.components.items() if name.startswith("I")
}
Component structure#
Note that each intensity consists of a subset of these amplitudes. This means that intensities have a larger expression tree than amplitudes.
amplitude_to_symbol = {
expr: sp.Symbol(f"A{i}") for i, expr in enumerate(amplitudes.values(), 1)
}
intensity_to_symbol = {
expr: sp.Symbol(f"I{i}") for i, expr in enumerate(intensities.values(), 1)
}
intensity_expr = model.expression.subs(intensity_to_symbol, simultaneous=True)
intensity_expr
dot = sp.dotprint(intensity_expr)
graphviz.Source(dot)
amplitude_expr = model.expression.subs(amplitude_to_symbol, simultaneous=True)
amplitude_expr
dot = sp.dotprint(amplitude_expr)
graphviz.Source(dot)
Performance check#
Lambdifying the whole HelicityModel.expression
is slowest. The lambdify()
function first prints the expression as a str
(!) with (in this case) numpy
syntax and then uses eval()
to convert that back to actual numpy
objects:
Show code cell content
runtime = {}
start = timeit.default_timer()
%%time
np_complete_model = sp.lambdify(free_symbols, model.expression.doit(), "numpy")
CPU times: user 1.46 s, sys: 703 µs, total: 1.46 s
Wall time: 1.46 s
Show code cell content
stop = timeit.default_timer()
runtime["complete model"] = stop - start
Printing to str
and converting back with eval()
becomes exponentially slow the larger the expression tree. This means that it’s more efficient to lambdify sub-trees of the expression tree separately. Lambdifying the four intensities of this model separately, the effect is not noticeable:
%%time
for expr, symbol in intensity_to_symbol.items():
logging.info(f"Lambdifying {symbol.name}")
start = timeit.default_timer()
sp.lambdify(free_symbols, expr.doit(), "numpy")
stop = timeit.default_timer()
runtime[symbol.name] = stop - start
CPU times: user 1.56 s, sys: 4.94 ms, total: 1.56 s
Wall time: 1.56 s
…but each of the eight amplitudes separately does result in a significant speed-up:
%%time
np_amplitudes = {}
for expr, symbol in amplitude_to_symbol.items():
logging.info(f"Lambdifying {symbol.name}")
start = timeit.default_timer()
np_expr = sp.lambdify(free_symbols, expr.doit(), "numpy")
stop = timeit.default_timer()
runtime[symbol.name] = stop - start
np_amplitudes[symbol] = np_expr
CPU times: user 547 ms, sys: 3.85 ms, total: 550 ms
Wall time: 547 ms
Recombining components#
Recall what amplitude module expressed in its amplitude components looks like:
amplitude_expr
We have to lambdify that top expression as well:
sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)
np_amplitude_expr = sp.lambdify(sorted_amplitude_symbols, amplitude_expr, "numpy")
source = inspect.getsource(np_amplitude_expr)
print(source)
def _lambdifygenerated(A1, A2, A3, A4, A5, A6, A7, A8):
return (abs(A1 + A2)**2 + abs(A3 + A4)**2 + abs(A5 + A6)**2 + abs(A7 + A8)**2)
We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments.
def componentwise_lambdified(*args):
"""Lambdified amplitude model, recombined from its amplitude components.
.. warning:: Order of the ``args`` has to be the same as that
of the ``args`` of the lambdified amplitude components.
"""
amplitude_values = []
for amp_symbol in sorted_amplitude_symbols:
np_amplitude = np_amplitudes[amp_symbol]
values = np_amplitude(*args)
amplitude_values.append(values)
return np_amplitude_expr(*amplitude_values)
Test with data#
Okay, so does all this work? Let’s first generate a phase space sample with good-old tensorwaves
. We can then use this sample as input to the component-wise lambdified function.
sympy_model = SympyModel(
expression=model.expression,
parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_sample = generate_phsp(10_000, model.adapter.reaction_info, random_generator=rng)
phsp_set = data_converter.transform(phsp_sample)
Show code cell source
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(phsp_set["m_12"], bins=50, alpha=0.5, density=True)
ax.hist(
phsp_set["m_12"],
bins=50,
alpha=0.5,
density=True,
weights=np.array(intensity(phsp_set)),
)
plt.show()
The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:
kinematic_variable_names = set(phsp_set)
parameter_names = {symbol.name for symbol in model.parameter_defaults}
free_symbol_names = {symbol.name for symbol in free_symbols}
assert free_symbol_names <= kinematic_variable_names ^ parameter_names
That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:
merged_par_var_values = {
symbol.name: value for symbol, value in model.parameter_defaults.items()
}
merged_par_var_values.update(phsp_set)
args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]
Finally, here’s the result of plugging that back into the component-wise lambdified expression:
componentwise_result = componentwise_lambdified(*args_values)
componentwise_result
array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
0.00030117])
And it’s indeed the same as that the intensity computed by tensorwaves
(direct lambdify):
tensorwaves_result = np.array(intensity(phsp_set))
mean_difference = (componentwise_result - tensorwaves_result).mean()
mean_difference
-7.307471250984975e-11
Arbitrary expressions#
The problem with Test with data is that it requires a HelicityModel
. In tensorwaves
, we want to work with general sympy.Expr
s though (see SympyModel
), where we don’t have sub-ampform.helicity.HelicityModel.components
available.
Instead, we have to split up the lambdifying in a more general way that can handle arbitrary sympy.core.expr.Expr
s. For that we need:
A general method of traversing through a SymPy expression tree. This can be done with Advanced Expression Manipulation.
A fast method to estimate the complexity of a model, so that we can decide whether a node in the expression tree is small enough to be lambdified without much runtime. The best measure for complexity is
count_ops()
(“count operations”), see notes under Simplify.
Expression complexity#
Let’s tackle 2. first and use the HelicityModel.expression
and its components
that we lambdified earlier on. Here’s an overview of the number of operations versus the time it took to lambdify each component:
Show code cell source
df = pd.DataFrame(runtime.values(), index=runtime, columns=["runtime (s)"])
operations = [sp.count_ops(model.expression)]
operations.extend(sp.count_ops(expr) for expr in intensity_to_symbol)
operations.extend(sp.count_ops(expr) for expr in amplitude_to_symbol)
df.insert(0, "operations", operations)
df
operations | runtime (s) | |
---|---|---|
complete model | 823 | 0.980456 |
I1 | 209 | 0.279897 |
I2 | 203 | 0.235227 |
I3 | 207 | 0.215937 |
I4 | 201 | 0.233635 |
A1 | 103 | 0.045300 |
A2 | 103 | 0.040710 |
A3 | 100 | 0.039767 |
A4 | 100 | 0.035684 |
A5 | 102 | 0.036551 |
A6 | 102 | 0.036198 |
A7 | 99 | 0.042208 |
A8 | 99 | 0.040205 |
From this we can already roughly see that the lambdify runtime scales roughly with the number of SymPy operations.
To better visualize this, we can lambdify the expressions in BlattWeisskopfSquared
for each angular momentums and compute their runtime a number of times with timeit
. Note that the BlattWeisskopfSquared
becomes increasingly complex the higher the angular momentum.
Show code cell source
from ampform.dynamics import BlattWeisskopfSquared
angular_momentum, z = sp.symbols("L z")
BlattWeisskopfSquared(angular_momentum, z).doit()
Show code cell content
operations = []
runtime = []
for angular_momentum in range(9):
ff2 = BlattWeisskopfSquared(angular_momentum, z)
operations.append(sp.count_ops(ff2.doit()))
n_iterations = 10
t = timeit.timeit(
setup=f"""
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared
z = sp.Symbol("z")
ff2 = BlattWeisskopfSquared({angular_momentum}, z)
""",
stmt='sp.lambdify(z, ff2.doit(), "numpy")',
number=n_iterations,
)
runtime.append(t / n_iterations * 1_000)
Show code cell source
df = pd.DataFrame(
{
"operations": operations,
"runtime (ms)": runtime,
},
)
df
operations | runtime (ms) | |
---|---|---|
0 | 0 | 0.81877 |
1 | 3 | 1.24712 |
2 | 7 | 1.64094 |
3 | 12 | 2.52622 |
4 | 14 | 2.29422 |
5 | 16 | 1.88900 |
6 | 19 | 2.24741 |
7 | 22 | 2.72068 |
8 | 25 | 3.01171 |
Show code cell source
fig, ax = plt.subplots(figsize=(8, 4))
plt.scatter(x=df["operations"], y=df["runtime (ms)"])
ax.set_ylim(bottom=0)
ax.set_xlabel("operations")
ax.set_ylabel("runtime (ms)")
plt.show()
Identifying nodes#
Now imagine that we don’t know anything about the expression
that we created before other than that it is a sympy.Expr
.
Approach 1: Generator#
A first attempt is to use a generator to recursively identify components in the expression that lie within a certain ‘complexity’ (as computed by count_ops()
).
def recurse_tree(
expression: sp.Expr, *, min_complexity: int = 0, max_complexity: int
) -> Generator[sp.Expr, None, None]:
for arg in expression.args:
complexity = sp.count_ops(arg)
if complexity < max_complexity and complexity > min_complexity:
yield arg
else:
yield from recurse_tree(
arg,
min_complexity=min_complexity,
max_complexity=max_complexity,
)
We can then use this generator function to create a mapping of these sub-expressions within the expression tree to Symbol
s. That mapping can then be used in xreplace()
to replace the sub-expressions with those symbols.
%%time
expression = model.expression.doit()
sub_expressions = {}
for i, expr in enumerate(recurse_tree(expression, max_complexity=100)):
symbol = sp.Symbol(f"f{i}")
complexity = sp.count_ops(expr)
sub_expressions[expr] = symbol
expression.xreplace(sub_expressions)
CPU times: user 314 ms, sys: 135 µs, total: 314 ms
Wall time: 313 ms
Approach 2: Direct substitution#
There is one problem though: xreplace()
is not accurate for larger expressions. It would therefore be better to directly substitute the sub-expression with a symbol while we loop over the nodes in the expression tree. The following function can do that:
def split_expression(
expression: sp.Expr,
max_complexity: int,
min_complexity: int = 0,
) -> tuple[sp.Expr, dict[sp.Symbol, sp.Expr]]:
i = 0
symbol_mapping = {}
def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
nonlocal i
for arg in sub_expression.args:
complexity = sp.count_ops(arg)
if complexity < max_complexity and complexity > min_complexity:
symbol = sp.Symbol(f"f{i}")
i += 1
symbol_mapping[symbol] = arg
sub_expression = sub_expression.xreplace({arg: symbol})
else:
new_arg = recursive_split(arg)
sub_expression = sub_expression.xreplace({arg: new_arg})
return sub_expression
top_expression = recursive_split(expression)
return top_expression, symbol_mapping
And indeed, this is much faster than Approach 1: Generator (it’s even possible to parallelize this for loop):
%time
top_expression, sub_expressions = split_expression(expression, max_complexity=100)
CPU times: user 7 µs, sys: 1 µs, total: 8 µs
Wall time: 15.5 µs
top_expression
sub_expressions[sp.Symbol("f0")]
Lambdify and combine#
Now that we have the machinery to split up arbitrary expressions by complexity, we need to lambdify the top expression as well as each of the sub-expressions and recombine them. The following function can do that and return a recombined Callable
.
def optimized_lambdify(
args: Sequence[sp.Symbol],
expr: sp.Expr,
modules: str | None = None,
min_complexity: int = 0,
max_complexity: int = 100,
) -> Callable:
top_expression, definitions = split_expression(
expression,
min_complexity=min_complexity,
max_complexity=max_complexity,
)
top_symbols = sorted(definitions, key=lambda s: s.name)
top_lambdified = sp.lambdify(top_symbols, top_expression, modules)
sub_lambdified = [
sp.lambdify(args, definitions[symbol], modules) for symbol in top_symbols
]
def recombined_function(*args):
new_args = [sub_expr(*args) for sub_expr in sub_lambdified]
return top_lambdified(*new_args)
return recombined_function
We can use the same input values as in Test with data to check that the resulting lambdified expression results in the same output.
%time
treewise_lambdified = optimized_lambdify(free_symbols, expression, "numpy")
CPU times: user 8 µs, sys: 1 µs, total: 9 µs
Wall time: 17.4 µs
treewise_result = treewise_lambdified(*args_values)
treewise_result
array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
0.00030117])
And it’s indeed the same as that the intensity computed by tensorwaves
(direct lambdify):
mean_difference = (treewise_result - tensorwaves_result).mean()
mean_difference
-7.307471274905997e-11
Comparison#
Now have a look at a slightly more complicated model:
Show code cell source
result = qrules.generate_transitions(
initial_state=("J/psi(1S)", [+1]),
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)"],
allowed_interaction_types=["strong", "EM"],
formalism_type="canonical-helicity",
)
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
complex_model = model_builder.generate()
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)
This makes it clear that the functions defined in Arbitrary expressions results in a huge speed-up!
new_expression = complex_model.expression.doit()
new_free_symbols = sorted(new_expression.free_symbols, key=lambda s: s.name)
%%time
np_expr = sp.lambdify(new_free_symbols, new_expression)
CPU times: user 4.57 s, sys: 3.16 ms, total: 4.57 s
Wall time: 4.57 s
%%time
np_expr = optimized_lambdify(new_free_symbols, new_expression)
CPU times: user 261 ms, sys: 87 µs, total: 262 ms
Wall time: 260 ms