Indexed
free symbols#
In TR-005, we made use of indexed symbols to create a \(\boldsymbol{K}\)-matrix. The problem with that approach is that IndexedBase
and their resulting Indexed
instances when taking indices behave strangely in an expression tree.
The following Expr
uses a Symbol
and a elements in IndexedBase
s (an Indexed
instance):
import sympy as sp
x = sp.Symbol("x")
c = sp.IndexedBase("c")
alpha = sp.IndexedBase("alpha")
expression = c[0, 1] + alpha[2] * x
expression
Although seemingly there are just three free_symbols
, there are actually five:
expression.free_symbols
{alpha, alpha[2], c, c[0, 1], x}
This becomes problematic when using lambdify()
, particularly through symplot.prepare_sliders()
.
In addition, while c[0, 1]
and alpha[2]
are Indexed
as expected, alpha
and c
are Symbol
s, not IndexedBase
:
{s: type(s) for s in expression.free_symbols}
{c: sympy.core.symbol.Symbol,
alpha[2]: sympy.tensor.indexed.Indexed,
x: sympy.core.symbol.Symbol,
c[0, 1]: sympy.tensor.indexed.Indexed,
alpha: sympy.core.symbol.Symbol}
The expression tree partially explains this behavior:
import graphviz
dot = sp.dotprint(expression)
graphviz.Source(dot);
We would like to collapse the nodes under c[0, 1]
and alpha[2]
to two single Symbol
nodes that are still nicely rendered as \(c_{0,1}\) and \(\alpha_2\). The following function does that and converts the []
into subscripts. It does that in such a way that the name of the Symbol
remains as short as possible, that is, short enough that it still renders nicely as LaTeX:
from sympy.printing.latex import translate
def to_symbol(idx: sp.Indexed) -> sp.Symbol:
base_name, _, _ = str(idx).rpartition("[")
subscript = ",".join(map(str, idx.indices))
if len(idx.indices) > 1:
base_name = translate(base_name)
subscript = "_{" + subscript + "}"
return sp.Symbol(f"{base_name}{subscript}")
Next, we use subs()
to substitute the nodes c[0, 1]
and alpha[2]
with these Symbol
s:
def replace_indexed_symbols(expression: sp.Expr) -> sp.Expr:
return expression.subs({
s: to_symbol(s) for s in expression.free_symbols if isinstance(s, sp.Indexed)
})
And indeed, the expression tree has been simplified correctly!
new_expression = replace_indexed_symbols(expression)
dot = sp.dotprint(new_expression)
graphviz.Source(dot);