Indexed free symbols

Indexed free symbols#

In TR-005, we made use of indexed symbols to create a \(\boldsymbol{K}\)-matrix. The problem with that approach is that IndexedBase and their resulting Indexed instances when taking indices behave strangely in an expression tree.

The following Expr uses a Symbol and a elements in IndexedBases (an Indexed instance):

import sympy as sp

x = sp.Symbol("x")
c = sp.IndexedBase("c")
alpha = sp.IndexedBase("alpha")
expression = c[0, 1] + alpha[2] * x
\[\displaystyle x {\alpha}_{2} + {c}_{0,1}\]

Although seemingly there are just three free_symbols, there are actually five:

{alpha, alpha[2], c, c[0, 1], x}

This becomes problematic when using lambdify(), particularly through symplot.prepare_sliders().

In addition, while c[0, 1] and alpha[2] are Indexed as expected, alpha and c are Symbols, not IndexedBase:

{s: type(s) for s in expression.free_symbols}
{c: sympy.core.symbol.Symbol,
 alpha[2]: sympy.tensor.indexed.Indexed,
 x: sympy.core.symbol.Symbol,
 c[0, 1]: sympy.tensor.indexed.Indexed,
 alpha: sympy.core.symbol.Symbol}

The expression tree partially explains this behavior:

import graphviz

dot = sp.dotprint(expression)

We would like to collapse the nodes under c[0, 1] and alpha[2] to two single Symbol nodes that are still nicely rendered as \(c_{0,1}\) and \(\alpha_2\). The following function does that and converts the [] into subscripts. It does that in such a way that the name of the Symbol remains as short as possible, that is, short enough that it still renders nicely as LaTeX:

from sympy.printing.latex import translate

def to_symbol(idx: sp.Indexed) -> sp.Symbol:
    base_name, _, _ = str(idx).rpartition("[")
    subscript = ",".join(map(str, idx.indices))
    if len(idx.indices) > 1:
        base_name = translate(base_name)
        subscript = "_{" + subscript + "}"
    return sp.Symbol(f"{base_name}{subscript}")

Next, we use subs() to substitute the nodes c[0, 1] and alpha[2] with these Symbols:

def replace_indexed_symbols(expression: sp.Expr) -> sp.Expr:
    return expression.subs({
        s: to_symbol(s) for s in expression.free_symbols if isinstance(s, sp.Indexed)

And indeed, the expression tree has been simplified correctly!

new_expression = replace_indexed_symbols(expression)
dot = sp.dotprint(new_expression)