Utils Submodule
SRToolkit.utils
Utilities for expression representation, compilation, generation, and evaluation.
Modules:
| Name | Description |
|---|---|
symbol_library |
The SymbolLibrary class — manages the token vocabulary and token properties. |
expression_tree |
The Node binary-tree representation and conversion utilities for expressions. |
expression_compiler |
Compiles token-list or tree expressions into executable Python callables. |
expression_simplifier |
SymPy-backed algebraic simplification, including constant folding. |
expression_generator |
PCFG construction from a SymbolLibrary and Monte-Carlo expression sampling. |
grammar |
CFG/PCFG representation, constraint protocol, and stateful derivation — Rule, Grammar, Constraint, Derivation, and more. |
measures |
Distance and similarity measures: edit distance, tree edit distance, and Behavior-aware Expression Distance (BED). |
serialization |
Internal JSON serialization utilities for numpy types. |
Node
A node in a binary expression tree.
- Binary operators (
"op") set bothleftandright. - Unary functions (
"fn") set onlyleft;rightisNone. - Leaves (variables, constants, literals, numeric values) have both children as
None.
Examples:
Warning
The second positional argument is right, not left. When passing
children positionally (e.g. Node("+", Node("a"), Node("b"))),
Node("a") becomes the right child and Node("b") the left.
Use keyword arguments to avoid confusion: Node("+", right=Node("a"), left=Node("b")).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token string stored at this node. |
required |
right
|
Optional[Node]
|
Right operand (binary operators only). |
None
|
left
|
Optional[Node]
|
Left operand (operators and unary functions). |
None
|
Source code in SRToolkit/utils/expression_tree.py
to_list
Transforms the tree rooted at this node into a list of tokens.
Examples:
>>> node = Node("+", Node("X_0"), Node("1"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_0']
>>> node.to_list(notation="postfix")
['1', 'X_0', '+']
>>> node.to_list(notation="prefix")
['+', '1', 'X_0']
>>> node = Node("+", Node("*", Node("X_0"), Node("X_1")), Node("1"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_1', '*', 'X_0']
>>> node.to_list(notation="infix")
['1', '+', '(', 'X_1', '*', 'X_0', ')']
>>> node = Node("sin", None, Node("X_0"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['sin', '(', 'X_0', ')']
>>> node = Node("^2", None, Node("X_0"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['X_0', '^2']
>>> node.to_list()
['(', 'X_0', ')', '^2']
>>> node = Node("*", Node("*", Node("X_0"), Node("X_0")), Node("X_0"))
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols(),notation="infix")
['X_0', '*', '(', 'X_0', '*', 'X_0', ')']
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used to determine token types and precedences
during infix reconstruction. If |
None
|
notation
|
str
|
Output notation: |
'infix'
|
Returns:
| Type | Description |
|---|---|
List[str]
|
Token list representing the subtree rooted at this node. |
Raises:
| Type | Description |
|---|---|
Exception
|
If |
Exception
|
If |
Source code in SRToolkit/utils/expression_tree.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | |
to_latex
Transforms the tree rooted at this node into a LaTeX expression.
Examples:
>>> node = Node("+", right=Node("X_0"), left=Node("1"))
>>> node.to_latex(symbol_library=SymbolLibrary.default_symbols())
'$1 + X_{0}$'
>>> node = Node("+", right=Node("*", right=Node("X_0"), left=Node("X_1")), left=Node("1"))
>>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
$1 + X_{1} \cdot X_{0}$
>>> node = Node("sin", None, Node("X_0"))
>>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
$\sin X_{0}$
>>> node = Node("+", right=Node("*", right=Node("X_0"), left=Node("C")), left=Node("C"))
>>> print(node.to_latex(symbol_library=SymbolLibrary.default_symbols()))
$C_{0} + C_{1} \cdot X_{0}$
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library providing LaTeX templates for each token. If None, falls back to the currently active library set via 'with SymbolLibrary(...) as sl:'. Defaults to None. |
None
|
Returns:
| Type | Description |
|---|---|
str
|
A LaTeX string of the form |
Raises:
| Type | Description |
|---|---|
Exception
|
If the tree contains a token whose type cannot be resolved in
|
Source code in SRToolkit/utils/expression_tree.py
height
Return the height of the subtree rooted at this node.
A single-node tree has height 1.
Examples:
Returns:
| Type | Description |
|---|---|
int
|
Height of the subtree. |
Source code in SRToolkit/utils/expression_tree.py
__len__
Return the number of nodes in the subtree rooted at this node.
Examples:
Returns:
| Type | Description |
|---|---|
int
|
Total node count of the subtree. |
Source code in SRToolkit/utils/expression_tree.py
__str__
Return the expression as a concatenated string using default infix notation that may contain redundant parentheses.
Examples:
Returns:
| Type | Description |
|---|---|
str
|
Concatenated token string with no spaces. |
Source code in SRToolkit/utils/expression_tree.py
__copy__
Return a deep copy of the subtree rooted at this node.
Examples:
>>> node = Node("+", Node("X_0"), Node("1"))
>>> new_node = copy(node)
>>> node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_0']
>>> new_node.to_list(symbol_library=SymbolLibrary.default_symbols())
['1', '+', 'X_0']
>>> node == node
True
>>> node == new_node
False
Returns:
| Type | Description |
|---|---|
Node
|
An independent copy of the subtree. |
Source code in SRToolkit/utils/expression_tree.py
AncestorInfo
dataclass
One entry in an ExpansionContext's ancestor chain.
Attributes:
| Name | Type | Description |
|---|---|---|
nonterminal |
str
|
The non-terminal symbol at this ancestor position. |
rule |
Rule
|
The rule applied at this position. |
child_index |
int
|
The index within |
Constraint
Bases: Generic[L, G]
Base class for derivation constraints (hard filters).
Subclass and override the methods you need. Constraints carry only construction-time configuration; all per-derivation state is managed by the engine and threaded through method arguments.
Scoping — set nonterminals and/or rule_names to restrict
allows to a subset of slots
or rules. A scope miss is treated as implicit acceptance.
update is always called
regardless of scope so that global counters stay accurate.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([Rule("E", ["x"]), Rule("E", ["y"])])
>>> class RejectY(Constraint):
... def allows(self, slot, rule, global_): return rule.rhs != ["y"]
>>> g.add_constraint(RejectY())
>>> d = g.start_derivation("E")
>>> [r.rhs for r in d.options()]
[['x']]
to_dict
Serialise this constraint to a JSON-safe dictionary.
The base implementation stores only the fully-qualified class path under
constraint_class. Subclasses should call super().to_dict() and
add their own constructor arguments so that from_dict can reconstruct
the instance faithfully. See
from_dict
for an example.
Returns:
| Type | Description |
|---|---|
dict
|
Dictionary with at least the key |
Source code in SRToolkit/utils/grammar/constraints.py
from_dict
classmethod
Reconstruct a constraint from a dictionary produced by to_dict.
When called on the base Constraint
class, dispatches to the correct subclass via the constraint_class key using
importlib — both built-in and user-defined subclasses are supported.
When called on a concrete subclass, the subclass must override this method.
The dictionary must contain at minimum the key constraint_class, whose value
is the fully-qualified class path (e.g. "mymodule.MyConstraint"). Any
additional keys are forwarded to the subclass override.
To make a custom subclass serialisable, override both to_dict and
from_dict::
class MyConstraint(Constraint):
def __init__(self, threshold: float) -> None:
self.threshold = threshold
def to_dict(self) -> dict:
return {**super().to_dict(), "threshold": self.threshold}
@classmethod
def from_dict(cls, d: dict) -> "MyConstraint":
return cls(d["threshold"])
Examples:
>>> from SRToolkit.utils.grammar import Constraint, MaxDepth
>>> c = Constraint.from_dict(MaxDepth(5).to_dict())
>>> isinstance(c, MaxDepth) and c.limit == 5
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d
|
dict
|
Dictionary previously returned by |
required |
Returns:
| Type | Description |
|---|---|
'Constraint'
|
A reconstructed Constraint instance. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If |
ImportError
|
If the module cannot be imported (dispatch path). |
AttributeError
|
If the class cannot be found in the module (dispatch path). |
NotImplementedError
|
If called on a subclass that has not overridden this method. |
Source code in SRToolkit/utils/grammar/constraints.py
initial_local
initial_global
allows
Return True if rule may be applied at slot.
Called only when the slot's non-terminal and rule name are within scope.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
slot
|
ExpansionContext[L]
|
Current derivation position with this constraint's local state. |
required |
rule
|
Rule
|
Candidate production rule. |
required |
global_
|
G
|
Current per-derivation global state. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in SRToolkit/utils/grammar/constraints.py
update
Called after rule is applied at slot.
Returns per-child local states (one per non-terminal in rule.rhs,
in order) and the new global state. May be used to update global
counters by returning a new global_ value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
slot
|
ExpansionContext[L]
|
Derivation position immediately before the rule was applied. |
required |
rule
|
Rule
|
The rule that was applied. |
required |
global_
|
G
|
Global state before this application. |
required |
Returns:
| Type | Description |
|---|---|
list[L]
|
|
G
|
the number of non-terminals in |
Source code in SRToolkit/utils/grammar/constraints.py
Derivation
Stateful leftmost derivation over a Grammar.
A derivation begins at a start non-terminal and proceeds by repeatedly choosing a rule to apply to the leftmost unexpanded non-terminal. options returns candidate rules filtered by all registered constraints; apply advances the derivation by one production.
Per-slot local state and per-derivation global state for each registered constraint are maintained internally; constraint instances carry only construction-time configuration and are safe to share across parallel derivations.
Obtain a Derivation via Grammar.start_derivation.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.complete
False
>>> opts = d.options()
>>> len(opts)
1
>>> d.apply(opts[0])
>>> d.complete
True
>>> d.to_token_list()
['x']
Source code in SRToolkit/utils/grammar/derivation.py
complete
property
local_stack
Return the local state stack for constraint across the open frontier,
leftmost slot first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
constraint
|
Constraint
|
A constraint previously registered on the grammar. |
required |
Returns:
| Type | Description |
|---|---|
list[Any]
|
One entry per open non-terminal in left-to-right order. |
Source code in SRToolkit/utils/grammar/derivation.py
global_state
Return the per-derivation global state for constraint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
constraint
|
Constraint
|
A constraint previously registered on the grammar. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The current global value owned by this derivation for |
Source code in SRToolkit/utils/grammar/derivation.py
options
Return candidate rules for the current leftmost unexpanded non-terminal, filtered by every registered constraint's allows.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> g.add_rule(Rule("E", ["y"]))
>>> d = g.start_derivation("E")
>>> len(d.options())
2
Returns:
| Type | Description |
|---|---|
list[Rule]
|
List of Rule objects that every |
list[Rule]
|
constraint accepts. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the derivation is already complete. |
Source code in SRToolkit/utils/grammar/derivation.py
apply
Apply rule to the current leftmost unexpanded non-terminal.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["E", "+", "F"]))
>>> g.add_rule(Rule("E", ["x"]))
>>> g.add_rule(Rule("F", ["y"]))
>>> d = g.start_derivation("E")
>>> d.apply(g.rules_for("E")[0]) # E -> E + F
>>> d.apply(g.rules_for("E")[1]) # E -> x
>>> d.apply(g.rules_for("F")[0]) # F -> y
>>> d.to_token_list()
['x', '+', 'y']
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rule
|
Rule
|
A rule whose |
required |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the derivation is already complete. |
ValueError
|
If |
Source code in SRToolkit/utils/grammar/derivation.py
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | |
sample
Apply one rule chosen proportionally by weight (PCFG) or uniformly (CFG) from the surviving candidates.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.sample()
>>> d.complete
True
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the derivation is already complete. |
RuntimeError
|
If all candidate rules are filtered out by allows. |
Source code in SRToolkit/utils/grammar/derivation.py
generate
Run the derivation to completion and return the token list.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> g.start_derivation("E").generate()
['x']
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
limit
|
int
|
Maximum number of rule applications. A negative value means
unlimited. Default |
1000
|
Returns:
| Type | Description |
|---|---|
list[str]
|
Flat list of terminal tokens in left-to-right order. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the derivation does not complete within
|
Source code in SRToolkit/utils/grammar/derivation.py
to_token_list
Return the completed expression as a flat token list.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.apply(d.options()[0])
>>> d.to_token_list()
['x']
Returns:
| Type | Description |
|---|---|
list[str]
|
Flat list of terminal tokens in left-to-right order. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the derivation is not yet complete. |
Source code in SRToolkit/utils/grammar/derivation.py
to_parse_tree
Return the completed derivation as a ParseTree.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation("E")
>>> d.apply(d.options()[0])
>>> isinstance(d.to_parse_tree(), ParseTree)
True
Returns:
| Type | Description |
|---|---|
ParseTree
|
The ParseTree rooted at the |
ParseTree
|
start symbol. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If the derivation is not yet complete. |
Source code in SRToolkit/utils/grammar/derivation.py
DimensionalConsistency
DimensionalConsistency(variable_units: dict[str, dict], target_unit: dict, constant_units: Optional[dict[str, dict]] = None, symbol_library: Optional[SymbolLibrary] = None, allow_unit_polymorphic_constants: bool = False)
Bases: Constraint[Optional[Unit], None]
Stateful constraint that enforces dimensional analysis during expression generation.
Local state at each slot is the required unit (Optional[Unit]).
None means "any unit is acceptable here" — used for the children of
multiplicative operators, whose individual units are underdetermined.
Unit representation: units are dict[str, Fraction] mapping base
dimension names (e.g. "m", "s", "kg") to rational exponents.
Dimensionless is {} (empty dict).
Free constants (const type): treated as dimensionless by default.
Set allow_unit_polymorphic_constants=True to let them absorb any
required unit.
Named physical constants (lit type, e.g. "g", "c"):
declared via constant_units; checked like variables.
Undeclared variables or literals: conservatively accepted.
Rule classification: uses rule.name when available (see
Grammar.from_symbol_library
for the naming scheme). Falls back to rhs-shape heuristics for
unnamed rules.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> from SRToolkit.utils.grammar import DimensionalConsistency
>>> from fractions import Fraction
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["E", "+", "F"], weight=0.6, name="E_add_+"))
>>> g.add_rule(Rule("E", ["F"], weight=0.4, name="E_to_F"))
>>> g.add_rule(Rule("F", ["v"], weight=0.5, name="F_v"))
>>> g.add_rule(Rule("F", ["t"], weight=0.5, name="F_t"))
>>> dc = DimensionalConsistency(
... variable_units={"v": {"m": 1, "s": -1}, "t": {"s": 1}},
... target_unit={"m": 1, "s": -1},
... )
>>> g.add_constraint(dc)
>>> d = g.start_derivation("E")
>>> all(r.rhs != ["t"] for r in d.options())
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
variable_units
|
dict[str, dict]
|
Mapping from variable token to its unit. |
required |
target_unit
|
dict
|
Required unit of the generated expression. |
required |
constant_units
|
Optional[dict[str, dict]]
|
Units for named physical constants ( |
None
|
symbol_library
|
Optional[SymbolLibrary]
|
Used to classify tokens by type and precedence. |
None
|
allow_unit_polymorphic_constants
|
bool
|
If |
False
|
Source code in SRToolkit/utils/grammar/constraints.py
ExpansionContext
dataclass
ExpansionContext(nonterminal: str, local: L, steps: int, parent_rule: Optional[Rule], child_index: Optional[int], ancestors: tuple[AncestorInfo, ...], partial_tree: ParseTreeNode, nonterminals: frozenset[str], frontier_size: int = 1)
Bases: Generic[L]
Read-only view of the current derivation position passed to constraints.
The engine constructs an ExpansionContext on demand
for each (constraint, candidate_rule) pair evaluated in
Derivation.options and for
each (constraint, selected_rule) pair in
Derivation.apply.
Attributes:
| Name | Type | Description |
|---|---|---|
nonterminal |
str
|
The non-terminal symbol being expanded at this position. |
local |
L
|
This constraint's per-slot inherited state. |
steps |
int
|
Number of rule applications made so far in the derivation
(i.e. how many times |
parent_rule |
Optional[Rule]
|
The rule whose application created this slot ( |
child_index |
Optional[int]
|
Index of this slot in |
ancestors |
tuple[AncestorInfo, ...]
|
Ancestor chain from the root to this slot's immediate parent, in root-first order. |
partial_tree |
ParseTreeNode
|
Root of the partially-built parse tree. Read-only by contract — constraints must not mutate it. |
nonterminals |
frozenset[str]
|
Frozen set of all non-terminal symbols in the grammar. |
frontier_size |
int
|
Number of open (unexpanded) non-terminal slots in the current derivation frontier, including this slot. Useful for budget constraints such as MaxNodes that need to account for the minimum number of rule applications still required to complete the derivation. |
Grammar
A context-free grammar (CFG) or probabilistic context-free grammar (PCFG).
Rules are added via add_rule. The set of
non-terminals is derived automatically: a symbol is a non-terminal if and only
if it appears as the lhs of at least one rule.
A grammar is a CFG when every rule carries the default weight (1.0), making
sampling uniform within each group. It is a PCFG when any rule has a weight
that differs from 1.0.
Constraints are registered via add_constraint. During a derivation, options returns only rules that every constraint's allows accepts.
Examples:
>>> g = Grammar([
... Rule("E", ["E", "+", "F"], weight=0.4),
... Rule("E", ["F"], weight=0.6),
... Rule("F", ["x"]),
... ])
>>> "E" in g.nonterminals
True
>>> g.is_pcfg()
True
>>> len(g.rules_for("E"))
2
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rules
|
Optional[list[Rule]]
|
None
|
|
start
|
Optional[str]
|
Default start non-terminal used by
start_derivation
when no |
None
|
Source code in SRToolkit/utils/grammar/grammar.py
nonterminals
property
add_rule
Add a production rule to the grammar.
Examples:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rule
|
Rule
|
The Rule to register. |
required |
Source code in SRToolkit/utils/grammar/grammar.py
add_constraint
Register a constraint applied at each derivation step.
A rule is offered as an option only when every registered constraint's
allows returns True
for it.
Examples:
>>> from SRToolkit.utils.grammar import MaxDepth
>>> g = Grammar([Rule("E", ["E", "+", "E"]), Rule("E", ["x"])])
>>> g.add_constraint(MaxDepth(0))
>>> d = g.start_derivation("E")
>>> [r.rhs for r in d.options()]
[['x']]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
constraint
|
'Constraint'
|
A Constraint instance — typically a built-in such as MaxDepth, MaxNodes, MaxOccurrences, NoNested, or DimensionalConsistency, or a user-defined subclass. |
required |
Source code in SRToolkit/utils/grammar/grammar.py
rules_for
Return all rules whose lhs matches nonterminal.
Examples:
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["F"]))
>>> g.add_rule(Rule("E", ["x"]))
>>> len(g.rules_for("E"))
2
>>> g.rules_for("Z")
[]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nonterminal
|
str
|
Non-terminal to look up. |
required |
Returns:
| Type | Description |
|---|---|
list[Rule]
|
List of matching Rule objects in insertion order. |
Source code in SRToolkit/utils/grammar/grammar.py
is_pcfg
Return True if any rule deviates from the default weight of 1.0.
Examples:
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["x"]))
>>> g.add_rule(Rule("E", ["y"]))
>>> g.is_pcfg()
False
>>> g.add_rule(Rule("F", ["z"], weight=0.3))
>>> g.is_pcfg()
True
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in SRToolkit/utils/grammar/grammar.py
validate
Validate that parse_tree is structurally valid, uses only productions
from this grammar, and satisfies every registered constraint.
By default the root of the tree may be any non-terminal in the grammar,
not just self.start. This lets you validate sub-trees (e.g. an
F-rooted fragment) in addition to full expressions. Pass
require_start=True to additionally enforce that the root symbol equals
self.start (useful when validating complete, top-level expressions).
The check replays the parse tree through a fresh
Derivation rooted at
parse_tree.root.symbol, walking nodes in leftmost order (mirroring the
derivation frontier). At each internal node it checks that:
- The applied rule is permitted at the current frontier position — meaning it exists in the grammar and every registered constraint accepts it.
- The number of children equals
len(rule.rhs). - Each
children[i].symbolequalsrule.rhs[i].
Terminal leaves are accepted when their symbol is not a non-terminal in this grammar. A non-terminal symbol appearing as a leaf (unexpanded) is rejected.
Examples:
>>> g = Grammar(start="E")
>>> r = Rule("E", ["x"])
>>> g.add_rule(r)
>>> leaf = ParseTreeNode("x", None)
>>> root = ParseTreeNode("E", r, [leaf])
>>> g.validate(ParseTree(root))
True
>>> foreign = Rule("E", ["y"])
>>> root2 = ParseTreeNode("E", foreign, [ParseTreeNode("y", None)])
>>> g.validate(ParseTree(root2))
False
>>> g.add_rule(Rule("F", ["x"]))
>>> r_f = Rule("F", ["x"])
>>> g.add_rule(r_f)
>>> f_root = ParseTreeNode("F", r_f, [ParseTreeNode("x", None)])
>>> g.validate(ParseTree(f_root))
True
>>> g.validate(ParseTree(f_root), require_start=True)
False
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parse_tree
|
ParseTree
|
The ParseTree to validate. |
required |
require_start
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
bool
|
|
bool
|
in this grammar, and all constraints would have permitted the derivation |
bool
|
(and the root matches |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in SRToolkit/utils/grammar/grammar.py
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 | |
start_derivation
Begin a new derivation.
Examples:
>>> g = Grammar(start="E")
>>> g.add_rule(Rule("E", ["x"]))
>>> d = g.start_derivation()
>>> d.complete
False
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
Optional[str]
|
Start non-terminal. Defaults to
|
None
|
Returns:
| Type | Description |
|---|---|
'Derivation'
|
A Derivation at the first expansion step. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the resolved start symbol is not a non-terminal in this grammar. |
Source code in SRToolkit/utils/grammar/grammar.py
generate_one
generate_one(start: Optional[str] = None, max_steps: int = 1000, max_retries: int = 10) -> Optional[list[str]]
Generate a single expression by sampling the grammar to completion.
Convenience wrapper around start_derivation and Derivation.generate.
Each attempt runs a fresh derivation from scratch. When a derivation
exceeds max_steps rule applications (e.g. because random sampling
kept choosing recursive rules), the attempt is discarded and a new one
starts. None is returned only when every attempt fails, signalling
that the caller should either relax the grammar, add more liberal
constraints, or increase max_steps / max_retries.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar(start="E")
>>> g.add_rule(Rule("E", ["x"]))
>>> g.generate_one()
['x']
>>> g.generate_one(max_steps=0, max_retries=1) is None
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
Optional[str]
|
Start non-terminal. Defaults to |
None
|
max_steps
|
int
|
Maximum rule applications per attempt. A negative number means unlimited (no retry logic applies). |
1000
|
max_retries
|
int
|
Maximum number of fresh attempts before returning
|
10
|
Returns:
| Type | Description |
|---|---|
Optional[list[str]]
|
A list of terminal tokens in left-to-right order, or |
Optional[list[str]]
|
every attempt exceeded |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the resolved start symbol is not a non-terminal in this grammar. |
Source code in SRToolkit/utils/grammar/grammar.py
to_dict
Serialise this grammar to a JSON-safe dictionary.
Constraints are serialised via their own to_dict method. User-defined
constraint subclasses must implement to_dict/from_dict to survive the
round-trip; built-in constraints are fully supported.
Returns:
| Type | Description |
|---|---|
dict
|
Dictionary with keys |
Source code in SRToolkit/utils/grammar/grammar.py
from_dict
classmethod
Reconstruct a Grammar from a dictionary produced by to_dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d
|
dict
|
Dictionary with keys |
required |
Returns:
| Type | Description |
|---|---|
'Grammar'
|
A new Grammar with all rules and |
'Grammar'
|
constraints registered. |
Source code in SRToolkit/utils/grammar/grammar.py
from_symbol_library
classmethod
from_symbol_library(symbol_library: Optional[SymbolLibrary] = None, start: Optional[str] = None) -> 'Grammar'
Build a PCFG from a SymbolLibrary using a generic operator-precedence non-terminal hierarchy.
One non-terminal is created per unique OP precedence level present in
the library. Known precedences map to readable names; custom precedences
fall back to L_{p}:
OP_ADDITIVE→EOP_MULTIPLICATIVE→FOP_POWER→B- custom precedence p →
L_{p}
The chain runs from the lowest-precedence level (the grammar start) down
to T, K, R, and V. Only levels that have at least one
operator are generated; there are no empty pass-through non-terminals.
Heuristic weights are calibrated against empirical expression distributions (e.g. Wikipedia mathematical formulae): E recursion ~30 %, F recursion ~40 %, B recursion ~5 %, custom-level recursion ~40 %.
Examples:
>>> sl = SymbolLibrary.from_symbol_list(["+", "-", "*", "sin", "^2"], 2)
>>> g = Grammar.from_symbol_library(sl)
>>> "E" in g.nonterminals and "V" in g.nonterminals
True
>>> g.is_pcfg()
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol_library
|
Optional[SymbolLibrary]
|
Token vocabulary. Falls back to the active library
from the context manager when |
None
|
start
|
Optional[str]
|
Override the name of the lowest-precedence non-terminal (and
|
None
|
Returns:
| Type | Description |
|---|---|
'Grammar'
|
A Grammar with heuristic PCFG weights. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the symbol library contains no variables, constants, or literals, making it impossible to generate any terminal expression. |
Source code in SRToolkit/utils/grammar/grammar.py
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 | |
from_grammar_string
classmethod
Construct a Grammar from a string in NLTK production-rule notation.
Each non-empty, non-comment line must have the form::
LHS -> RHS_1 | RHS_2 | ...
where each alternative is a space-separated sequence of symbols.
Terminals are enclosed in single quotes (e.g. '+'); unquoted
tokens are treated as non-terminals. An optional weight in square
brackets at the end of an alternative is parsed as a
float (e.g. E '+' F [0.4]); alternatives without a weight
default to 1.0.
The start symbol may be embedded as a comment header emitted by to_grammar_string::
# start: E
The explicit start parameter takes precedence over this comment.
When start is None, only the first # start: comment encountered
is used; subsequent ones are ignored along with all other comment lines.
Rule names are not preserved in NLTK notation and are set to None
on all returned Rule objects.
Examples:
>>> g = Grammar.from_grammar_string("E -> E '+' F | F\nF -> 'x'", start="E")
>>> sorted(g.nonterminals)
['E', 'F']
>>> g.rules_for("F")[0].rhs
['x']
>>> g.start
'E'
>>> g2 = Grammar.from_grammar_string(
... "# start: E\nE -> E '+' F [0.4] | F [0.6]\nF -> 'x' [1.0]"
... )
>>> g2.start
'E'
>>> g2.is_pcfg()
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
text
|
str
|
Grammar specification in NLTK notation, optionally with a
|
required |
start
|
Optional[str]
|
Start non-terminal stored on the returned grammar. Takes
precedence over a |
None
|
Returns:
| Type | Description |
|---|---|
'Grammar'
|
A new Grammar populated with |
'Grammar'
|
the parsed rules. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a content line contains no |
ValueError
|
If a weight token cannot be converted to |
ValueError
|
If |
ValueError
|
If the start symbol cannot be determined from either
the parameter or the |
Source code in SRToolkit/utils/grammar/grammar.py
796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 | |
to_grammar_string
Serialise this grammar to a string in NLTK production-rule notation.
All rules sharing the same left-hand side are written on a single
line, separated by |. Symbols that are non-terminals (i.e.
appear as the lhs of at least one rule) are written unquoted;
all other symbols are enclosed in single quotes. When
is_pcfg returns True,
the probability of each alternative (its weight divided by the sum of
weights for that non-terminal) is appended in square brackets.
Rule names and registered constraints are not included in the output.
When self.start is set, a # start: <symbol> comment is
prepended so that
from_grammar_string
can reconstruct the grammar without requiring an explicit start
argument.
Examples:
>>> g = Grammar([
... Rule("E", ["E", "+", "F"]),
... Rule("E", ["F"]),
... Rule("F", ["x"]),
... ], start="E")
>>> print(g.to_grammar_string())
# start: E
E -> E '+' F | F
F -> 'x'
>>> gp = Grammar([
... Rule("E", ["x"], weight=1.0),
... Rule("E", ["y"], weight=3.0),
... ])
>>> print(gp.to_grammar_string())
E -> 'x' [0.25] | 'y' [0.75]
Returns:
| Type | Description |
|---|---|
str
|
Multi-line string in NLTK production-rule notation, optionally |
str
|
preceded by a |
Source code in SRToolkit/utils/grammar/grammar.py
MaxDepth
Bases: Constraint[int, None]
Hard limit on derivation depth.
Local state is the remaining depth budget at each slot. A rule is rejected when its application would require at least one recursive non-terminal child but the budget has reached zero.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
... Rule("E", ["E", "+", "E"]),
... Rule("E", ["x"]),
... ])
>>> g.add_constraint(MaxDepth(0))
>>> d = g.start_derivation("E")
>>> all(r.rhs == ["x"] for r in d.options())
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
limit
|
int
|
Maximum nesting depth (number of non-terminal expansion levels). |
required |
Source code in SRToolkit/utils/grammar/constraints.py
MaxNodes
Bases: Constraint[None, int]
Hard limit on the total number of rule applications in a derivation.
Global state is a running count of applications so far. A rule is rejected when applying it would push the count over the limit taking into account the non-terminal children it introduces (each of which will require at least one more application).
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
... Rule("E", ["E", "+", "E"]),
... Rule("E", ["x"]),
... ])
>>> g.add_constraint(MaxNodes(3))
>>> d = g.start_derivation("E")
>>> # At node-count 3 only the terminal rule survives
>>> tokens = d.generate()
>>> len(tokens) <= 5
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
limit
|
int
|
Maximum number of rule applications. |
required |
Source code in SRToolkit/utils/grammar/constraints.py
MaxOccurrences
Bases: Constraint[None, int]
Hard limit on how many times a specific terminal symbol may appear.
Global state counts occurrences committed so far. A rule whose rhs
contains the tracked symbol is rejected once the count equals the limit.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
... Rule("E", ["E", "+", "E"]),
... Rule("E", ["x"]),
... Rule("E", ["y"]),
... ])
>>> g.add_constraint(MaxOccurrences("x", 1))
>>> d = g.start_derivation("E")
>>> tokens = d.generate(limit=200)
>>> tokens.count("x") <= 1
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Terminal token to track. |
required |
limit
|
int
|
Maximum allowed occurrences. |
required |
Source code in SRToolkit/utils/grammar/constraints.py
NoNested
Bases: Constraint[bool, None]
Prevent any symbol in a group from appearing nested inside any other symbol in the same group.
Local state is a boolean "currently under any symbol in the group". Any
rule whose rhs contains a group symbol is rejected while the local
state is True. Children inherit True once such a rule is applied.
Pass a single string to forbid self-nesting only; pass multiple symbols to
forbid cross-nesting within the group (e.g. NoNested(TRIG_FNS) prevents
sin(cos(x)) as well as sin(sin(x))).
Warning
This constraint propagates the "inside" flag to all non-terminal
children of any rule whose rhs contains a group symbol. It works
correctly when each group symbol appears in a rule with exactly one
non-terminal child (typical prefix-function rules such as
E -> sin ( E )) and for infix operators where every child should be
blocked (e.g. E -> E + E). It does not work correctly for
mixed rules that combine a function call with additional non-terminal
siblings in a single production (e.g. E -> sin ( E ) + F): the
sibling F will be incorrectly treated as inside the group. If your
grammar uses such rules, implement a custom
Constraint instead.
Examples:
>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
... Rule("E", ["sin", "(", "E", ")"]),
... Rule("E", ["cos", "(", "E", ")"]),
... Rule("E", ["x"]),
... ])
>>> g.add_constraint(NoNested(["sin", "cos"]))
>>> d = g.start_derivation("E")
>>> d.apply(g.rules_for("E")[0]) # apply sin(E)
>>> # Now inside sin; both sin and cos rules should be gone
>>> opts = d.options()
>>> all("sin" not in r.rhs and "cos" not in r.rhs for r in opts)
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbols
|
str | Iterable[str]
|
A single terminal token or an iterable of terminal tokens that form the nesting group. |
required |
Source code in SRToolkit/utils/grammar/constraints.py
ParseTree
Full derivation history of an expression under a grammar.
Unlike Node, which holds only terminal symbols for expression evaluation, a ParseTree retains every non-terminal and every production applied during the derivation.
Examples:
>>> r = Rule("E", ["x"])
>>> leaf = ParseTreeNode("x", None)
>>> root = ParseTreeNode("E", r, [leaf])
>>> pt = ParseTree(root)
>>> pt.to_token_list()
['x']
>>> pt.productions_used()
[Rule(lhs='E', rhs=['x'], weight=1.0, name=None)]
Attributes:
| Name | Type | Description |
|---|---|---|
root |
The root node of the parse tree. |
Source code in SRToolkit/utils/grammar/grammar.py
to_token_list
Collect all terminal tokens in left-to-right order.
Examples:
>>> leaf1 = ParseTreeNode("a", None)
>>> leaf2 = ParseTreeNode("+", None)
>>> leaf3 = ParseTreeNode("b", None)
>>> root = ParseTreeNode("E", Rule("E", ["a", "+", "b"]), [leaf1, leaf2, leaf3])
>>> ParseTree(root).to_token_list()
['a', '+', 'b']
Returns:
| Type | Description |
|---|---|
list[str]
|
Flat list of terminal tokens in left-to-right order. |
Source code in SRToolkit/utils/grammar/grammar.py
productions_used
Return all rules applied in the derivation, in pre-order.
Examples:
>>> leaf = ParseTreeNode("x", None)
>>> r = Rule("E", ["x"])
>>> root = ParseTreeNode("E", r, [leaf])
>>> ParseTree(root).productions_used()
[Rule(lhs='E', rhs=['x'], weight=1.0, name=None)]
Returns:
| Type | Description |
|---|---|
list[Rule]
|
List of Rule objects in pre-order traversal order. |
Source code in SRToolkit/utils/grammar/grammar.py
ParseTreeNode
dataclass
A node in a derivation parse tree.
Terminal leaves have rule_applied=None and children=[].
Internal nodes store the rule that expanded the non-terminal at this position.
Attributes:
| Name | Type | Description |
|---|---|---|
symbol |
str
|
The grammar symbol at this node (terminal or non-terminal name). |
rule_applied |
Optional[Rule]
|
The Rule used to expand this node.
|
children |
list[ParseTreeNode]
|
Ordered child nodes corresponding to the symbols in
|
Rule
dataclass
A single production rule in a grammar.
A rule is a CFG production when weights remain 1 for all rules in the grammar; when weights differ across rules, the grammar is treated as a PCFG and productions are sampled proportionally. Weights are unnormalised — the grammar normalises them at sampling time.
Examples:
>>> r = Rule("E", ["E", "+", "F"], weight=0.4, name="E_add_+")
>>> r.lhs
'E'
>>> r.rhs
['E', '+', 'F']
>>> r.weight
0.4
>>> r.name
'E_add_+'
>>> Rule("E", ["F"]).weight
1.0
Attributes:
| Name | Type | Description |
|---|---|---|
lhs |
str
|
The non-terminal being expanded, e.g. |
rhs |
list[str]
|
Ordered sequence of symbols the non-terminal expands to. Each
element is either a terminal token (e.g. |
weight |
float
|
Unnormalised sampling weight. Defaults to |
name |
Optional[str]
|
Optional stable identifier for this rule. Used by constraints
for scoping and identification. |
from_line
classmethod
Parse one NLTK production line into a list of Rule objects.
The line must have the form::
LHS -> RHS_1 | RHS_2 | ...
where each alternative is a space-separated sequence of symbols.
Terminals are enclosed in single quotes; unquoted tokens are non-terminals.
An optional weight in square brackets at the end of an alternative is parsed
as a float; alternatives without a weight default to 1.0.
Examples:
>>> Rule.from_line("E -> E '+' F [0.4] | F [0.6]")
[Rule(lhs='E', rhs=['E', '+', 'F'], weight=0.4, name=None), Rule(lhs='E', rhs=['F'], weight=0.6, name=None)]
>>> Rule.from_line("F -> 'x'")
[Rule(lhs='F', rhs=['x'], weight=1.0, name=None)]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
line
|
str
|
A single non-empty, non-comment production line. |
required |
Returns:
| Type | Description |
|---|---|
list['Rule']
|
List of Rule objects, one per alternative. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
ValueError
|
If the left-hand side is empty. |
ValueError
|
If a weight token cannot be converted to |
ValueError
|
If no alternatives are parsed from the right-hand side. |
Source code in SRToolkit/utils/grammar/grammar.py
SymbolLibrary
SymbolLibrary(symbols: Optional[List[str]] = None, num_variables: int = 0, preamble: Optional[List[str]] = None)
A registry of tokens and their properties, used throughout the toolkit to parse, compile, and generate symbolic expressions.
By default, the library uses NumPy for operator and function evaluation. To use a
different backend, pass the required import statements via preamble.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x", "x")
>>> library.get_type("x")
'var'
>>> library.get_precedence("x")
0
>>> library.get_np_fn("x")
'x'
>>> library.remove_symbol("x")
>>> library = SymbolLibrary.default_symbols()
>>> # You can also initialize the library with a list of symbols (listed in SymbolLibrary.default_symbols)
>>> # and the number of variables.
>>> library2 = SymbolLibrary(["+", "*", "sin"], num_variables=2)
>>> len(library2)
5
>>> # Use as a context manager to avoid passing sl explicitly
>>> # with SymbolLibrary.default_symbols(num_variables=2) as sl:
>>> # tree = tokens_to_tree(["X_0", "+", "X_1", "*", "C"])
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbols
|
Optional[List[str]]
|
Symbols to pre-populate from the default set. |
None
|
num_variables
|
int
|
Number of variable tokens to add, labeled |
0
|
preamble
|
Optional[List[str]]
|
Import statements prepended to compiled expression functions.
Defaults to |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
symbols |
Mapping from token string to its property dict (type, precedence, NumPy function string, LaTeX template). |
Source code in SRToolkit/utils/symbol_library.py
add_symbol
add_symbol(symbol: str, symbol_type: str, precedence: int, np_fn: str, latex_str: Optional[str] = None, cython_id: int = -1)
Add a token to the library with its associated type, precedence, NumPy function string, and LaTeX template.
Symbol types:
"op": binary operator (e.g.+,*)."fn": unary function (e.g.sin,sqrt)."lit": literal with a fixed value (e.g.pi,e)."const": free constant whose value is optimised during parameter estimation (e.g.C). Using a single"const"token is recommended; multiple tokens increase complexity and reduce readability."var": input variable whose values are read from the data arrayX.
If latex_str is omitted, a default template is generated based on the symbol type:
"{} \text{symb} {}" for operators (e.g. + → {} \text{+} {}),
"\text{symb} {}" for functions (e.g. sin → \text{sin} {}),
and "\text{symb}" for all other types.
Note
For best performance, we recommend selecting symbols from the predefined set via the constructor or from_symbol_list rather than adding custom ones. Predefined symbols have C implementations and benefit from Cython acceleration; custom symbols always fall back to Python callables.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
>>> library.add_symbol("C", "const", 5, "C[{}]", r"c_{}")
>>> library.add_symbol("X_0", "var", 5, "X[:, 0]", r"X_0")
>>> library.add_symbol("pi", "lit", 5, "np.pi", r"\pi")
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token string to register. |
required |
symbol_type
|
str
|
One of |
required |
precedence
|
int
|
Operator precedence, used for infix reconstruction and PCFG generation.
For |
required |
np_fn
|
str
|
Python/NumPy expression string used in compiled callables
(e.g. |
required |
latex_str
|
Optional[str]
|
LaTeX template string with |
None
|
cython_id
|
int
|
Integer dispatch ID used by the Cython-based evaluator.
|
-1
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in SRToolkit/utils/symbol_library.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | |
remove_symbol
Remove a token from the library.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> len(library.symbols)
1
>>> library.remove_symbol("x")
>>> len(library.symbols)
0
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token string to remove. |
required |
Raises:
| Type | Description |
|---|---|
KeyError
|
If |
Source code in SRToolkit/utils/symbol_library.py
get_type
Return the type of a symbol.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.get_type("x")
'var'
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token to look up. |
required |
Returns:
| Type | Description |
|---|---|
str
|
The type string ( |
Source code in SRToolkit/utils/symbol_library.py
get_precedence
Return the precedence of a symbol.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.get_precedence("x")
0
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token to look up. |
required |
Returns:
| Type | Description |
|---|---|
int
|
The precedence value if the symbol is in the library, otherwise |
Source code in SRToolkit/utils/symbol_library.py
get_np_fn
Return the NumPy function string for a symbol.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.get_np_fn("x")
'x'
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token to look up. |
required |
Returns:
| Type | Description |
|---|---|
str
|
The NumPy function string if the symbol is in the library, otherwise an empty string. |
Source code in SRToolkit/utils/symbol_library.py
get_cython_id
Return the Cython stack machine dispatch ID for a symbol.
Returns -1 for symbols without a C implementation (they will fall
back to a Python callable during Cython evaluation).
Examples:
>>> library = SymbolLibrary.default_symbols()
>>> library.get_cython_id("+")
0
>>> library.get_cython_id("sin")
10
>>> library.get_cython_id("unknown_symbol")
-1
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token to look up. |
required |
Returns:
| Type | Description |
|---|---|
int
|
Integer dispatch ID, or |
int
|
or has no Cython implementation. |
Source code in SRToolkit/utils/symbol_library.py
get_latex_str
Return the LaTeX template string for a symbol.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x", "test")
>>> library.get_latex_str("x")
'test'
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol
|
str
|
Token to look up. |
required |
Returns:
| Type | Description |
|---|---|
str
|
The LaTeX template string if the symbol is in the library, otherwise an empty string. |
Source code in SRToolkit/utils/symbol_library.py
get_symbols_of_type
Return all symbols of a given type.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.add_symbol("y", "var", 0, "y")
>>> library.get_symbols_of_type("var")
['x', 'y']
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol_type
|
str
|
Type to filter by. One of |
required |
Returns:
| Type | Description |
|---|---|
List[str]
|
List of token strings matching the requested type. Returns an empty list |
List[str]
|
if no symbols match or if |
Source code in SRToolkit/utils/symbol_library.py
symbols2index
Return a mapping from each token to its index in insertion order.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x")
>>> library.add_symbol("y", "var", 0, "y")
>>> print(library.symbols2index())
{'x': 0, 'y': 1}
>>> library.remove_symbol("x")
>>> print(library.symbols2index())
{'y': 0}
Returns:
| Type | Description |
|---|---|
Dict[str, int]
|
Dict mapping each token string to its zero-based position in the library. |
Source code in SRToolkit/utils/symbol_library.py
from_symbol_list
staticmethod
Create a SymbolLibrary containing only the specified subset of default symbols.
The supported token names are those defined in default_symbols.
Examples:
>>> library = SymbolLibrary.from_symbol_list(["+", "*", "C"], num_variables=2)
>>> len(library.symbols)
5
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbols
|
List[str]
|
Token strings to include. Symbols not in the default set are silently ignored. |
required |
num_variables
|
int
|
Number of variable tokens ( |
25
|
Returns:
| Type | Description |
|---|---|
SymbolLibrary
|
A SymbolLibrary restricted to the requested symbols and variables. |
Source code in SRToolkit/utils/symbol_library.py
default_symbols
staticmethod
Return a SymbolLibrary pre-populated with standard mathematical symbols.
Supported tokens:
- Operators (
"op"):+,-,*,/,^ - Functions (
"fn"):u-,sqrt,sin,cos,exp,tan,arcsin,arccos,arctan,sinh,cosh,tanh,floor,ceil,ln,log,^-1,^2,^3,^4,^5 - Literals (
"lit"):pi,e - Free constant (
"const"):C - Variables (
"var"):X_0throughX_{num_variables-1}, mapped to columns of the input array in order.
Examples:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_variables
|
int
|
Number of variable tokens to include. Default is |
25
|
Returns:
| Type | Description |
|---|---|
SymbolLibrary
|
A SymbolLibrary populated with the symbols listed above. |
Source code in SRToolkit/utils/symbol_library.py
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 | |
get_active
staticmethod
Return the currently active SymbolLibrary.
Checks, in order: (1) the context manager stack, (2) the module-level default set via set_default.
Returns:
| Type | Description |
|---|---|
SymbolLibrary
|
The active SymbolLibrary instance. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If no library is active and no default has been set. |
Source code in SRToolkit/utils/symbol_library.py
get_or_default
staticmethod
Return the active SymbolLibrary, falling back to default_symbols when nothing is active.
Checks, in order: (1) the context manager stack, (2) the module-level default set via set_default, (3) a freshly constructed default library.
Returns:
| Type | Description |
|---|---|
SymbolLibrary
|
The active or default SymbolLibrary instance. |
Source code in SRToolkit/utils/symbol_library.py
set_default
staticmethod
Set (or clear) a module-level default SymbolLibrary.
The default is used as a fallback by get_active and get_or_default when no context manager is active. It is module-global (not per-thread or per-task) and intended for scripts and notebooks where a single library is used throughout a session.
Pass None to clear the default.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sl
|
Optional[SymbolLibrary]
|
Library to set as the module-level default, or |
required |
Source code in SRToolkit/utils/symbol_library.py
to_dict
Serialize the library to a JSON-safe dictionary.
Examples:
>>> library = SymbolLibrary.from_symbol_list(["+"], num_variables=1)
>>> d = library.to_dict()
>>> d["format_version"]
1
>>> d["num_variables"]
1
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary suitable for passing to from_dict. |
Source code in SRToolkit/utils/symbol_library.py
from_dict
staticmethod
Reconstruct a SymbolLibrary from a dictionary produced by to_dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d
|
dict
|
Dictionary representation of the library, as produced by to_dict. |
required |
Returns:
| Type | Description |
|---|---|
SymbolLibrary
|
The reconstructed SymbolLibrary. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in SRToolkit/utils/symbol_library.py
__len__
Return the number of symbols currently in the library.
Examples:
>>> library = SymbolLibrary.default_symbols(5)
>>> len(library)
34
>>> library.add_symbol("a", "lit", 5, "a", "a")
>>> len(library)
35
Returns:
| Type | Description |
|---|---|
int
|
Number of tokens registered in the library. |
Source code in SRToolkit/utils/symbol_library.py
__str__
Return a comma-separated string of all registered token strings.
Examples:
>>> library = SymbolLibrary()
>>> library.add_symbol("x", "var", 0, "x", "x")
>>> str(library)
'x'
>>> library.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
>>> str(library)
'x, sin'
Returns:
| Type | Description |
|---|---|
str
|
All token names joined by |
Source code in SRToolkit/utils/symbol_library.py
__copy__
Return a copy of the library with independent copies of all attributes.
Examples:
>>> old_symbols = SymbolLibrary()
>>> old_symbols.add_symbol("x", "var", 0, "x", "x")
>>> print(old_symbols)
x
>>> new_symbols = copy.copy(old_symbols)
>>> new_symbols.add_symbol("sin", "fn", 5, "np.sin({})", r"\sin {}")
>>> print(old_symbols)
x
>>> print(new_symbols)
x, sin
Returns:
| Type | Description |
|---|---|
SymbolLibrary
|
A new SymbolLibrary instance with deep-copied symbols and preamble. |
Source code in SRToolkit/utils/symbol_library.py
EstimationSettings
Bases: TypedDict
Shared settings for parameter estimation and BED evaluation.
Passed as **kwargs to SR_dataset, SR_evaluator, and
ParameterEstimator. All fields are optional.
Examples:
>>> settings: EstimationSettings = {"method": "L-BFGS-B", "max_iter": 200}
>>> settings.get("method")
'L-BFGS-B'
>>> settings.get("tol", 1e-6)
1e-06
Attributes:
| Name | Type | Description |
|---|---|---|
method |
str
|
Optimization algorithm for parameter fitting. Default: |
tol |
float
|
Termination tolerance for the optimizer. Default: |
gtol |
float
|
Gradient-norm termination tolerance. Default: |
max_iter |
int
|
Maximum optimizer iterations. Default: |
constant_bounds |
Tuple[float, float]
|
|
initialization |
str
|
Constant initialization strategy — |
max_constants |
int
|
Maximum number of free constants permitted in a single
expression. Expressions exceeding this limit score |
max_expr_length |
int
|
Maximum expression length in tokens. |
backend |
str
|
Evaluation backend used by |
num_points_sampled |
int
|
Number of domain points used when evaluating expression
behavior for BED. |
bed_X |
Optional[ndarray]
|
Fixed evaluation points for BED. If |
num_consts_sampled |
int
|
Number of constant vectors sampled per expression for
BED. Default: |
domain_bounds |
Optional[List[Tuple[float, float]]]
|
Per-variable |
EvalResult
dataclass
EvalResult(min_error: float, best_expr: str, num_evaluated: int, evaluation_calls: int, top_models: List[ModelResult], all_models: List[ModelResult], approach_name: str, success: bool, dataset_name: Optional[str] = None, metadata: Optional[dict] = None, augmentations: Dict[str, Dict[str, Any]] = dict())
Result for a single SR experiment, as returned by SR_results[i].
Examples:
>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
... min_error=0.05,
... best_expr="X_0",
... num_evaluated=500,
... evaluation_calls=612,
... top_models=[model],
... all_models=[model],
... approach_name="MyApproach",
... success=True,
... )
>>> result.min_error
0.05
>>> result.success
True
>>> result.dataset_name is None
True
Attributes:
| Name | Type | Description |
|---|---|---|
min_error |
float
|
Lowest error achieved across all evaluated expressions. |
best_expr |
str
|
String representation of the best expression found. |
num_evaluated |
int
|
Number of unique expressions evaluated. |
evaluation_calls |
int
|
Number of times |
top_models |
List[ModelResult]
|
Top-k models sorted by error. |
all_models |
List[ModelResult]
|
All evaluated models sorted by error. |
approach_name |
str
|
Name of the SR approach, or empty string if not provided. |
success |
bool
|
Whether |
dataset_name |
Optional[str]
|
Name of the dataset. |
metadata |
Optional[dict]
|
Arbitrary metadata dict associated with the dataset. |
augmentations |
Dict[str, Dict[str, Any]]
|
Per-augmenter data keyed by augmenter name. Populated by ResultAugmenter subclasses via add_augmentation. |
add_augmentation
Attach augmentation data produced by a ResultAugmenter to this result.
If name is already present in :attr:augmentations, a numeric suffix is
appended (name_1, name_2, …) to avoid overwriting existing data.
Examples:
>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
... min_error=0.05, best_expr="X_0", num_evaluated=10,
... evaluation_calls=10, top_models=[model], all_models=[model],
... approach_name="MyApproach", success=True,
... )
>>> result.add_augmentation("complexity", {"value": 3}, "ComplexityAugmenter")
>>> result.augmentations["complexity"]["value"]
3
>>> result.add_augmentation("complexity", {"value": 5}, "ComplexityAugmenter")
>>> "complexity_1" in result.augmentations
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Key under which the augmentation is stored in :attr: |
required |
data
|
Dict[str, Any]
|
Arbitrary dict of augmentation data. Any existing |
required |
aug_type
|
str
|
Augmenter class name, stored as |
required |
Source code in SRToolkit/utils/types.py
to_dict
Serialize this evaluation result to a JSON-safe dictionary.
NumPy arrays and scalars within nested ModelResult entries are
converted to native Python types so the result can be passed directly
to json.dump.
Examples:
>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
... min_error=0.05, best_expr="X_0", num_evaluated=10,
... evaluation_calls=10, top_models=[model], all_models=[model],
... approach_name="MyApproach", success=True,
... )
>>> d = result.to_dict()
>>> d["min_error"]
0.05
>>> d["approach_name"]
'MyApproach'
>>> len(d["top_models"])
1
Returns:
| Type | Description |
|---|---|
dict
|
A JSON-safe dictionary suitable for passing to from_dict. |
Source code in SRToolkit/utils/types.py
from_dict
staticmethod
Reconstruct an EvalResult from a dictionary produced by to_dict.
Examples:
>>> model = ModelResult(expr=["X_0"], error=0.05)
>>> result = EvalResult(
... min_error=0.05, best_expr="X_0", num_evaluated=10,
... evaluation_calls=10, top_models=[model], all_models=[model],
... approach_name="MyApproach", success=True,
... )
>>> result2 = EvalResult.from_dict(result.to_dict())
>>> result2.min_error
0.05
>>> result2.best_expr
'X_0'
>>> len(result2.top_models)
1
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict
|
Dictionary representation of an EvalResult, as produced by to_dict. |
required |
Returns:
| Type | Description |
|---|---|
EvalResult
|
The reconstructed EvalResult. |
Source code in SRToolkit/utils/types.py
ModelResult
dataclass
ModelResult(expr: List[str], error: float, parameters: Optional[ndarray] = None, augmentations: Dict[str, Dict[str, Any]] = dict())
A single model entry in EvalResult.top_models and EvalResult.all_models.
Examples:
>>> result = ModelResult(expr=["C", "*", "X_0"], error=0.42)
>>> result.expr
['C', '*', 'X_0']
>>> result.error
0.42
>>> result.parameters is None
True
Attributes:
| Name | Type | Description |
|---|---|---|
expr |
List[str]
|
Token list representing the expression, e.g. |
error |
float
|
Numeric error under the ranking function (RMSE or BED). |
parameters |
Optional[ndarray]
|
Fitted constant values. Present for RMSE ranking only, |
augmentations |
Dict[str, Dict[str, Any]]
|
Per-augmenter data keyed by augmenter name. Populated by ResultAugmenter subclasses via add_augmentation. |
add_augmentation
Attach augmentation data produced by a ResultAugmenter to this result.
If name is already present in :attr:augmentations, a numeric suffix is
appended (name_1, name_2, …) to avoid overwriting existing data.
Examples:
>>> result = ModelResult(expr=["X_0"], error=0.1)
>>> result.add_augmentation("latex", {"value": "$X_0$"}, "LaTeXAugmenter")
>>> result.augmentations["latex"]["value"]
'$X_0$'
>>> result.add_augmentation("latex", {"value": "$X_0$"}, "LaTeXAugmenter")
>>> "latex_1" in result.augmentations
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Key under which the augmentation is stored in :attr: |
required |
data
|
Dict[str, Any]
|
Arbitrary dict of augmentation data. Any existing |
required |
aug_type
|
str
|
Augmenter class name, stored as |
required |
Source code in SRToolkit/utils/types.py
to_dict
Serialize this model result to a JSON-safe dictionary.
NumPy arrays and scalars are converted to native Python types so the
result can be passed directly to json.dump.
Examples:
>>> result = ModelResult(expr=["X_0", "+", "C"], error=0.25)
>>> d = result.to_dict()
>>> d["expr"]
['X_0', '+', 'C']
>>> d["error"]
0.25
>>> d["parameters"] is None
True
Returns:
| Type | Description |
|---|---|
dict
|
A JSON-safe dictionary suitable for passing to from_dict. |
Source code in SRToolkit/utils/types.py
from_dict
staticmethod
Reconstruct a ModelResult from a dictionary produced by to_dict.
Examples:
>>> result = ModelResult(expr=["X_0", "+", "C"], error=0.25)
>>> result2 = ModelResult.from_dict(result.to_dict())
>>> result2.expr
['X_0', '+', 'C']
>>> result2.error
0.25
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict
|
Dictionary representation of a ModelResult, as produced by to_dict. |
required |
Returns:
| Type | Description |
|---|---|
ModelResult
|
The reconstructed :class: |
Source code in SRToolkit/utils/types.py
compile_expr
compile_expr(expr: Union[List[str], Node], symbol_library: Optional[SymbolLibrary] = None, backend: str = 'stack') -> Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray]
Compile an expression into a callable f(X, C) → np.ndarray.
Examples:
>>> f = compile_expr(["X_0", "+", "1"])
>>> f(np.array([[1.0], [2.0], [3.0]]), None)
array([2., 3., 4.])
>>> f = compile_expr(["X_0", "+", "1"], backend="codegen")
>>> f(np.array([[1], [2], [3]]), np.array([]))
array([2, 3, 4])
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Union[List[str], Node]
|
Expression as a token list in infix notation or a Node tree. |
required |
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used to look up token types. Defaults to SymbolLibrary.default_symbols. |
None
|
backend
|
str
|
Evaluation backend. One of:
|
'stack'
|
Returns:
| Type | Description |
|---|---|
Callable[[ndarray, Optional[ndarray]], ndarray]
|
A callable |
Callable[[ndarray, Optional[ndarray]], ndarray]
|
|
Callable[[ndarray, Optional[ndarray]], ndarray]
|
(pass |
Callable[[ndarray, Optional[ndarray]], ndarray]
|
Returns a 1-D output array of shape |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in SRToolkit/utils/expression_compiler.py
compile_expr_rmse
compile_expr_rmse(expr: Union[List[str], Node], symbol_library: Optional[SymbolLibrary] = None, backend: str = 'stack', X: Optional[ndarray] = None) -> Callable[[np.ndarray, np.ndarray, np.ndarray], float]
Compile an expression into an RMSE callable f(X, C, y) → float.
Examples:
>>> f = compile_expr_rmse(["X_0", "+", "1"])
>>> f(np.array([[1.0], [2.0], [3.0]]), np.array([]), np.array([2.0, 3.0, 4.0]))
0.0
>>> f = compile_expr_rmse(["X_0", "+", "1"], backend="codegen")
>>> print(float(f(np.array([[1], [2], [3]]), np.array([]), np.array([2, 3, 4]))))
0.0
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Union[List[str], Node]
|
Expression as a token list in infix notation or a Node tree. |
required |
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used to look up token types. Defaults to SymbolLibrary.default_symbols. |
None
|
backend
|
str
|
Evaluation backend. One of:
|
'stack'
|
X
|
Optional[ndarray]
|
Optional input data of shape |
None
|
Returns:
| Type | Description |
|---|---|
Callable[[ndarray, ndarray, ndarray], float]
|
A callable |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in SRToolkit/utils/expression_compiler.py
generate_n_expressions
generate_n_expressions(expression_description: Union[str, SymbolLibrary, Grammar], num_expressions: int, unique: bool = True, max_expression_length: int = 50, verbose: bool = True, max_consecutive_generation_failures: int = 100, max_consecutive_uniqueness_failures: int = 200, max_derivation_steps: int = 1000, start: str = 'E') -> List[List[str]]
Sample num_expressions expressions from a grammar or symbol library.
Examples:
>>> len(generate_n_expressions(SymbolLibrary.default_symbols(5), 100, verbose=False))
100
>>> generate_n_expressions(SymbolLibrary.from_symbol_list([], 1), 3, unique=False, verbose=False, max_expression_length=1)
[['X_0'], ['X_0'], ['X_0']]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expression_description
|
Union[str, SymbolLibrary, Grammar]
|
Grammar source — one of:
|
required |
num_expressions
|
int
|
Number of expressions to generate. |
required |
unique
|
bool
|
If |
True
|
max_expression_length
|
int
|
Maximum token count per expression. Values ≤ |
50
|
verbose
|
bool
|
Display a progress bar showing total attempts, the ratio of invalid
expressions (derivation failed or exceeded |
True
|
max_consecutive_generation_failures
|
int
|
Maximum number of consecutive attempts that
produce an invalid expression (derivation failed or result exceeded
|
100
|
max_consecutive_uniqueness_failures
|
int
|
Maximum number of consecutive valid expressions
that are already in the output set before stopping early and returning what has
been collected so far. Only relevant when |
200
|
max_derivation_steps
|
int
|
Maximum number of rule applications per single derivation
attempt before it is abandoned. Increase for grammars with deep recursion that
legitimately require many steps. Default |
1000
|
start
|
str
|
Start non-terminal used when |
'E'
|
Returns:
| Type | Description |
|---|---|
List[List[str]]
|
List of expressions, each represented as a list of string tokens in infix notation. |
List[List[str]]
|
May contain fewer than |
List[List[str]]
|
search space is exhausted before the target count is reached. |
Raises:
| Type | Description |
|---|---|
Exception
|
If |
Exception
|
If generation fails |
Warns:
| Type | Description |
|---|---|
UserWarning
|
If more than 80 % of attempts produce an invalid expression (after at
least |
UserWarning
|
If |
Source code in SRToolkit/utils/expression_generator.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | |
simplify
simplify(expr: Union[List[str], Node], symbol_library: Optional[SymbolLibrary] = None) -> Union[List[str], Node]
Simplify an expression algebraically.
Two successive steps are applied:
- SymPy simplification — expands and reduces the expression algebraically
(e.g.
X_0 * X_1 / X_0→X_1). - Constant folding — collapses any sub-expression containing no variables
into a single free constant
C(e.g.C * C + C→C).
Examples:
>>> expr = ["C", "+", "C", "*", "C", "+", "X_0", "*", "X_1", "/", "X_0"]
>>> print("".join(simplify(expr)))
C+X_1
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Union[List[str], Node]
|
Expression as a token list in infix notation or a Node tree. |
required |
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library defining variables and constants. Defaults to SymbolLibrary.default_symbols. |
None
|
Returns:
| Type | Description |
|---|---|
Union[List[str], Node]
|
The simplified expression in the same form as the input (list if a list was given, Node if a tree was given). |
Raises:
| Type | Description |
|---|---|
Exception
|
If simplification fails or the result contains tokens absent from
|
Source code in SRToolkit/utils/expression_simplifier.py
expr_to_latex
Convert an expression to a LaTeX string.
Examples:
>>> expr_to_latex(["(", "X_0", "+", "X_1", ")"], SymbolLibrary.default_symbols())
'$X_{0} + X_{1}$'
>>> expr = Node("+", Node("X_0"), Node("1"))
>>> expr_to_latex(expr, SymbolLibrary.default_symbols())
'$1 + X_{0}$'
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Union[Node, List[str]]
|
Expression as a token list or a Node tree. |
required |
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library providing LaTeX templates. If None, falls back to the currently active library set via 'with SymbolLibrary(...) as sl:'. Defaults to None. |
None
|
Returns:
| Type | Description |
|---|---|
str
|
A LaTeX string of the form |
Source code in SRToolkit/utils/expression_tree.py
is_float
Return True if element can be interpreted as a floating-point number.
Examples:
>>> is_float(1.0)
True
>>> is_float("1.0")
True
>>> is_float("1")
True
>>> is_float(None)
False
>>> is_float("hello")
False
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
element
|
Any
|
Value to test. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in SRToolkit/utils/expression_tree.py
tokens_to_tree
Parse a token list into an expression tree using the shunting-yard algorithm.
Examples:
>>> tree = tokens_to_tree(["(", "X_0", "+", "X_1", ")"], SymbolLibrary.default_symbols())
>>> len(tree)
3
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
List[str]
|
Token list in infix notation. |
required |
sl
|
Optional[SymbolLibrary]
|
Symbol library used to resolve token types and precedences. If None, falls back to the currently active library set via 'with SymbolLibrary(...) as sl:'. Defaults to None. |
None
|
Returns:
| Type | Description |
|---|---|
Node
|
Root Node of the parsed expression tree. |
Raises:
| Type | Description |
|---|---|
Exception
|
If a token is absent from |
Source code in SRToolkit/utils/expression_tree.py
bed
bed(expr1: Union[Node, List[str], ndarray], expr2: Union[Node, List[str], ndarray], X: Optional[ndarray] = None, num_consts_sampled: int = 32, num_points_sampled: int = 64, domain_bounds: Optional[List[Tuple[float, float]]] = None, consts_bounds: Tuple[float, float] = (-5, 5), symbol_library: Optional[SymbolLibrary] = None, seed: Optional[int] = None) -> float
Compute the Behavior-aware Expression Distance (BED) between two expressions.
BED measures how similarly two expressions behave over a domain by comparing their output distributions point-by-point using the Wasserstein distance. Free constants are marginalised by sampling multiple constant vectors via Latin Hypercube Sampling.
Either X or domain_bounds must be provided when expressions are given as
token lists or Node trees. Pre-computed behavior matrices can be passed
directly to avoid redundant evaluation.
Examples:
>>> X = np.random.rand(10, 2) - 0.5
>>> expr1 = ["X_0", "+", "C"] # instances of SRToolkit.utils.expression_tree.Node work as well
>>> expr2 = ["X_1", "+", "C"]
>>> bed(expr1, expr2, X) < 1
True
>>> # Changing the number of sampled constants
>>> bed(expr1, expr2, X, num_consts_sampled=128, consts_bounds=(-2, 2)) < 1
True
>>> # Sampling X instead of giving it directly by defining a domain
>>> bed(expr1, expr2, domain_bounds=[(0, 1), (0, 1)]) < 1
True
>>> bed(expr1, expr2, domain_bounds=[(0, 1), (0, 1)], num_points_sampled=128) < 1
True
>>> # You can use behavior matrices instead of expressions (this has potential computational advantages if same expression is used multiple times)
>>> bm1 = create_behavior_matrix(expr1, X)
>>> bed(bm1, expr2, X) < 1
True
>>> bm2 = create_behavior_matrix(expr2, X)
>>> bed(bm1, bm2) < 1
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr1
|
Union[Node, List[str], ndarray]
|
First expression as a token list, a Node tree, or a pre-computed
behavior matrix of shape |
required |
expr2
|
Union[Node, List[str], ndarray]
|
Second expression in the same format as |
required |
X
|
Optional[ndarray]
|
Evaluation points of shape |
None
|
num_consts_sampled
|
int
|
Number of constant vectors sampled per expression. Default |
32
|
num_points_sampled
|
int
|
Number of points sampled from |
64
|
domain_bounds
|
Optional[List[Tuple[float, float]]]
|
Per-variable |
None
|
consts_bounds
|
Tuple[float, float]
|
|
(-5, 5)
|
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used to compile expressions. Defaults to SymbolLibrary.default_symbols. |
None
|
seed
|
Optional[int]
|
Random seed for reproducible sampling. Default |
None
|
Returns:
| Type | Description |
|---|---|
float
|
BED between the expressions as a non-negative float. A value of |
float
|
indicates identical behavior over the sampled domain; larger values indicate |
float
|
greater behavioral dissimilarity. Returns |
float
|
produces finite outputs for one expression but not the other. |
Raises:
| Type | Description |
|---|---|
Exception
|
If |
Exception
|
If |
ValueError
|
If any entry in |
ValueError
|
If the two behavior matrices have different numbers of rows. |
ValueError
|
If the behavior matrices have zero rows. |
Source code in SRToolkit/utils/measures.py
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 | |
create_behavior_matrix
create_behavior_matrix(expr: Union[Node, List[str]], X: ndarray, num_consts_sampled: int = 32, consts_bounds: Tuple[float, float] = (-5, 5), symbol_library: Optional[SymbolLibrary] = None, seed: Optional[int] = None) -> np.ndarray
Evaluate an expression over multiple constant samples to produce a behavior matrix.
For expressions with free constants, constants are drawn via Latin Hypercube Sampling
within consts_bounds. For constant-free expressions, all columns are identical.
Examples:
>>> X = np.random.rand(10, 2) - 0.5
>>> create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32).shape
(10, 32)
>>> mean_0_1 = np.mean(create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32, consts_bounds=(0, 1)))
>>> mean_1_5 = np.mean(create_behavior_matrix(["X_0", "+", "C"], X, num_consts_sampled=32, consts_bounds=(1, 5)))
>>> print(bool(mean_0_1 < mean_1_5))
True
>>> # Deterministic expressions always produce the same behavior matrix
>>> bm1 = create_behavior_matrix(["X_0", "+", "X_1"], X)
>>> bm2 = create_behavior_matrix(["X_0", "+", "X_1"], X)
>>> print(bool(np.array_equal(bm1, bm2)))
True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr
|
Union[Node, List[str]]
|
Expression as a token list or a Node tree. |
required |
X
|
ndarray
|
Input data of shape |
required |
num_consts_sampled
|
int
|
Number of constant vectors to sample; sets the number of
output columns. Default |
32
|
consts_bounds
|
Tuple[float, float]
|
|
(-5, 5)
|
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used to compile the expression. Defaults to SymbolLibrary.default_symbols. |
None
|
seed
|
Optional[int]
|
Random seed for reproducible constant sampling. Default |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Behavior matrix of shape |
Raises:
| Type | Description |
|---|---|
Exception
|
If |
Source code in SRToolkit/utils/measures.py
edit_distance
edit_distance(expr1: Union[List[str], Node], expr2: Union[List[str], Node], notation: str = 'postfix', symbol_library: Optional[SymbolLibrary] = None) -> int
Compute the edit distance between two expressions.
Both expressions are first converted to the requested notation, so the result is independent of whether a token list or a Node tree is passed. Levenshtein distance is then computed on the serialised token sequences.
Examples:
>>> edit_distance(["X_0", "+", "1"], ["X_0", "+", "1"])
0
>>> edit_distance(["X_0", "+", "1"], ["X_0", "-", "1"])
1
>>> edit_distance(tokens_to_tree(["X_0", "+", "1"], SymbolLibrary.default_symbols(1)), tokens_to_tree(["X_0", "-", "1"], SymbolLibrary.default_symbols(1)))
1
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr1
|
Union[List[str], Node]
|
First expression as a token list or a Node tree. |
required |
expr2
|
Union[List[str], Node]
|
Second expression as a token list or a Node tree. |
required |
notation
|
str
|
Notation used for comparison: |
'postfix'
|
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used when converting expressions to the target notation. Defaults to SymbolLibrary.default_symbols. |
None
|
Returns:
| Type | Description |
|---|---|
int
|
Integer edit distance between the two serialised expressions. |
Source code in SRToolkit/utils/measures.py
tree_edit_distance
tree_edit_distance(expr1: Union[Node, List[str]], expr2: Union[Node, List[str]], symbol_library: Optional[SymbolLibrary] = None) -> int
Compute the Zhang-Shasha tree edit distance between two expressions.
Unlike edit_distance, which operates on flattened token sequences, tree edit distance considers the expression's hierarchical structure. The cost is the minimum number of node insertions, deletions, and relabellings needed to transform one tree into the other.
Examples:
>>> tree_edit_distance(["X_0", "+", "1"], ["X_0", "+", "1"])
0
>>> tree_edit_distance(["X_0", "+", "1"], ["X_0", "-", "1"])
1
>>> tree_edit_distance(tokens_to_tree(["X_0", "+", "1"], SymbolLibrary.default_symbols(1)), tokens_to_tree(["X_0", "-", "1"], SymbolLibrary.default_symbols(1)))
1
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expr1
|
Union[Node, List[str]]
|
First expression as a token list or a Node tree. |
required |
expr2
|
Union[Node, List[str]]
|
Second expression as a token list or a Node tree. |
required |
symbol_library
|
Optional[SymbolLibrary]
|
Symbol library used when converting token lists to trees. Defaults to SymbolLibrary.default_symbols. |
None
|
Returns:
| Type | Description |
|---|---|
int
|
Integer tree edit distance. |