Model serialization#
This notebooks illustrates the use of the ampform_dpd.io.serialization
module, which can be used to build amplitude models from the amplitude-serialization initiative.
Warning
The ampform_dpd.io.serialization
module is a preview feature.
Import model#
Show code cell source
from __future__ import annotations
import json
import logging
import os
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import sympy as sp
from ampform.dynamics.form_factor import BlattWeisskopfSquared, FormFactor
from ampform.dynamics.phasespace import BreakupMomentumSquared
from ampform.kinematics.phasespace import Kallen
from ampform.sympy import perform_cached_doit
from IPython.display import JSON, Math
from tqdm.auto import tqdm
from ampform_dpd import DefinedExpression
from ampform_dpd.decay import FinalStateID, State, ThreeBodyDecay
from ampform_dpd.dynamics import (
BreitWigner,
BuggBreitWigner,
ChannelArguments,
EnergyDependentWidth,
MultichannelBreitWigner,
P,
SimpleBreitWigner,
)
from ampform_dpd.io import aslatex, perform_cached_lambdify, simplify_latex_rendering
from ampform_dpd.io.serialization.amplitude import (
HelicityRecoupling,
LSRecoupling,
ParityRecoupling,
formulate,
formulate_aligned_amplitude,
formulate_chain_amplitude,
formulate_recoupling,
)
from ampform_dpd.io.serialization.decay import get_final_state, to_decay
from ampform_dpd.io.serialization.dynamics import (
formulate_breit_wigner,
formulate_dynamics,
formulate_form_factor,
formulate_multichannel_breit_wigner,
to_mandelstam_symbol,
to_mass_symbol,
)
from ampform_dpd.io.serialization.format import (
ModelDefinition,
Propagator,
get_decay_chains,
get_function_definition,
)
logging.getLogger("ampform.sympy").setLevel(logging.ERROR)
simplify_latex_rendering()
Construct ThreeBodyDecay
#
Name-to-LaTeX converter
def to_latex(name: str) -> str:
latex = {
"Lc": R"\Lambda_c^+",
"pi": R"\pi^+",
"K": "K^-",
"p": "p",
}.get(name)
if latex is not None:
return latex
mass_str = name[1:].strip("(").strip(")")
subsystem_letter = name[0]
subsystem = {"D": "D", "K": "K", "L": R"\Lambda"}.get(subsystem_letter)
if subsystem is None:
return name
return f"{subsystem}({mass_str})"
Dynamics#
See also
CHAIN_DEFS = get_decay_chains(MODEL_DEFINITION)
Vertices#
Blatt-Weisskopf form factor#
Show code cell source
z = sp.Symbol("z", nonnegative=True)
s, m1, m2, L, d = sp.symbols("s m1 m2 L R", nonnegative=True)
exprs = [
FormFactor(s, m1, m2, L, d),
BlattWeisskopfSquared(z, L),
BreakupMomentumSquared(s, m1, m2),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
ff_L1520 = formulate_form_factor(
vertex=CHAIN_DEFS[2]["vertices"][0],
model=MODEL_DEFINITION,
)
Math(aslatex(ff_L1520))
Propagators#
Breit-Wigner#
Show code cell source
x, y, z = sp.symbols("x:z")
s, m0, Γ0, m1, m2, L, d = sp.symbols("s m0 Gamma0 m1 m2 L R", nonnegative=True)
exprs = [
BreitWigner(s, m0, Γ0, m1, m2, L, d),
SimpleBreitWigner(s, m0, Γ0),
EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d),
FormFactor(s, m1, m2, L, d),
P(s, m1, m2),
Kallen(x, y, z),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
K892_BW = formulate_breit_wigner(
propagator=CHAIN_DEFS[20]["propagators"][0],
resonance=to_latex(CHAIN_DEFS[20]["name"]),
model=MODEL_DEFINITION,
)
Math(aslatex(K892_BW))
Multi-channel Breit-Wigner#
Show code cell source
x, y, z = sp.symbols("x:z")
s, m0, Γ0, m1, m2, L, d = sp.symbols("s m0 Gamma0 m1 m2 L R", nonnegative=True)
channels = tuple(
ChannelArguments(
s,
m0,
width=sp.Symbol(f"Gamma{i}", nonnegative=True),
m1=sp.Symbol(f"m_{{a,{i}}}", nonnegative=True),
m2=sp.Symbol(f"m_{{b,{i}}}", nonnegative=True),
angular_momentum=sp.Symbol(f"L{i}", integer=True, nonnegative=True),
meson_radius=d,
)
for i in [1, 2]
)
exprs = [
MultichannelBreitWigner(s, m0, channels),
BreitWigner(s, m0, Γ0, m1, m2, L, d),
BreitWigner(s, m0, Γ0),
EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d),
FormFactor(s, m1, m2, L, d),
P(s, m1, m2),
Kallen(x, y, z),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
L1405_Flatte = formulate_multichannel_breit_wigner(
propagator=CHAIN_DEFS[0]["propagators"][0],
resonance=to_latex(CHAIN_DEFS[0]["name"]),
model=MODEL_DEFINITION,
)
Math(aslatex(L1405_Flatte))
Breit-Wigner with exponential#
The model contains one lineshape function that is not standard, so we have to implement a custom propagator dynamics builder for this.
Show code cell source
s, m0, Γ0, m1, m2, γ = sp.symbols("s m0 Gamma0 m1 m2 gamma", nonnegative=True)
expr = BuggBreitWigner(s, m0, Γ0, m1, m2, γ)
Math(aslatex({expr: expr.doit(deep=False)}))
CHAIN_DEFS[18]
get_function_definition("K700_BuggBW", MODEL_DEFINITION)
def formulate_bugg_breit_wigner(
propagator: Propagator, resonance: str, model: ModelDefinition
) -> DefinedExpression:
function_definition = get_function_definition(propagator["parametrization"], model)
node = propagator["node"]
i, j = node
s = to_mandelstam_symbol(node)
mass = sp.Symbol(f"m_{{{resonance}}}", nonnegative=True)
width = sp.Symbol(Rf"\Gamma_{{{resonance}}}", nonnegative=True)
γ = sp.Symbol(Rf"\gamma_{{{resonance}}}", nonnegative=True)
m1 = to_mass_symbol(i)
m2 = to_mass_symbol(j)
final_state = get_final_state(model)
return DefinedExpression(
expression=BuggBreitWigner(s, mass, width, m1, m2, γ),
definitions={
mass: function_definition["mass"],
width: function_definition["width"],
m1: final_state[i].mass,
m2: final_state[j].mass,
γ: function_definition["slope"],
},
)
General propagator dynamics builder#
DYNAMICS_BUILDERS = {
"BreitWignerWidthExpLikeBugg": formulate_bugg_breit_wigner,
}
Show code cell source
exprs = [
formulate_dynamics(CHAIN_DEFS[0], MODEL_DEFINITION, to_latex, DYNAMICS_BUILDERS),
formulate_dynamics(CHAIN_DEFS[18], MODEL_DEFINITION, to_latex, DYNAMICS_BUILDERS),
formulate_dynamics(CHAIN_DEFS[20], MODEL_DEFINITION, to_latex, DYNAMICS_BUILDERS),
]
Math(aslatex(exprs))
Construct AmplitudeModel
#
Unpolarized intensity#
λ0, λ1, λ2, λ3 = sp.symbols("lambda(:4)", rational=True)
amplitude_expr, _ = formulate_aligned_amplitude(MODEL_DEFINITION, λ0, λ1, λ2, λ3)
amplitude_expr.cleanup()
Amplitude for the decay chain#
Helicity recouplings#
Show code cell source
λa = sp.Symbol(R"\lambda_a", rational=True)
λb = sp.Symbol(R"\lambda_b", rational=True)
λa0 = sp.Symbol(R"\lambda_a^0", rational=True)
λb0 = sp.Symbol(R"\lambda_b^0", rational=True)
f = sp.Symbol("f", integer=True)
l = sp.Symbol("l", integer=True, nonnegative=True)
s = sp.Symbol("s", nonnegative=True, rational=True)
ja = sp.Symbol("j_a", nonnegative=True, rational=True)
jb = sp.Symbol("j_b", nonnegative=True, rational=True)
j = sp.Symbol("j", nonnegative=True, rational=True)
exprs = [
HelicityRecoupling(λa, λb, λa0, λb0),
ParityRecoupling(λa, λb, λa0, λb0, f),
LSRecoupling(λa, λb, l, s, ja, jb, j),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
Recoupling deserialization#
Show code cell source
recouplings = [
formulate_recoupling(MODEL_DEFINITION, chain_idx=0, vertex_idx=i) for i in range(2)
]
Math(aslatex({e: e.doit(deep=False) for e in recouplings}))
Chain amplitudes#
definitions = formulate_chain_amplitude(λ0, λ1, λ2, λ3, MODEL_DEFINITION, chain_idx=0)
Math(aslatex(definitions))
Full amplitude model#
MODEL = formulate(
MODEL_DEFINITION,
additional_builders=DYNAMICS_BUILDERS,
cleanup_summations=True,
to_latex=to_latex,
)
MODEL.intensity
Show code cell source
if "EXECUTE_NB" in os.environ:
selected_amplitudes = MODEL.amplitudes
else:
selected_amplitudes = {
k: v for i, (k, v) in enumerate(MODEL.amplitudes.items()) if i < 2
}
Math(aslatex(selected_amplitudes, terms_per_line=1))
Numeric results#
intensity_expr = MODEL.full_expression.xreplace(MODEL.variables)
intensity_expr = intensity_expr.xreplace(MODEL.parameter_defaults)
Lambdify to numeric function
intensity_funcs = {}
for s, s_expr in tqdm(MODEL.invariants.items()):
k = int(str(s)[-1])
s_expr = s_expr.xreplace(MODEL.masses).doit()
expr = perform_cached_doit(intensity_expr.xreplace({s: s_expr}))
func = perform_cached_lambdify(expr, backend="jax")
assert len(func.argument_order) == 2, func.argument_order
intensity_funcs[k] = func
Validation#
Error
The following serves as a numerical check on whether the amplitude model has been deserialized correctly. For now, this is not the case, see ComPWA/ampform-dpd#133 for updates.
checksums = {
misc_key: {checksum["name"]: checksum["value"] for checksum in misc_value}
for misc_key, misc_value in MODEL_DEFINITION["misc"].items()
if "checksum" in misc_key
}
checksums
checksum_points = {
point["name"]: {par["name"]: par["value"] for par in point["parameters"]}
for point in MODEL_DEFINITION["parameter_points"]
}
checksum_points
Show code cell source
array = []
for distribution_name, checksum in checksums.items():
for point_name, expected in checksum.items():
parameters = checksum_points[point_name]
s1 = parameters["m_31_2"] ** 2
s2 = parameters["m_31"] ** 2
computed = intensity_funcs[3]({"sigma1": s1, "sigma2": s2})
status = "🟢" if computed == expected else "🔴"
array.append((distribution_name, point_name, computed, expected, status))
pd.DataFrame(array, columns=["Distribution", "Point", "Computed", "Expected", "Status"])
Dalitz plot#
Define meshgrid for Dalitz plot
resolution = 1_000
m = sorted(MODEL.masses, key=str)
x_min = float(((m[j] + m[k]) ** 2).xreplace(MODEL.masses))
x_max = float(((m[0] - m[i]) ** 2).xreplace(MODEL.masses))
y_min = float(((m[i] + m[k]) ** 2).xreplace(MODEL.masses))
y_max = float(((m[0] - m[j]) ** 2).xreplace(MODEL.masses))
x_diff = x_max - x_min
y_diff = y_max - y_min
x_min -= 0.05 * x_diff
x_max += 0.05 * x_diff
y_min -= 0.05 * y_diff
y_max += 0.05 * y_diff
X, Y = jnp.meshgrid(
jnp.linspace(x_min, x_max, num=resolution),
jnp.linspace(y_min, y_max, num=resolution),
)
dalitz_data = {
f"sigma{i}": X,
f"sigma{j}": Y,
}
intensities = intensity_funcs[k](dalitz_data)
Dalitz plot is not yet correct
def get_decay_products(
decay: ThreeBodyDecay, subsystem_id: FinalStateID
) -> tuple[State, State]:
if subsystem_id not in decay.final_state:
msg = f"Subsystem ID {subsystem_id} is not a valid final state ID"
raise ValueError(msg)
return tuple(s for s in decay.final_state.values() if s.index != subsystem_id)
plt.rc("font", size=18)
I_tot = jnp.nansum(intensities)
normalized_intensities = intensities / I_tot
fig, ax = plt.subplots(figsize=(14, 10))
mesh = ax.pcolormesh(X, Y, normalized_intensities)
ax.set_aspect("equal")
c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)
c_bar.ax.set_ylabel("Normalized intensity (a.u.)")
sigma_labels = {
i: Rf"$\sigma_{i} = M^2\left({' '.join(p.latex for p in get_decay_products(DECAY, i))}\right)$"
for i in (1, 2, 3)
}
ax.set_xlabel(sigma_labels[i])
ax.set_ylabel(sigma_labels[j])
plt.show()