Symbolic model serialization#
Expression trees#
SymPy expressions are built up from symbols and mathematical operations as follows:
x, y, z = sp.symbols("x y z")
expression = sp.sin(x * y) / 2 - x**2 + 1 / z
expression
In the back, SymPy represents these expressions as trees. There are a few ways to visualize this for this particular example:
sp.printing.tree.print_tree(expression, assumptions=False)
Add: -x**2 + sin(x*y)/2 + 1/z
+-Pow: 1/z
| +-Symbol: z
| +-NegativeOne: -1
+-Mul: sin(x*y)/2
| +-Half: 1/2
| +-sin: sin(x*y)
| +-Mul: x*y
| +-Symbol: x
| +-Symbol: y
+-Mul: -x**2
+-NegativeOne: -1
+-Pow: x**2
+-Symbol: x
+-Integer: 2
Expression trees are powerful, because we can use them as templates for any human-readable presentation we are interested in. In fact, the LaTeX representation that we saw when constructing the expression was generated by SymPy’s LaTeX printer.
src = sp.latex(expression)
Markdown(f"```latex\n{src}\n```")
- x^{2} + \frac{\sin{\left(x y \right)}}{2} + \frac{1}{z}
Hint
SymPy expressions can serve as a template for generating code!
Here’s a number of other representations:
# Python
-x**2 + (1/2)*math.sin(x*y) + 1/z
// C++
-std::pow(x, 2) + (1.0/2.0)*std::sin(x*y) + 1.0/z
! Fortran
-x**2 + (1.0d0/2.0d0)*sin(x*y) + 1d0/z
% Matlab / Octave
-x.^2 + sin(x.*y)/2 + 1./z
# Julia
-x .^ 2 + sin(x .* y) / 2 + 1 ./ z
// Rust
-x.powi(2) + (1_f64/2.0)*(x*y).sin() + z.recip()
<!-- MathML -->
<mrow>
<mrow>
<mo>-</mo>
<msup>
<mi>x</mi>
<mn>2</mn>
</msup>
</mrow>
<mo>+</mo>
<mrow>
<mfrac>
<mrow>
<mi>sin</mi>
<mfenced>
<mrow>
<mi>x</mi>
<mo>⁢</mo>
<mi>y</mi>
</mrow>
</mfenced>
</mrow>
<mn>2</mn>
</mfrac>
</mrow>
<mo>+</mo>
<mfrac>
<mn>1</mn>
<mi>z</mi>
</mfrac>
</mrow>
Foldable expressions#
The previous example is quite simple, but SymPy works just as well with huge expressions, as we will see in Large expressions. Before, though, let’s have a look how to define these larger expressions in such a way that we can still read them. A nice solution is to define sp.Expr
classes with the @unevaluated
decorator (see ComPWA/ampform#364). Here, we define a Chew-Mandelstam function \(\rho^\text{CM}\) for \(S\)-waves. This function requires the definition of a break-up momentum \(q\).
@unevaluated(real=False)
class PhspFactorSWave(sp.Expr):
s: sp.Symbol
m1: sp.Symbol
m2: sp.Symbol
_latex_repr_ = R"\rho^\text{{CM}}\left({s}\right)"
def evaluate(self) -> sp.Expr:
s, m1, m2 = self.args
q = BreakupMomentum(s, m1, m2)
cm = (
(2 * q / sp.sqrt(s))
* sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
- (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
) / (16 * sp.pi**2)
return 16 * sp.pi * sp.I * cm
@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
s: sp.Symbol
m1: sp.Symbol
m2: sp.Symbol
_latex_repr_ = R"q\left({s}\right)"
def evaluate(self) -> sp.Expr:
s, m1, m2 = self.args
return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2) / (s * 4))
We now have a very clean mathematical representation of how the \(\rho^\text{CM}\) function is defined in terms of \(q\):
Now, let’s build up a more complicated expression that contains this phase space factor. Here, we use SymPy to derive a Breit-Wigner using a single-channel \(K\) matrix [Chung et al., 1995]:
I = sp.Identity(n=1)
K = sp.MatrixSymbol("K", m=1, n=1)
ρ = sp.MatrixSymbol("rho", m=1, n=1)
T = (I - sp.I * K * ρ).inv() * K
T
T.as_explicit()[0, 0]
Here we need to provide definitions for the matrix elements of \(K\) and \(\rho\). A suitable choice is our phase space factor for \(S\) waves we defined above:
m0, Γ0, γ0 = sp.symbols("m0 Gamma0 gamma0")
K_expr = (γ0**2 * m0 * Γ0) / (s - m0**2)
And there we have it! After some algebraic simplifications, we get a Breit-Wigner with Chew-Mandelstam phase space factor for \(S\) waves:
T_expr = T.as_explicit().xreplace(substitutions)
BW_expr = T_expr[0, 0].simplify(doit=False)
BW_expr
The expression tree now has a node that is ‘folded’:
After unfolding, we get the full expression tree of fundamental mathematical operations:
Large expressions#
Here, we import the large symbolic intensity expression that was used for and see how well SymPy serialization performs on a much more complicated model.
The model contains 43,198 mathematical operations. See ComPWA/polarimetry#319 for the origin of this investigation.
Serialization with srepr
#
SymPy expressions can directly be serialized to Python code as well, with the function srepr()
. For the full intensity expression, we can do so with:
%%time
eval_str = sp.srepr(unfolded_intensity_expr)
CPU times: user 799 ms, sys: 16 μs, total: 799 ms
Wall time: 798 ms
This serializes the intensity expression of 43,198 nodes to a string of 1.04 MB.
Add(Pow(Abs(Add(Mul(Add(Mul(Integer(-1), Pow(Add(Mul(Integer(-1), I, ... ))))))))))
It is up to the user, however, to import the classes of each exported node before the string can be unparsed with eval()
(see this comment).
imported_intensity_expr = eval(eval_str)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[20], line 1
----> 1 imported_intensity_expr = eval(eval_str)
File <string>:1
NameError: name 'Add' is not defined
In the case of this intensity expression, it is sufficient to import all definition from the main sympy
module and the Str
class. Optionally, the required import
statements can be embedded into the string:
exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str
def get_intensity_function() -> Expr:
return {eval_str}
"""
exec_filename = Path("../_static/exported_intensity_model.py")
with open(exec_filename, "w") as f:
f.write(exec_str)
See exported_intensity_model.py
for the exported model.
The parsing is then done with exec()
instead of the eval()
function:
%%time
exec(exec_str)
imported_intensity_expr = get_intensity_function()
CPU times: user 464 ms, sys: 27 ms, total: 491 ms
Wall time: 488 ms
Notice how the imported expression is exactly the same as the serialized one, including assumptions:
Common sub-expressions#
A problem is that the expression exported generated with srepr()
is not human-readable in practice for large expressions. One way out may be to extract common components of the main expression with Foldable expressions. Another may be to use SymPy to detect and collect common sub-expressions.
sub_exprs, common_expr = sp.cse(unfolded_intensity_expr, order="none")
This already works quite well with sp.lambdify
(without cse=True
, this would takes minutes):
%%time
args = sorted(unfolded_intensity_expr.free_symbols, key=str)
_ = sp.lambdify(args, unfolded_intensity_expr, cse=True, dummify=True)
CPU times: user 1.26 s, sys: 6.99 ms, total: 1.27 s
Wall time: 1.27 s
Still, as can be seen above, there are many sub-expressions that have exactly the same form. It would be better to find those expressions that have a similar structure, so that we can serialize them to functions or custom sub-definitions.
In SymPy, the equivalence between the expressions can be determined by the match()
method using Wild
symbols. We therefore first have to make all symbols in the common sub-expressions ‘wild’. In addition, in the case of this intensity expression, some of symbols are indexed and need to be replaced first.
pure_symbol_expr = unfolded_intensity_expr.replace(
query=lambda z: isinstance(z, sp.Indexed),
value=lambda z: sp.Symbol(sp.latex(z), **z.assumptions0),
)
sub_exprs, common_expr = sp.cse(pure_symbol_expr, order="none")
Note that for example the following two common sub-expressions are equivalent:
Wild
symbols now allow us to find how these expressions relate to each other.
is_symbol = lambda z: isinstance(z, sp.Symbol)
make_wild = lambda z: sp.Wild(z.name)
X = [x.replace(is_symbol, make_wild) for _, x in sub_exprs]
Math(aslatex(X[5].match(X[8])))
Hint
This can be used to define functions for larger, common expression blocks.