from __future__ import annotations
import ast
from typing import NamedTuple
from flowrep import edge_models
from flowrep.nodes import for_recipe, helper_models
from flowrep.parsers import parser_protocol, symbol_scope
FOR_BODY_LABEL: str = "body"
class _IterationAxis(NamedTuple):
"""Holding the variable, x, and the collection xs, in statements like for x in xs"""
variable: str
collection: str
AccumulatorMap = dict[str, str]
"""
Maps accumulator names, xs, to appended symbol names, x, in statements like xs.append(x)
"""
[docs]
def parse_for_node(
walker: parser_protocol.BodyWalker, tree: ast.For
) -> for_recipe.ForEachRecipe:
"""
Walk a for-loop.
Args:
walker: A walker to fork and use for collecting state inside the tree.
tree: The top-level ``ast.For`` node (may contain immediately
nested for-headers that declare additional iteration axes).
"""
# Parse the iteration header — pure AST, no parser state needed
nested_iters, zipped_iters, body_tree = _parse_for_iterations(tree)
all_iters = nested_iters + zipped_iters
# When we fork the scope here, we replace iterated-over symbols with iteration
# variables, all as InputSources from the body's perspective
body_symbol_map = walker.symbol_map.fork(
{src: var for var, src in all_iters},
available_accumulators=walker.symbol_map.declared_accumulators.copy(),
)
body_walker = walker.fork(new_symbol_map=body_symbol_map, new_scope=walker.scope)
body_walker.walk(body_tree.body)
consumed = body_walker.symbol_map.consumed_accumulators
_validate_some_output_exists(consumed)
_validate_no_unused_iterators(all_iters, body_walker, consumed)
_validate_no_leaked_reassignments(
all_iters, body_walker, consumed, walker.symbol_map
)
nested_ports = [var for var, _ in nested_iters]
zipped_ports = [var for var, _ in zipped_iters]
inputs, input_edges = _wire_inputs(body_walker, all_iters)
outputs, output_edges = _wire_outputs(body_walker, input_edges)
body_node = helper_models.LabeledRecipe(
label=FOR_BODY_LABEL, node=body_walker.build_model()
)
return for_recipe.ForEachRecipe(
inputs=inputs,
outputs=outputs,
body_node=body_node,
input_edges=input_edges,
output_edges=output_edges,
nested_ports=nested_ports,
zipped_ports=zipped_ports,
)
def _validate_some_output_exists(consumed: AccumulatorMap):
if len(consumed) == 0:
raise ValueError("For nodes must use up at least one accumulator symbol.")
def _validate_no_unused_iterators(
all_iters: list[_IterationAxis],
body_walker: parser_protocol.BodyWalker,
consumed: AccumulatorMap,
):
"""
Every iteration variable must actually be consumed inside the body.
An unused iterator likely indicates a bug; if the user only needs the structural
effect (e.g. repetition count), they should make the dependency explicit.
"""
iterating_symbols = {var for var, _ in all_iters}
consumed_symbols = set(body_walker.inputs) | set(consumed.values())
if unused := iterating_symbols - consumed_symbols:
raise ValueError(
f"For-node iteration variable(s) {sorted(unused)} are never "
f"used inside the node body. Either use them or remove them "
f"from the iteration header."
)
def _validate_no_leaked_reassignments(
all_iters: list[_IterationAxis],
body_walker: parser_protocol.BodyWalker,
consumed: AccumulatorMap,
symbol_map: symbol_scope.SymbolScope,
):
"""
Check for internal symbol reassignments that would leak to un-captured outputs --
the only outputs we allow from a for node are iterated outputs!
"""
body_reassigned = set(body_walker.symbol_map.reassigned_symbols)
accumulator_outputs = set(consumed)
unreturned_reassignments = (
body_reassigned - accumulator_outputs - {var for var, _ in all_iters}
)
leaked_reassignments = unreturned_reassignments.intersection(symbol_map.keys())
if leaked_reassignments:
raise ValueError(
f"For-loop body reassigns symbol(s) {sorted(leaked_reassignments)} "
f"from the enclosing scope. This is not supported because for-node "
f"outputs are determined by accumulators. If you need the reassigned "
f"value after the loop, accumulate it explicitly."
)
def _wire_inputs(
body_walker: parser_protocol.BodyWalker, all_iters: list[_IterationAxis]
) -> tuple[list[str], edge_models.InputEdges]:
consumed = body_walker.symbol_map.consumed_accumulators
broadcast_symbols = [
s
for s in body_walker.inputs
if s not in set(consumed.values())
and s not in {iterating_symbol for iterating_symbol, _ in all_iters}
] # Need to keep it consistently ordered, so don't use a simple set op
scattered_symbols = [scattered_symbol for _, scattered_symbol in all_iters]
inputs = broadcast_symbols + scattered_symbols
broadcast_inputs = {
edge_models.TargetHandle(
node=FOR_BODY_LABEL, port=port
): edge_models.InputSource(port=port)
for port in broadcast_symbols
}
scattered_inputs = {
edge_models.TargetHandle(
node=FOR_BODY_LABEL, port=body_port
): edge_models.InputSource(port=for_port)
for body_port, for_port in all_iters
}
input_edges = broadcast_inputs | scattered_inputs
return inputs, input_edges
def _wire_outputs(
body_walker: parser_protocol.BodyWalker, input_edges: edge_models.InputEdges
) -> tuple[list[str], edge_models.OutputEdges]:
consumed = body_walker.symbol_map.consumed_accumulators
outputs = list(consumed)
output_edges: edge_models.OutputEdges = {}
for accumulator_symbol, appended_symbol in consumed.items():
target = edge_models.OutputTarget(port=accumulator_symbol)
if appended_symbol in body_walker.outputs:
output_edges[target] = edge_models.SourceHandle(
node=FOR_BODY_LABEL, port=appended_symbol
)
else:
output_edges[target] = input_edges[
edge_models.TargetHandle(node=FOR_BODY_LABEL, port=appended_symbol)
]
return outputs, output_edges
def _parse_for_iterations(
for_stmt: ast.For,
) -> tuple[list[_IterationAxis], list[_IterationAxis], ast.For]:
"""
Parse for-node iteration structure, handling zip and immediately nested iterations.
Returns (nested_iterations, zipped_iterations) where each is a list of
(variable_name, source_symbol) tuples.
"""
nested: list[_IterationAxis] = []
zipped: list[_IterationAxis] = []
current = for_stmt
while isinstance(current, ast.For):
is_zip, pairs = _parse_single_for_header(current)
if is_zip:
zipped.extend(pairs)
else:
nested.extend(pairs)
# Check for nested for-declaration (single statement that's another For)
if len(current.body) >= 1 and isinstance(current.body[0], ast.For):
current = current.body[0]
else:
break
return nested, zipped, current
def _parse_single_for_header(
for_stmt: ast.For,
) -> tuple[bool, list[_IterationAxis]]:
"""
Parse a single for-header.
Returns (is_zipped, [(var, source), ...]).
"""
iter_expr = for_stmt.iter
target = for_stmt.target
# Check for zip()
if isinstance(iter_expr, ast.Call) and _is_zip_call(iter_expr):
if not isinstance(target, ast.Tuple):
raise ValueError("zip() iteration requires tuple unpacking")
vars_list = [elt.id for elt in target.elts if isinstance(elt, ast.Name)]
if len(vars_list) != len(target.elts):
raise ValueError("zip() iteration targets must be simple names")
sources = []
for arg in iter_expr.args:
if not isinstance(arg, ast.Name):
raise ValueError("zip() arguments must be simple symbols")
sources.append(arg.id)
if len(vars_list) != len(sources):
raise ValueError(
f"zip() variable count ({len(vars_list)}) must match "
f"argument count ({len(sources)})"
)
return True, [
_IterationAxis(v, s) for v, s in zip(vars_list, sources, strict=True)
]
# Simple iteration: for x in xs
if not isinstance(iter_expr, ast.Name):
raise ValueError(
"For iteration must iterate over a symbol (not an inline expression)"
)
if isinstance(target, ast.Name):
return False, [_IterationAxis(target.id, iter_expr.id)]
elif isinstance(target, ast.Tuple):
# for a, b in items (tuple unpacking without zip)
raise ValueError(
"Tuple unpacking in for-nodes requires zip(). "
"Use 'for a, b in zip(as, bs):' instead of 'for a, b in items:'"
)
else:
raise ValueError(f"Unsupported for iteration target: {type(target)}")
def _is_zip_call(node: ast.Call) -> bool:
"""Check if a Call node is a call to zip()."""
return isinstance(node.func, ast.Name) and node.func.id == "zip"