Source code for flowrep.subgraph_validation

import collections
from collections.abc import Collection
from typing import Protocol, runtime_checkable

from flowrep import base_models, edge_models

ProspectiveOutputEdges = dict[edge_models.OutputTarget, list[edge_models.SourceHandle]]


[docs] class RecipeProtocol(Protocol): inputs: base_models.Labels outputs: base_models.Labels @property def inputs_with_defaults(self) -> base_models.Labels: ...
[docs] def validate_internal_data_completeness(self): ...
NodesAlias = dict[base_models.Label, RecipeProtocol]
[docs] @runtime_checkable class StaticSubgraphOwner(Protocol): """Owns a concrete subgraph known at definition time (WorkflowRecipe).""" inputs: base_models.Labels outputs: base_models.Labels nodes: NodesAlias input_edges: edge_models.InputEdges edges: edge_models.Edges output_edges: edge_models.OutputEdges
[docs] class DynamicSubgraphOwner(Protocol): """ Owns a subgraph instantiated at runtime (ForEachRecipe, WhileRecipe, IfRecipe, TryRecipe). """ inputs: base_models.Labels outputs: base_models.Labels input_edges: edge_models.InputEdges @property def prospective_nodes(self) -> NodesAlias: ...
[docs] @runtime_checkable class DynamicSubgraphStaticOutput(DynamicSubgraphOwner, Protocol): """ Dynamic subgraph with output interface known a-priori (ForEachRecipe, WhileRecipe). """ output_edges: edge_models.OutputEdges
[docs] @runtime_checkable class DynamicSubgraphDynamicOutput(DynamicSubgraphOwner, Protocol): """Dynamic subgraph with output interface known at runtime (IfRecipe, TryRecipe).""" prospective_output_edges: ProspectiveOutputEdges
[docs] def validate_input_edge_sources( input_edges: edge_models.InputEdges, available_inputs: base_models.Labels, ) -> None: if invalid := { s.serialize() for s in input_edges.values() if s.port not in available_inputs }: raise ValueError(f"Invalid input_edges source ports: {invalid}")
[docs] def validate_input_edge_targets( input_edges: edge_models.InputEdges, target_nodes: NodesAlias, ) -> None: if invalid_nodes := { t.serialize() for t in input_edges if t.node not in target_nodes }: raise ValueError( f"Invalid input_edges target nodes {invalid_nodes}, " f"available: {tuple(target_nodes.keys())}" ) if invalid_ports := { t.serialize() for t in input_edges if t.port not in target_nodes[t.node].inputs }: raise ValueError(f"Invalid input_edges target ports: {invalid_ports}")
[docs] def validate_output_edge_targets( output_targets: Collection[edge_models.OutputTarget], available_outputs: base_models.Labels, ) -> None: target_ports = {t.port for t in output_targets} if invalid := target_ports - set(available_outputs): raise ValueError(f"Invalid output target ports: {invalid}") if missing := set(available_outputs) - target_ports: raise ValueError(f"Missing output edge for: {missing}")
[docs] def validate_output_edge_sources( sources: Collection[edge_models.SourceHandle | edge_models.InputSource], source_nodes: NodesAlias, inputs: base_models.Labels, ) -> None: if invalid_nodes := { s.serialize() for s in sources if s.node is not None and s.node not in source_nodes }: raise ValueError(f"Invalid output source nodes: {invalid_nodes}") if invalid_ports := { s.serialize() for s in sources if ( (s.node is None and s.port not in inputs) or (s.node is not None and s.port not in source_nodes[s.node].outputs) ) }: raise ValueError(f"Invalid output source ports: {invalid_ports}")
[docs] def validate_prospective_sources_list( target: edge_models.OutputTarget, sources: Collection[edge_models.SourceHandle], ) -> None: if not sources: raise ValueError(f"Sources for '{target.serialize()}' cannot be empty.") node_counts = collections.Counter(source.node for source in sources) if duplicate_nodes := {node for node, count in node_counts.items() if count > 1}: raise ValueError( f"Sources for {target.serialize()} must be unique. " f"Duplicate source nodes: {duplicate_nodes}" )
[docs] def validate_sibling_edges( edges: edge_models.Edges, target_nodes: NodesAlias, source_nodes: NodesAlias | None = None, ) -> None: if source_nodes is None: source_nodes = target_nodes for target, source in edges.items(): if target.node not in target_nodes: raise ValueError(f"Invalid edge target node: {target.serialize()}") if source.node not in source_nodes: raise ValueError(f"Invalid edge source node: {source.serialize()}") if target.port not in target_nodes[target.node].inputs: raise ValueError(f"Invalid edge target port: {target.serialize()}") if source.port not in source_nodes[source.node].outputs: raise ValueError(f"Invalid edge source port: {source.serialize()}")
[docs] def validate_acyclic_edges( edges: edge_models.Edges, message="Edges contain cycle(s)" ) -> None: # Build adjacency list and in-degree count in_degree: dict[str, int] = {} successors: dict[str, list[str]] = {} for target, source in edges.items(): if target.node is None or source.node is None: continue s, t = source.node, target.node in_degree.setdefault(s, 0) in_degree.setdefault(t, 0) successors.setdefault(s, []) in_degree[t] += 1 successors[s].append(t) # Kahn's algorithm queue = [n for n, d in in_degree.items() if d == 0] visited = 0 while queue: node = queue.pop() visited += 1 for neighbor in successors.get(node, ()): in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) if visited != len(in_degree): raise ValueError(f"{message}")
[docs] def validate_nodes_are_fully_sourced( nodes: NodesAlias, context: Collection[edge_models.TargetHandle], ): for label, node in nodes.items(): for port in node.inputs: target = edge_models.TargetHandle(node=label, port=port) if port not in node.inputs_with_defaults and target not in context: raise ValueError( f"Could not find a source or default for the target: {label}.{port}" ) for node in nodes.values(): node.validate_internal_data_completeness()