J/ψ → π⁰ π⁺ π⁻#

The decay \(J/\psi \to \pi^0 \pi^+ \pi^-\) has an initial state with spin, but has a spinless final state. This is a follow-up to D⁰ → K⁰ K⁺ K⁻, where there were no alignment Wigner-\(D\) functions.

Hide code cell content
from __future__ import annotations

import logging
import os
import warnings
from textwrap import dedent
from typing import TYPE_CHECKING

import ampform
import graphviz
import ipywidgets
import jax.numpy as jnp
import matplotlib.pyplot as plt
import qrules
import sympy as sp
from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass
from ampform.sympy import perform_cached_doit
from IPython.display import Latex, Markdown, clear_output, display
from ipywidgets import (
    Accordion,
    Checkbox,
    GridBox,
    HBox,
    Layout,
    SelectMultiple,
    Tab,
    ToggleButtons,
    VBox,
    interactive_output,
)
from tensorwaves.data.phasespace import TFPhaseSpaceGenerator
from tensorwaves.data.rng import TFUniformRealNumberGenerator
from tensorwaves.data.transform import SympyDataTransformer

from ampform_dpd import DalitzPlotDecompositionBuilder
from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay
from ampform_dpd.decay import Particle
from ampform_dpd.io import (
    as_markdown_table,
    aslatex,
    get_readable_hash,
    perform_cached_lambdify,
    simplify_latex_rendering,
)

if TYPE_CHECKING:
    from ampform.helicity import HelicityModel
    from qrules.transition import ReactionInfo
    from tensorwaves.interface import DataSample, ParameterValue, ParametrizedFunction

simplify_latex_rendering()
logging.getLogger("jax").setLevel(logging.ERROR)  # mute JAX
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # mute TF
warnings.simplefilter("ignore", category=RuntimeWarning)
NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger("ampform.sympy").setLevel(logging.ERROR)
    logging.getLogger("ampform_dpd.io").setLevel(logging.ERROR)

Decay definition#

Hide code cell source
REACTION = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["pi0", "pi-", "pi+"],
    allowed_intermediate_particles=["a(0)(980)", "rho(770)"],
    mass_conservation_factor=0,
    formalism="helicity",
)
REACTION123 = normalize_state_ids(REACTION)
dot = qrules.io.asdot(REACTION123, collapse_graphs=True)
graphviz.Source(dot)
Hide code cell source
DECAY = to_three_body_decay(REACTION123.transitions, min_ls=True)
Markdown(as_markdown_table([DECAY.initial_state, *DECAY.final_state.values()]))
Hide code cell source
resonances = sorted(
    {t.resonance for t in DECAY.chains},
    key=lambda p: (p.name[0], p.mass),
)
resonance_names = [p.name for p in resonances]
Markdown(as_markdown_table(resonances))
Hide code cell source
Latex(aslatex(DECAY, with_jp=True))

Model formulation#

DPD model#

Hide code cell source
model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=True)
DPD_MODEL = model_builder.formulate(reference_subsystem=1)
del model_builder
DPD_MODEL.intensity.cleanup()
Hide code cell source

There is an isobar Wigner-\(d\) function, which takes the following helicity angles as argument:

Hide code cell source

AmpForm model#

Hide code cell source
model_builder = ampform.get_builder(REACTION)
model_builder.use_helicity_couplings = False
model_builder.config.scalar_initial_state_mass = True
model_builder.config.stable_final_state_ids = [0, 1, 2]
AMPFORM_MODEL = model_builder.formulate()
AMPFORM_MODEL.intensity
Hide code cell source
Hide code cell source

Phase space sample#

Hide code cell source
p1, p2, p3 = tuple(FourMomentumSymbol(f"p{i}", shape=[]) for i in (0, 1, 2))
s1, s2, s3 = sorted(DPD_MODEL.invariants, key=str)
mass_definitions = {
    **DPD_MODEL.masses,
    s1: InvariantMass(p2 + p3) ** 2,
    s2: InvariantMass(p1 + p3) ** 2,
    s3: InvariantMass(p1 + p2) ** 2,
    sp.Symbol("m_01", nonnegative=True): InvariantMass(p1 + p2),
    sp.Symbol("m_02", nonnegative=True): InvariantMass(p1 + p3),
    sp.Symbol("m_12", nonnegative=True): InvariantMass(p2 + p3),
}
dpd_variables = {
    symbol: expr.doit().xreplace(DPD_MODEL.variables).xreplace(mass_definitions)
    for symbol, expr in DPD_MODEL.variables.items()
}
dpd_transformer = SympyDataTransformer.from_sympy(dpd_variables, backend="jax")

ampform_transformer = SympyDataTransformer.from_sympy(
    AMPFORM_MODEL.kinematic_variables, backend="jax"
)
def generate_phase_space(reaction: ReactionInfo, size: int) -> dict[str, jnp.ndarray]:
    rng = TFUniformRealNumberGenerator(seed=0)
    phsp_generator = TFPhaseSpaceGenerator(
        initial_state_mass=reaction.initial_state[-1].mass,
        final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
    )
    return phsp_generator.generate(size, rng)


phsp = generate_phase_space(AMPFORM_MODEL.reaction_info, size=100_000)
ampform_phsp = ampform_transformer(phsp)
dpd_phsp = dpd_transformer(phsp)

Convert to numerical functions#

def unfold_intensity(model: HelicityModel) -> sp.Expr:
    unfolded_intensity = perform_cached_doit(model.intensity)
    unfolded_amplitudes = {
        symbol: perform_cached_doit(expr) for symbol, expr in model.amplitudes.items()
    }
    return unfolded_intensity.xreplace(unfolded_amplitudes)


ampform_intensity_expr = unfold_intensity(AMPFORM_MODEL)
dpd_intensity_expr = unfold_intensity(DPD_MODEL)
ampform_func = perform_cached_lambdify(
    ampform_intensity_expr,
    parameters=AMPFORM_MODEL.parameter_defaults,
)
dpd_func = perform_cached_lambdify(
    dpd_intensity_expr,
    parameters=DPD_MODEL.parameter_defaults,
)

Visualization#

Hide code cell source
def compute_sub_intensities(
    func: ParametrizedFunction, phsp: DataSample, resonance_name: str
) -> jnp.ndarray:
    original_parameters = dict(func.parameters)
    _set_couplings_to_zero(func, resonance_name)
    intensity_array = func(phsp)
    func.update_parameters(original_parameters)
    return intensity_array


def _set_couplings_to_zero(
    func: ParametrizedFunction, resonance_names: list[str]
) -> None:
    couplings_to_zero = {
        key: value if any(r in key for r in resonance_names) else 0
        for key, value in _get_couplings(func).items()
    }
    func.update_parameters(couplings_to_zero)


def _get_couplings(func: ParametrizedFunction) -> dict[str, ParameterValue]:
    return {
        key: value
        for key, value in func.parameters.items()
        if key.startswith("C") or "production" in key
    }
Hide code cell source
def create_sliders() -> dict[str, ToggleButtons]:
    all_parameters = dict(AMPFORM_MODEL.parameter_defaults.items())
    all_parameters.update(DPD_MODEL.parameter_defaults)
    sliders = {}
    for symbol, value in all_parameters.items():
        value = "+1"
        if (
            symbol.name.startswith(R"\mathcal{H}^\mathrm{decay}") and "+" in symbol.name
        ) and any(s in symbol.name for s in ["{1}", "*", "rho"]):
            value = "-1"
        sliders[symbol.name] = ToggleButtons(
            description=Rf"\({sp.latex(symbol)}\)",
            options=["-1", "0", "+1"],
            value=value,
            continuous_update=False,
        )
    return sliders


def to_unicode(particle: Particle) -> str:
    unicode = particle.name
    unicode = unicode.replace("pi", "π")
    unicode = unicode.replace("rho", "p")
    unicode = unicode.replace("Sigma", "Σ")
    unicode = unicode.replace("~", "")
    unicode = unicode.replace("Σ", "~Σ")
    unicode = unicode.replace("+", "⁺")
    unicode = unicode.replace("-", "⁻")
    unicode = unicode.replace("(0)", "₀")
    unicode = unicode.replace("(1)", "₁")
    return unicode.replace(")0", ")⁰")


sliders = create_sliders()
resonance_selector = SelectMultiple(
    description="Resonance",
    options={to_unicode(p): p.latex for p in resonances},
    value=[resonances[0].latex, resonances[1].latex],
    layout=Layout(
        height=f"{14 * (len(resonances) + 1)}pt",
        width="auto",
    ),
)
hide_expressions = Checkbox(description="Hide expressions", value=True)
simplify_expressions = Checkbox(description="Simplify", value=True)
ipywidgets.link((hide_expressions, "value"), (simplify_expressions, "disabled"))

package_names = ("AmpForm", "AmpForm-DPD")
ui = HBox([
    VBox([resonance_selector, hide_expressions, simplify_expressions]),
    Tab(
        children=[
            Accordion(
                children=[
                    GridBox([
                        sliders[key]
                        for key in sorted(sliders)
                        if p.latex in key
                        if (
                            key[0] in {"C", "H"}
                            if package == "AmpForm"
                            else key.startswith(R"\mathcal{H}")
                        )
                    ])
                    for package in package_names
                ],
                selected_index=1,
                titles=package_names,
            )
            for p in resonances
        ],
        titles=[to_unicode(p) for p in resonances],
    ),
])
Hide code cell source
%matplotlib widget
plt.rc("font", size=12)
fig, axes = plt.subplots(figsize=(16, 6), ncols=3, nrows=2)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
(
    (ax_s1, ax_s2, ax_s3),
    (ax_t1, ax_t2, ax_t3),
) = axes
for ax in axes[:, 0].flatten():
    ax.set_ylabel("Intensity (a.u.)")
for ax in axes[:, 1:].flatten():
    ax.set_yticks([])

final_state = DECAY.final_state
ax_s1.set_xlabel(f"$m({final_state[2].latex}, {final_state[3].latex})$")
ax_s2.set_xlabel(f"$m({final_state[1].latex}, {final_state[3].latex})$")
ax_s3.set_xlabel(f"$m({final_state[1].latex}, {final_state[2].latex})$")
ax_t1.set_xlabel(Rf"$\theta({final_state[2].latex}, {final_state[3].latex})$")
ax_t2.set_xlabel(Rf"$\theta({final_state[1].latex}, {final_state[3].latex})$")
ax_t3.set_xlabel(Rf"$\theta({final_state[1].latex}, {final_state[2].latex})$")
fig.suptitle(f"Selected resonances: ${', '.join(resonance_selector.value)}$")
fig.tight_layout()

lines = None

__EXPRS: dict[int, sp.Expr] = {}


def _get_symbol_values(
    expr: sp.Expr,
    parameters: dict[str, ParameterValue],
    selected_resonances: list[str],
) -> dict[sp.Symbol, sp.Rational]:
    parameters = {
        key: value if any(r in key for r in selected_resonances) else 0
        for key, value in parameters.items()
    }
    return {
        s: sp.Rational(parameters[s.name])
        for s in expr.free_symbols
        if s.name in parameters
    }


def _simplify(
    intensity_expr: sp.Expr,
    parameter_defaults: dict[str, ParameterValue],
    variables: dict[sp.Symbol, sp.Expr],
    selected_resonances: list[str],
) -> sp.Expr:
    parameters = _get_symbol_values(
        intensity_expr, parameter_defaults, selected_resonances
    )
    fixed_variables = {k: v for k, v in variables.items() if not v.free_symbols}
    obj = (
        intensity_expr,
        tuple((k, fixed_variables[k]) for k in sorted(fixed_variables, key=str)),
        tuple((k, parameters[k]) for k in sorted(parameters, key=str)),
    )
    h = get_readable_hash(obj)
    if h in __EXPRS:
        return __EXPRS[h]
    expr = intensity_expr.xreplace(parameters).xreplace(fixed_variables)
    expr = sp.trigsimp(expr)
    __EXPRS[h] = expr
    return expr


def plot_contributions(**kwargs) -> None:
    kwargs.pop("resonance_selector")
    kwargs.pop("hide_expressions")
    kwargs.pop("simplify_expressions")
    selected_resonances = list(resonance_selector.value)
    fig.suptitle(f"Selected resonances: ${', '.join(selected_resonances)}$")
    dpd_pars = {k: int(v) for k, v in kwargs.items() if k in dpd_func.parameters}
    ampform_pars = {
        k: int(v) for k, v in kwargs.items() if k in ampform_func.parameters
    }
    ampform_func.update_parameters(ampform_pars)
    dpd_func.update_parameters(dpd_pars)
    ampform_intensities = compute_sub_intensities(
        ampform_func, ampform_phsp, selected_resonances
    )
    dpd_intensities = compute_sub_intensities(dpd_func, dpd_phsp, selected_resonances)

    s_edges = jnp.linspace(0.2, 3.0, num=50)
    amp_values_s1, _ = jnp.histogram(
        ampform_phsp["m_12"].real,
        bins=s_edges,
        weights=ampform_intensities,
    )
    dpd_values_s1, _ = jnp.histogram(
        ampform_phsp["m_12"].real,
        bins=s_edges,
        weights=dpd_intensities,
    )
    amp_values_s2, _ = jnp.histogram(
        ampform_phsp["m_02"].real,
        bins=s_edges,
        weights=ampform_intensities,
    )
    dpd_values_s2, _ = jnp.histogram(
        ampform_phsp["m_02"].real,
        bins=s_edges,
        weights=dpd_intensities,
    )
    amp_values_s3, _ = jnp.histogram(
        ampform_phsp["m_01"].real,
        bins=s_edges,
        weights=ampform_intensities,
    )
    dpd_values_s3, _ = jnp.histogram(
        ampform_phsp["m_01"].real,
        bins=s_edges,
        weights=dpd_intensities,
    )

    t_edges = jnp.linspace(0, jnp.pi, num=50)
    amp_values_t1, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=ampform_intensities,
    )
    dpd_values_t1, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=dpd_intensities,
    )
    amp_values_t2, _ = jnp.histogram(
        dpd_phsp["theta_31"].real,
        bins=t_edges,
        weights=ampform_intensities,
    )
    dpd_values_t2, _ = jnp.histogram(
        dpd_phsp["theta_31"].real,
        bins=t_edges,
        weights=dpd_intensities,
    )
    amp_values_t3, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=ampform_intensities,
    )
    dpd_values_t3, _ = jnp.histogram(
        dpd_phsp["theta_23"].real,
        bins=t_edges,
        weights=dpd_intensities,
    )

    global lines
    amp_kwargs = dict(color="r", label="ampform", linestyle="solid")
    dpd_kwargs = dict(color="blue", label="dpd", linestyle="dotted")
    if lines is None:
        sx = (s_edges[:-1] + s_edges[1:]) / 2
        tx = (t_edges[:-1] + t_edges[1:]) / 2
        lines = [
            ax_s1.step(sx, amp_values_s1, **amp_kwargs)[0],
            ax_s1.step(sx, dpd_values_s1, **dpd_kwargs)[0],
            ax_s2.step(sx, amp_values_s2, **amp_kwargs)[0],
            ax_s2.step(sx, dpd_values_s2, **dpd_kwargs)[0],
            ax_s3.step(sx, amp_values_s3, **amp_kwargs)[0],
            ax_s3.step(sx, dpd_values_s3, **dpd_kwargs)[0],
            ax_t1.step(tx, amp_values_t1, **amp_kwargs)[0],
            ax_t1.step(tx, dpd_values_t1, **dpd_kwargs)[0],
            ax_t2.step(tx, amp_values_t2, **amp_kwargs)[0],
            ax_t2.step(tx, dpd_values_t2, **dpd_kwargs)[0],
            ax_t3.step(tx, amp_values_t3, **amp_kwargs)[0],
            ax_t3.step(tx, dpd_values_t3, **dpd_kwargs)[0],
        ]
        ax_s1.legend(loc="upper right")
    else:
        lines[0].set_ydata(amp_values_s1)
        lines[1].set_ydata(dpd_values_s1)
        lines[2].set_ydata(amp_values_s2)
        lines[3].set_ydata(dpd_values_s2)
        lines[4].set_ydata(amp_values_s3)
        lines[5].set_ydata(dpd_values_s3)
        lines[6].set_ydata(amp_values_t1)
        lines[7].set_ydata(dpd_values_t1)
        lines[8].set_ydata(amp_values_t2)
        lines[9].set_ydata(dpd_values_t2)
        lines[10].set_ydata(amp_values_t3)
        lines[11].set_ydata(dpd_values_t3)

    sy_max = max(
        jnp.nanmax(amp_values_s1),
        jnp.nanmax(dpd_values_s1),
        jnp.nanmax(amp_values_s2),
        jnp.nanmax(dpd_values_s2),
        jnp.nanmax(amp_values_s3),
        jnp.nanmax(dpd_values_s3),
    )
    ty_max = max(
        jnp.nanmax(amp_values_t1),
        jnp.nanmax(dpd_values_t1),
        jnp.nanmax(amp_values_t2),
        jnp.nanmax(dpd_values_t2),
        jnp.nanmax(amp_values_t3),
        jnp.nanmax(dpd_values_t3),
    )
    for ax in axes[0]:
        ax.set_ylim(0, 1.05 * sy_max)
    for ax in axes[1]:
        ax.set_ylim(0, 1.05 * ty_max)

    fig.canvas.draw_idle()

    if hide_expressions.value:
        clear_output()
    else:
        if simplify_expressions.value:
            ampform_expr = _simplify(
                ampform_intensity_expr,
                ampform_pars,
                AMPFORM_MODEL.kinematic_variables,
                selected_resonances,
            )
            dpd_expr = _simplify(
                dpd_intensity_expr,
                dpd_pars,
                DPD_MODEL.variables,
                selected_resonances,
            )
        else:
            ampform_expr = ampform_intensity_expr
            dpd_expr = dpd_intensity_expr
        src = Rf"""
        \begin{{eqnarray}}
          \text{{AmpForm:}} && {sp.latex(ampform_expr)} \\
          \text{{DPD:}}     && {sp.latex(dpd_expr)} \\
        \end{{eqnarray}}
        """
        src = dedent(src).strip()
        display(Latex(src))


output = interactive_output(
    plot_contributions,
    controls={
        **sliders,
        "resonance_selector": resonance_selector,
        "hide_expressions": hide_expressions,
        "simplify_expressions": simplify_expressions,
    },
)
display(output, ui)