Grammar-Guided Expression Generation
SRToolkit includes a context-free grammar (CFG) and probabilistic context-free grammar (PCFG) engine for generating expressions with fine-grained structural control. Grammars let you specify exactly which expression shapes are legal — restricting depth, operator combinations, variable usage, or physical units — without modifying your search algorithm.
Core concepts
A grammar is a set of production rules. Each rule expands a non-terminal symbol into a sequence of terminals and non-terminals. A symbol is a non-terminal if and only if it appears on the left-hand side of at least one rule; everything else is a terminal.
Expressions are generated by starting from a chosen non-terminal (the start symbol) and repeatedly replacing non-terminals with one of their alternatives until only terminals remain.
Building a grammar
Import the building blocks:
Add rules at construction time or incrementally with Grammar.add_rule:
g = Grammar([
Rule("E", ["E", "+", "F"], weight=0.4, name="E_add"),
Rule("E", ["F"], weight=0.6, name="E_to_F"),
Rule("F", ["x"], name="F_x"),
Rule("F", ["y"], name="F_y"),
], start="E")
# Or incrementally:
g2 = Grammar(start="E")
g2.add_rule(Rule("E", ["x"]))
g2.add_rule(Rule("E", ["y"]))
Each Rule has four fields:
| Field | Type | Default | Purpose |
|---|---|---|---|
lhs |
str |
— | Non-terminal being expanded |
rhs |
list[str] |
— | Ordered replacement symbols |
weight |
float |
1.0 |
Unnormalised sampling weight |
name |
str \| None |
None |
Stable identifier used by constraints |
When all rules for a non-terminal share the same weight (1.0) the grammar is a CFG (uniform sampling). When any weight differs it becomes a PCFG; the engine normalises weights within each group at sampling time.
Useful introspection properties:
g.nonterminals # set[str] of all LHS symbols
g.rules_for("E") # list[Rule] for that non-terminal, insertion order
g.is_pcfg() # True if any rule has weight != 1.0
Generating expressions
Call Grammar.generate_one for a single expression:
generate_one retries up to max_retries times (default 10) if sampling exceeds max_steps (default 1000) applications. It returns None when every attempt fails — a signal to relax the grammar or increase the limits:
tokens = g.generate_one(start="E", max_steps=500, max_retries=20)
if tokens is None:
print("could not generate an expression within the budget")
For more control, start a stateful Derivation and drive it manually or automatically:
d = g.start_derivation("E")
# Option A — automatic: sample to completion, get tokens directly
tokens = d.generate() # default limit=1000
# Option B — manual: inspect and apply rules yourself
d2 = g.start_derivation("E")
while not d2.complete:
rule = d2.options()[0] # or pick by weight via d2.sample()
d2.apply(rule)
tokens = d2.to_token_list() # flat token list from completed derivation
tree = d2.to_parse_tree() # ParseTree for structural analysis
Derivation methods at a glance
| Method | Returns | Notes |
|---|---|---|
d.options() |
list[Rule] |
Constraint-filtered candidates for the current slot |
d.apply(rule) |
None |
Apply one rule; raises if wrong lhs or already complete |
d.sample() |
None |
Apply one rule sampled by weight; raises if no candidates |
d.generate(limit=1000) |
list[str] |
Sample to completion, return token list |
d.to_token_list() |
list[str] |
Token list of a completed derivation |
d.to_parse_tree() |
ParseTree |
Parse tree of a completed derivation |
d.complete |
bool |
True when no unexpanded non-terminals remain |
to_token_list() and to_parse_tree() both raise RuntimeError if called before the derivation is complete.
ParseTree
ParseTree wraps the root ParseTreeNode and provides two methods:
tree.to_token_list() # ['x', '+', 'y'] — leaf symbols in order
tree.productions_used() # [Rule, ...] — applied rules in pre-order
productions_used() is useful for counting operators, computing expression complexity, or feeding rule-usage statistics back to a search algorithm.
Validating a parse tree
Grammar.validate checks that a ParseTree was produced by this grammar and satisfies all registered constraints:
ok = g.validate(tree) # root can be any NT
ok = g.validate(tree, require_start=True) # root must equal g.start
It returns False (never raises an exception) when the tree uses a foreign rule, has the wrong child count, contains an unexpanded non-terminal leaf, or violates a constraint. Passing require_start=True when g.start is None raises ValueError.
Derivation-generated trees always pass validation.
Grammar from a SymbolLibrary
Grammar.from_symbol_library builds a PCFG automatically from a SymbolLibrary using a standard operator-precedence hierarchy:
from SRToolkit.utils import SymbolLibrary
from SRToolkit.utils.grammar import Grammar
sl = SymbolLibrary.from_symbol_list(["+", "-", "*", "sin", "^2"], num_variables=2)
g = Grammar.from_symbol_library(sl)
tokens = g.generate_one() # e.g. ['X_0', '+', 'sin', '(', 'X_1', ')']
The resulting grammar uses heuristic weights that favour shallower, more readable expressions. Rule names follow a predictable scheme that the constraint system relies on for fast classification:
| Pattern | Example | Meaning |
|---|---|---|
E_add_{op} |
E_add_+ |
Additive binary operator |
F_mul_{op} |
F_mul_* |
Multiplicative binary operator |
B_pow_{op} |
B_pow_^ |
Power binary operator |
R_fn_{fn} |
R_fn_sin |
Prefix function |
R_postfix_{op} |
R_postfix_^2 |
Postfix unary operator |
V_{var} |
V_X_0 |
Variable leaf |
K_{sym} |
K_C, K_pi |
Constant or literal leaf |
Grammar from a string
For quick prototyping, Grammar.from_grammar_string parses NLTK production-rule notation:
- Terminals are enclosed in single quotes; unquoted tokens are non-terminals.
- An optional
[weight]at the end of an alternative sets the sampling weight. - A
# start: <NT>comment line sets the start symbol (ignored ifstart=is passed explicitly). - Rule names are not preserved in this format.
Convert back to the same notation with Grammar.to_grammar_string:
Note
to_grammar_string / from_grammar_string round-trip preserves rules, weights (normalized), and the start symbol (via the # start: comment). Rule names and constraints are dropped. Use to_dict / from_dict for a full round-trip.
Adding constraints
Constraints are hard filters applied at every derivation step. Register them with Grammar.add_constraint:
from SRToolkit.utils.grammar import MaxDepth, MaxNodes, MaxOccurrences, NoNested
g.add_constraint(MaxDepth(4)) # at most 4 levels of nesting
g.add_constraint(MaxNodes(20)) # at most 20 rule applications total
g.add_constraint(MaxOccurrences("x", 2)) # 'x' may appear at most twice
g.add_constraint(NoNested(["sin", "cos"])) # no sin/cos inside each other
Multiple constraints compose: a rule is offered only when every constraint accepts it.
MaxDepth
MaxDepth limits the nesting depth. A rule is rejected at any slot whose remaining depth budget is zero if applying it would introduce further non-terminal children.
g = Grammar([
Rule("E", ["E", "+", "E"]),
Rule("E", ["x"]),
])
g.add_constraint(MaxDepth(0))
d = g.start_derivation("E")
print([r.rhs for r in d.options()]) # [['x']] — recursive rule banned
MaxNodes
MaxNodes limits the total number of rule applications. It accounts for the open frontier, ensuring the derivation can always be completed within the budget.
MaxOccurrences
MaxOccurrences limits how many times a specific terminal appears in the finished expression:
NoNested
NoNested prevents any symbol in a group from appearing inside any other symbol in the same group:
g.add_constraint(NoNested(["sin", "cos"])) # no sin(cos(x)), sin(sin(x)), cos(cos(x)), or cos(sin(x))
g.add_constraint(NoNested("exp")) # no exp(exp(x))
Dimensional consistency
DimensionalConsistency enforces physical units during generation. It tracks the required unit at each open slot and rejects rules whose output unit would conflict.
Units are represented as dict[str, int | Fraction] mapping base dimension names to exponents. Dimensionless is {}.
from SRToolkit.utils.grammar import DimensionalConsistency
dc = DimensionalConsistency(
variable_units={
"v": {"m": 1, "s": -1}, # velocity [m/s]
"t": {"s": 1}, # time [s]
},
target_unit={"m": 1}, # target: length [m]
)
g = Grammar.from_symbol_library(sl)
g.add_constraint(dc)
tokens = g.generate_one() # guaranteed to have unit [m]
Named physical constants (type lit) are declared via constant_units. Free constants (type const) are treated as dimensionless by default; set allow_unit_polymorphic_constants=True to let them absorb any required unit.
dc = DimensionalConsistency(
variable_units={"F": {"kg": 1, "m": 1, "s": -2}, "m": {"kg": 1}},
target_unit={"m": 1, "s": -2}, # acceleration [m/s²]
constant_units={"g": {"m": 1, "s": -2}}, # gravitational acceleration
)
Note
DimensionalConsistency uses rule names (e.g. E_add_+, R_fn_sin) to classify rules efficiently. Grammars built with Grammar.from_symbol_library include these names automatically. For hand-written grammars, add names to rules whose classification matters; unnamed rules fall back to shape heuristics.
Serialization
Grammars and all built-in constraints support to_dict / from_dict for JSON round-trips:
import json
d = g.to_dict()
with open("grammar.json", "w") as f:
json.dump(d, f)
with open("grammar.json") as f:
g2 = Grammar.from_dict(json.load(f))
to_dict stores the start symbol, all rules (with names and weights), and all constraints. Constraints are identified by their fully-qualified class path, so from_dict uses importlib to reconstruct them — the same approach used for callbacks and samplers throughout SRToolkit.
DimensionalConsistency serializes unit exponents as "p/q" strings for exact Fraction fidelity. The symbol_library reference is not serialized; reconstruct it from the original symbol set if needed.
Custom constraints
Subclass Constraint and override the methods you need:
from SRToolkit.utils.grammar import Constraint, Grammar, Rule
class NoConstants(Constraint):
"""Reject any rule whose rhs contains 'C'."""
def allows(self, slot, rule, global_):
return "C" not in rule.rhs
g = Grammar([Rule("E", ["C"]), Rule("E", ["x"])])
g.add_constraint(NoConstants())
d = g.start_derivation("E")
print([r.rhs for r in d.options()]) # [['x']]
Constraint instances carry only construction-time configuration and are safe to share across parallel derivations. All per-derivation state is managed by the engine and passed through allows / update arguments.
The slot object (ExpansionContext)
The first argument to allows and update is an ExpansionContext describing the current expansion site:
| Field | Type | Description |
|---|---|---|
slot.nonterminal |
str |
The non-terminal being expanded |
slot.local |
Any |
Per-slot state for this constraint (see below) |
slot.nonterminals |
frozenset[str] |
All non-terminals in the grammar |
slot.frontier_size |
int |
Number of open slots at this point |
slot.steps |
int |
Total rule applications so far |
slot.parent_rule |
Rule \| None |
Rule that opened this slot, or None for the root |
slot.child_index |
int \| None |
NT-child index within parent_rule.rhs, or None for root |
slot.ancestors |
tuple[AncestorInfo, ...] |
Chain from root to this slot (outermost first) |
slot.partial_tree |
ParseTreeNode |
The partial tree node for this slot |
Each AncestorInfo in slot.ancestors has .nonterminal, .rule, and .child_index fields describing the ancestor expansion.
Local and global state
The engine provides two kinds of per-derivation state for constraints that need to track history:
Local state is per-slot and inherited from parent to children. The engine maintains a stack of local values alongside the open frontier; each slot starts with the value propagated from its parent's update call. Use local state for anything that depends on position in the tree, such as a remaining depth budget.
Global state is a single value shared across the entire derivation. It is updated after every rule application and is useful for running counters and accumulators, such as a total node count.
Both are initialised once per derivation — local state via initial_local(start), global state via initial_global() — and the engine threads the current values into every allows and update call.
# Local state: remaining depth budget inherited by each child slot
class BudgetDepth(Constraint):
def __init__(self, limit):
self.limit = limit
def initial_local(self, start):
return self.limit # every slot starts with full budget
def allows(self, slot, rule, global_):
if slot.local <= 0:
return not any(s in slot.nonterminals for s in rule.rhs)
return True
def update(self, slot, rule, global_):
n = sum(1 for s in rule.rhs if s in slot.nonterminals)
return [slot.local - 1] * n, global_ # children receive budget - 1
# Global state: running count of rule applications across the whole derivation
class ApplicationCounter(Constraint):
def __init__(self, limit):
self.limit = limit
def initial_global(self):
return 0 # counter starts at zero
def allows(self, slot, rule, global_):
return global_ < self.limit
def update(self, slot, rule, global_):
n = sum(1 for s in rule.rhs if s in slot.nonterminals)
return [None] * n, global_ + 1 # increment counter after each application
Both flavours can be combined in a single constraint by implementing all four methods together, as DimensionalConsistency does with its per-slot unit tracking.
To make a custom constraint serializable, implement to_dict and from_dict:
class NoConstants(Constraint):
def to_dict(self):
return {**super().to_dict()} # constraint_class is enough
@classmethod
def from_dict(cls, d):
return cls()
constraint_from_dict dispatches via importlib, so any class reachable by its module path is supported automatically.
Scoping
Set nonterminals and/or rule_names on the class to limit when allows is called. Slots outside scope are accepted unconditionally; update is always called regardless of scope so that global counters stay accurate.
class DepthLimitForE(Constraint):
nonterminals = frozenset({"E"}) # only active at E slots
def __init__(self, limit):
self.limit = limit
def initial_local(self, start):
return self.limit
def allows(self, slot, rule, global_):
if slot.local <= 0:
return not any(s in slot.nonterminals for s in rule.rhs)
return True
def update(self, slot, rule, global_):
n = sum(1 for s in rule.rhs if s in slot.nonterminals)
return [slot.local - 1] * n, global_