Skip to content

Utils Submodule

SRToolkit.utils

Utilities for expression representation, compilation, generation, and evaluation.

Modules:

Name Description
symbol_library

The SymbolLibrary class — manages the token vocabulary and token properties.

expression_tree

The Node binary-tree representation and conversion utilities for expressions.

expression_compiler

Compiles token-list or tree expressions into executable Python callables.

expression_simplifier

SymPy-backed algebraic simplification, including constant folding.

expression_generator

PCFG construction from a SymbolLibrary and Monte-Carlo expression sampling.

grammar

CFG/PCFG representation, constraint protocol, and stateful derivation — Rule, Grammar, Constraint, Derivation, and more.

measures

Distance and similarity measures: edit distance, tree edit distance, and Behavior-aware Expression Distance (BED).

serialization

Internal JSON serialization utilities for numpy types.

Node

Node(symbol: str, right: Optional[Node] = None, left: Optional[Node] = None)

A node in a binary expression tree.

  • Binary operators ("op") set both left and right.
  • Unary functions ("fn") set only left; right is None.
  • Leaves (variables, constants, literals, numeric values) have both children as None.

Examples:

>>> node = Node("+", Node("x"), Node("1"))
>>> len(node)
3
Warning

The second positional argument is right, not left. When passing children positionally (e.g. Node("+", Node("a"), Node("b"))), Node("a") becomes the right child and Node("b") the left. Use keyword arguments to avoid confusion: Node("+", right=Node("a"), left=Node("b")).

Parameters:

Name Type Description Default
symbol str

Token string stored at this node.

required
right Optional[Node]

Right operand (binary operators only).

None
left Optional[Node]

Left operand (operators and unary functions).

None
Source code in SRToolkit/utils/expression_tree.py
def __init__(self, symbol: str, right: Optional["Node"] = None, left: Optional["Node"] = None) -> None:
    """
    A node in a binary expression tree.

    - Binary operators (``"op"``) set both ``left`` and ``right``.
    - Unary functions (``"fn"``) set only ``left``; ``right`` is ``None``.
    - Leaves (variables, constants, literals, numeric values) have both children as ``None``.

    Examples:
        >>> node = Node("+", Node("x"), Node("1"))
        >>> len(node)
        3

    Warning:
        The second positional argument is ``right``, not ``left``. When passing
        children positionally (e.g. ``Node("+", Node("a"), Node("b"))``),
        ``Node("a")`` becomes the ``right`` child and ``Node("b")`` the ``left``.
        Use keyword arguments to avoid confusion: ``Node("+", right=Node("a"), left=Node("b"))``.

    Args:
        symbol: Token string stored at this node.
        right: Right operand (binary operators only).
        left: Left operand (operators and unary functions).
    """
    self.symbol = symbol
    self.right = right
    self.left = left

to_list

to_list(symbol_library: Optional[SymbolLibrary] = None, notation: str = 'infix') -> List[str]

Transforms the tree rooted at this node into a list of tokens.

Examples:

>>> node = Node("+", Node("X_0"), Node("1"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_0']
>>> node.to_list(notation="postfix")
['1', 'X_0', '+']
>>> node.to_list(notation="prefix")
['+', '1', 'X_0']
>>> node = Node("+", Node("*", Node("X_0"), Node("X_1")), Node("1"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_1', '*', 'X_0']
>>> node.to_list(notation="infix")
['1', '+', '(', 'X_1', '*', 'X_0', ')']
>>> node = Node("sin", None, Node("X_0"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['sin', '(', 'X_0', ')']
>>> node = Node("^2", None, Node("X_0"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['X_0', '^2']
>>> node.to_list()
['(', 'X_0', ')', '^2']
>>> node = Node("*", Node("*", Node("X_0"), Node("X_0")),  Node("X_0"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols(),notation="infix")
['X_0', '*', '(', 'X_0', '*', 'X_0', ')']

Parameters:

Name Type Description Default
symbol_library Optional[SymbolLibrary]

Symbol library used to determine token types and precedences during infix reconstruction. If None with "infix" notation, a UserWarning is issued and the output may contain redundant parentheses.

None
notation str

Output notation: "infix", "prefix", or "postfix". Default "infix".

'infix'

Returns:

Type Description
List[str]

Token list representing the subtree rooted at this node.

Raises:

Type Description
Exception

If notation is not one of the accepted values.

Exception

If symbol_library is provided and a token's type cannot be resolved during infix reconstruction.

Source code in SRToolkit/utils/expression_tree.py
def to_list(self, symbol_library: Optional[SymbolLibrary] = None, notation: str = "infix") -> List[str]:
    """
    Transforms the tree rooted at this node into a list of tokens.

    Examples:
        >>> node = Node("+", Node("X_0"), Node("1"))
        >>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
        ['1', '+', 'X_0']
        >>> node.to_list(notation="postfix")
        ['1', 'X_0', '+']
        >>> node.to_list(notation="prefix")
        ['+', '1', 'X_0']
        >>> node = Node("+", Node("*", Node("X_0"), Node("X_1")), Node("1"))
        >>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
        ['1', '+', 'X_1', '*', 'X_0']
        >>> node.to_list(notation="infix")
        ['1', '+', '(', 'X_1', '*', 'X_0', ')']
        >>> node = Node("sin", None, Node("X_0"))
        >>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
        ['sin', '(', 'X_0', ')']
        >>> node = Node("^2", None, Node("X_0"))
        >>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
        ['X_0', '^2']
        >>> node.to_list()
        ['(', 'X_0', ')', '^2']
        >>> node = Node("*", Node("*", Node("X_0"), Node("X_0")),  Node("X_0"))
        >>> node.to_list(symbol_library=SymbolLibrary.default_symbols(),notation="infix")
        ['X_0', '*', '(', 'X_0', '*', 'X_0', ')']

    Args:
        symbol_library: Symbol library used to determine token types and precedences
            during infix reconstruction. If ``None`` with ``"infix"`` notation, a
            ``UserWarning`` is issued and the output may contain redundant parentheses.
        notation: Output notation: ``"infix"``, ``"prefix"``, or ``"postfix"``.
            Default ``"infix"``.

    Returns:
        Token list representing the subtree rooted at this node.

    Raises:
        Exception: If ``notation`` is not one of the accepted values.
        Exception: If ``symbol_library`` is provided and a token's type cannot be
            resolved during infix reconstruction.
    """
    # if symbol_library is None:
    #     symbol_library = SymbolLibrary.default_symbols()

    left = [] if self.left is None else self.left.to_list(symbol_library, notation)
    right = [] if self.right is None else self.right.to_list(symbol_library, notation)

    if notation == "prefix":
        return [self.symbol] + left + right

    elif notation == "postfix":
        return left + right + [self.symbol]

    elif notation == "infix" and symbol_library is None:
        try:
            symbol_library = SymbolLibrary.get_active()
        except RuntimeError:
            pass

    if notation == "infix" and symbol_library is None:
        warnings.warn(
            "Symbol library not provided. Generated expression may contain unnecessary parentheses and"
            " have other issues."
        )
        if self.left is None and self.right is None:
            return [self.symbol]
        if self.right is None and self.left is not None:
            if self.symbol[0] == "^":
                return ["("] + left + [")", self.symbol]
            else:
                return [self.symbol, "("] + left + [")"]
        else:
            if len(left) > 1:
                left = ["("] + left + [")"]
            if len(right) > 1:
                right = ["("] + right + [")"]
            return left + [self.symbol] + right

    if notation == "infix":
        assert symbol_library is not None, "[Node.to_list] parameter symbol_library should be of type SymbolLibrary"
        if is_float(self.symbol):
            return [self.symbol]
        if symbol_library.get_type(self.symbol) in ["var", "const", "lit"]:
            return [self.symbol]
        elif symbol_library.get_type(self.symbol) == "fn":
            if symbol_library.get_precedence(self.symbol) > 0:
                return [self.symbol, "("] + left + [")"]
            else:
                if len(left) > 1:
                    left = ["("] + left + [")"]
                return left + [self.symbol]
        elif symbol_library.get_type(self.symbol) == "op":
            if (
                self.left is not None
                and not is_float(self.left.symbol)
                and -1
                < symbol_library.get_precedence(self.left.symbol)
                <= symbol_library.get_precedence(self.symbol)
            ):
                left = ["("] + left + [")"]
            if (
                self.right is not None
                and not is_float(self.right.symbol)
                and -1
                < symbol_library.get_precedence(self.right.symbol)
                <= symbol_library.get_precedence(self.symbol)
            ):
                right = ["("] + right + [")"]
            return left + [self.symbol] + right
        else:
            raise Exception(f"Invalid symbol type for symbol {self.symbol}.")
    else:
        raise Exception(
            "Invalid notation selected. Use 'infix', 'prefix', 'postfix', or leave blank (defaults to 'infix')."
        )

to_latex

to_latex(symbol_library: Optional[SymbolLibrary] = None) -> str

Transforms the tree rooted at this node into a LaTeX expression.

Examples:

>>> node = Node("+", right=Node("X_0"), left=Node("1"))
>>> node.to_latex(symbol_library=SymbolLibrary.default_symbols())
'$1 + X_{0}$'
>>> node = Node("+", right=Node("*", right=Node("X_0"), left=Node("X_1")), left=Node("1"))
>>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
$1 + X_{1} \cdot X_{0}$
>>> node = Node("sin", None, Node("X_0"))
>>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
$\sin X_{0}$
>>> node = Node("+", right=Node("*", right=Node("X_0"), left=Node("C")), left=Node("C"))
>>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
$C_{0} + C_{1} \cdot X_{0}$

Parameters:

Name Type Description Default
symbol_library Optional[SymbolLibrary]

Symbol library providing LaTeX templates for each token. If None, falls back to the currently active library set via 'with SymbolLibrary(...) as sl:'. Defaults to None.

None

Returns:

Type Description
str

A LaTeX string of the form $...$.

Raises:

Type Description
Exception

If the tree contains a token whose type cannot be resolved in symbol_library.

Source code in SRToolkit/utils/expression_tree.py
def to_latex(self, symbol_library: Optional[SymbolLibrary] = None) -> str:
    r"""
    Transforms the tree rooted at this node into a LaTeX expression.

    Examples:
        >>> node = Node("+", right=Node("X_0"), left=Node("1"))
        >>> node.to_latex(symbol_library=SymbolLibrary.default_symbols())
        '$1 + X_{0}$'
        >>> node = Node("+", right=Node("*", right=Node("X_0"), left=Node("X_1")), left=Node("1"))
        >>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
        $1 + X_{1} \cdot X_{0}$
        >>> node = Node("sin", None, Node("X_0"))
        >>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
        $\sin X_{0}$
        >>> node = Node("+", right=Node("*", right=Node("X_0"), left=Node("C")), left=Node("C"))
        >>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
        $C_{0} + C_{1} \cdot X_{0}$

    Args:
        symbol_library: Symbol library providing LaTeX templates for each token.
            If None, falls back to the currently active library set via
            'with SymbolLibrary(...) as sl:'. Defaults to None.

    Returns:
        A LaTeX string of the form ``$...$``.

    Raises:
        Exception: If the tree contains a token whose type cannot be resolved in
            ``symbol_library``.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_active()
    return f"${self.__to_latex_rec(symbol_library)[0]}$"

height

height() -> int

Return the height of the subtree rooted at this node.

A single-node tree has height 1.

Examples:

>>> node = Node("+", Node("x"), Node("1"))
>>> node.height()
2

Returns:

Type Description
int

Height of the subtree.

Source code in SRToolkit/utils/expression_tree.py
def height(self) -> int:
    """
    Return the height of the subtree rooted at this node.

    A single-node tree has height 1.

    Examples:
        >>> node = Node("+", Node("x"), Node("1"))
        >>> node.height()
        2

    Returns:
        Height of the subtree.
    """
    return 1 + max(
        (self.left.height() if self.left is not None else 0),
        (self.right.height() if self.right is not None else 0),
    )

__len__

__len__() -> int

Return the number of nodes in the subtree rooted at this node.

Examples:

>>> node = Node("+", Node("x"), Node("1"))
>>> len(node)
3

Returns:

Type Description
int

Total node count of the subtree.

Source code in SRToolkit/utils/expression_tree.py
def __len__(self) -> int:
    """
    Return the number of nodes in the subtree rooted at this node.

    Examples:
        >>> node = Node("+", Node("x"), Node("1"))
        >>> len(node)
        3

    Returns:
        Total node count of the subtree.
    """
    return 1 + (len(self.left) if self.left is not None else 0) + (len(self.right) if self.right is not None else 0)

__str__

__str__() -> str

Return the expression as a concatenated string using default infix notation that may contain redundant parentheses.

Examples:

>>> node = Node("+", Node("x"), Node("1"))
>>> str(node)
'1+x'

Returns:

Type Description
str

Concatenated token string with no spaces.

Source code in SRToolkit/utils/expression_tree.py
def __str__(self) -> str:
    """
    Return the expression as a concatenated string using default infix notation that may contain redundant parentheses.

    Examples:
        >>> node = Node("+", Node("x"), Node("1"))
        >>> str(node)
        '1+x'

    Returns:
        Concatenated token string with no spaces.
    """
    return "".join(self.to_list())

__copy__

__copy__() -> Node

Return a deep copy of the subtree rooted at this node.

Examples:

>>> node = Node("+", Node("X_0"), Node("1"))
>>> new_node = copy(node)
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_0']
>>> new_node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_0']
>>> node == node
True
>>> node == new_node
False

Returns:

Type Description
Node

An independent copy of the subtree.

Source code in SRToolkit/utils/expression_tree.py
def __copy__(self) -> "Node":
    """
    Return a deep copy of the subtree rooted at this node.

    Examples:
        >>> node = Node("+", Node("X_0"), Node("1"))
        >>> new_node = copy(node)
        >>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
        ['1', '+', 'X_0']
        >>> new_node.to_list(symbol_library=SymbolLibrary.default_symbols())
        ['1', '+', 'X_0']
        >>> node == node
        True
        >>> node == new_node
        False

    Returns:
        An independent copy of the subtree.
    """
    if self.left is not None:
        left = copy(self.left)
    else:
        left = None
    if self.right is not None:
        right = copy(self.right)
    else:
        right = None
    return Node(copy(self.symbol), left=left, right=right)

AncestorInfo dataclass

AncestorInfo(nonterminal: str, rule: Rule, child_index: int)

One entry in an ExpansionContext's ancestor chain.

Attributes:

Name Type Description
nonterminal str

The non-terminal symbol at this ancestor position.

rule Rule

The rule applied at this position.

child_index int

The index within rule.rhs that leads toward the current slot, counting only non-terminal positions.

Constraint

Bases: Generic[L, G]

Base class for derivation constraints (hard filters).

Subclass and override the methods you need. Constraints carry only construction-time configuration; all per-derivation state is managed by the engine and threaded through method arguments.

Scoping — set nonterminals and/or rule_names to restrict allows to a subset of slots or rules. A scope miss is treated as implicit acceptance. update is always called regardless of scope so that global counters stay accurate.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([Rule("E", ["x"]), Rule("E", ["y"])])
>>> class RejectY(Constraint):
...     def allows(self, slot, rule, global_): return rule.rhs != ["y"]
>>> g.add_constraint(RejectY())
>>> d = g.start_derivation("E")
>>> [r.rhs for r in d.options()]
[['x']]

to_dict

to_dict() -> dict

Serialise this constraint to a JSON-safe dictionary.

The base implementation stores only the fully-qualified class path under constraint_class. Subclasses should call super().to_dict() and add their own constructor arguments so that from_dict can reconstruct the instance faithfully. See from_dict for an example.

Returns:

Type Description
dict

Dictionary with at least the key constraint_class.

Source code in SRToolkit/utils/grammar/constraints.py
def to_dict(self) -> dict:
    """
    Serialise this constraint to a JSON-safe dictionary.

    The base implementation stores only the fully-qualified class path under
    ``constraint_class``. Subclasses should call ``super().to_dict()`` and
    add their own constructor arguments so that ``from_dict`` can reconstruct
    the instance faithfully. See
    [from_dict][SRToolkit.utils.grammar.constraints.Constraint.from_dict]
    for an example.

    Returns:
        Dictionary with at least the key ``constraint_class``.
    """
    return {"constraint_class": f"{self.__class__.__module__}.{self.__class__.__qualname__}"}

from_dict classmethod

from_dict(d: dict) -> 'Constraint'

Reconstruct a constraint from a dictionary produced by to_dict.

When called on the base Constraint class, dispatches to the correct subclass via the constraint_class key using importlib — both built-in and user-defined subclasses are supported. When called on a concrete subclass, the subclass must override this method.

The dictionary must contain at minimum the key constraint_class, whose value is the fully-qualified class path (e.g. "mymodule.MyConstraint"). Any additional keys are forwarded to the subclass override.

To make a custom subclass serialisable, override both to_dict and from_dict::

class MyConstraint(Constraint):
    def __init__(self, threshold: float) -> None:
        self.threshold = threshold

    def to_dict(self) -> dict:
        return {**super().to_dict(), "threshold": self.threshold}

    @classmethod
    def from_dict(cls, d: dict) -> "MyConstraint":
        return cls(d["threshold"])

Examples:

>>> from SRToolkit.utils.grammar import Constraint, MaxDepth
>>> c = Constraint.from_dict(MaxDepth(5).to_dict())
>>> isinstance(c, MaxDepth) and c.limit == 5
True

Parameters:

Name Type Description Default
d dict

Dictionary previously returned by to_dict.

required

Returns:

Type Description
'Constraint'

A reconstructed Constraint instance.

Raises:

Type Description
KeyError

If constraint_class is missing from d (dispatch path).

ImportError

If the module cannot be imported (dispatch path).

AttributeError

If the class cannot be found in the module (dispatch path).

NotImplementedError

If called on a subclass that has not overridden this method.

Source code in SRToolkit/utils/grammar/constraints.py
@classmethod
def from_dict(cls, d: dict) -> "Constraint":
    """
    Reconstruct a constraint from a dictionary produced by ``to_dict``.

    When called on the base [Constraint][SRToolkit.utils.grammar.constraints.Constraint]
    class, dispatches to the correct subclass via the ``constraint_class`` key using
    `importlib` — both built-in and user-defined subclasses are supported.
    When called on a concrete subclass, the subclass must override this method.

    The dictionary must contain at minimum the key ``constraint_class``, whose value
    is the fully-qualified class path (e.g. ``"mymodule.MyConstraint"``).  Any
    additional keys are forwarded to the subclass override.

    To make a custom subclass serialisable, override both ``to_dict`` and
    ``from_dict``::

        class MyConstraint(Constraint):
            def __init__(self, threshold: float) -> None:
                self.threshold = threshold

            def to_dict(self) -> dict:
                return {**super().to_dict(), "threshold": self.threshold}

            @classmethod
            def from_dict(cls, d: dict) -> "MyConstraint":
                return cls(d["threshold"])

    Examples:
        >>> from SRToolkit.utils.grammar import Constraint, MaxDepth
        >>> c = Constraint.from_dict(MaxDepth(5).to_dict())
        >>> isinstance(c, MaxDepth) and c.limit == 5
        True

    Args:
        d: Dictionary previously returned by ``to_dict``.

    Returns:
        A reconstructed [Constraint][SRToolkit.utils.grammar.constraints.Constraint] instance.

    Raises:
        KeyError: If ``constraint_class`` is missing from ``d`` (dispatch path).
        ImportError: If the module cannot be imported (dispatch path).
        AttributeError: If the class cannot be found in the module (dispatch path).
        NotImplementedError: If called on a subclass that has not overridden this method.
    """
    if cls is Constraint:
        class_path = d["constraint_class"]
        module_path, cls_name = class_path.rsplit(".", 1)
        resolved = getattr(importlib.import_module(module_path), cls_name)
        return resolved.from_dict(d)
    raise NotImplementedError(f"{cls.__name__}.from_dict is not implemented.")

initial_local

initial_local(start: str) -> L

Return the initial local state for the constraint.

Source code in SRToolkit/utils/grammar/constraints.py
def initial_local(self, start: str) -> L:
    """Return the initial local state for the constraint."""
    return None  # type: ignore[return-value]

initial_global

initial_global() -> G

Return the initial global state for the constraint.

Source code in SRToolkit/utils/grammar/constraints.py
def initial_global(self) -> G:
    """Return the initial global state for the constraint."""
    return None  # type: ignore[return-value]

allows

allows(slot: ExpansionContext[L], rule: Rule, global_: G) -> bool

Return True if rule may be applied at slot.

Called only when the slot's non-terminal and rule name are within scope.

Parameters:

Name Type Description Default
slot ExpansionContext[L]

Current derivation position with this constraint's local state.

required
rule Rule

Candidate production rule.

required
global_ G

Current per-derivation global state.

required

Returns:

Type Description
bool

True to keep the rule in the candidate set.

Source code in SRToolkit/utils/grammar/constraints.py
def allows(self, slot: ExpansionContext[L], rule: Rule, global_: G) -> bool:
    """
    Return ``True`` if ``rule`` may be applied at ``slot``.

    Called only when the slot's non-terminal and rule name are within scope.

    Args:
        slot: Current derivation position with this constraint's local state.
        rule: Candidate production rule.
        global_: Current per-derivation global state.

    Returns:
        ``True`` to keep the rule in the candidate set.
    """
    return True

update

update(slot: ExpansionContext[L], rule: Rule, global_: G) -> tuple[list[L], G]

Called after rule is applied at slot.

Returns per-child local states (one per non-terminal in rule.rhs, in order) and the new global state. May be used to update global counters by returning a new global_ value.

Parameters:

Name Type Description Default
slot ExpansionContext[L]

Derivation position immediately before the rule was applied.

required
rule Rule

The rule that was applied.

required
global_ G

Global state before this application.

required

Returns:

Type Description
list[L]

(child_locals, new_global) where len(child_locals) equals

G

the number of non-terminals in rule.rhs.

Source code in SRToolkit/utils/grammar/constraints.py
def update(self, slot: ExpansionContext[L], rule: Rule, global_: G) -> tuple[list[L], G]:
    """
    Called after ``rule`` is applied at ``slot``.

    Returns per-child local states (one per non-terminal in ``rule.rhs``,
    in order) and the new global state. May be used to update global
    counters by returning a new ``global_`` value.

    Args:
        slot: Derivation position immediately before the rule was applied.
        rule: The rule that was applied.
        global_: Global state before this application.

    Returns:
        ``(child_locals, new_global)`` where ``len(child_locals)`` equals
        the number of non-terminals in ``rule.rhs``.
    """
    n = sum(1 for s in rule.rhs if s in slot.nonterminals)
    return [slot.local] * n, global_

Derivation

Derivation(grammar: Grammar, start: str)

Stateful leftmost derivation over a Grammar.

A derivation begins at a start non-terminal and proceeds by repeatedly choosing a rule to apply to the leftmost unexpanded non-terminal. options returns candidate rules filtered by all registered constraints; apply advances the derivation by one production.

Per-slot local state and per-derivation global state for each registered constraint are maintained internally; constraint instances carry only construction-time configuration and are safe to share across parallel derivations.

Obtain a Derivation via Grammar.start_derivation.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.complete
False
>>> opts = d.options()
>>> len(opts)
1
>>> d.apply(opts[0])
>>> d.complete
True
>>> d.to_token_list()
['x']
Source code in SRToolkit/utils/grammar/derivation.py
def __init__(self, grammar: Grammar, start: str) -> None:
    self._grammar = grammar
    self._nonterminals: frozenset[str] = frozenset(grammar.nonterminals)
    self._steps: int = 0
    self._root: ParseTreeNode = ParseTreeNode(start, None)

    initial_frame = _Frame(
        nonterminal=start,
        node=self._root,
        ancestors=(),
        parent_rule=None,
        child_index=None,
    )
    self._globals: dict[int, Any] = {}
    for c in grammar._constraints:
        cid = id(c)
        initial_frame.locals[cid] = c.initial_local(start)
        self._globals[cid] = c.initial_global()
    # _frames[-1] is the current leftmost open slot.
    self._frames: list[_Frame] = [initial_frame]

complete property

complete: bool

True when no unexpanded non-terminals remain.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.complete
False
>>> d.apply(d.options()[0])
>>> d.complete
True

local_stack

local_stack(constraint: Constraint) -> list[Any]

Return the local state stack for constraint across the open frontier, leftmost slot first.

Parameters:

Name Type Description Default
constraint Constraint

A constraint previously registered on the grammar.

required

Returns:

Type Description
list[Any]

One entry per open non-terminal in left-to-right order.

Source code in SRToolkit/utils/grammar/derivation.py
def local_stack(self, constraint: Constraint) -> list[Any]:
    """
    Return the local state stack for ``constraint`` across the open frontier,
    leftmost slot first.

    Args:
        constraint: A constraint previously registered on the grammar.

    Returns:
        One entry per open non-terminal in left-to-right order.
    """
    cid = id(constraint)
    return [f.locals[cid] for f in reversed(self._frames)]

global_state

global_state(constraint: Constraint) -> Any

Return the per-derivation global state for constraint.

Parameters:

Name Type Description Default
constraint Constraint

A constraint previously registered on the grammar.

required

Returns:

Type Description
Any

The current global value owned by this derivation for constraint.

Source code in SRToolkit/utils/grammar/derivation.py
def global_state(self, constraint: Constraint) -> Any:
    """
    Return the per-derivation global state for ``constraint``.

    Args:
        constraint: A constraint previously registered on the grammar.

    Returns:
        The current global value owned by this derivation for ``constraint``.
    """
    return self._globals[id(constraint)]

options

options() -> list[Rule]

Return candidate rules for the current leftmost unexpanded non-terminal, filtered by every registered constraint's allows.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> g.add_rule(Rule("E", ["y"]))
>>> d = g.start_derivation("E")
>>> len(d.options())
2

Returns:

Type Description
list[Rule]

List of Rule objects that every

list[Rule]

constraint accepts.

Raises:

Type Description
RuntimeError

If the derivation is already complete.

Source code in SRToolkit/utils/grammar/derivation.py
def options(self) -> list[Rule]:
    """
    Return candidate rules for the current leftmost unexpanded
    non-terminal, filtered by every registered constraint's
    [allows][SRToolkit.utils.grammar.constraints.Constraint.allows].

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> g.add_rule(Rule("E", ["y"]))
        >>> d = g.start_derivation("E")
        >>> len(d.options())
        2

    Returns:
        List of [Rule][SRToolkit.utils.grammar.Rule] objects that every
        constraint accepts.

    Raises:
        RuntimeError: If the derivation is already complete.
    """
    if self.complete:
        raise RuntimeError("Derivation is already complete; no options remain.")

    top = self._frames[-1]
    candidates = self._grammar.rules_for(top.nonterminal)

    for c in self._grammar._constraints:
        cid = id(c)
        slot = self._slot_for(top, top.locals[cid])
        global_ = self._globals[cid]
        surviving: list[Rule] = []
        for rule in candidates:
            if _scope_miss(c, top.nonterminal, rule):
                surviving.append(rule)
                continue
            if c.allows(slot, rule, global_):
                surviving.append(rule)
        candidates = surviving
        if not candidates:
            break

    return candidates

apply

apply(rule: Rule) -> None

Apply rule to the current leftmost unexpanded non-terminal.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["E", "+", "F"]))
>>> g.add_rule(Rule("E", ["x"]))
>>> g.add_rule(Rule("F", ["y"]))
>>> d = g.start_derivation("E")
>>> d.apply(g.rules_for("E")[0])   # E -> E + F
>>> d.apply(g.rules_for("E")[1])   # E -> x
>>> d.apply(g.rules_for("F")[0])   # F -> y
>>> d.to_token_list()
['x', '+', 'y']

Parameters:

Name Type Description Default
rule Rule

A rule whose lhs matches the current leftmost non-terminal.

required

Raises:

Type Description
RuntimeError

If the derivation is already complete.

ValueError

If rule.lhs does not match the current non-terminal.

Source code in SRToolkit/utils/grammar/derivation.py
def apply(self, rule: Rule) -> None:
    """
    Apply ``rule`` to the current leftmost unexpanded non-terminal.

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["E", "+", "F"]))
        >>> g.add_rule(Rule("E", ["x"]))
        >>> g.add_rule(Rule("F", ["y"]))
        >>> d = g.start_derivation("E")
        >>> d.apply(g.rules_for("E")[0])   # E -> E + F
        >>> d.apply(g.rules_for("E")[1])   # E -> x
        >>> d.apply(g.rules_for("F")[0])   # F -> y
        >>> d.to_token_list()
        ['x', '+', 'y']

    Args:
        rule: A rule whose ``lhs`` matches the current leftmost
            non-terminal.

    Raises:
        RuntimeError: If the derivation is already complete.
        ValueError: If ``rule.lhs`` does not match the current
            non-terminal.
    """
    if self.complete:
        raise RuntimeError("Derivation is already complete.")

    top = self._frames[-1]
    if rule.lhs != top.nonterminal:
        raise ValueError(f"Rule lhs '{rule.lhs}' does not match current non-terminal '{top.nonterminal}'.")

    # Build child parse-tree nodes and the new frames for each NT child.
    child_frames: list[_Frame] = []
    nt_child_index = 0
    for sym in rule.rhs:
        child_node = ParseTreeNode(sym, None)
        top.node.children.append(child_node)
        if sym in self._nonterminals:
            frame = AncestorInfo(top.nonterminal, rule, nt_child_index)
            child_frames.append(
                _Frame(
                    nonterminal=sym,
                    node=child_node,
                    ancestors=top.ancestors + (frame,),
                    parent_rule=rule,
                    child_index=nt_child_index,
                )
            )
            nt_child_index += 1

    top.node.rule_applied = rule
    n_nt_children = len(child_frames)

    # Thread state through every constraint using the pre-apply slot view.
    pre_frontier_size = len(self._frames)
    for c in self._grammar._constraints:
        cid = id(c)
        slot = self._slot_for(top, top.locals[cid], frontier_size=pre_frontier_size)
        child_locals, new_global = c.update(slot, rule, self._globals[cid])
        if len(child_locals) != n_nt_children:
            raise RuntimeError(
                f"Constraint {c!r}.update() returned {len(child_locals)} child locals "
                f"but rule '{rule.lhs} -> {rule.rhs}' has {n_nt_children} "
                f"non-terminal children."
            )
        for i, cf in enumerate(child_frames):
            cf.locals[cid] = child_locals[i]
        self._globals[cid] = new_global

    # Pop the top frame and push children so that the leftmost child ends
    # up at the end of the list (i.e. on top of the stack).
    self._frames.pop()
    for cf in reversed(child_frames):
        self._frames.append(cf)
    self._steps += 1

sample

sample() -> None

Apply one rule chosen proportionally by weight (PCFG) or uniformly (CFG) from the surviving candidates.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.sample()
>>> d.complete
True

Raises:

Type Description
RuntimeError

If the derivation is already complete.

RuntimeError

If all candidate rules are filtered out by allows.

Source code in SRToolkit/utils/grammar/derivation.py
def sample(self) -> None:
    """
    Apply one rule chosen proportionally by weight (PCFG) or uniformly
    (CFG) from the surviving candidates.

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> d = g.start_derivation("E")
        >>> d.sample()
        >>> d.complete
        True

    Raises:
        RuntimeError: If the derivation is already complete.
        RuntimeError: If all candidate rules are filtered out by
            [allows][SRToolkit.utils.grammar.constraints.Constraint.allows].
    """
    candidates = self.options()
    if not candidates:
        raise RuntimeError(
            "No valid rules available for the current non-terminal after applying constraint filters."
        )

    weights = [float(r.weight) for r in candidates]
    total = sum(weights)
    probs = [w / total for w in weights]
    self.apply(candidates[np.random.choice(len(candidates), p=probs)])

generate

generate(limit: int = 1000) -> list[str]

Run the derivation to completion and return the token list.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> g.start_derivation("E").generate()
['x']

Parameters:

Name Type Description Default
limit int

Maximum number of rule applications. A negative value means unlimited. Default 1000.

1000

Returns:

Type Description
list[str]

Flat list of terminal tokens in left-to-right order.

Raises:

Type Description
RuntimeError

If the derivation does not complete within limit steps (only when limit >= 0).

Source code in SRToolkit/utils/grammar/derivation.py
def generate(self, limit: int = 1000) -> list[str]:
    """
    Run the derivation to completion and return the token list.

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> g.start_derivation("E").generate()
        ['x']

    Args:
        limit: Maximum number of rule applications. A negative value means
            unlimited. Default ``1000``.

    Returns:
        Flat list of terminal tokens in left-to-right order.

    Raises:
        RuntimeError: If the derivation does not complete within
            ``limit`` steps (only when ``limit >= 0``).
    """
    steps = 0
    while not self.complete:
        if steps >= limit >= 0:
            raise RuntimeError(f"Derivation did not complete within {limit} rule applications.")
        self.sample()
        steps += 1
    return self.to_token_list()

to_token_list

to_token_list() -> list[str]

Return the completed expression as a flat token list.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.apply(d.options()[0])
>>> d.to_token_list()
['x']

Returns:

Type Description
list[str]

Flat list of terminal tokens in left-to-right order.

Raises:

Type Description
RuntimeError

If the derivation is not yet complete.

Source code in SRToolkit/utils/grammar/derivation.py
def to_token_list(self) -> list[str]:
    """
    Return the completed expression as a flat token list.

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> d = g.start_derivation("E")
        >>> d.apply(d.options()[0])
        >>> d.to_token_list()
        ['x']

    Returns:
        Flat list of terminal tokens in left-to-right order.

    Raises:
        RuntimeError: If the derivation is not yet complete.
    """
    if not self.complete:
        raise RuntimeError("Derivation is not yet complete.")
    return self.to_parse_tree().to_token_list()

to_parse_tree

to_parse_tree() -> ParseTree

Return the completed derivation as a ParseTree.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.apply(d.options()[0])
>>> isinstance(d.to_parse_tree(), ParseTree)
True

Returns:

Type Description
ParseTree

The ParseTree rooted at the

ParseTree

start symbol.

Raises:

Type Description
RuntimeError

If the derivation is not yet complete.

Source code in SRToolkit/utils/grammar/derivation.py
def to_parse_tree(self) -> ParseTree:
    """
    Return the completed derivation as a
    [ParseTree][SRToolkit.utils.grammar.ParseTree].

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> d = g.start_derivation("E")
        >>> d.apply(d.options()[0])
        >>> isinstance(d.to_parse_tree(), ParseTree)
        True

    Returns:
        The [ParseTree][SRToolkit.utils.grammar.ParseTree] rooted at the
        start symbol.

    Raises:
        RuntimeError: If the derivation is not yet complete.
    """
    if not self.complete:
        raise RuntimeError("Derivation is not yet complete.")
    return ParseTree(self._root)

DimensionalConsistency

DimensionalConsistency(variable_units: dict[str, dict], target_unit: dict, constant_units: Optional[dict[str, dict]] = None, symbol_library: Optional[SymbolLibrary] = None, allow_unit_polymorphic_constants: bool = False)

Bases: Constraint[Optional[Unit], None]

Stateful constraint that enforces dimensional analysis during expression generation.

Local state at each slot is the required unit (Optional[Unit]). None means "any unit is acceptable here" — used for the children of multiplicative operators, whose individual units are underdetermined.

Unit representation: units are dict[str, Fraction] mapping base dimension names (e.g. "m", "s", "kg") to rational exponents. Dimensionless is {} (empty dict).

Free constants (const type): treated as dimensionless by default. Set allow_unit_polymorphic_constants=True to let them absorb any required unit.

Named physical constants (lit type, e.g. "g", "c"): declared via constant_units; checked like variables.

Undeclared variables or literals: conservatively accepted.

Rule classification: uses rule.name when available (see Grammar.from_symbol_library for the naming scheme). Falls back to rhs-shape heuristics for unnamed rules.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> from SRToolkit.utils.grammar import DimensionalConsistency
>>> from fractions import Fraction
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["E", "+", "F"], weight=0.6, name="E_add_+"))
>>> g.add_rule(Rule("E", ["F"], weight=0.4, name="E_to_F"))
>>> g.add_rule(Rule("F", ["v"], weight=0.5, name="F_v"))
>>> g.add_rule(Rule("F", ["t"], weight=0.5, name="F_t"))
>>> dc = DimensionalConsistency(
...     variable_units={"v": {"m": 1, "s": -1}, "t": {"s": 1}},
...     target_unit={"m": 1, "s": -1},
... )
>>> g.add_constraint(dc)
>>> d = g.start_derivation("E")
>>> all(r.rhs != ["t"] for r in d.options())
True

Parameters:

Name Type Description Default
variable_units dict[str, dict]

Mapping from variable token to its unit.

required
target_unit dict

Required unit of the generated expression.

required
constant_units Optional[dict[str, dict]]

Units for named physical constants (lit type).

None
symbol_library Optional[SymbolLibrary]

Used to classify tokens by type and precedence.

None
allow_unit_polymorphic_constants bool

If True, free constants absorb whatever unit their slot requires. Default False.

False
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(
    self,
    variable_units: dict[str, dict],
    target_unit: dict,
    constant_units: Optional[dict[str, dict]] = None,
    symbol_library: Optional[SymbolLibrary] = None,
    allow_unit_polymorphic_constants: bool = False,
) -> None:
    self._var_units: dict[str, Unit] = {k: _to_unit(v) for k, v in variable_units.items()}
    self._target: Unit = _to_unit(target_unit)
    self._const_units: dict[str, Unit] = {k: _to_unit(v) for k, v in (constant_units or {}).items()}
    self._sl = symbol_library
    self._allow_poly_const = allow_unit_polymorphic_constants

ExpansionContext dataclass

ExpansionContext(nonterminal: str, local: L, steps: int, parent_rule: Optional[Rule], child_index: Optional[int], ancestors: tuple[AncestorInfo, ...], partial_tree: ParseTreeNode, nonterminals: frozenset[str], frontier_size: int = 1)

Bases: Generic[L]

Read-only view of the current derivation position passed to constraints.

The engine constructs an ExpansionContext on demand for each (constraint, candidate_rule) pair evaluated in Derivation.options and for each (constraint, selected_rule) pair in Derivation.apply.

Attributes:

Name Type Description
nonterminal str

The non-terminal symbol being expanded at this position.

local L

This constraint's per-slot inherited state.

steps int

Number of rule applications made so far in the derivation (i.e. how many times apply() has been called).

parent_rule Optional[Rule]

The rule whose application created this slot (None for the start symbol).

child_index Optional[int]

Index of this slot in parent_rule.rhs counting only non-terminal positions (None for the start symbol).

ancestors tuple[AncestorInfo, ...]

Ancestor chain from the root to this slot's immediate parent, in root-first order.

partial_tree ParseTreeNode

Root of the partially-built parse tree. Read-only by contract — constraints must not mutate it.

nonterminals frozenset[str]

Frozen set of all non-terminal symbols in the grammar.

frontier_size int

Number of open (unexpanded) non-terminal slots in the current derivation frontier, including this slot. Useful for budget constraints such as MaxNodes that need to account for the minimum number of rule applications still required to complete the derivation.

Grammar

Grammar(rules: Optional[list[Rule]] = None, start: Optional[str] = None)

A context-free grammar (CFG) or probabilistic context-free grammar (PCFG).

Rules are added via add_rule. The set of non-terminals is derived automatically: a symbol is a non-terminal if and only if it appears as the lhs of at least one rule.

A grammar is a CFG when every rule carries the default weight (1.0), making sampling uniform within each group. It is a PCFG when any rule has a weight that differs from 1.0.

Constraints are registered via add_constraint. During a derivation, options returns only rules that every constraint's allows accepts.

Examples:

>>> g = Grammar([
...     Rule("E", ["E", "+", "F"], weight=0.4),
...     Rule("E", ["F"], weight=0.6),
...     Rule("F", ["x"]),
... ])
>>> "E" in g.nonterminals
True
>>> g.is_pcfg()
True
>>> len(g.rules_for("E"))
2

Parameters:

Name Type Description Default
rules Optional[list[Rule]]

Optional list of Rule objects to add at construction time. Equivalent to calling add_rule for each entry.

None
start Optional[str]

Default start non-terminal used by start_derivation when no start argument is given.

None
Source code in SRToolkit/utils/grammar/grammar.py
def __init__(self, rules: Optional[list[Rule]] = None, start: Optional[str] = None) -> None:
    """
    Args:
        rules: Optional list of [Rule][SRToolkit.utils.grammar.Rule] objects to add at
            construction time. Equivalent to calling
            [add_rule][SRToolkit.utils.grammar.Grammar.add_rule] for each entry.
        start: Default start non-terminal used by
            [start_derivation][SRToolkit.utils.grammar.Grammar.start_derivation]
            when no ``start`` argument is given.
    """
    self.start: Optional[str] = start
    self._rules: list[Rule] = []
    self._rules_by_lhs: dict[str, list[Rule]] = {}
    self._constraints: list = []
    for rule in rules or []:
        self.add_rule(rule)

nonterminals property

nonterminals: set[str]

Set of all non-terminal symbols in the grammar.

A symbol is a non-terminal if and only if it appears as the lhs of at least one rule.

Examples:

>>> g = Grammar()
>>> g.add_rule(Rule("E", ["F"]))
>>> g.add_rule(Rule("F", ["x"]))
>>> g.nonterminals == {"E", "F"}
True

add_rule

add_rule(rule: Rule) -> None

Add a production rule to the grammar.

Examples:

>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> len(g.rules_for("E"))
1

Parameters:

Name Type Description Default
rule Rule

The Rule to register.

required
Source code in SRToolkit/utils/grammar/grammar.py
def add_rule(self, rule: Rule) -> None:
    """
    Add a production rule to the grammar.

    Examples:
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> len(g.rules_for("E"))
        1

    Args:
        rule: The [Rule][SRToolkit.utils.grammar.Rule] to register.
    """
    self._rules.append(rule)
    self._rules_by_lhs.setdefault(rule.lhs, []).append(rule)

add_constraint

add_constraint(constraint: 'Constraint') -> None

Register a constraint applied at each derivation step.

A rule is offered as an option only when every registered constraint's allows returns True for it.

Examples:

>>> from SRToolkit.utils.grammar import MaxDepth
>>> g = Grammar([Rule("E", ["E", "+", "E"]), Rule("E", ["x"])])
>>> g.add_constraint(MaxDepth(0))
>>> d = g.start_derivation("E")
>>> [r.rhs for r in d.options()]
[['x']]

Parameters:

Name Type Description Default
constraint 'Constraint'

A Constraint instance — typically a built-in such as MaxDepth, MaxNodes, MaxOccurrences, NoNested, or DimensionalConsistency, or a user-defined subclass.

required
Source code in SRToolkit/utils/grammar/grammar.py
def add_constraint(self, constraint: "Constraint") -> None:
    """
    Register a [constraint][SRToolkit.utils.grammar.constraints.Constraint] applied at each derivation step.

    A rule is offered as an option only when every registered constraint's
    [allows][SRToolkit.utils.grammar.constraints.Constraint.allows] returns ``True``
    for it.

    Examples:
        >>> from SRToolkit.utils.grammar import MaxDepth
        >>> g = Grammar([Rule("E", ["E", "+", "E"]), Rule("E", ["x"])])
        >>> g.add_constraint(MaxDepth(0))
        >>> d = g.start_derivation("E")
        >>> [r.rhs for r in d.options()]
        [['x']]

    Args:
        constraint: A [Constraint][SRToolkit.utils.grammar.constraints.Constraint]
            instance — typically a built-in such as
            [MaxDepth][SRToolkit.utils.grammar.constraints.MaxDepth],
            [MaxNodes][SRToolkit.utils.grammar.constraints.MaxNodes],
            [MaxOccurrences][SRToolkit.utils.grammar.constraints.MaxOccurrences],
            [NoNested][SRToolkit.utils.grammar.constraints.NoNested], or
            [DimensionalConsistency][SRToolkit.utils.grammar.constraints.DimensionalConsistency],
            or a user-defined subclass.
    """
    self._constraints.append(constraint)

rules_for

rules_for(nonterminal: str) -> list[Rule]

Return all rules whose lhs matches nonterminal.

Examples:

>>> g = Grammar()
>>> g.add_rule(Rule("E", ["F"]))
>>> g.add_rule(Rule("E", ["x"]))
>>> len(g.rules_for("E"))
2
>>> g.rules_for("Z")
[]

Parameters:

Name Type Description Default
nonterminal str

Non-terminal to look up.

required

Returns:

Type Description
list[Rule]

List of matching Rule objects in insertion order.

Source code in SRToolkit/utils/grammar/grammar.py
def rules_for(self, nonterminal: str) -> list[Rule]:
    """
    Return all rules whose ``lhs`` matches ``nonterminal``.

    Examples:
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["F"]))
        >>> g.add_rule(Rule("E", ["x"]))
        >>> len(g.rules_for("E"))
        2
        >>> g.rules_for("Z")
        []

    Args:
        nonterminal: Non-terminal to look up.

    Returns:
        List of matching [Rule][SRToolkit.utils.grammar.Rule] objects in insertion order.
    """
    return list(self._rules_by_lhs.get(nonterminal, []))

is_pcfg

is_pcfg() -> bool

Return True if any rule deviates from the default weight of 1.0.

Examples:

>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> g.add_rule(Rule("E", ["y"]))
>>> g.is_pcfg()
False
>>> g.add_rule(Rule("F", ["z"], weight=0.3))
>>> g.is_pcfg()
True

Returns:

Type Description
bool

True if at least one rule has a weight other than 1.0.

Source code in SRToolkit/utils/grammar/grammar.py
def is_pcfg(self) -> bool:
    """
    Return ``True`` if any rule deviates from the default weight of ``1.0``.

    Examples:
        >>> g = Grammar()
        >>> g.add_rule(Rule("E", ["x"]))
        >>> g.add_rule(Rule("E", ["y"]))
        >>> g.is_pcfg()
        False
        >>> g.add_rule(Rule("F", ["z"], weight=0.3))
        >>> g.is_pcfg()
        True

    Returns:
        ``True`` if at least one rule has a weight other than ``1.0``.
    """
    return any(r.weight != 1.0 for r in self._rules)

validate

validate(parse_tree: ParseTree, require_start: bool = False) -> bool

Validate that parse_tree is structurally valid, uses only productions from this grammar, and satisfies every registered constraint.

By default the root of the tree may be any non-terminal in the grammar, not just self.start. This lets you validate sub-trees (e.g. an F-rooted fragment) in addition to full expressions. Pass require_start=True to additionally enforce that the root symbol equals self.start (useful when validating complete, top-level expressions).

The check replays the parse tree through a fresh Derivation rooted at parse_tree.root.symbol, walking nodes in leftmost order (mirroring the derivation frontier). At each internal node it checks that:

  • The applied rule is permitted at the current frontier position — meaning it exists in the grammar and every registered constraint accepts it.
  • The number of children equals len(rule.rhs).
  • Each children[i].symbol equals rule.rhs[i].

Terminal leaves are accepted when their symbol is not a non-terminal in this grammar. A non-terminal symbol appearing as a leaf (unexpanded) is rejected.

Examples:

>>> g = Grammar(start="E")
>>> r = Rule("E", ["x"])
>>> g.add_rule(r)
>>> leaf = ParseTreeNode("x", None)
>>> root = ParseTreeNode("E", r, [leaf])
>>> g.validate(ParseTree(root))
True
>>> foreign = Rule("E", ["y"])
>>> root2 = ParseTreeNode("E", foreign, [ParseTreeNode("y", None)])
>>> g.validate(ParseTree(root2))
False
>>> g.add_rule(Rule("F", ["x"]))
>>> r_f = Rule("F", ["x"])
>>> g.add_rule(r_f)
>>> f_root = ParseTreeNode("F", r_f, [ParseTreeNode("x", None)])
>>> g.validate(ParseTree(f_root))
True
>>> g.validate(ParseTree(f_root), require_start=True)
False

Parameters:

Name Type Description Default
parse_tree ParseTree

The ParseTree to validate.

required
require_start bool

When True, return False if the root symbol does not equal self.start. Raises ValueError if self.start is None and require_start is True.

False

Returns:

Type Description
bool

True if the tree is structurally consistent, every production exists

bool

in this grammar, and all constraints would have permitted the derivation

bool

(and the root matches self.start when require_start=True).

Raises:

Type Description
ValueError

If require_start=True but self.start is None.

Source code in SRToolkit/utils/grammar/grammar.py
def validate(self, parse_tree: ParseTree, require_start: bool = False) -> bool:
    """
    Validate that ``parse_tree`` is structurally valid, uses only productions
    from this grammar, and satisfies every registered constraint.

    By default the root of the tree may be **any non-terminal** in the grammar,
    not just ``self.start``.  This lets you validate sub-trees (e.g. an
    ``F``-rooted fragment) in addition to full expressions.  Pass
    ``require_start=True`` to additionally enforce that the root symbol equals
    ``self.start`` (useful when validating complete, top-level expressions).

    The check replays the parse tree through a fresh
    [Derivation][SRToolkit.utils.grammar.derivation.Derivation] rooted at
    ``parse_tree.root.symbol``, walking nodes in leftmost order (mirroring the
    derivation frontier).  At each internal node it checks that:

    - The applied rule is permitted at the current frontier position — meaning
      it exists in the grammar **and** every registered constraint accepts it.
    - The number of children equals ``len(rule.rhs)``.
    - Each ``children[i].symbol`` equals ``rule.rhs[i]``.

    Terminal leaves are accepted when their symbol is not a non-terminal in
    this grammar.  A non-terminal symbol appearing as a leaf (unexpanded) is
    rejected.

    Examples:
        >>> g = Grammar(start="E")
        >>> r = Rule("E", ["x"])
        >>> g.add_rule(r)
        >>> leaf = ParseTreeNode("x", None)
        >>> root = ParseTreeNode("E", r, [leaf])
        >>> g.validate(ParseTree(root))
        True
        >>> foreign = Rule("E", ["y"])
        >>> root2 = ParseTreeNode("E", foreign, [ParseTreeNode("y", None)])
        >>> g.validate(ParseTree(root2))
        False
        >>> g.add_rule(Rule("F", ["x"]))
        >>> r_f = Rule("F", ["x"])
        >>> g.add_rule(r_f)
        >>> f_root = ParseTreeNode("F", r_f, [ParseTreeNode("x", None)])
        >>> g.validate(ParseTree(f_root))
        True
        >>> g.validate(ParseTree(f_root), require_start=True)
        False

    Args:
        parse_tree: The [ParseTree][SRToolkit.utils.grammar.ParseTree] to validate.
        require_start: When ``True``, return ``False`` if the root symbol does
            not equal ``self.start``.  Raises ``ValueError`` if ``self.start``
            is ``None`` and ``require_start`` is ``True``.

    Returns:
        ``True`` if the tree is structurally consistent, every production exists
        in this grammar, and all constraints would have permitted the derivation
        (and the root matches ``self.start`` when ``require_start=True``).

    Raises:
        ValueError: If ``require_start=True`` but ``self.start`` is ``None``.
    """
    if require_start:
        if self.start is None:
            raise ValueError("require_start=True but this grammar has no start symbol.")
        if parse_tree.root.symbol != self.start:
            return False
    root = parse_tree.root
    if root.rule_applied is None:
        return len(root.children) == 0 and root.symbol not in self.nonterminals
    try:
        d = self.start_derivation(root.symbol)
    except ValueError:
        return False
    stack = [root]
    while stack:
        node = stack.pop()
        if node.rule_applied is None:
            if node.symbol in self.nonterminals:
                return False
            continue
        if d.complete:
            return False
        if node.rule_applied not in d.options():
            return False
        if len(node.children) != len(node.rule_applied.rhs):
            return False
        for child, expected in zip(node.children, node.rule_applied.rhs):
            if child.symbol != expected:
                return False
        d.apply(node.rule_applied)
        for child in reversed(node.children):
            if child.symbol in self.nonterminals:
                stack.append(child)
    return d.complete

start_derivation

start_derivation(start: Optional[str] = None) -> 'Derivation'

Begin a new derivation.

Examples:

>>> g = Grammar(start="E")
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation()
>>> d.complete
False

Parameters:

Name Type Description Default
start Optional[str]

Start non-terminal. Defaults to self.start when None.

None

Returns:

Type Description
'Derivation'

A Derivation at the first expansion step.

Raises:

Type Description
ValueError

If the resolved start symbol is not a non-terminal in this grammar.

Source code in SRToolkit/utils/grammar/grammar.py
def start_derivation(self, start: Optional[str] = None) -> "Derivation":
    """
    Begin a new derivation.

    Examples:
        >>> g = Grammar(start="E")
        >>> g.add_rule(Rule("E", ["x"]))
        >>> d = g.start_derivation()
        >>> d.complete
        False

    Args:
        start: Start non-terminal.  Defaults to
            ``self.start`` when
            ``None``.

    Returns:
        A [Derivation][SRToolkit.utils.grammar.derivation.Derivation] at the first expansion step.

    Raises:
        ValueError: If the resolved start symbol is not a non-terminal in
            this grammar.
    """
    from .derivation import Derivation

    s = self.start if start is None else start

    if s is None:
        raise ValueError(
            "No start symbol. Either add it thorough the constructor or as parameter "
            "in the Grammar.start_derivation method."
        )
    if s not in self.nonterminals:
        raise ValueError(f"Start symbol '{s}' is not a non-terminal in this grammar.")

    return Derivation(self, s)

generate_one

generate_one(start: Optional[str] = None, max_steps: int = 1000, max_retries: int = 10) -> Optional[list[str]]

Generate a single expression by sampling the grammar to completion.

Convenience wrapper around start_derivation and Derivation.generate.

Each attempt runs a fresh derivation from scratch. When a derivation exceeds max_steps rule applications (e.g. because random sampling kept choosing recursive rules), the attempt is discarded and a new one starts. None is returned only when every attempt fails, signalling that the caller should either relax the grammar, add more liberal constraints, or increase max_steps / max_retries.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar(start="E")
>>> g.add_rule(Rule("E", ["x"]))
>>> g.generate_one()
['x']
>>> g.generate_one(max_steps=0, max_retries=1) is None
True

Parameters:

Name Type Description Default
start Optional[str]

Start non-terminal. Defaults to self.start when None.

None
max_steps int

Maximum rule applications per attempt. A negative number means unlimited (no retry logic applies).

1000
max_retries int

Maximum number of fresh attempts before returning None. Must be at least 1.

10

Returns:

Type Description
Optional[list[str]]

A list of terminal tokens in left-to-right order, or None if

Optional[list[str]]

every attempt exceeded max_steps.

Raises:

Type Description
ValueError

If the resolved start symbol is not a non-terminal in this grammar.

Source code in SRToolkit/utils/grammar/grammar.py
def generate_one(
    self,
    start: Optional[str] = None,
    max_steps: int = 1000,
    max_retries: int = 10,
) -> Optional[list[str]]:
    """
    Generate a single expression by sampling the grammar to completion.

    Convenience wrapper around
    [start_derivation][SRToolkit.utils.grammar.grammar.Grammar.start_derivation]
    and [Derivation.generate][SRToolkit.utils.grammar.derivation.Derivation.generate].

    Each attempt runs a fresh derivation from scratch.  When a derivation
    exceeds ``max_steps`` rule applications (e.g. because random sampling
    kept choosing recursive rules), the attempt is discarded and a new one
    starts.  ``None`` is returned only when every attempt fails, signalling
    that the caller should either relax the grammar, add more liberal
    constraints, or increase ``max_steps`` / ``max_retries``.

    Examples:
        >>> from SRToolkit.utils.grammar import Grammar, Rule
        >>> g = Grammar(start="E")
        >>> g.add_rule(Rule("E", ["x"]))
        >>> g.generate_one()
        ['x']
        >>> g.generate_one(max_steps=0, max_retries=1) is None
        True

    Args:
        start: Start non-terminal. Defaults to ``self.start`` when ``None``.
        max_steps: Maximum rule applications per attempt. A negative number
            means unlimited (no retry logic applies).
        max_retries: Maximum number of fresh attempts before returning
            ``None``. Must be at least ``1``.

    Returns:
        A list of terminal tokens in left-to-right order, or ``None`` if
        every attempt exceeded ``max_steps``.

    Raises:
        ValueError: If the resolved start symbol is not a non-terminal in
            this grammar.
    """
    for _ in range(max(1, max_retries)):
        try:
            return self.start_derivation(start).generate(limit=max_steps)
        except RuntimeError:
            continue
    return None

to_dict

to_dict() -> dict

Serialise this grammar to a JSON-safe dictionary.

Constraints are serialised via their own to_dict method. User-defined constraint subclasses must implement to_dict/from_dict to survive the round-trip; built-in constraints are fully supported.

Returns:

Type Description
dict

Dictionary with keys start, rules, and constraints.

Source code in SRToolkit/utils/grammar/grammar.py
def to_dict(self) -> dict:
    """
    Serialise this grammar to a JSON-safe dictionary.

    Constraints are serialised via their own ``to_dict`` method.  User-defined
    constraint subclasses must implement ``to_dict``/``from_dict`` to survive the
    round-trip; built-in constraints are fully supported.

    Returns:
        Dictionary with keys ``start``, ``rules``, and ``constraints``.
    """
    return {
        "start": self.start,
        "rules": [r.to_dict() for r in self._rules],
        "constraints": [c.to_dict() for c in self._constraints],
    }

from_dict classmethod

from_dict(d: dict) -> 'Grammar'

Reconstruct a Grammar from a dictionary produced by to_dict.

Parameters:

Name Type Description Default
d dict

Dictionary with keys start, rules, and constraints.

required

Returns:

Type Description
'Grammar'

A new Grammar with all rules and

'Grammar'

constraints registered.

Source code in SRToolkit/utils/grammar/grammar.py
@classmethod
def from_dict(cls, d: dict) -> "Grammar":
    """
    Reconstruct a [Grammar][SRToolkit.utils.grammar.Grammar] from a dictionary
    produced by [to_dict][SRToolkit.utils.grammar.Grammar.to_dict].

    Args:
        d: Dictionary with keys ``start``, ``rules``, and ``constraints``.

    Returns:
        A new [Grammar][SRToolkit.utils.grammar.Grammar] with all rules and
        constraints registered.
    """
    from .constraints import Constraint

    rules = [Rule.from_dict(r) for r in d.get("rules", [])]
    g = cls(rules, start=d.get("start"))
    for cd in d.get("constraints", []):
        g.add_constraint(Constraint.from_dict(cd))
    return g

from_symbol_library classmethod

from_symbol_library(symbol_library: Optional[SymbolLibrary] = None, start: Optional[str] = None) -> 'Grammar'

Build a PCFG from a SymbolLibrary using a generic operator-precedence non-terminal hierarchy.

One non-terminal is created per unique OP precedence level present in the library. Known precedences map to readable names; custom precedences fall back to L_{p}:

  • OP_ADDITIVEE
  • OP_MULTIPLICATIVEF
  • OP_POWERB
  • custom precedence pL_{p}

The chain runs from the lowest-precedence level (the grammar start) down to T, K, R, and V. Only levels that have at least one operator are generated; there are no empty pass-through non-terminals.

Heuristic weights are calibrated against empirical expression distributions (e.g. Wikipedia mathematical formulae): E recursion ~30 %, F recursion ~40 %, B recursion ~5 %, custom-level recursion ~40 %.

Examples:

>>> sl = SymbolLibrary.from_symbol_list(["+", "-", "*", "sin", "^2"], 2)
>>> g = Grammar.from_symbol_library(sl)
>>> "E" in g.nonterminals and "V" in g.nonterminals
True
>>> g.is_pcfg()
True

Parameters:

Name Type Description Default
symbol_library Optional[SymbolLibrary]

Token vocabulary. Falls back to the active library from the context manager when None.

None
start Optional[str]

Override the name of the lowest-precedence non-terminal (and g.start). When None (default) the name is auto-inferred from the precedence constants ("E" for standard libraries).

None

Returns:

Type Description
'Grammar'

A Grammar with heuristic PCFG weights.

Raises:

Type Description
ValueError

If the symbol library contains no variables, constants, or literals, making it impossible to generate any terminal expression.

Source code in SRToolkit/utils/grammar/grammar.py
@classmethod
def from_symbol_library(
    cls,
    symbol_library: Optional[SymbolLibrary] = None,
    start: Optional[str] = None,
) -> "Grammar":
    """
    Build a PCFG from a [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary]
    using a generic operator-precedence non-terminal hierarchy.

    One non-terminal is created per unique ``OP`` precedence level present in
    the library.  Known precedences map to readable names; custom precedences
    fall back to ``L_{p}``:

    - ``OP_ADDITIVE`` → ``E``
    - ``OP_MULTIPLICATIVE`` → ``F``
    - ``OP_POWER`` → ``B``
    - custom precedence *p* → ``L_{p}``

    The chain runs from the lowest-precedence level (the grammar start) down
    to ``T``, ``K``, ``R``, and ``V``.  Only levels that have at least one
    operator are generated; there are no empty pass-through non-terminals.

    Heuristic weights are calibrated against empirical expression distributions
    (e.g. Wikipedia mathematical formulae): E recursion ~30 %, F recursion ~40 %,
    B recursion ~5 %, custom-level recursion ~40 %.

    Examples:
        >>> sl = SymbolLibrary.from_symbol_list(["+", "-", "*", "sin", "^2"], 2)
        >>> g = Grammar.from_symbol_library(sl)
        >>> "E" in g.nonterminals and "V" in g.nonterminals
        True
        >>> g.is_pcfg()
        True

    Args:
        symbol_library: Token vocabulary.  Falls back to the active library
            from the context manager when ``None``.
        start: Override the name of the lowest-precedence non-terminal (and
            ``g.start``).  When ``None`` (default) the name is auto-inferred
            from the precedence constants (``"E"`` for standard libraries).

    Returns:
        A [Grammar][SRToolkit.utils.grammar.Grammar] with heuristic PCFG weights.

    Raises:
        ValueError: If the symbol library contains no variables, constants, or
            literals, making it impossible to generate any terminal expression.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_active()

    symbols = list(symbol_library.symbols.values())

    # Group operators by precedence level
    op_groups: dict[int, list[str]] = {}
    for s in symbols:
        if s["type"] == OP:
            op_groups.setdefault(s["precedence"], []).append(s["symbol"])
    sorted_precs = sorted(op_groups)

    R_fns = [s["symbol"] for s in symbols if s["type"] == FN and s["precedence"] == FN_PREFIX]
    P_fns = [s["symbol"] for s in symbols if s["type"] == FN and s["precedence"] == FN_POSTFIX]
    variables = [s["symbol"] for s in symbols if s["type"] == VAR]
    consts = [s["symbol"] for s in symbols if s["type"] == CONST]
    lits = [s["symbol"] for s in symbols if s["type"] == LIT]

    if not variables and not consts and not lits:
        raise ValueError(
            "Symbol library has no variables, constants, or literals; "
            "the grammar cannot generate any terminal expression."
        )

    # Collect terminal names so NT names can be made unique against them.
    all_terminals: set[str] = {sym for ops_list in op_groups.values() for sym in ops_list}
    all_terminals.update(R_fns + P_fns + variables + consts + lits)

    # If the caller explicitly chose a start name that is also a terminal symbol,
    # that is an unresolvable conflict — raise immediately.
    if start is not None and start in all_terminals:
        raise ValueError(
            f"start={start!r} is also a terminal symbol in the SymbolLibrary; "
            f"choose a different start symbol (e.g. {start + '_'!r})."
        )

    def _unique_nt(base: str, taken: set[str]) -> str:
        name = base
        while name in taken:
            name += "_"
        return name

    level_nts: dict[int, str] = {}
    for prec in sorted_precs:
        level_nts[prec] = _unique_nt(_level_name(prec), all_terminals)

    t_nt = _unique_nt("T", all_terminals)
    k_nt = _unique_nt("K", all_terminals)
    r_nt = _unique_nt("R", all_terminals)
    v_nt = _unique_nt("V", all_terminals)

    # top_nt is the topmost level of the hierarchy; functions recurse back here.
    top_nt = level_nts[sorted_precs[0]] if sorted_precs else t_nt

    g = cls(start=start if start is not None else top_nt)

    # Build operator-precedence chain.  Level i expands to itself OP level i+1;
    # the highest level chains to T.
    for i, prec in enumerate(sorted_precs):
        nt = level_nts[prec]
        next_nt = level_nts[sorted_precs[i + 1]] if i + 1 < len(sorted_precs) else t_nt
        ops = op_groups[prec]
        weight = _LEVEL_REC_WEIGHTS.get(prec, 0.4)
        infix = _LEVEL_RULE_INFIXES.get(prec, "op")

        for sym in ops:
            # Power is right-associative (a^b^c = a^(b^c)), so recurse on the right.
            if prec == OP_POWER:
                g.add_rule(Rule(nt, [next_nt, sym, nt], weight=weight / len(ops), name=f"{nt}_{infix}_{sym}"))
            else:
                g.add_rule(Rule(nt, [nt, sym, next_nt], weight=weight / len(ops), name=f"{nt}_{infix}_{sym}"))
        g.add_rule(Rule(nt, [next_nt], weight=1.0 - weight, name=f"{nt}_to_{next_nt}"))

    # T level — leaf dispatcher; K and V branches added only when their symbols exist.
    if variables and (consts or lits):
        g.add_rule(Rule(t_nt, [r_nt], weight=0.2, name=f"{t_nt}_to_{r_nt}"))
        g.add_rule(Rule(t_nt, [k_nt], weight=0.2, name=f"{t_nt}_to_{k_nt}"))
        g.add_rule(Rule(t_nt, [v_nt], weight=0.6, name=f"{t_nt}_to_{v_nt}"))
    elif consts or lits:
        g.add_rule(Rule(t_nt, [r_nt], weight=0.3, name=f"{t_nt}_to_{r_nt}"))
        g.add_rule(Rule(t_nt, [k_nt], weight=0.7, name=f"{t_nt}_to_{k_nt}"))
    else:
        g.add_rule(Rule(t_nt, [r_nt], weight=0.3, name=f"{t_nt}_to_{r_nt}"))
        g.add_rule(Rule(t_nt, [v_nt], weight=0.7, name=f"{t_nt}_to_{v_nt}"))

    # K level — constants and literals
    if lits and consts:
        for sym in lits:
            g.add_rule(Rule(k_nt, [sym], weight=0.2 / len(lits), name=f"{k_nt}_{sym}"))
        for sym in consts:
            g.add_rule(Rule(k_nt, [sym], weight=0.8 / len(consts), name=f"{k_nt}_{sym}"))
    elif lits:
        for sym in lits:
            g.add_rule(Rule(k_nt, [sym], weight=1.0 / len(lits), name=f"{k_nt}_{sym}"))
    elif consts:
        for sym in consts:
            g.add_rule(Rule(k_nt, [sym], weight=1.0 / len(consts), name=f"{k_nt}_{sym}"))

    # R level — prefix functions, postfix functions (top_nt sym), and parens.
    # Postfix rules live here directly (no separate P non-terminal).
    if R_fns:
        for sym in R_fns:
            g.add_rule(Rule(r_nt, [sym, "(", top_nt, ")"], weight=0.4 / len(R_fns), name=f"{r_nt}_fn_{sym}"))
        if P_fns:
            for sym in P_fns:
                g.add_rule(Rule(r_nt, [top_nt, sym], weight=0.05 / len(P_fns), name=f"{r_nt}_postfix_{sym}"))
            g.add_rule(Rule(r_nt, ["(", top_nt, ")"], weight=0.55, name=f"{r_nt}_paren"))
        else:
            g.add_rule(Rule(r_nt, ["(", top_nt, ")"], weight=0.6, name=f"{r_nt}_paren"))
    else:
        if P_fns:
            for sym in P_fns:
                g.add_rule(Rule(r_nt, [top_nt, sym], weight=0.05 / len(P_fns), name=f"{r_nt}_postfix_{sym}"))
            g.add_rule(Rule(r_nt, ["(", top_nt, ")"], weight=0.95, name=f"{r_nt}_paren"))
        else:
            g.add_rule(Rule(r_nt, ["(", top_nt, ")"], weight=1.0, name=f"{r_nt}_paren"))

    # V level — variables
    if variables:
        for sym in variables:
            g.add_rule(Rule(v_nt, [sym], weight=1.0 / len(variables), name=f"{v_nt}_{sym}"))

    return g

from_grammar_string classmethod

from_grammar_string(text: str, start: Optional[str] = None) -> 'Grammar'

Construct a Grammar from a string in NLTK production-rule notation.

Each non-empty, non-comment line must have the form::

LHS -> RHS_1 | RHS_2 | ...

where each alternative is a space-separated sequence of symbols. Terminals are enclosed in single quotes (e.g. '+'); unquoted tokens are treated as non-terminals. An optional weight in square brackets at the end of an alternative is parsed as a float (e.g. E '+' F [0.4]); alternatives without a weight default to 1.0.

The start symbol may be embedded as a comment header emitted by to_grammar_string::

# start: E

The explicit start parameter takes precedence over this comment. When start is None, only the first # start: comment encountered is used; subsequent ones are ignored along with all other comment lines.

Rule names are not preserved in NLTK notation and are set to None on all returned Rule objects.

Examples:

>>> g = Grammar.from_grammar_string("E -> E '+' F | F\nF -> 'x'", start="E")
>>> sorted(g.nonterminals)
['E', 'F']
>>> g.rules_for("F")[0].rhs
['x']
>>> g.start
'E'
>>> g2 = Grammar.from_grammar_string(
...     "# start: E\nE -> E '+' F [0.4] | F [0.6]\nF -> 'x' [1.0]"
... )
>>> g2.start
'E'
>>> g2.is_pcfg()
True

Parameters:

Name Type Description Default
text str

Grammar specification in NLTK notation, optionally with a # start: <symbol> header line.

required
start Optional[str]

Start non-terminal stored on the returned grammar. Takes precedence over a # start: comment in text. Required when neither the parameter nor the comment is present.

None

Returns:

Type Description
'Grammar'

A new Grammar populated with

'Grammar'

the parsed rules.

Raises:

Type Description
ValueError

If a content line contains no ->.

ValueError

If a weight token cannot be converted to float.

ValueError

If text contains no parseable rules.

ValueError

If the start symbol cannot be determined from either the parameter or the # start: comment.

Source code in SRToolkit/utils/grammar/grammar.py
@classmethod
def from_grammar_string(cls, text: str, start: Optional[str] = None) -> "Grammar":
    """
    Construct a [Grammar][SRToolkit.utils.grammar.Grammar] from a string in
    NLTK production-rule notation.

    Each non-empty, non-comment line must have the form::

        LHS -> RHS_1 | RHS_2 | ...

    where each alternative is a space-separated sequence of symbols.
    Terminals are enclosed in single quotes (e.g. ``'+'``); unquoted
    tokens are treated as non-terminals.  An optional weight in square
    brackets at the end of an alternative is parsed as a
    ``float`` (e.g. ``E '+' F [0.4]``); alternatives without a weight
    default to ``1.0``.

    The start symbol may be embedded as a comment header emitted by
    [to_grammar_string][SRToolkit.utils.grammar.Grammar.to_grammar_string]::

        # start: E

    The explicit ``start`` parameter takes precedence over this comment.
    When ``start`` is ``None``, only the first ``# start:`` comment encountered
    is used; subsequent ones are ignored along with all other comment lines.

    Rule names are not preserved in NLTK notation and are set to ``None``
    on all returned [Rule][SRToolkit.utils.grammar.Rule] objects.

    Examples:
        >>> g = Grammar.from_grammar_string("E -> E '+' F | F\\nF -> 'x'", start="E")
        >>> sorted(g.nonterminals)
        ['E', 'F']
        >>> g.rules_for("F")[0].rhs
        ['x']
        >>> g.start
        'E'

        >>> g2 = Grammar.from_grammar_string(
        ...     "# start: E\\nE -> E '+' F [0.4] | F [0.6]\\nF -> 'x' [1.0]"
        ... )
        >>> g2.start
        'E'
        >>> g2.is_pcfg()
        True

    Args:
        text: Grammar specification in NLTK notation, optionally with a
            ``# start: <symbol>`` header line.
        start: Start non-terminal stored on the returned grammar.  Takes
            precedence over a ``# start:`` comment in ``text``.  Required
            when neither the parameter nor the comment is present.

    Returns:
        A new [Grammar][SRToolkit.utils.grammar.Grammar] populated with
        the parsed rules.

    Raises:
        ValueError: If a content line contains no ``->``.
        ValueError: If a weight token cannot be converted to ``float``.
        ValueError: If ``text`` contains no parseable rules.
        ValueError: If the start symbol cannot be determined from either
            the parameter or the ``# start:`` comment.
    """
    resolved_start = start
    rules: list[Rule] = []
    for raw_line in text.splitlines():
        line = raw_line.strip()
        if not line:
            continue
        if line.startswith("#"):
            if resolved_start is None and line.startswith("# start:"):
                resolved_start = line[len("# start:") :].strip()
            continue
        rules.extend(Rule.from_line(line))
    if not rules:
        raise ValueError("No rules parsed from grammar string.")
    if resolved_start is None:
        raise ValueError(
            "No start symbol found. Either pass start=<nonterminal> or include a '# start: <symbol>' line."
        )
    return cls(rules, start=resolved_start)

to_grammar_string

to_grammar_string() -> str

Serialise this grammar to a string in NLTK production-rule notation.

All rules sharing the same left-hand side are written on a single line, separated by |. Symbols that are non-terminals (i.e. appear as the lhs of at least one rule) are written unquoted; all other symbols are enclosed in single quotes. When is_pcfg returns True, the probability of each alternative (its weight divided by the sum of weights for that non-terminal) is appended in square brackets.

Rule names and registered constraints are not included in the output. When self.start is set, a # start: <symbol> comment is prepended so that from_grammar_string can reconstruct the grammar without requiring an explicit start argument.

Examples:

>>> g = Grammar([
...     Rule("E", ["E", "+", "F"]),
...     Rule("E", ["F"]),
...     Rule("F", ["x"]),
... ], start="E")
>>> print(g.to_grammar_string())
# start: E
E -> E '+' F | F
F -> 'x'
>>> gp = Grammar([
...     Rule("E", ["x"], weight=1.0),
...     Rule("E", ["y"], weight=3.0),
... ])
>>> print(gp.to_grammar_string())
E -> 'x' [0.25] | 'y' [0.75]

Returns:

Type Description
str

Multi-line string in NLTK production-rule notation, optionally

str

preceded by a # start: header.

Source code in SRToolkit/utils/grammar/grammar.py
def to_grammar_string(self) -> str:
    """
    Serialise this grammar to a string in NLTK production-rule notation.

    All rules sharing the same left-hand side are written on a single
    line, separated by `` | ``.  Symbols that are non-terminals (i.e.
    appear as the ``lhs`` of at least one rule) are written unquoted;
    all other symbols are enclosed in single quotes.  When
    [is_pcfg][SRToolkit.utils.grammar.Grammar.is_pcfg] returns ``True``,
    the probability of each alternative (its weight divided by the sum of
    weights for that non-terminal) is appended in square brackets.

    Rule names and registered constraints are not included in the output.
    When ``self.start`` is set, a ``# start: <symbol>`` comment is
    prepended so that
    [from_grammar_string][SRToolkit.utils.grammar.Grammar.from_grammar_string]
    can reconstruct the grammar without requiring an explicit ``start``
    argument.

    Examples:
        >>> g = Grammar([
        ...     Rule("E", ["E", "+", "F"]),
        ...     Rule("E", ["F"]),
        ...     Rule("F", ["x"]),
        ... ], start="E")
        >>> print(g.to_grammar_string())
        # start: E
        E -> E '+' F | F
        F -> 'x'

        >>> gp = Grammar([
        ...     Rule("E", ["x"], weight=1.0),
        ...     Rule("E", ["y"], weight=3.0),
        ... ])
        >>> print(gp.to_grammar_string())
        E -> 'x' [0.25] | 'y' [0.75]

    Returns:
        Multi-line string in NLTK production-rule notation, optionally
        preceded by a ``# start:`` header.
    """
    nts = self.nonterminals
    pcfg = self.is_pcfg()
    lines: list[str] = []
    if self.start is not None:
        lines.append(f"# start: {self.start}")
    for lhs, rules in self._rules_by_lhs.items():
        parts: list[str] = []
        if pcfg:
            total = sum(r.weight for r in rules)
        for rule in rules:
            rhs_tokens = [sym if sym in nts else f"'{sym}'" for sym in rule.rhs]
            alt = " ".join(rhs_tokens)
            if pcfg:
                alt += f" [{rule.weight / total}]"
            parts.append(alt)
        lines.append(f"{lhs} -> {' | '.join(parts)}")
    return "\n".join(lines)

MaxDepth

MaxDepth(limit: int)

Bases: Constraint[int, None]

Hard limit on derivation depth.

Local state is the remaining depth budget at each slot. A rule is rejected when its application would require at least one recursive non-terminal child but the budget has reached zero.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["E", "+", "E"]),
...     Rule("E", ["x"]),
... ])
>>> g.add_constraint(MaxDepth(0))
>>> d = g.start_derivation("E")
>>> all(r.rhs == ["x"] for r in d.options())
True

Parameters:

Name Type Description Default
limit int

Maximum nesting depth (number of non-terminal expansion levels).

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, limit: int) -> None:
    self.limit = limit

MaxNodes

MaxNodes(limit: int)

Bases: Constraint[None, int]

Hard limit on the total number of rule applications in a derivation.

Global state is a running count of applications so far. A rule is rejected when applying it would push the count over the limit taking into account the non-terminal children it introduces (each of which will require at least one more application).

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["E", "+", "E"]),
...     Rule("E", ["x"]),
... ])
>>> g.add_constraint(MaxNodes(3))
>>> d = g.start_derivation("E")
>>> # At node-count 3 only the terminal rule survives
>>> tokens = d.generate()
>>> len(tokens) <= 5
True

Parameters:

Name Type Description Default
limit int

Maximum number of rule applications.

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, limit: int) -> None:
    self.limit = limit

MaxOccurrences

MaxOccurrences(symbol: str, limit: int)

Bases: Constraint[None, int]

Hard limit on how many times a specific terminal symbol may appear.

Global state counts occurrences committed so far. A rule whose rhs contains the tracked symbol is rejected once the count equals the limit.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["E", "+", "E"]),
...     Rule("E", ["x"]),
...     Rule("E", ["y"]),
... ])
>>> g.add_constraint(MaxOccurrences("x", 1))
>>> d = g.start_derivation("E")
>>> tokens = d.generate(limit=200)
>>> tokens.count("x") <= 1
True

Parameters:

Name Type Description Default
symbol str

Terminal token to track.

required
limit int

Maximum allowed occurrences.

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, symbol: str, limit: int) -> None:
    self.symbol = symbol
    self.limit = limit

NoNested

NoNested(symbols: str | Iterable[str])

Bases: Constraint[bool, None]

Prevent any symbol in a group from appearing nested inside any other symbol in the same group.

Local state is a boolean "currently under any symbol in the group". Any rule whose rhs contains a group symbol is rejected while the local state is True. Children inherit True once such a rule is applied.

Pass a single string to forbid self-nesting only; pass multiple symbols to forbid cross-nesting within the group (e.g. NoNested(TRIG_FNS) prevents sin(cos(x)) as well as sin(sin(x))).

Warning

This constraint propagates the "inside" flag to all non-terminal children of any rule whose rhs contains a group symbol. It works correctly when each group symbol appears in a rule with exactly one non-terminal child (typical prefix-function rules such as E -> sin ( E )) and for infix operators where every child should be blocked (e.g. E -> E + E). It does not work correctly for mixed rules that combine a function call with additional non-terminal siblings in a single production (e.g. E -> sin ( E ) + F): the sibling F will be incorrectly treated as inside the group. If your grammar uses such rules, implement a custom Constraint instead.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["sin", "(", "E", ")"]),
...     Rule("E", ["cos", "(", "E", ")"]),
...     Rule("E", ["x"]),
... ])
>>> g.add_constraint(NoNested(["sin", "cos"]))
>>> d = g.start_derivation("E")
>>> d.apply(g.rules_for("E")[0])   # apply sin(E)
>>> # Now inside sin; both sin and cos rules should be gone
>>> opts = d.options()
>>> all("sin" not in r.rhs and "cos" not in r.rhs for r in opts)
True

Parameters:

Name Type Description Default
symbols str | Iterable[str]

A single terminal token or an iterable of terminal tokens that form the nesting group.

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, symbols: str | Iterable[str]) -> None:
    if isinstance(symbols, str):
        self.symbols: frozenset[str] = frozenset({symbols})
    else:
        self.symbols = frozenset(symbols)

ParseTree

ParseTree(root: ParseTreeNode)

Full derivation history of an expression under a grammar.

Unlike Node, which holds only terminal symbols for expression evaluation, a ParseTree retains every non-terminal and every production applied during the derivation.

Examples:

>>> r = Rule("E", ["x"])
>>> leaf = ParseTreeNode("x", None)
>>> root = ParseTreeNode("E", r, [leaf])
>>> pt = ParseTree(root)
>>> pt.to_token_list()
['x']
>>> pt.productions_used()
[Rule(lhs='E', rhs=['x'], weight=1.0, name=None)]

Attributes:

Name Type Description
root

The root node of the parse tree.

Source code in SRToolkit/utils/grammar/grammar.py
def __init__(self, root: ParseTreeNode) -> None:
    self.root = root

to_token_list

to_token_list() -> list[str]

Collect all terminal tokens in left-to-right order.

Examples:

>>> leaf1 = ParseTreeNode("a", None)
>>> leaf2 = ParseTreeNode("+", None)
>>> leaf3 = ParseTreeNode("b", None)
>>> root = ParseTreeNode("E", Rule("E", ["a", "+", "b"]), [leaf1, leaf2, leaf3])
>>> ParseTree(root).to_token_list()
['a', '+', 'b']

Returns:

Type Description
list[str]

Flat list of terminal tokens in left-to-right order.

Source code in SRToolkit/utils/grammar/grammar.py
def to_token_list(self) -> list[str]:
    """
    Collect all terminal tokens in left-to-right order.

    Examples:
        >>> leaf1 = ParseTreeNode("a", None)
        >>> leaf2 = ParseTreeNode("+", None)
        >>> leaf3 = ParseTreeNode("b", None)
        >>> root = ParseTreeNode("E", Rule("E", ["a", "+", "b"]), [leaf1, leaf2, leaf3])
        >>> ParseTree(root).to_token_list()
        ['a', '+', 'b']

    Returns:
        Flat list of terminal tokens in left-to-right order.
    """
    tokens: list[str] = []
    self._collect_tokens(self.root, tokens)
    return tokens

productions_used

productions_used() -> list[Rule]

Return all rules applied in the derivation, in pre-order.

Examples:

>>> leaf = ParseTreeNode("x", None)
>>> r = Rule("E", ["x"])
>>> root = ParseTreeNode("E", r, [leaf])
>>> ParseTree(root).productions_used()
[Rule(lhs='E', rhs=['x'], weight=1.0, name=None)]

Returns:

Type Description
list[Rule]

List of Rule objects in pre-order traversal order.

Source code in SRToolkit/utils/grammar/grammar.py
def productions_used(self) -> list[Rule]:
    """
    Return all rules applied in the derivation, in pre-order.

    Examples:
        >>> leaf = ParseTreeNode("x", None)
        >>> r = Rule("E", ["x"])
        >>> root = ParseTreeNode("E", r, [leaf])
        >>> ParseTree(root).productions_used()
        [Rule(lhs='E', rhs=['x'], weight=1.0, name=None)]

    Returns:
        List of [Rule][SRToolkit.utils.grammar.Rule] objects in pre-order traversal order.
    """
    rules: list[Rule] = []
    self._collect_rules(self.root, rules)
    return rules

ParseTreeNode dataclass

ParseTreeNode(symbol: str, rule_applied: Optional[Rule], children: list[ParseTreeNode] = list())

A node in a derivation parse tree.

Terminal leaves have rule_applied=None and children=[]. Internal nodes store the rule that expanded the non-terminal at this position.

Attributes:

Name Type Description
symbol str

The grammar symbol at this node (terminal or non-terminal name).

rule_applied Optional[Rule]

The Rule used to expand this node. None for terminal leaves.

children list[ParseTreeNode]

Ordered child nodes corresponding to the symbols in rule_applied.rhs.

Rule dataclass

Rule(lhs: str, rhs: list[str], weight: float = 1.0, name: Optional[str] = None)

A single production rule in a grammar.

A rule is a CFG production when weights remain 1 for all rules in the grammar; when weights differ across rules, the grammar is treated as a PCFG and productions are sampled proportionally. Weights are unnormalised — the grammar normalises them at sampling time.

Examples:

>>> r = Rule("E", ["E", "+", "F"], weight=0.4, name="E_add_+")
>>> r.lhs
'E'
>>> r.rhs
['E', '+', 'F']
>>> r.weight
0.4
>>> r.name
'E_add_+'
>>> Rule("E", ["F"]).weight
1.0

Attributes:

Name Type Description
lhs str

The non-terminal being expanded, e.g. "E".

rhs list[str]

Ordered sequence of symbols the non-terminal expands to. Each element is either a terminal token (e.g. "+" or "sin") or the name of another non-terminal. A symbol is treated as a non-terminal if and only if it appears as the lhs of at least one rule in the grammar.

weight float

Unnormalised sampling weight. Defaults to 1.0, which produces uniform sampling within a group when all rules share the same weight.

name Optional[str]

Optional stable identifier for this rule. Used by constraints for scoping and identification. None by default.

from_line classmethod

from_line(line: str) -> list['Rule']

Parse one NLTK production line into a list of Rule objects.

The line must have the form::

LHS -> RHS_1 | RHS_2 | ...

where each alternative is a space-separated sequence of symbols. Terminals are enclosed in single quotes; unquoted tokens are non-terminals. An optional weight in square brackets at the end of an alternative is parsed as a float; alternatives without a weight default to 1.0.

Examples:

>>> Rule.from_line("E -> E '+' F [0.4] | F [0.6]")
[Rule(lhs='E', rhs=['E', '+', 'F'], weight=0.4, name=None), Rule(lhs='E', rhs=['F'], weight=0.6, name=None)]
>>> Rule.from_line("F -> 'x'")
[Rule(lhs='F', rhs=['x'], weight=1.0, name=None)]

Parameters:

Name Type Description Default
line str

A single non-empty, non-comment production line.

required

Returns:

Type Description
list['Rule']

List of Rule objects, one per alternative.

Raises:

Type Description
ValueError

If line contains no ->.

ValueError

If the left-hand side is empty.

ValueError

If a weight token cannot be converted to float.

ValueError

If no alternatives are parsed from the right-hand side.

Source code in SRToolkit/utils/grammar/grammar.py
@classmethod
def from_line(cls, line: str) -> list["Rule"]:
    """
    Parse one NLTK production line into a list of [Rule][SRToolkit.utils.grammar.Rule] objects.

    The line must have the form::

        LHS -> RHS_1 | RHS_2 | ...

    where each alternative is a space-separated sequence of symbols.
    Terminals are enclosed in single quotes; unquoted tokens are non-terminals.
    An optional weight in square brackets at the end of an alternative is parsed
    as a ``float``; alternatives without a weight default to ``1.0``.

    Examples:
        >>> Rule.from_line("E -> E '+' F [0.4] | F [0.6]")
        [Rule(lhs='E', rhs=['E', '+', 'F'], weight=0.4, name=None), Rule(lhs='E', rhs=['F'], weight=0.6, name=None)]
        >>> Rule.from_line("F -> 'x'")
        [Rule(lhs='F', rhs=['x'], weight=1.0, name=None)]

    Args:
        line: A single non-empty, non-comment production line.

    Returns:
        List of [Rule][SRToolkit.utils.grammar.Rule] objects, one per alternative.

    Raises:
        ValueError: If ``line`` contains no ``->``.
        ValueError: If the left-hand side is empty.
        ValueError: If a weight token cannot be converted to ``float``.
        ValueError: If no alternatives are parsed from the right-hand side.
    """
    line = line.strip()
    if "->" not in line:
        raise ValueError(f"Expected '->' in grammar line: {line!r}")
    lhs, rhs_part = line.split("->", 1)
    lhs = lhs.strip()
    if not lhs:
        raise ValueError(f"Empty left-hand side in grammar line: {line!r}")
    rules = []
    for alt in rhs_part.split("|"):
        alt = alt.strip()
        if not alt:
            continue
        weight = 1.0
        m = re.search(r"\[([^\]]+)\]\s*$", alt)
        if m:
            try:
                weight = float(m.group(1))
            except ValueError:
                raise ValueError(f"Invalid weight {m.group(1)!r} in grammar line: {line!r}")
            alt = alt[: m.start()].strip()
        tokens = [t[1:-1] if t.startswith("'") and t.endswith("'") else t for t in re.findall(r"'[^']*'|\S+", alt)]
        if tokens:
            rules.append(cls(lhs, tokens, weight=weight))
    if not rules:
        raise ValueError(f"No alternatives parsed from grammar line: {line!r}")
    return rules

SymbolLibrary

SymbolLibrary(symbols: Optional[List[str]] = None, num_variables: int = 0, preamble: Optional[List[str]] = None)

A registry of tokens and their properties, used throughout the toolkit to parse, compile, and generate symbolic expressions.

By default, the library uses NumPy for operator and function evaluation. To use a different backend, pass the required import statements via preamble.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x", "x")
>>> library.get_type("x")
'var'
>>> library.get_precedence("x")
0
>>> library.get_np_fn("x")
'x'
>>> library.remove_symbol("x")
>>> library = SymbolLibrary.default_symbols()
>>> # You can also initialize the library with a list of symbols (listed in SymbolLibrary.default_symbols)
>>> # and the number of variables.
>>> library2 = SymbolLibrary(["+", "*", "sin"], num_variables=2)
>>> len(library2)
5
>>> # Use as a context manager to avoid passing sl explicitly
>>> # with SymbolLibrary.default_symbols(num_variables=2) as sl:
>>> #     tree = tokens_to_tree(["X_0", "+", "X_1", "*", "C"])

Parameters:

Name Type Description Default
symbols Optional[List[str]]

Symbols to pre-populate from the default set. None with num_variables=0 produces an empty library; with num_variables > 0, only variable tokens are added. Symbols not present in the default set are silently ignored. See default_symbols for the supported names.

None
num_variables int

Number of variable tokens to add, labeled X_0 through X_{num_variables-1}. Default is 0.

0
preamble Optional[List[str]]

Import statements prepended to compiled expression functions. Defaults to ["import numpy as np"].

None

Attributes:

Name Type Description
symbols

Mapping from token string to its property dict (type, precedence, NumPy function string, LaTeX template).

Source code in SRToolkit/utils/symbol_library.py
def __init__(
    self, symbols: Optional[List[str]] = None, num_variables: int = 0, preamble: Optional[List[str]] = None
) -> None:
    """
    A registry of tokens and their properties, used throughout the toolkit to parse,
    compile, and generate symbolic expressions.

    By default, the library uses NumPy for operator and function evaluation. To use a
    different backend, pass the required import statements via ``preamble``.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x", "x")
        >>> library.get_type("x")
        'var'
        >>> library.get_precedence("x")
        0
        >>> library.get_np_fn("x")
        'x'
        >>> library.remove_symbol("x")
        >>> library = SymbolLibrary.default_symbols()
        >>> # You can also initialize the library with a list of symbols (listed in SymbolLibrary.default_symbols)
        >>> # and the number of variables.
        >>> library2 = SymbolLibrary(["+", "*", "sin"], num_variables=2)
        >>> len(library2)
        5
        >>> # Use as a context manager to avoid passing sl explicitly
        >>> # with SymbolLibrary.default_symbols(num_variables=2) as sl:
        >>> #     tree = tokens_to_tree(["X_0", "+", "X_1", "*", "C"])

    Args:
        symbols: Symbols to pre-populate from the default set. ``None`` with
            ``num_variables=0`` produces an empty library; with ``num_variables > 0``,
            only variable tokens are added. Symbols not present in the default set are
            silently ignored. See [default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols] for the supported names.
        num_variables: Number of variable tokens to add, labeled ``X_0`` through
            ``X_{num_variables-1}``. Default is ``0``.
        preamble: Import statements prepended to compiled expression functions.
            Defaults to ``["import numpy as np"]``.

    Attributes:
        symbols: Mapping from token string to its property dict (type, precedence,
            NumPy function string, LaTeX template).
    """
    if preamble is None:
        self.preamble = ["import numpy as np"]
    else:
        self.preamble = preamble

    if symbols is None and num_variables == 0:
        self.symbols: Dict[str, Any] = dict()
        self.num_variables = 0
    else:
        if symbols is None:
            symbols = []

        self.symbols = SymbolLibrary.from_symbol_list(symbols, num_variables).symbols
        self.num_variables = num_variables

add_symbol

add_symbol(symbol: str, symbol_type: str, precedence: int, np_fn: str, latex_str: Optional[str] = None, cython_id: int = -1)

Add a token to the library with its associated type, precedence, NumPy function string, and LaTeX template.

Symbol types:

  • "op": binary operator (e.g. +, *).
  • "fn": unary function (e.g. sin, sqrt).
  • "lit": literal with a fixed value (e.g. pi, e).
  • "const": free constant whose value is optimised during parameter estimation (e.g. C). Using a single "const" token is recommended; multiple tokens increase complexity and reduce readability.
  • "var": input variable whose values are read from the data array X.

If latex_str is omitted, a default template is generated based on the symbol type: "{} \text{symb} {}" for operators (e.g. +{} \text{+} {}), "\text{symb} {}" for functions (e.g. sin\text{sin} {}), and "\text{symb}" for all other types.

Note

For best performance, we recommend selecting symbols from the predefined set via the constructor or from_symbol_list rather than adding custom ones. Predefined symbols have C implementations and benefit from Cython acceleration; custom symbols always fall back to Python callables.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
>>> library.add_symbol("C", "const", 5, "C[{}]", r"c_{}")
>>> library.add_symbol("X_0", "var", 5, "X[:, 0]", r"X_0")
>>> library.add_symbol("pi", "lit", 5, "np.pi", r"\pi")

Parameters:

Name Type Description Default
symbol str

Token string to register.

required
symbol_type str

One of "op", "fn", "lit", "const", or "var".

required
precedence int

Operator precedence, used for infix reconstruction and PCFG generation. For "op" symbols use the named constants OP_ADDITIVE (100), OP_MULTIPLICATIVE (200), or OP_POWER (300) — or any integer between them for custom operators. For "fn" symbols use FN_PREFIX (1000) for prefix functions such as sin and FN_POSTFIX (-1000) for postfix functions such as ^2. For "var", "const", and "lit" use LEAF (1000).

required
np_fn str

Python/NumPy expression string used in compiled callables (e.g. "np.sin({})"). For "var" type, passing None or "" auto-generates "X[:, i]" where i is the current num_variables count at the time of the call.

required
latex_str Optional[str]

LaTeX template string with {} placeholders for operands. Auto-generated if omitted.

None
cython_id int

Integer dispatch ID used by the Cython-based evaluator. -1 (default) means the symbol has no C implementation and will fall back to a Python callable during Cython evaluation. Custom symbols should leave this at -1; to redefine a built-in symbol while preserving its Cython acceleration, look up its ID in the source of default_symbols.

-1

Raises:

Type Description
ValueError

If symbol_type is not one of the valid types.

Source code in SRToolkit/utils/symbol_library.py
def add_symbol(
    self,
    symbol: str,
    symbol_type: str,
    precedence: int,
    np_fn: str,
    latex_str: Optional[str] = None,
    cython_id: int = -1,
):
    r"""
    Add a token to the library with its associated type, precedence, NumPy function
    string, and LaTeX template.

    Symbol types:

    - ``"op"``: binary operator (e.g. ``+``, ``*``).
    - ``"fn"``: unary function (e.g. ``sin``, ``sqrt``).
    - ``"lit"``: literal with a fixed value (e.g. ``pi``, ``e``).
    - ``"const"``: free constant whose value is optimised during parameter estimation
      (e.g. ``C``). Using a single ``"const"`` token is recommended; multiple tokens
      increase complexity and reduce readability.
    - ``"var"``: input variable whose values are read from the data array ``X``.

    If ``latex_str`` is omitted, a default template is generated based on the symbol type:
    ``"{} \text{symb} {}"`` for operators (e.g. ``+`` → ``{} \text{+} {}``),
    ``"\text{symb} {}"`` for functions (e.g. ``sin`` → ``\text{sin} {}``),
    and ``"\text{symb}"`` for all other types.

    Note:
        For best performance, we recommend selecting symbols from the predefined set via the
        constructor or [from_symbol_list][SRToolkit.utils.symbol_library.SymbolLibrary.from_symbol_list]
        rather than adding custom ones. Predefined symbols have C implementations and
        benefit from Cython acceleration; custom symbols always fall back to Python callables.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> library.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
        >>> library.add_symbol("C", "const", 5, "C[{}]", r"c_{}")
        >>> library.add_symbol("X_0", "var", 5, "X[:, 0]", r"X_0")
        >>> library.add_symbol("pi", "lit", 5, "np.pi", r"\pi")

    Args:
        symbol: Token string to register.
        symbol_type: One of ``"op"``, ``"fn"``, ``"lit"``, ``"const"``, or ``"var"``.
        precedence: Operator precedence, used for infix reconstruction and PCFG generation.
            For ``"op"`` symbols use the named constants
            ``OP_ADDITIVE`` (100), ``OP_MULTIPLICATIVE`` (200), or ``OP_POWER`` (300) —
            or any integer between them for custom operators.
            For ``"fn"`` symbols use ``FN_PREFIX`` (1000) for prefix functions such as
            ``sin`` and ``FN_POSTFIX`` (-1000) for postfix functions such as ``^2``.
            For ``"var"``, ``"const"``, and ``"lit"`` use ``LEAF`` (1000).
        np_fn: Python/NumPy expression string used in compiled callables
            (e.g. ``"np.sin({})"``). For ``"var"`` type, passing ``None``
            or ``""`` auto-generates ``"X[:, i]"`` where ``i`` is the current
            ``num_variables`` count at the time of the call.
        latex_str: LaTeX template string with ``{}`` placeholders for operands.
            Auto-generated if omitted.
        cython_id: Integer dispatch ID used by the Cython-based evaluator.
            ``-1`` (default) means the symbol has no C implementation and will
            fall back to a Python callable during Cython evaluation. Custom symbols
            should leave this at ``-1``; to redefine a built-in symbol while
            preserving its Cython acceleration, look up its ID in the source of
            [default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].

    Raises:
        ValueError: If ``symbol_type`` is not one of the valid types.
    """
    if symbol_type not in VALID_SYMBOL_TYPES:
        raise ValueError(f"Invalid symbol type '{symbol_type}'. Must be one of: {sorted(VALID_SYMBOL_TYPES)}")

    if latex_str is None:
        if symbol_type == "op":
            latex_str = r"{} \text{{" + symbol + r"}} {}"
        elif symbol_type == "fn":
            latex_str = r"\text{{" + symbol + r"}} {}"
        else:
            latex_str = r"\text{{" + symbol + r"}}"

    if symbol_type == "var" and (np_fn is None or np_fn == ""):
        np_fn = "X[:, {}]".format(self.num_variables)

    if symbol_type == "var":
        self.num_variables += 1

    self.symbols[symbol] = {
        "symbol": symbol,
        "type": symbol_type,
        "precedence": precedence,
        "np_fn": np_fn,
        "latex_str": latex_str,
        "cython_id": cython_id,
    }

remove_symbol

remove_symbol(symbol: str)

Remove a token from the library.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> len(library.symbols)
1
>>> library.remove_symbol("x")
>>> len(library.symbols)
0

Parameters:

Name Type Description Default
symbol str

Token string to remove.

required

Raises:

Type Description
KeyError

If symbol is not present in the library.

Source code in SRToolkit/utils/symbol_library.py
def remove_symbol(self, symbol: str):
    """
    Remove a token from the library.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> len(library.symbols)
        1
        >>> library.remove_symbol("x")
        >>> len(library.symbols)
        0

    Args:
        symbol: Token string to remove.

    Raises:
        KeyError: If ``symbol`` is not present in the library.
    """
    del self.symbols[symbol]

get_type

get_type(symbol: str) -> str

Return the type of a symbol.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.get_type("x")
'var'

Parameters:

Name Type Description Default
symbol str

Token to look up.

required

Returns:

Type Description
str

The type string ("op", "fn", "lit", "const", or "var") if the symbol is in the library, otherwise an empty string.

Source code in SRToolkit/utils/symbol_library.py
def get_type(self, symbol: str) -> str:
    """
    Return the type of a symbol.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> library.get_type("x")
        'var'

    Args:
        symbol: Token to look up.

    Returns:
        The type string (``"op"``, ``"fn"``, ``"lit"``, ``"const"``, or ``"var"``) if the symbol is in the library, otherwise an empty string.
    """
    if symbol in self.symbols:
        return self.symbols[symbol]["type"]
    else:
        return ""

get_precedence

get_precedence(symbol: str) -> int

Return the precedence of a symbol.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.get_precedence("x")
0

Parameters:

Name Type Description Default
symbol str

Token to look up.

required

Returns:

Type Description
int

The precedence value if the symbol is in the library, otherwise -1.

Source code in SRToolkit/utils/symbol_library.py
def get_precedence(self, symbol: str) -> int:
    """
    Return the precedence of a symbol.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> library.get_precedence("x")
        0

    Args:
        symbol: Token to look up.

    Returns:
        The precedence value if the symbol is in the library, otherwise ``-1``.
    """
    if symbol in self.symbols:
        return self.symbols[symbol]["precedence"]
    else:
        return -1

get_np_fn

get_np_fn(symbol: str) -> str

Return the NumPy function string for a symbol.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.get_np_fn("x")
'x'

Parameters:

Name Type Description Default
symbol str

Token to look up.

required

Returns:

Type Description
str

The NumPy function string if the symbol is in the library, otherwise an empty string.

Source code in SRToolkit/utils/symbol_library.py
def get_np_fn(self, symbol: str) -> str:
    """
    Return the NumPy function string for a symbol.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> library.get_np_fn("x")
        'x'

    Args:
        symbol: Token to look up.

    Returns:
        The NumPy function string if the symbol is in the library, otherwise an empty string.
    """
    if symbol in self.symbols:
        return self.symbols[symbol]["np_fn"]
    else:
        return ""

get_cython_id

get_cython_id(symbol: str) -> int

Return the Cython stack machine dispatch ID for a symbol.

Returns -1 for symbols without a C implementation (they will fall back to a Python callable during Cython evaluation).

Examples:

>>> library = SymbolLibrary.default_symbols()
>>> library.get_cython_id("+")
0
>>> library.get_cython_id("sin")
10
>>> library.get_cython_id("unknown_symbol")
-1

Parameters:

Name Type Description Default
symbol str

Token to look up.

required

Returns:

Type Description
int

Integer dispatch ID, or -1 if the symbol is not in the library

int

or has no Cython implementation.

Source code in SRToolkit/utils/symbol_library.py
def get_cython_id(self, symbol: str) -> int:
    """
    Return the Cython stack machine dispatch ID for a symbol.

    Returns ``-1`` for symbols without a C implementation (they will fall
    back to a Python callable during Cython evaluation).

    Examples:
        >>> library = SymbolLibrary.default_symbols()
        >>> library.get_cython_id("+")
        0
        >>> library.get_cython_id("sin")
        10
        >>> library.get_cython_id("unknown_symbol")
        -1

    Args:
        symbol: Token to look up.

    Returns:
        Integer dispatch ID, or ``-1`` if the symbol is not in the library
        or has no Cython implementation.
    """
    if symbol in self.symbols:
        return self.symbols[symbol].get("cython_id", -1)
    return -1

get_latex_str

get_latex_str(symbol: str) -> str

Return the LaTeX template string for a symbol.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x", "test")
>>> library.get_latex_str("x")
'test'

Parameters:

Name Type Description Default
symbol str

Token to look up.

required

Returns:

Type Description
str

The LaTeX template string if the symbol is in the library, otherwise an empty string.

Source code in SRToolkit/utils/symbol_library.py
def get_latex_str(self, symbol: str) -> str:
    """
    Return the LaTeX template string for a symbol.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x", "test")
        >>> library.get_latex_str("x")
        'test'

    Args:
        symbol: Token to look up.

    Returns:
        The LaTeX template string if the symbol is in the library, otherwise an empty string.
    """
    if symbol in self.symbols:
        return self.symbols[symbol]["latex_str"]
    else:
        return ""

get_symbols_of_type

get_symbols_of_type(symbol_type: str) -> List[str]

Return all symbols of a given type.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.add_symbol("y", "var", 0, "y")
>>> library.get_symbols_of_type("var")
['x', 'y']

Parameters:

Name Type Description Default
symbol_type str

Type to filter by. One of "op", "fn", "var", "const", "lit".

required

Returns:

Type Description
List[str]

List of token strings matching the requested type. Returns an empty list

List[str]

if no symbols match or if symbol_type is not recognised.

Source code in SRToolkit/utils/symbol_library.py
def get_symbols_of_type(self, symbol_type: str) -> List[str]:
    """
    Return all symbols of a given type.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> library.add_symbol("y", "var", 0, "y")
        >>> library.get_symbols_of_type("var")
        ['x', 'y']

    Args:
        symbol_type: Type to filter by. One of ``"op"``, ``"fn"``, ``"var"``,
            ``"const"``, ``"lit"``.

    Returns:
        List of token strings matching the requested type. Returns an empty list
        if no symbols match or if ``symbol_type`` is not recognised.
    """
    symbols = list()
    for symbol in self.symbols.keys():
        if self.get_type(symbol) == symbol_type:
            symbols.append(symbol)

    return symbols

symbols2index

symbols2index() -> Dict[str, int]

Return a mapping from each token to its index in insertion order.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.add_symbol("y", "var", 0, "y")
>>> print(library.symbols2index())
{'x': 0, 'y': 1}
>>> library.remove_symbol("x")
>>> print(library.symbols2index())
{'y': 0}

Returns:

Type Description
Dict[str, int]

Dict mapping each token string to its zero-based position in the library.

Source code in SRToolkit/utils/symbol_library.py
def symbols2index(self) -> Dict[str, int]:
    """
    Return a mapping from each token to its index in insertion order.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x")
        >>> library.add_symbol("y", "var", 0, "y")
        >>> print(library.symbols2index())
        {'x': 0, 'y': 1}
        >>> library.remove_symbol("x")
        >>> print(library.symbols2index())
        {'y': 0}

    Returns:
        Dict mapping each token string to its zero-based position in the library.
    """
    return {s: i for i, s in enumerate(self.symbols.keys())}

from_symbol_list staticmethod

from_symbol_list(symbols: List[str], num_variables: int = 25) -> SymbolLibrary

Create a SymbolLibrary containing only the specified subset of default symbols.

The supported token names are those defined in default_symbols.

Examples:

>>> library = SymbolLibrary.from_symbol_list(["+", "*", "C"], num_variables=2)
>>> len(library.symbols)
5

Parameters:

Name Type Description Default
symbols List[str]

Token strings to include. Symbols not in the default set are silently ignored.

required
num_variables int

Number of variable tokens (X_0 through X_{num_variables-1}). Default is 25.

25

Returns:

Type Description
SymbolLibrary

A SymbolLibrary restricted to the requested symbols and variables.

Source code in SRToolkit/utils/symbol_library.py
@staticmethod
def from_symbol_list(symbols: List[str], num_variables: int = 25) -> "SymbolLibrary":
    """
    Create a [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] containing only the specified subset of default symbols.

    The supported token names are those defined in [default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].

    Examples:
        >>> library = SymbolLibrary.from_symbol_list(["+", "*", "C"], num_variables=2)
        >>> len(library.symbols)
        5

    Args:
        symbols: Token strings to include. Symbols not in the default set are silently ignored.
        num_variables: Number of variable tokens (``X_0`` through ``X_{num_variables-1}``).
            Default is ``25``.

    Returns:
        A [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] restricted to the requested symbols and variables.
    """
    variables = [f"X_{i}" for i in range(num_variables)]
    symbols = symbols + variables

    sl = SymbolLibrary.default_symbols(num_variables)

    all_symbols = list(sl.symbols.keys())
    for symbol in all_symbols:
        if symbol not in symbols:
            sl.remove_symbol(symbol)

    return sl

default_symbols staticmethod

default_symbols(num_variables: int = 25) -> SymbolLibrary

Return a SymbolLibrary pre-populated with standard mathematical symbols.

Supported tokens:

  • Operators ("op"): +, -, *, /, ^
  • Functions ("fn"): u-, sqrt, sin, cos, exp, tan, arcsin, arccos, arctan, sinh, cosh, tanh, floor, ceil, ln, log, ^-1, ^2, ^3, ^4, ^5
  • Literals ("lit"): pi, e
  • Free constant ("const"): C
  • Variables ("var"): X_0 through X_{num_variables-1}, mapped to columns of the input array in order.

Examples:

>>> library = SymbolLibrary.default_symbols()
>>> len(library)
54

Parameters:

Name Type Description Default
num_variables int

Number of variable tokens to include. Default is 25.

25

Returns:

Type Description
SymbolLibrary

A SymbolLibrary populated with the symbols listed above.

Source code in SRToolkit/utils/symbol_library.py
@staticmethod
def default_symbols(num_variables: int = 25) -> "SymbolLibrary":
    """
    Return a [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] pre-populated with standard mathematical symbols.

    Supported tokens:

    - **Operators** (``"op"``): ``+``, ``-``, ``*``, ``/``, ``^``
    - **Functions** (``"fn"``): ``u-``, ``sqrt``, ``sin``, ``cos``, ``exp``, ``tan``,
      ``arcsin``, ``arccos``, ``arctan``, ``sinh``, ``cosh``, ``tanh``, ``floor``,
      ``ceil``, ``ln``, ``log``, ``^-1``, ``^2``, ``^3``, ``^4``, ``^5``
    - **Literals** (``"lit"``): ``pi``, ``e``
    - **Free constant** (``"const"``): ``C``
    - **Variables** (``"var"``): ``X_0`` through ``X_{num_variables-1}``,
      mapped to columns of the input array in order.

    Examples:
        >>> library = SymbolLibrary.default_symbols()
        >>> len(library)
        54

    Args:
        num_variables: Number of variable tokens to include. Default is ``25``.

    Returns:
        A [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] populated with the symbols listed above.
    """
    sl = SymbolLibrary()
    sl.add_symbol(
        "+",
        symbol_type="op",
        precedence=OP_ADDITIVE,
        np_fn="{} + {}",
        latex_str=r"{} + {}",
        cython_id=0,
    )
    sl.add_symbol(
        "-",
        symbol_type="op",
        precedence=OP_ADDITIVE,
        np_fn="{} - {}",
        latex_str=r"{} - {}",
        cython_id=1,
    )
    sl.add_symbol(
        "*",
        symbol_type="op",
        precedence=OP_MULTIPLICATIVE,
        np_fn="{} * {}",
        latex_str=r"{} \cdot {}",
        cython_id=2,
    )
    sl.add_symbol(
        "/",
        symbol_type="op",
        precedence=OP_MULTIPLICATIVE,
        np_fn="{} / {}",
        latex_str=r"\frac{{{}}}{{{}}}",
        cython_id=3,
    )
    sl.add_symbol(
        "^",
        symbol_type="op",
        precedence=OP_POWER,
        np_fn="np.power({},{})",
        latex_str=r"{}^{{{}}}",
        cython_id=4,
    )
    sl.add_symbol("u-", symbol_type="fn", precedence=FN_PREFIX, np_fn="-{}", latex_str=r"- {}", cython_id=25)
    sl.add_symbol(
        "sqrt",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.sqrt({})",
        latex_str=r"\sqrt {{{}}}",
        cython_id=14,
    )
    sl.add_symbol(
        "sin",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.sin({})",
        latex_str=r"\sin {}",
        cython_id=10,
    )
    sl.add_symbol(
        "cos",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.cos({})",
        latex_str=r"\cos {}",
        cython_id=11,
    )
    sl.add_symbol(
        "exp",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.exp({})",
        latex_str=r"e^{{{}}}",
        cython_id=13,
    )
    sl.add_symbol(
        "tan",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.tan({})",
        latex_str=r"\tan {}",
        cython_id=12,
    )
    sl.add_symbol(
        "arcsin",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.arcsin({})",
        latex_str=r"\arcsin {}",
        cython_id=17,
    )
    sl.add_symbol(
        "arccos",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.arccos({})",
        latex_str=r"\arccos {}",
        cython_id=18,
    )
    sl.add_symbol(
        "arctan",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.arctan({})",
        latex_str=r"\arctan {}",
        cython_id=19,
    )
    sl.add_symbol(
        "sinh",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.sinh({})",
        latex_str=r"\sinh {}",
        cython_id=20,
    )
    sl.add_symbol(
        "cosh",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.cosh({})",
        latex_str=r"\cosh {}",
        cython_id=21,
    )
    sl.add_symbol(
        "tanh",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.tanh({})",
        latex_str=r"\tanh {}",
        cython_id=22,
    )
    sl.add_symbol(
        "floor",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.floor({})",
        latex_str=r"\lfloor {} \rfloor",
        cython_id=23,
    )
    sl.add_symbol(
        "ceil",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.ceil({})",
        latex_str=r"\lceil {} \rceil",
        cython_id=24,
    )
    sl.add_symbol(
        "ln",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.log({})",
        latex_str=r"\ln {}",
        cython_id=15,
    )
    sl.add_symbol(
        "log",
        symbol_type="fn",
        precedence=FN_PREFIX,
        np_fn="np.log10({})",
        latex_str=r"\log_{{10}} {}",
        cython_id=16,
    )
    sl.add_symbol(
        "^-1",
        symbol_type="fn",
        precedence=FN_POSTFIX,
        np_fn="1/{}",
        latex_str=r"{}^{{-1}}",
        cython_id=26,
    )
    sl.add_symbol("^2", symbol_type="fn", precedence=FN_POSTFIX, np_fn="{}**2", latex_str=r"{}^2", cython_id=27)
    sl.add_symbol("^3", symbol_type="fn", precedence=FN_POSTFIX, np_fn="{}**3", latex_str=r"{}^3", cython_id=28)
    sl.add_symbol("^4", symbol_type="fn", precedence=FN_POSTFIX, np_fn="{}**4", latex_str=r"{}^4", cython_id=29)
    sl.add_symbol("^5", symbol_type="fn", precedence=FN_POSTFIX, np_fn="{}**5", latex_str=r"{}^5", cython_id=30)
    sl.add_symbol(
        "pi",
        symbol_type="lit",
        precedence=LEAF,
        np_fn="np.pi",
        latex_str=r"\pi",
    )
    sl.add_symbol(
        "e",
        symbol_type="lit",
        precedence=LEAF,
        np_fn="np.e",
        latex_str=r"e",
    )
    sl.add_symbol(
        "C",
        symbol_type="const",
        precedence=LEAF,
        np_fn="np.full(X.shape[0], C[{}])",
        latex_str=r"C_{{{}}}",
    )

    if num_variables > 0:
        for i in range(num_variables):
            sl.add_symbol(f"X_{i}", "var", LEAF, "X[:, {}]".format(i), "X_{{{}}}".format(i))

    return sl

get_active staticmethod

get_active() -> SymbolLibrary

Return the currently active SymbolLibrary.

Checks, in order: (1) the context manager stack, (2) the module-level default set via set_default.

Returns:

Type Description
SymbolLibrary

The active SymbolLibrary instance.

Raises:

Type Description
RuntimeError

If no library is active and no default has been set.

Source code in SRToolkit/utils/symbol_library.py
@staticmethod
def get_active() -> "SymbolLibrary":
    """
    Return the currently active [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary].

    Checks, in order: (1) the context manager stack, (2) the module-level default set via
    [set_default][SRToolkit.utils.symbol_library.SymbolLibrary.set_default].

    Returns:
        The active [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] instance.

    Raises:
        RuntimeError: If no library is active and no default has been set.
    """
    try:
        return _active_sl.get()
    except LookupError:
        if _default_sl is not None:
            return _default_sl
        raise RuntimeError(
            "No active SymbolLibrary. Either pass one explicitly, use "
            "'with SymbolLibrary(...) as sl:', or call SymbolLibrary.set_default(sl)."
        )

get_or_default staticmethod

get_or_default() -> SymbolLibrary

Return the active SymbolLibrary, falling back to default_symbols when nothing is active.

Checks, in order: (1) the context manager stack, (2) the module-level default set via set_default, (3) a freshly constructed default library.

Returns:

Type Description
SymbolLibrary

The active or default SymbolLibrary instance.

Source code in SRToolkit/utils/symbol_library.py
@staticmethod
def get_or_default() -> "SymbolLibrary":
    """
    Return the active [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary], falling back to
    [default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols] when nothing is active.

    Checks, in order: (1) the context manager stack, (2) the module-level default set via
    [set_default][SRToolkit.utils.symbol_library.SymbolLibrary.set_default], (3) a freshly
    constructed default library.

    Returns:
        The active or default [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] instance.
    """
    try:
        return _active_sl.get()
    except LookupError:
        if _default_sl is not None:
            return _default_sl
        return SymbolLibrary.default_symbols()

set_default staticmethod

set_default(sl: Optional[SymbolLibrary]) -> None

Set (or clear) a module-level default SymbolLibrary.

The default is used as a fallback by get_active and get_or_default when no context manager is active. It is module-global (not per-thread or per-task) and intended for scripts and notebooks where a single library is used throughout a session.

Pass None to clear the default.

Parameters:

Name Type Description Default
sl Optional[SymbolLibrary]

Library to set as the module-level default, or None to clear it.

required
Source code in SRToolkit/utils/symbol_library.py
@staticmethod
def set_default(sl: Optional["SymbolLibrary"]) -> None:
    """
    Set (or clear) a module-level default [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary].

    The default is used as a fallback by [get_active][SRToolkit.utils.symbol_library.SymbolLibrary.get_active]
    and [get_or_default][SRToolkit.utils.symbol_library.SymbolLibrary.get_or_default] when no context manager
    is active. It is module-global (not per-thread or per-task) and intended for scripts and
    notebooks where a single library is used throughout a session.

    Pass ``None`` to clear the default.

    Args:
        sl: Library to set as the module-level default, or ``None`` to clear it.
    """
    global _default_sl
    _default_sl = sl

to_dict

to_dict() -> dict

Serialize the library to a JSON-safe dictionary.

Examples:

>>> library = SymbolLibrary.from_symbol_list(["+"], num_variables=1)
>>> d = library.to_dict()
>>> d["format_version"]
1
>>> d["num_variables"]
1

Returns:

Type Description
dict

A dictionary suitable for passing to from_dict.

Source code in SRToolkit/utils/symbol_library.py
def to_dict(self) -> dict:
    """
    Serialize the library to a JSON-safe dictionary.

    Examples:
        >>> library = SymbolLibrary.from_symbol_list(["+"], num_variables=1)
        >>> d = library.to_dict()
        >>> d["format_version"]
        1
        >>> d["num_variables"]
        1

    Returns:
        A dictionary suitable for passing to [from_dict][SRToolkit.utils.symbol_library.SymbolLibrary.from_dict].
    """
    return {
        "format_version": 1,
        "type": "SymbolLibrary",
        "symbols": self.symbols,
        "preamble": self.preamble,
        "num_variables": self.num_variables,
    }

from_dict staticmethod

from_dict(d: dict) -> SymbolLibrary

Reconstruct a SymbolLibrary from a dictionary produced by to_dict.

Parameters:

Name Type Description Default
d dict

Dictionary representation of the library, as produced by to_dict.

required

Returns:

Type Description
SymbolLibrary

The reconstructed SymbolLibrary.

Raises:

Type Description
ValueError

If d["format_version"] is not 1.

Source code in SRToolkit/utils/symbol_library.py
@staticmethod
def from_dict(d: dict) -> "SymbolLibrary":
    """
    Reconstruct a [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] from a dictionary produced by [to_dict][SRToolkit.utils.symbol_library.SymbolLibrary.to_dict].

    Args:
        d: Dictionary representation of the library, as produced by
            [to_dict][SRToolkit.utils.symbol_library.SymbolLibrary.to_dict].

    Returns:
        The reconstructed [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary].

    Raises:
        ValueError: If ``d["format_version"]`` is not ``1``.
    """
    if d.get("format_version", 1) != 1:
        raise ValueError(
            f"[SymbolLibrary.from_dict] Unsupported format_version: {d.get('format_version')!r}. Expected 1."
        )
    sl = SymbolLibrary()
    sl.symbols = d["symbols"]
    sl.preamble = d["preamble"]
    sl.num_variables = d["num_variables"]
    return sl

__len__

__len__() -> int

Return the number of symbols currently in the library.

Examples:

>>> library = SymbolLibrary.default_symbols(5)
>>> len(library)
34
>>> library.add_symbol("a", "lit", 5, "a", "a")
>>> len(library)
35

Returns:

Type Description
int

Number of tokens registered in the library.

Source code in SRToolkit/utils/symbol_library.py
def __len__(self) -> int:
    """
    Return the number of symbols currently in the library.

    Examples:
         >>> library = SymbolLibrary.default_symbols(5)
         >>> len(library)
         34
         >>> library.add_symbol("a", "lit", 5, "a", "a")
         >>> len(library)
         35

    Returns:
        Number of tokens registered in the library.
    """
    return len(self.symbols)

__str__

__str__() -> str

Return a comma-separated string of all registered token strings.

Examples:

>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x", "x")
>>> str(library)
'x'
>>> library.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
>>> str(library)
'x, sin'

Returns:

Type Description
str

All token names joined by ", ", in insertion order.

Source code in SRToolkit/utils/symbol_library.py
def __str__(self) -> str:
    r"""
    Return a comma-separated string of all registered token strings.

    Examples:
        >>> library = SymbolLibrary()
        >>> library.add_symbol("x", "var", 0, "x", "x")
        >>> str(library)
        'x'
        >>> library.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
        >>> str(library)
        'x, sin'

    Returns:
        All token names joined by ``", "``, in insertion order.
    """
    return ", ".join(self.symbols.keys())

__copy__

__copy__() -> SymbolLibrary

Return a copy of the library with independent copies of all attributes.

Examples:

>>> old_symbols = SymbolLibrary()
>>> old_symbols.add_symbol("x", "var", 0, "x", "x")
>>> print(old_symbols)
x
>>> new_symbols = copy.copy(old_symbols)
>>> new_symbols.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
>>> print(old_symbols)
x
>>> print(new_symbols)
x, sin

Returns:

Type Description
SymbolLibrary

A new SymbolLibrary instance with deep-copied symbols and preamble.

Source code in SRToolkit/utils/symbol_library.py
def __copy__(self) -> "SymbolLibrary":
    r"""
    Return a copy of the library with independent copies of all attributes.

    Examples:
        >>> old_symbols = SymbolLibrary()
        >>> old_symbols.add_symbol("x", "var", 0, "x", "x")
        >>> print(old_symbols)
        x
        >>> new_symbols = copy.copy(old_symbols)
        >>> new_symbols.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
        >>> print(old_symbols)
        x
        >>> print(new_symbols)
        x, sin

    Returns:
        A new [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] instance with deep-copied symbols and preamble.
    """
    sl = SymbolLibrary()
    sl.symbols = copy.deepcopy(self.symbols)
    sl.preamble = copy.deepcopy(self.preamble)
    sl.num_variables = self.num_variables
    return sl

EstimationSettings

Bases: TypedDict

Shared settings for parameter estimation and BED evaluation.

Passed as **kwargs to SR_dataset, SR_evaluator, and ParameterEstimator. All fields are optional.

Examples:

>>> settings: EstimationSettings = {"method": "L-BFGS-B", "max_iter": 200}
>>> settings.get("method")
'L-BFGS-B'
>>> settings.get("tol", 1e-6)
1e-06

Attributes:

Name Type Description
method str

Optimization algorithm for parameter fitting. Default: "L-BFGS-B".

tol float

Termination tolerance for the optimizer. Default: 1e-6.

gtol float

Gradient-norm termination tolerance. Default: 1e-3.

max_iter int

Maximum optimizer iterations. Default: 100.

constant_bounds Tuple[float, float]

(lower, upper) bounds for sampled constant values. Default: (-5, 5).

initialization str

Constant initialization strategy — "random" samples uniformly within constant_bounds; "mean" sets all constants to the midpoint. Default: "random".

max_constants int

Maximum number of free constants permitted in a single expression. Expressions exceeding this limit score NaN. Default: 8.

max_expr_length int

Maximum expression length in tokens. -1 disables the limit. Default: -1.

backend str

Evaluation backend used by ParameterEstimator. "stack" uses the Cython-backed postfix stack machine (default); "codegen" generates Python/NumPy source via exec(); "stack_py" uses the pure-Python stack machine (faster than "stack" when the symbol library contains many custom symbols).

num_points_sampled int

Number of domain points used when evaluating expression behavior for BED. -1 uses all points in X. Default: 64.

bed_X Optional[ndarray]

Fixed evaluation points for BED. If None, points are sampled from domain_bounds or selected randomly from X. Default: None.

num_consts_sampled int

Number of constant vectors sampled per expression for BED. Default: 32.

domain_bounds Optional[List[Tuple[float, float]]]

Per-variable (lower, upper) bounds used to sample bed_X when it is None. Default: None.

EvalResult dataclass

EvalResult(min_error: float, best_expr: str, num_evaluated: int, evaluation_calls: int, top_models: List[ModelResult], all_models: List[ModelResult], approach_name: str, success: bool, dataset_name: Optional[str] = None, metadata: Optional[dict] = None, augmentations: Dict[str, Dict[str, Any]] = dict())

Result for a single SR experiment, as returned by SR_results[i].

Examples:

>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
...     min_error=0.05,
...     best_expr="X_0",
...     num_evaluated=500,
...     evaluation_calls=612,
...     top_models=[model],
...     all_models=[model],
...     approach_name="MyApproach",
...     success=True,
... )
>>> result.min_error
0.05
>>> result.success
True
>>> result.dataset_name is None
True

Attributes:

Name Type Description
min_error float

Lowest error achieved across all evaluated expressions.

best_expr str

String representation of the best expression found.

num_evaluated int

Number of unique expressions evaluated.

evaluation_calls int

Number of times evaluate_expr was called (includes cache hits).

top_models List[ModelResult]

Top-k models sorted by error.

all_models List[ModelResult]

All evaluated models sorted by error.

approach_name str

Name of the SR approach, or empty string if not provided.

success bool

Whether min_error is below the configured success_threshold.

dataset_name Optional[str]

Name of the dataset. None if not provided.

metadata Optional[dict]

Arbitrary metadata dict associated with the dataset. None if not provided.

augmentations Dict[str, Dict[str, Any]]

Per-augmenter data keyed by augmenter name. Populated by ResultAugmenter subclasses via add_augmentation.

add_augmentation

add_augmentation(name: str, data: Dict[str, Any], aug_type: str) -> None

Attach augmentation data produced by a ResultAugmenter to this result.

If name is already present in :attr:augmentations, a numeric suffix is appended (name_1, name_2, …) to avoid overwriting existing data.

Examples:

>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
...     min_error=0.05, best_expr="X_0", num_evaluated=10,
...     evaluation_calls=10, top_models=[model], all_models=[model],
...     approach_name="MyApproach", success=True,
... )
>>> result.add_augmentation("complexity", {"value": 3}, "ComplexityAugmenter")
>>> result.augmentations["complexity"]["value"]
3
>>> result.add_augmentation("complexity", {"value": 5}, "ComplexityAugmenter")
>>> "complexity_1" in result.augmentations
True

Parameters:

Name Type Description Default
name str

Key under which the augmentation is stored in :attr:augmentations. A suffix is added automatically if the key already exists.

required
data Dict[str, Any]

Arbitrary dict of augmentation data. Any existing "_type" key will be overwritten with aug_type.

required
aug_type str

Augmenter class name, stored as data["_type"].

required
Source code in SRToolkit/utils/types.py
def add_augmentation(self, name: str, data: Dict[str, Any], aug_type: str) -> None:
    """
    Attach augmentation data produced by a [ResultAugmenter][SRToolkit.evaluation.sr_evaluator.ResultAugmenter] to this result.

    If ``name`` is already present in :attr:`augmentations`, a numeric suffix is
    appended (``name_1``, ``name_2``, …) to avoid overwriting existing data.

    Examples:
        >>> model = ModelResult(expr=["X_0"], error=0.05)
        >>> result = EvalResult(
        ...     min_error=0.05, best_expr="X_0", num_evaluated=10,
        ...     evaluation_calls=10, top_models=[model], all_models=[model],
        ...     approach_name="MyApproach", success=True,
        ... )
        >>> result.add_augmentation("complexity", {"value": 3}, "ComplexityAugmenter")
        >>> result.augmentations["complexity"]["value"]
        3
        >>> result.add_augmentation("complexity", {"value": 5}, "ComplexityAugmenter")
        >>> "complexity_1" in result.augmentations
        True

    Args:
        name: Key under which the augmentation is stored in :attr:`augmentations`.
            A suffix is added automatically if the key already exists.
        data: Arbitrary dict of augmentation data. Any existing ``"_type"`` key
            will be overwritten with ``aug_type``.
        aug_type: Augmenter class name, stored as ``data["_type"]``.
    """
    resolved = name
    counter = 1
    while resolved in self.augmentations:
        resolved = f"{name}_{counter}"
        counter += 1
    data["_type"] = aug_type
    self.augmentations[resolved] = data

to_dict

to_dict() -> dict

Serialize this evaluation result to a JSON-safe dictionary.

NumPy arrays and scalars within nested ModelResult entries are converted to native Python types so the result can be passed directly to json.dump.

Examples:

>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
...     min_error=0.05, best_expr="X_0", num_evaluated=10,
...     evaluation_calls=10, top_models=[model], all_models=[model],
...     approach_name="MyApproach", success=True,
... )
>>> d = result.to_dict()
>>> d["min_error"]
0.05
>>> d["approach_name"]
'MyApproach'
>>> len(d["top_models"])
1

Returns:

Type Description
dict

A JSON-safe dictionary suitable for passing to from_dict.

Source code in SRToolkit/utils/types.py
def to_dict(self) -> dict:
    """
    Serialize this evaluation result to a JSON-safe dictionary.

    NumPy arrays and scalars within nested [ModelResult][SRToolkit.utils.types.ModelResult] entries are
    converted to native Python types so the result can be passed directly
    to ``json.dump``.

    Examples:
        >>> model = ModelResult(expr=["X_0"], error=0.05)
        >>> result = EvalResult(
        ...     min_error=0.05, best_expr="X_0", num_evaluated=10,
        ...     evaluation_calls=10, top_models=[model], all_models=[model],
        ...     approach_name="MyApproach", success=True,
        ... )
        >>> d = result.to_dict()
        >>> d["min_error"]
        0.05
        >>> d["approach_name"]
        'MyApproach'
        >>> len(d["top_models"])
        1

    Returns:
        A JSON-safe dictionary suitable for passing to [from_dict][SRToolkit.utils.types.EvalResult.from_dict].
    """
    return {
        "min_error": float(self.min_error),
        "best_expr": self.best_expr,
        "num_evaluated": int(self.num_evaluated),
        "evaluation_calls": int(self.evaluation_calls),
        "top_models": [m.to_dict() for m in self.top_models],
        "all_models": [m.to_dict() for m in self.all_models],
        "approach_name": self.approach_name,
        "success": bool(self.success),
        "dataset_name": self.dataset_name,
        "metadata": self.metadata,
        "augmentations": _to_json_safe(self.augmentations),
    }

from_dict staticmethod

from_dict(data: dict) -> EvalResult

Reconstruct an EvalResult from a dictionary produced by to_dict.

Examples:

>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
...     min_error=0.05, best_expr="X_0", num_evaluated=10,
...     evaluation_calls=10, top_models=[model], all_models=[model],
...     approach_name="MyApproach", success=True,
... )
>>> result2 = EvalResult.from_dict(result.to_dict())
>>> result2.min_error
0.05
>>> result2.best_expr
'X_0'
>>> len(result2.top_models)
1

Parameters:

Name Type Description Default
data dict

Dictionary representation of an EvalResult, as produced by to_dict.

required

Returns:

Type Description
EvalResult

The reconstructed EvalResult.

Source code in SRToolkit/utils/types.py
@staticmethod
def from_dict(data: dict) -> "EvalResult":
    """
    Reconstruct an [EvalResult][SRToolkit.utils.types.EvalResult] from a dictionary produced by [to_dict][SRToolkit.utils.types.EvalResult.to_dict].

    Examples:
        >>> model = ModelResult(expr=["X_0"], error=0.05)
        >>> result = EvalResult(
        ...     min_error=0.05, best_expr="X_0", num_evaluated=10,
        ...     evaluation_calls=10, top_models=[model], all_models=[model],
        ...     approach_name="MyApproach", success=True,
        ... )
        >>> result2 = EvalResult.from_dict(result.to_dict())
        >>> result2.min_error
        0.05
        >>> result2.best_expr
        'X_0'
        >>> len(result2.top_models)
        1

    Args:
        data: Dictionary representation of an [EvalResult][SRToolkit.utils.types.EvalResult], as produced
            by [to_dict][SRToolkit.utils.types.EvalResult.to_dict].

    Returns:
        The reconstructed [EvalResult][SRToolkit.utils.types.EvalResult].
    """
    return EvalResult(
        min_error=data["min_error"],
        best_expr=data["best_expr"],
        num_evaluated=data["num_evaluated"],
        evaluation_calls=data["evaluation_calls"],
        top_models=[ModelResult.from_dict(m) for m in data["top_models"]],
        all_models=[ModelResult.from_dict(m) for m in data["all_models"]],
        approach_name=data["approach_name"],
        success=data["success"],
        dataset_name=data.get("dataset_name"),
        metadata=data.get("metadata"),
        augmentations=_from_json_safe(data["augmentations"]),
    )

ModelResult dataclass

ModelResult(expr: List[str], error: float, parameters: Optional[ndarray] = None, augmentations: Dict[str, Dict[str, Any]] = dict())

A single model entry in EvalResult.top_models and EvalResult.all_models.

Examples:

>>> result = ModelResult(expr=["C", "*", "X_0"], error=0.42)
>>> result.expr
['C', '*', 'X_0']
>>> result.error
0.42
>>> result.parameters is None
True

Attributes:

Name Type Description
expr List[str]

Token list representing the expression, e.g. ["C", "*", "X_0"].

error float

Numeric error under the ranking function (RMSE or BED).

parameters Optional[ndarray]

Fitted constant values. Present for RMSE ranking only, None otherwise.

augmentations Dict[str, Dict[str, Any]]

Per-augmenter data keyed by augmenter name. Populated by ResultAugmenter subclasses via add_augmentation.

add_augmentation

add_augmentation(name: str, data: Dict[str, Any], aug_type: str) -> None

Attach augmentation data produced by a ResultAugmenter to this result.

If name is already present in :attr:augmentations, a numeric suffix is appended (name_1, name_2, …) to avoid overwriting existing data.

Examples:

>>> result = ModelResult(expr=["X_0"], error=0.1)
>>> result.add_augmentation("latex", {"value": "$X_0$"}, "LaTeXAugmenter")
>>> result.augmentations["latex"]["value"]
'$X_0$'
>>> result.add_augmentation("latex", {"value": "$X_0$"}, "LaTeXAugmenter")
>>> "latex_1" in result.augmentations
True

Parameters:

Name Type Description Default
name str

Key under which the augmentation is stored in :attr:augmentations. A suffix is added automatically if the key already exists.

required
data Dict[str, Any]

Arbitrary dict of augmentation data. Any existing "_type" key will be overwritten with aug_type.

required
aug_type str

Augmenter class name, stored as data["_type"].

required
Source code in SRToolkit/utils/types.py
def add_augmentation(self, name: str, data: Dict[str, Any], aug_type: str) -> None:
    """
    Attach augmentation data produced by a [ResultAugmenter][SRToolkit.evaluation.sr_evaluator.ResultAugmenter] to this result.

    If ``name`` is already present in :attr:`augmentations`, a numeric suffix is
    appended (``name_1``, ``name_2``, …) to avoid overwriting existing data.

    Examples:
        >>> result = ModelResult(expr=["X_0"], error=0.1)
        >>> result.add_augmentation("latex", {"value": "$X_0$"}, "LaTeXAugmenter")
        >>> result.augmentations["latex"]["value"]
        '$X_0$'
        >>> result.add_augmentation("latex", {"value": "$X_0$"}, "LaTeXAugmenter")
        >>> "latex_1" in result.augmentations
        True

    Args:
        name: Key under which the augmentation is stored in :attr:`augmentations`.
            A suffix is added automatically if the key already exists.
        data: Arbitrary dict of augmentation data. Any existing ``"_type"`` key
            will be overwritten with ``aug_type``.
        aug_type: Augmenter class name, stored as ``data["_type"]``.
    """
    resolved = name
    counter = 1
    while resolved in self.augmentations:
        resolved = f"{name}_{counter}"
        counter += 1
    data["_type"] = aug_type
    self.augmentations[resolved] = data

to_dict

to_dict() -> dict

Serialize this model result to a JSON-safe dictionary.

NumPy arrays and scalars are converted to native Python types so the result can be passed directly to json.dump.

Examples:

>>> result = ModelResult(expr=["X_0", "+", "C"], error=0.25)
>>> d = result.to_dict()
>>> d["expr"]
['X_0', '+', 'C']
>>> d["error"]
0.25
>>> d["parameters"] is None
True

Returns:

Type Description
dict

A JSON-safe dictionary suitable for passing to from_dict.

Source code in SRToolkit/utils/types.py
def to_dict(self) -> dict:
    """
    Serialize this model result to a JSON-safe dictionary.

    NumPy arrays and scalars are converted to native Python types so the
    result can be passed directly to ``json.dump``.

    Examples:
        >>> result = ModelResult(expr=["X_0", "+", "C"], error=0.25)
        >>> d = result.to_dict()
        >>> d["expr"]
        ['X_0', '+', 'C']
        >>> d["error"]
        0.25
        >>> d["parameters"] is None
        True

    Returns:
        A JSON-safe dictionary suitable for passing to [from_dict][SRToolkit.utils.types.ModelResult.from_dict].
    """
    return {
        "expr": self.expr,
        "error": float(self.error),
        "parameters": _to_json_safe(self.parameters),
        "augmentations": _to_json_safe(self.augmentations),
    }

from_dict staticmethod

from_dict(data: dict) -> ModelResult

Reconstruct a ModelResult from a dictionary produced by to_dict.

Examples:

>>> result = ModelResult(expr=["X_0", "+", "C"], error=0.25)
>>> result2 = ModelResult.from_dict(result.to_dict())
>>> result2.expr
['X_0', '+', 'C']
>>> result2.error
0.25

Parameters:

Name Type Description Default
data dict

Dictionary representation of a ModelResult, as produced by to_dict.

required

Returns:

Type Description
ModelResult

The reconstructed :class:ModelResult.

Source code in SRToolkit/utils/types.py
@staticmethod
def from_dict(data: dict) -> "ModelResult":
    """
    Reconstruct a [ModelResult][SRToolkit.utils.types.ModelResult] from a dictionary produced by [to_dict][SRToolkit.utils.types.ModelResult.to_dict].

    Examples:
        >>> result = ModelResult(expr=["X_0", "+", "C"], error=0.25)
        >>> result2 = ModelResult.from_dict(result.to_dict())
        >>> result2.expr
        ['X_0', '+', 'C']
        >>> result2.error
        0.25

    Args:
        data: Dictionary representation of a [ModelResult][SRToolkit.utils.types.ModelResult], as produced
            by [to_dict][SRToolkit.utils.types.ModelResult.to_dict].

    Returns:
        The reconstructed :class:`ModelResult`.
    """
    return ModelResult(
        expr=data["expr"],
        error=data["error"],
        parameters=_from_json_safe(data["parameters"]),
        augmentations=_from_json_safe(data["augmentations"]),
    )

compile_expr

compile_expr(expr: Union[List[str], Node], symbol_library: Optional[SymbolLibrary] = None, backend: str = 'stack') -> Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]

Compile an expression into a callable f(X, C) → np.ndarray.

Examples:

>>> f = compile_expr(["X_0", "+", "1"])
>>> f(np.array([[1.0], [2.0], [3.0]]), None)
array([2., 3., 4.])
>>> f = compile_expr(["X_0", "+", "1"], backend="codegen")
>>> f(np.array([[1], [2], [3]]), np.array([]))
array([2, 3, 4])

Parameters:

Name Type Description Default
expr Union[List[str], Node]

Expression as a token list in infix notation or a Node tree.

required
symbol_library Optional[SymbolLibrary]

Symbol library used to look up token types. Defaults to SymbolLibrary.default_symbols.

None
backend str

Evaluation backend. One of:

  • "stack" (default): postfix stack-machine evaluator backed by Cython; falls back to pure-Python when the compiled extension is unavailable.
  • "codegen": generates Python/NumPy source via exec(); compatible with any custom symbol library.
  • "stack_py": pure-Python stack-machine evaluator; avoids the Cython→Python boundary overhead when the library contains many custom symbols.
'stack'

Returns:

Type Description
Callable[[ndarray, Optional[ndarray]], ndarray]

A callable f(X, C) where X is a 2-D array of shape

Callable[[ndarray, Optional[ndarray]], ndarray]

(n_samples, n_features) and C is a 1-D array of constant values

Callable[[ndarray, Optional[ndarray]], ndarray]

(pass None or an empty array for constant-free expressions).

Callable[[ndarray, Optional[ndarray]], ndarray]

Returns a 1-D output array of shape (n_samples,).

Raises:

Type Description
ValueError

If backend is not one of the supported values.

Source code in SRToolkit/utils/expression_compiler.py
def compile_expr(
    expr: Union[List[str], Node],
    symbol_library: Optional[SymbolLibrary] = None,
    backend: str = "stack",
) -> Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]:
    """
    Compile an expression into a callable ``f(X, C) → np.ndarray``.

    Examples:
        >>> f = compile_expr(["X_0", "+", "1"])
        >>> f(np.array([[1.0], [2.0], [3.0]]), None)
        array([2., 3., 4.])
        >>> f = compile_expr(["X_0", "+", "1"], backend="codegen")
        >>> f(np.array([[1], [2], [3]]), np.array([]))
        array([2, 3, 4])

    Args:
        expr: Expression as a token list in infix notation or a
            [Node][SRToolkit.utils.expression_tree.Node] tree.
        symbol_library: Symbol library used to look up token types.
            Defaults to [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].
        backend: Evaluation backend. One of:

            - ``"stack"`` (default): postfix stack-machine evaluator backed by
              Cython; falls back to pure-Python when the compiled extension is
              unavailable.
            - ``"codegen"``: generates Python/NumPy source via ``exec()``;
              compatible with any custom symbol library.
            - ``"stack_py"``: pure-Python stack-machine evaluator; avoids the
              Cython→Python boundary overhead when the library contains many
              custom symbols.

    Returns:
        A callable ``f(X, C)`` where ``X`` is a 2-D array of shape
        ``(n_samples, n_features)`` and ``C`` is a 1-D array of constant values
        (pass ``None`` or an empty array for constant-free expressions).
        Returns a 1-D output array of shape ``(n_samples,)``.

    Raises:
        ValueError: If ``backend`` is not one of the supported values.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    if backend == "stack":
        return _expr_to_cython_callable(expr, symbol_library)
    elif backend == "codegen":
        return _expr_to_executable_function(expr, symbol_library)
    elif backend == "stack_py":
        return _expr_to_python_callable(expr, symbol_library)
    else:
        raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'stack', 'codegen', 'stack_py'.")

compile_expr_rmse

compile_expr_rmse(expr: Union[List[str], Node], symbol_library: Optional[SymbolLibrary] = None, backend: str = 'stack', X: Optional[ndarray] = None) -> Callable[[np.ndarray, np.ndarray, np.ndarray], float]

Compile an expression into an RMSE callable f(X, C, y) → float.

Examples:

>>> f = compile_expr_rmse(["X_0", "+", "1"])
>>> f(np.array([[1.0], [2.0], [3.0]]), np.array([]), np.array([2.0, 3.0, 4.0]))
0.0
>>> f = compile_expr_rmse(["X_0", "+", "1"], backend="codegen")
>>> print(float(f(np.array([[1], [2], [3]]), np.array([]), np.array([2, 3, 4]))))
0.0

Parameters:

Name Type Description Default
expr Union[List[str], Node]

Expression as a token list in infix notation or a Node tree.

required
symbol_library Optional[SymbolLibrary]

Symbol library used to look up token types. Defaults to SymbolLibrary.default_symbols.

None
backend str

Evaluation backend. One of:

  • "stack" (default): postfix stack-machine evaluator backed by Cython; RMSE is computed in C without an intermediate output array. Falls back to pure-Python when the compiled extension is unavailable.
  • "codegen": generates Python/NumPy source via exec(); compatible with any custom symbol library.
  • "stack_py": pure-Python stack-machine evaluator; avoids the Cython→Python boundary overhead when the library contains many custom symbols.
'stack'
X Optional[ndarray]

Optional input data of shape (n_samples, n_features). When provided with backend="stack" or backend="stack_py", all constant-free subtrees are pre-evaluated against X at compile time — a significant speedup when the same X is reused across many calls with varying C (e.g. inside an optimiser loop). Ignored for backend="codegen".

None

Returns:

Type Description
Callable[[ndarray, ndarray, ndarray], float]

A callable f(X, C, y) returning the scalar RMSE as a float.

Raises:

Type Description
ValueError

If backend is not one of the supported values.

Source code in SRToolkit/utils/expression_compiler.py
def compile_expr_rmse(
    expr: Union[List[str], Node],
    symbol_library: Optional[SymbolLibrary] = None,
    backend: str = "stack",
    X: Optional[np.ndarray] = None,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray], float]:
    """
    Compile an expression into an RMSE callable ``f(X, C, y) → float``.

    Examples:
        >>> f = compile_expr_rmse(["X_0", "+", "1"])
        >>> f(np.array([[1.0], [2.0], [3.0]]), np.array([]), np.array([2.0, 3.0, 4.0]))
        0.0
        >>> f = compile_expr_rmse(["X_0", "+", "1"], backend="codegen")
        >>> print(float(f(np.array([[1], [2], [3]]), np.array([]), np.array([2, 3, 4]))))
        0.0

    Args:
        expr: Expression as a token list in infix notation or a
            [Node][SRToolkit.utils.expression_tree.Node] tree.
        symbol_library: Symbol library used to look up token types.
            Defaults to [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].
        backend: Evaluation backend. One of:

            - ``"stack"`` (default): postfix stack-machine evaluator backed by
              Cython; RMSE is computed in C without an intermediate output array.
              Falls back to pure-Python when the compiled extension is unavailable.
            - ``"codegen"``: generates Python/NumPy source via ``exec()``;
              compatible with any custom symbol library.
            - ``"stack_py"``: pure-Python stack-machine evaluator; avoids the
              Cython→Python boundary overhead when the library contains many
              custom symbols.

        X: Optional input data of shape ``(n_samples, n_features)``. When provided
            with ``backend="stack"`` or ``backend="stack_py"``, all constant-free
            subtrees are pre-evaluated against *X* at compile time — a significant
            speedup when the same *X* is reused across many calls with varying *C*
            (e.g. inside an optimiser loop). Ignored for ``backend="codegen"``.

    Returns:
        A callable ``f(X, C, y)`` returning the scalar RMSE as a float.

    Raises:
        ValueError: If ``backend`` is not one of the supported values.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    if backend == "stack":
        return _expr_to_cython_error_callable(expr, symbol_library, X)
    elif backend == "codegen":
        return _expr_to_error_function(expr, symbol_library)
    elif backend == "stack_py":
        return _expr_to_python_stack_error_callable(expr, symbol_library, X)
    else:
        raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'stack', 'codegen', 'stack_py'.")

generate_n_expressions

generate_n_expressions(expression_description: Union[str, SymbolLibrary, Grammar], num_expressions: int, unique: bool = True, max_expression_length: int = 50, verbose: bool = True, max_consecutive_generation_failures: int = 100, max_consecutive_uniqueness_failures: int = 200, max_derivation_steps: int = 1000, start: str = 'E') -> List[List[str]]

Sample num_expressions expressions from a grammar or symbol library.

Examples:

>>> len(generate_n_expressions(SymbolLibrary.default_symbols(5), 100, verbose=False))
100
>>> generate_n_expressions(SymbolLibrary.from_symbol_list([], 1), 3, unique=False, verbose=False, max_expression_length=1)
[['X_0'], ['X_0'], ['X_0']]

Parameters:

Name Type Description Default
expression_description Union[str, SymbolLibrary, Grammar]

Grammar source — one of:

required
num_expressions int

Number of expressions to generate.

required
unique bool

If True, every expression in the output is lexicographically distinct (though semantically equivalent expressions may still appear). Default True.

True
max_expression_length int

Maximum token count per expression. Values ≤ 0 allow unbounded length. Default 50.

50
verbose bool

Display a progress bar showing total attempts, the ratio of invalid expressions (derivation failed or exceeded max_expression_length), and — when unique=True — the ratio of duplicate expressions among valid ones and the total number of generation attempts.

True
max_consecutive_generation_failures int

Maximum number of consecutive attempts that produce an invalid expression (derivation failed or result exceeded max_expression_length) before raising an exception. Resets on any valid expression. Default 100.

100
max_consecutive_uniqueness_failures int

Maximum number of consecutive valid expressions that are already in the output set before stopping early and returning what has been collected so far. Only relevant when unique=True. Resets whenever a new unique expression is found. Default 200.

200
max_derivation_steps int

Maximum number of rule applications per single derivation attempt before it is abandoned. Increase for grammars with deep recursion that legitimately require many steps. Default 1000.

1000
start str

Start non-terminal used when expression_description is a grammar string. Ignored for Grammar and SymbolLibrary inputs. Default "E".

'E'

Returns:

Type Description
List[List[str]]

List of expressions, each represented as a list of string tokens in infix notation.

List[List[str]]

May contain fewer than num_expressions entries if unique=True and the

List[List[str]]

search space is exhausted before the target count is reached.

Raises:

Type Description
Exception

If expression_description is not a Grammar, SymbolLibrary, or str.

Exception

If generation fails max_consecutive_generation_failures times in a row, indicating the grammar or length constraint may be too restrictive.

Warns:

Type Description
UserWarning

If more than 80 % of attempts produce an invalid expression (after at least 500 total attempts), suggesting the grammar or constraints are overly restrictive. Emitted at most once per call.

UserWarning

If unique=True and no new unique expression is found after max_consecutive_uniqueness_failures consecutive valid attempts, indicating the search space may be exhausted. The expressions collected so far are returned.

Source code in SRToolkit/utils/expression_generator.py
def generate_n_expressions(
    expression_description: Union[str, SymbolLibrary, Grammar],
    num_expressions: int,
    unique: bool = True,
    max_expression_length: int = 50,
    verbose: bool = True,
    max_consecutive_generation_failures: int = 100,
    max_consecutive_uniqueness_failures: int = 200,
    max_derivation_steps: int = 1000,
    start: str = "E",
) -> List[List[str]]:
    """
    Sample ``num_expressions`` expressions from a grammar or symbol library.

    Examples:
        >>> len(generate_n_expressions(SymbolLibrary.default_symbols(5), 100, verbose=False))
        100
        >>> generate_n_expressions(SymbolLibrary.from_symbol_list([], 1), 3, unique=False, verbose=False, max_expression_length=1)
        [['X_0'], ['X_0'], ['X_0']]

    Args:
        expression_description: Grammar source — one of:

            - A [Grammar][SRToolkit.utils.grammar.Grammar] object (used directly, including
              any registered constraints).
            - A [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] (a generic PCFG
              is built automatically via
              [Grammar.from_symbol_library][SRToolkit.utils.grammar.Grammar.from_symbol_library]).
            - A grammar string in the NLTK notation (with an optional ``# start: nonterminal``
              line where nonterminal indicates the start symbol) used by
              [Grammar.from_grammar_string][SRToolkit.utils.grammar.Grammar.from_grammar_string].

        num_expressions: Number of expressions to generate.
        unique: If ``True``, every expression in the output is lexicographically distinct
            (though semantically equivalent expressions may still appear). Default ``True``.
        max_expression_length: Maximum token count per expression. Values ≤ ``0``
            allow unbounded length. Default ``50``.
        verbose: Display a progress bar showing total attempts, the ratio of invalid
            expressions (derivation failed or exceeded ``max_expression_length``), and —
            when ``unique=True`` — the ratio of duplicate expressions among valid ones
            and the total number of generation attempts.
        max_consecutive_generation_failures: Maximum number of consecutive attempts that
            produce an invalid expression (derivation failed *or* result exceeded
            ``max_expression_length``) before raising an exception. Resets on any valid expression.
            Default ``100``.
        max_consecutive_uniqueness_failures: Maximum number of consecutive valid expressions
            that are already in the output set before stopping early and returning what has
            been collected so far. Only relevant when ``unique=True``. Resets whenever a new
            unique expression is found. Default ``200``.
        max_derivation_steps: Maximum number of rule applications per single derivation
            attempt before it is abandoned. Increase for grammars with deep recursion that
            legitimately require many steps. Default ``1000``.
        start: Start non-terminal used when ``expression_description`` is a grammar
            string. Ignored for [Grammar][SRToolkit.utils.grammar.Grammar] and
            [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] inputs.
            Default ``"E"``.

    Returns:
        List of expressions, each represented as a list of string tokens in infix notation.
        May contain fewer than ``num_expressions`` entries if ``unique=True`` and the
        search space is exhausted before the target count is reached.

    Raises:
        Exception: If ``expression_description`` is not a
            [Grammar][SRToolkit.utils.grammar.Grammar],
            [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary], or ``str``.
        Exception: If generation fails ``max_consecutive_generation_failures`` times in a
            row, indicating the grammar or length constraint may be too restrictive.

    Warns:
        UserWarning: If more than 80 % of attempts produce an invalid expression (after at
            least ``500`` total attempts), suggesting the grammar
            or constraints are overly restrictive. Emitted at most once per call.
        UserWarning: If ``unique=True`` and no new unique expression is found after
            ``max_consecutive_uniqueness_failures`` consecutive valid attempts, indicating
            the search space may be exhausted. The expressions collected so far are returned.
    """
    if isinstance(expression_description, Grammar):
        grammar = expression_description
    elif isinstance(expression_description, SymbolLibrary):
        grammar = Grammar.from_symbol_library(expression_description)
    elif isinstance(expression_description, str):
        grammar = Grammar.from_grammar_string(expression_description, start=start)
    else:
        raise Exception("expression_description must be a Grammar, SymbolLibrary, or grammar string.")

    expressions: List[List[str]] = []
    expression_strings: set = set()

    total_attempts = 0
    total_invalid = 0
    total_duplicates = 0
    consecutive_generation_failures = 0
    consecutive_uniqueness_failures = 0
    ratio_warning_sent = False

    pbar = tqdm(total=num_expressions) if verbose else None

    def update_postfix() -> None:
        if pbar is None:
            return
        fail_pct = total_invalid / total_attempts if total_attempts > 0 else 0.0
        postfix: dict = {"attempts": total_attempts, "fail%": f"{fail_pct:.0%}"}
        if unique:
            total_valid = total_attempts - total_invalid
            dup_pct = total_duplicates / total_valid if total_valid > 0 else 0.0
            postfix["dup%"] = f"{dup_pct:.0%}"
        pbar.set_postfix(postfix)

    while len(expressions) < num_expressions:
        expr = grammar.generate_one(max_retries=1, max_steps=max_derivation_steps)
        total_attempts += 1

        if expr is None or (max_expression_length > 0 and len(expr) > max_expression_length):
            total_invalid += 1
            consecutive_generation_failures += 1

            if (
                not ratio_warning_sent
                and total_attempts >= _MIN_ATTEMPTS_BEFORE_WARNING
                and total_invalid / total_attempts > _INVALID_RATIO_THRESHOLD
            ):
                warnings.warn(
                    f"[Expression generation] {total_invalid / total_attempts:.0%} of "
                    f"{total_attempts} attempts produced an invalid expression (derivation "
                    "failed or exceeded max_expression_length). Consider relaxing grammar "
                    "constraints or increasing max_expression_length.",
                    stacklevel=2,
                )
                ratio_warning_sent = True

            if consecutive_generation_failures >= max_consecutive_generation_failures:
                if pbar is not None:
                    pbar.close()
                raise Exception(
                    f"[Expression generation] Failed to generate a valid expression "
                    f"{consecutive_generation_failures} times in a row "
                    f"({total_invalid} invalid out of {total_attempts} total attempts). "
                    "The grammar or length constraint may be too restrictive."
                )
            update_postfix()
            continue

        consecutive_generation_failures = 0

        expr_string = "".join(expr)
        if unique and expr_string in expression_strings:
            total_duplicates += 1
            consecutive_uniqueness_failures += 1
            if consecutive_uniqueness_failures >= max_consecutive_uniqueness_failures:
                warnings.warn(
                    f"[Expression generation] Failed to find a new unique expression "
                    f"{consecutive_uniqueness_failures} times in a row — stopping early "
                    f"with {len(expressions)} of {num_expressions} expressions collected. "
                    "The expression search space may be exhausted.",
                    stacklevel=2,
                )
                break
            update_postfix()
            continue

        consecutive_uniqueness_failures = 0
        expressions.append(expr)
        if unique:
            expression_strings.add(expr_string)
        if pbar is not None:
            pbar.update(1)
        update_postfix()

    if pbar is not None:
        pbar.close()
    return expressions

simplify

simplify(expr: Union[List[str], Node], symbol_library: Optional[SymbolLibrary] = None) -> Union[List[str], Node]

Simplify an expression algebraically.

Two successive steps are applied:

  1. SymPy simplification — expands and reduces the expression algebraically (e.g. X_0 * X_1 / X_0X_1).
  2. Constant folding — collapses any sub-expression containing no variables into a single free constant C (e.g. C * C + CC).

Examples:

>>> expr = ["C", "+", "C", "*", "C", "+", "X_0", "*", "X_1", "/", "X_0"]
>>> print("".join(simplify(expr)))
C+X_1

Parameters:

Name Type Description Default
expr Union[List[str], Node]

Expression as a token list in infix notation or a Node tree.

required
symbol_library Optional[SymbolLibrary]

Symbol library defining variables and constants. Defaults to SymbolLibrary.default_symbols.

None

Returns:

Type Description
Union[List[str], Node]

The simplified expression in the same form as the input (list if a list was given, Node if a tree was given).

Raises:

Type Description
Exception

If simplification fails or the result contains tokens absent from symbol_library.

Source code in SRToolkit/utils/expression_simplifier.py
def simplify(
    expr: Union[List[str], Node],
    symbol_library: Optional[SymbolLibrary] = None,
) -> Union[List[str], Node]:
    """
    Simplify an expression algebraically.

    Two successive steps are applied:

    1. **SymPy simplification** — expands and reduces the expression algebraically
       (e.g. ``X_0 * X_1 / X_0`` → ``X_1``).
    2. **Constant folding** — collapses any sub-expression containing no variables
       into a single free constant ``C`` (e.g. ``C * C + C`` → ``C``).

    Examples:
        >>> expr = ["C", "+", "C", "*", "C", "+", "X_0", "*", "X_1", "/", "X_0"]
        >>> print("".join(simplify(expr)))
        C+X_1

    Args:
        expr: Expression as a token list in infix notation or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        symbol_library: Symbol library defining variables and constants.
            Defaults to [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].

    Returns:
        The simplified expression in the same form as the input (list if a list was given, [Node][SRToolkit.utils.expression_tree.Node] if a tree was given).

    Raises:
        Exception: If simplification fails or the result contains tokens absent from
            ``symbol_library``.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    is_tree = False
    if isinstance(expr, Node):
        expr = expr.to_list(symbol_library=symbol_library, notation="infix")
        is_tree = True

    variables = symbol_library.get_symbols_of_type("var")

    # We expect only one symbol for constants
    if len(symbol_library.get_symbols_of_type("const")) > 0:
        constant = symbol_library.get_symbols_of_type("const")[0]
    else:
        # In this case constants shouldn't be problematic as they are not in the SymbolLibrary
        # Just in case and to not change other functions, I changed it to __C__.
        constant = "__C__"

    expr = _simplify_expression("".join(expr), constant, variables)
    expr = sympify(_denumerate_constants(str(expr), constant), evaluate=False)
    expr = _sympy_to_sr(expr)
    if not _check_tree(expr, symbol_library):
        raise Exception(
            "Simplified expression contains invalid symbols. Possibly skip its simplification or add symbols to the SymbolLibrary."
        )

    if is_tree:
        return expr
    else:
        return expr.to_list(symbol_library=symbol_library, notation="infix")

expr_to_latex

expr_to_latex(expr: Union[Node, List[str]], symbol_library: Optional[SymbolLibrary] = None) -> str

Convert an expression to a LaTeX string.

Examples:

>>> expr_to_latex(["(", "X_0", "+", "X_1", ")"], SymbolLibrary.default_symbols())
'$X_{0} + X_{1}$'
>>> expr = Node("+", Node("X_0"), Node("1"))
>>> expr_to_latex(expr, SymbolLibrary.default_symbols())
'$1 + X_{0}$'

Parameters:

Name Type Description Default
expr Union[Node, List[str]]

Expression as a token list or a Node tree.

required
symbol_library Optional[SymbolLibrary]

Symbol library providing LaTeX templates. If None, falls back to the currently active library set via 'with SymbolLibrary(...) as sl:'. Defaults to None.

None

Returns:

Type Description
str

A LaTeX string of the form $...$, or an empty string if conversion fails.

Source code in SRToolkit/utils/expression_tree.py
def expr_to_latex(expr: Union[Node, List[str]], symbol_library: Optional[SymbolLibrary] = None) -> str:
    """
    Convert an expression to a LaTeX string.

    Examples:
        >>> expr_to_latex(["(", "X_0", "+", "X_1", ")"], SymbolLibrary.default_symbols())
        '$X_{0} + X_{1}$'
        >>> expr = Node("+", Node("X_0"), Node("1"))
        >>> expr_to_latex(expr, SymbolLibrary.default_symbols())
        '$1 + X_{0}$'

    Args:
        expr: Expression as a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        symbol_library: Symbol library providing LaTeX templates. If None, falls
            back to the currently active library set via
            'with SymbolLibrary(...) as sl:'. Defaults to None.

    Returns:
        A LaTeX string of the form ``$...$``, or an empty string if conversion fails.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_active()
    try:
        if isinstance(expr, Node):
            return expr.to_latex(symbol_library)
        elif isinstance(expr, list):
            return tokens_to_tree(expr, symbol_library).to_latex(symbol_library)
        else:
            raise Exception(
                f"Invalid type for expression {str(expr)}. Should be SRToolkit.utils.Node or a list of tokens."
            )
    except Exception as e:
        print(f"Error while converting expression {str(expr)} to LaTeX: {str(e)}")
        return ""

is_float

is_float(element: Any) -> bool

Return True if element can be interpreted as a floating-point number.

Examples:

>>> is_float(1.0)
True
>>> is_float("1.0")
True
>>> is_float("1")
True
>>> is_float(None)
False
>>> is_float("hello")
False

Parameters:

Name Type Description Default
element Any

Value to test.

required

Returns:

Type Description
bool

True if float(element) succeeds, False otherwise (including None).

Source code in SRToolkit/utils/expression_tree.py
def is_float(element: Any) -> bool:
    """
    Return ``True`` if ``element`` can be interpreted as a floating-point number.

    Examples:
        >>> is_float(1.0)
        True
        >>> is_float("1.0")
        True
        >>> is_float("1")
        True
        >>> is_float(None)
        False
        >>> is_float("hello")
        False

    Args:
        element: Value to test.

    Returns:
        ``True`` if ``float(element)`` succeeds, ``False`` otherwise (including ``None``).
    """
    if element is None:
        return False
    try:
        float(element)
        return True
    except ValueError:
        return False

tokens_to_tree

tokens_to_tree(tokens: List[str], sl: Optional[SymbolLibrary] = None) -> Node

Parse a token list into an expression tree using the shunting-yard algorithm.

Examples:

>>> tree = tokens_to_tree(["(", "X_0", "+", "X_1", ")"], SymbolLibrary.default_symbols())
>>> len(tree)
3

Parameters:

Name Type Description Default
tokens List[str]

Token list in infix notation.

required
sl Optional[SymbolLibrary]

Symbol library used to resolve token types and precedences. If None, falls back to the currently active library set via 'with SymbolLibrary(...) as sl:'. Defaults to None.

None

Returns:

Type Description
Node

Root Node of the parsed expression tree.

Raises:

Type Description
Exception

If a token is absent from sl, or if the expression is syntactically invalid.

Source code in SRToolkit/utils/expression_tree.py
def tokens_to_tree(tokens: List[str], sl: Optional[SymbolLibrary] = None) -> Node:
    """
    Parse a token list into an expression tree using the shunting-yard algorithm.

    Examples:
        >>> tree = tokens_to_tree(["(", "X_0", "+", "X_1", ")"], SymbolLibrary.default_symbols())
        >>> len(tree)
        3

    Args:
        tokens: Token list in infix notation.
        sl: Symbol library used to resolve token types and precedences. If None,
            falls back to the currently active library set via
            'with SymbolLibrary(...) as sl:'. Defaults to None.

    Returns:
        Root [Node][SRToolkit.utils.expression_tree.Node] of the parsed expression tree.

    Raises:
        Exception: If a token is absent from ``sl``, or if the expression is
            syntactically invalid.
    """
    if sl is None:
        sl = SymbolLibrary.get_active()
    num_tokens = len([t for t in tokens if t != "(" and t != ")"])
    expr_str = "".join(tokens)
    tokens = ["("] + tokens + [")"]
    operator_stack = []
    out_stack = []
    for token in tokens:
        if token == "(":
            operator_stack.append(token)
        elif sl.get_type(token) in ["var", "const", "lit"] or is_float(token):
            out_stack.append(Node(token))
        elif sl.get_type(token) == "fn":
            if token[0] == "^":
                out_stack.append(Node(token, left=out_stack.pop()))
            else:
                operator_stack.append(token)
        elif sl.get_type(token) == "op":
            while (
                len(operator_stack) > 0
                and operator_stack[-1] != "("
                and sl.get_precedence(operator_stack[-1]) >= sl.get_precedence(token)
            ):
                if sl.get_type(operator_stack[-1]) == "fn":
                    out_stack.append(Node(operator_stack.pop(), left=out_stack.pop()))
                else:
                    out_stack.append(Node(operator_stack.pop(), out_stack.pop(), out_stack.pop()))
            operator_stack.append(token)
        else:
            if token != ")":
                raise Exception(
                    f'Invalid symbol "{token}" in expression {expr_str}. Did you add token "{token}" to the symbol library?'
                )

            while len(operator_stack) > 0 and operator_stack[-1] != "(":
                if sl.get_type(operator_stack[-1]) == "fn":
                    out_stack.append(Node(operator_stack.pop(), left=out_stack.pop()))
                else:
                    out_stack.append(Node(operator_stack.pop(), out_stack.pop(), out_stack.pop()))
            operator_stack.pop()
            if len(operator_stack) > 0 and sl.get_type(operator_stack[-1]) == "fn":
                out_stack.append(Node(operator_stack.pop(), left=out_stack.pop()))
    if len(out_stack[-1]) == num_tokens:
        return out_stack[-1]
    else:
        raise Exception(f"Error while parsing expression {expr_str}.")

bed

bed(expr1: Union[Node, List[str], ndarray], expr2: Union[Node, List[str], ndarray], X: Optional[ndarray] = None, num_consts_sampled: int = 32, num_points_sampled: int = 64, domain_bounds: Optional[List[Tuple[float, float]]] = None, consts_bounds: Tuple[float, float] = (-5, 5), symbol_library: Optional[SymbolLibrary] = None, seed: Optional[int] = None) -> float

Compute the Behavior-aware Expression Distance (BED) between two expressions.

BED measures how similarly two expressions behave over a domain by comparing their output distributions point-by-point using the Wasserstein distance. Free constants are marginalised by sampling multiple constant vectors via Latin Hypercube Sampling.

Either X or domain_bounds must be provided when expressions are given as token lists or Node trees. Pre-computed behavior matrices can be passed directly to avoid redundant evaluation.

Examples:

>>> X = np.random.rand(10, 2) - 0.5
>>> expr1 = ["X_0", "+", "C"] # instances of SRToolkit.utils.expression_tree.Node work as well
>>> expr2 = ["X_1", "+", "C"]
>>> bed(expr1, expr2, X) < 1
True
>>> # Changing the number of sampled constants
>>> bed(expr1, expr2, X, num_consts_sampled=128, consts_bounds=(-2, 2)) < 1
True
>>> # Sampling X instead of giving it directly by defining a domain
>>> bed(expr1, expr2, domain_bounds=[(0, 1), (0, 1)]) < 1
True
>>> bed(expr1, expr2, domain_bounds=[(0, 1), (0, 1)], num_points_sampled=128) < 1
True
>>> # You can use behavior matrices instead of expressions (this has potential computational advantages if same expression is used multiple times)
>>> bm1 = create_behavior_matrix(expr1, X)
>>> bed(bm1, expr2, X) < 1
True
>>> bm2 = create_behavior_matrix(expr2, X)
>>> bed(bm1, bm2) < 1
True

Parameters:

Name Type Description Default
expr1 Union[Node, List[str], ndarray]

First expression as a token list, a Node tree, or a pre-computed behavior matrix of shape (n_samples, num_consts_sampled).

required
expr2 Union[Node, List[str], ndarray]

Second expression in the same format as expr1.

required
X Optional[ndarray]

Evaluation points of shape (n_samples, n_features). Required unless both expressions are behavior matrices or domain_bounds is provided.

None
num_consts_sampled int

Number of constant vectors sampled per expression. Default 32.

32
num_points_sampled int

Number of points sampled from domain_bounds when X is None. Default 64.

64
domain_bounds Optional[List[Tuple[float, float]]]

Per-variable (lower, upper) bounds used to sample X via Latin Hypercube Sampling when X is None.

None
consts_bounds Tuple[float, float]

(lower, upper) bounds for constant sampling. Default (-5, 5).

(-5, 5)
symbol_library Optional[SymbolLibrary]

Symbol library used to compile expressions. Defaults to SymbolLibrary.default_symbols.

None
seed Optional[int]

Random seed for reproducible sampling. Default None.

None

Returns:

Type Description
float

BED between the expressions as a non-negative float. A value of 0.0

float

indicates identical behavior over the sampled domain; larger values indicate

float

greater behavioral dissimilarity. Returns inf if any evaluation point

float

produces finite outputs for one expression but not the other.

Raises:

Type Description
Exception

If X is None and neither domain_bounds is provided nor both expressions are pre-computed behavior matrices.

Exception

If X is None and exactly one expression is a pre-computed behavior matrix (the two matrices would be over different domains).

ValueError

If any entry in domain_bounds has a lower bound greater than or equal to its upper bound.

ValueError

If the two behavior matrices have different numbers of rows.

ValueError

If the behavior matrices have zero rows.

Source code in SRToolkit/utils/measures.py
def bed(
    expr1: Union[Node, List[str], np.ndarray],
    expr2: Union[Node, List[str], np.ndarray],
    X: Optional[np.ndarray] = None,
    num_consts_sampled: int = 32,
    num_points_sampled: int = 64,
    domain_bounds: Optional[List[Tuple[float, float]]] = None,
    consts_bounds: Tuple[float, float] = (-5, 5),
    symbol_library: Optional[SymbolLibrary] = None,
    seed: Optional[int] = None,
) -> float:
    """
    Compute the Behavior-aware Expression Distance (BED) between two expressions.

    BED measures how similarly two expressions behave over a domain by comparing
    their output distributions point-by-point using the Wasserstein distance. Free
    constants are marginalised by sampling multiple constant vectors via Latin
    Hypercube Sampling.

    Either ``X`` or ``domain_bounds`` must be provided when expressions are given as
    token lists or [Node][SRToolkit.utils.expression_tree.Node] trees. Pre-computed behavior matrices can be passed
    directly to avoid redundant evaluation.

    Examples:
        >>> X = np.random.rand(10, 2) - 0.5
        >>> expr1 = ["X_0", "+", "C"] # instances of SRToolkit.utils.expression_tree.Node work as well
        >>> expr2 = ["X_1", "+", "C"]
        >>> bed(expr1, expr2, X) < 1
        True
        >>> # Changing the number of sampled constants
        >>> bed(expr1, expr2, X, num_consts_sampled=128, consts_bounds=(-2, 2)) < 1
        True
        >>> # Sampling X instead of giving it directly by defining a domain
        >>> bed(expr1, expr2, domain_bounds=[(0, 1), (0, 1)]) < 1
        True
        >>> bed(expr1, expr2, domain_bounds=[(0, 1), (0, 1)], num_points_sampled=128) < 1
        True
        >>> # You can use behavior matrices instead of expressions (this has potential computational advantages if same expression is used multiple times)
        >>> bm1 = create_behavior_matrix(expr1, X)
        >>> bed(bm1, expr2, X) < 1
        True
        >>> bm2 = create_behavior_matrix(expr2, X)
        >>> bed(bm1, bm2) < 1
        True

    Args:
        expr1: First expression as a token list, a [Node][SRToolkit.utils.expression_tree.Node] tree, or a pre-computed
            behavior matrix of shape ``(n_samples, num_consts_sampled)``.
        expr2: Second expression in the same format as ``expr1``.
        X: Evaluation points of shape ``(n_samples, n_features)``. Required unless both
            expressions are behavior matrices or ``domain_bounds`` is provided.
        num_consts_sampled: Number of constant vectors sampled per expression. Default ``32``.
        num_points_sampled: Number of points sampled from ``domain_bounds`` when ``X`` is
            ``None``. Default ``64``.
        domain_bounds: Per-variable ``(lower, upper)`` bounds used to sample ``X`` via
            Latin Hypercube Sampling when ``X`` is ``None``.
        consts_bounds: ``(lower, upper)`` bounds for constant sampling. Default ``(-5, 5)``.
        symbol_library: Symbol library used to compile expressions. Defaults to
            [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].
        seed: Random seed for reproducible sampling. Default ``None``.

    Returns:
        BED between the expressions as a non-negative float. A value of ``0.0``
        indicates identical behavior over the sampled domain; larger values indicate
        greater behavioral dissimilarity. Returns ``inf`` if any evaluation point
        produces finite outputs for one expression but not the other.

    Raises:
        Exception: If ``X`` is ``None`` and neither ``domain_bounds`` is provided nor
            both expressions are pre-computed behavior matrices.
        Exception: If ``X`` is ``None`` and exactly one expression is a pre-computed
            behavior matrix (the two matrices would be over different domains).
        ValueError: If any entry in ``domain_bounds`` has a lower bound greater than
            or equal to its upper bound.
        ValueError: If the two behavior matrices have different numbers of rows.
        ValueError: If the behavior matrices have zero rows.
    """

    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    if X is None and not isinstance(expr1, np.ndarray) and not isinstance(expr2, np.ndarray):
        if domain_bounds is None:
            raise Exception(
                "If X is not given and both expressions are not given as a behavior matrix, "
                "then domain_bounds parameter must be given"
            )
        for i, (lb, ub) in enumerate(domain_bounds):
            if lb >= ub:
                raise ValueError(f"domain_bounds[{i}] has lower bound ({lb}) >= upper bound ({ub}).")
        interval_length = np.array([ub - lb for (lb, ub) in domain_bounds])
        lower_bound = np.array([lb for (lb, ub) in domain_bounds])
        lho = LatinHypercube(len(domain_bounds), optimization="random-cd", rng=seed)
        X = lho.random(num_points_sampled) * interval_length + lower_bound
    elif X is None and (isinstance(expr1, np.ndarray) != isinstance(expr2, np.ndarray)):
        raise Exception(
            "If X is not given, both expressions must be given as a behavior matrix or as a list of "
            "tokens/SRToolkit.utils.Node objects. Otherwise, behavior matrices are uncomparable."
        )

    if isinstance(expr1, list) or isinstance(expr1, Node):
        assert X is not None
        expr1 = create_behavior_matrix(expr1, X, num_consts_sampled, consts_bounds, symbol_library, seed)

    if isinstance(expr2, list) or isinstance(expr2, Node):
        assert X is not None
        expr2 = create_behavior_matrix(expr2, X, num_consts_sampled, consts_bounds, symbol_library, seed)

    if expr1.shape[0] != expr2.shape[0]:
        raise ValueError("Behavior matrices must have the same number of rows (points on which behavior is evaluated).")
    if expr1.shape[0] == 0:
        raise ValueError(
            "Behavior matrices must have at least one row. If your expressions are given as behavior "
            "matrices, make sure they are not empty. Otherwise, if X is given, make sure it contains "
            "at least one point. If X is not given, make sure num_points_sampled is greater than 0."
        )
    n = expr1.shape[0]

    finite_any1 = np.any(np.isfinite(expr1), axis=1)
    finite_any2 = np.any(np.isfinite(expr2), axis=1)

    if np.any(finite_any1 ^ finite_any2):
        return np.inf

    active = finite_any1 & finite_any2
    e1, e2 = expr1[active], expr2[active]

    if e1.shape[0] == 0:
        return 0.0

    if np.any(np.isinf(e1)) or np.any(np.isinf(e2)):
        return np.inf

    has_nan = ~(np.all(np.isfinite(e1), axis=1) & np.all(np.isfinite(e2), axis=1))
    wds = np.empty(e1.shape[0])

    if np.any(~has_nan):
        wds[~has_nan] = _vectorized_wasserstein_batch(e1[~has_nan], e2[~has_nan])

    for i in np.where(has_nan)[0]:
        wds[i] = _custom_wasserstein(e1[i][~np.isnan(e1[i])], e2[i][~np.isnan(e2[i])])

    return float(np.sum(wds) / n)

create_behavior_matrix

create_behavior_matrix(expr: Union[Node, List[str]], X: ndarray, num_consts_sampled: int = 32, consts_bounds: Tuple[float, float] = (-5, 5), symbol_library: Optional[SymbolLibrary] = None, seed: Optional[int] = None) -> np.ndarray

Evaluate an expression over multiple constant samples to produce a behavior matrix.

For expressions with free constants, constants are drawn via Latin Hypercube Sampling within consts_bounds. For constant-free expressions, all columns are identical.

Examples:

>>> X = np.random.rand(10, 2) - 0.5
>>> create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32).shape
(10, 32)
>>> mean_0_1 = np.mean(create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32, consts_bounds=(0, 1)))
>>> mean_1_5 = np.mean(create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32, consts_bounds=(1, 5)))
>>> print(bool(mean_0_1 < mean_1_5))
True
>>> # Deterministic expressions always produce the same behavior matrix
>>> bm1 = create_behavior_matrix(["X_0", "+", "X_1"], X)
>>> bm2 = create_behavior_matrix(["X_0", "+", "X_1"], X)
>>> print(bool(np.array_equal(bm1, bm2)))
True

Parameters:

Name Type Description Default
expr Union[Node, List[str]]

Expression as a token list or a Node tree.

required
X ndarray

Input data of shape (n_samples, n_features) at which the expression is evaluated.

required
num_consts_sampled int

Number of constant vectors to sample; sets the number of output columns. Default 32.

32
consts_bounds Tuple[float, float]

(lower, upper) bounds for constant sampling. Default (-5, 5).

(-5, 5)
symbol_library Optional[SymbolLibrary]

Symbol library used to compile the expression. Defaults to SymbolLibrary.default_symbols.

None
seed Optional[int]

Random seed for reproducible constant sampling. Default None.

None

Returns:

Type Description
ndarray

Behavior matrix of shape (n_samples, num_consts_sampled).

Raises:

Type Description
Exception

If expr is neither a token list nor a Node.

Source code in SRToolkit/utils/measures.py
def create_behavior_matrix(
    expr: Union[Node, List[str]],
    X: np.ndarray,
    num_consts_sampled: int = 32,
    consts_bounds: Tuple[float, float] = (-5, 5),
    symbol_library: Optional[SymbolLibrary] = None,
    seed: Optional[int] = None,
) -> np.ndarray:
    """
    Evaluate an expression over multiple constant samples to produce a behavior matrix.

    For expressions with free constants, constants are drawn via Latin Hypercube Sampling
    within ``consts_bounds``. For constant-free expressions, all columns are identical.

    Examples:
        >>> X = np.random.rand(10, 2) - 0.5
        >>> create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32).shape
        (10, 32)
        >>> mean_0_1 = np.mean(create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32, consts_bounds=(0, 1)))
        >>> mean_1_5 = np.mean(create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32, consts_bounds=(1, 5)))
        >>> print(bool(mean_0_1 < mean_1_5))
        True
        >>> # Deterministic expressions always produce the same behavior matrix
        >>> bm1 = create_behavior_matrix(["X_0", "+", "X_1"], X)
        >>> bm2 = create_behavior_matrix(["X_0", "+", "X_1"], X)
        >>> print(bool(np.array_equal(bm1, bm2)))
        True

    Args:
        expr: Expression as a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        X: Input data of shape ``(n_samples, n_features)`` at which the expression is
            evaluated.
        num_consts_sampled: Number of constant vectors to sample; sets the number of
            output columns. Default ``32``.
        consts_bounds: ``(lower, upper)`` bounds for constant sampling. Default ``(-5, 5)``.
        symbol_library: Symbol library used to compile the expression. Defaults to
            [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].
        seed: Random seed for reproducible constant sampling. Default ``None``.

    Returns:
        Behavior matrix of shape ``(n_samples, num_consts_sampled)``.

    Raises:
        Exception: If ``expr`` is neither a token list nor a [Node][SRToolkit.utils.expression_tree.Node].
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    const_symbols = symbol_library.get_symbols_of_type("const")

    if isinstance(expr, list):
        tokens = expr
    elif isinstance(expr, Node):
        tokens = expr.to_list(notation="postfix")
    else:
        raise TypeError("Expression must be given as a list of strings or a Node tree.")

    num_constants = sum(tokens.count(c) for c in const_symbols)

    callable_expr = compile_expr(expr, symbol_library, backend="stack")

    with np.errstate(divide="ignore", invalid="ignore", over="ignore", under="ignore"):
        if num_constants > 0:
            lho = LatinHypercube(num_constants, rng=seed)
            constants = lho.random(num_consts_sampled) * (consts_bounds[1] - consts_bounds[0]) + consts_bounds[0]
            ys = []
            for c in constants:
                ys.append(callable_expr(X, c))
            return np.array(ys).T
        else:
            return np.repeat(callable_expr(X, None)[:, None], num_consts_sampled, axis=1)

edit_distance

edit_distance(expr1: Union[List[str], Node], expr2: Union[List[str], Node], notation: str = 'postfix', symbol_library: Optional[SymbolLibrary] = None) -> int

Compute the edit distance between two expressions.

Both expressions are first converted to the requested notation, so the result is independent of whether a token list or a Node tree is passed. Levenshtein distance is then computed on the serialised token sequences.

Examples:

>>> edit_distance(["X_0", "+", "1"], ["X_0", "+", "1"])
0
>>> edit_distance(["X_0", "+", "1"], ["X_0", "-", "1"])
1
>>> edit_distance(tokens_to_tree(["X_0", "+", "1"], SymbolLibrary.default_symbols(1)), tokens_to_tree(["X_0", "-", "1"], SymbolLibrary.default_symbols(1)))
1

Parameters:

Name Type Description Default
expr1 Union[List[str], Node]

First expression as a token list or a Node tree.

required
expr2 Union[List[str], Node]

Second expression as a token list or a Node tree.

required
notation str

Notation used for comparison: "infix", "prefix", or "postfix". Defaults to "postfix" to avoid parenthesis artefacts.

'postfix'
symbol_library Optional[SymbolLibrary]

Symbol library used when converting expressions to the target notation. Defaults to SymbolLibrary.default_symbols.

None

Returns:

Type Description
int

Integer edit distance between the two serialised expressions.

Source code in SRToolkit/utils/measures.py
def edit_distance(
    expr1: Union[List[str], Node],
    expr2: Union[List[str], Node],
    notation: str = "postfix",
    symbol_library: Optional[SymbolLibrary] = None,
) -> int:
    """
    Compute the edit distance between two expressions.

    Both expressions are first converted to the requested notation, so the result
    is independent of whether a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree is passed.
    Levenshtein distance is then computed on the serialised token sequences.

    Examples:
        >>> edit_distance(["X_0", "+", "1"], ["X_0", "+", "1"])
        0
        >>> edit_distance(["X_0", "+", "1"], ["X_0", "-", "1"])
        1
        >>> edit_distance(tokens_to_tree(["X_0", "+", "1"], SymbolLibrary.default_symbols(1)), tokens_to_tree(["X_0", "-", "1"], SymbolLibrary.default_symbols(1)))
        1

    Args:
        expr1: First expression as a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        expr2: Second expression as a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        notation: Notation used for comparison: ``"infix"``, ``"prefix"``, or
            ``"postfix"``. Defaults to ``"postfix"`` to avoid parenthesis artefacts.
        symbol_library: Symbol library used when converting expressions to the target
            notation. Defaults to [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].

    Returns:
        Integer edit distance between the two serialised expressions.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    if isinstance(expr1, Node):
        expr1 = expr1.to_list(symbol_library=symbol_library, notation=notation)
    elif isinstance(expr1, list):
        expr1 = tokens_to_tree(expr1, symbol_library).to_list(symbol_library=symbol_library, notation=notation)

    if isinstance(expr2, Node):
        expr2 = expr2.to_list(symbol_library=symbol_library, notation=notation)
    elif isinstance(expr2, list):
        expr2 = tokens_to_tree(expr2, symbol_library).to_list(symbol_library=symbol_library, notation=notation)

    return editdistance.eval(expr1, expr2)

tree_edit_distance

tree_edit_distance(expr1: Union[Node, List[str]], expr2: Union[Node, List[str]], symbol_library: Optional[SymbolLibrary] = None) -> int

Compute the Zhang-Shasha tree edit distance between two expressions.

Unlike edit_distance, which operates on flattened token sequences, tree edit distance considers the expression's hierarchical structure. The cost is the minimum number of node insertions, deletions, and relabellings needed to transform one tree into the other.

Examples:

>>> tree_edit_distance(["X_0", "+", "1"], ["X_0", "+", "1"])
0
>>> tree_edit_distance(["X_0", "+", "1"], ["X_0", "-", "1"])
1
>>> tree_edit_distance(tokens_to_tree(["X_0", "+", "1"], SymbolLibrary.default_symbols(1)), tokens_to_tree(["X_0", "-", "1"], SymbolLibrary.default_symbols(1)))
1

Parameters:

Name Type Description Default
expr1 Union[Node, List[str]]

First expression as a token list or a Node tree.

required
expr2 Union[Node, List[str]]

Second expression as a token list or a Node tree.

required
symbol_library Optional[SymbolLibrary]

Symbol library used when converting token lists to trees. Defaults to SymbolLibrary.default_symbols.

None

Returns:

Type Description
int

Integer tree edit distance.

Source code in SRToolkit/utils/measures.py
def tree_edit_distance(
    expr1: Union[Node, List[str]],
    expr2: Union[Node, List[str]],
    symbol_library: Optional[SymbolLibrary] = None,
) -> int:
    """
    Compute the Zhang-Shasha tree edit distance between two expressions.

    Unlike [edit_distance][SRToolkit.utils.measures.edit_distance], which operates on flattened
    token sequences, tree edit distance considers the expression's hierarchical structure.
    The cost is the minimum number of node insertions, deletions, and relabellings needed
    to transform one tree into the other.

    Examples:
        >>> tree_edit_distance(["X_0", "+", "1"], ["X_0", "+", "1"])
        0
        >>> tree_edit_distance(["X_0", "+", "1"], ["X_0", "-", "1"])
        1
        >>> tree_edit_distance(tokens_to_tree(["X_0", "+", "1"], SymbolLibrary.default_symbols(1)), tokens_to_tree(["X_0", "-", "1"], SymbolLibrary.default_symbols(1)))
        1

    Args:
        expr1: First expression as a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        expr2: Second expression as a token list or a [Node][SRToolkit.utils.expression_tree.Node] tree.
        symbol_library: Symbol library used when converting token lists to trees.
            Defaults to [SymbolLibrary.default_symbols][SRToolkit.utils.symbol_library.SymbolLibrary.default_symbols].

    Returns:
        Integer tree edit distance.
    """
    if symbol_library is None:
        symbol_library = SymbolLibrary.get_or_default()
    if isinstance(expr1, Node):
        zss1 = _expr_to_zss(expr1)
    elif isinstance(expr1, list):
        zss1 = _expr_to_zss(tokens_to_tree(expr1, symbol_library))

    if isinstance(expr2, Node):
        zss2 = _expr_to_zss(expr2)
    elif isinstance(expr2, list):
        zss2 = _expr_to_zss(tokens_to_tree(expr2, symbol_library))

    return int(zss.simple_distance(zss1, zss2))