import collections
from collections.abc import Collection
from typing import Protocol, runtime_checkable
from flowrep import base_models, edge_models
ProspectiveOutputEdges = dict[edge_models.OutputTarget, list[edge_models.SourceHandle]]
[docs]
class RecipeProtocol(Protocol):
inputs: base_models.Labels
outputs: base_models.Labels
@property
def inputs_with_defaults(self) -> base_models.Labels: ...
[docs]
def validate_internal_data_completeness(self): ...
NodesAlias = dict[base_models.Label, RecipeProtocol]
[docs]
@runtime_checkable
class StaticSubgraphOwner(Protocol):
"""Owns a concrete subgraph known at definition time (WorkflowRecipe)."""
inputs: base_models.Labels
outputs: base_models.Labels
nodes: NodesAlias
input_edges: edge_models.InputEdges
edges: edge_models.Edges
output_edges: edge_models.OutputEdges
[docs]
class DynamicSubgraphOwner(Protocol):
"""
Owns a subgraph instantiated at runtime (ForEachRecipe, WhileRecipe, IfRecipe,
TryRecipe).
"""
inputs: base_models.Labels
outputs: base_models.Labels
input_edges: edge_models.InputEdges
@property
def prospective_nodes(self) -> NodesAlias: ...
[docs]
@runtime_checkable
class DynamicSubgraphStaticOutput(DynamicSubgraphOwner, Protocol):
"""
Dynamic subgraph with output interface known a-priori (ForEachRecipe, WhileRecipe).
"""
output_edges: edge_models.OutputEdges
[docs]
@runtime_checkable
class DynamicSubgraphDynamicOutput(DynamicSubgraphOwner, Protocol):
"""Dynamic subgraph with output interface known at runtime (IfRecipe, TryRecipe)."""
prospective_output_edges: ProspectiveOutputEdges
[docs]
def validate_output_edge_targets(
output_targets: Collection[edge_models.OutputTarget],
available_outputs: base_models.Labels,
) -> None:
target_ports = {t.port for t in output_targets}
if invalid := target_ports - set(available_outputs):
raise ValueError(f"Invalid output target ports: {invalid}")
if missing := set(available_outputs) - target_ports:
raise ValueError(f"Missing output edge for: {missing}")
[docs]
def validate_output_edge_sources(
sources: Collection[edge_models.SourceHandle | edge_models.InputSource],
source_nodes: NodesAlias,
inputs: base_models.Labels,
) -> None:
if invalid_nodes := {
s.serialize()
for s in sources
if s.node is not None and s.node not in source_nodes
}:
raise ValueError(f"Invalid output source nodes: {invalid_nodes}")
if invalid_ports := {
s.serialize()
for s in sources
if (
(s.node is None and s.port not in inputs)
or (s.node is not None and s.port not in source_nodes[s.node].outputs)
)
}:
raise ValueError(f"Invalid output source ports: {invalid_ports}")
[docs]
def validate_prospective_sources_list(
target: edge_models.OutputTarget,
sources: Collection[edge_models.SourceHandle],
) -> None:
if not sources:
raise ValueError(f"Sources for '{target.serialize()}' cannot be empty.")
node_counts = collections.Counter(source.node for source in sources)
if duplicate_nodes := {node for node, count in node_counts.items() if count > 1}:
raise ValueError(
f"Sources for {target.serialize()} must be unique. "
f"Duplicate source nodes: {duplicate_nodes}"
)
[docs]
def validate_sibling_edges(
edges: edge_models.Edges,
target_nodes: NodesAlias,
source_nodes: NodesAlias | None = None,
) -> None:
if source_nodes is None:
source_nodes = target_nodes
for target, source in edges.items():
if target.node not in target_nodes:
raise ValueError(f"Invalid edge target node: {target.serialize()}")
if source.node not in source_nodes:
raise ValueError(f"Invalid edge source node: {source.serialize()}")
if target.port not in target_nodes[target.node].inputs:
raise ValueError(f"Invalid edge target port: {target.serialize()}")
if source.port not in source_nodes[source.node].outputs:
raise ValueError(f"Invalid edge source port: {source.serialize()}")
[docs]
def validate_acyclic_edges(
edges: edge_models.Edges, message="Edges contain cycle(s)"
) -> None:
# Build adjacency list and in-degree count
in_degree: dict[str, int] = {}
successors: dict[str, list[str]] = {}
for target, source in edges.items():
if target.node is None or source.node is None:
continue
s, t = source.node, target.node
in_degree.setdefault(s, 0)
in_degree.setdefault(t, 0)
successors.setdefault(s, [])
in_degree[t] += 1
successors[s].append(t)
# Kahn's algorithm
queue = [n for n, d in in_degree.items() if d == 0]
visited = 0
while queue:
node = queue.pop()
visited += 1
for neighbor in successors.get(node, ()):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
if visited != len(in_degree):
raise ValueError(f"{message}")
[docs]
def validate_nodes_are_fully_sourced(
nodes: NodesAlias,
context: Collection[edge_models.TargetHandle],
):
for label, node in nodes.items():
for port in node.inputs:
target = edge_models.TargetHandle(node=label, port=port)
if port not in node.inputs_with_defaults and target not in context:
raise ValueError(
f"Could not find a source or default for the target: {label}.{port}"
)
for node in nodes.values():
node.validate_internal_data_completeness()