Skip to content

EDHiE

SRToolkit.approaches.EDHiE

EDHiE approach — equation discovery with hierarchical variational autoencoders by Mežnar et al.

EDHiEConfig dataclass

EDHiEConfig(name: str = 'EDHiE', approach_class: str = '', latent_size: int = 32, num_expressions: int = 20000, max_expression_length: int = 30, only_unique_expressions: bool = True, epochs: int = 20, batch_size: int = 32, max_beta: float = 0.035, population_size: int = 40, weights_path: Optional[str] = None, verbose: bool = True)

Bases: ApproachConfig

Configuration dataclass for the EDHiE approach.

Examples:

>>> cfg = EDHiEConfig(latent_size=32, epochs=10)
>>> cfg.name
'EDHiE'
>>> d = cfg.to_dict()
>>> EDHiEConfig.from_dict(d).latent_size
32

EDHiE

EDHiE(latent_size: int = 32, num_expressions: int = 20000, max_expression_length: int = 30, only_unique_expressions: bool = True, epochs: int = 20, batch_size: int = 32, max_beta: float = 0.035, population_size: int = 40, weights_path: Optional[str] = None, verbose: bool = True)

Bases: SR_approach

EDHiE — Equation Discovery with Hierarchical variational autoEncoders by Mežnar et al.

Trains a Hierarchical VAE (HVAE) on randomly generated expressions from the symbol library (adapt), then explores the learned latent space with a genetic algorithm to find expressions that best fit the target dataset (search).

Examples:

>>> model = EDHiE(latent_size=24, num_expressions=100, epochs=1, verbose=False)
>>> model.name
'EDHiE'

Parameters:

Name Type Description Default
latent_size int

Dimensionality of the HVAE latent space.

32
num_expressions int

Number of random expressions generated to train the HVAE.

20000
max_expression_length int

Maximum token length of generated training expressions.

30
only_unique_expressions bool

If True, only unique expressions are generated.

True
epochs int

Number of training epochs for the HVAE.

20
batch_size int

Batch size used during HVAE training.

32
max_beta float

Maximum value of the KL annealing coefficient (controls regularization strength).

0.035
population_size int

GA population size used during latent space search.

40
weights_path Optional[str]

Optional path to load/save HVAE weights from/to disk.

None
verbose bool

If True, prints training loss and new best expressions during search.

True
Source code in SRToolkit/approaches/EDHiE.py
def __init__(
    self,
    latent_size: int = 32,
    num_expressions: int = 20000,
    max_expression_length: int = 30,
    only_unique_expressions: bool = True,
    epochs: int = 20,
    batch_size: int = 32,
    max_beta: float = 0.035,
    population_size: int = 40,
    weights_path: Optional[str] = None,
    verbose: bool = True,
) -> None:
    r"""
    EDHiE — Equation Discovery with Hierarchical variational autoEncoders by Mežnar et al.

    Trains a Hierarchical VAE (HVAE) on randomly generated expressions from the symbol library
    (``adapt``), then explores the learned latent space with a genetic algorithm to find
    expressions that best fit the target dataset (``search``).

    Examples:
        >>> model = EDHiE(latent_size=24, num_expressions=100, epochs=1, verbose=False)
        >>> model.name
        'EDHiE'

    Args:
        latent_size: Dimensionality of the HVAE latent space.
        num_expressions: Number of random expressions generated to train the HVAE.
        max_expression_length: Maximum token length of generated training expressions.
        only_unique_expressions: If ``True``, only unique expressions are generated.
        epochs: Number of training epochs for the HVAE.
        batch_size: Batch size used during HVAE training.
        max_beta: Maximum value of the KL annealing coefficient (controls regularization strength).
        population_size: GA population size used during latent space search.
        weights_path: Optional path to load/save HVAE weights from/to disk.
        verbose: If ``True``, prints training loss and new best expressions during search.
    """
    super().__init__(
        EDHiEConfig(
            latent_size=latent_size,
            num_expressions=num_expressions,
            max_expression_length=max_expression_length,
            only_unique_expressions=only_unique_expressions,
            epochs=epochs,
            batch_size=batch_size,
            max_beta=max_beta,
            population_size=population_size,
            weights_path=weights_path,
            verbose=verbose,
        )
    )
    check_dependencies(["pytorch", "pymoo"])
    self.latent_size = latent_size
    self.num_expressions = num_expressions
    self.max_expression_length = max_expression_length
    self.only_unique_expressions = only_unique_expressions
    self.epochs = epochs
    self.batch_size = batch_size
    self.max_beta = max_beta
    self.population_size = population_size
    self.verbose = verbose

    self.model: Optional[HVAE] = None
    if weights_path is not None:
        self.load_adapted_state(weights_path)

prepare

prepare() -> None

Reset per-experiment state.

The trained HVAE weights are preserved across experiments; only the evolutionary search is re-initialised on each call to search.

Returns:

Type Description
None

None

Source code in SRToolkit/approaches/EDHiE.py
def prepare(self) -> None:
    """
    Reset per-experiment state.

    The trained HVAE weights are preserved across experiments; only the evolutionary search
    is re-initialised on each call to
    [search][SRToolkit.approaches.EDHiE.EDHiE.search].

    Returns:
        None
    """
    pass

adapt

adapt(X: ndarray, symbol_library: SymbolLibrary) -> None

Train the HVAE on randomly generated expressions from the symbol library.

Parameters:

Name Type Description Default
X ndarray

Input data from the domain (shape (n_samples, n_variables)). Only used to determine the number of variables; target values are not accessed.

required
symbol_library SymbolLibrary

Symbol library defining available tokens.

required

Returns:

Type Description
None

None

Source code in SRToolkit/approaches/EDHiE.py
def adapt(self, X: np.ndarray, symbol_library: SymbolLibrary) -> None:
    """
    Train the HVAE on randomly generated expressions from the symbol library.

    Args:
        X: Input data from the domain (shape ``(n_samples, n_variables)``). Only used to
            determine the number of variables; target values are not accessed.
        symbol_library: Symbol library defining available tokens.

    Returns:
        None
    """
    expressions = generate_n_expressions(
        symbol_library,
        self.num_expressions,
        max_expression_length=self.max_expression_length,
        unique=self.only_unique_expressions,
    )
    expr_trees: List[Optional[Node]] = [tokens_to_tree(expr, symbol_library) for expr in expressions]
    trainset = TreeDataset(expr_trees)

    self.model = HVAE(len(symbol_library), self.latent_size, symbol_library)
    train_hvae(
        self.model,
        trainset,
        symbol_library,
        epochs=self.epochs,
        batch_size=self.batch_size,
        max_beta=self.max_beta,
        verbose=self.verbose,
    )

save_adapted_state

save_adapted_state(path: str) -> None

Save the trained HVAE model weights and architecture metadata to disk.

Saves a single checkpoint file containing the symbol library, latent size, hidden size, max height, and model weights — enough to fully reconstruct the model in load_adapted_state without needing to call adapt first.

Parameters:

Name Type Description Default
path str

File path to save the checkpoint to, including the file extension, e.g. path/model.pt.

required
Source code in SRToolkit/approaches/EDHiE.py
def save_adapted_state(self, path: str) -> None:
    """
    Save the trained HVAE model weights and architecture metadata to disk.

    Saves a single checkpoint file containing the symbol library, latent size,
    hidden size, max height, and model weights — enough to fully reconstruct
    the model in [load_adapted_state][SRToolkit.approaches.EDHiE.EDHiE.load_adapted_state]
    without needing to call [adapt][SRToolkit.approaches.EDHiE.EDHiE.adapt] first.

    Args:
        path: File path to save the checkpoint to, including the file extension, e.g. ``path/model.pt``.
    """
    if self.model is None:
        warnings.warn("[EDHiE.save_adapted_state] No model to save.")
        return
    torch.save(
        {
            "symbol_library": self.model.decoder.symbol_library.to_dict(),
            "latent_size": self.latent_size,
            "hidden_size": self.model.decoder.hidden_size,
            "max_height": self.model.decoder.max_height,
            "state_dict": self.model.state_dict(),
        },
        path,
    )

load_adapted_state

load_adapted_state(path: str) -> None

Restore a previously trained HVAE model from disk.

Reconstructs the HVAE from the architecture metadata saved alongside the weights, so no prior call to adapt is required.

Parameters:

Name Type Description Default
path str

File path previously passed to save_adapted_state.

required
Source code in SRToolkit/approaches/EDHiE.py
def load_adapted_state(self, path: str) -> None:
    """
    Restore a previously trained HVAE model from disk.

    Reconstructs the [HVAE][SRToolkit.approaches.EDHiE.HVAE] from the architecture
    metadata saved alongside the weights, so no prior call to
    [adapt][SRToolkit.approaches.EDHiE.EDHiE.adapt] is required.

    Args:
        path: File path previously passed to
            [save_adapted_state][SRToolkit.approaches.EDHiE.EDHiE.save_adapted_state].
    """
    checkpoint = torch.load(path, weights_only=False)
    symbol_library = SymbolLibrary.from_dict(checkpoint["symbol_library"])
    self.model = HVAE(
        input_size=len(symbol_library),
        output_size=checkpoint["latent_size"],
        symbol_library=symbol_library,
        hidden_size=checkpoint["hidden_size"],
        max_height=checkpoint["max_height"],
    )
    self.model.load_state_dict(checkpoint["state_dict"])

search

search(sr_evaluator: SR_evaluator, seed: Optional[int] = None) -> None

Explore the HVAE latent space with a genetic algorithm to find the best expression.

Parameters:

Name Type Description Default
sr_evaluator SR_evaluator

SR_evaluator used to score candidate expressions.

required
seed Optional[int]

Optional random seed for reproducibility.

None

Returns:

Type Description
None

None

Raises:

Type Description
RuntimeError

If adapt has not been called before search.

Source code in SRToolkit/approaches/EDHiE.py
def search(self, sr_evaluator: SR_evaluator, seed: Optional[int] = None) -> None:
    """
    Explore the HVAE latent space with a genetic algorithm to find the best expression.

    Args:
        sr_evaluator: [SR_evaluator][SRToolkit.evaluation.sr_evaluator.SR_evaluator] used to
            score candidate expressions.
        seed: Optional random seed for reproducibility.

    Returns:
        None

    Raises:
        RuntimeError: If [adapt][SRToolkit.approaches.EDHiE.EDHiE.adapt] has not been called
            before ``search``.
    """
    if self.model is None:
        raise RuntimeError("EDHiE.adapt() must be called before search().")

    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
        random.seed(seed)

    problem = _HVAEProblem(self.model, sr_evaluator, self.latent_size)
    ga = GA(
        pop_size=self.population_size,
        sampling=_TorchNormalSampling(),
        crossover=_LICrossover(),
        mutation=_RandomMutation(),
        eliminate_duplicates=False,
    )
    minimize(
        problem,
        ga,
        _BestTermination(),
        verbose=False,
    )

BatchedNode

BatchedNode(symbol2index: Dict[str, int], size: int = 0, trees: Optional[List[Optional[Node]]] = None)

Batched binary tree node for vectorised HVAE forward/backward passes.

Mirrors the recursive structure of Node but stores a batch of trees simultaneously. Each position in symbols corresponds to one tree in the batch; empty strings "" mark padding (absent subtrees). left and right are themselves BatchedNode instances (or None) representing the batched subtrees.

After all trees have been added via add_tree, call create_target to build the one-hot target tensors and masks required by the HVAE loss.

We recommend using [SRToolkit.approaches.EDHiE.create_batch] to create batched trees.

Parameters:

Name Type Description Default
symbol2index Dict[str, int]

Mapping from symbol strings to vocabulary indices.

required
size int

Number of padding slots to pre-allocate (filled with "").

0
trees Optional[List[Optional[Node]]]

Optional list of Node trees to add immediately via add_tree.

None
Source code in SRToolkit/approaches/EDHiE.py
def __init__(self, symbol2index: Dict[str, int], size: int = 0, trees: Optional[List[Optional[Node]]] = None):
    """
    Args:
        symbol2index: Mapping from symbol strings to vocabulary indices.
        size: Number of padding slots to pre-allocate (filled with ``""``).
        trees: Optional list of [Node][SRToolkit.utils.expression_tree.Node] trees to add
            immediately via [add_tree][SRToolkit.approaches.EDHiE.BatchedNode.add_tree].
    """
    self.symbols: List[str] = ["" for _ in range(size)]  # For when a new leaf node is added
    self.left: Optional[BatchedNode] = None
    self.right: Optional[BatchedNode] = None
    self.symbol2index: Dict[str, int] = symbol2index
    self.mask: Optional[torch.Tensor] = None
    self.target: Optional[torch.Tensor] = None
    self.target_indices: Optional[torch.Tensor] = None
    self.prediction: Optional[torch.Tensor] = None

    if trees is not None:
        for tree in trees:
            self.add_tree(tree)

add_tree

add_tree(tree: Optional[Node] = None) -> None

Append one tree (or a padding slot) to the batch.

If tree is None, an empty padding entry ("") is appended at every level. Otherwise the tree's symbol is appended and its subtrees are recursively merged into self.left / self.right, creating new BatchedNode children as needed.

Parameters:

Name Type Description Default
tree Optional[Node]

A Node to add, or None to append a padding slot.

None

Returns:

Type Description
None

None

Source code in SRToolkit/approaches/EDHiE.py
def add_tree(self, tree: Optional[Node] = None) -> None:
    """
    Append one tree (or a padding slot) to the batch.

    If ``tree`` is ``None``, an empty padding entry (``""``) is appended at every level.
    Otherwise the tree's symbol is appended and its subtrees are recursively merged into
    ``self.left`` / ``self.right``, creating new ``BatchedNode`` children as needed.

    Args:
        tree: A [Node][SRToolkit.utils.expression_tree.Node] to add, or ``None`` to append
            a padding slot.

    Returns:
        None
    """
    # Add an empty subtree to the batch. This is used to pad the batch so that it
    # has the same number of symbols in each node.
    if tree is None:
        self.symbols.append("")

        # Recursively fill subtrees if they exist
        if self.left is not None:
            self.left.add_tree()
        if self.right is not None:
            self.right.add_tree()
    # Add the given subtree to the batch
    else:
        self.symbols.append(tree.symbol)

        # Add the left subtree
        if isinstance(self.left, BatchedNode) and isinstance(tree.left, Node):
            self.left.add_tree(tree.left)
        # Add the empty subtree to the right batched node
        elif isinstance(self.left, BatchedNode):
            self.left.add_tree()
        # Add new batched nodes to the left
        elif isinstance(tree.left, Node):
            self.left = BatchedNode(self.symbol2index, size=len(self.symbols) - 1)
            self.left.add_tree(tree.left)

        # Add the right subtree
        if isinstance(self.right, BatchedNode) and isinstance(tree.right, Node):
            self.right.add_tree(tree.right)
        # Add the empty subtree to the right batched node
        elif isinstance(self.right, BatchedNode):
            self.right.add_tree()
        # Add the new batched node to the right
        elif isinstance(tree.right, Node):
            self.right = BatchedNode(self.symbol2index, size=len(self.symbols) - 1)
            self.right.add_tree(tree.right)

create_target

create_target() -> None

Build one-hot target tensors, integer target indices, and binary masks for the entire subtree.

Populates self.target (one-hot, shape (batch, vocab)), self.target_indices (integer class indices, shape (batch,), -1 for padding), and self.mask (1.0 for real symbols, 0.0 for padding). Recurses into left and right.

Returns:

Type Description
None

None

Source code in SRToolkit/approaches/EDHiE.py
def create_target(self) -> None:
    """
    Build one-hot target tensors, integer target indices, and binary masks for the entire subtree.

    Populates ``self.target`` (one-hot, shape ``(batch, vocab)``), ``self.target_indices``
    (integer class indices, shape ``(batch,)``, ``-1`` for padding), and ``self.mask``
    (``1.0`` for real symbols, ``0.0`` for padding). Recurses into ``left`` and ``right``.

    Returns:
        None
    """
    target = torch.zeros((len(self.symbols), len(self.symbol2index)))
    mask = torch.ones(len(self.symbols))
    target_indices = torch.full((len(self.symbols),), -1, dtype=torch.long)

    for i, s in enumerate(self.symbols):
        if s == "":
            mask[i] = 0
        else:
            target[i, self.symbol2index[s]] = 1
            target_indices[i] = self.symbol2index[s]

    self.mask = mask
    self.target = target
    self.target_indices = target_indices

    if self.left is not None:
        self.left.create_target()
    if self.right is not None:
        self.right.create_target()

to_expr_list

to_expr_list() -> List[Optional[Node]]

Extract one Node tree per batch element.

Returns:

Type Description
List[Optional[Node]]

A list of Node trees (or None for padding

List[Optional[Node]]

slots), one per position in the batch.

Source code in SRToolkit/approaches/EDHiE.py
def to_expr_list(self) -> List[Optional[Node]]:
    """
    Extract one [Node][SRToolkit.utils.expression_tree.Node] tree per batch element.

    Returns:
        A list of [Node][SRToolkit.utils.expression_tree.Node] trees (or ``None`` for padding
        slots), one per position in the batch.
    """
    return [self.get_expr_at_idx(i) for i in range(len(self.symbols))]

get_expr_at_idx

get_expr_at_idx(idx: int) -> Optional[Node]

Reconstruct the expression tree for a single batch element.

Parameters:

Name Type Description Default
idx int

Batch index to retrieve.

required

Returns:

Type Description
Optional[Node]

A Node tree, or None if the slot is padding.

Source code in SRToolkit/approaches/EDHiE.py
def get_expr_at_idx(self, idx: int) -> Optional[Node]:
    """
    Reconstruct the expression tree for a single batch element.

    Args:
        idx: Batch index to retrieve.

    Returns:
        A [Node][SRToolkit.utils.expression_tree.Node] tree, or ``None`` if the slot is padding.
    """
    symbol = self.symbols[idx]
    if symbol == "":
        return None

    left = self.left.get_expr_at_idx(idx) if isinstance(self.left, BatchedNode) else None
    right = self.right.get_expr_at_idx(idx) if isinstance(self.right, BatchedNode) else None

    return Node(symbol, left=left, right=right)

get_prediction

get_prediction() -> torch.Tensor

Collect decoder logit predictions from all nodes in infix order.

Returns:

Type Description
Tensor

Stacked logit tensor of shape (batch, vocab, n_nodes).

Source code in SRToolkit/approaches/EDHiE.py
def get_prediction(self) -> torch.Tensor:
    """
    Collect decoder logit predictions from all nodes in infix order.

    Returns:
        Stacked logit tensor of shape ``(batch, vocab, n_nodes)``.
    """
    predictions = self.get_prediction_rec()
    return torch.stack(predictions, dim=2)

get_target

get_target() -> torch.Tensor

Collect one-hot target indices from all nodes in infix order.

Returns:

Type Description
Tensor

Stacked target tensor of shape (batch, n_nodes) with -1 for padding.

Source code in SRToolkit/approaches/EDHiE.py
def get_target(self) -> torch.Tensor:
    """
    Collect one-hot target indices from all nodes in infix order.

    Returns:
        Stacked target tensor of shape ``(batch, n_nodes)`` with ``-1`` for padding.
    """
    targets = self.get_target_rec()
    return torch.stack(targets, dim=1)

HVAE

HVAE(input_size: int, output_size: int, symbol_library: SymbolLibrary, hidden_size: Optional[int] = None, max_height: int = 20)

Bases: Module

Hierarchical Variational Autoencoder (HVAE) for expression trees.

Combines a tree-recursive Encoder with a tree-recursive Decoder to learn a continuous latent representation of symbolic expressions.

Parameters:

Name Type Description Default
input_size int

Vocabulary size (number of symbols in the symbol library).

required
output_size int

Dimensionality of the latent space.

required
symbol_library SymbolLibrary

SymbolLibrary used by the decoder to constrain leaf/non-leaf symbol selection.

required
hidden_size Optional[int]

Hidden state size for the GRU cells. Defaults to output_size.

None
max_height int

Maximum tree depth the decoder will generate before forcing leaf symbols.

20
Source code in SRToolkit/approaches/EDHiE.py
def __init__(
    self,
    input_size: int,
    output_size: int,
    symbol_library: SymbolLibrary,
    hidden_size: Optional[int] = None,
    max_height: int = 20,
):
    """
    Args:
        input_size: Vocabulary size (number of symbols in the symbol library).
        output_size: Dimensionality of the latent space.
        symbol_library: [SymbolLibrary][SRToolkit.utils.symbol_library.SymbolLibrary] used by
            the decoder to constrain leaf/non-leaf symbol selection.
        hidden_size: Hidden state size for the GRU cells. Defaults to ``output_size``.
        max_height: Maximum tree depth the decoder will generate before forcing leaf symbols.
    """
    super().__init__()

    if hidden_size is None:
        hidden_size = output_size

    self.encoder = Encoder(input_size, hidden_size, output_size)
    self.decoder = Decoder(output_size, hidden_size, input_size, symbol_library, max_height)

forward

forward(tree: BatchedNode) -> Tuple[torch.Tensor, torch.Tensor, BatchedNode]

Run the full VAE forward pass: encode, sample latent vector, decode.

Parameters:

Name Type Description Default
tree BatchedNode

Batched expression trees (training mode — teacher-forced decoding).

required

Returns:

Type Description
Tensor

A 3-tuple (mu, logvar, reconstructed_tree) where mu and logvar are the

Tensor

approximate posterior parameters and reconstructed_tree is the decoder output.

Source code in SRToolkit/approaches/EDHiE.py
def forward(self, tree: BatchedNode) -> Tuple[torch.Tensor, torch.Tensor, BatchedNode]:
    """
    Run the full VAE forward pass: encode, sample latent vector, decode.

    Args:
        tree: Batched expression trees (training mode — teacher-forced decoding).

    Returns:
        A 3-tuple ``(mu, logvar, reconstructed_tree)`` where ``mu`` and ``logvar`` are the
        approximate posterior parameters and ``reconstructed_tree`` is the decoder output.
    """
    mu, logvar = self.encoder(tree)
    z = self.sample(mu, logvar)
    out = self.decoder(z, tree)
    return mu, logvar, out

sample

sample(mu: Tensor, logvar: Tensor) -> torch.Tensor

Draw a latent sample using the reparameterisation trick.

Parameters:

Name Type Description Default
mu Tensor

Mean of the approximate posterior, shape (batch, latent_size).

required
logvar Tensor

Log-variance of the approximate posterior, same shape.

required

Returns:

Type Description
Tensor

Sampled latent vector z = mu + eps * exp(logvar / 2).

Source code in SRToolkit/approaches/EDHiE.py
def sample(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """
    Draw a latent sample using the reparameterisation trick.

    Args:
        mu: Mean of the approximate posterior, shape ``(batch, latent_size)``.
        logvar: Log-variance of the approximate posterior, same shape.

    Returns:
        Sampled latent vector ``z = mu + eps * exp(logvar / 2)``.
    """
    eps = torch.randn_like(mu)
    std = torch.exp(logvar / 2.0)
    return mu + eps * std

encode

encode(tree: BatchedNode) -> Tuple[torch.Tensor, torch.Tensor]

Encode a batch of trees to posterior parameters without sampling.

Parameters:

Name Type Description Default
tree BatchedNode

Batched expression trees.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

A 2-tuple (mu, logvar) — the approximate posterior mean and log-variance.

Source code in SRToolkit/approaches/EDHiE.py
def encode(self, tree: BatchedNode) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Encode a batch of trees to posterior parameters without sampling.

    Args:
        tree: Batched expression trees.

    Returns:
        A 2-tuple ``(mu, logvar)`` — the approximate posterior mean and log-variance.
    """
    mu, logvar = self.encoder(tree)
    return mu, logvar

decode

decode(z: Tensor) -> List[Optional[Node]]

Decode a batch of latent vectors to expression trees (inference mode).

Parameters:

Name Type Description Default
z Tensor

Latent vectors of shape (batch, latent_size).

required

Returns:

Type Description
List[Optional[Node]]

A list of Node trees (or None for failed

List[Optional[Node]]

decodes), one per row of z.

Source code in SRToolkit/approaches/EDHiE.py
def decode(self, z: torch.Tensor) -> List[Optional[Node]]:
    """
    Decode a batch of latent vectors to expression trees (inference mode).

    Args:
        z: Latent vectors of shape ``(batch, latent_size)``.

    Returns:
        A list of [Node][SRToolkit.utils.expression_tree.Node] trees (or ``None`` for failed
        decodes), one per row of ``z``.
    """
    return self.decoder.decode(z)

Encoder

Encoder(input_size: int, hidden_size: int, output_size: int)

Bases: Module

Tree-recursive GRU encoder that maps a BatchedNode tree to approximate posterior parameters (mu, logvar).

Uses a GRU221 cell to combine left/right child hidden states bottom-up, then projects to mu and logvar via linear layers.

Parameters:

Name Type Description Default
input_size int

Vocabulary size (one-hot input dimension).

required
hidden_size int

Hidden state dimensionality of the GRU.

required
output_size int

Dimensionality of the latent space (mu and logvar output size).

required
Source code in SRToolkit/approaches/EDHiE.py
def __init__(self, input_size: int, hidden_size: int, output_size: int):
    """
    Args:
        input_size: Vocabulary size (one-hot input dimension).
        hidden_size: Hidden state dimensionality of the GRU.
        output_size: Dimensionality of the latent space (``mu`` and ``logvar`` output size).
    """
    super().__init__()
    self.hidden_size = hidden_size
    self.gru = GRU221(input_size=input_size, hidden_size=hidden_size)
    self.mu = nn.Linear(in_features=hidden_size, out_features=output_size)
    self.logvar = nn.Linear(in_features=hidden_size, out_features=output_size)

    torch.nn.init.xavier_uniform_(self.mu.weight)
    torch.nn.init.xavier_uniform_(self.logvar.weight)

forward

forward(tree: BatchedNode) -> Tuple[torch.Tensor, torch.Tensor]

Encode a batched tree to posterior parameters.

Parameters:

Name Type Description Default
tree BatchedNode

Batched expression tree with targets and masks populated.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

A 2-tuple (mu, logvar) of shape (batch, latent_size) each.

Source code in SRToolkit/approaches/EDHiE.py
def forward(self, tree: BatchedNode) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Encode a batched tree to posterior parameters.

    Args:
        tree: Batched expression tree with targets and masks populated.

    Returns:
        A 2-tuple ``(mu, logvar)`` of shape ``(batch, latent_size)`` each.
    """
    if tree.target is None or tree.mask is None:
        raise RuntimeError(
            "[Encoder.forward] BatchedNode.create_target() must be called before encoding. We recommend using SRToolkit.approaches.EDHiE.create_batch() to create batched trees."
        )
    tree_encoding = self.recursive_forward(tree)
    mu = self.mu(tree_encoding)
    logvar = self.logvar(tree_encoding)
    return mu, logvar

recursive_forward

recursive_forward(tree: BatchedNode) -> torch.Tensor

Recursively encode a subtree bottom-up using the GRU cell.

Parameters:

Name Type Description Default
tree BatchedNode

Current subtree node (batched).

required

Returns:

Type Description
Tensor

Hidden state tensor of shape (batch, hidden_size) for this subtree root.

Source code in SRToolkit/approaches/EDHiE.py
def recursive_forward(self, tree: BatchedNode) -> torch.Tensor:
    """
    Recursively encode a subtree bottom-up using the GRU cell.

    Args:
        tree: Current subtree node (batched).

    Returns:
        Hidden state tensor of shape ``(batch, hidden_size)`` for this subtree root.
    """
    if isinstance(tree.left, BatchedNode):
        h_left = self.recursive_forward(tree.left)
    else:
        # Type checking suppressed because if tree.target exists is checked in forward() and target is added recursively.
        h_left = torch.zeros(len(tree.symbols), self.hidden_size, device=tree.target.device)  # type: ignore[union-attr]

    if isinstance(tree.right, BatchedNode):
        h_right = self.recursive_forward(tree.right)
    else:
        # Type checking suppressed because if tree.target exists is checked in forward() and target is added recursively.
        h_right = torch.zeros(len(tree.symbols), self.hidden_size, device=tree.target.device)  # type: ignore[union-attr]

    hidden = self.gru(tree.target, h_left, h_right)
    # Type checking suppressed because if tree.mask exists is checked in forward() and mask is added recursively.
    hidden = hidden.mul(tree.mask[:, None])  # type: ignore[index]
    return hidden

Decoder

Decoder(input_size: int, hidden_size: int, output_size: int, symbol_library: SymbolLibrary, max_height: int = 20)

Bases: Module

Tree-recursive GRU decoder that maps a latent vector to an expression tree.

Uses a GRU122 cell to split each parent hidden state top-down into left and right child hidden states, then predicts the symbol at each node.

Parameters:

Name Type Description Default
input_size int

Latent space dimensionality (decoder input).

required
hidden_size int

Hidden state dimensionality of the GRU.

required
output_size int

Vocabulary size (logit output dimension).

required
symbol_library SymbolLibrary

Used to identify leaf symbols and enforce the max_height constraint.

required
max_height int

Maximum tree depth; beyond this depth only leaf symbols are sampled.

20
Source code in SRToolkit/approaches/EDHiE.py
def __init__(
    self, input_size: int, hidden_size: int, output_size: int, symbol_library: SymbolLibrary, max_height: int = 20
):
    """
    Args:
        input_size: Latent space dimensionality (decoder input).
        hidden_size: Hidden state dimensionality of the GRU.
        output_size: Vocabulary size (logit output dimension).
        symbol_library: Used to identify leaf symbols and enforce the ``max_height`` constraint.
        max_height: Maximum tree depth; beyond this depth only leaf symbols are sampled.
    """
    super().__init__()
    self.hidden_size = hidden_size
    self.z2h = nn.Linear(input_size, hidden_size)
    self.h2o = nn.Linear(hidden_size, output_size)
    self.gru = GRU122(input_size=output_size, hidden_size=hidden_size)

    self.max_height = max_height
    self.symbol_library = symbol_library
    self.symbol2index = symbol_library.symbols2index()
    self.index2symbol = {i: s for s, i in self.symbol2index.items()}
    leaf_symbols_mask = torch.zeros(len(self.symbol2index))
    for s, i in self.symbol2index.items():
        if self.symbol_library.get_type(s) in ["var", "const", "lit"]:
            leaf_symbols_mask[i] = 1
    self.register_buffer("leaf_symbols_mask", leaf_symbols_mask)

    torch.nn.init.xavier_uniform_(self.z2h.weight)
    torch.nn.init.xavier_uniform_(self.h2o.weight)

forward

forward(z: Tensor, tree: BatchedNode) -> BatchedNode

Teacher-forced decoding pass used during training.

Stores logit predictions in each node of tree (in-place) using the ground-truth tree structure to guide recursion.

Parameters:

Name Type Description Default
z Tensor

Latent vectors of shape (batch, latent_size).

required
tree BatchedNode

Ground-truth batched tree; predictions are written into each node's prediction attribute.

required

Returns:

Type Description
BatchedNode

The same tree object with prediction populated at every node.

Source code in SRToolkit/approaches/EDHiE.py
def forward(self, z: torch.Tensor, tree: BatchedNode) -> BatchedNode:
    """
    Teacher-forced decoding pass used during training.

    Stores logit predictions in each node of ``tree`` (in-place) using the ground-truth
    tree structure to guide recursion.

    Args:
        z: Latent vectors of shape ``(batch, latent_size)``.
        tree: Ground-truth batched tree; predictions are written into each node's
            ``prediction`` attribute.

    Returns:
        The same ``tree`` object with ``prediction`` populated at every node.
    """
    hidden = self.z2h(z)
    self.recursive_forward(hidden, tree)
    return tree

recursive_forward

recursive_forward(hidden: Tensor, tree: BatchedNode) -> None

Recursively write predictions into a subtree (teacher-forced, training only).

Parameters:

Name Type Description Default
hidden Tensor

Current node hidden state, shape (batch, hidden_size).

required
tree BatchedNode

Current subtree node; prediction is set in-place.

required

Returns:

Type Description
None

None

Source code in SRToolkit/approaches/EDHiE.py
def recursive_forward(self, hidden: torch.Tensor, tree: BatchedNode) -> None:  # type: ignore[override]
    """
    Recursively write predictions into a subtree (teacher-forced, training only).

    Args:
        hidden: Current node hidden state, shape ``(batch, hidden_size)``.
        tree: Current subtree node; ``prediction`` is set in-place.

    Returns:
        None
    """
    prediction = self.h2o(hidden)
    tree.prediction = prediction
    symbol_probs = F.softmax(prediction, dim=1)
    continue_left = isinstance(tree.left, BatchedNode)
    continue_right = isinstance(tree.right, BatchedNode)
    if continue_left or continue_right:
        left, right = self.gru(symbol_probs, hidden)
        if continue_left:
            self.recursive_forward(left, tree.left)  # type: ignore[arg-type]
        if continue_right:
            self.recursive_forward(right, tree.right)  # type: ignore[arg-type]

decode

decode(z: Tensor) -> List[Optional[Node]]

Autoregressively decode latent vectors to expression trees (inference mode).

Parameters:

Name Type Description Default
z Tensor

Latent vectors of shape (batch, latent_size).

required

Returns:

Type Description
List[Optional[Node]]

A list of Node trees, one per row of z.

Source code in SRToolkit/approaches/EDHiE.py
def decode(self, z: torch.Tensor) -> List[Optional[Node]]:
    """
    Autoregressively decode latent vectors to expression trees (inference mode).

    Args:
        z: Latent vectors of shape ``(batch, latent_size)``.

    Returns:
        A list of [Node][SRToolkit.utils.expression_tree.Node] trees, one per row of ``z``.
    """
    with torch.no_grad():
        mask = torch.ones(z.size(0)).bool()
        hidden = self.z2h(z)
        batch = self.recursive_decode(hidden, mask)
        return batch.to_expr_list()

recursive_decode

recursive_decode(hidden: Tensor, mask: Tensor, height: int = 0) -> BatchedNode

Recursively decode hidden states into a batched subtree (inference mode).

Parameters:

Name Type Description Default
hidden Tensor

Current node hidden state, shape (batch, hidden_size).

required
mask Tensor

Boolean mask; True for active (non-padding) batch elements.

required
height int

Current tree depth (used to enforce max_height).

0

Returns:

Type Description
BatchedNode

A BatchedNode representing this subtree.

Source code in SRToolkit/approaches/EDHiE.py
def recursive_decode(self, hidden: torch.Tensor, mask: torch.Tensor, height: int = 0) -> BatchedNode:
    """
    Recursively decode hidden states into a batched subtree (inference mode).

    Args:
        hidden: Current node hidden state, shape ``(batch, hidden_size)``.
        mask: Boolean mask; ``True`` for active (non-padding) batch elements.
        height: Current tree depth (used to enforce ``max_height``).

    Returns:
        A [BatchedNode][SRToolkit.approaches.EDHiE.BatchedNode] representing this subtree.
    """
    prediction = F.softmax(self.h2o(hidden), dim=1)
    # Sample symbol in a given node
    symbols, left_mask, right_mask = self.sample_symbol(prediction, mask, height)
    has_left = torch.any(left_mask)
    has_right = torch.any(right_mask)
    if has_left or has_right:
        left, right = self.gru(prediction, hidden)
        l_tree = self.recursive_decode(left, left_mask, height + 1) if has_left else None
        r_tree = self.recursive_decode(right, right_mask, height + 1) if has_right else None
    else:
        l_tree = None
        r_tree = None

    node = BatchedNode(self.symbol2index)
    node.symbols = symbols
    node.left = l_tree
    node.right = r_tree
    return node

sample_symbol

sample_symbol(prediction: Tensor, mask: Tensor, height: int) -> Tuple[List[str], torch.Tensor, torch.Tensor]

Greedily select one symbol per batch element and compute child masks.

Parameters:

Name Type Description Default
prediction Tensor

Softmax probabilities over the vocabulary, shape (batch, vocab).

required
mask Tensor

Boolean mask for active batch elements.

required
height int

Current tree depth; at max_height only leaf symbols are eligible.

required

Returns:

Type Description
List[str]

A 3-tuple (symbols, left_mask, right_mask) where symbols is a list of

Tensor

selected symbol strings ("" for masked positions), and the masks indicate which

Tensor

batch elements should recurse into the left/right child.

Source code in SRToolkit/approaches/EDHiE.py
def sample_symbol(
    self, prediction: torch.Tensor, mask: torch.Tensor, height: int
) -> Tuple[List[str], torch.Tensor, torch.Tensor]:
    """
    Greedily select one symbol per batch element and compute child masks.

    Args:
        prediction: Softmax probabilities over the vocabulary, shape ``(batch, vocab)``.
        mask: Boolean mask for active batch elements.
        height: Current tree depth; at ``max_height`` only leaf symbols are eligible.

    Returns:
        A 3-tuple ``(symbols, left_mask, right_mask)`` where ``symbols`` is a list of
        selected symbol strings (``""`` for masked positions), and the masks indicate which
        batch elements should recurse into the left/right child.
    """
    symbols = []
    left_mask = mask.clone()
    right_mask = mask.clone()

    if height >= self.max_height:
        prediction = prediction * self.leaf_symbols_mask

    for i in range(prediction.size(0)):
        if mask[i]:
            symbol = self.index2symbol[int(torch.argmax(prediction[i, :]))]
            symbols.append(symbol)
            if self.symbol_library.get_type(symbol) == "fn":
                right_mask[i] = False
            if self.symbol_library.get_type(symbol) in ["var", "const", "lit"]:
                left_mask[i] = False
                right_mask[i] = False
        else:
            symbols.append("")
    return symbols, left_mask, right_mask

GRU221

GRU221(input_size: int, hidden_size: int)

Bases: Module

Two-input, one-output GRU cell used by the Encoder.

Combines an input vector with two hidden states (left child and right child) into a single parent hidden state. The two child states are concatenated before being passed through the standard GRU gating equations.

Parameters:

Name Type Description Default
input_size int

Dimensionality of the input vector (vocabulary one-hot size).

required
hidden_size int

Dimensionality of each child hidden state and the output hidden state.

required
Source code in SRToolkit/approaches/EDHiE.py
def __init__(self, input_size: int, hidden_size: int):
    """
    Args:
        input_size: Dimensionality of the input vector (vocabulary one-hot size).
        hidden_size: Dimensionality of each child hidden state and the output hidden state.
    """
    super().__init__()
    self.hidden_size = hidden_size
    self.wir = nn.Linear(in_features=input_size, out_features=hidden_size)
    self.whr = nn.Linear(in_features=2 * hidden_size, out_features=hidden_size)
    self.wiz = nn.Linear(in_features=input_size, out_features=hidden_size)
    self.whz = nn.Linear(in_features=2 * hidden_size, out_features=hidden_size)
    self.win = nn.Linear(in_features=input_size, out_features=hidden_size)
    self.whn = nn.Linear(in_features=2 * hidden_size, out_features=hidden_size)
    torch.nn.init.xavier_uniform_(self.wir.weight)
    torch.nn.init.xavier_uniform_(self.whr.weight)
    torch.nn.init.xavier_uniform_(self.wiz.weight)
    torch.nn.init.xavier_uniform_(self.whz.weight)
    torch.nn.init.xavier_uniform_(self.win.weight)
    torch.nn.init.xavier_uniform_(self.whn.weight)

forward

forward(x: Tensor, h1: Tensor, h2: Tensor) -> torch.Tensor

Compute parent hidden state from input and two child hidden states.

Parameters:

Name Type Description Default
x Tensor

Input (symbol embedding / one-hot), shape (batch, input_size).

required
h1 Tensor

Left child hidden state, shape (batch, hidden_size).

required
h2 Tensor

Right child hidden state, shape (batch, hidden_size).

required

Returns:

Type Description
Tensor

Parent hidden state of shape (batch, hidden_size).

Source code in SRToolkit/approaches/EDHiE.py
def forward(self, x: torch.Tensor, h1: torch.Tensor, h2: torch.Tensor) -> torch.Tensor:
    """
    Compute parent hidden state from input and two child hidden states.

    Args:
        x: Input (symbol embedding / one-hot), shape ``(batch, input_size)``.
        h1: Left child hidden state, shape ``(batch, hidden_size)``.
        h2: Right child hidden state, shape ``(batch, hidden_size)``.

    Returns:
        Parent hidden state of shape ``(batch, hidden_size)``.
    """
    h = torch.cat([h1, h2], dim=1)
    r = torch.sigmoid(self.wir(x) + self.whr(h))
    z = torch.sigmoid(self.wiz(x) + self.whz(h))
    n = torch.tanh(self.win(x) + r * self.whn(h))
    return (1 - z) * n + (z / 2) * h1 + (z / 2) * h2

GRU122

GRU122(input_size: int, hidden_size: int)

Bases: Module

One-input, two-output GRU cell used by the Decoder.

Splits a parent hidden state into two child hidden states (left and right) driven by an input vector. The output is split evenly along the hidden dimension.

Parameters:

Name Type Description Default
input_size int

Dimensionality of the input vector (symbol probability distribution).

required
hidden_size int

Dimensionality of each output child hidden state (output is 2 * hidden_size before splitting).

required
Source code in SRToolkit/approaches/EDHiE.py
def __init__(self, input_size: int, hidden_size: int):
    """
    Args:
        input_size: Dimensionality of the input vector (symbol probability distribution).
        hidden_size: Dimensionality of each output child hidden state (output is
            ``2 * hidden_size`` before splitting).
    """
    super().__init__()
    self.hidden_size = hidden_size
    self.wir = nn.Linear(in_features=input_size, out_features=2 * hidden_size)
    self.whr = nn.Linear(in_features=hidden_size, out_features=2 * hidden_size)
    self.wiz = nn.Linear(in_features=input_size, out_features=2 * hidden_size)
    self.whz = nn.Linear(in_features=hidden_size, out_features=2 * hidden_size)
    self.win = nn.Linear(in_features=input_size, out_features=2 * hidden_size)
    self.whn = nn.Linear(in_features=hidden_size, out_features=2 * hidden_size)
    torch.nn.init.xavier_uniform_(self.wir.weight)
    torch.nn.init.xavier_uniform_(self.whr.weight)
    torch.nn.init.xavier_uniform_(self.wiz.weight)
    torch.nn.init.xavier_uniform_(self.whz.weight)
    torch.nn.init.xavier_uniform_(self.win.weight)
    torch.nn.init.xavier_uniform_(self.whn.weight)

forward

forward(x: Tensor, h: Tensor) -> Tuple[torch.Tensor, torch.Tensor]

Compute two child hidden states from input and parent hidden state.

Parameters:

Name Type Description Default
x Tensor

Input vector (symbol probabilities), shape (batch, input_size).

required
h Tensor

Parent hidden state, shape (batch, hidden_size).

required

Returns:

Type Description
Tensor

A 2-tuple (h_left, h_right) of child hidden states, each of shape

Tensor

(batch, hidden_size).

Source code in SRToolkit/approaches/EDHiE.py
def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute two child hidden states from input and parent hidden state.

    Args:
        x: Input vector (symbol probabilities), shape ``(batch, input_size)``.
        h: Parent hidden state, shape ``(batch, hidden_size)``.

    Returns:
        A 2-tuple ``(h_left, h_right)`` of child hidden states, each of shape
        ``(batch, hidden_size)``.
    """
    r = torch.sigmoid(self.wir(x) + self.whr(h))
    z = torch.sigmoid(self.wiz(x) + self.whz(h))
    n = torch.tanh(self.win(x) + r * self.whn(h))
    dh = h.repeat(1, 2)
    out = (1 - z) * n + z * dh
    h_left, h_right = torch.split(out, self.hidden_size, dim=1)
    return h_left, h_right

TreeBatchSampler

TreeBatchSampler(batch_size: int, num_eq: int)

Bases: Sampler[List[int]]

PyTorch sampler that yields randomly permuted mini-batches of tree indices.

Each call to __iter__ produces a fresh permutation. Batches are contiguous slices of the permutation; the last incomplete batch is dropped.

Parameters:

Name Type Description Default
batch_size int

Number of trees per mini-batch.

required
num_eq int

Total number of trees in the dataset.

required
Source code in SRToolkit/approaches/EDHiE.py
def __init__(self, batch_size: int, num_eq: int):
    """
    Args:
        batch_size: Number of trees per mini-batch.
        num_eq: Total number of trees in the dataset.
    """
    self.batch_size = batch_size
    self.num_eq = num_eq

TreeDataset

TreeDataset(trees: List[Optional[Node]])

Bases: Dataset

Minimal PyTorch Dataset wrapping a list of expression trees.

Parameters:

Name Type Description Default
trees List[Optional[Node]]

List of Node expression trees.

required
Source code in SRToolkit/approaches/EDHiE.py
def __init__(self, trees: List[Optional[Node]]):
    """
    Args:
        trees: List of [Node][SRToolkit.utils.expression_tree.Node] expression trees.
    """
    self.trees: List[Node] = [t for t in trees if isinstance(t, Node)]

create_batch

create_batch(trees: List[Optional[Node]], symbol2index: Dict[str, int]) -> BatchedNode

Pack a list of expression trees into a single BatchedNode with targets and masks populated.

Parameters:

Name Type Description Default
trees List[Optional[Node]]

List of Node trees to batch.

required
symbol2index Dict[str, int]

Mapping from symbol strings to vocabulary indices.

required

Returns:

Type Description
BatchedNode

A BatchedNode ready for the HVAE forward pass.

Source code in SRToolkit/approaches/EDHiE.py
def create_batch(trees: List[Optional[Node]], symbol2index: Dict[str, int]) -> BatchedNode:
    """
    Pack a list of expression trees into a single [BatchedNode][SRToolkit.approaches.EDHiE.BatchedNode]
    with targets and masks populated.

    Args:
        trees: List of [Node][SRToolkit.utils.expression_tree.Node] trees to batch.
        symbol2index: Mapping from symbol strings to vocabulary indices.

    Returns:
        A [BatchedNode][SRToolkit.approaches.EDHiE.BatchedNode] ready for the HVAE forward pass.
    """
    t = BatchedNode(symbol2index, trees=trees)
    t.create_target()
    return t

annealing_schedule

annealing_schedule(it: int, total_iters: int, supremum: float = 0.04) -> float

Linear KL annealing schedule mapping iteration count to a [0, supremum] coefficient.

Parameters:

Name Type Description Default
it int

Current iteration index.

required
total_iters int

Total number of training iterations.

required
supremum float

Maximum value returned at it == total_iters.

0.04

Returns:

Type Description
float

Annealing coefficient supremum * (it / total_iters).

Source code in SRToolkit/approaches/EDHiE.py
def annealing_schedule(it: int, total_iters: int, supremum: float = 0.04) -> float:
    """
    Linear KL annealing schedule mapping iteration count to a ``[0, supremum]`` coefficient.

    Args:
        it: Current iteration index.
        total_iters: Total number of training iterations.
        supremum: Maximum value returned at ``it == total_iters``.

    Returns:
        Annealing coefficient ``supremum * (it / total_iters)``.
    """
    x = it / total_iters
    return x * supremum

hvae_loss

hvae_loss(outputs: BatchedNode, mu: Tensor, logvar: Tensor, lmbda: float, criterion: Module) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Compute the HVAE ELBO loss: reconstruction (cross-entropy) + KL divergence.

Parameters:

Name Type Description Default
outputs BatchedNode

Decoded BatchedNode produced by the HVAE forward pass.

required
mu Tensor

Mean of the approximate posterior, shape (batch, latent_size).

required
logvar Tensor

Log-variance of the approximate posterior, same shape as mu.

required
lmbda float

KL annealing coefficient that scales the KL term.

required
criterion Module

Reconstruction loss callable (typically CrossEntropyLoss).

required

Returns:

Type Description
Tuple[Tensor, Tensor, Tensor]

A 3-tuple (total_loss, bce, kld) where total_loss = bce + lmbda * kld.

Source code in SRToolkit/approaches/EDHiE.py
def hvae_loss(
    outputs: "BatchedNode",
    mu: torch.Tensor,
    logvar: torch.Tensor,
    lmbda: float,
    criterion: nn.Module,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute the HVAE ELBO loss: reconstruction (cross-entropy) + KL divergence.

    Args:
        outputs: Decoded [BatchedNode][SRToolkit.approaches.EDHiE.BatchedNode] produced by the HVAE forward pass.
        mu: Mean of the approximate posterior, shape ``(batch, latent_size)``.
        logvar: Log-variance of the approximate posterior, same shape as ``mu``.
        lmbda: KL annealing coefficient that scales the KL term.
        criterion: Reconstruction loss callable (typically ``CrossEntropyLoss``).

    Returns:
        A 3-tuple ``(total_loss, bce, kld)`` where ``total_loss = bce + lmbda * kld``.
    """
    BCE = criterion(outputs.get_prediction(), outputs.get_target())
    KLD = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())) / mu.size(0)
    return BCE + lmbda * KLD, BCE, KLD

train_hvae

train_hvae(model: HVAE, trainset: TreeDataset, symbol_library: SymbolLibrary, epochs: int = 20, batch_size: int = 32, max_beta: float = 0.04, verbose: bool = True) -> None

Train an HVAE on a dataset of expression trees.

Uses the Adam optimiser and cross-entropy reconstruction loss with KL annealing controlled by annealing_schedule. If verbose=True, prints one decoded reconstruction sample per epoch at the midpoint batch.

Parameters:

Name Type Description Default
model HVAE

The HVAE to train (modified in-place).

required
trainset TreeDataset

TreeDataset of expression trees.

required
symbol_library SymbolLibrary

Symbol library defining the vocabulary.

required
epochs int

Number of training epochs.

20
batch_size int

Mini-batch size.

32
max_beta float

Maximum KL annealing coefficient.

0.04
verbose bool

If True, prints loss statistics and a sample reconstruction per epoch.

True

Returns:

Type Description
None

None

Source code in SRToolkit/approaches/EDHiE.py
def train_hvae(
    model: HVAE,
    trainset: TreeDataset,
    symbol_library: SymbolLibrary,
    epochs: int = 20,
    batch_size: int = 32,
    max_beta: float = 0.04,
    verbose: bool = True,
) -> None:
    """
    Train an [HVAE][SRToolkit.approaches.EDHiE.HVAE] on a dataset of expression trees.

    Uses the Adam optimiser and cross-entropy reconstruction loss with KL annealing controlled
    by [annealing_schedule][SRToolkit.approaches.EDHiE.annealing_schedule]. If ``verbose=True``,
    prints one decoded reconstruction sample per epoch at the midpoint batch.

    Args:
        model: The [HVAE][SRToolkit.approaches.EDHiE.HVAE] to train (modified in-place).
        trainset: [TreeDataset][SRToolkit.approaches.EDHiE.TreeDataset] of expression trees.
        symbol_library: Symbol library defining the vocabulary.
        epochs: Number of training epochs.
        batch_size: Mini-batch size.
        max_beta: Maximum KL annealing coefficient.
        verbose: If ``True``, prints loss statistics and a sample reconstruction per epoch.

    Returns:
        None
    """
    symbol2index = symbol_library.symbols2index()
    optimizer = Adam(model.parameters())
    criterion = CrossEntropyLoss(ignore_index=-1, reduction="mean")

    iter_counter = 0
    total_iters = epochs * (len(trainset) // batch_size)
    midpoint = len(trainset) // (2 * batch_size)
    sampler = TreeBatchSampler(batch_size, len(trainset))

    for epoch in range(epochs):
        bce, kl, loss_sum, num_iters = 0.0, 0.0, 0.0, 0.0

        with tqdm(total=len(trainset), desc=f"Training HVAE - Epoch: {epoch + 1}/{epochs}", unit="chunks") as prog_bar:
            for i, tree_ids in enumerate(sampler):
                lmbda = annealing_schedule(iter_counter, total_iters, max_beta)
                iter_counter += 1

                batch = create_batch([trainset[j] for j in tree_ids], symbol2index)

                optimizer.zero_grad()
                mu, logvar, outputs = model(batch)
                loss, bcel, kll = hvae_loss(outputs, mu, logvar, lmbda, criterion)
                loss.backward()
                optimizer.step()

                num_iters += 1
                bce += bcel.detach().item()
                kl += kll.detach().item()
                loss_sum += loss.detach().item()
                prog_bar.set_postfix(
                    **{"run:": "HVAE", "loss": loss_sum / num_iters, "BCE": bce / num_iters, "KLD": kl / num_iters}
                )
                prog_bar.update(batch_size)

                if verbose and i == midpoint:
                    original_trees = batch.to_expr_list()
                    decoded_trees = model.decode(mu.detach())
                    print()
                    if original_trees[0] is not None and decoded_trees[0] is not None:
                        print(f"O: {''.join(original_trees[0].to_list(symbol_library=symbol_library))}")
                        print(f"P: {''.join(decoded_trees[0].to_list(symbol_library=symbol_library))}")