from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import pydantic
from flowrep import base_models, edge_models, subgraph_validation
from flowrep.nodes import helper_models
if TYPE_CHECKING:
from flowrep.nodes.union_types import Recipes
[docs]
class ForEachRecipe(base_models.NodeRecipe):
"""
Loop over a body node and collect outputs as a list.
Each loop step is to be treated independently, such that the overall loop behaves
as a map.
This is a dynamic node, which must actualize the body of its subgraph at runtime.
Loops can be done with a combination of nested iteration and zipping values.
Output edges whose source is an `InputSource` indicate data forwarded directly from
the for-node's own inputs. In the even that these are inputs that are scattered to
body nodes from, the iteration, it is the responsibility of the WfMS to collect
these into lists alongside the body node outputs. This allows outputs to be linked
directly to the input that generated them.
Intended recipe realization:
1. Assess the number of body executions necessary by examining the lengths of
nested and zipped ports, and the length of data on the corresponding inputs.
a) The data in all inputs being passed to zipped ports should be
length-validated
b) The count scales multiplicatively with the data in each input passed to
nested ports, and finally multiplied once more by the zipped length (or
directly the zipped length if no nested ports are present).
2. Create the appropriate number of body node instances in the subgraph
3. Broadcast input edges not involved in nested or zipped ports to each child
4. Decompose input for input edges used for zipped or nested ports and scatter
edges to each child accordingly
a) The manner of this decomposition is an implementation detail for which the
WfMS is responsible
5. Collect output of child nodes into list fields and connect these to the output
according to the output edges.
a) The manner of this collection is an implementation detail for which the
WfMS is responsible
6. For output edges sourced from InputSource (rather than body SourceHandle),
collect the corresponding input values used for each iteration and connect
to output accordingly.
Attributes:
type: The node type -- always "for_each".
inputs: The available input port names.
outputs: The available output port names.
body_node: The labeled node to execute for each iteration.
input_edges: Edges from workflow inputs to inputs of body node instances.
output_edges: Edges from body node outputs or for-node inputs to workflow
outputs. Sources that are InputSource values indicate forwarded input data
(collected per-iteration); SourceHandle values indicate body node outputs.
nested_ports: The body node ports over which to do nested iteration. Input
edges will map parent input elements to each child node accordingly.
zipped_ports: The body node ports over which to do zipped iteration. Input
edges will map parent input elements to each child node accordingly.
Notes:
At runtime, iterated input values should themselves be iterable. It is
recommended to pass values conforming to `collections.abc.Collection`. This is
a runtime behaviour, and is thus not enforced here at the recipe level in any
way.
All iterated output — whether collected from body executions or forwarded from
scattered inputs — should have the same length. Thus, forwarded inputs empower
the node output to precisely provide which input was used to produce each
output element.
"""
type: Literal[base_models.RecipeElementType.FOR_EACH] = pydantic.Field(
default=base_models.RecipeElementType.FOR_EACH, frozen=True
)
body_node: helper_models.LabeledRecipe
input_edges: edge_models.InputEdges
output_edges: edge_models.OutputEdges
nested_ports: base_models.Labels = pydantic.Field(default_factory=list)
zipped_ports: base_models.Labels = pydantic.Field(default_factory=list)
@property
def prospective_nodes(self) -> Recipes:
return {self.body_node.label: self.body_node.node}
@property
def iterated_ports(self) -> base_models.Labels:
return self.nested_ports + self.zipped_ports
@property
def transferred_outputs(self) -> edge_models.OutputEdges:
"""
Output edges sourced from iterated (nested/zipped) inputs.
These inputs are scattered across body executions, so the WfMS must
collect them back into lists correlated with body node outputs. This is a
helper property for the WfMS to more easily find these.
"""
return {
target: source
for target, source in self.output_edges.items()
if isinstance(source, edge_models.InputSource)
and source.port in self._iterated_input_ports
}
@property
def _iterated_input_ports(self) -> set[str]:
"""For-node input ports that feed into iterated (nested/zipped) body ports."""
return {
source.port
for target, source in self.input_edges.items()
if target.port in self.iterated_ports
}
[docs]
@pydantic.model_validator(mode="after")
def validate_io_edges(self):
subgraph_validation.validate_input_edge_sources(self.input_edges, self.inputs)
subgraph_validation.validate_input_edge_targets(
self.input_edges,
self.prospective_nodes,
)
subgraph_validation.validate_output_edge_targets(
self.output_edges, self.outputs
)
subgraph_validation.validate_output_edge_sources(
self.output_edges.values(),
self.prospective_nodes,
self.inputs,
)
# Disallow passthrough outputs -- there is no way to generate these at this
# scope from python, so simply disallow them for simplicity and consistency
if passthrough := {
source.serialize()
for source in self.output_edges.values()
if isinstance(source, edge_models.InputSource)
and source.port not in self._iterated_input_ports
}:
raise ValueError(
f"Output edges from input sources are only allowed if the input is "
f"being iterated on, but got: {passthrough}"
)
return self
[docs]
@pydantic.model_validator(mode="after")
def validate_some_iteration(self):
if not (self.nested_ports or self.zipped_ports):
raise ValueError("For node must have at least one nested or zipped port")
return self
[docs]
@pydantic.model_validator(mode="after")
def validate_non_overlapping_iterators(self):
if not set(self.nested_ports).isdisjoint(self.zipped_ports):
raise ValueError(
f"Loop values in nested_ports or zipped_ports must not overlap, but "
f"share {set(self.nested_ports).intersection(self.zipped_ports)}."
)
return self
[docs]
@pydantic.model_validator(mode="after")
def validate_iterated_ports_exist(self):
if invalid := {
port
for port in self.iterated_ports
if port not in self.body_node.node.inputs
}:
raise ValueError(
f"For node must iterate on body node ports "
f"({self.body_node.node.inputs}) but got: {invalid}"
)
return self
[docs]
@pydantic.model_validator(mode="after")
def validate_internal_data_completeness(self):
subgraph_validation.validate_nodes_are_fully_sourced(
self.prospective_nodes, self.input_edges
)
return self