Model serialization#

This notebooks illustrates the use of the ampform_dpd.io.serialization module, which can be used to build amplitude models from the amplitude-serialization initiative.

Warning

The ampform_dpd.io.serialization module is a preview feature.

Import model#

Hide code cell source
with open("Lc2ppiK.json") as stream:
    MODEL_DEFINITION = json.load(stream)

Construct ThreeBodyDecay#

Hide code cell source
def to_latex(name: str) -> str:
    latex = {
        "Lc": R"\Lambda_c^+",
        "pi": R"\pi^+",
        "K": "K^-",
        "p": "p",
    }.get(name)
    if latex is not None:
        return latex
    mass_str = name[1:].strip("(").strip(")")
    subsystem_letter = name[0]
    subsystem = {"D": "D", "K": "K", "L": R"\Lambda"}.get(subsystem_letter)
    if subsystem is None:
        return name
    return f"{subsystem}({mass_str})"
DECAY = to_decay(MODEL_DEFINITION, to_latex=to_latex)
Math(aslatex(DECAY, with_jp=True))

Dynamics#

CHAIN_DEFS = get_decay_chains(MODEL_DEFINITION)

Vertices#

Blatt-Weisskopf form factor#

Hide code cell source
z = sp.Symbol("z", nonnegative=True)
s, m1, m2, L, d = sp.symbols("s m1 m2 L R", nonnegative=True)
exprs = [
    FormFactor(s, m1, m2, L, d),
    BlattWeisskopfSquared(z, L),
    BreakupMomentumSquared(s, m1, m2),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
ff_L1520 = formulate_form_factor(
    vertex=CHAIN_DEFS[2]["vertices"][0],
    model=MODEL_DEFINITION,
)
Math(aslatex(ff_L1520))

Propagators#

Breit-Wigner#

Hide code cell source
x, y, z = sp.symbols("x:z")
s, m0, Γ0, m1, m2, L, d = sp.symbols("s m0 Gamma0 m1 m2 L R", nonnegative=True)
exprs = [
    BreitWigner(s, m0, Γ0, m1, m2, L, d),
    SimpleBreitWigner(s, m0, Γ0),
    EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d),
    FormFactor(s, m1, m2, L, d),
    P(s, m1, m2),
    Kallen(x, y, z),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
K892_BW = formulate_breit_wigner(
    propagator=CHAIN_DEFS[20]["propagators"][0],
    resonance=to_latex(CHAIN_DEFS[20]["name"]),
    model=MODEL_DEFINITION,
)
Math(aslatex(K892_BW))

Multi-channel Breit-Wigner#

Hide code cell source
x, y, z = sp.symbols("x:z")
s, m0, Γ0, m1, m2, L, d = sp.symbols("s m0 Gamma0 m1 m2 L R", nonnegative=True)
channels = tuple(
    ChannelArguments(
        s,
        m0,
        width=sp.Symbol(f"Gamma{i}", nonnegative=True),
        m1=sp.Symbol(f"m_{{a,{i}}}", nonnegative=True),
        m2=sp.Symbol(f"m_{{b,{i}}}", nonnegative=True),
        angular_momentum=sp.Symbol(f"L{i}", integer=True, nonnegative=True),
        meson_radius=d,
    )
    for i in [1, 2]
)
exprs = [
    MultichannelBreitWigner(s, m0, channels),
    BreitWigner(s, m0, Γ0, m1, m2, L, d),
    BreitWigner(s, m0, Γ0),
    EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d),
    FormFactor(s, m1, m2, L, d),
    P(s, m1, m2),
    Kallen(x, y, z),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))
L1405_Flatte = formulate_multichannel_breit_wigner(
    propagator=CHAIN_DEFS[0]["propagators"][0],
    resonance=to_latex(CHAIN_DEFS[0]["name"]),
    model=MODEL_DEFINITION,
)
Math(aslatex(L1405_Flatte))

Breit-Wigner with exponential#

The model contains one lineshape function that is not standard, so we have to implement a custom propagator dynamics builder for this.

Hide code cell source
s, m0, Γ0, m1, m2, γ = sp.symbols("s m0 Gamma0 m1 m2 gamma", nonnegative=True)
expr = BuggBreitWigner(s, m0, Γ0, m1, m2, γ)
Math(aslatex({expr: expr.doit(deep=False)}))
CHAIN_DEFS[18]
get_function_definition("K700_BuggBW", MODEL_DEFINITION)
def formulate_bugg_breit_wigner(
    propagator: Propagator, resonance: str, model: ModelDefinition
) -> DefinedExpression:
    function_definition = get_function_definition(propagator["parametrization"], model)
    node = propagator["node"]
    i, j = node
    s = to_mandelstam_symbol(node)
    mass = sp.Symbol(f"m_{{{resonance}}}", nonnegative=True)
    width = sp.Symbol(Rf"\Gamma_{{{resonance}}}", nonnegative=True)
    γ = sp.Symbol(Rf"\gamma_{{{resonance}}}", nonnegative=True)
    m1 = to_mass_symbol(i)
    m2 = to_mass_symbol(j)
    final_state = get_final_state(model)
    return DefinedExpression(
        expression=BuggBreitWigner(s, mass, width, m1, m2, γ),
        definitions={
            mass: function_definition["mass"],
            width: function_definition["width"],
            m1: final_state[i].mass,
            m2: final_state[j].mass,
            γ: function_definition["slope"],
        },
    )
CHAIN_18 = CHAIN_DEFS[18]
K700_BuggBW = formulate_bugg_breit_wigner(
    propagator=CHAIN_18["propagators"][0],
    resonance=to_latex(CHAIN_18["name"]),
    model=MODEL_DEFINITION,
)
Math(aslatex(K700_BuggBW))

General propagator dynamics builder#

DYNAMICS_BUILDERS = {
    "BreitWignerWidthExpLikeBugg": formulate_bugg_breit_wigner,
}
Hide code cell source
exprs = [
    formulate_dynamics(CHAIN_DEFS[0], MODEL_DEFINITION, to_latex, DYNAMICS_BUILDERS),
    formulate_dynamics(CHAIN_DEFS[18], MODEL_DEFINITION, to_latex, DYNAMICS_BUILDERS),
    formulate_dynamics(CHAIN_DEFS[20], MODEL_DEFINITION, to_latex, DYNAMICS_BUILDERS),
]
Math(aslatex(exprs))

Construct AmplitudeModel#

Unpolarized intensity#

λ0, λ1, λ2, λ3 = sp.symbols("lambda(:4)", rational=True)
amplitude_expr, _ = formulate_aligned_amplitude(MODEL_DEFINITION, λ0, λ1, λ2, λ3)
amplitude_expr.cleanup()

Amplitude for the decay chain#

Helicity recouplings#

Hide code cell source
λa = sp.Symbol(R"\lambda_a", rational=True)
λb = sp.Symbol(R"\lambda_b", rational=True)
λa0 = sp.Symbol(R"\lambda_a^0", rational=True)
λb0 = sp.Symbol(R"\lambda_b^0", rational=True)
f = sp.Symbol("f", integer=True)
l = sp.Symbol("l", integer=True, nonnegative=True)
s = sp.Symbol("s", nonnegative=True, rational=True)
ja = sp.Symbol("j_a", nonnegative=True, rational=True)
jb = sp.Symbol("j_b", nonnegative=True, rational=True)
j = sp.Symbol("j", nonnegative=True, rational=True)
exprs = [
    HelicityRecoupling(λa, λb, λa0, λb0),
    ParityRecoupling(λa, λb, λa0, λb0, f),
    LSRecoupling(λa, λb, l, s, ja, jb, j),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))

Recoupling deserialization#

Hide code cell source
recouplings = [
    formulate_recoupling(MODEL_DEFINITION, chain_idx=0, vertex_idx=i) for i in range(2)
]
Math(aslatex({e: e.doit(deep=False) for e in recouplings}))

Chain amplitudes#

definitions = formulate_chain_amplitude(λ0, λ1, λ2, λ3, MODEL_DEFINITION, chain_idx=0)
Math(aslatex(definitions))

Full amplitude model#

MODEL = formulate(
    MODEL_DEFINITION,
    additional_builders=DYNAMICS_BUILDERS,
    cleanup_summations=True,
    to_latex=to_latex,
)
MODEL.intensity
Hide code cell source
if "EXECUTE_NB" in os.environ:
    selected_amplitudes = MODEL.amplitudes
else:
    selected_amplitudes = {
        k: v for i, (k, v) in enumerate(MODEL.amplitudes.items()) if i < 2
    }
Math(aslatex(selected_amplitudes, terms_per_line=1))
Hide code cell source
Math(aslatex(MODEL.variables))
Hide code cell source
Math(aslatex({**MODEL.invariants, **MODEL.masses}))

Numeric results#

intensity_expr = MODEL.full_expression.xreplace(MODEL.variables)
intensity_expr = intensity_expr.xreplace(MODEL.parameter_defaults)
Hide code cell source
intensity_funcs = {}
for s, s_expr in tqdm(MODEL.invariants.items()):
    k = int(str(s)[-1])
    s_expr = s_expr.xreplace(MODEL.masses).doit()
    expr = perform_cached_doit(intensity_expr.xreplace({s: s_expr}))
    func = perform_cached_lambdify(expr, backend="jax")
    assert len(func.argument_order) == 2, func.argument_order
    intensity_funcs[k] = func

Validation#

Error

The following serves as a numerical check on whether the amplitude model has been deserialized correctly. For now, this is not the case, see ComPWA/ampform-dpd#133 for updates.

checksums = {
    misc_key: {checksum["name"]: checksum["value"] for checksum in misc_value}
    for misc_key, misc_value in MODEL_DEFINITION["misc"].items()
    if "checksum" in misc_key
}
checksums
checksum_points = {
    point["name"]: {par["name"]: par["value"] for par in point["parameters"]}
    for point in MODEL_DEFINITION["parameter_points"]
}
checksum_points
Hide code cell source
array = []
for distribution_name, checksum in checksums.items():
    for point_name, expected in checksum.items():
        parameters = checksum_points[point_name]
        s1 = parameters["m_31_2"] ** 2
        s2 = parameters["m_31"] ** 2
        computed = intensity_funcs[3]({"sigma1": s1, "sigma2": s2})
        status = "🟢" if computed == expected else "🔴"
        array.append((distribution_name, point_name, computed, expected, status))
pd.DataFrame(array, columns=["Distribution", "Point", "Computed", "Expected", "Status"])

Dalitz plot#

Hide code cell source
i, j = (2, 1)
k, *_ = {1, 2, 3} - {i, j}
σk, σk_expr = list(MODEL.invariants.items())[k - 1]
Math(aslatex({σk: σk_expr}))
Hide code cell source
resolution = 1_000
m = sorted(MODEL.masses, key=str)
x_min = float(((m[j] + m[k]) ** 2).xreplace(MODEL.masses))
x_max = float(((m[0] - m[i]) ** 2).xreplace(MODEL.masses))
y_min = float(((m[i] + m[k]) ** 2).xreplace(MODEL.masses))
y_max = float(((m[0] - m[j]) ** 2).xreplace(MODEL.masses))
x_diff = x_max - x_min
y_diff = y_max - y_min
x_min -= 0.05 * x_diff
x_max += 0.05 * x_diff
y_min -= 0.05 * y_diff
y_max += 0.05 * y_diff
X, Y = jnp.meshgrid(
    jnp.linspace(x_min, x_max, num=resolution),
    jnp.linspace(y_min, y_max, num=resolution),
)
dalitz_data = {
    f"sigma{i}": X,
    f"sigma{j}": Y,
}
intensities = intensity_funcs[k](dalitz_data)
Hide code cell source
def get_decay_products(
    decay: ThreeBodyDecay, subsystem_id: FinalStateID
) -> tuple[State, State]:
    if subsystem_id not in decay.final_state:
        msg = f"Subsystem ID {subsystem_id} is not a valid final state ID"
        raise ValueError(msg)
    return tuple(s for s in decay.final_state.values() if s.index != subsystem_id)


plt.rc("font", size=18)
I_tot = jnp.nansum(intensities)
normalized_intensities = intensities / I_tot

fig, ax = plt.subplots(figsize=(14, 10))
mesh = ax.pcolormesh(X, Y, normalized_intensities)
ax.set_aspect("equal")
c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)
c_bar.ax.set_ylabel("Normalized intensity (a.u.)")
sigma_labels = {
    i: Rf"$\sigma_{i} = M^2\left({' '.join(p.latex for p in get_decay_products(DECAY, i))}\right)$"
    for i in (1, 2, 3)
}
ax.set_xlabel(sigma_labels[i])
ax.set_ylabel(sigma_labels[j])
plt.show()