Skip to content

Result Augmentation

SRToolkit.evaluation.result_augmentation

Concrete ResultAugmenter implementations that post-process SR results with additional information.

Available augmenters: ExpressionToLatex, ExpressionSimplifier, RMSE, BED, R2. Custom augmenters work automatically without registration — just subclass ResultAugmenter and implement the required methods.

ExpressionToLatex

ExpressionToLatex(symbol_library: SymbolLibrary, scope: str = 'top', verbose: bool = False, name: str = 'ExpressionToLatex')

Bases: ResultAugmenter

Converts expressions inside the results to LaTeX strings.

Parameters:

Name Type Description Default
symbol_library SymbolLibrary

Symbol library used to produce LaTeX templates for each token.

required
scope str

Which expressions to convert.

  • "best": only the best expression.
  • "top": the best expression and all top-k models.
  • "all": everything in "top" plus all evaluated expressions.
'top'
verbose bool

If True, emits a warning when LaTeX conversion fails for an expression. Default False.

False
name str

Key used in augmentations dict of EvalResult and ModelResult. Default "ExpressionToLatex".

'ExpressionToLatex'
Source code in SRToolkit/evaluation/result_augmentation.py
def __init__(
    self,
    symbol_library: SymbolLibrary,
    scope: str = "top",
    verbose: bool = False,
    name: str = "ExpressionToLatex",
) -> None:
    """
    Converts expressions inside the results to LaTeX strings.

    Args:
        symbol_library: Symbol library used to produce LaTeX templates for each token.
        scope: Which expressions to convert.

            - ``"best"``: only the best expression.
            - ``"top"``: the best expression and all top-k models.
            - ``"all"``: everything in ``"top"`` plus all evaluated expressions.
        verbose: If ``True``, emits a warning when LaTeX conversion fails for an expression.
            Default ``False``.
        name: Key used in
            ``augmentations`` dict of [EvalResult][SRToolkit.utils.types.EvalResult] and
            [ModelResult][SRToolkit.utils.types.ModelResult].
            Default ``"ExpressionToLatex"``.
    """
    super().__init__(name)
    self.symbol_library = symbol_library

    if scope not in ["best", "top", "all"]:
        raise Exception(f"[RMSE augmenter] Invalid scope: {scope}. Must be one of 'best', 'top', 'all'.")
    self.scope = scope

    self.verbose = verbose

write_results

write_results(results: EvalResult) -> None

Write LaTeX representations into results and its models.

Stores {"best_expr_latex": ...} in EvalResult augmentations. Also stores {"expr_latex": ...} in each model's augmentations when scope is "top" or "all".

Parameters:

Name Type Description Default
results EvalResult

The EvalResult to augment.

required
Source code in SRToolkit/evaluation/result_augmentation.py
def write_results(self, results: EvalResult) -> None:
    """
    Write LaTeX representations into *results* and its models.

    Stores ``{"best_expr_latex": ...}`` in
    [EvalResult][SRToolkit.utils.types.EvalResult] ``augmentations``.
    Also stores ``{"expr_latex": ...}`` in each model's augmentations when
    ``scope`` is ``"top"`` or ``"all"``.

    Args:
        results: The [EvalResult][SRToolkit.utils.types.EvalResult] to augment.
    """
    eval_data: Dict[str, Any] = {}
    try:
        eval_data["best_expr_latex"] = tokens_to_tree(results.top_models[0].expr, self.symbol_library).to_latex(
            self.symbol_library
        )
    except Exception as e:
        if self.verbose:
            warnings.warn(f"Unable to convert best expression to LaTeX: {e}")
    results.add_augmentation(self.name, eval_data, self._type)

    if self.scope == "top" or self.scope == "all":
        for model in results.top_models:
            try:
                model.add_augmentation(
                    self.name,
                    {"expr_latex": tokens_to_tree(model.expr, self.symbol_library).to_latex(self.symbol_library)},
                    self._type,
                )
            except Exception as e:
                if self.verbose:
                    warnings.warn(f"Unable to convert expression {''.join(model.expr)} to LaTeX: {e}")

    if self.scope == "all":
        for model in results.all_models:
            try:
                model.add_augmentation(
                    self.name,
                    {"expr_latex": tokens_to_tree(model.expr, self.symbol_library).to_latex(self.symbol_library)},
                    self._type,
                )
            except Exception as e:
                if self.verbose:
                    warnings.warn(f"Unable to convert expression {''.join(model.expr)} to LaTeX: {e}")

format_eval_result classmethod

format_eval_result(data: Dict[str, Any]) -> str

Format experiment-level LaTeX augmentation data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "best_expr_latex".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_eval_result(cls, data: Dict[str, Any]) -> str:
    """
    Format experiment-level LaTeX augmentation data for display.

    Args:
        data: Augmentation dict containing ``"best_expr_latex"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    latex = data.get("best_expr_latex", "")
    return f"LaTeX of the best expression: {latex}" if latex else ""

format_model_result classmethod

format_model_result(data: Dict[str, Any]) -> str

Format per-model LaTeX augmentation data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "expr_latex".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_model_result(cls, data: Dict[str, Any]) -> str:
    """
    Format per-model LaTeX augmentation data for display.

    Args:
        data: Augmentation dict containing ``"expr_latex"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    latex = data.get("expr_latex", "")
    return f"LaTeX: {latex}" if latex else ""

to_dict

to_dict(base_path: str, name: str) -> dict

Creates a dictionary representation of the ExpressionToLatex augmenter.

Parameters:

Name Type Description Default
base_path str

Unused and ignored

required
name str

Unused and ignored

required

Returns:

Type Description
dict

A dictionary containing the necessary information to recreate the augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
def to_dict(self, base_path: str, name: str) -> dict:
    """
    Creates a dictionary representation of the ExpressionToLatex augmenter.

    Args:
        base_path: Unused and ignored
        name: Unused and ignored

    Returns:
        A dictionary containing the necessary information to recreate the augmenter.
    """
    return {
        "format_version": 1,
        "type": "ExpressionToLatex",
        "name": self.name,
        "symbol_library": self.symbol_library.to_dict(),
        "scope": self.scope,
        "verbose": self.verbose,
    }

from_dict staticmethod

from_dict(data: dict) -> ExpressionToLatex

Creates an instance of the ExpressionToLatex augmenter from a dictionary.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary information to recreate the augmenter.

required

Returns:

Type Description
ExpressionToLatex

An instance of the ExpressionToLatex augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
@staticmethod
def from_dict(data: dict) -> "ExpressionToLatex":
    """
    Creates an instance of the ExpressionToLatex augmenter from a dictionary.

    Args:
        data: A dictionary containing the necessary information to recreate the augmenter.

    Returns:
        An instance of the ExpressionToLatex augmenter.
    """
    if data.get("format_version", 1) != 1:
        raise ValueError(
            f"[ExpressionToLatex.from_dict] Unsupported format_version: {data.get('format_version')!r}. Expected 1."
        )
    return ExpressionToLatex(
        symbol_library=data["symbol_library"],
        scope=data["scope"],
        verbose=data["verbose"],
        name=data["name"],
    )

ExpressionSimplifier

ExpressionSimplifier(symbol_library: SymbolLibrary, scope: str = 'top', verbose: bool = False, name: str = 'ExpressionSimplifier')

Bases: ResultAugmenter

Algebraically simplifies expressions inside the results using SymPy.

Parameters:

Name Type Description Default
symbol_library SymbolLibrary

Symbol library used by the simplifier to resolve token types.

required
scope str

Which expressions to simplify.

  • "best": only the best expression.
  • "top": the best expression and all top-k models.
  • "all": everything in "top" plus all evaluated expressions.
'top'
verbose bool

If True, emits a warning when simplification fails for an expression. Default False.

False
name str

Key used in augmentations dict of EvalResult and ModelResult. Default "ExpressionSimplifier".

'ExpressionSimplifier'
Source code in SRToolkit/evaluation/result_augmentation.py
def __init__(
    self,
    symbol_library: SymbolLibrary,
    scope: str = "top",
    verbose: bool = False,
    name: str = "ExpressionSimplifier",
) -> None:
    """
    Algebraically simplifies expressions inside the results using SymPy.

    Args:
        symbol_library: Symbol library used by the simplifier to resolve token types.
        scope: Which expressions to simplify.

            - ``"best"``: only the best expression.
            - ``"top"``: the best expression and all top-k models.
            - ``"all"``: everything in ``"top"`` plus all evaluated expressions.
        verbose: If ``True``, emits a warning when simplification fails for an expression.
            Default ``False``.
        name: Key used in
            ``augmentations`` dict of [EvalResult][SRToolkit.utils.types.EvalResult] and
            [ModelResult][SRToolkit.utils.types.ModelResult].
            Default ``"ExpressionSimplifier"``.
    """
    super().__init__(name)
    self.symbol_library = symbol_library

    if scope not in ["best", "top", "all"]:
        raise Exception(f"[RMSE augmenter] Invalid scope: {scope}. Must be one of 'best', 'top', 'all'.")
    self.scope = scope

    self.verbose = verbose

write_results

write_results(results: EvalResult) -> None

Write simplified expressions into results and its models.

Stores {"simplified_best_expr": ...} in EvalResult augmentations if simplification succeeds. Also stores {"simplified_expr": ...} in each model's augmentations when scope is "top" or "all".

Parameters:

Name Type Description Default
results EvalResult

The EvalResult to augment.

required
Source code in SRToolkit/evaluation/result_augmentation.py
def write_results(self, results: EvalResult) -> None:
    """
    Write simplified expressions into *results* and its models.

    Stores ``{"simplified_best_expr": ...}`` in
    [EvalResult][SRToolkit.utils.types.EvalResult] ``augmentations`` if
    simplification succeeds. Also stores ``{"simplified_expr": ...}`` in each model's
    augmentations when ``scope`` is ``"top"`` or ``"all"``.

    Args:
        results: The [EvalResult][SRToolkit.utils.types.EvalResult] to augment.
    """
    eval_data: Dict[str, Any] = {}
    try:
        simplified_expr = simplify(results.top_models[0].expr, self.symbol_library)
        if isinstance(simplified_expr, list):
            eval_data["simplified_best_expr"] = "".join(simplified_expr)
        elif isinstance(simplified_expr, Node):
            eval_data["simplified_best_expr"] = "".join(simplified_expr.to_list(self.symbol_library))
        else:
            raise Exception(f"Simplified expression is not a list or Node: {simplified_expr}")
    except Exception as e:
        if self.verbose:
            warnings.warn(f"Unable to simplify {results.best_expr}: {e}")
    results.add_augmentation(self.name, eval_data, self._type)

    if self.scope == "top" or self.scope == "all":
        for model in results.top_models:
            top_model_data: Dict[str, Any] = {}
            try:
                simplified_expr = simplify(model.expr, self.symbol_library)
                if isinstance(simplified_expr, list):
                    top_model_data["simplified_expr"] = "".join(simplified_expr)
                elif isinstance(simplified_expr, Node):
                    top_model_data["simplified_expr"] = "".join(simplified_expr.to_list(self.symbol_library))
                else:
                    raise Exception(f"Simplified expression is not a list or Node: {simplified_expr}")
            except Exception as e:
                if self.verbose:
                    warnings.warn(f"Unable to simplify {''.join(model.expr)}: {e}")
            model.add_augmentation(self.name, top_model_data, self._type)

    if self.scope == "all":
        for model in results.all_models:
            all_model_data: Dict[str, Any] = {}
            try:
                simplified_expr = simplify(model.expr, self.symbol_library)
                if isinstance(simplified_expr, list):
                    all_model_data["simplified_expr"] = "".join(simplified_expr)
                elif isinstance(simplified_expr, Node):
                    all_model_data["simplified_expr"] = "".join(simplified_expr.to_list(self.symbol_library))
                else:
                    raise Exception(f"Simplified expression is not a list or Node: {simplified_expr}")
            except Exception as e:
                if self.verbose:
                    warnings.warn(f"Unable to simplify {''.join(model.expr)}: {e}")
            model.add_augmentation(self.name, all_model_data, self._type)

format_eval_result classmethod

format_eval_result(data: Dict[str, Any]) -> str

Format experiment-level simplification data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "simplified_best_expr".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_eval_result(cls, data: Dict[str, Any]) -> str:
    """
    Format experiment-level simplification data for display.

    Args:
        data: Augmentation dict containing ``"simplified_best_expr"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    simplified = data.get("simplified_best_expr", "")
    return f"Simplified: {simplified}" if simplified else ""

format_model_result classmethod

format_model_result(data: Dict[str, Any]) -> str

Format per-model simplification data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "simplified_expr".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_model_result(cls, data: Dict[str, Any]) -> str:
    """
    Format per-model simplification data for display.

    Args:
        data: Augmentation dict containing ``"simplified_expr"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    simplified = data.get("simplified_expr", "")
    return f"Simplified: {simplified}" if simplified else ""

to_dict

to_dict(base_path: str, name: str) -> dict

Creates a dictionary representation of the ExpressionSimplifier augmenter.

Parameters:

Name Type Description Default
base_path str

Unused and ignored

required
name str

Unused and ignored

required

Returns:

Type Description
dict

A dictionary containing the necessary information to recreate the augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
def to_dict(self, base_path: str, name: str) -> dict:
    """
    Creates a dictionary representation of the ExpressionSimplifier augmenter.

    Args:
        base_path: Unused and ignored
        name: Unused and ignored

    Returns:
        A dictionary containing the necessary information to recreate the augmenter.
    """
    return {
        "format_version": 1,
        "type": "ExpressionSimplifier",
        "name": self.name,
        "symbol_library": self.symbol_library.to_dict(),
        "scope": self.scope,
        "verbose": self.verbose,
    }

from_dict staticmethod

from_dict(data: dict) -> ExpressionSimplifier

Creates an instance of the ExpressionSimplifier augmenter from a dictionary.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary information to recreate the augmenter.

required

Returns: An instance of the ExpressionSimplifier augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
@staticmethod
def from_dict(data: dict) -> "ExpressionSimplifier":
    """
    Creates an instance of the ExpressionSimplifier augmenter from a dictionary.

    Args:
        data: A dictionary containing the necessary information to recreate the augmenter.
    Returns:
        An instance of the ExpressionSimplifier augmenter.
    """
    if data.get("format_version", 1) != 1:
        raise ValueError(
            f"[ExpressionSimplifier.from_dict] Unsupported format_version: {data.get('format_version')!r}. Expected 1."
        )
    return ExpressionSimplifier(
        symbol_library=data["symbol_library"],
        scope=data["scope"],
        verbose=data["verbose"],
        name=data["name"],
    )

RMSE

RMSE(evaluator: SR_evaluator, scope: str = 'top', name: str = 'RMSE')

Bases: ResultAugmenter

Computes RMSE for the top models using a separate evaluator (e.g. a held-out test set).

Parameters:

Name Type Description Default
evaluator SR_evaluator

SR_evaluator used to score the models. Must be initialized with ranking_function="rmse" and a non-None y.

required
scope str

Which expressions to score.

  • "best": only the best expression.
  • "top": the best expression and all top-k models.
  • "all": everything in "top" plus all evaluated expressions.
'top'
name str

Key used in augmentations dict of EvalResult and ModelResult. Default "RMSE".

'RMSE'

Raises:

Type Description
Exception

If evaluator.ranking_function != "rmse" or evaluator.y is None.

Source code in SRToolkit/evaluation/result_augmentation.py
def __init__(self, evaluator: SR_evaluator, scope: str = "top", name: str = "RMSE") -> None:  # noqa: F821
    """
    Computes RMSE for the top models using a separate evaluator (e.g. a held-out test set).

    Args:
        evaluator: [SR_evaluator][SRToolkit.evaluation.sr_evaluator.SR_evaluator] used to
            score the models. Must be initialized with ``ranking_function="rmse"`` and a
            non-``None`` ``y``.
        scope: Which expressions to score.

            - ``"best"``: only the best expression.
            - ``"top"``: the best expression and all top-k models.
            - ``"all"``: everything in ``"top"`` plus all evaluated expressions.
        name: Key used in
            ``augmentations`` dict of [EvalResult][SRToolkit.utils.types.EvalResult] and
            [ModelResult][SRToolkit.utils.types.ModelResult].
            Default ``"RMSE"``.

    Raises:
        Exception: If ``evaluator.ranking_function != "rmse"`` or ``evaluator.y is None``.
    """
    super().__init__(name)
    self.evaluator = evaluator

    if scope not in ["best", "top", "all"]:
        raise Exception(f"[RMSE augmenter] Invalid scope: {scope}. Must be one of 'best', 'top', 'all'.")
    self.scope = scope

    if self.evaluator.ranking_function != "rmse":
        raise Exception("[RMSE augmenter] Ranking function of the evaluator must be set to 'rmse' to compute RMSE.")
    if self.evaluator.y is None:
        raise Exception("[RMSE augmenter] y in the evaluator must not be None to compute RMSE.")

write_results

write_results(results: EvalResult) -> None

Write RMSE scores into results and its models.

Stores {"min_error": ...} in EvalResult augmentations and {"error": ..., "parameters": ...} in each model's augmentations when scope is "top" or "all".

Parameters:

Name Type Description Default
results EvalResult

The EvalResult to augment.

required
Source code in SRToolkit/evaluation/result_augmentation.py
def write_results(self, results: EvalResult) -> None:
    """
    Write RMSE scores into *results* and its models.

    Stores ``{"min_error": ...}`` in
    [EvalResult][SRToolkit.utils.types.EvalResult] ``augmentations`` and
    ``{"error": ..., "parameters": ...}`` in each model's augmentations when ``scope``
    is ``"top"`` or ``"all"``.

    Args:
        results: The [EvalResult][SRToolkit.utils.types.EvalResult] to augment.
    """
    eval_data: Dict[str, Any] = {"min_error": self.evaluator.evaluate_expr(results.top_models[0].expr)}
    results.add_augmentation(self.name, eval_data, self._type)

    if self.scope == "top" or self.scope == "all":
        for model in results.top_models:
            key = "".join(model.expr)
            top_model_data: Dict[str, Any] = {
                "error": self.evaluator.evaluate_expr(model.expr),
                "parameters": self.evaluator.models[key].parameters,
            }
            model.add_augmentation(self.name, top_model_data, self._type)

    if self.scope == "all":
        for model in results.all_models:
            key = "".join(model.expr)
            all_model_data: Dict[str, Any] = {
                "error": self.evaluator.evaluate_expr(model.expr),
                "parameters": self.evaluator.models[key].parameters,
            }
            model.add_augmentation(self.name, all_model_data, self._type)

format_eval_result classmethod

format_eval_result(data: Dict[str, Any]) -> str

Format experiment-level RMSE data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "min_error".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_eval_result(cls, data: Dict[str, Any]) -> str:
    """
    Format experiment-level RMSE data for display.

    Args:
        data: Augmentation dict containing ``"min_error"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    val = data.get("min_error", "")
    return f"Test RMSE: {val}" if val != "" else ""

format_model_result classmethod

format_model_result(data: Dict[str, Any]) -> str

Format per-model RMSE data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "error" and optionally "parameters".

required

Returns:

Type Description
str

A human-readable string with RMSE and fitted parameters.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_model_result(cls, data: Dict[str, Any]) -> str:
    """
    Format per-model RMSE data for display.

    Args:
        data: Augmentation dict containing ``"error"`` and optionally ``"parameters"``.

    Returns:
        A human-readable string with RMSE and fitted parameters.
    """
    parts = [f"RMSE={data['error']:.6g}"]
    if "parameters" in data and data["parameters"] is not None:
        parts.append(f"params={np.round(data['parameters'], 4).tolist()}")
    return ", ".join(parts)

to_dict

to_dict(base_path: str, name: str) -> dict

Creates a dictionary representation of the RMSE augmenter.

Parameters:

Name Type Description Default
base_path str

Used to save the data of the evaluator to disk.

required
name str

Used to save the data of the evaluator to disk.

required

Returns:

Type Description
dict

A dictionary containing the necessary information to recreate the augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
def to_dict(self, base_path: str, name: str) -> dict:
    """
    Creates a dictionary representation of the RMSE augmenter.

    Args:
        base_path: Used to save the data of the evaluator to disk.
        name: Used to save the data of the evaluator to disk.

    Returns:
        A dictionary containing the necessary information to recreate the augmenter.
    """
    return {
        "format_version": 1,
        "name": self.name,
        "type": "RMSE",
        "scope": self.scope,
        "evaluator": self.evaluator.to_dict(base_path, name + "_RMSE_augmenter"),
    }

from_dict staticmethod

from_dict(data: dict) -> RMSE

Creates an instance of the RMSE augmenter from a dictionary.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary information to recreate the augmenter.

required

Returns:

Type Description
RMSE

An instance of the RMSE augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
@staticmethod
def from_dict(data: dict) -> "RMSE":
    """
    Creates an instance of the RMSE augmenter from a dictionary.

    Args:
        data: A dictionary containing the necessary information to recreate the augmenter.

    Returns:
        An instance of the RMSE augmenter.
    """
    if data.get("format_version", 1) != 1:
        raise ValueError(
            f"[RMSE.from_dict] Unsupported format_version: {data.get('format_version')!r}. Expected 1."
        )
    evaluator = SR_evaluator.from_dict(data["evaluator"])
    return RMSE(evaluator, scope=data["scope"], name=data["name"])

BED

BED(evaluator: SR_evaluator, scope: str = 'top', name: str = 'BED')

Bases: ResultAugmenter

Computes BED for the top models using a separate evaluator (e.g. a held-out test set).

Parameters:

Name Type Description Default
evaluator SR_evaluator

SR_evaluator used to score the models. Must be initialized with ranking_function="bed".

required
scope str

Which expressions to score.

  • "best": only the best expression.
  • "top": the best expression and all top-k models.
  • "all": everything in "top" plus all evaluated expressions.
'top'
name str

Key used in augmentations dict of EvalResult and ModelResult. Default "BED".

'BED'

Raises:

Type Description
Exception

If evaluator.ranking_function != "bed".

Source code in SRToolkit/evaluation/result_augmentation.py
def __init__(self, evaluator: SR_evaluator, scope: str = "top", name: str = "BED") -> None:  # noqa: F821
    """
    Computes BED for the top models using a separate evaluator (e.g. a held-out test set).

    Args:
        evaluator: [SR_evaluator][SRToolkit.evaluation.sr_evaluator.SR_evaluator] used to
            score the models. Must be initialized with ``ranking_function="bed"``.
        scope: Which expressions to score.

            - ``"best"``: only the best expression.
            - ``"top"``: the best expression and all top-k models.
            - ``"all"``: everything in ``"top"`` plus all evaluated expressions.
        name: Key used in
            ``augmentations`` dict of [EvalResult][SRToolkit.utils.types.EvalResult] and
            [ModelResult][SRToolkit.utils.types.ModelResult].
            Default ``"BED"``.

    Raises:
        Exception: If ``evaluator.ranking_function != "bed"``.
    """
    super().__init__(name)
    self.evaluator = evaluator

    if scope not in ["best", "top", "all"]:
        raise Exception(f"[BED augmenter] Invalid scope: {scope}. Must be one of 'best', 'top', 'all'.")
    self.scope = scope

    if self.evaluator.ranking_function != "bed":
        raise Exception("[BED augmenter] Ranking function of the evaluator must be set to 'bed' to compute BED.")

write_results

write_results(results: EvalResult) -> None

Write BED scores into results and its models.

Stores {"best_expr_bed": ...} in EvalResult augmentations and {"bed": ...} in each model's augmentations when scope is "top" or "all".

Parameters:

Name Type Description Default
results EvalResult

The EvalResult to augment.

required
Source code in SRToolkit/evaluation/result_augmentation.py
def write_results(
    self,
    results: EvalResult,
) -> None:
    """
    Write BED scores into *results* and its models.

    Stores ``{"best_expr_bed": ...}`` in
    [EvalResult][SRToolkit.utils.types.EvalResult] ``augmentations`` and
    ``{"bed": ...}`` in each model's augmentations when ``scope`` is ``"top"`` or ``"all"``.

    Args:
        results: The [EvalResult][SRToolkit.utils.types.EvalResult] to augment.
    """
    eval_data: Dict[str, Any] = {"best_expr_bed": self.evaluator.evaluate_expr(results.top_models[0].expr)}
    results.add_augmentation(self.name, eval_data, self._type)

    if self.scope == "top" or self.scope == "all":
        for model in results.top_models:
            top_model_data: Dict[str, Any] = {"bed": self.evaluator.evaluate_expr(model.expr)}
            model.add_augmentation(self.name, top_model_data, self._type)

    if self.scope == "all":
        for model in results.all_models:
            all_model_data: Dict[str, Any] = {"bed": self.evaluator.evaluate_expr(model.expr)}
            model.add_augmentation(self.name, all_model_data, self._type)

format_eval_result classmethod

format_eval_result(data: Dict[str, Any]) -> str

Format experiment-level BED data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "best_expr_bed".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_eval_result(cls, data: Dict[str, Any]) -> str:
    """
    Format experiment-level BED data for display.

    Args:
        data: Augmentation dict containing ``"best_expr_bed"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    val = data.get("best_expr_bed", "")
    return f"Test BED: {val}" if val != "" else ""

format_model_result classmethod

format_model_result(data: Dict[str, Any]) -> str

Format per-model BED data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "bed".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_model_result(cls, data: Dict[str, Any]) -> str:
    """
    Format per-model BED data for display.

    Args:
        data: Augmentation dict containing ``"bed"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    val = data.get("bed", "")
    return f"BED={val}" if val != "" else ""

to_dict

to_dict(base_path: str, name: str) -> dict

Creates a dictionary representation of the BED augmenter.

Parameters:

Name Type Description Default
base_path str

Used to save the data of the evaluator to disk.

required
name str

Used to save the data of the evaluator to disk.

required

Returns:

Type Description
dict

A dictionary containing the necessary information to recreate the augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
def to_dict(self, base_path: str, name: str) -> dict:
    """
    Creates a dictionary representation of the BED augmenter.

    Args:
        base_path: Used to save the data of the evaluator to disk.
        name: Used to save the data of the evaluator to disk.

    Returns:
        A dictionary containing the necessary information to recreate the augmenter.
    """
    return {
        "format_version": 1,
        "name": self.name,
        "type": "BED",
        "scope": self.scope,
        "evaluator": self.evaluator.to_dict(base_path, name + "_BED_augmenter"),
    }

from_dict staticmethod

from_dict(data: dict) -> BED

Creates an instance of the BED augmenter from a dictionary.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary information to recreate the augmenter.

required

Returns:

Type Description
BED

An instance of the BED augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
@staticmethod
def from_dict(data: dict) -> "BED":
    """
    Creates an instance of the BED augmenter from a dictionary.

    Args:
        data: A dictionary containing the necessary information to recreate the augmenter.

    Returns:
        An instance of the BED augmenter.
    """
    if data.get("format_version", 1) != 1:
        raise ValueError(f"[BED.from_dict] Unsupported format_version: {data.get('format_version')!r}. Expected 1.")
    evaluator = SR_evaluator.from_dict(data["evaluator"])
    return BED(evaluator, scope=data["scope"], name=data["name"])

R2

R2(evaluator: SR_evaluator, scope: str = 'top', name: str = 'R2')

Bases: ResultAugmenter

Computes R² for the top models using a separate evaluator (e.g. a held-out test set).

The same evaluator instance can be shared with RMSE to avoid loading test data twice.

Parameters:

Name Type Description Default
evaluator SR_evaluator

SR_evaluator used to score the models. Must be initialized with ranking_function="rmse" and a non-None y.

required
scope str

Which expressions to score.

  • "best": only the best expression.
  • "top": the best expression and all top-k models.
  • "all": everything in "top" plus all evaluated expressions.
'top'
name str

Key used in augmentations dict of EvalResult and ModelResult. Default "R2".

'R2'

Raises:

Type Description
Exception

If evaluator.ranking_function != "rmse" or evaluator.y is None.

Source code in SRToolkit/evaluation/result_augmentation.py
def __init__(self, evaluator: SR_evaluator, scope: str = "top", name: str = "R2") -> None:  # noqa: F821
    """
    Computes R² for the top models using a separate evaluator (e.g. a held-out test set).

    The same evaluator instance can be shared with
    [RMSE][SRToolkit.evaluation.result_augmentation.RMSE] to avoid loading test data twice.

    Args:
        evaluator: [SR_evaluator][SRToolkit.evaluation.sr_evaluator.SR_evaluator] used to
            score the models. Must be initialized with ``ranking_function="rmse"`` and a
            non-``None`` ``y``.
        scope: Which expressions to score.

            - ``"best"``: only the best expression.
            - ``"top"``: the best expression and all top-k models.
            - ``"all"``: everything in ``"top"`` plus all evaluated expressions.
        name: Key used in
            ``augmentations`` dict of [EvalResult][SRToolkit.utils.types.EvalResult] and
            [ModelResult][SRToolkit.utils.types.ModelResult].
            Default ``"R2"``.

    Raises:
        Exception: If ``evaluator.ranking_function != "rmse"`` or ``evaluator.y is None``.
    """
    super().__init__(name)

    if scope not in ["best", "top", "all"]:
        raise Exception(f"[R2 augmenter] Invalid scope: {scope}. Must be one of 'best', 'top', 'all'.")
    self.scope = scope

    self.evaluator = evaluator
    if self.evaluator.ranking_function != "rmse":
        raise Exception("[R2 augmenter] Ranking function of the evaluator must be set to 'rmse' to compute R^2.")
    if self.evaluator.y is None:
        raise Exception("[R2 augmenter] y in the evaluator must not be None to compute R^2.")
    self.ss_tot = np.sum((self.evaluator.y - np.mean(self.evaluator.y)) ** 2)

write_results

write_results(results: EvalResult) -> None

Write R² scores into results and its models.

Stores {"best_expr_r^2": ...} in EvalResult augmentations and {"r^2": ..., "parameters_r^2": ...} in each model's augmentations when scope is "top" or "all".

Parameters:

Name Type Description Default
results EvalResult

The EvalResult to augment.

required
Source code in SRToolkit/evaluation/result_augmentation.py
def write_results(self, results: EvalResult) -> None:
    """
    Write R² scores into *results* and its models.

    Stores ``{"best_expr_r^2": ...}`` in
    [EvalResult][SRToolkit.utils.types.EvalResult] ``augmentations`` and
    ``{"r^2": ..., "parameters_r^2": ...}`` in each model's augmentations when ``scope``
    is ``"top"`` or ``"all"``.

    Args:
        results: The [EvalResult][SRToolkit.utils.types.EvalResult] to augment.
    """
    eval_data: Dict[str, Any] = {"best_expr_r^2": self._compute_r2(results.top_models[0])}
    results.add_augmentation(self.name, eval_data, self._type)

    if self.scope == "top" or self.scope == "all":
        for model in results.top_models:
            key = "".join(model.expr)
            top_model_data: Dict[str, Any] = {
                "r^2": self._compute_r2(model),
                "parameters_r^2": self.evaluator.models[key].parameters,
            }
            model.add_augmentation(self.name, top_model_data, self._type)

    if self.scope == "all":
        for model in results.all_models:
            key = "".join(model.expr)
            all_model_data: Dict[str, Any] = {
                "r^2": self._compute_r2(model),
                "parameters_r^2": self.evaluator.models[key].parameters,
            }
            model.add_augmentation(self.name, all_model_data, self._type)

format_eval_result classmethod

format_eval_result(data: Dict[str, Any]) -> str

Format experiment-level R² data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "best_expr_r^2".

required

Returns:

Type Description
str

A human-readable string, or empty string if no data is present.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_eval_result(cls, data: Dict[str, Any]) -> str:
    """
    Format experiment-level R² data for display.

    Args:
        data: Augmentation dict containing ``"best_expr_r^2"``.

    Returns:
        A human-readable string, or empty string if no data is present.
    """
    val = data.get("best_expr_r^2", "")
    return f"Test R²: {val}" if val != "" else ""

format_model_result classmethod

format_model_result(data: Dict[str, Any]) -> str

Format per-model R² data for display.

Parameters:

Name Type Description Default
data Dict[str, Any]

Augmentation dict containing "r^2" and optionally "parameters_r^2".

required

Returns:

Type Description
str

A human-readable string with R² and fitted parameters.

Source code in SRToolkit/evaluation/result_augmentation.py
@classmethod
def format_model_result(cls, data: Dict[str, Any]) -> str:
    """
    Format per-model R² data for display.

    Args:
        data: Augmentation dict containing ``"r^2"`` and optionally ``"parameters_r^2"``.

    Returns:
        A human-readable string with R² and fitted parameters.
    """
    parts = [f"R²={data['r^2']:.4g}"]
    if "parameters_r^2" in data and data["parameters_r^2"] is not None:
        parts.append(f"params={np.round(data['parameters_r^2'], 4).tolist()}")
    return ", ".join(parts)

to_dict

to_dict(base_path: str, name: str) -> dict

Creates a dictionary representation of the R2 augmenter.

Parameters:

Name Type Description Default
base_path str

Used to save the data of the evaluator to disk.

required
name str

Used to save the data of the evaluator to disk.

required

Returns:

Type Description
dict

A dictionary containing the necessary information to recreate the augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
def to_dict(self, base_path: str, name: str) -> dict:
    """
    Creates a dictionary representation of the R2 augmenter.

    Args:
        base_path: Used to save the data of the evaluator to disk.
        name: Used to save the data of the evaluator to disk.

    Returns:
        A dictionary containing the necessary information to recreate the augmenter.
    """
    return {
        "format_version": 1,
        "name": self.name,
        "type": "R2",
        "scope": self.scope,
        "evaluator": self.evaluator.to_dict(base_path, name + "_R2_augmenter"),
    }

from_dict staticmethod

from_dict(data: dict) -> R2

Creates an instance of the R2 augmenter from a dictionary.

Parameters:

Name Type Description Default
data dict

A dictionary containing the necessary information to recreate the augmenter.

required

Returns:

Type Description
R2

An instance of the R2 augmenter.

Source code in SRToolkit/evaluation/result_augmentation.py
@staticmethod
def from_dict(data: dict) -> "R2":
    """
    Creates an instance of the R2 augmenter from a dictionary.

    Args:
        data: A dictionary containing the necessary information to recreate the augmenter.

    Returns:
        An instance of the R2 augmenter.
    """
    if data.get("format_version", 1) != 1:
        raise ValueError(f"[R2.from_dict] Unsupported format_version: {data.get('format_version')!r}. Expected 1.")
    evaluator = SR_evaluator.from_dict(data["evaluator"])
    return R2(evaluator, scope=data["scope"], name=data["name"])