from __future__ import annotations
import ast
import importlib
import inspect
from collections.abc import Callable, Collection
from types import FunctionType
from typing import cast
from pyiron_snippets import versions
from flowrep import base_models, edge_models
from flowrep.nodes import helper_models, union_types, workflow_recipe
from flowrep.parsers import (
atomic_parser,
for_parser,
if_parser,
label_helpers,
object_scope,
parser_helpers,
parser_protocol,
symbol_scope,
try_parser,
while_parser,
)
[docs]
def workflow(
func: FunctionType | str | None = None,
/,
*output_labels: str,
version_scraping: versions.VersionScrapingMap | None = None,
forbid_main: bool = False,
forbid_locals: bool = False,
require_version: bool = False,
) -> FunctionType | Callable[[FunctionType], FunctionType]:
"""
Decorator that attaches a :class:`~flowrep.models.nodes.workflow_recipe.WorkflowRecipe`
to the ``flowrep_recipe`` attribute of a function, under constraints that the
function body is parseable as a workflow recipe.
The decorated function's module, qualname, and (optionally) package version are
captured as provenance metadata via
:meth:`~pyiron_snippets.versions.VersionInfo.of`.
Can be used with or without arguments.
Args:
func: The function to decorate. Passed positionally by Python when the
decorator is used without parentheses.
*output_labels: Explicit names for the workflow's output ports. When
provided, their count must match the number of returned symbols.
version_scraping: Optional mapping from top-level package names to callables
that return a version string. Forwarded to
:meth:`~pyiron_snippets.versions.VersionInfo.of`.
forbid_main: If ``True``, raise if the function's module is ``__main__``.
forbid_locals: If ``True``, raise if the function's qualname contains
``<locals>``.
require_version: If ``True``, raise if no version can be determined for
the function's package.
Returns:
The original function with a ``flowrep_recipe`` attribute holding a
:class:`~flowrep.models.nodes.workflow_recipe.WorkflowRecipe`.
"""
return parser_helpers.parser2decorator(
func,
output_labels,
parser=parse_workflow,
decorator_name="@workflow",
parser_kwargs={
"version_scraping": version_scraping,
"forbid_main": forbid_main,
"forbid_locals": forbid_locals,
"require_version": require_version,
},
)
[docs]
def parse_workflow(
func: FunctionType,
*output_labels: str,
version_scraping: versions.VersionScrapingMap | None = None,
forbid_main: bool = False,
forbid_locals: bool = False,
require_version: bool = False,
):
"""
Build a :class:`~flowrep.models.nodes.workflow_recipe.WorkflowRecipe` by
statically analysing a Python function's AST.
The function body is walked statement-by-statement; assignments with calls on
the right-hand side become atomic (or recursively parsed) child nodes, and
supported control-flow structures (``for``, ``while``, ``if``, ``try``) are
converted into the corresponding composite node types. A single ``return``
statement defines the workflow's output ports.
Args:
func: The function to parse into a workflow graph.
*output_labels: Explicit output port names. When provided, their count must
match the number of returned symbols.
version_scraping: Optional version-scraping overrides, forwarded to
:meth:`~pyiron_snippets.versions.VersionInfo.of`.
forbid_main: If ``True``, raise if the function's module is ``__main__``.
forbid_locals: If ``True``, raise if the function's qualname contains
``<locals>``.
require_version: If ``True``, raise if no version can be determined.
Returns:
A fully constructed :class:`WorkflowRecipe`.
Raises:
ValueError: If the function has no return, multiple returns, returns
duplicate symbols, returns workflow inputs directly, or if any
``forbid_*`` / ``require_*`` constraint is violated.
TypeError: If the function body contains unsupported AST statement types.
"""
info_factory = versions.VersionInfoFactory(
version_scraping=version_scraping,
forbid_main=forbid_main,
forbid_locals=forbid_locals,
require_version=require_version,
)
function_info = info_factory.of(func)
signature_info = parser_helpers.SignatureInfo.of(func)
docstring = inspect.getdoc(func)
inputs = signature_info.names
reference = base_models.PythonReference(
info=function_info,
inputs_with_defaults=signature_info.have_defaults,
restricted_input_kinds=signature_info.have_restricted_kinds,
)
state = _WorkflowFunctionParser(
object_scope.get_scope(func),
symbol_scope.SymbolScope({p: edge_models.InputSource(port=p) for p in inputs}),
source=reference,
info_factory=info_factory,
func=func,
output_labels=output_labels,
)
tree = parser_helpers.get_ast_function_node(func)
state.walk(skip_docstring(tree.body))
if not state.found_return:
raise ValueError("Workflow python definitions must have a return statement.")
return state.build_model(inputs_override=inputs, description=docstring)
[docs]
def skip_docstring(body: list[ast.stmt]) -> list[ast.stmt]:
return (
body[1:]
if (
body
and isinstance(body[0], ast.Expr)
and isinstance(body[0].value, ast.Constant)
and isinstance(body[0].value.value, str)
)
else body
)
[docs]
class WorkflowParser(ast.NodeVisitor, parser_protocol.BodyWalker):
"""
Aggregates state until there is enough data to successfully build the pydantic
data model.
Treatment for different ast nodes is under `handle_*` methods, and aim to keep all
state mutation of _this object_ directly in those methods.
Other callers reference the handle methods as they walk through some ast tree,
e.g. to build a top-level workflow from a function definition (`ast.FunctionDef`),
or to dynamically build a workflow from the body of some control flow.
"""
def __init__(
self,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
info_factory: versions.VersionInfoFactory,
source: base_models.PythonReference | None = None,
):
self.scope = scope
self.symbol_map = symbol_map
self.info_factory = info_factory
self.nodes: union_types.Recipes = {}
self.source = source
@property
def inputs(self) -> list[str]:
return self.symbol_map.inputs
@property
def input_edges(self) -> edge_models.InputEdges:
return self.symbol_map.input_edges
@property
def edges(self) -> edge_models.Edges:
return self.symbol_map.edges
@property
def output_edges(self) -> edge_models.OutputEdges:
return self.symbol_map.output_edges
@property
def outputs(self) -> list[str]:
return self.symbol_map.outputs
[docs]
def build_model(
self,
inputs_override: list[str] | None = None,
description: str | None = None,
) -> workflow_recipe.WorkflowRecipe:
return workflow_recipe.WorkflowRecipe(
inputs=self.inputs if inputs_override is None else inputs_override,
outputs=self.outputs,
description=description,
nodes=self.nodes,
input_edges=self.input_edges,
edges=self.edges,
output_edges=self.output_edges,
reference=self.source,
)
[docs]
def fork(
self,
*,
new_symbol_map: symbol_scope.SymbolScope,
new_scope: object_scope.ScopeProxy,
) -> WorkflowParser:
"""Create a child walker with optionally replaced symbol map and scope.
Configuration (version scraping, constraints, etc.) is propagated
from this walker. If *new_scope* is ``None``, ``self.scope`` is
reused (shared, not copied).
"""
return WorkflowParser(
scope=new_scope,
symbol_map=new_symbol_map,
info_factory=self.info_factory,
)
[docs]
def walk(self, statements: list[ast.stmt]) -> None:
for statement in statements:
self.visit(statement)
[docs]
def visit_Assign(self, stmt: ast.Assign) -> None:
self._handle_assign(stmt)
[docs]
def visit_AnnAssign(self, stmt: ast.AnnAssign) -> None:
self._handle_assign(stmt)
def _handle_assign(self, body: ast.Assign | ast.AnnAssign):
# Get returned symbols from the left-hand side
lhs = body.targets[0] if isinstance(body, ast.Assign) else body.target
new_symbols = parser_helpers.resolve_symbols_to_strings(lhs)
rhs = body.value
if isinstance(rhs, ast.Call):
child = atomic_parser.get_labeled_recipe(
rhs,
self.nodes.keys(),
self.scope,
self.info_factory,
)
self.nodes[child.label] = child.node
parser_helpers.consume_call_arguments(self.symbol_map, rhs, child)
self.symbol_map.register(new_symbols, child)
elif isinstance(rhs, ast.List) and len(rhs.elts) == 0:
if len(new_symbols) != 1:
raise ValueError(
f"Empty list assignment must target exactly one symbol, "
f"got {new_symbols}"
)
self.symbol_map.register_accumulator(new_symbols[0])
else:
raise ValueError(
f"Workflow python definitions can only interpret assignments with "
f"a call or empty list on the right-hand-side, but ast found "
f"{type(rhs)}"
)
def _digest_flow_control(
self, label_prefix: str, node: union_types.RecipeDiscrimination
) -> None:
label = label_helpers.unique_suffix(label_prefix, self.nodes)
self.nodes[label] = node
self._connect_node_to_enclosing_scope(label, node)
def _connect_node_to_enclosing_scope(
self, label: str, node: union_types.RecipeDiscrimination
):
for port in node.inputs:
self.symbol_map.consume(port, label, port)
labeled_node = helper_models.LabeledRecipe(label=label, node=node)
self.symbol_map.register(new_symbols=node.outputs, child=labeled_node)
[docs]
def visit_For(self, tree: ast.For) -> None:
for_recipe = for_parser.parse_for_node(self, tree)
# Accumulators consumed by the for body are no longer available here
self.symbol_map.declared_accumulators -= set(for_recipe.outputs)
self._digest_flow_control("for_each", for_recipe)
[docs]
def visit_While(self, tree: ast.While) -> None:
while_recipe = while_parser.parse_while_node(self, tree)
self._digest_flow_control("while", while_recipe)
[docs]
def visit_If(self, tree: ast.If) -> None:
if_recipe = if_parser.parse_if_node(self, tree)
self._digest_flow_control("if", if_recipe)
[docs]
def visit_Try(self, tree: ast.Try) -> None:
try_recipe = try_parser.parse_try_node(self, tree)
self._digest_flow_control("try", try_recipe)
[docs]
def visit_Expr(self, stmt: ast.Expr) -> None:
if is_append_call(stmt.value):
self._handle_appending_to_accumulator(cast(ast.Call, stmt.value))
else:
self.generic_visit(stmt)
[docs]
def visit_Import(self, node: ast.Import) -> None:
"""
Handle ``import foo`` and ``import foo as bar`` statements.
Resolves the imported module and registers it in the current
:class:`ScopeProxy` so that subsequent attribute-based calls
(e.g. ``foo.func(x)``) can be resolved.
"""
for alias in node.names:
module = importlib.import_module(alias.name)
if alias.asname is not None:
# import numpy as np → register "np" → numpy module
self.scope.register(alias.asname, module)
else:
# import os.path → register "os" → os module (top-level only)
top_level_name = alias.name.split(".")[0]
top_level_module = importlib.import_module(top_level_name)
self.scope.register(top_level_name, top_level_module)
[docs]
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""
Handle ``from foo import bar`` and ``from foo import bar as baz``.
Resolves each imported name and registers it in the current scope.
"""
if node.module is None or node.level > 0:
raise ValueError(
f"Relative imports are not supported in workflow definitions. "
f"Encountered importing from {node.module}."
)
module = importlib.import_module(node.module)
for alias in node.names:
obj = getattr(module, alias.name)
local_name = alias.asname if alias.asname is not None else alias.name
self.scope.register(local_name, obj)
def _handle_appending_to_accumulator(self, append_call: ast.Call) -> None:
used_accumulator = cast(
ast.Name, cast(ast.Attribute, append_call.func).value
).id
appended_symbol = cast(ast.Name, append_call.args[0]).id
self.symbol_map.use_accumulator(used_accumulator, appended_symbol)
appended_source = self.symbol_map[appended_symbol]
if isinstance(appended_source, edge_models.SourceHandle):
self.symbol_map.produce(appended_symbol)
[docs]
def generic_visit(self, stmt: ast.AST) -> None:
raise TypeError(
f"Workflow python definitions can only interpret a subset of assignments, "
f"and flow controls (for/while/if/try) and (when parsing a function "
f"definition) a return, but ast found "
f"{type(stmt)}"
)
class _WorkflowFunctionParser(WorkflowParser):
def __init__(
self,
scope: object_scope.ScopeProxy,
symbol_map: symbol_scope.SymbolScope,
info_factory: versions.VersionInfoFactory,
*,
source: base_models.PythonReference | None = None,
func: FunctionType,
output_labels: Collection[str],
):
super().__init__(
scope,
symbol_map,
info_factory,
source=source,
)
self._func = func
self._output_labels = output_labels
self._found_return = False
@property
def found_return(self) -> bool:
return self._found_return
def visit_Return(self, stmt: ast.Return) -> None:
if self._found_return:
raise ValueError(
"Workflow python definitions must have exactly one return."
)
self._found_return = True
self.handle_return(stmt, self._func, self._output_labels)
def handle_return(
self,
body: ast.Return,
func: FunctionType,
output_labels: Collection[str],
) -> None:
returned_symbols = parser_helpers.resolve_symbols_to_strings(body.value)
base_models.validate_unique(
returned_symbols,
message=f"Workflow python definitions must have unique returns, but "
f"got duplicates in: {returned_symbols}",
)
annotated_returns = label_helpers.get_annotated_output_labels(func)
scraped_labels = label_helpers.merge_labels(
first_choice=annotated_returns,
fallback=returned_symbols,
message_prefix="Annotation labels and returned symbols mis-match. ",
)
if output_labels and len(output_labels) != len(returned_symbols):
raise ValueError(
f"When output_labels are specified ({output_labels}), workflow "
f"python definitions have a matching number of returned symbols "
f"({returned_symbols})."
)
final_ports = list(output_labels) if output_labels else scraped_labels
for symbol, port in zip(returned_symbols, final_ports, strict=True):
if symbol not in self.symbol_map:
raise ValueError(
f"Return symbol '{symbol}' is not defined. "
f"Available: {list(self.symbol_map)}"
)
self.symbol_map.produce(port, symbol)
[docs]
def is_append_call(node: ast.expr | ast.Expr) -> bool:
"""Check if node is an append call to a known accumulator."""
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == "append"
and isinstance(node.func.value, ast.Name)
)