Skip to content

Constraints

SRToolkit.utils.grammar.constraints

Constraint protocol for grammar-guided expression generation.

A Constraint is a hard filter: at each derivation step its allows method decides which candidate rules survive.

All constraints operate on an ExpansionContext view that exposes the current derivation position together with the constraint's own per-slot local state.

Per-derivation state comes in two flavours:

  • Local — per-slot, inherited from parent to children. The engine maintains a stack parallel to the open frontier; each constraint sees only its own local value for the current slot.
  • Global — per-derivation scalar, lives on the engine alongside the frontier. Suitable for counters and accumulators.

Both flavours are owned by the engine; the Derivation initialises them at Grammar.start_derivation and threads them through every allows / update call. Constraint instances carry only construction-time configuration and are safe to share across parallel derivations.

Built-in constraints

TRANSCENDENTAL_FNS module-attribute

TRANSCENDENTAL_FNS: frozenset[str] = frozenset({'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2', 'arcsin', 'arccos', 'arctan', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh', 'exp', 'exp2', 'log', 'log2', 'log10', 'ln'})

Functions that both require dimensionless input and produce dimensionless output. Used by DimensionalConsistency.

UNIT_PRESERVING_FNS module-attribute

UNIT_PRESERVING_FNS: frozenset[str] = frozenset({'abs', 'fabs', 'floor', 'ceil', 'round', 'sign', 'sgn', 'neg'})

Functions whose output carries the same unit as their input. Used by DimensionalConsistency.

SQRT_FNS module-attribute

SQRT_FNS: frozenset[str] = frozenset({'sqrt'})

Square-root functions: output unit = input unit raised to the power ½. Used by DimensionalConsistency.

CBRT_FNS module-attribute

CBRT_FNS: frozenset[str] = frozenset({'cbrt'})

Cube-root functions: output unit = input unit raised to the power ⅓. Used by DimensionalConsistency.

AncestorInfo dataclass

AncestorInfo(nonterminal: str, rule: Rule, child_index: int)

One entry in an ExpansionContext's ancestor chain.

Attributes:

Name Type Description
nonterminal str

The non-terminal symbol at this ancestor position.

rule Rule

The rule applied at this position.

child_index int

The index within rule.rhs that leads toward the current slot, counting only non-terminal positions.

ExpansionContext dataclass

ExpansionContext(nonterminal: str, local: L, steps: int, parent_rule: Optional[Rule], child_index: Optional[int], ancestors: tuple[AncestorInfo, ...], partial_tree: ParseTreeNode, nonterminals: frozenset[str], frontier_size: int = 1)

Bases: Generic[L]

Read-only view of the current derivation position passed to constraints.

The engine constructs an ExpansionContext on demand for each (constraint, candidate_rule) pair evaluated in Derivation.options and for each (constraint, selected_rule) pair in Derivation.apply.

Attributes:

Name Type Description
nonterminal str

The non-terminal symbol being expanded at this position.

local L

This constraint's per-slot inherited state.

steps int

Number of rule applications made so far in the derivation (i.e. how many times apply() has been called).

parent_rule Optional[Rule]

The rule whose application created this slot (None for the start symbol).

child_index Optional[int]

Index of this slot in parent_rule.rhs counting only non-terminal positions (None for the start symbol).

ancestors tuple[AncestorInfo, ...]

Ancestor chain from the root to this slot's immediate parent, in root-first order.

partial_tree ParseTreeNode

Root of the partially-built parse tree. Read-only by contract — constraints must not mutate it.

nonterminals frozenset[str]

Frozen set of all non-terminal symbols in the grammar.

frontier_size int

Number of open (unexpanded) non-terminal slots in the current derivation frontier, including this slot. Useful for budget constraints such as MaxNodes that need to account for the minimum number of rule applications still required to complete the derivation.

Constraint

Bases: Generic[L, G]

Base class for derivation constraints (hard filters).

Subclass and override the methods you need. Constraints carry only construction-time configuration; all per-derivation state is managed by the engine and threaded through method arguments.

Scoping — set nonterminals and/or rule_names to restrict allows to a subset of slots or rules. A scope miss is treated as implicit acceptance. update is always called regardless of scope so that global counters stay accurate.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([Rule("E", ["x"]), Rule("E", ["y"])])
>>> class RejectY(Constraint):
...     def allows(self, slot, rule, global_): return rule.rhs != ["y"]
>>> g.add_constraint(RejectY())
>>> d = g.start_derivation("E")
>>> [r.rhs for r in d.options()]
[['x']]

to_dict

to_dict() -> dict

Serialise this constraint to a JSON-safe dictionary.

The base implementation stores only the fully-qualified class path under constraint_class. Subclasses should call super().to_dict() and add their own constructor arguments so that from_dict can reconstruct the instance faithfully. See from_dict for an example.

Returns:

Type Description
dict

Dictionary with at least the key constraint_class.

Source code in SRToolkit/utils/grammar/constraints.py
def to_dict(self) -> dict:
    """
    Serialise this constraint to a JSON-safe dictionary.

    The base implementation stores only the fully-qualified class path under
    ``constraint_class``. Subclasses should call ``super().to_dict()`` and
    add their own constructor arguments so that ``from_dict`` can reconstruct
    the instance faithfully. See
    [from_dict][SRToolkit.utils.grammar.constraints.Constraint.from_dict]
    for an example.

    Returns:
        Dictionary with at least the key ``constraint_class``.
    """
    return {"constraint_class": f"{self.__class__.__module__}.{self.__class__.__qualname__}"}

from_dict classmethod

from_dict(d: dict) -> 'Constraint'

Reconstruct a constraint from a dictionary produced by to_dict.

When called on the base Constraint class, dispatches to the correct subclass via the constraint_class key using importlib — both built-in and user-defined subclasses are supported. When called on a concrete subclass, the subclass must override this method.

The dictionary must contain at minimum the key constraint_class, whose value is the fully-qualified class path (e.g. "mymodule.MyConstraint"). Any additional keys are forwarded to the subclass override.

To make a custom subclass serialisable, override both to_dict and from_dict::

class MyConstraint(Constraint):
    def __init__(self, threshold: float) -> None:
        self.threshold = threshold

    def to_dict(self) -> dict:
        return {**super().to_dict(), "threshold": self.threshold}

    @classmethod
    def from_dict(cls, d: dict) -> "MyConstraint":
        return cls(d["threshold"])

Examples:

>>> from SRToolkit.utils.grammar import Constraint, MaxDepth
>>> c = Constraint.from_dict(MaxDepth(5).to_dict())
>>> isinstance(c, MaxDepth) and c.limit == 5
True

Parameters:

Name Type Description Default
d dict

Dictionary previously returned by to_dict.

required

Returns:

Type Description
'Constraint'

A reconstructed Constraint instance.

Raises:

Type Description
KeyError

If constraint_class is missing from d (dispatch path).

ImportError

If the module cannot be imported (dispatch path).

AttributeError

If the class cannot be found in the module (dispatch path).

NotImplementedError

If called on a subclass that has not overridden this method.

Source code in SRToolkit/utils/grammar/constraints.py
@classmethod
def from_dict(cls, d: dict) -> "Constraint":
    """
    Reconstruct a constraint from a dictionary produced by ``to_dict``.

    When called on the base [Constraint][SRToolkit.utils.grammar.constraints.Constraint]
    class, dispatches to the correct subclass via the ``constraint_class`` key using
    `importlib` — both built-in and user-defined subclasses are supported.
    When called on a concrete subclass, the subclass must override this method.

    The dictionary must contain at minimum the key ``constraint_class``, whose value
    is the fully-qualified class path (e.g. ``"mymodule.MyConstraint"``).  Any
    additional keys are forwarded to the subclass override.

    To make a custom subclass serialisable, override both ``to_dict`` and
    ``from_dict``::

        class MyConstraint(Constraint):
            def __init__(self, threshold: float) -> None:
                self.threshold = threshold

            def to_dict(self) -> dict:
                return {**super().to_dict(), "threshold": self.threshold}

            @classmethod
            def from_dict(cls, d: dict) -> "MyConstraint":
                return cls(d["threshold"])

    Examples:
        >>> from SRToolkit.utils.grammar import Constraint, MaxDepth
        >>> c = Constraint.from_dict(MaxDepth(5).to_dict())
        >>> isinstance(c, MaxDepth) and c.limit == 5
        True

    Args:
        d: Dictionary previously returned by ``to_dict``.

    Returns:
        A reconstructed [Constraint][SRToolkit.utils.grammar.constraints.Constraint] instance.

    Raises:
        KeyError: If ``constraint_class`` is missing from ``d`` (dispatch path).
        ImportError: If the module cannot be imported (dispatch path).
        AttributeError: If the class cannot be found in the module (dispatch path).
        NotImplementedError: If called on a subclass that has not overridden this method.
    """
    if cls is Constraint:
        class_path = d["constraint_class"]
        module_path, cls_name = class_path.rsplit(".", 1)
        resolved = getattr(importlib.import_module(module_path), cls_name)
        return resolved.from_dict(d)
    raise NotImplementedError(f"{cls.__name__}.from_dict is not implemented.")

initial_local

initial_local(start: str) -> L

Return the initial local state for the constraint.

Source code in SRToolkit/utils/grammar/constraints.py
def initial_local(self, start: str) -> L:
    """Return the initial local state for the constraint."""
    return None  # type: ignore[return-value]

initial_global

initial_global() -> G

Return the initial global state for the constraint.

Source code in SRToolkit/utils/grammar/constraints.py
def initial_global(self) -> G:
    """Return the initial global state for the constraint."""
    return None  # type: ignore[return-value]

allows

allows(slot: ExpansionContext[L], rule: Rule, global_: G) -> bool

Return True if rule may be applied at slot.

Called only when the slot's non-terminal and rule name are within scope.

Parameters:

Name Type Description Default
slot ExpansionContext[L]

Current derivation position with this constraint's local state.

required
rule Rule

Candidate production rule.

required
global_ G

Current per-derivation global state.

required

Returns:

Type Description
bool

True to keep the rule in the candidate set.

Source code in SRToolkit/utils/grammar/constraints.py
def allows(self, slot: ExpansionContext[L], rule: Rule, global_: G) -> bool:
    """
    Return ``True`` if ``rule`` may be applied at ``slot``.

    Called only when the slot's non-terminal and rule name are within scope.

    Args:
        slot: Current derivation position with this constraint's local state.
        rule: Candidate production rule.
        global_: Current per-derivation global state.

    Returns:
        ``True`` to keep the rule in the candidate set.
    """
    return True

update

update(slot: ExpansionContext[L], rule: Rule, global_: G) -> tuple[list[L], G]

Called after rule is applied at slot.

Returns per-child local states (one per non-terminal in rule.rhs, in order) and the new global state. May be used to update global counters by returning a new global_ value.

Parameters:

Name Type Description Default
slot ExpansionContext[L]

Derivation position immediately before the rule was applied.

required
rule Rule

The rule that was applied.

required
global_ G

Global state before this application.

required

Returns:

Type Description
list[L]

(child_locals, new_global) where len(child_locals) equals

G

the number of non-terminals in rule.rhs.

Source code in SRToolkit/utils/grammar/constraints.py
def update(self, slot: ExpansionContext[L], rule: Rule, global_: G) -> tuple[list[L], G]:
    """
    Called after ``rule`` is applied at ``slot``.

    Returns per-child local states (one per non-terminal in ``rule.rhs``,
    in order) and the new global state. May be used to update global
    counters by returning a new ``global_`` value.

    Args:
        slot: Derivation position immediately before the rule was applied.
        rule: The rule that was applied.
        global_: Global state before this application.

    Returns:
        ``(child_locals, new_global)`` where ``len(child_locals)`` equals
        the number of non-terminals in ``rule.rhs``.
    """
    n = sum(1 for s in rule.rhs if s in slot.nonterminals)
    return [slot.local] * n, global_

MaxDepth

MaxDepth(limit: int)

Bases: Constraint[int, None]

Hard limit on derivation depth.

Local state is the remaining depth budget at each slot. A rule is rejected when its application would require at least one recursive non-terminal child but the budget has reached zero.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["E", "+", "E"]),
...     Rule("E", ["x"]),
... ])
>>> g.add_constraint(MaxDepth(0))
>>> d = g.start_derivation("E")
>>> all(r.rhs == ["x"] for r in d.options())
True

Parameters:

Name Type Description Default
limit int

Maximum nesting depth (number of non-terminal expansion levels).

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, limit: int) -> None:
    self.limit = limit

MaxNodes

MaxNodes(limit: int)

Bases: Constraint[None, int]

Hard limit on the total number of rule applications in a derivation.

Global state is a running count of applications so far. A rule is rejected when applying it would push the count over the limit taking into account the non-terminal children it introduces (each of which will require at least one more application).

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["E", "+", "E"]),
...     Rule("E", ["x"]),
... ])
>>> g.add_constraint(MaxNodes(3))
>>> d = g.start_derivation("E")
>>> # At node-count 3 only the terminal rule survives
>>> tokens = d.generate()
>>> len(tokens) <= 5
True

Parameters:

Name Type Description Default
limit int

Maximum number of rule applications.

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, limit: int) -> None:
    self.limit = limit

MaxOccurrences

MaxOccurrences(symbol: str, limit: int)

Bases: Constraint[None, int]

Hard limit on how many times a specific terminal symbol may appear.

Global state counts occurrences committed so far. A rule whose rhs contains the tracked symbol is rejected once the count equals the limit.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["E", "+", "E"]),
...     Rule("E", ["x"]),
...     Rule("E", ["y"]),
... ])
>>> g.add_constraint(MaxOccurrences("x", 1))
>>> d = g.start_derivation("E")
>>> tokens = d.generate(limit=200)
>>> tokens.count("x") <= 1
True

Parameters:

Name Type Description Default
symbol str

Terminal token to track.

required
limit int

Maximum allowed occurrences.

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, symbol: str, limit: int) -> None:
    self.symbol = symbol
    self.limit = limit

NoNested

NoNested(symbols: str | Iterable[str])

Bases: Constraint[bool, None]

Prevent any symbol in a group from appearing nested inside any other symbol in the same group.

Local state is a boolean "currently under any symbol in the group". Any rule whose rhs contains a group symbol is rejected while the local state is True. Children inherit True once such a rule is applied.

Pass a single string to forbid self-nesting only; pass multiple symbols to forbid cross-nesting within the group (e.g. NoNested(TRIG_FNS) prevents sin(cos(x)) as well as sin(sin(x))).

Warning

This constraint propagates the "inside" flag to all non-terminal children of any rule whose rhs contains a group symbol. It works correctly when each group symbol appears in a rule with exactly one non-terminal child (typical prefix-function rules such as E -> sin ( E )) and for infix operators where every child should be blocked (e.g. E -> E + E). It does not work correctly for mixed rules that combine a function call with additional non-terminal siblings in a single production (e.g. E -> sin ( E ) + F): the sibling F will be incorrectly treated as inside the group. If your grammar uses such rules, implement a custom Constraint instead.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> g = Grammar([
...     Rule("E", ["sin", "(", "E", ")"]),
...     Rule("E", ["cos", "(", "E", ")"]),
...     Rule("E", ["x"]),
... ])
>>> g.add_constraint(NoNested(["sin", "cos"]))
>>> d = g.start_derivation("E")
>>> d.apply(g.rules_for("E")[0])   # apply sin(E)
>>> # Now inside sin; both sin and cos rules should be gone
>>> opts = d.options()
>>> all("sin" not in r.rhs and "cos" not in r.rhs for r in opts)
True

Parameters:

Name Type Description Default
symbols str | Iterable[str]

A single terminal token or an iterable of terminal tokens that form the nesting group.

required
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(self, symbols: str | Iterable[str]) -> None:
    if isinstance(symbols, str):
        self.symbols: frozenset[str] = frozenset({symbols})
    else:
        self.symbols = frozenset(symbols)

DimensionalConsistency

DimensionalConsistency(variable_units: dict[str, dict], target_unit: dict, constant_units: Optional[dict[str, dict]] = None, symbol_library: Optional[SymbolLibrary] = None, allow_unit_polymorphic_constants: bool = False)

Bases: Constraint[Optional[Unit], None]

Stateful constraint that enforces dimensional analysis during expression generation.

Local state at each slot is the required unit (Optional[Unit]). None means "any unit is acceptable here" — used for the children of multiplicative operators, whose individual units are underdetermined.

Unit representation: units are dict[str, Fraction] mapping base dimension names (e.g. "m", "s", "kg") to rational exponents. Dimensionless is {} (empty dict).

Free constants (const type): treated as dimensionless by default. Set allow_unit_polymorphic_constants=True to let them absorb any required unit.

Named physical constants (lit type, e.g. "g", "c"): declared via constant_units; checked like variables.

Undeclared variables or literals: conservatively accepted.

Rule classification: uses rule.name when available (see Grammar.from_symbol_library for the naming scheme). Falls back to rhs-shape heuristics for unnamed rules.

Examples:

>>> from SRToolkit.utils.grammar import Grammar, Rule
>>> from SRToolkit.utils.grammar import DimensionalConsistency
>>> from fractions import Fraction
>>> g = Grammar()
>>> g.add_rule(Rule("E", ["E", "+", "F"], weight=0.6, name="E_add_+"))
>>> g.add_rule(Rule("E", ["F"], weight=0.4, name="E_to_F"))
>>> g.add_rule(Rule("F", ["v"], weight=0.5, name="F_v"))
>>> g.add_rule(Rule("F", ["t"], weight=0.5, name="F_t"))
>>> dc = DimensionalConsistency(
...     variable_units={"v": {"m": 1, "s": -1}, "t": {"s": 1}},
...     target_unit={"m": 1, "s": -1},
... )
>>> g.add_constraint(dc)
>>> d = g.start_derivation("E")
>>> all(r.rhs != ["t"] for r in d.options())
True

Parameters:

Name Type Description Default
variable_units dict[str, dict]

Mapping from variable token to its unit.

required
target_unit dict

Required unit of the generated expression.

required
constant_units Optional[dict[str, dict]]

Units for named physical constants (lit type).

None
symbol_library Optional[SymbolLibrary]

Used to classify tokens by type and precedence.

None
allow_unit_polymorphic_constants bool

If True, free constants absorb whatever unit their slot requires. Default False.

False
Source code in SRToolkit/utils/grammar/constraints.py
def __init__(
    self,
    variable_units: dict[str, dict],
    target_unit: dict,
    constant_units: Optional[dict[str, dict]] = None,
    symbol_library: Optional[SymbolLibrary] = None,
    allow_unit_polymorphic_constants: bool = False,
) -> None:
    self._var_units: dict[str, Unit] = {k: _to_unit(v) for k, v in variable_units.items()}
    self._target: Unit = _to_unit(target_unit)
    self._const_units: dict[str, Unit] = {k: _to_unit(v) for k, v in (constant_units or {}).items()}
    self._sl = symbol_library
    self._allow_poly_const = allow_unit_polymorphic_constants