D⁰ → K⁰ K⁺ K⁻#

The decay \(D^0 \to K^0K^+K^-\) has a spinless initial and final state, which means that there is no need to align spin with Dalitz-plot decomposition. This notebook shows that the model formulated by ampform is the same as that formulated by AmpForm-DPD. To simplify this comparison, we do not define any dynamics.

Hide code cell content
from __future__ import annotations

import logging
import os
import warnings
from textwrap import dedent
from typing import TYPE_CHECKING

import ampform
import graphviz
import ipywidgets
import jax.numpy as jnp
import matplotlib.pyplot as plt
import qrules
import sympy as sp
from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass
from ampform.sympy import perform_cached_doit
from IPython.display import Latex, Markdown, clear_output, display
from ipywidgets import (
    Accordion,
    Checkbox,
    GridBox,
    HBox,
    Layout,
    SelectMultiple,
    Tab,
    ToggleButtons,
    VBox,
    interactive_output,
)
from tensorwaves.data.phasespace import TFPhaseSpaceGenerator
from tensorwaves.data.rng import TFUniformRealNumberGenerator
from tensorwaves.data.transform import SympyDataTransformer

from ampform_dpd import DalitzPlotDecompositionBuilder
from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay
from ampform_dpd.decay import Particle
from ampform_dpd.io import (
    as_markdown_table,
    aslatex,
    get_readable_hash,
    perform_cached_lambdify,
    simplify_latex_rendering,
)

if TYPE_CHECKING:
    from ampform.helicity import HelicityModel
    from qrules.transition import ReactionInfo
    from tensorwaves.interface import DataSample, ParameterValue, ParametrizedFunction

simplify_latex_rendering()
logging.getLogger("jax").setLevel(logging.ERROR)  # mute JAX
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # mute TF
warnings.simplefilter("ignore", category=RuntimeWarning)
NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger("ampform.sympy").setLevel(logging.ERROR)
    logging.getLogger("ampform_dpd.io").setLevel(logging.ERROR)

Decay definition#

Hide code cell source
REACTION = qrules.generate_transitions(
    initial_state="D0",
    final_state=["K0", "K-", "K+"],
    allowed_intermediate_particles=["a(0)", "f(0)(980)", "pi(1)"],
    mass_conservation_factor=0.2,
    formalism="helicity",
)
REACTION123 = normalize_state_ids(REACTION)
dot = qrules.io.asdot(REACTION123, collapse_graphs=True)
graphviz.Source(dot)
../_images/ffbb15ecd139f3904de55d1401a55060b0bd75593acdc167dfeac1d171f56321.svg
Hide code cell source
DECAY = to_three_body_decay(REACTION123.transitions, min_ls=True)
Markdown(as_markdown_table([DECAY.initial_state, *DECAY.final_state.values()]))

index

name

LaTeX

\(J^P\)

mass (MeV)

width (MeV)

0

D0

\(D^{0}\)

\(0^-\)

1,864

0

1

K0

\(K^{0}\)

\(0^-\)

497

0

2

K-

\(K^{-}\)

\(0^-\)

493

0

3

K+

\(K^{+}\)

\(0^-\)

493

0

Hide code cell source
resonances = sorted(
    {t.resonance for t in DECAY.chains},
    key=lambda p: (p.name[0], p.mass),
)
resonance_names = [p.name for p in resonances]
Markdown(as_markdown_table(resonances))

name

LaTeX

\(J^P\)

mass (MeV)

width (MeV)

a(0)(980)+

\(a_{0}(980)^{+}\)

\(0^+\)

980

75

a(0)(980)-

\(a_{0}(980)^{-}\)

\(0^+\)

980

75

a(0)(980)0

\(a_{0}(980)^{0}\)

\(0^+\)

980

75

f(0)(980)

\(f_{0}(980)\)

\(0^+\)

990

60

pi(1)(1400)+

\(\pi_{1}(1400)^{+}\)

\(1^-\)

1,354

330

pi(1)(1400)-

\(\pi_{1}(1400)^{-}\)

\(1^-\)

1,354

330

pi(1)(1400)0

\(\pi_{1}(1400)^{0}\)

\(1^-\)

1,354

330

Hide code cell source
Latex(aslatex(DECAY, with_jp=True))
\[\begin{split}\begin{array}{c} D^{0} \to \left(a_{0}(980)^{+} \to K^{0} K^{+}\right) K^{-} \\ D^{0} \to \left(a_{0}(980)^{-} \to K^{0} K^{-}\right) K^{+} \\ D^{0} \to \left(a_{0}(980)^{0} \to K^{-} K^{+}\right) K^{0} \\ D^{0} \to \left(f_{0}(980) \to K^{-} K^{+}\right) K^{0} \\ D^{0} \to \left(\pi_{1}(1400)^{+} \to K^{0} K^{+}\right) K^{-} \\ D^{0} \to \left(\pi_{1}(1400)^{-} \to K^{0} K^{-}\right) K^{+} \\ D^{0} \to \left(\pi_{1}(1400)^{0} \to K^{-} K^{+}\right) K^{0} \\ \end{array}\end{split}\]

Model formulation#

DPD model#

Note that, as opposed to Λc⁺ → pπ⁺K⁻ and J/ψ → K⁰Σ⁺p̅, there are no Wigner-\(d\) functions, because the final state is spinless.

Hide code cell source
model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=True)
DPD_MODEL = model_builder.formulate(reference_subsystem=1)
del model_builder
DPD_MODEL.intensity.cleanup()
\[\displaystyle \left|{\sum_{\lambda_0^{\prime}=0} \sum_{\lambda_1^{\prime}=0} \sum_{\lambda_2^{\prime}=0} \sum_{\lambda_3^{\prime}=0}{A^{1}_{\lambda_0^{\prime}, \lambda_1^{\prime}, \lambda_2^{\prime}, \lambda_3^{\prime}} + A^{2}_{\lambda_0^{\prime}, \lambda_1^{\prime}, \lambda_2^{\prime}, \lambda_3^{\prime}} + A^{3}_{\lambda_0^{\prime}, \lambda_1^{\prime}, \lambda_2^{\prime}, \lambda_3^{\prime}}}}\right|^{2}\]
Hide code cell source
\[\begin{split}\begin{array}{rcl} A^{1}_{0, 0, 0, 0} &=& \sum_{\lambda_{R}=-1}^{1}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{\pi_{1}(1400)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{\pi_{1}(1400)^{0}, \lambda_{R}, 0} d^{1}_{\lambda_{R},0}\left(\theta_{23}\right)} + \sum_{\lambda_{R}=0}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{a_{0}(980)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{a_{0}(980)^{0}, \lambda_{R}, 0} d^{0}_{\lambda_{R},0}\left(\theta_{23}\right)} + \sum_{\lambda_{R}=0}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{f_{0}(980), 0, 0} \mathcal{H}^\mathrm{production}_{f_{0}(980), \lambda_{R}, 0} d^{0}_{\lambda_{R},0}\left(\theta_{23}\right)} \\ A^{2}_{0, 0, 0, 0} &=& \sum_{\lambda_{R}=-1}^{1}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{\pi_{1}(1400)^{+}, 0, 0} \mathcal{H}^\mathrm{production}_{\pi_{1}(1400)^{+}, \lambda_{R}, 0} d^{1}_{\lambda_{R},0}\left(\theta_{31}\right)} + \sum_{\lambda_{R}=0}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{a_{0}(980)^{+}, 0, 0} \mathcal{H}^\mathrm{production}_{a_{0}(980)^{+}, \lambda_{R}, 0} d^{0}_{\lambda_{R},0}\left(\theta_{31}\right)} \\ A^{3}_{0, 0, 0, 0} &=& \sum_{\lambda_{R}=-1}^{1}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{\pi_{1}(1400)^{-}, 0, 0} \mathcal{H}^\mathrm{production}_{\pi_{1}(1400)^{-}, \lambda_{R}, 0} d^{1}_{\lambda_{R},0}\left(\theta_{12}\right)} + \sum_{\lambda_{R}=0}{\delta_{0 \lambda_{R}} \mathcal{H}^\mathrm{decay}_{a_{0}(980)^{-}, 0, 0} \mathcal{H}^\mathrm{production}_{a_{0}(980)^{-}, \lambda_{R}, 0} d^{0}_{\lambda_{R},0}\left(\theta_{12}\right)} \\ \end{array}\end{split}\]

There is an isobar Wigner-\(d\) function, which takes the following helicity angles as argument:

Hide code cell source
\[\begin{split}\begin{array}{rcl} \theta_{23} &=& \operatorname{acos}{\left(\frac{2 \sigma_{1} \left(- m_{1}^{2} - m_{2}^{2} + \sigma_{3}\right) - \left(m_{0}^{2} - m_{1}^{2} - \sigma_{1}\right) \left(m_{2}^{2} - m_{3}^{2} + \sigma_{1}\right)}{\sqrt{\lambda\left(m_{0}^{2}, m_{1}^{2}, \sigma_{1}\right)} \sqrt{\lambda\left(\sigma_{1}, m_{2}^{2}, m_{3}^{2}\right)}} \right)} \\ \theta_{31} &=& \operatorname{acos}{\left(\frac{2 \sigma_{2} \left(- m_{2}^{2} - m_{3}^{2} + \sigma_{1}\right) - \left(m_{0}^{2} - m_{2}^{2} - \sigma_{2}\right) \left(- m_{1}^{2} + m_{3}^{2} + \sigma_{2}\right)}{\sqrt{\lambda\left(m_{0}^{2}, m_{2}^{2}, \sigma_{2}\right)} \sqrt{\lambda\left(\sigma_{2}, m_{3}^{2}, m_{1}^{2}\right)}} \right)} \\ \theta_{12} &=& \operatorname{acos}{\left(\frac{2 \sigma_{3} \left(- m_{1}^{2} - m_{3}^{2} + \sigma_{2}\right) - \left(m_{0}^{2} - m_{3}^{2} - \sigma_{3}\right) \left(m_{1}^{2} - m_{2}^{2} + \sigma_{3}\right)}{\sqrt{\lambda\left(m_{0}^{2}, m_{3}^{2}, \sigma_{3}\right)} \sqrt{\lambda\left(\sigma_{3}, m_{1}^{2}, m_{2}^{2}\right)}} \right)} \\ \end{array}\end{split}\]

AmpForm model#

AmpForm does not formulate alignment Wigner-\(D\) functions. For the case of this spinless final state, this means the intensity is the same as that of the DPD model.

Hide code cell source
model_builder = ampform.get_builder(REACTION)
model_builder.use_helicity_couplings = False
model_builder.config.scalar_initial_state_mass = True
model_builder.config.stable_final_state_ids = [0, 1, 2]
AMPFORM_MODEL = model_builder.formulate()
AMPFORM_MODEL.intensity
\[\displaystyle \sum_{m_{A}=0} \sum_{m_{0}=0} \sum_{m_{1}=0} \sum_{m_{2}=0}{\left|{A^{01}_{m_{A}, m_{0}, m_{1}, m_{2}} + A^{02}_{m_{A}, m_{0}, m_{1}, m_{2}} + A^{12}_{m_{A}, m_{0}, m_{1}, m_{2}}}\right|^{2}}\]
Hide code cell source
\[\begin{split}\begin{array}{rcl} A^{01}_{0, 0, 0, 0} &=& C_{D^{0} \to K^{+}_{0} {\pi_{1}(1400)^{-}}_{0}; \pi_{1}(1400)^{-} \to K^{-}_{0} K^{0}_{0}} D^{0}_{0,0}\left(- \phi_{01},\theta_{01},0\right) D^{1}_{0,0}\left(- \phi^{01}_{0},\theta^{01}_{0},0\right) + C_{D^{0} \to K^{+}_{0} {a_{0}(980)^{-}}_{0}; a_{0}(980)^{-} \to K^{-}_{0} K^{0}_{0}} D^{0}_{0,0}\left(- \phi_{01},\theta_{01},0\right) D^{0}_{0,0}\left(- \phi^{01}_{0},\theta^{01}_{0},0\right) \\ A^{02}_{0, 0, 0, 0} &=& C_{D^{0} \to K^{-}_{0} {\pi_{1}(1400)^{+}}_{0}; \pi_{1}(1400)^{+} \to K^{+}_{0} K^{0}_{0}} D^{0}_{0,0}\left(- \phi_{02},\theta_{02},0\right) D^{1}_{0,0}\left(- \phi^{02}_{0},\theta^{02}_{0},0\right) + C_{D^{0} \to K^{-}_{0} {a_{0}(980)^{+}}_{0}; a_{0}(980)^{+} \to K^{+}_{0} K^{0}_{0}} D^{0}_{0,0}\left(- \phi_{02},\theta_{02},0\right) D^{0}_{0,0}\left(- \phi^{02}_{0},\theta^{02}_{0},0\right) \\ A^{12}_{0, 0, 0, 0} &=& C_{D^{0} \to K^{0}_{0} {\pi_{1}(1400)^{0}}_{0}; \pi_{1}(1400)^{0} \to K^{+}_{0} K^{-}_{0}} D^{0}_{0,0}\left(- \phi_{0},\theta_{0},0\right) D^{1}_{0,0}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) + C_{D^{0} \to K^{0}_{0} {a_{0}(980)^{0}}_{0}; a_{0}(980)^{0} \to K^{+}_{0} K^{-}_{0}} D^{0}_{0,0}\left(- \phi_{0},\theta_{0},0\right) D^{0}_{0,0}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) + C_{D^{0} \to K^{0}_{0} {f_{0}(980)}_{0}; f_{0}(980) \to K^{+}_{0} K^{-}_{0}} D^{0}_{0,0}\left(- \phi_{0},\theta_{0},0\right) D^{0}_{0,0}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) \\ \end{array}\end{split}\]
Hide code cell source
\[\begin{split}\begin{array}{rcl} m_{01} &=& m_{{p}_{01}} \\ m_{02} &=& m_{{p}_{02}} \\ m_{12} &=& m_{{p}_{12}} \\ \phi_{0} &=& \phi\left({p}_{12}\right) \\ \phi^{01}_{0} &=& \phi\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{01}}\right|}{E\left({p}_{01}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{01}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{01}\right)\right) p_{0}\right) \\ \phi^{02}_{0} &=& \phi\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{02}}\right|}{E\left({p}_{02}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{02}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{02}\right)\right) p_{0}\right) \\ \phi_{01} &=& \phi\left({p}_{01}\right) \\ \phi^{12}_{1} &=& \phi\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{12}}\right|}{E\left({p}_{12}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{12}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{12}\right)\right) p_{1}\right) \\ \phi_{02} &=& \phi\left({p}_{02}\right) \\ \theta_{0} &=& \theta\left({p}_{12}\right) \\ \theta^{01}_{0} &=& \theta\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{01}}\right|}{E\left({p}_{01}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{01}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{01}\right)\right) p_{0}\right) \\ \theta^{02}_{0} &=& \theta\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{02}}\right|}{E\left({p}_{02}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{02}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{02}\right)\right) p_{0}\right) \\ \theta_{01} &=& \theta\left({p}_{01}\right) \\ \theta^{12}_{1} &=& \theta\left(\boldsymbol{B_z}\left(\frac{\left|\vec{{p}_{12}}\right|}{E\left({p}_{12}\right)}\right) \boldsymbol{R_y}\left(- \theta\left({p}_{12}\right)\right) \boldsymbol{R_z}\left(- \phi\left({p}_{12}\right)\right) p_{1}\right) \\ \theta_{02} &=& \theta\left({p}_{02}\right) \\ \end{array}\end{split}\]

Phase space sample#

Hide code cell source
p1, p2, p3 = tuple(FourMomentumSymbol(f"p{i}", shape=[]) for i in (0, 1, 2))
s1, s2, s3 = sorted(DPD_MODEL.invariants, key=str)
mass_definitions = {
    **DPD_MODEL.masses,
    s1: InvariantMass(p2 + p3) ** 2,
    s2: InvariantMass(p1 + p3) ** 2,
    s3: InvariantMass(p1 + p2) ** 2,
    sp.Symbol("m_01", nonnegative=True): InvariantMass(p1 + p2),
    sp.Symbol("m_02", nonnegative=True): InvariantMass(p1 + p3),
    sp.Symbol("m_12", nonnegative=True): InvariantMass(p2 + p3),
}
dpd_variables = {
    symbol: expr.doit().xreplace(DPD_MODEL.variables).xreplace(mass_definitions)
    for symbol, expr in DPD_MODEL.variables.items()
}
dpd_transformer = SympyDataTransformer.from_sympy(dpd_variables, backend="jax")

ampform_transformer = SympyDataTransformer.from_sympy(
    AMPFORM_MODEL.kinematic_variables, backend="jax"
)
def generate_phase_space(reaction: ReactionInfo, size: int) -> dict[str, jnp.ndarray]:
    rng = TFUniformRealNumberGenerator(seed=0)
    phsp_generator = TFPhaseSpaceGenerator(
        initial_state_mass=reaction.initial_state[-1].mass,
        final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
    )
    return phsp_generator.generate(size, rng)


phsp = generate_phase_space(AMPFORM_MODEL.reaction_info, size=100_000)
ampform_phsp = ampform_transformer(phsp)
dpd_phsp = dpd_transformer(phsp)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714472689.973854    3254 service.cc:145] XLA service 0x7f4db418cab0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1714472689.973904    3254 service.cc:153]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1714472689.997112    3254 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

Convert to numerical functions#

def unfold_intensity(model: HelicityModel) -> sp.Expr:
    unfolded_intensity = perform_cached_doit(model.intensity)
    unfolded_amplitudes = {
        symbol: perform_cached_doit(expr) for symbol, expr in model.amplitudes.items()
    }
    return unfolded_intensity.xreplace(unfolded_amplitudes)


ampform_intensity_expr = unfold_intensity(AMPFORM_MODEL)
dpd_intensity_expr = unfold_intensity(DPD_MODEL)
ampform_func = perform_cached_lambdify(
    ampform_intensity_expr,
    parameters=AMPFORM_MODEL.parameter_defaults,
)
dpd_func = perform_cached_lambdify(
    dpd_intensity_expr,
    parameters=DPD_MODEL.parameter_defaults,
)

Visualization#

Hide code cell source
def compute_sub_intensities(
    func: ParametrizedFunction, phsp: DataSample, resonance_name: str
) -> jnp.ndarray:
    original_parameters = dict(func.parameters)
    _set_couplings_to_zero(func, resonance_name)
    intensity_array = func(phsp)
    func.update_parameters(original_parameters)
    return intensity_array


def _set_couplings_to_zero(
    func: ParametrizedFunction, resonance_names: list[str]
) -> None:
    couplings_to_zero = {
        key: value if any(r in key for r in resonance_names) else 0
        for key, value in _get_couplings(func).items()
    }
    func.update_parameters(couplings_to_zero)


def _get_couplings(func: ParametrizedFunction) -> dict[str, ParameterValue]:
    return {
        key: value
        for key, value in func.parameters.items()
        if key.startswith("C") or "production" in key
    }
Hide code cell source
def create_sliders() -> dict[str, ToggleButtons]:
    all_parameters = dict(AMPFORM_MODEL.parameter_defaults.items())
    all_parameters.update(DPD_MODEL.parameter_defaults)
    sliders = {}
    for symbol, value in all_parameters.items():
        value = "+1"
        if (
            symbol.name.startswith(R"\mathcal{H}^\mathrm{decay}") and "+" in symbol.name
        ) and any(s in symbol.name for s in ["{1}", "*", "rho"]):
            value = "-1"
        sliders[symbol.name] = ToggleButtons(
            description=Rf"\({sp.latex(symbol)}\)",
            options=["-1", "0", "+1"],
            value=value,
            continuous_update=False,
        )
    return sliders


def to_unicode(particle: Particle) -> str:
    unicode = particle.name
    unicode = unicode.replace("pi", "π")
    unicode = unicode.replace("rho", "p")
    unicode = unicode.replace("Sigma", "Σ")
    unicode = unicode.replace("~", "")
    unicode = unicode.replace("Σ", "~Σ")
    unicode = unicode.replace("+", "⁺")
    unicode = unicode.replace("-", "⁻")
    unicode = unicode.replace("(0)", "₀")
    unicode = unicode.replace("(1)", "₁")
    return unicode.replace(")0", ")⁰")


sliders = create_sliders()
resonance_selector = SelectMultiple(
    description="Resonance",
    options={to_unicode(p): p.latex for p in resonances},
    value=[resonances[0].latex, resonances[1].latex],
    layout=Layout(
        height=f"{14 * (len(resonances) + 1)}pt",
        width="auto",
    ),
)
hide_expressions = Checkbox(description="Hide expressions", value=True)
simplify_expressions = Checkbox(description="Simplify", value=True)
ipywidgets.link((hide_expressions, "value"), (simplify_expressions, "disabled"))

package_names = ("AmpForm", "AmpForm-DPD")
ui = HBox([
    VBox([resonance_selector, hide_expressions, simplify_expressions]),
    Tab(
        children=[
            Accordion(
                children=[
                    GridBox([
                        sliders[key]
                        for key in sorted(sliders)
                        if p.latex in key
                        if (
                            key[0] in {"C", "H"}
                            if package == "AmpForm"
                            else key.startswith(R"\mathcal{H}")
                        )
                    ])
                    for package in package_names
                ],
                selected_index=1,
                titles=package_names,
            )
            for p in resonances
        ],
        titles=[to_unicode(p) for p in resonances],
    ),
])
Hide code cell source
%matplotlib widget
plt.rc("font", size=12)
fig, axes = plt.subplots(figsize=(16, 6), ncols=3, nrows=2)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
(
    (ax_s1, ax_s2, ax_s3),
    (ax_t1, ax_t2, ax_t3),
) = axes
for ax in axes[:, 0].flatten():
    ax.set_ylabel("Intensity (a.u.)")
for ax in axes[:, 1:].flatten():
    ax.set_yticks([])

final_state = DECAY.final_state
ax_s1.set_xlabel(f"$m({final_state[2].latex}, {final_state[3].latex})$")
ax_s2.set_xlabel(f"$m({final_state[1].latex}, {final_state[3].latex})$")
ax_s3.set_xlabel(f"$m({final_state[1].latex}, {final_state[2].latex})$")
ax_t1.set_xlabel(Rf"$\theta({final_state[2].latex}, {final_state[3].latex})$")
ax_t2.set_xlabel(Rf"$\theta({final_state[1].latex}, {final_state[3].latex})$")
ax_t3.set_xlabel(Rf"$\theta({final_state[1].latex}, {final_state[2].latex})$")
fig.suptitle(f'Selected resonances: ${", ".join(resonance_selector.value)}$')
fig.tight_layout()

lines = None

__EXPRS: dict[int, sp.Expr] = {}


def _get_symbol_values(
    expr: sp.Expr,
    parameters: dict[str, ParameterValue],
    selected_resonances: list[str],
) -> dict[sp.Symbol, sp.Rational]:
    parameters = {
        key: value if any(r in key for r in selected_resonances) else 0
        for key, value in parameters.items()
    }
    return {
        s: sp.Rational(parameters[s.name])
        for s in expr.free_symbols
        if s.name in parameters
    }


def _simplify(
    intensity_expr: sp.Expr,
    parameter_defaults: dict[str, ParameterValue],
    variables: dict[sp.Symbol, sp.Expr],
    selected_resonances: list[str],
) -> sp.Expr:
    parameters = _get_symbol_values(
        intensity_expr, parameter_defaults, selected_resonances
    )
    fixed_variables = {k: v for k, v in variables.items() if not v.free_symbols}
    obj = (
        intensity_expr,
        tuple((k, fixed_variables[k]) for k in sorted(fixed_variables, key=str)),
        tuple((k, parameters[k]) for k in sorted(parameters, key=str)),
    )
    h = get_readable_hash(obj)
    if h in __EXPRS:
        return __EXPRS[h]
    expr = intensity_expr.xreplace(parameters).xreplace(fixed_variables)
    expr = sp.trigsimp(expr)
    __EXPRS[h] = expr
    return expr


def plot_contributions(**kwargs) -> None:
    kwargs.pop("resonance_selector")
    kwargs.pop("hide_expressions")
    kwargs.pop("simplify_expressions")
    selected_resonances = list(resonance_selector.value)
    fig.suptitle(f'Selected resonances: ${", ".join(selected_resonances)}$')
    dpd_pars = {k: int(v) for k, v in kwargs.items() if k in dpd_func.parameters}
    ampform_pars = {
        k: int(v) for k, v in kwargs.items() if k in ampform_func.parameters
    }
    ampform_func.update_parameters(ampform_pars)
    dpd_func.update_parameters(dpd_pars)
    ampform_intensities = compute_sub_intensities(
        ampform_func, ampform_phsp, selected_resonances
    )
    dpd_intensities = compute_sub_intensities(dpd_func, dpd_phsp, selected_resonances)

    s_edges = jnp.linspace(0.98, 1.38, num=50)
    amp_values_s1, _ = jnp.histogram(
        ampform_phsp["m_12"].real,
        bins=s_edges,
        weights=ampform_intensities,
    )
    dpd_values_s1, _ = jnp.histogram(
        ampform_phsp["m_12"].real,
        bins=s_edges,
        weights=dpd_intensities,
    )
    amp_values_s2, _ = jnp.histogram(
        ampform_phsp["m_02"].real,
        bins=s_edges,
        weights=ampform_intensities,
    )
    dpd_values_s2, _ = jnp.histogram(
        ampform_phsp["m_02"].real,
        bins=s_edges,
        weights=dpd_intensities,
    )
    amp_values_s3, _ = jnp.histogram(
        ampform_phsp["m_01"].real,
        bins=s_edges,
        weights=ampform_intensities,
    )
    dpd_values_s3, _ = jnp.histogram(
        ampform_phsp["m_01"].real,
        bins=s_edges,
        weights=dpd_intensities,
    )

    t_edges = jnp.linspace(0, jnp.pi, num=50)
    amp_values_t1, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=ampform_intensities,
    )
    dpd_values_t1, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=dpd_intensities,
    )
    amp_values_t2, _ = jnp.histogram(
        dpd_phsp["theta_31"].real,
        bins=t_edges,
        weights=ampform_intensities,
    )
    dpd_values_t2, _ = jnp.histogram(
        dpd_phsp["theta_31"].real,
        bins=t_edges,
        weights=dpd_intensities,
    )
    amp_values_t3, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=ampform_intensities,
    )
    dpd_values_t3, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=dpd_intensities,
    )

    global lines
    amp_kwargs = dict(color="r", label="ampform", linestyle="solid")
    dpd_kwargs = dict(color="blue", label="dpd", linestyle="dotted")
    if lines is None:
        sx = (s_edges[:-1] + s_edges[1:]) / 2
        tx = (t_edges[:-1] + t_edges[1:]) / 2
        lines = [
            ax_s1.step(sx, amp_values_s1, **amp_kwargs)[0],
            ax_s1.step(sx, dpd_values_s1, **dpd_kwargs)[0],
            ax_s2.step(sx, amp_values_s2, **amp_kwargs)[0],
            ax_s2.step(sx, dpd_values_s2, **dpd_kwargs)[0],
            ax_s3.step(sx, amp_values_s3, **amp_kwargs)[0],
            ax_s3.step(sx, dpd_values_s3, **dpd_kwargs)[0],
            ax_t1.step(tx, amp_values_t1, **amp_kwargs)[0],
            ax_t1.step(tx, dpd_values_t1, **dpd_kwargs)[0],
            ax_t2.step(tx, amp_values_t2, **amp_kwargs)[0],
            ax_t2.step(tx, dpd_values_t2, **dpd_kwargs)[0],
            ax_t3.step(tx, amp_values_t3, **amp_kwargs)[0],
            ax_t3.step(tx, dpd_values_t3, **dpd_kwargs)[0],
        ]
        ax_s1.legend(loc="upper right")
    else:
        lines[0].set_ydata(amp_values_s1)
        lines[1].set_ydata(dpd_values_s1)
        lines[2].set_ydata(amp_values_s2)
        lines[3].set_ydata(dpd_values_s2)
        lines[4].set_ydata(amp_values_s3)
        lines[5].set_ydata(dpd_values_s3)
        lines[6].set_ydata(amp_values_t1)
        lines[7].set_ydata(dpd_values_t1)
        lines[8].set_ydata(amp_values_t2)
        lines[9].set_ydata(dpd_values_t2)
        lines[10].set_ydata(amp_values_t3)
        lines[11].set_ydata(dpd_values_t3)

    sy_max = max(
        jnp.nanmax(amp_values_s1),
        jnp.nanmax(dpd_values_s1),
        jnp.nanmax(amp_values_s2),
        jnp.nanmax(dpd_values_s2),
        jnp.nanmax(amp_values_s3),
        jnp.nanmax(dpd_values_s3),
    )
    ty_max = max(
        jnp.nanmax(amp_values_t1),
        jnp.nanmax(dpd_values_t1),
        jnp.nanmax(amp_values_t2),
        jnp.nanmax(dpd_values_t2),
        jnp.nanmax(amp_values_t3),
        jnp.nanmax(dpd_values_t3),
    )
    for ax in axes[0]:
        ax.set_ylim(0, 1.05 * sy_max)
    for ax in axes[1]:
        ax.set_ylim(0, 1.05 * ty_max)

    fig.canvas.draw_idle()

    if hide_expressions.value:
        clear_output()
    else:
        if simplify_expressions.value:
            ampform_expr = _simplify(
                ampform_intensity_expr,
                ampform_pars,
                AMPFORM_MODEL.kinematic_variables,
                selected_resonances,
            )
            dpd_expr = _simplify(
                dpd_intensity_expr,
                dpd_pars,
                DPD_MODEL.variables,
                selected_resonances,
            )
        else:
            ampform_expr = ampform_intensity_expr
            dpd_expr = dpd_intensity_expr
        src = Rf"""
        \begin{{eqnarray}}
          \text{{AmpForm:}} && {sp.latex(ampform_expr)} \\
          \text{{DPD:}}     && {sp.latex(dpd_expr)} \\
        \end{{eqnarray}}
        """
        src = dedent(src).strip()
        display(Latex(src))


output = interactive_output(
    plot_contributions,
    controls={
        **sliders,
        "resonance_selector": resonance_selector,
        "hide_expressions": hide_expressions,
        "simplify_expressions": simplify_expressions,
    },
)
display(output, ui)