Importance sampling#

Model definition#

Hide code cell content
from __future__ import annotations

import logging
import os
import warnings
from pathlib import Path

import jax.numpy as jnp
import numpy as np

logging.getLogger("absl").setLevel(logging.ERROR)  # no JAX warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # no TF warnings
warnings.filterwarnings("ignore")  # sqrt negative argument

We generate data for the reaction \(J/\psi \to \gamma \pi^0\pi^0\). We limit ourselves to two resonances, so that the amplitude model contains one narrow structure. This makes it hard to numerically compute the integral over the intensity distribution.

import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)", "omega(782)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="canonical-helicity",
)
Hide code cell source
import graphviz

src = qrules.io.asdot(reaction, collapse_graphs=True)
output_file = Path("graph")
graphviz.Source(src).render(output_file, format="svg")
output_file.unlink()

import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

builder = ampform.get_builder(reaction)
builder.align_spin = False
builder.adapter.permutate_registered_topologies()
builder.scalar_initial_state_mass = True
builder.stable_final_state_ids = [0, 1, 2]
builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()

Phase space distribution#

An evenly distributed phase space sample can be generated with a TFPhaseSpaceGenerator:

from tensorwaves.data import (
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
transformer = SympyDataTransformer.from_sympy(model.kinematic_variables, backend="jax")
phsp = phsp_generator.generate(1_000_000, rng)
phsp = transformer(phsp)
Hide code cell source
import matplotlib.pyplot as plt


def convert_zero_to_nan(array):
    array = np.array(array).astype("float")
    array[array == 0] = np.nan
    return jnp.array(array)


Z, x_edges, y_edges = jnp.histogram2d(
    phsp["m_01"].real ** 2,
    phsp["m_12"].real ** 2,
    bins=100,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = convert_zero_to_nan(Z)

bin_width_x = X[0, 1] - X[0, 0]
bin_width_y = Y[1, 0] - Y[0, 0]
bar_title = (
    Rf"events per ${1e3 * bin_width_x:.0f} \times {1e3 * bin_width_y:.0f}$ MeV$^2/c^4$"
)
xlabel = R"$M^2\left(\gamma\pi^0\right)$"
ylabel = R"$M^2\left(\pi^0\pi^0\right)$"

plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_title("TFPhaseSpaceGenerator sample")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel(bar_title)
fig.savefig("TFPhaseSpaceGenerator.png")
plt.ion()
plt.close(fig)

This TFPhaseSpaceGenerator actually uses a hit-and-miss strategy on a distribution and its weights generated by a TFWeightedPhaseSpaceGenerator. That generator interfaces to the phasespace package. We have a short look at the distribution and its weights generated by a TFWeightedPhaseSpaceGenerator. The ‘unweighted’ distribution is uneven, because four-momenta events are generated using a certain decay algorithm. The weights cause these events to be normalized, so that we again have the same, evenly distributed distribution from above when we combine them.

from tensorwaves.data import TFWeightedPhaseSpaceGenerator

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
unweighted_phsp = weighted_phsp_generator.generate(1_000_000, rng)
phsp_weights = unweighted_phsp["weights"]
unweighted_phsp = transformer(unweighted_phsp)
Hide code cell source
from typing import TYPE_CHECKING

from scipy.interpolate import griddata

if TYPE_CHECKING:
    from tensorwaves.interface import DataSample


def plot_distribution_and_weights(phsp: DataSample, weights: np.ndarray) -> None:
    n_bins = 100
    x = phsp["m_01"].real ** 2
    y = phsp["m_12"].real ** 2
    X, Y = jnp.meshgrid(
        jnp.linspace(x.min(), x.max(), num=n_bins),
        jnp.linspace(y.min(), y.max(), num=n_bins),
    )

    Z_weights = griddata(np.transpose([x, y]), weights, (X, Y))
    Z_unweighted, x_edges, y_edges = jnp.histogram2d(x, y, bins=n_bins)
    Z_weighted, x_edges, y_edges = jnp.histogram2d(x, y, bins=n_bins, weights=weights)
    # https://numpy.org/doc/stable/reference/generated/numpy.histogram2d.html
    Z_unweighted = Z_unweighted.T
    Z_weighted = Z_weighted.T

    X_edges, Y_edges = jnp.meshgrid(x_edges, y_edges)
    Z_unweighted = convert_zero_to_nan(Z_unweighted)
    Z_weighted = convert_zero_to_nan(Z_weighted)

    _, axes = plt.subplots(
        dpi=200,
        figsize=(16, 5),
        ncols=3,
        tight_layout=True,
    )
    for ax in axes:
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
    axes[0].set_title("Unweighted distribution")
    axes[1].set_title("Weights")
    axes[2].set_title("Weighted phase space distribution")

    mesh = axes[0].pcolormesh(X_edges, Y_edges, Z_unweighted)
    c_bar = plt.colorbar(mesh, ax=axes[0])
    c_bar.ax.set_ylabel(bar_title)

    mesh = axes[1].pcolormesh(X, Y, Z_weights)
    c_bar = plt.colorbar(mesh, ax=axes[1])
    c_bar.ax.set_ylabel("phase space weight")

    mesh = axes[2].pcolormesh(X_edges, Y_edges, Z_weighted)
    c_bar = plt.colorbar(mesh, ax=axes[2])
    c_bar.ax.set_ylabel(bar_title)


plot_distribution_and_weights(unweighted_phsp, phsp_weights)
plt.gcf().suptitle("TFWeightedPhaseSpaceGenerator sample")
plt.savefig("TFWeightedPhaseSpaceGenerator.png")
plt.show()

Intensity distribution#

We now use a IntensityDistributionGenerator to generate a hit-and-miss data sample based on the amplitude model that we formulated for this \(J/\psi \to \gamma\pi^0\pi^0\) reaction.

from tensorwaves.function.sympy import create_parametrized_function

intensity_expr = model.expression.doit()
intensity_func = create_parametrized_function(
    expression=intensity_expr,
    parameters=model.parameter_defaults,
    backend="jax",
)
from tensorwaves.data import IntensityDistributionGenerator

data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=intensity_func,
    domain_transformer=transformer,
)
data = data_generator.generate(100_000, rng)
data = transformer(data)

Note that it takes a long time to generate a distribution for amplitude model. This is because most phase space points are outside the region where the intensity is highest and therefore result in a ‘miss’.

Hide code cell source
Z, x_edges, y_edges = jnp.histogram2d(
    data["m_01"].real ** 2,
    data["m_12"].real ** 2,
    bins=100,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = Z.T  # https://numpy.org/doc/stable/reference/generated/numpy.histogram2d.html
Z = convert_zero_to_nan(Z)

plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel("intensity")
fig.savefig("intensity-distribution.png")
plt.ion()
plt.close(fig)

The \(\omega\) resonance appears as a narrow structure on the Dalitz plot. This is problematic when computing the integral over this distribution, which is important when performing an UnbinnedNLL fit. The integral that appears in the log-likelihood has to be computed in each fit iteration and this can be done most efficiently when there are more points on which to evaluate the amplitude model in the phase space regions where the intensity is high.

The solution is to evaluate the intensity over an importance-sampled phase space sample. This is a phase space sample with more events in the regions where the intensity is high. Each point \(\tau\) carries a weight that is set to \(1/I(\tau)\). In fact, all this is, is the intensity-based sample from the previous step, with the weights computed posteriorly by simply evaluating the a amplitude model over the sample (and taking the inverse).

from copy import deepcopy

importance_phsp = deepcopy(data)
importance_weights = 1 / intensity_func(importance_phsp)

Of course, we could define a special class for this.

As expected, the inverse-intensity weights flatten the distribution again to a flat phase space sample:

Hide code cell source
plot_distribution_and_weights(importance_phsp, importance_weights)
plt.gcf().suptitle("Importance-sampled phase space distribution")
plt.savefig("importance-sampling.png")
plt.show()

Now, aren’t we duplicating things here? Not really. First, in an actual analysis, there would be no intensity-based data sample. Second, the importance-sampled phase space sample is generated with a specific parameter values. During a fit, the parameters change and the integral over the (importance-sampled) phase space changes. So after updating parameters during a fit iteration, we have to multiply the new intensities with the importance weights (the inverse of the original intensity distribution) in order to get the new distribution. This needs to be done in particular when computing the negative log likelihood (UnbinnedNLL).[1]

In the following, extreme example, we move the mass of the \(f_0(980)\) resonance far from its original position. As can be seen in the distribution below, the narrow structure has indeed moved, but the structure is still visible as a blur in the original position, because there are many more phase space points in that region.

intensity_func.update_parameters({"m_{f_{0}(980)}": 2.0})
new_intensities = intensity_func(importance_phsp)
Hide code cell source
Z, x_edges, y_edges = jnp.histogram2d(
    importance_phsp["m_01"].real ** 2,
    importance_phsp["m_12"].real ** 2,
    bins=100,
    weights=new_intensities * importance_weights,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = Z.T  # https://numpy.org/doc/stable/reference/generated/numpy.histogram2d.html
Z = convert_zero_to_nan(Z)

plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel(R"new intensity $\times$ importance weight")
fig.savefig("importance-sampling-after-modification.png")
plt.ion()
plt.close(fig)