7.7. Model serialization#
This page demonstrates a strategy for exporting an amplitude model with its suggested parameter defaults to disk and loading it back into memory later on for computations with the computational backend.
Import Python libraries
from __future__ import annotations
import os
import pickle
from textwrap import shorten
import cloudpickle
import jax.numpy as jnp
import sympy as sp
from ampform.sympy import perform_cached_doit
from IPython.display import Markdown
from tensorwaves.function.sympy import create_function
from polarimetry.io import mute_jax_warnings
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles
mute_jax_warnings()
7.7.1. Export model#
Import Python libraries
data_dir = "../../data"
particles = load_particles(f"{data_dir}/particle-definitions.yaml")
model = load_model(f"{data_dir}/model-definitions.yaml", particles, model_id=0)
unfolded_intensity_expr = perform_cached_doit(model.full_expression)
Substitute fixed parameters
free_parameters = {
symbol: value
for symbol, value in model.parameter_defaults.items()
if isinstance(symbol, sp.Indexed)
if "production" in str(symbol)
}
fixed_parameters = {
symbol: value
for symbol, value in model.parameter_defaults.items()
if symbol not in free_parameters
}
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)
def sigma3via12() -> dict[sp.Symbol, sp.Expr]:
s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
m0, m1, m2, m3 = sp.symbols("m:4", nonnegative=True)
return {s3: m1**2 + m2**2 + m3**2 + m0**2 - s1 - s2}
dict_forms = {
"intensity_expr": unfolded_intensity_expr,
"variables": {k: v.doit() for (k, v) in model.variables.items()},
"parameter_defaults": model.parameter_defaults,
"sigma3": sigma3via12(),
}
filename = "exported-model.pkl"
with open(filename, "wb") as f:
cloudpickle.dump(dict_forms, f)
7.7.2. Import model#
The model is saved in a Python dict
and to a pickle
file. The dictionary contains a SymPy expressions for the model and suggested parameter default values. These parameter and variable symbols are substituted using the fully_substitute()
function.
def load_model(filename: str) -> dict:
if not os.path.exists(filename):
msg = f"The input file not found at ${filename}"
raise ValueError(msg)
with open(filename, "rb") as f:
return pickle.load(f)
imported_model = load_model(filename)
intensity_on_2vars = fully_substitute(imported_model)
7.7.2.1. Compilation#
The resulting symbolic expression depends on two variables:
\(\sigma_1 = m_{K\pi}^2\), mass of the \(K^- \pi^+\) system, and
\(\sigma_2 = m_{pK}^2\), mass of the \(p K^-\) system.
This expression is turned into a numerical function by either lambdify()
, using JAX as a computational backend.
For sympy
backend the position argument are used.
s12 = sp.symbols("sigma1:3", nonnegative=True)
assert intensity_on_2vars.free_symbols == set(s12)
func = sp.lambdify(s12, intensity_on_2vars)
func(1.0, 3.0), func(1.1, 3.2)
(np.float64(1156.5307379422927), np.float64(636.1087670531597))
The compilation to JAX is facilitated by tensorwaves
:
density = create_function(intensity_on_2vars, backend="jax")
7.7.3. Serialization with srepr
#
SymPy expressions can directly be serialized to Python code as well, with the function srepr()
. For the full intensity expression, we can do so with:
%%time
eval_str = sp.srepr(unfolded_intensity_expr)
CPU times: user 2.15 s, sys: 4.99 ms, total: 2.16 s
Wall time: 2.16 s
Show code cell source
n_nodes = sp.count_ops(unfolded_intensity_expr)
byt = len(eval_str.encode("utf-8"))
mb = f"{1e-6 * byt:.2f}"
rendering = shorten(eval_str, placeholder=" ...", width=85)
src = f"""
This serializes the intensity expression of {n_nodes:,d} nodes
to a string of **{mb} MB**.
```
{rendering}
```
"""
Markdown(src)
This serializes the intensity expression of 43,198 nodes to a string of 1.04 MB.
Add(Pow(Abs(Add(Mul(Add(Mul(Integer(-1), Pow(Add(Mul(Integer(-1), I, ...
It is up to the user, however, to import the classes of each exported node before the string can be unparsed with eval()
(see this comment).
eval(eval_str)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[16], line 1
----> 1 eval(eval_str)
File <string>:1
NameError: name 'Add' is not defined
In the case of this intensity expression, it is sufficient to import all definition from the main sympy
module and the Str
class.
from sympy import * # noqa: F403
from sympy.core.symbol import Str # noqa: F401
%%time
eval_imported_intensity_expr = eval(eval_str)
CPU times: user 2.82 s, sys: 72 ms, total: 2.89 s
Wall time: 2.89 s
Notice how the imported expression is exactly the same as the serialized one, including assumptions:
Optionally, the import
statements can be embedded into the string. The parsing is then done with exec()
instead:
exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str
def get_intensity_function() -> Expr:
return {eval_str}
"""
exec_filename = "exported_intensity_model.py"
with open(exec_filename, "w") as f:
f.write(exec_str)
See exported_intensity_model.py
for the exported model.
%%time
exec(exec_str)
exec_imported_intensity_expr = get_intensity_function() # noqa: F405
CPU times: user 376 ms, sys: 16 ms, total: 392 ms
Wall time: 390 ms
Note
The load time is faster due to caching within SymPy.
7.7.4. Python package#
As noted on the main page, the source code for this analysis is available as a Python package on PyPI and can be installed as follows.
pip install polarimetry-lc2pkpi
Each of the models can then simply be imported as
import polarimetry
model = polarimetry.published_model("Default amplitude model")
model.intensity.cleanup()
The expressions have to be converted to a numerical function to evaluate them over larger data samples. There are several ways of doing this (such as algebraically substituting the parameter values first), but it depends on your application what is best. Here’s a small example where we want to evaluate the model over a set of data points on the Dalitz plane. We first ‘unfold’ the main intensity expression and lambdify it to a numerical function with JAX as computational backend.
from ampform.sympy import perform_cached_doit
from ampform_dpd.io import perform_cached_lambdify
intensity_expr = perform_cached_doit(model.full_expression)
intensity_func = perform_cached_lambdify(
intensity_expr.doit(),
parameters=model.parameter_defaults,
backend="jax",
)
Cached expression file /home/runner/.cache/ampform/sympy-v1.13.3/5f128eade2e5fe6ee90e1ad294d08808.pkl not found, performing doit()...
Now, let’s say we have some data sample containing generated phase space data points in the Dalitz plane.
Download ZIP data
import urllib.request
from io import BytesIO
from zipfile import ZipFile
url = "https://github.com/ComPWA/polarimetry/files/14123343/lc2pkpi_phsp.zip"
with urllib.request.urlopen(url) as response:
zipfile = ZipFile(BytesIO(response.read()))
zipfile.extractall()
import uproot
with uproot.open("lc2pkpi_phsp.root") as root_file:
events = root_file["Lc2ppiK"]
df = events.arrays(library="pd")
df
msq_piK | msq_Kp | |
---|---|---|
0 | 1.303328 | 3.505217 |
1 | 0.424415 | 4.382733 |
2 | 1.694020 | 2.827971 |
3 | 1.260368 | 3.106184 |
4 | 1.171106 | 3.424632 |
5 | 1.556747 | 3.210350 |
6 | 1.615608 | 2.484730 |
7 | 0.649837 | 3.523832 |
Here, we have data points for the two Mandelstam variables \(\sigma_1\) and \(\sigma_2\). The invariants
attribute of the amplitude model provides symbolic expressions for how to compute the third Mandelstam. In combination with the variables
and parameter_defaults
attributes, we can create a data transformer for computing helicity angles and DPD alignment angles.
Show code cell source
from ampform.io import aslatex
from IPython.display import Math
Math(aslatex(model.invariants))
import sympy as sp
from tensorwaves.data.transform import SympyDataTransformer
*_, (s3, s3_expr) = model.invariants.items()
definitions = model.variables
definitions[s3] = s3_expr
definitions = {
symbol: expr.xreplace(definitions).xreplace(model.parameter_defaults)
for symbol, expr in definitions.items()
}
data_transformer = SympyDataTransformer.from_sympy(definitions, backend="jax")
Finally, we can create an input DataSample
that we can feed to the numerical function for the amplitude model.
data = {
"sigma1": df["msq_piK"].to_numpy(),
"sigma2": df["msq_Kp"].to_numpy(),
}
completed_data = data_transformer(data)
completed_data.update(data)
intensity_func(completed_data)
Array([2158.16752493, 5447.08091414, 6232.40502307, 681.74714874,
902.7663634 , 5626.31808739, 7225.75775956, 2247.93599152], dtype=float64)