EDHiE
SRToolkit.approaches.EDHiE
EDHiE approach — equation discovery with hierarchical variational autoencoders by Mežnar et al.
EDHiEConfig
dataclass
EDHiEConfig(name: str = 'EDHiE', approach_class: str = '', latent_size: int = 32, num_expressions: int = 20000, max_expression_length: int = 30, only_unique_expressions: bool = True, epochs: int = 20, batch_size: int = 32, max_beta: float = 0.035, population_size: int = 40, weights_path: Optional[str] = None, verbose: bool = True)
Bases: ApproachConfig
Configuration dataclass for the EDHiE approach.
Examples:
>>> cfg = EDHiEConfig(latent_size=32, epochs=10)
>>> cfg.name
'EDHiE'
>>> d = cfg.to_dict()
>>> EDHiEConfig.from_dict(d).latent_size
32
EDHiE
EDHiE(latent_size: int = 32, num_expressions: int = 20000, max_expression_length: int = 30, only_unique_expressions: bool = True, epochs: int = 20, batch_size: int = 32, max_beta: float = 0.035, population_size: int = 40, weights_path: Optional[str] = None, verbose: bool = True)
Bases: SR_approach
EDHiE — Equation Discovery with Hierarchical variational autoEncoders by Mežnar et al.
Trains a Hierarchical VAE (HVAE) on randomly generated expressions from the symbol library
(adapt), then explores the learned latent space with a genetic algorithm to find
expressions that best fit the target dataset (search).
Examples:
>>> model = EDHiE(latent_size=24, num_expressions=100, epochs=1, verbose=False)
>>> model.name
'EDHiE'
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent_size
|
int
|
Dimensionality of the HVAE latent space. |
32
|
num_expressions
|
int
|
Number of random expressions generated to train the HVAE. |
20000
|
max_expression_length
|
int
|
Maximum token length of generated training expressions. |
30
|
only_unique_expressions
|
bool
|
If |
True
|
epochs
|
int
|
Number of training epochs for the HVAE. |
20
|
batch_size
|
int
|
Batch size used during HVAE training. |
32
|
max_beta
|
float
|
Maximum value of the KL annealing coefficient (controls regularization strength). |
0.035
|
population_size
|
int
|
GA population size used during latent space search. |
40
|
weights_path
|
Optional[str]
|
Optional path to load/save HVAE weights from/to disk. |
None
|
verbose
|
bool
|
If |
True
|
Source code in SRToolkit/approaches/EDHiE.py
prepare
Reset per-experiment state.
The trained HVAE weights are preserved across experiments; only the evolutionary search is re-initialised on each call to search.
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in SRToolkit/approaches/EDHiE.py
adapt
Train the HVAE on randomly generated expressions from the symbol library.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
ndarray
|
Input data from the domain (shape |
required |
symbol_library
|
SymbolLibrary
|
Symbol library defining available tokens. |
required |
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in SRToolkit/approaches/EDHiE.py
save_adapted_state
Save the trained HVAE model weights and architecture metadata to disk.
Saves a single checkpoint file containing the symbol library, latent size, hidden size, max height, and model weights — enough to fully reconstruct the model in load_adapted_state without needing to call adapt first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
File path to save the checkpoint to, including the file extension, e.g. |
required |
Source code in SRToolkit/approaches/EDHiE.py
load_adapted_state
Restore a previously trained HVAE model from disk.
Reconstructs the HVAE from the architecture metadata saved alongside the weights, so no prior call to adapt is required.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
File path previously passed to save_adapted_state. |
required |
Source code in SRToolkit/approaches/EDHiE.py
search
Explore the HVAE latent space with a genetic algorithm to find the best expression.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sr_evaluator
|
SR_evaluator
|
SR_evaluator used to score candidate expressions. |
required |
seed
|
Optional[int]
|
Optional random seed for reproducibility. |
None
|
Returns:
| Type | Description |
|---|---|
None
|
None |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If adapt has not been called
before |
Source code in SRToolkit/approaches/EDHiE.py
BatchedNode
BatchedNode(symbol2index: Dict[str, int], size: int = 0, trees: Optional[List[Optional[Node]]] = None)
Batched binary tree node for vectorised HVAE forward/backward passes.
Mirrors the recursive structure of Node but stores a
batch of trees simultaneously. Each position in symbols corresponds to one tree in the
batch; empty strings "" mark padding (absent subtrees). left and right are
themselves BatchedNode instances (or None) representing the batched subtrees.
After all trees have been added via add_tree, call create_target to build the one-hot target tensors and masks required by the HVAE loss.
We recommend using [SRToolkit.approaches.EDHiE.create_batch] to create batched trees.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
symbol2index
|
Dict[str, int]
|
Mapping from symbol strings to vocabulary indices. |
required |
size
|
int
|
Number of padding slots to pre-allocate (filled with |
0
|
trees
|
Optional[List[Optional[Node]]]
|
None
|
Source code in SRToolkit/approaches/EDHiE.py
add_tree
Append one tree (or a padding slot) to the batch.
If tree is None, an empty padding entry ("") is appended at every level.
Otherwise the tree's symbol is appended and its subtrees are recursively merged into
self.left / self.right, creating new BatchedNode children as needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Optional[Node]
|
A Node to add, or |
None
|
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in SRToolkit/approaches/EDHiE.py
create_target
Build one-hot target tensors, integer target indices, and binary masks for the entire subtree.
Populates self.target (one-hot, shape (batch, vocab)), self.target_indices
(integer class indices, shape (batch,), -1 for padding), and self.mask
(1.0 for real symbols, 0.0 for padding). Recurses into left and right.
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in SRToolkit/approaches/EDHiE.py
to_expr_list
Extract one Node tree per batch element.
Returns:
| Type | Description |
|---|---|
List[Optional[Node]]
|
A list of Node trees (or |
List[Optional[Node]]
|
slots), one per position in the batch. |
Source code in SRToolkit/approaches/EDHiE.py
get_expr_at_idx
Reconstruct the expression tree for a single batch element.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
Batch index to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
Optional[Node]
|
A Node tree, or |
Source code in SRToolkit/approaches/EDHiE.py
get_prediction
Collect decoder logit predictions from all nodes in infix order.
Returns:
| Type | Description |
|---|---|
Tensor
|
Stacked logit tensor of shape |
Source code in SRToolkit/approaches/EDHiE.py
get_target
Collect one-hot target indices from all nodes in infix order.
Returns:
| Type | Description |
|---|---|
Tensor
|
Stacked target tensor of shape |
Source code in SRToolkit/approaches/EDHiE.py
HVAE
HVAE(input_size: int, output_size: int, symbol_library: SymbolLibrary, hidden_size: Optional[int] = None, max_height: int = 20)
Bases: Module
Hierarchical Variational Autoencoder (HVAE) for expression trees.
Combines a tree-recursive Encoder with a tree-recursive Decoder to learn a continuous latent representation of symbolic expressions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_size
|
int
|
Vocabulary size (number of symbols in the symbol library). |
required |
output_size
|
int
|
Dimensionality of the latent space. |
required |
symbol_library
|
SymbolLibrary
|
SymbolLibrary used by the decoder to constrain leaf/non-leaf symbol selection. |
required |
hidden_size
|
Optional[int]
|
Hidden state size for the GRU cells. Defaults to |
None
|
max_height
|
int
|
Maximum tree depth the decoder will generate before forcing leaf symbols. |
20
|
Source code in SRToolkit/approaches/EDHiE.py
forward
Run the full VAE forward pass: encode, sample latent vector, decode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
BatchedNode
|
Batched expression trees (training mode — teacher-forced decoding). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A 3-tuple |
Tensor
|
approximate posterior parameters and |
Source code in SRToolkit/approaches/EDHiE.py
sample
Draw a latent sample using the reparameterisation trick.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Tensor
|
Mean of the approximate posterior, shape |
required |
logvar
|
Tensor
|
Log-variance of the approximate posterior, same shape. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Sampled latent vector |
Source code in SRToolkit/approaches/EDHiE.py
encode
Encode a batch of trees to posterior parameters without sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
BatchedNode
|
Batched expression trees. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor]
|
A 2-tuple |
Source code in SRToolkit/approaches/EDHiE.py
decode
Decode a batch of latent vectors to expression trees (inference mode).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Tensor
|
Latent vectors of shape |
required |
Returns:
| Type | Description |
|---|---|
List[Optional[Node]]
|
A list of Node trees (or |
List[Optional[Node]]
|
decodes), one per row of |
Source code in SRToolkit/approaches/EDHiE.py
Encoder
Bases: Module
Tree-recursive GRU encoder that maps a BatchedNode
tree to approximate posterior parameters (mu, logvar).
Uses a GRU221 cell to combine left/right child hidden
states bottom-up, then projects to mu and logvar via linear layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_size
|
int
|
Vocabulary size (one-hot input dimension). |
required |
hidden_size
|
int
|
Hidden state dimensionality of the GRU. |
required |
output_size
|
int
|
Dimensionality of the latent space ( |
required |
Source code in SRToolkit/approaches/EDHiE.py
forward
Encode a batched tree to posterior parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
BatchedNode
|
Batched expression tree with targets and masks populated. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor]
|
A 2-tuple |
Source code in SRToolkit/approaches/EDHiE.py
recursive_forward
Recursively encode a subtree bottom-up using the GRU cell.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
BatchedNode
|
Current subtree node (batched). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Hidden state tensor of shape |
Source code in SRToolkit/approaches/EDHiE.py
Decoder
Decoder(input_size: int, hidden_size: int, output_size: int, symbol_library: SymbolLibrary, max_height: int = 20)
Bases: Module
Tree-recursive GRU decoder that maps a latent vector to an expression tree.
Uses a GRU122 cell to split each parent hidden state top-down into left and right child hidden states, then predicts the symbol at each node.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_size
|
int
|
Latent space dimensionality (decoder input). |
required |
hidden_size
|
int
|
Hidden state dimensionality of the GRU. |
required |
output_size
|
int
|
Vocabulary size (logit output dimension). |
required |
symbol_library
|
SymbolLibrary
|
Used to identify leaf symbols and enforce the |
required |
max_height
|
int
|
Maximum tree depth; beyond this depth only leaf symbols are sampled. |
20
|
Source code in SRToolkit/approaches/EDHiE.py
forward
Teacher-forced decoding pass used during training.
Stores logit predictions in each node of tree (in-place) using the ground-truth
tree structure to guide recursion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Tensor
|
Latent vectors of shape |
required |
tree
|
BatchedNode
|
Ground-truth batched tree; predictions are written into each node's
|
required |
Returns:
| Type | Description |
|---|---|
BatchedNode
|
The same |
Source code in SRToolkit/approaches/EDHiE.py
recursive_forward
Recursively write predictions into a subtree (teacher-forced, training only).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden
|
Tensor
|
Current node hidden state, shape |
required |
tree
|
BatchedNode
|
Current subtree node; |
required |
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in SRToolkit/approaches/EDHiE.py
decode
Autoregressively decode latent vectors to expression trees (inference mode).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Tensor
|
Latent vectors of shape |
required |
Returns:
| Type | Description |
|---|---|
List[Optional[Node]]
|
A list of Node trees, one per row of |
Source code in SRToolkit/approaches/EDHiE.py
recursive_decode
Recursively decode hidden states into a batched subtree (inference mode).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden
|
Tensor
|
Current node hidden state, shape |
required |
mask
|
Tensor
|
Boolean mask; |
required |
height
|
int
|
Current tree depth (used to enforce |
0
|
Returns:
| Type | Description |
|---|---|
BatchedNode
|
A BatchedNode representing this subtree. |
Source code in SRToolkit/approaches/EDHiE.py
sample_symbol
sample_symbol(prediction: Tensor, mask: Tensor, height: int) -> Tuple[List[str], torch.Tensor, torch.Tensor]
Greedily select one symbol per batch element and compute child masks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Tensor
|
Softmax probabilities over the vocabulary, shape |
required |
mask
|
Tensor
|
Boolean mask for active batch elements. |
required |
height
|
int
|
Current tree depth; at |
required |
Returns:
| Type | Description |
|---|---|
List[str]
|
A 3-tuple |
Tensor
|
selected symbol strings ( |
Tensor
|
batch elements should recurse into the left/right child. |
Source code in SRToolkit/approaches/EDHiE.py
GRU221
Bases: Module
Two-input, one-output GRU cell used by the Encoder.
Combines an input vector with two hidden states (left child and right child) into a single parent hidden state. The two child states are concatenated before being passed through the standard GRU gating equations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_size
|
int
|
Dimensionality of the input vector (vocabulary one-hot size). |
required |
hidden_size
|
int
|
Dimensionality of each child hidden state and the output hidden state. |
required |
Source code in SRToolkit/approaches/EDHiE.py
forward
Compute parent hidden state from input and two child hidden states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input (symbol embedding / one-hot), shape |
required |
h1
|
Tensor
|
Left child hidden state, shape |
required |
h2
|
Tensor
|
Right child hidden state, shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Parent hidden state of shape |
Source code in SRToolkit/approaches/EDHiE.py
GRU122
Bases: Module
One-input, two-output GRU cell used by the Decoder.
Splits a parent hidden state into two child hidden states (left and right) driven by an input vector. The output is split evenly along the hidden dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_size
|
int
|
Dimensionality of the input vector (symbol probability distribution). |
required |
hidden_size
|
int
|
Dimensionality of each output child hidden state (output is
|
required |
Source code in SRToolkit/approaches/EDHiE.py
forward
Compute two child hidden states from input and parent hidden state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input vector (symbol probabilities), shape |
required |
h
|
Tensor
|
Parent hidden state, shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
A 2-tuple |
Tensor
|
|
Source code in SRToolkit/approaches/EDHiE.py
TreeBatchSampler
Bases: Sampler[List[int]]
PyTorch sampler that yields randomly permuted mini-batches of tree indices.
Each call to __iter__ produces a fresh permutation. Batches are contiguous slices
of the permutation; the last incomplete batch is dropped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of trees per mini-batch. |
required |
num_eq
|
int
|
Total number of trees in the dataset. |
required |
Source code in SRToolkit/approaches/EDHiE.py
TreeDataset
create_batch
Pack a list of expression trees into a single BatchedNode with targets and masks populated.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trees
|
List[Optional[Node]]
|
List of Node trees to batch. |
required |
symbol2index
|
Dict[str, int]
|
Mapping from symbol strings to vocabulary indices. |
required |
Returns:
| Type | Description |
|---|---|
BatchedNode
|
A BatchedNode ready for the HVAE forward pass. |
Source code in SRToolkit/approaches/EDHiE.py
annealing_schedule
Linear KL annealing schedule mapping iteration count to a [0, supremum] coefficient.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
it
|
int
|
Current iteration index. |
required |
total_iters
|
int
|
Total number of training iterations. |
required |
supremum
|
float
|
Maximum value returned at |
0.04
|
Returns:
| Type | Description |
|---|---|
float
|
Annealing coefficient |
Source code in SRToolkit/approaches/EDHiE.py
hvae_loss
hvae_loss(outputs: BatchedNode, mu: Tensor, logvar: Tensor, lmbda: float, criterion: Module) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Compute the HVAE ELBO loss: reconstruction (cross-entropy) + KL divergence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outputs
|
BatchedNode
|
Decoded BatchedNode produced by the HVAE forward pass. |
required |
mu
|
Tensor
|
Mean of the approximate posterior, shape |
required |
logvar
|
Tensor
|
Log-variance of the approximate posterior, same shape as |
required |
lmbda
|
float
|
KL annealing coefficient that scales the KL term. |
required |
criterion
|
Module
|
Reconstruction loss callable (typically |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor, Tensor]
|
A 3-tuple |
Source code in SRToolkit/approaches/EDHiE.py
train_hvae
train_hvae(model: HVAE, trainset: TreeDataset, symbol_library: SymbolLibrary, epochs: int = 20, batch_size: int = 32, max_beta: float = 0.04, verbose: bool = True) -> None
Train an HVAE on a dataset of expression trees.
Uses the Adam optimiser and cross-entropy reconstruction loss with KL annealing controlled
by annealing_schedule. If verbose=True,
prints one decoded reconstruction sample per epoch at the midpoint batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
HVAE
|
The HVAE to train (modified in-place). |
required |
trainset
|
TreeDataset
|
TreeDataset of expression trees. |
required |
symbol_library
|
SymbolLibrary
|
Symbol library defining the vocabulary. |
required |
epochs
|
int
|
Number of training epochs. |
20
|
batch_size
|
int
|
Mini-batch size. |
32
|
max_beta
|
float
|
Maximum KL annealing coefficient. |
0.04
|
verbose
|
bool
|
If |
True
|
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in SRToolkit/approaches/EDHiE.py
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 | |