Amplitude model with sympy#

This section is a follow-up of the previous chapter, Reaction and Models, to formulate the amplitude model for the \(\gamma p \to \eta\pi^0 p\) channel example symbolically. See TR‑033 for a purely numerical tutorial.

The model we want to implement is

\[\begin{split} \begin{array}{rcl} I &=& \left|A^{12} + A^{23} + A^{31}\right|^2 \\ A^{12} &=& \frac{\sum a_m Y_2^m (\Omega_1)}{s_{12}-m^2_{a_2}+im_{a_2} \Gamma_{a_2}} \\ A^{23} &=& \frac{\sum b_m Y_1^m (\Omega_2)}{s_{23}-m^2_{\Delta^+}+im_{\Delta^+} \Gamma_{\Delta^+}} \\ A^{31} &=& \frac{c_0}{s_{31}-m^2_{N^*}+im_{N^*} \Gamma_{N^*}} \,, \end{array} \end{split}\]

where \(1=\eta\), \(2=\pi^0\), and \(3=p\).

Hide code cell content
from __future__ import annotations

import logging
import os
import warnings
from collections import defaultdict

import ipywidgets as w
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.io import aslatex
from ampform.kinematics.angles import Phi, Theta
from ampform.kinematics.lorentz import (
    ArrayMultiplication,
    ArraySize,
    BoostZMatrix,
    Energy,
    EuclideanNorm,
    FourMomentumSymbol,
    RotationYMatrix,
    RotationZMatrix,
    ThreeMomentum,
    three_momentum_norm,
)
from ampform.sympy import unevaluated
from ampform.sympy._array_expressions import ArraySum
from IPython.display import SVG, Image, Latex, display
from tensorwaves.data import (
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)
from tensorwaves.function.sympy import create_parametrized_function

STATIC_PAGE = "EXECUTE_NB" in os.environ

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.disable(logging.WARNING)
warnings.filterwarnings("ignore")

Model implementation#

l_max = 2

\(A^{12}\)#

s12, m_a2, gamma_a2, l12 = sp.symbols(r"s_{12} m_{a_2} \Gamma_{a_2} l_{12}")
theta1, phi1 = sp.symbols("theta_1 phi_1")
a = sp.IndexedBase("a")
m = sp.symbols("m", cls=sp.Idx)
A12 = sp.Sum(a[m] * sp.Ynm(l12, m, theta1, phi1), (m, -l12, l12)) / (
    s12 - m_a2**2 + sp.I * m_a2 * gamma_a2
)
A12
\[\displaystyle \frac{\sum_{m=- l_{12}}^{l_{12}} Y_{l_{12}}^{m}\left(\theta_{1},\phi_{1}\right) {a}_{m}}{i \Gamma_{a_2} m_{a_2} - m_{a_2}^{2} + s_{12}}\]
A12_funcs = [
    sp.lambdify(
        [s12, *(a[j] for j in range(-l_max, l_max + 1)), m_a2, gamma_a2, theta1, phi1],
        expr=A12.subs(l12, i).doit().expand(func=True),
    )
    for i in range(l_max + 1)
]

\(A^{23}\)#

s23, m_delta, gamma_delta, l23 = sp.symbols(
    r"s_{23} m_{\Delta^+} \Gamma_{\Delta^+} l_{23}"
)
b = sp.IndexedBase("b")
m = sp.symbols("m", cls=sp.Idx)
theta2, phi2 = sp.symbols("theta_2 phi_2")
A23 = sp.Sum(b[m] * sp.Ynm(l23, m, theta2, phi2), (m, -l23, l23)) / (
    s23 - m_delta**2 + sp.I * m_delta * gamma_delta
)
A23
\[\displaystyle \frac{\sum_{m=- l_{23}}^{l_{23}} Y_{l_{23}}^{m}\left(\theta_{2},\phi_{2}\right) {b}_{m}}{i \Gamma_{\Delta^+} m_{\Delta^+} - \left(m_{\Delta^+}\right)^{2} + s_{23}}\]
A23_funcs = [
    sp.lambdify(
        [
            s23,
            *(b[j] for j in range(-l_max, l_max + 1)),
            m_delta,
            gamma_delta,
            theta2,
            phi2,
        ],
        A23.subs(l23, i).doit().expand(func=True),
    )
    for i in range(l_max + 1)
]

\(A^{31}\)#

c = sp.IndexedBase("c")
s31, m_nstar, gamma_nstar = sp.symbols(r"s_{31} m_{N^*} \Gamma_{N^*}")
theta3, phi3, l31 = sp.symbols("theta_3 phi_3 l_{31}")
A31 = sp.Sum(c[m] * sp.Ynm(l31, m, theta3, phi3), (m, -l31, l31)) / (
    s31 - m_nstar**2 + sp.I * m_nstar * gamma_nstar
)
A31
\[\displaystyle \frac{\sum_{m=- l_{31}}^{l_{31}} Y_{l_{31}}^{m}\left(\theta_{3},\phi_{3}\right) {c}_{m}}{i \Gamma_{N^*} m_{N^*} - \left(m_{N^*}\right)^{2} + s_{31}}\]
A31_funcs = [
    sp.lambdify(
        [
            s31,
            *(c[j] for j in range(-l_max, l_max + 1)),
            m_nstar,
            gamma_nstar,
            theta3,
            phi3,
        ],
        A31.subs(l31, i).doit().expand(func=True),
    )
    for i in range(l_max + 1)
]

\(I = |A|^2 = |A^{12}+A^{23}+A^{31}|^2\)#

intensity_expr = sp.Abs(A12 + A23 + A31) ** 2
intensity_expr
\[\displaystyle \left|{\frac{\sum_{m=- l_{12}}^{l_{12}} Y_{l_{12}}^{m}\left(\theta_{1},\phi_{1}\right) {a}_{m}}{i \Gamma_{a_2} m_{a_2} - m_{a_2}^{2} + s_{12}} + \frac{\sum_{m=- l_{23}}^{l_{23}} Y_{l_{23}}^{m}\left(\theta_{2},\phi_{2}\right) {b}_{m}}{i \Gamma_{\Delta^+} m_{\Delta^+} - \left(m_{\Delta^+}\right)^{2} + s_{23}} + \frac{\sum_{m=- l_{31}}^{l_{31}} Y_{l_{31}}^{m}\left(\theta_{3},\phi_{3}\right) {c}_{m}}{i \Gamma_{N^*} m_{N^*} - \left(m_{N^*}\right)^{2} + s_{31}}}\right|^{2}\]

Phase Space Generation#

Mass for \(p\gamma\) system

E_lab_gamma = 8.5
m_proton = 0.938
m_0 = np.sqrt(2 * E_lab_gamma * m_proton + m_proton**2)
m_eta = 0.548
m_pi = 0.135
m_0
4.101931740046389
rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=m_0,
    final_state_masses={1: m_eta, 2: m_pi, 3: m_proton},
)
phsp_momenta = phsp_generator.generate(500_000, rng)
Hide code cell output

Kinematic variables#

Hide code cell content
@unevaluated
class SquaredInvariantMass(sp.Expr):
    momentum: sp.Basic
    _latex_repr_ = "m_{{{momentum}}}^2"

    def evaluate(self) -> sp.Expr:
        p = self.momentum
        p_xyz = ThreeMomentum(p)
        return Energy(p) ** 2 - EuclideanNorm(p_xyz) ** 2


def formulate_helicity_angles(
    pi: FourMomentumSymbol, pj: FourMomentumSymbol
) -> tuple[Theta, Phi]:
    pij = ArraySum(pi, pj)
    beta = three_momentum_norm(pij) / Energy(pij)
    Rz = RotationZMatrix(-Phi(pij), n_events=ArraySize(Phi(pij)))
    Ry = RotationYMatrix(-Theta(pij), n_events=ArraySize(Theta(pij)))
    Bz = BoostZMatrix(beta, n_events=ArraySize(beta))
    pi_boosted = ArrayMultiplication(Bz, Ry, Rz, pi)
    return Theta(pi_boosted), Phi(pi_boosted)
p1 = FourMomentumSymbol("p1", shape=[])
p2 = FourMomentumSymbol("p2", shape=[])
p3 = FourMomentumSymbol("p3", shape=[])
p12 = ArraySum(p1, p2)
p23 = ArraySum(p2, p3)
p31 = ArraySum(p3, p1)

theta1_expr, phi1_expr = formulate_helicity_angles(p1, p2)
theta2_expr, phi2_expr = formulate_helicity_angles(p2, p3)
theta3_expr, phi3_expr = formulate_helicity_angles(p3, p1)

s12_expr = SquaredInvariantMass(p12)
s23_expr = SquaredInvariantMass(p23)
s31_expr = SquaredInvariantMass(p31)
Hide code cell source
kinematic_variables = {
    theta1: theta1_expr,
    theta2: theta2_expr,
    theta3: theta3_expr,
    phi1: phi1_expr,
    phi2: phi2_expr,
    phi3: phi3_expr,
    s12: s12_expr,
    s23: s23_expr,
    s31: s31_expr,
}

Latex(aslatex(kinematic_variables))
\[\begin{split}\begin{array}{rcl} \theta_{1} &=& \theta\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{12}}\right|}{E\left({p}_{12}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{12}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{12}\right)\right) p_{1}\right) \\ \theta_{2} &=& \theta\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{23}}\right|}{E\left({p}_{23}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{23}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{23}\right)\right) p_{2}\right) \\ \theta_{3} &=& \theta\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{31}}\right|}{E\left({p}_{31}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{31}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{31}\right)\right) p_{3}\right) \\ \phi_{1} &=& \phi\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{12}}\right|}{E\left({p}_{12}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{12}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{12}\right)\right) p_{1}\right) \\ \phi_{2} &=& \phi\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{23}}\right|}{E\left({p}_{23}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{23}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{23}\right)\right) p_{2}\right) \\ \phi_{3} &=& \phi\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{31}}\right|}{E\left({p}_{31}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{31}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{31}\right)\right) p_{3}\right) \\ s_{12} &=& m_{{p}_{12}}^2 \\ s_{23} &=& m_{{p}_{23}}^2 \\ s_{31} &=& m_{{p}_{31}}^2 \\ \end{array}\end{split}\]
helicity_transformer = SympyDataTransformer.from_sympy(
    kinematic_variables, backend="jax"
)
phsp = helicity_transformer(phsp_momenta)
list(phsp)
['theta_1',
 'theta_2',
 'theta_3',
 'phi_1',
 'phi_2',
 'phi_3',
 's_{12}',
 's_{23}',
 's_{31}']

Parameters#

Hide code cell source
a_vals = [0, 0.5, 3.5, 4, 2.5]
b_vals = [0, -1.5, 4, 0.5, 0]
c_vals = [0, 0, 2.5, 0, 0]
m_a2_val = 1.32
m_delta_val = 1.54
m_nstar_val = 1.87
gamma_a2_val = 0.1
gamma_delta_val = 0.1
gamma_nstar_val = 0.1
l12_val = 2
l23_val = 1
l31_val = 0

parameters_default = {
    m_a2: m_a2_val,
    m_delta: m_delta_val,
    m_nstar: m_nstar_val,
    gamma_a2: gamma_a2_val,
    gamma_delta: gamma_delta_val,
    gamma_nstar: gamma_nstar_val,
    l12: l12_val,
    l23: l23_val,
    l31: l31_val,
}

a_dict = {a[i]: a_vals[i + l_max] for i in range(-l_max, l_max + 1)}
b_dict = {b[i]: b_vals[i + l_max] for i in range(-l_max, l_max + 1)}
c_dict = {c[i]: c_vals[i + l_max] for i in range(-l_max, l_max + 1)}
parameters_default.update(a_dict)
parameters_default.update(b_dict)
parameters_default.update(c_dict)

Latex(aslatex(parameters_default))
Hide code cell output
\[\begin{split}\begin{array}{rcl} m_{a_2} &=& 1.32 \\ m_{\Delta^+} &=& 1.54 \\ m_{N^*} &=& 1.87 \\ \Gamma_{a_2} &=& 0.1 \\ \Gamma_{\Delta^+} &=& 0.1 \\ \Gamma_{N^*} &=& 0.1 \\ l_{12} &=& 2 \\ l_{23} &=& 1 \\ l_{31} &=& 0 \\ {a}_{-2} &=& 0 \\ {a}_{-1} &=& 0.5 \\ {a}_{0} &=& 3.5 \\ {a}_{1} &=& 4 \\ {a}_{2} &=& 2.5 \\ {b}_{-2} &=& 0 \\ {b}_{-1} &=& -1.5 \\ {b}_{0} &=& 4 \\ {b}_{1} &=& 0.5 \\ {b}_{2} &=& 0 \\ {c}_{-2} &=& 0 \\ {c}_{-1} &=& 0 \\ {c}_{0} &=& 2.5 \\ {c}_{1} &=& 0 \\ {c}_{2} &=& 0 \\ \end{array}\end{split}\]

Note

The mass and width of each resonance is customized to make the resonance bands in the Dalitz plot more visible.

Visualization#

Model components#

phi = np.pi / 4
theta = np.pi / 4
s_values = np.linspace(0, 5, num=500)
A12_values = A12_funcs[2](s_values, *a_vals, m_a2_val, gamma_a2_val, theta, phi)
A23_values = A23_funcs[1](s_values, *b_vals, m_delta_val, gamma_delta_val, theta, phi)
A31_values = A31_funcs[0](s_values, *c_vals, m_nstar_val, gamma_nstar_val, theta, phi)
Hide code cell source
fig, axes = plt.subplots(figsize=(10, 9), nrows=3)
angle_projection_text = R"$\phi=\frac{\pi}{4},\theta=\frac{\pi}{4}$"
for i, A_values in enumerate([A12_values, A23_values, A31_values]):
    ax = axes[i]
    recoil_id = i + 1
    decay_products = sorted({1, 2, 3} - {recoil_id})
    subsystem = "".join(map(str, decay_products))
    ax.plot(s_values, A_values.imag, label="Imaginary Part", c="blue", ls="-")
    ax.plot(s_values, A_values.real, label="Real Part", c="red", ls="--")
    ax.plot(
        s_values,
        np.abs(A_values) ** 2,
        label=f"Intensity $I^{{{subsystem}}}=|A^{{{subsystem}}}|^2$",
        c="g",
        ls="-.",
    )
    ax.plot(
        s_values,
        np.angle(A_values),
        label="Phase (angle)",
        c="m",
        ls=":",
    )
    ax.set_title(Rf"Components of $A^{{{subsystem}}}$ vs s ({angle_projection_text})")
    ax.set_xlabel(f"$s_{{{subsystem}}}$")
    ax.set_ylabel(Rf"$A^{{{subsystem}}}$ components")
    ax.legend()
fig.tight_layout()
plt.show()
../_images/26efc53239754edf599557cedfa9505a23e561f72bddb3bb68851aea6b29781c.svg

Unitarity is preserved in each of the subsystems (assuming fixed \(\phi,\theta\)), because we assume there is only one resonance in the subsystem.

Hide code cell source
fig, axes = plt.subplots(figsize=(12, 4), ncols=3)
for i, A_values in enumerate([A12_values, A23_values, A31_values]):
    ax = axes[i]
    recoil_id = i + 1
    decay_products = sorted({1, 2, 3} - {recoil_id})
    subsystem = "".join(map(str, decay_products))
    ax.set_xlabel(Rf"$\mathrm{{Re}}\,A^{{{subsystem}}}$")
    ax.set_ylabel(Rf"$\mathrm{{Im}}\,A^{{{subsystem}}}$")
    ax.set_title(Rf"$A^{{{subsystem}}}$ ({angle_projection_text})")
    ax.plot(A_values.real, A_values.imag)
fig.suptitle("Argand diagrams")
fig.tight_layout()
plt.show(fig)
../_images/b624827761d652821adf7de4ebf01d4b89c17e2e461f1dbcb7c3bdbe1d0e3e98.svg

Dalitz Plot#

Hide code cell source
def format_l_as_unicode(name: str) -> str:
    replacements = {
        "l": "ℓ",
        "1": "₁",
        "2": "₂",
        "3": "₃",
        "_{": "",
        "}": "",
    }
    for old, new in replacements.items():
        name = name.replace(old, new)
    return name


sliders = {}
categorized_sliders_m = defaultdict(list)
categorized_sliders_gamma = defaultdict(list)
categorized_cphi_pair = defaultdict(list)
categorized_sliders_l = []

for symbol, value in parameters_default.items():
    if symbol.name.startswith(R"\Gamma_{"):
        slider = w.FloatSlider(
            description=Rf"\({sp.latex(symbol)}\)",
            min=0.0,
            max=1.0,
            step=0.01,
            value=value,
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        if symbol.name.startswith(R"\Gamma_{N"):
            categorized_sliders_gamma[0].append(slider)
        elif symbol.name.startswith(R"\Gamma_{\D"):
            categorized_sliders_gamma[1].append(slider)
        elif symbol.name.startswith(R"\Gamma_{a"):
            categorized_sliders_gamma[2].append(slider)

    elif symbol.name.startswith("m_{"):
        slider = w.FloatSlider(
            description=Rf"\({sp.latex(symbol)}\)",
            min=0.63,
            max=4,
            step=0.01,
            value=value,
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        if symbol.name.startswith("m_{N"):
            categorized_sliders_m[0].append(slider)
        elif symbol.name.startswith(R"m_{\D"):
            categorized_sliders_m[1].append(slider)
        elif symbol.name.startswith("m_{a"):
            categorized_sliders_m[2].append(slider)

    elif isinstance(symbol, sp.Indexed):
        c_latex = sp.latex(symbol)
        phi_latex = Rf"\phi_{{{c_latex}}}"
        slider_c = w.FloatSlider(
            description=Rf"\({c_latex}\)",
            min=0,
            max=10,
            step=0.01,
            value=abs(value),
            continuous_update=False,
        )
        slider_phi = w.FloatSlider(
            description=Rf"\({phi_latex}\)",
            min=-np.pi,
            max=+np.pi,
            step=0.01,
            value=np.angle(value),
            continuous_update=False,
        )

        sliders[symbol.name] = slider_c
        sliders[f"phi_{symbol.name}"] = slider_phi
        cphi_hbox = w.HBox([slider_c, slider_phi])
        if symbol.base is a:
            categorized_cphi_pair[2].append(cphi_hbox)
        elif symbol.base is b:
            categorized_cphi_pair[1].append(cphi_hbox)
        elif symbol.base is c:
            categorized_cphi_pair[0].append(cphi_hbox)
        else:
            raise NotImplementedError(symbol.name)

    elif symbol.name.startswith("l_{"):
        slider = w.IntSlider(
            value=1,
            min=0,
            max=2,
            description=format_l_as_unicode(symbol.name),
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        categorized_sliders_l.append(slider)

    else:
        raise NotImplementedError(symbol.name)

tab_contents = []
resonances_name = ["N* [pη (31)]", "Δ⁺ [π⁰p(23)]", "a₂ [ηπ⁰(12)]"]
for i in range(len(resonances_name)):
    tab_content = w.VBox([
        w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),
        w.VBox(categorized_cphi_pair[i]),
    ])
    tab_contents.append(tab_content)
UI = w.VBox([
    w.HBox(categorized_sliders_l),
    w.Tab(tab_contents, titles=resonances_name),
])
intensity_funcs = np.array([
    [
        [
            create_parametrized_function(
                expression=intensity_expr.xreplace({l12: i, l23: j, l31: k})
                .doit()
                .expand(func=True),
                parameters=parameters_default,
                backend="jax",
            )
            for i in range(l_max + 1)
        ]
        for j in range(l_max + 1)
    ]
    for k in range(l_max + 1)
])
intensity_funcs.shape
(3, 3, 3)
Hide code cell source
def insert_phi(parameters: dict) -> dict:
    updated_parameters = {}
    for key, value in parameters.items():
        if key.startswith("phi_"):
            continue
        if key.startswith(("a", "b", "c")):
            phi_key = f"phi_{key}"
            if phi_key in parameters:
                phi = parameters[phi_key]
                value *= np.exp(1j * phi)  # noqa:PLW2901
        updated_parameters[key] = value

    return updated_parameters
Hide code cell source
%matplotlib widget
%config InlineBackend.figure_formats = ['png']
fig_2d, ax_2d = plt.subplots(dpi=200)
ax_2d.set_title("Model-weighted Phase space Dalitz Plot")
ax_2d.set_xlabel(R"$m^2(\eta \pi^0)\;\left[\mathrm{GeV}^2\right]$")
ax_2d.set_ylabel(R"$m^2(\pi^0 p)\;\left[\mathrm{GeV}^2\right]$")

mesh = None


def update_histogram(**parameters):
    global mesh
    l12 = parameters["l_{12}"]
    l23 = parameters["l_{23}"]
    l31 = parameters["l_{31}"]
    intensity_func = intensity_funcs[l12, l23, l31]
    parameters = insert_phi(parameters)
    intensity_func.update_parameters(parameters)
    intensity_weights = intensity_func(phsp)
    bin_values, xedges, yedges = jnp.histogram2d(
        phsp["s_{12}"],
        phsp["s_{23}"],
        bins=200,
        weights=intensity_weights,
        density=True,
    )
    bin_values = jnp.where(bin_values < 1e-10, jnp.nan, bin_values)
    x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])
    if mesh is None:
        mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap="jet", vmax=0.15)
    else:
        mesh.set_array(bin_values.T)
    fig_2d.canvas.draw_idle()


interactive_plot = w.interactive_output(update_histogram, sliders)
fig_2d.tight_layout()
fig_2d.colorbar(mesh, ax=ax_2d)

if STATIC_PAGE:
    filename = "dalitz-plot-man.png"
    fig_2d.savefig(filename)
    plt.close(fig_2d)
    display(UI, Image(filename))
else:
    display(UI, interactive_plot)
../_images/2d071cf6b87a17e6d17f287bd85f62daabfbe2d005ee518d67a327d8443f4fc6.png

Projection#

Hide code cell source
%matplotlib widget
%config InlineBackend.figure_formats = ['svg']

theta_subtitles = [
    R"$\cos (\theta_{1}^{{12}}) \equiv \cos (\theta_{\eta}^{{\eta \pi^0}})$",
    R"$\cos (\theta_{2}^{{23}}) \equiv \cos (\theta_{\pi^0}^{{\pi^0 p}})$",
    R"$\cos (\theta_{3}^{{31}}) \equiv \cos (\theta_{p}^{{p \eta}})$",
]
phi_subtitles = [
    R"$\phi_1^{12} \equiv \phi_{\eta}^{{\eta \pi^0}}$",
    R"$\phi_2^{23} \equiv \phi_{\pi^0}^{{\pi^0 p}}$",
    R"$\phi_3^{31} \equiv \phi_{p}^{{p \eta}}$",
]
mass_subtitles = [
    R"$m_{12} \equiv m_{\eta \pi^0}$",
    R"$m_{23} \equiv m_{\pi^0 p}$",
    R"$m_{31} \equiv m_{p \eta}$",
]

fig, (theta_axes, phi_axes, mass_axes) = plt.subplots(figsize=(12, 8), ncols=3, nrows=3)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False

for i, ax in enumerate(theta_axes):
    ax.set_title(theta_subtitles[i])
    ax.set_xticks([-1, 0, 1])

for i, ax in enumerate(phi_axes):
    ax.set_title(phi_subtitles[i])
    ax.set_xticks([-np.pi, 0, np.pi])
    ax.set_xticklabels([R"-$\pi$", 0, R"$\pi$"])

for i, ax in enumerate(mass_axes):
    ax.set_title(mass_subtitles[i])
    ax.set_xlabel("Mass [GeV]")

LINES = 3 * [None]
THETA_LINES = 3 * [None]
PHI_LINES = 3 * [None]
RESONANCE_LINES = defaultdict(dict)
RESONANCES_MASS_NAME = [m_a2, m_delta, m_nstar]


def update_plot(**parameters):  # noqa: C901, PLR0912, PLR0914
    l12 = parameters["l_{12}"]
    l23 = parameters["l_{23}"]
    l31 = parameters["l_{31}"]
    intensity_func = intensity_funcs[l12, l23, l31]
    parameters = insert_phi(parameters)
    intensity_func.update_parameters(parameters)
    intensities = intensity_func(phsp)
    max_value_theta = 0.0
    max_value_phi = 0.0
    max_value_mass = 0.0
    theta_keys = ["theta_1", "theta_2", "theta_3"]
    phi_keys = ["phi_1", "phi_2", "phi_3"]
    m2_keys = ["s_{12}", "s_{23}", "s_{31}"]
    line_labels = [R"$m_{a_2}$", R"$m_{\Delta}$", R"$m_{N^*}$"]
    line_colors = ["r", "g", "b"]
    plot_style = dict(bins=120, weights=intensities, density=True)

    for i, ax in enumerate(mass_axes):
        bin_values, bin_edges = jnp.histogram(np.sqrt(phsp[m2_keys[i]]), **plot_style)
        max_value_mass = max(max_value_mass, bin_values.max())

        if LINES[i] is None:
            LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]
        else:
            LINES[i].set_ydata(bin_values)

        symbol_key = sp.latex([m_a2, m_delta, m_nstar][i])
        val = parameters[symbol_key]
        resonance_line = RESONANCE_LINES[i].get(symbol_key)
        if resonance_line is None:
            RESONANCE_LINES[i][symbol_key] = ax.axvline(
                val, color=line_colors[i], linestyle="--", label=line_labels[i]
            )
        else:
            resonance_line.set_xdata([val, val])

    for i, ax in enumerate(theta_axes):
        bin_values, bin_edges = jnp.histogram(np.cos(phsp[theta_keys[i]]), **plot_style)
        max_value_theta = max(max_value_theta, bin_values.max())
        if THETA_LINES[i] is None:
            THETA_LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]
        else:
            THETA_LINES[i].set_ydata(bin_values)

    for i, ax in enumerate(phi_axes):
        bin_values, bin_edges = jnp.histogram(phsp[phi_keys[i]], **plot_style)
        max_value_phi = max(max_value_phi, bin_values.max())
        if PHI_LINES[i] is None:
            PHI_LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]
        else:
            PHI_LINES[i].set_ydata(bin_values)

    for ax in mass_axes:
        ax.set_ylim(0, max_value_mass * 1.1)
        ax.legend()

    for ax in theta_axes:
        ax.set_ylim(0, max_value_theta * 1.1)

    for ax in phi_axes:
        ax.set_ylim(0, max_value_phi * 1.1)


interactive_plot = w.interactive_output(update_plot, sliders)
fig.tight_layout()

if STATIC_PAGE:
    filename = "1d-histograms-man.svg"
    fig.savefig(filename)
    plt.close(fig)
    display(UI, SVG(filename))
else:
    display(UI, interactive_plot)
../_images/8764a5f935c6f19bfe9b2e2e596a071e27132cf9f61860718afd64d44b7c82af.svg