J/ψ → K⁰ Σ⁺ p̅#

Decay definition#

We follow this example, which was generated with QRules, and leave out the \(K\)-resonances and the resonances that lie far outside of phase space.

Hide code cell source
REACTION = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["K0", "Sigma+", "p~"],
    allowed_interaction_types="strong",
    formalism="canonical-helicity",
    mass_conservation_factor=0.05,
)
REACTION = normalize_state_ids(REACTION)
dot = qrules.io.asdot(REACTION, collapse_graphs=True)
graphviz.Source(dot)
Hide code cell source
DECAY = to_three_body_decay(REACTION.transitions, min_ls=True)
Markdown(as_markdown_table([DECAY.initial_state, *DECAY.final_state.values()]))
Hide code cell source
resonances = sorted(
    {t.resonance for t in DECAY.chains},
    key=lambda p: (p.name[0], p.mass),
)
resonance_names = [p.name for p in resonances]
Markdown(as_markdown_table(resonances))
Hide code cell source
Latex(aslatex(DECAY, with_jp=True))

Model formulation#

The total, aligned intensity expression looks as follows:

model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=False)
for chain in model_builder.decay.chains:
    model_builder.dynamics_choices.register_builder(
        chain, formulate_breit_wigner_with_form_factor
    )
model = model_builder.formulate(reference_subsystem=2)
model.intensity
Hide code cell source

Each unaligned amplitude is defined as follows:

Hide code cell source
Hide code cell source
s, m0, w0, m1, m2, L, R, z = sp.symbols("s m0 Gamma0 m1 m2 L R z", nonnegative=True)
exprs = [
    RelativisticBreitWigner(s, m0, w0, m1, m2, L, R),
    EnergyDependentWidth(s, m0, w0, m1, m2, L, R),
    FormFactor(s, m1, m2, L, R),
    BlattWeisskopfSquared(z, L),
]
Latex(aslatex({e: e.doit(deep=False) for e in exprs}))

Preparing for input data#

Hide code cell source
i, j = (3, 2)
k, *_ = {1, 2, 3} - {i, j}
σk, σk_expr = list(model.invariants.items())[k - 1]
Latex(aslatex({σk: σk_expr}))
Hide code cell source
resolution = 1_000
m = sorted(model.masses, key=str)
x_min = float(((m[j] + m[k]) ** 2).xreplace(model.masses))
x_max = float(((m[0] - m[i]) ** 2).xreplace(model.masses))
y_min = float(((m[i] + m[k]) ** 2).xreplace(model.masses))
y_max = float(((m[0] - m[j]) ** 2).xreplace(model.masses))
x_diff = x_max - x_min
y_diff = y_max - y_min
x_min -= 0.05 * x_diff
x_max += 0.05 * x_diff
y_min -= 0.05 * y_diff
y_max += 0.05 * y_diff
X, Y = jnp.meshgrid(
    jnp.linspace(x_min, x_max, num=resolution),
    jnp.linspace(y_min, y_max, num=resolution),
)
Hide code cell source
definitions = dict(model.variables)
definitions[σk] = σk_expr
definitions = {
    symbol: expr.xreplace(definitions).xreplace(model.masses)
    for symbol, expr in definitions.items()
}
data_transformer = SympyDataTransformer.from_sympy(definitions, backend="jax")
dalitz_data = {
    f"sigma{i}": X,
    f"sigma{j}": Y,
}
dalitz_data.update(data_transformer(dalitz_data))
Hide code cell source
unfolded_intensity_expr = perform_cached_doit(model.intensity)
full_intensity_expr = perform_cached_doit(
    unfolded_intensity_expr.xreplace(model.amplitudes)
)
free_parameters = {
    k: v
    for k, v in model.parameter_defaults.items()
    if isinstance(k, sp.Indexed)
    if "production" in str(k) or "decay" in str(k)
}
fixed_parameters = {
    k: v for k, v in model.parameter_defaults.items() if k not in free_parameters
}
intensity_func = perform_cached_lambdify(
    full_intensity_expr.xreplace(fixed_parameters),
    parameters=free_parameters,
    backend="jax",
)
intensities = intensity_func(dalitz_data)

Dalitz plot#

Hide code cell source
def get_decay_products(subsystem_id: int) -> tuple[State, State]:
    return tuple(s for s in DECAY.final_state.values() if s.index != subsystem_id)


plt.rc("font", size=18)
I_tot = jnp.nansum(intensities)
normalized_intensities = intensities / I_tot

fig, ax = plt.subplots(figsize=(14, 10))
mesh = ax.pcolormesh(X, Y, normalized_intensities)
ax.set_aspect("equal")
c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)
c_bar.ax.set_ylabel("Normalized intensity (a.u.)")
sigma_labels = {
    i: Rf"$\sigma_{i} = M^2\left({' '.join(p.latex for p in get_decay_products(i))}\right)$"
    for i in (1, 2, 3)
}
ax.set_xlabel(sigma_labels[i])
ax.set_ylabel(sigma_labels[j])
plt.show()
Hide code cell source
def compute_sub_intensity(
    func: ParametrizedFunction, phsp: DataSample, resonance_latex: str
) -> jnp.ndarray:
    original_parameters = dict(func.parameters)
    zero_parameters = {
        k: 0
        for k, v in func.parameters.items()
        if R"\mathcal{H}" in k
        if resonance_latex not in k
    }
    func.update_parameters(zero_parameters)
    intensities = func(phsp)
    func.update_parameters(original_parameters)
    return intensities


plt.rc("font", size=16)
fig, axes = plt.subplots(figsize=(18, 6), ncols=2, sharey=True)
fig.subplots_adjust(wspace=0.02)
ax1, ax2 = axes
x = jnp.sqrt(X[0])
y = jnp.sqrt(Y[:, 0])
ax1.fill_between(x, jnp.nansum(normalized_intensities, axis=0), alpha=0.5)
ax2.fill_between(y, jnp.nansum(normalized_intensities, axis=1), alpha=0.5)
for ax in axes:
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)
    ax.autoscale(enable=False, axis="x")
ax1.set_ylabel("Normalized intensity (a.u.)")
ax1.set_xlabel(sigma_labels[i])
ax2.set_xlabel(sigma_labels[j])
resonance_counter1, resonance_counter2 = 0, 0
for chain in tqdm(model.decay.chains, disable=NO_TQDM):
    resonance = chain.resonance
    if set(chain.decay_products) == set(get_decay_products(subsystem_id=i)):
        ax = ax1
        resonance_counter1 += 1
        color = f"C{resonance_counter1}"
        projection_axis = 0
        x_data = x
    elif set(chain.decay_products) == set(get_decay_products(subsystem_id=j)):
        ax = ax2
        resonance_counter2 += 1
        color = f"C{resonance_counter2}"
        projection_axis = 1
        x_data = y
    else:
        raise NotImplementedError
    sub_intensities = compute_sub_intensity(
        intensity_func, dalitz_data, resonance.latex
    )
    ax.plot(x_data, jnp.nansum(sub_intensities / I_tot, axis=projection_axis), c=color)
    ax.axvline(resonance.mass, label=f"${resonance.latex}$", c=color, ls="dashed")
for ax in axes:
    ax.legend(fontsize=12)
plt.show()