7.9. PC and PV currents#

Hide code cell content
from __future__ import annotations

import itertools
import re
from collections.abc import Iterable
from functools import cache
from pathlib import Path

import cloudpickle
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
import sympy as sp
import unicodeitplus
from ampform.io import aslatex
from ampform.sympy._cache import cache_to_disk
from ampform_dpd import AmplitudeModel  # pyright:ignore[reportPrivateUsage]
from ampform_dpd.decay import IsobarNode, ThreeBodyDecayChain, to_particle
from ampform_dpd.io import cached
from IPython.display import HTML, Markdown, Math, display
from tensorwaves.interface import ParametrizedFunction
from tqdm.auto import tqdm

from polarimetry.data import create_data_transformer, generate_phasespace_sample
from polarimetry.function import integrate_intensity
from polarimetry.io import mute_jax_warnings
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles

src = '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
display(HTML(src))
mute_jax_warnings()

MODEL_FILE = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")

7.9.1. Model formulation#

Hide code cell source
def formulate_model(title: str) -> AmplitudeModel:
    builder = load_model_builder(MODEL_FILE, PARTICLES, title)
    imported_parameters = load_model_parameters(
        MODEL_FILE, builder.decay, title, PARTICLES
    )
    model = builder.formulate()
    model.parameter_defaults.update(imported_parameters)
    return model


CANO_MODEL = formulate_model("Alternative amplitude model obtained using LS couplings")
HELI_MODEL = formulate_model("Default amplitude model")
Hide code cell source
@cache_to_disk(
    dependencies=[
        "ampform",
        "ampform_dpd",
        "cloudpickle",
        "jax",
        "lc2pkpi-polarimetry",
        "sympy",
    ],
    dump_function=cloudpickle.dump,
)
def lambdify_ls_and_default_model() -> tuple[
    ParametrizedFunction, ParametrizedFunction
]:
    return lambdify(HELI_MODEL), lambdify(CANO_MODEL)


def lambdify(model: AmplitudeModel) -> sp.Expr:
    intensity_expr = cached.unfold(model)
    pars = model.parameter_defaults
    free_parameters = {s: v for s, v in pars.items() if "production" in str(s)}
    fixed_parameters = {s: v for s, v in pars.items() if s not in free_parameters}
    subs_intensity_expr = cached.xreplace(intensity_expr, fixed_parameters)
    return cached.lambdify(subs_intensity_expr, free_parameters)


HELI_FUNC, CANO_FUNC = lambdify_ls_and_default_model()

7.9.2. Sort couplings by PC/PV#

Hide code cell source
def is_pc(node: IsobarNode) -> int:
    """Check if the production node in a decay is parity-conserving."""
    parent = to_particle(node.parent)
    child1 = to_particle(node.child1)
    child2 = to_particle(node.child2)
    L = node.interaction.L
    return parent.parity == child1.parity * child2.parity * (-1) ** L


def get_chain_identifier(chain: ThreeBodyDecayChain) -> str:
    resonance_id = chain.resonance.latex
    return f"{resonance_id}, L={chain.incoming_ls.L}"


PC_PV_MAPPING: dict[str, bool] = {
    get_chain_identifier(chain): Rf"\texttt{{{is_pc(chain.production_node)}}}"
    for chain in CANO_MODEL.decay.chains
}
Math(aslatex(PC_PV_MAPPING))
\[\begin{split}\displaystyle \begin{array}{rcl} \Lambda(1405), L=0 &=& \texttt{True} \\ \Lambda(1405), L=1 &=& \texttt{False} \\ \Lambda(1520), L=1 &=& \texttt{False} \\ \Lambda(1520), L=2 &=& \texttt{True} \\ \Lambda(1600), L=0 &=& \texttt{False} \\ \Lambda(1600), L=1 &=& \texttt{True} \\ \Lambda(1670), L=0 &=& \texttt{True} \\ \Lambda(1670), L=1 &=& \texttt{False} \\ \Lambda(1690), L=1 &=& \texttt{False} \\ \Lambda(1690), L=2 &=& \texttt{True} \\ \Lambda(2000), L=0 &=& \texttt{True} \\ \Lambda(2000), L=1 &=& \texttt{False} \\ \Delta(1232), L=1 &=& \texttt{True} \\ \Delta(1232), L=2 &=& \texttt{False} \\ \Delta(1600), L=1 &=& \texttt{True} \\ \Delta(1600), L=2 &=& \texttt{False} \\ \Delta(1700), L=1 &=& \texttt{False} \\ \Delta(1700), L=2 &=& \texttt{True} \\ K(700), L=0 &=& \texttt{True} \\ K(700), L=1 &=& \texttt{False} \\ K(892), L=0 &=& \texttt{False} \\ K(892), L=1 &=& \texttt{True} \\ K(892), L=2 &=& \texttt{False} \\ K(1430), L=0 &=& \texttt{True} \\ K(1430), L=1 &=& \texttt{False} \\ \end{array}\end{split}\]
Hide code cell source
def sort_couplings(parameters: Iterable[str], helicity: bool) -> list[str]:
    couplings = (p for p in parameters if "production" in p)
    if helicity:
        return sorted(couplings, key=factorize_latex)
    return sorted(
        couplings,
        key=lambda x: (
            PC_PV_MAPPING[get_partial_wave_latex(x, helicity)],
            factorize_latex(x),
        ),
    )


@cache
def factorize_latex(coupling: str) -> tuple[str, int, str, str]:
    resonance, mass, *indices = re.match(
        r"^.*\[(\\?[DKL][a-z]*)\((\d+)\), (.+), (.+)\]$", coupling
    ).groups()
    return resonance, int(mass), *indices


@cache
def get_partial_wave_latex(coupling: str, helicity: bool) -> str:
    R, mass, *indices = factorize_latex(coupling)
    latex_indices = map(to_frac, indices)
    if helicity:
        λ_resonance, λ_recoil = latex_indices
        return f"{R}({mass}), {λ_resonance}, {λ_recoil}"
    L, _ = latex_indices
    return f"{R}({mass}), L={L}"


@cache
def to_frac(text: str) -> str:
    if "/" in text:
        denominator, nominator = text.split("/")
        denominator = int(denominator)
        nominator = int(nominator)
        sign = "-" if denominator < 0 else "+"
        return Rf"{sign}\frac{{{abs(denominator)}}}{{{nominator}}}"
    return text
CANO_COUPLINGS = sort_couplings(CANO_FUNC.parameters, helicity=False)
HELI_COUPLINGS = sort_couplings(HELI_FUNC.parameters, helicity=True)
Hide code cell source
sp.Matrix([
    [sp.Symbol(get_partial_wave_latex(s, helicity=False)) for s in CANO_COUPLINGS],
    [sp.Symbol(get_partial_wave_latex(s, helicity=True)) for s in HELI_COUPLINGS],
]).T
\[\begin{split}\displaystyle \left[\begin{matrix}K(700), L=1 & K(700), 0, -\frac{1}{2}\\K(892), L=0 & K(700), 0, +\frac{1}{2}\\K(892), L=2 & K(892), -1, -\frac{1}{2}\\K(1430), L=1 & K(892), 0, -\frac{1}{2}\\\Delta(1232), L=2 & K(892), 0, +\frac{1}{2}\\\Delta(1600), L=2 & K(892), 1, +\frac{1}{2}\\\Delta(1700), L=1 & K(1430), 0, -\frac{1}{2}\\\Lambda(1405), L=1 & K(1430), 0, +\frac{1}{2}\\\Lambda(1520), L=1 & \Delta(1232), -\frac{1}{2}, 0\\\Lambda(1600), L=0 & \Delta(1232), +\frac{1}{2}, 0\\\Lambda(1670), L=1 & \Delta(1600), -\frac{1}{2}, 0\\\Lambda(1690), L=1 & \Delta(1600), +\frac{1}{2}, 0\\\Lambda(2000), L=1 & \Delta(1700), -\frac{1}{2}, 0\\K(700), L=0 & \Delta(1700), +\frac{1}{2}, 0\\K(892), L=1 & \Lambda(1405), -\frac{1}{2}, 0\\K(892), L=1 & \Lambda(1405), +\frac{1}{2}, 0\\K(1430), L=0 & \Lambda(1520), -\frac{1}{2}, 0\\\Delta(1232), L=1 & \Lambda(1520), +\frac{1}{2}, 0\\\Delta(1600), L=1 & \Lambda(1600), -\frac{1}{2}, 0\\\Delta(1700), L=2 & \Lambda(1600), +\frac{1}{2}, 0\\\Lambda(1405), L=0 & \Lambda(1670), -\frac{1}{2}, 0\\\Lambda(1520), L=2 & \Lambda(1670), +\frac{1}{2}, 0\\\Lambda(1600), L=1 & \Lambda(1690), -\frac{1}{2}, 0\\\Lambda(1670), L=0 & \Lambda(1690), +\frac{1}{2}, 0\\\Lambda(1690), L=2 & \Lambda(2000), -\frac{1}{2}, 0\\\Lambda(2000), L=0 & \Lambda(2000), +\frac{1}{2}, 0\end{matrix}\right]\end{split}\]

7.9.3. Compute decay rates#

TRANSFORMER = create_data_transformer(HELI_MODEL)
PHSP = generate_phasespace_sample(HELI_MODEL.decay, n_events=100_000, seed=0)
PHSP = TRANSFORMER(PHSP)
Hide code cell source
def compute_decay_rate_matrix(
    intensity_func: ParametrizedFunction, couplings: list[str]
) -> tuple[list[str], np.ndarray]:
    I_tot = integrate_intensity(intensity_func(PHSP))
    n_couplings = len(couplings)
    matrix = np.zeros((n_couplings, n_couplings))
    original_parameters = dict(intensity_func.parameters)
    cache: dict[tuple[int, int]] = {}
    progress_bar = tqdm(
        desc="Computing decay rates",
        total=n_couplings * (n_couplings + 1) // 2,
    )
    is_helicity = is_helicity_func(intensity_func)
    items = list(enumerate(couplings))
    for (i, ci), (j, cj) in itertools.product(items, items):
        pi = get_partial_wave_name(ci, is_helicity)
        pj = get_partial_wave_name(cj, is_helicity)
        progress_bar.set_postfix_str(Rf"{pi} x {pj}")
        transpose_value = cache.get((j, i))
        if transpose_value is None:
            new_parameters = dict.fromkeys(couplings, 0)
            new_parameters[ci] = original_parameters[ci]
            new_parameters[cj] = original_parameters[cj]
            intensity_func.update_parameters(new_parameters)
            I_sub = integrate_intensity(intensity_func(PHSP))
            intensity_func.update_parameters(original_parameters)
            cache[i, j] = I_sub / I_tot
            matrix[i, j] = cache[i, j]
            progress_bar.update()
        else:
            cache[i, j] = transpose_value
            matrix[i, j] = transpose_value
    progress_bar.close()
    return matrix


def is_helicity_func(intensity_func: ParametrizedFunction) -> bool:
    couplings = (p for p in intensity_func.parameters if "production" in p)
    return not any("LS" in p for p in couplings)


@cache
def get_partial_wave_name(coupling: str, helicity: bool) -> str:
    latex = get_partial_wave_latex(coupling, helicity)
    latex = re.sub(r"\\frac{(\d+)}{(\d+)}", r"\1/\2", latex)
    latex = latex.replace("1/2", "½")
    return unicodeitplus.replace(latex)
CANO_RATES = compute_decay_rate_matrix(CANO_FUNC, CANO_COUPLINGS)
assert CANO_RATES.shape == (len(CANO_COUPLINGS), len(CANO_COUPLINGS))
HELI_RATES = compute_decay_rate_matrix(HELI_FUNC, HELI_COUPLINGS)
assert HELI_RATES.shape == (len(HELI_COUPLINGS), len(HELI_COUPLINGS))
Hide code cell source
def compute_interference(rates: jnp.ndarray) -> jnp.ndarray:
    m, n = rates.shape
    assert m == n
    d = rates.diagonal()
    D = d * np.identity(m)
    X = d[None] + d[None].T
    return (rates - X + D) / 2 + D
CANO_INTERFERENCE = compute_interference(CANO_RATES)
HELI_INTERFERENCE = compute_interference(HELI_RATES)

7.9.4. Visualize decay rates#

Hide code cell source
def visualize_rates(
    labels: list[str], rates: np.ndarray, model_name: str, helicity: bool
) -> Path:
    abs_max = float(jnp.abs(rates).max())
    abs_max = np.floor(abs_max * 5) / 5
    if not helicity:
        rates = jnp.where(jnp.abs(rates) > 1e-7, rates, jnp.nan)
    fig = go.Figure()
    fig.add_trace(
        go.Heatmap(
            x=labels,
            y=labels,
            z=rates,
            colorscale="RdBu_r",
            customdata=100 * rates,
            hovertemplate="%{x}<br>%{y}<br>Decay rate: <b>%{customdata:+.3g}%</b><extra></extra>",
            zmin=-abs_max,
            zmax=+abs_max,
            colorbar=dict(
                title="Decay rate",
                title_side="right",
                tickformat="+.0%",
            ),
        )
    )
    if helicity:
        indicate_subsystem(fig, 0, 7, label="𝐾<sup>*</sup>")
        indicate_subsystem(fig, 8, 13, label="𝛥<sup>*</sup>")
        indicate_subsystem(fig, 14, 25, label="𝛬</sup>*</sup>")
    else:
        indicate_subsystem(fig, 0, 3, label="𝐾<sup>*</sup>")
        indicate_subsystem(fig, 4, 6, label="𝛥<sup>*</sup>")
        indicate_subsystem(fig, 7, 12, label="𝛬</sup>*</sup>")
        indicate_subsystem(fig, 13, 15, label="𝐾<sup>*</sup>", legend=False)
        indicate_subsystem(fig, 16, 18, label="𝛥<sup>*</sup>", legend=False)
        indicate_subsystem(fig, 19, 24, label="𝛬</sup>*</sup>", legend=False)
        fig.add_annotation(
            x=13.5,
            y=6,
            font_size=15,
            showarrow=False,
            text="parity-violating",
            textangle=+90,
        )
        fig.add_annotation(
            x=11.5,
            y=19,
            font_size=15,
            text="parity-conserving",
            showarrow=False,
            textangle=-90,
        )
    fig.update_layout(
        autosize=False,
        height=750,
        legend=dict(font_size=16, yanchor="top", y=1, xanchor="left", x=0.85),
        paper_bgcolor="rgba(0, 0, 0, 0)",
        title=dict(
            text=f"Decay rates of {model_name}",
            xanchor="center",
            x=0.5,
            y=0.89,
        ),
    )
    fig.update_scenes(aspectmode="data")
    fig.update_yaxes(autorange="reversed")
    return fig


def indicate_subsystem(
    fig: go.Figure, idx1: int, idx2: int, label: str, legend: bool = True
) -> None:
    colors = {
        "𝐾<sup>*</sup>": ("red", "rgba(255,0,0,0.1)"),
        "𝛬</sup>*</sup>": ("green", "rgba(0,255,0,0.1)"),
        "𝛥<sup>*</sup>": ("blue", "rgba(0,0,255,0.1)"),
    }
    linecolor, fillcolor = colors[label]
    left = idx1 - 0.5
    right = idx2 + 0.5
    return fig.add_shape(
        **dict(x0=left, x1=right, y0=left, y1=right),
        fillcolor=fillcolor,
        line=dict(color=linecolor, width=1),
        name=label,
        opacity=0.3,
        showlegend=legend,
        type="rect",
    )
Hide code cell source
CANO_LABELS = [get_partial_wave_name(p, helicity=False) for p in CANO_COUPLINGS]
fig = visualize_rates(
    CANO_LABELS, CANO_INTERFERENCE, "LS model (canonical basis)", helicity=False
)
CANONICAL_PATH = Path("../_static/images/decay-rates-default-canonical.svg")
fig.write_image(CANONICAL_PATH)
fig
Hide code cell source
HELI_LABELS = [get_partial_wave_name(p, helicity=True) for p in HELI_COUPLINGS]
fig = visualize_rates(
    HELI_LABELS, HELI_INTERFERENCE, "default model (helicity basis)", helicity=True
)
HELICITY_PATH = Path("../_static/images/decay-rates-default-helicity.svg")
fig.write_image(HELICITY_PATH)
fig