from __future__ import annotations
import ast
import dataclasses
from pyiron_snippets import versions
from flowrep import edge_models, subgraph_validation
from flowrep.nodes import helper_models
from flowrep.parsers import (
atomic_parser,
object_scope,
parser_helpers,
parser_protocol,
symbol_scope,
)
[docs]
def parse_case(
test: ast.expr,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
info_factory: versions.VersionInfoFactory,
label: str,
) -> tuple[helper_models.LabeledRecipe, edge_models.InputEdges]:
"""
Parse a conditional expression.
Validates that the statement is a function call returning exactly one value.
Returns the labeled condition node, and the input edges neeeded to feed it.
"""
if not isinstance(test, ast.Call):
raise ValueError(
"Test conditions must be a function call, but got " f"{type(test).__name__}"
)
condition = atomic_parser.get_labeled_recipe(test, set(), scope, info_factory)
if len(condition.node.outputs) != 1:
raise ValueError(
f"If/elif condition must return exactly one value (and it had better be "
f"truthy), but got {condition.node.outputs}"
)
scope_copy = symbol_map.fork()
parser_helpers.consume_call_arguments(scope_copy, test, condition)
return _relabel_node_data(condition, scope_copy.input_edges, label)
def _relabel_node_data(
labeled_node: helper_models.LabeledRecipe,
inputs: edge_models.InputEdges,
new_label: str,
) -> tuple[helper_models.LabeledRecipe, edge_models.InputEdges]:
relabeled_node = helper_models.LabeledRecipe(
label=new_label, node=labeled_node.node
)
relabeled_inputs: edge_models.InputEdges = {
edge_models.TargetHandle(node=new_label, port=target.port): source
for target, source in inputs.items()
}
return relabeled_node, relabeled_inputs
[docs]
@dataclasses.dataclass
class WalkedBranch:
label: str
walker: parser_protocol.BodyWalker
assigned: list[str]
[docs]
def to_labeled_node(self) -> helper_models.LabeledRecipe:
return helper_models.LabeledRecipe(
label=self.label,
node=self.walker.build_model(),
)
[docs]
def walk_branch(
walker: parser_protocol.BodyWalker,
label: str,
stmts: list[ast.stmt],
) -> WalkedBranch:
"""
Fork a walker and walk a conditional branch body.
Both the :class:`SymbolScope` and the :class:`ScopeProxy` are forked so
that symbol assignments *and* import-based scope extensions in one branch
do not leak into sibling or parent branches.
"""
symbol_fork = walker.symbol_map.fork()
scope_fork = walker.scope.fork()
branch_walker = walker.fork(
new_symbol_map=symbol_fork,
new_scope=scope_fork,
)
branch_walker.walk(stmts)
assigned = symbol_fork.assigned_symbols
symbol_fork.produce_symbols(assigned)
return WalkedBranch(label, branch_walker, assigned)
[docs]
def wire_outputs(
branches: list[WalkedBranch],
) -> tuple[list[str], subgraph_validation.ProspectiveOutputEdges]:
"""Collect outputs and prospective output edges from try and except bodies."""
# Union of assigned symbols across all branches, preserving first-seen order
outputs: list[str] = []
seen: set[str] = set()
for branch in branches:
for sym in branch.assigned:
if sym not in seen:
seen.add(sym)
outputs.append(sym)
# Build prospective output edges: each output maps to the list of branch
# body nodes that can source it.
prospective_output_edges: subgraph_validation.ProspectiveOutputEdges = {}
for output_name in outputs:
target = edge_models.OutputTarget(port=output_name)
sources: list[edge_models.SourceHandle] = []
for branch in branches:
if output_name in branch.assigned:
sources.append(
edge_models.SourceHandle(node=branch.label, port=output_name)
)
prospective_output_edges[target] = sources
return outputs, prospective_output_edges