def visualize_rates(
labels: list[str], rates: np.ndarray, model_name: str, helicity: bool
) -> Path:
abs_max = float(jnp.abs(rates).max())
abs_max = np.floor(abs_max * 5) / 5
if not helicity:
rates = jnp.where(jnp.abs(rates) > 1e-7, rates, jnp.nan)
fig = go.Figure()
fig.add_trace(
go.Heatmap(
x=labels,
y=labels,
z=rates,
colorscale="RdBu_r",
customdata=100 * rates,
hovertemplate="%{x}<br>%{y}<br>Decay rate: <b>%{customdata:+.3g}%</b><extra></extra>",
zmin=-abs_max,
zmax=+abs_max,
colorbar=dict(
title="Decay rate",
title_side="right",
tickformat="+.0%",
),
)
)
if helicity:
indicate_subsystem(fig, 0, 7, label="𝐾<sup>*</sup>")
indicate_subsystem(fig, 8, 13, label="𝛥<sup>*</sup>")
indicate_subsystem(fig, 14, 25, label="𝛬</sup>*</sup>")
else:
indicate_subsystem(fig, 0, 3, label="𝐾<sup>*</sup>")
indicate_subsystem(fig, 4, 6, label="𝛥<sup>*</sup>")
indicate_subsystem(fig, 7, 12, label="𝛬</sup>*</sup>")
indicate_subsystem(fig, 13, 16, label="𝐾<sup>*</sup>", legend=False)
indicate_subsystem(fig, 17, 19, label="𝛥<sup>*</sup>", legend=False)
indicate_subsystem(fig, 20, 25, label="𝛬</sup>*</sup>", legend=False)
fig.add_annotation(
x=13.5,
y=6,
font_size=15,
showarrow=False,
text="parity-violating",
textangle=+90,
)
fig.add_annotation(
x=11.5,
y=19,
font_size=15,
text="parity-conserving",
showarrow=False,
textangle=-90,
)
fig.update_layout(
autosize=False,
height=750,
legend=dict(font_size=16, yanchor="top", y=1, xanchor="left", x=0.85),
paper_bgcolor="rgba(0, 0, 0, 0)",
title=dict(
text=f"Decay rates of {model_name}",
xanchor="center",
x=0.5,
y=0.89,
),
)
fig.update_scenes(aspectmode="data")
fig.update_yaxes(autorange="reversed")
return fig
def indicate_subsystem(
fig: go.Figure,
idx1: int,
idx2: int,
label: str,
legend: bool = True,
col: int | None = None,
) -> None:
colors = {
"𝐾<sup>*</sup>": ("red", "rgba(255,0,0,0.1)"),
"𝛬</sup>*</sup>": ("green", "rgba(0,255,0,0.1)"),
"𝛥<sup>*</sup>": ("blue", "rgba(0,0,255,0.1)"),
}
linecolor, fillcolor = colors[label]
left = idx1 - 0.5
right = idx2 + 0.5
kwargs = dict(
**dict(x0=left, x1=right, y0=left, y1=right),
fillcolor=fillcolor,
line=dict(color=linecolor, width=1),
name=label,
opacity=0.3,
showlegend=legend,
type="rect",
)
if col is not None:
kwargs["col"] = col
kwargs["row"] = 1
return fig.add_shape(**kwargs)