Source code for flowrep.wfms

"""
This module holds a prototypical, minimal Workflow Management System (WfMS) for
flowrep recipes.

Intended for use in tests and documentation, and as an example implementation to which
fully-fledged WfMS can refer.
"""

from __future__ import annotations

import dataclasses
import heapq
import itertools
from collections.abc import Collection
from typing import Any, cast

from pyiron_snippets import retrieve

from flowrep import base_models, edge_models, retrospective
from flowrep.nodes import (
    atomic_recipe,
    for_recipe,
    helper_models,
    if_recipe,
    try_recipe,
    union_types,
    while_recipe,
    workflow_recipe,
)
from flowrep.parsers import label_helpers


[docs] def run_recipe( recipe: union_types.RecipeDiscrimination, **kwargs: Any ) -> retrospective.NodeData: """ Execute a flowrep recipe, returning a populated :class:`LiveNode`. All inputs are passed as keyword arguments matching the recipe's input port names. """ match recipe: case atomic_recipe.AtomicRecipe(): return _run_atomic(recipe, **kwargs) case workflow_recipe.WorkflowRecipe(): return _run_workflow(recipe, **kwargs) case for_recipe.ForEachRecipe(): return _run_for(recipe, **kwargs) case if_recipe.IfRecipe(): return _run_if(recipe, **kwargs) case try_recipe.TryRecipe(): return _run_try(recipe, **kwargs) case while_recipe.WhileRecipe(): return _run_while(recipe, **kwargs) case _: raise TypeError(f"Unsupported recipe type: {type(recipe).__name__}")
# --------------------------------------------------------------------------- # Atomic # --------------------------------------------------------------------------- def _run_atomic( recipe: atomic_recipe.AtomicRecipe, **kwargs: Any ) -> retrospective.AtomicData: node = retrospective.AtomicData.from_recipe(recipe) _populate_input_ports(node, kwargs) result = _call_atomic(node) _store_atomic_outputs(node, result) return node def _call_atomic(node: retrospective.AtomicData) -> Any: """ Invoke the underlying function, respecting positional-only parameter kinds. Values are drawn from the input data ports; if a port has no value, its default is used. A :class:`ValueError` is raised when neither is available. """ recipe = node.recipe assert isinstance(recipe, atomic_recipe.AtomicRecipe) positional: list[Any] = [] keyword: dict[str, Any] = {} for name in recipe.inputs: port = node.input_ports[name] val = ( port.value if not isinstance(port.value, retrospective.NotData) else port.default ) if isinstance(val, retrospective.NotData): raise ValueError(f"Input port '{name}' has no value and no default") kind = recipe.reference.restricted_input_kinds.get(name) if kind == base_models.RestrictedParamKind.POSITIONAL_ONLY: positional.append(val) else: keyword[name] = val return node.function(*positional, **keyword) def _store_atomic_outputs(node: retrospective.AtomicData, result: Any) -> None: recipe = node.recipe assert isinstance(recipe, atomic_recipe.AtomicRecipe) output_names = list(node.output_ports.keys()) if recipe.unpack_mode == atomic_recipe.UnpackMode.NONE: node.output_ports[output_names[0]].value = result elif recipe.unpack_mode == atomic_recipe.UnpackMode.TUPLE: if len(output_names) == 1: node.output_ports[output_names[0]].value = result else: for name, val in zip(output_names, result, strict=True): node.output_ports[name].value = val elif recipe.unpack_mode == atomic_recipe.UnpackMode.DATACLASS: fields = dataclasses.fields(result) for label, field in zip(node.recipe.outputs, fields, strict=True): node.output_ports[label].value = getattr(result, field.name) # --------------------------------------------------------------------------- # Workflow # --------------------------------------------------------------------------- def _run_workflow( recipe: workflow_recipe.WorkflowRecipe, **kwargs: Any ) -> retrospective.DagData: node = retrospective.DagData.from_recipe(recipe) _populate_input_ports(node, kwargs) for child_label in _topo_sort_children(recipe): child_inputs = _gather_child_inputs(child_label, recipe, node) child_recipe = recipe.nodes[child_label] child_node = run_recipe(child_recipe, **child_inputs) node.nodes[child_label] = child_node # Overwrite with _executed_ child _populate_workflow_outputs(node, recipe) return node def _topo_sort_children(recipe: workflow_recipe.WorkflowRecipe) -> list[str]: """Kahn's algorithm over sibling edges; deterministic tie-breaking by label.""" in_degree: dict[str, int] = {label: 0 for label in recipe.nodes} successors: dict[str, list[str]] = {label: [] for label in recipe.nodes} for target, source in recipe.edges.items(): in_degree[target.node] += 1 successors[source.node].append(target.node) queue = [label for label in recipe.nodes if in_degree[label] == 0] heapq.heapify(queue) order: list[str] = [] while queue: label = heapq.heappop(queue) order.append(label) for succ in successors.get(label, []): in_degree[succ] -= 1 if in_degree[succ] == 0: heapq.heappush(queue, succ) if len(order) != len(recipe.nodes): # pragma: no cover raise ValueError( "Cycle detected in workflow edges. This should have been caught by the " "underlying recipe validation. Please raise a GitHub issue reporting " "how you got here!" ) return order def _gather_child_inputs( child_label: str, recipe: workflow_recipe.WorkflowRecipe, workflow_node: retrospective.DagData, ) -> dict[str, Any]: """ Resolve input values for a child node from workflow input ports and sibling output ports according to the recipe edges. Ports not covered by any edge are omitted — the child's own defaults (if any) will be used downstream. """ child_recipe = recipe.nodes[child_label] inputs: dict[str, Any] = {} for port in child_recipe.inputs: th = edge_models.TargetHandle(node=child_label, port=port) if th in recipe.input_edges: parent_source = recipe.input_edges[th] inputs[port] = workflow_node.input_ports[parent_source.port].get_data() elif th in recipe.edges: sibling_source = recipe.edges[th] sibling = workflow_node.nodes[sibling_source.node] inputs[port] = sibling.output_ports[sibling_source.port].value # else: port has a default on the child, _call_atomic will handle it return inputs def _populate_workflow_outputs( node: retrospective.DagData, recipe: workflow_recipe.WorkflowRecipe ) -> None: for target, source in recipe.output_edges.items(): if isinstance(source, edge_models.InputSource): val = node.input_ports[source.port].get_data() else: child = node.nodes[source.node] val = child.output_ports[source.port].value node.output_ports[target.port].value = val # --------------------------------------------------------------------------- # For # --------------------------------------------------------------------------- def _run_for( recipe: for_recipe.ForEachRecipe, **kwargs: Any ) -> retrospective.FlowControlData: """ Execute a for-node by scattering iterated inputs across body instances and collecting outputs into lists. Nested ports drive a Cartesian product; zipped ports are iterated in lockstep. Broadcast (non-iterated) inputs are passed unchanged to every body instance. Transferred outputs collect the per-iteration value of a scattered input, preserving the link between input element and body output element. """ node = retrospective.FlowControlData.from_recipe(recipe) _populate_input_ports(node, kwargs) body_label = recipe.body_node.label body_recipe = recipe.body_node.node iterated_ports = recipe.iterated_ports # body iterated port -> for-node input name body_to_for: dict[str, str] = {} for port in iterated_ports: th = edge_models.TargetHandle(node=body_label, port=port) body_to_for[port] = recipe.input_edges[th].port # Reverse mapping for transferred outputs for_to_body: dict[str, str] = {v: k for k, v in body_to_for.items()} # Broadcast inputs (non-iterated body ports sourced from for-node inputs) broadcast: dict[str, Any] = {} for port in body_recipe.inputs: if port not in iterated_ports: th = edge_models.TargetHandle(node=body_label, port=port) if th in recipe.input_edges: src = recipe.input_edges[th] broadcast[port] = node.input_ports[src.port].value # Build iteration axes nested_iters = [ cast(Collection, node.input_ports[body_to_for[p]].value) for p in recipe.nested_ports ] zipped_iters = [ cast(Collection, node.input_ports[body_to_for[p]].value) for p in recipe.zipped_ports ] # Note that we simply cast iterated input values to the form we expect, and let the # user pay the price if runtime data is non-compliant. nested_combos = list(itertools.product(*nested_iters)) if nested_iters else [()] if zipped_iters: zip_len = len(zipped_iters[0]) for zi in zipped_iters: if len(zi) != zip_len: raise ValueError("Zipped inputs must have equal lengths") zipped_combos = list(zip(*zipped_iters, strict=True)) else: zipped_combos = [()] accumulators: dict[str, list[Any]] = {port: [] for port in recipe.outputs} for nested_vals, zipped_vals in itertools.product(nested_combos, zipped_combos): body_kwargs = dict(broadcast) for port, val in zip(recipe.nested_ports, nested_vals, strict=True): body_kwargs[port] = val for port, val in zip(recipe.zipped_ports, zipped_vals, strict=True): body_kwargs[port] = val child = run_recipe(body_recipe, **body_kwargs) idx = len(node.nodes) node.nodes[label_helpers.index_label(body_label, idx)] = child for target, source in recipe.output_edges.items(): if isinstance(source, edge_models.SourceHandle): accumulators[target.port].append(child.output_ports[source.port].value) else: # Transferred output: collect the scattered input element body_port = for_to_body[source.port] accumulators[target.port].append(body_kwargs[body_port]) for port, values in accumulators.items(): node.output_ports[port].value = values return node # --------------------------------------------------------------------------- # While # --------------------------------------------------------------------------- def _run_while( recipe: while_recipe.WhileRecipe, **kwargs: Any ) -> retrospective.FlowControlData: """ Execute a while-node by repeatedly evaluating a condition and running a body. On each iteration the body outputs (which share names with a subset of inputs) feed back into the next condition/body evaluation. If the condition is false on the first check, outputs are sourced from the initial input values. """ node = retrospective.FlowControlData.from_recipe(recipe) _populate_input_ports(node, kwargs) cond_label = recipe.case.condition.label body_label = recipe.case.body.label cond_recipe = recipe.case.condition.node body_recipe = recipe.case.body.node # Working copy of current values — starts from inputs, body outputs update it current: dict[str, Any] = { name: node.input_ports[name].value for name in recipe.inputs } iteration = 0 while True: # --- condition --- cond_kwargs = _gather_dynamic_child_inputs( cond_label, recipe.input_edges, current ) cond_node = run_recipe(cond_recipe, **cond_kwargs) node.nodes[label_helpers.index_label(cond_label, iteration)] = cond_node if not _evaluate_condition(recipe.case, cond_node): break # --- body --- body_kwargs = _gather_dynamic_child_inputs( body_label, recipe.input_edges, current ) body_node = run_recipe(body_recipe, **body_kwargs) node.nodes[label_helpers.index_label(body_label, iteration)] = body_node # Feed body outputs back into current values for target, source in recipe.output_edges.items(): current[target.port] = body_node.output_ports[source.port].value iteration += 1 for name in recipe.outputs: node.output_ports[name].value = current[name] return node # --------------------------------------------------------------------------- # If # --------------------------------------------------------------------------- def _run_if(recipe: if_recipe.IfRecipe, **kwargs: Any) -> retrospective.FlowControlData: """ Execute an if-node by walking cases until a condition evaluates positively, then executing the matching body (or the else case). Output ports that have no source from the executed branch remain NOT_DATA. """ node = retrospective.FlowControlData.from_recipe(recipe) _populate_input_ports(node, kwargs) for case in recipe.cases: # --- condition --- cond_kwargs = _gather_dynamic_child_inputs( case.condition.label, recipe.input_edges, node ) cond_node = run_recipe(case.condition.node, **cond_kwargs) node.nodes[case.condition.label] = cond_node if _evaluate_condition(case, cond_node): _execute_if_branch(node, recipe, case.body) return node # No case matched — try else if recipe.else_case is not None: _execute_if_branch(node, recipe, recipe.else_case) return node def _execute_if_branch( node: retrospective.FlowControlData, recipe: if_recipe.IfRecipe, branch: helper_models.LabeledRecipe, ) -> None: branch_kwargs = _gather_dynamic_child_inputs(branch.label, recipe.input_edges, node) branch_node = run_recipe(branch.node, **branch_kwargs) node.nodes[branch.label] = branch_node _populate_prospective_outputs(node, recipe.prospective_output_edges, branch.label) # --------------------------------------------------------------------------- # Try # --------------------------------------------------------------------------- def _run_try( recipe: try_recipe.TryRecipe, **kwargs: Any ) -> retrospective.FlowControlData: """ Execute a try-node: run the try body and, on exception, walk exception cases for a matching handler. If no handler matches, the exception propagates. """ node = retrospective.FlowControlData.from_recipe(recipe) _populate_input_ports(node, kwargs) try_kwargs = _gather_dynamic_child_inputs( recipe.try_node.label, recipe.input_edges, node ) try: try_node = run_recipe(recipe.try_node.node, **try_kwargs) node.nodes[recipe.try_node.label] = try_node _populate_prospective_outputs( node, recipe.prospective_output_edges, recipe.try_node.label ) return node except BaseException as exc: for case in recipe.exception_cases: exc_types = tuple( retrieve.import_from_string(info.fully_qualified_name) for info in case.exceptions ) if isinstance(exc, exc_types): handler_kwargs = _gather_dynamic_child_inputs( case.body.label, recipe.input_edges, node ) handler_node = run_recipe(case.body.node, **handler_kwargs) node.nodes[case.body.label] = handler_node _populate_prospective_outputs( node, recipe.prospective_output_edges, case.body.label ) return node raise # --------------------------------------------------------------------------- # Shared helpers # --------------------------------------------------------------------------- def _populate_input_ports(node: retrospective.NodeData, values: dict[str, Any]) -> None: for name, val in values.items(): if name in node.input_ports: node.input_ports[name].value = val else: raise ValueError( f"Input port '{name}' not found -- please select among " f"{node.recipe.inputs}" ) def _gather_dynamic_child_inputs( child_label: str, input_edges: edge_models.InputEdges, source: retrospective.NodeData | dict[str, Any], ) -> dict[str, Any]: """ Gather inputs for a dynamic subgraph child. *source* can be a LiveNode (reads from ``input_ports``) or a plain dict (used by the while-node where current values are tracked in a dict). """ result: dict[str, Any] = {} for target, edge_source in input_edges.items(): if target.node == child_label: if isinstance(source, dict): result[target.port] = source[edge_source.port] else: result[target.port] = source.input_ports[edge_source.port].value return result def _evaluate_condition( case: helper_models.ConditionalCase, cond_node: retrospective.NodeData, ) -> bool: if case.condition_output is not None: return bool(cond_node.output_ports[case.condition_output].value) output_name = next(iter(cond_node.output_ports)) return bool(cond_node.output_ports[output_name].value) def _populate_prospective_outputs( node: retrospective.FlowControlData, prospective_output_edges: dict[ edge_models.OutputTarget, list[edge_models.SourceHandle] ], active_label: str, ) -> None: """Wire outputs from the branch that actually executed.""" for target, sources in prospective_output_edges.items(): for source in sources: if source.node == active_label and source.node in node.nodes: child = node.nodes[source.node] node.output_ports[target.port].value = child.output_ports[ source.port ].value break