7.7. Model serialization#

This page demonstrates a strategy for exporting an amplitude model with its suggested parameter defaults to disk and loading it back into memory later on for computations with the computational backend.

Hide code cell content
from __future__ import annotations

import os
import pickle
from textwrap import shorten

import cloudpickle
import jax.numpy as jnp
import sympy as sp
from ampform.sympy import perform_cached_doit
from IPython.display import Markdown
from tensorwaves.function.sympy import create_function

from polarimetry.io import mute_jax_warnings
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles

mute_jax_warnings()

7.7.1. Export model#

Hide code cell source
data_dir = "../../data"
particles = load_particles(f"{data_dir}/particle-definitions.yaml")
model = load_model(f"{data_dir}/model-definitions.yaml", particles, model_id=0)
unfolded_intensity_expr = perform_cached_doit(model.full_expression)
Hide code cell source
free_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if isinstance(symbol, sp.Indexed)
    if "production" in str(symbol)
}
fixed_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in free_parameters
}
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)
def sigma3via12() -> dict[sp.Symbol, sp.Expr]:
    s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
    m0, m1, m2, m3 = sp.symbols("m:4", nonnegative=True)
    return {s3: m1**2 + m2**2 + m3**2 + m0**2 - s1 - s2}
dict_forms = {
    "intensity_expr": unfolded_intensity_expr,
    "variables": {k: v.doit() for (k, v) in model.variables.items()},
    "parameter_defaults": model.parameter_defaults,
    "sigma3": sigma3via12(),
}
filename = "exported-model.pkl"
with open(filename, "wb") as f:
    cloudpickle.dump(dict_forms, f)

7.7.2. Import model#

The model is saved in a Python dict and to a pickle file. The dictionary contains a SymPy expressions for the model and suggested parameter default values. These parameter and variable symbols are substituted using the fully_substitute() function.

def load_model(filename: str) -> dict:
    if not os.path.exists(filename):
        msg = f"The input file not found at ${filename}"
        raise ValueError(msg)
    with open(filename, "rb") as f:
        return pickle.load(f)
def fully_substitute(model_description: dict) -> sp.Expr:
    return (
        model_description["intensity_expr"]
        .xreplace(model_description["variables"])
        .xreplace(model_description["sigma3"])
        .xreplace(model_description["parameter_defaults"])
    )
imported_model = load_model(filename)
intensity_on_2vars = fully_substitute(imported_model)

7.7.2.1. Compilation#

The resulting symbolic expression depends on two variables:

  • \(\sigma_1 = m_{K\pi}^2\), mass of the \(K^- \pi^+\) system, and

  • \(\sigma_2 = m_{pK}^2\), mass of the \(p K^-\) system.

This expression is turned into a numerical function by either lambdify(), using JAX as a computational backend.

For sympy backend the position argument are used.

s12 = sp.symbols("sigma1:3", nonnegative=True)
assert intensity_on_2vars.free_symbols == set(s12)

func = sp.lambdify(s12, intensity_on_2vars)
func(1.0, 3.0), func(1.1, 3.2)
(np.float64(1156.5307379422927), np.float64(636.1087670531597))

The compilation to JAX is facilitated by tensorwaves:

density = create_function(intensity_on_2vars, backend="jax")
density({"sigma1": jnp.array([1.0, 1.1]), "sigma2": jnp.array([3.0, 3.2])})
Array([1156.53073794,  636.10876705], dtype=float64)

7.7.3. Serialization with srepr#

SymPy expressions can directly be serialized to Python code as well, with the function srepr(). For the full intensity expression, we can do so with:

%%time
eval_str = sp.srepr(unfolded_intensity_expr)
CPU times: user 2.14 s, sys: 4 ms, total: 2.15 s
Wall time: 2.15 s
Hide code cell source
n_nodes = sp.count_ops(unfolded_intensity_expr)
byt = len(eval_str.encode("utf-8"))
mb = f"{1e-6 * byt:.2f}"
rendering = shorten(eval_str, placeholder=" ...", width=85)
src = f"""
This serializes the intensity expression of {n_nodes:,d} nodes
to a string of **{mb} MB**.

```
{rendering}
```
"""
Markdown(src)

This serializes the intensity expression of 43,198 nodes to a string of 1.05 MB.

Add(Pow(Abs(Add(Mul(Add(Mul(Integer(-1), Pow(Add(Mul(Integer(-1), I, ...

It is up to the user, however, to import the classes of each exported node before the string can be unparsed with eval() (see this comment).

eval(eval_str)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[16], line 1
----> 1 eval(eval_str)

File <string>:1

NameError: name 'Add' is not defined

In the case of this intensity expression, it is sufficient to import all definition from the main sympy module and the Str class.

from sympy import *  # noqa: F403
from sympy.core.symbol import Str  # noqa: F401
%%time
eval_imported_intensity_expr = eval(eval_str)
CPU times: user 2.81 s, sys: 51 ms, total: 2.87 s
Wall time: 2.86 s

Notice how the imported expression is exactly the same as the serialized one, including assumptions:

assert eval_imported_intensity_expr == unfolded_intensity_expr
assert hash(eval_imported_intensity_expr) == hash(unfolded_intensity_expr)

Optionally, the import statements can be embedded into the string. The parsing is then done with exec() instead:

exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str

def get_intensity_function() -> Expr:
    return {eval_str}
"""
exec_filename = "exported_intensity_model.py"
with open(exec_filename, "w") as f:
    f.write(exec_str)

See exported_intensity_model.py for the exported model.

%%time
exec(exec_str)
exec_imported_intensity_expr = get_intensity_function()  # noqa: F405
CPU times: user 369 ms, sys: 21 ms, total: 390 ms
Wall time: 389 ms
assert exec_imported_intensity_expr == unfolded_intensity_expr
assert hash(exec_imported_intensity_expr) == hash(unfolded_intensity_expr)

Note

The load time is faster due to caching within SymPy.

7.7.4. Python package#

PyPI package Supported Python versions

As noted on the main page, the source code for this analysis is available as a Python package on PyPI and can be installed as follows.

pip install polarimetry-lc2pkpi

Each of the models can then simply be imported as

import polarimetry

model = polarimetry.published_model("Default amplitude model")
model.intensity.cleanup()
\[\displaystyle \sum_{\lambda_{0}=-1/2}^{1/2} \sum_{\lambda_{1}=-1/2}^{1/2}{\left|{\sum_{\lambda_0^{\prime}=-1/2}^{1/2} \sum_{\lambda_1^{\prime}=-1/2}^{1/2}{A^{1}_{\lambda_0^{\prime}, \lambda_1^{\prime}, 0, 0} d^{\frac{1}{2}}_{\lambda_1^{\prime},\lambda_{1}}\left(\zeta^1_{1(1)}\right) d^{\frac{1}{2}}_{\lambda_{0},\lambda_0^{\prime}}\left(\zeta^0_{1(1)}\right) + A^{2}_{\lambda_0^{\prime}, \lambda_1^{\prime}, 0, 0} d^{\frac{1}{2}}_{\lambda_1^{\prime},\lambda_{1}}\left(\zeta^1_{2(1)}\right) d^{\frac{1}{2}}_{\lambda_{0},\lambda_0^{\prime}}\left(\zeta^0_{2(1)}\right) + A^{3}_{\lambda_0^{\prime}, \lambda_1^{\prime}, 0, 0} d^{\frac{1}{2}}_{\lambda_1^{\prime},\lambda_{1}}\left(\zeta^1_{3(1)}\right) d^{\frac{1}{2}}_{\lambda_{0},\lambda_0^{\prime}}\left(\zeta^0_{3(1)}\right)}}\right|^{2}}\]

The expressions have to be converted to a numerical function to evaluate them over larger data samples. There are several ways of doing this (such as algebraically substituting the parameter values first), but it depends on your application what is best. Here’s a small example where we want to evaluate the model over a set of data points on the Dalitz plane. We first ‘unfold’ the main intensity expression and lambdify it to a numerical function with JAX as computational backend.

from ampform.sympy import perform_cached_doit
from ampform_dpd.io import perform_cached_lambdify

intensity_expr = perform_cached_doit(model.full_expression)
intensity_func = perform_cached_lambdify(
    intensity_expr.doit(),
    parameters=model.parameter_defaults,
    backend="jax",
)
Cached expression file /home/runner/.cache/ampform/sympy-v1.13.3/49b64a9082d201b24754e806e678a7e8.pkl not found, performing doit()...
Cached function file /home/runner/.cache/ampform_dpd/jax-v0.4.38/33cfe95d184dcc0cc91f96b1b3f3f47e.pkl not found, lambdifying...

Now, let’s say we have some data sample containing generated phase space data points in the Dalitz plane.

Hide code cell source
import urllib.request
from io import BytesIO
from zipfile import ZipFile

url = "https://github.com/ComPWA/polarimetry/files/14123343/lc2pkpi_phsp.zip"
with urllib.request.urlopen(url) as response:
    zipfile = ZipFile(BytesIO(response.read()))
zipfile.extractall()
import uproot

with uproot.open("lc2pkpi_phsp.root") as root_file:
    events = root_file["Lc2ppiK"]
    df = events.arrays(library="pd")
df
msq_piK msq_Kp
0 1.303328 3.505217
1 0.424415 4.382733
2 1.694020 2.827971
3 1.260368 3.106184
4 1.171106 3.424632
5 1.556747 3.210350
6 1.615608 2.484730
7 0.649837 3.523832

Here, we have data points for the two Mandelstam variables \(\sigma_1\) and \(\sigma_2\). The invariants attribute of the amplitude model provides symbolic expressions for how to compute the third Mandelstam. In combination with the variables and parameter_defaults attributes, we can create a data transformer for computing helicity angles and DPD alignment angles.

Hide code cell source
from ampform.io import aslatex
from IPython.display import Math

Math(aslatex(model.invariants))
\[\begin{split}\displaystyle \begin{array}{rcl} \sigma_{1} &=& m_{0}^{2} + m_{1}^{2} + m_{2}^{2} + m_{3}^{2} - \sigma_{2} - \sigma_{3} \\ \sigma_{2} &=& m_{0}^{2} + m_{1}^{2} + m_{2}^{2} + m_{3}^{2} - \sigma_{1} - \sigma_{3} \\ \sigma_{3} &=& m_{0}^{2} + m_{1}^{2} + m_{2}^{2} + m_{3}^{2} - \sigma_{1} - \sigma_{2} \\ \end{array}\end{split}\]
import sympy as sp
from tensorwaves.data.transform import SympyDataTransformer

*_, (s3, s3_expr) = model.invariants.items()
definitions = model.variables
definitions[s3] = s3_expr
definitions = {
    symbol: expr.xreplace(definitions).xreplace(model.parameter_defaults)
    for symbol, expr in definitions.items()
}
data_transformer = SympyDataTransformer.from_sympy(definitions, backend="jax")

Finally, we can create an input DataSample that we can feed to the numerical function for the amplitude model.

data = {
    "sigma1": df["msq_piK"].to_numpy(),
    "sigma2": df["msq_Kp"].to_numpy(),
}
completed_data = data_transformer(data)
completed_data.update(data)
intensity_func(completed_data)
Array([2158.16752493, 5447.08091414, 6232.40502307,  681.74714874,
        902.7663634 , 5626.31808739, 7225.75775956, 2247.93599152],      dtype=float64)