Skip to content

Grammar-Guided Expression Generation

SRToolkit includes a context-free grammar (CFG) and probabilistic context-free grammar (PCFG) engine for generating expressions with fine-grained structural control. Grammars let you specify exactly which expression shapes are legal — restricting depth, operator combinations, variable usage, or physical units — without modifying your search algorithm.

Core concepts

A grammar is a set of production rules. Each rule expands a non-terminal symbol into a sequence of terminals and non-terminals. A symbol is a non-terminal if and only if it appears on the left-hand side of at least one rule; everything else is a terminal.

E → E '+' F | F
F → 'x' | 'y'

Expressions are generated by starting from a chosen non-terminal (the start symbol) and repeatedly replacing non-terminals with one of their alternatives until only terminals remain.

Building a grammar

Import the building blocks:

from SRToolkit.utils.grammar import Grammar, Rule

Add rules at construction time or incrementally with Grammar.add_rule:

g = Grammar([
    Rule("E", ["E", "+", "F"], weight=0.4, name="E_add"),
    Rule("E", ["F"],            weight=0.6, name="E_to_F"),
    Rule("F", ["x"],                        name="F_x"),
    Rule("F", ["y"],                        name="F_y"),
], start="E")

# Or incrementally:
g2 = Grammar(start="E")
g2.add_rule(Rule("E", ["x"]))
g2.add_rule(Rule("E", ["y"]))

Each Rule has four fields:

Field Type Default Purpose
lhs str Non-terminal being expanded
rhs list[str] Ordered replacement symbols
weight float 1.0 Unnormalised sampling weight
name str \| None None Stable identifier used by constraints

When all rules for a non-terminal share the same weight (1.0) the grammar is a CFG (uniform sampling). When any weight differs it becomes a PCFG; the engine normalises weights within each group at sampling time.

Useful introspection properties:

g.nonterminals          # set[str] of all LHS symbols
g.rules_for("E")        # list[Rule] for that non-terminal, insertion order
g.is_pcfg()             # True if any rule has weight != 1.0

Generating expressions

Call Grammar.generate_one for a single expression:

tokens = g.generate_one()   # e.g. ['x', '+', 'y']

generate_one retries up to max_retries times (default 10) if sampling exceeds max_steps (default 1000) applications. It returns None when every attempt fails — a signal to relax the grammar or increase the limits:

tokens = g.generate_one(start="E", max_steps=500, max_retries=20)
if tokens is None:
    print("could not generate an expression within the budget")

For more control, start a stateful Derivation and drive it manually or automatically:

d = g.start_derivation("E")

# Option A — automatic: sample to completion, get tokens directly
tokens = d.generate()                # default limit=1000

# Option B — manual: inspect and apply rules yourself
d2 = g.start_derivation("E")
while not d2.complete:
    rule = d2.options()[0]           # or pick by weight via d2.sample()
    d2.apply(rule)
tokens = d2.to_token_list()          # flat token list from completed derivation
tree   = d2.to_parse_tree()          # ParseTree for structural analysis

Derivation methods at a glance

Method Returns Notes
d.options() list[Rule] Constraint-filtered candidates for the current slot
d.apply(rule) None Apply one rule; raises if wrong lhs or already complete
d.sample() None Apply one rule sampled by weight; raises if no candidates
d.generate(limit=1000) list[str] Sample to completion, return token list
d.to_token_list() list[str] Token list of a completed derivation
d.to_parse_tree() ParseTree Parse tree of a completed derivation
d.complete bool True when no unexpanded non-terminals remain

to_token_list() and to_parse_tree() both raise RuntimeError if called before the derivation is complete.

ParseTree

ParseTree wraps the root ParseTreeNode and provides two methods:

tree.to_token_list()      # ['x', '+', 'y'] — leaf symbols in order
tree.productions_used()   # [Rule, ...] — applied rules in pre-order

productions_used() is useful for counting operators, computing expression complexity, or feeding rule-usage statistics back to a search algorithm.

Validating a parse tree

Grammar.validate checks that a ParseTree was produced by this grammar and satisfies all registered constraints:

ok = g.validate(tree)                    # root can be any NT
ok = g.validate(tree, require_start=True)  # root must equal g.start

It returns False (never raises an exception) when the tree uses a foreign rule, has the wrong child count, contains an unexpanded non-terminal leaf, or violates a constraint. Passing require_start=True when g.start is None raises ValueError.

Derivation-generated trees always pass validation.

Grammar from a SymbolLibrary

Grammar.from_symbol_library builds a PCFG automatically from a SymbolLibrary using a standard operator-precedence hierarchy:

from SRToolkit.utils import SymbolLibrary
from SRToolkit.utils.grammar import Grammar

sl = SymbolLibrary.from_symbol_list(["+", "-", "*", "sin", "^2"], num_variables=2)
g = Grammar.from_symbol_library(sl)

tokens = g.generate_one()   # e.g. ['X_0', '+', 'sin', '(', 'X_1', ')']

The resulting grammar uses heuristic weights that favour shallower, more readable expressions. Rule names follow a predictable scheme that the constraint system relies on for fast classification:

Pattern Example Meaning
E_add_{op} E_add_+ Additive binary operator
F_mul_{op} F_mul_* Multiplicative binary operator
B_pow_{op} B_pow_^ Power binary operator
R_fn_{fn} R_fn_sin Prefix function
R_postfix_{op} R_postfix_^2 Postfix unary operator
V_{var} V_X_0 Variable leaf
K_{sym} K_C, K_pi Constant or literal leaf

Grammar from a string

For quick prototyping, Grammar.from_grammar_string parses NLTK production-rule notation:

g = Grammar.from_grammar_string("""
    # start: E
    E -> E '+' F [0.4] | F [0.6]
    F -> 'x' | 'y'
""")
  • Terminals are enclosed in single quotes; unquoted tokens are non-terminals.
  • An optional [weight] at the end of an alternative sets the sampling weight.
  • A # start: <NT> comment line sets the start symbol (ignored if start= is passed explicitly).
  • Rule names are not preserved in this format.

Convert back to the same notation with Grammar.to_grammar_string:

print(g.to_grammar_string())
# start: E
# E -> E '+' F [0.4] | F [0.6]
# F -> 'x' [0.5] | 'y' [0.5]

Note

to_grammar_string / from_grammar_string round-trip preserves rules, weights (normalized), and the start symbol (via the # start: comment). Rule names and constraints are dropped. Use to_dict / from_dict for a full round-trip.

Adding constraints

Constraints are hard filters applied at every derivation step. Register them with Grammar.add_constraint:

from SRToolkit.utils.grammar import MaxDepth, MaxNodes, MaxOccurrences, NoNested

g.add_constraint(MaxDepth(4))          # at most 4 levels of nesting
g.add_constraint(MaxNodes(20))         # at most 20 rule applications total
g.add_constraint(MaxOccurrences("x", 2))  # 'x' may appear at most twice
g.add_constraint(NoNested(["sin", "cos"]))  # no sin/cos inside each other

Multiple constraints compose: a rule is offered only when every constraint accepts it.

MaxDepth

MaxDepth limits the nesting depth. A rule is rejected at any slot whose remaining depth budget is zero if applying it would introduce further non-terminal children.

g = Grammar([
    Rule("E", ["E", "+", "E"]),
    Rule("E", ["x"]),
])
g.add_constraint(MaxDepth(0))

d = g.start_derivation("E")
print([r.rhs for r in d.options()])  # [['x']] — recursive rule banned

MaxNodes

MaxNodes limits the total number of rule applications. It accounts for the open frontier, ensuring the derivation can always be completed within the budget.

MaxOccurrences

MaxOccurrences limits how many times a specific terminal appears in the finished expression:

g.add_constraint(MaxOccurrences("sin", 1))  # at most one sin

NoNested

NoNested prevents any symbol in a group from appearing inside any other symbol in the same group:

g.add_constraint(NoNested(["sin", "cos"]))  # no sin(cos(x)), sin(sin(x)), cos(cos(x)), or cos(sin(x))
g.add_constraint(NoNested("exp"))           # no exp(exp(x))

Dimensional consistency

DimensionalConsistency enforces physical units during generation. It tracks the required unit at each open slot and rejects rules whose output unit would conflict.

Units are represented as dict[str, int | Fraction] mapping base dimension names to exponents. Dimensionless is {}.

from SRToolkit.utils.grammar import DimensionalConsistency

dc = DimensionalConsistency(
    variable_units={
        "v": {"m": 1, "s": -1},   # velocity [m/s]
        "t": {"s": 1},             # time [s]
    },
    target_unit={"m": 1},          # target: length [m]
)

g = Grammar.from_symbol_library(sl)
g.add_constraint(dc)

tokens = g.generate_one()   # guaranteed to have unit [m]

Named physical constants (type lit) are declared via constant_units. Free constants (type const) are treated as dimensionless by default; set allow_unit_polymorphic_constants=True to let them absorb any required unit.

dc = DimensionalConsistency(
    variable_units={"F": {"kg": 1, "m": 1, "s": -2}, "m": {"kg": 1}},
    target_unit={"m": 1, "s": -2},     # acceleration [m/s²]
    constant_units={"g": {"m": 1, "s": -2}},  # gravitational acceleration
)

Note

DimensionalConsistency uses rule names (e.g. E_add_+, R_fn_sin) to classify rules efficiently. Grammars built with Grammar.from_symbol_library include these names automatically. For hand-written grammars, add names to rules whose classification matters; unnamed rules fall back to shape heuristics.

Serialization

Grammars and all built-in constraints support to_dict / from_dict for JSON round-trips:

import json

d = g.to_dict()
with open("grammar.json", "w") as f:
    json.dump(d, f)

with open("grammar.json") as f:
    g2 = Grammar.from_dict(json.load(f))

to_dict stores the start symbol, all rules (with names and weights), and all constraints. Constraints are identified by their fully-qualified class path, so from_dict uses importlib to reconstruct them — the same approach used for callbacks and samplers throughout SRToolkit.

DimensionalConsistency serializes unit exponents as "p/q" strings for exact Fraction fidelity. The symbol_library reference is not serialized; reconstruct it from the original symbol set if needed.

Custom constraints

Subclass Constraint and override the methods you need:

from SRToolkit.utils.grammar import Constraint, Grammar, Rule

class NoConstants(Constraint):
    """Reject any rule whose rhs contains 'C'."""

    def allows(self, slot, rule, global_):
        return "C" not in rule.rhs

g = Grammar([Rule("E", ["C"]), Rule("E", ["x"])])
g.add_constraint(NoConstants())
d = g.start_derivation("E")
print([r.rhs for r in d.options()])  # [['x']]

Constraint instances carry only construction-time configuration and are safe to share across parallel derivations. All per-derivation state is managed by the engine and passed through allows / update arguments.

The slot object (ExpansionContext)

The first argument to allows and update is an ExpansionContext describing the current expansion site:

Field Type Description
slot.nonterminal str The non-terminal being expanded
slot.local Any Per-slot state for this constraint (see below)
slot.nonterminals frozenset[str] All non-terminals in the grammar
slot.frontier_size int Number of open slots at this point
slot.steps int Total rule applications so far
slot.parent_rule Rule \| None Rule that opened this slot, or None for the root
slot.child_index int \| None NT-child index within parent_rule.rhs, or None for root
slot.ancestors tuple[AncestorInfo, ...] Chain from root to this slot (outermost first)
slot.partial_tree ParseTreeNode The partial tree node for this slot

Each AncestorInfo in slot.ancestors has .nonterminal, .rule, and .child_index fields describing the ancestor expansion.

Local and global state

The engine provides two kinds of per-derivation state for constraints that need to track history:

Local state is per-slot and inherited from parent to children. The engine maintains a stack of local values alongside the open frontier; each slot starts with the value propagated from its parent's update call. Use local state for anything that depends on position in the tree, such as a remaining depth budget.

Global state is a single value shared across the entire derivation. It is updated after every rule application and is useful for running counters and accumulators, such as a total node count.

Both are initialised once per derivation — local state via initial_local(start), global state via initial_global() — and the engine threads the current values into every allows and update call.

# Local state: remaining depth budget inherited by each child slot
class BudgetDepth(Constraint):
    def __init__(self, limit):
        self.limit = limit

    def initial_local(self, start):
        return self.limit                        # every slot starts with full budget

    def allows(self, slot, rule, global_):
        if slot.local <= 0:
            return not any(s in slot.nonterminals for s in rule.rhs)
        return True

    def update(self, slot, rule, global_):
        n = sum(1 for s in rule.rhs if s in slot.nonterminals)
        return [slot.local - 1] * n, global_    # children receive budget - 1


# Global state: running count of rule applications across the whole derivation
class ApplicationCounter(Constraint):
    def __init__(self, limit):
        self.limit = limit

    def initial_global(self):
        return 0                                 # counter starts at zero

    def allows(self, slot, rule, global_):
        return global_ < self.limit

    def update(self, slot, rule, global_):
        n = sum(1 for s in rule.rhs if s in slot.nonterminals)
        return [None] * n, global_ + 1          # increment counter after each application

Both flavours can be combined in a single constraint by implementing all four methods together, as DimensionalConsistency does with its per-slot unit tracking.

To make a custom constraint serializable, implement to_dict and from_dict:

class NoConstants(Constraint):
    def to_dict(self):
        return {**super().to_dict()}   # constraint_class is enough

    @classmethod
    def from_dict(cls, d):
        return cls()

constraint_from_dict dispatches via importlib, so any class reachable by its module path is supported automatically.

Scoping

Set nonterminals and/or rule_names on the class to limit when allows is called. Slots outside scope are accepted unconditionally; update is always called regardless of scope so that global counters stay accurate.

class DepthLimitForE(Constraint):
    nonterminals = frozenset({"E"})   # only active at E slots

    def __init__(self, limit):
        self.limit = limit

    def initial_local(self, start):
        return self.limit

    def allows(self, slot, rule, global_):
        if slot.local <= 0:
            return not any(s in slot.nonterminals for s in rule.rhs)
        return True

    def update(self, slot, rule, global_):
        n = sum(1 for s in rule.rhs if s in slot.nonterminals)
        return [slot.local - 1] * n, global_