Source code for flowrep.parsers.atomic_parser

import ast
import dataclasses
import inspect
from collections.abc import Callable, Iterable
from types import FunctionType
from typing import Annotated, cast, get_args, get_origin, get_type_hints

from pyiron_snippets import versions

from flowrep import base_models
from flowrep.nodes import atomic_recipe, helper_models
from flowrep.parsers import label_helpers, object_scope, parser_helpers
from flowrep.parsers.label_helpers import default_output_label


[docs] def atomic( func: FunctionType | str | None = None, /, *output_labels: str, unpack_mode: atomic_recipe.UnpackMode = atomic_recipe.UnpackMode.TUPLE, version_scraping: versions.VersionScrapingMap | None = None, forbid_main: bool = False, forbid_locals: bool = False, require_version: bool = False, ) -> FunctionType | Callable[[FunctionType], FunctionType]: """ Decorator that attaches a :class:`~flowrep.models.nodes.atomic_recipe.AtomicRecipe` to the ``flowrep_recipe`` attribute of a function. The decorated function's module, qualname, and (optionally) package version are captured as provenance metadata via :meth:`~pyiron_snippets.versions.VersionInfo.of`. Can be used with or without arguments. Args: func: The function to decorate. Passed positionally by Python when the decorator is used without parentheses. *output_labels: Explicit names for the node's output ports. When provided, their count must match the number of outputs inferred from the function and the chosen ``unpack_mode``. unpack_mode: How to convert the function's return value into output ports. See :class:`~flowrep.models.nodes.atomic_recipe.UnpackMode`. version_scraping: Optional mapping from top-level package names to callables that return a version string, for packages that don't expose ``__version__``. Forwarded to :meth:`~pyiron_snippets.versions.VersionInfo.of`. forbid_main: If ``True``, raise if the function's module is ``__main__``. forbid_locals: If ``True``, raise if the function's qualname contains ``<locals>`` (i.e. it was defined inside another function). require_version: If ``True``, raise if no version can be determined for the function's package. Returns: The original function with a ``flowrep_recipe`` attribute holding an :class:`~flowrep.models.nodes.atomic_recipe.AtomicRecipe`. """ return parser_helpers.parser2decorator( func, output_labels, parser=parse_atomic, decorator_name="@atomic", parser_kwargs={ "unpack_mode": unpack_mode, "version_scraping": version_scraping, "forbid_main": forbid_main, "forbid_locals": forbid_locals, "require_version": require_version, }, )
[docs] def parse_atomic( func: FunctionType, *output_labels: str, unpack_mode: atomic_recipe.UnpackMode = atomic_recipe.UnpackMode.TUPLE, version_scraping: versions.VersionScrapingMap | None = None, forbid_main: bool = False, forbid_locals: bool = False, require_version: bool = False, ) -> atomic_recipe.AtomicRecipe: """ Build an :class:`~flowrep.models.nodes.atomic_recipe.AtomicRecipe` from a plain Python function. Introspects the function to determine its fully qualified name, package version, input parameter names, and output port names (via AST return-value analysis and/or type annotations). Args: func: The function to represent as an atomic node. *output_labels: Explicit output port names. When provided, their count must match the number of outputs inferred from the function and the chosen ``unpack_mode``. unpack_mode: How to convert the function's return value into output ports. version_scraping: Optional version-scraping overrides, forwarded to :meth:`~pyiron_snippets.versions.VersionInfo.of`. forbid_main: If ``True``, raise if the function's module is ``__main__``. forbid_locals: If ``True``, raise if the function's qualname contains ``<locals>``. require_version: If ``True``, raise if no version can be determined. Returns: A fully constructed :class:`AtomicRecipe`. Raises: ValueError: If ``output_labels`` length mismatches the inferred output count, or if any ``forbid_*`` / ``require_*`` constraint is violated. """ function_info = versions.VersionInfo.of( func, version_scraping=version_scraping, forbid_main=forbid_main, forbid_locals=forbid_locals, require_version=require_version, ) sig_info = parser_helpers.SignatureInfo.of(func) docstring = inspect.getdoc(func) scraped_output_labels = _get_output_labels(func, unpack_mode) if len(output_labels) > 0 and len(output_labels) != len(scraped_output_labels): raise ValueError( "Explicitly provided output labels must match the function analysis and " f"unpack_mode: expected {len(scraped_output_labels)} labels for " f"unpack_mode='{unpack_mode}', got {len(output_labels)} labels " f"{output_labels}; inferred labels were {scraped_output_labels}." ) return atomic_recipe.AtomicRecipe( reference=base_models.PythonReference( info=function_info, inputs_with_defaults=sig_info.have_defaults, restricted_input_kinds=sig_info.have_restricted_kinds, ), inputs=sig_info.names, outputs=( list(output_labels) if len(output_labels) > 0 else scraped_output_labels ), description=docstring, unpack_mode=unpack_mode, )
def _get_output_labels( func: FunctionType, unpack_mode: atomic_recipe.UnpackMode ) -> list[str]: if unpack_mode == atomic_recipe.UnpackMode.NONE: return _parse_return_label_without_unpacking(func) elif unpack_mode == atomic_recipe.UnpackMode.TUPLE: return _parse_tuple_return_labels(func) elif unpack_mode == atomic_recipe.UnpackMode.DATACLASS: return _parse_dataclass_return_labels(func) raise TypeError( f"Invalid unpack mode: {unpack_mode}. Possible values are " f"{', '.join(atomic_recipe.UnpackMode.__members__.values())}" ) def _parse_return_label_without_unpacking(func: FunctionType) -> list[str]: """ Get output label for UnpackMode.NONE. Looks for annotation on the return type itself (not tuple elements). For `-> Annotated[T, {"label": "x"}]` or `-> Annotated[tuple[...], {"label": "x"}]` """ try: hints = get_type_hints(func, include_extras=True) except Exception: return [label_helpers.default_output_label(0)] return_hint = hints.get("return") if return_hint is None: return [label_helpers.default_output_label(0)] # Extract label from the outermost Annotated wrapper label = label_helpers.extract_label_from_annotated(return_hint) return [label] if label is not None else [label_helpers.default_output_label(0)] def _parse_tuple_return_labels(func: FunctionType) -> list[str]: func_node = parser_helpers.get_ast_function_node(func) return_labels = _extract_combined_return_labels(func_node) if not all(len(ret) == len(return_labels[0]) for ret in return_labels): raise ValueError( f"All return statements must have the same number of elements, got " f"{return_labels}" ) # Get AST-scraped labels scraped = list( ( label if all(other_branch[i] == label for other_branch in return_labels) else label_helpers.default_output_label(i) ) for i, label in enumerate(return_labels[0]) ) # Override with annotation-based labels where available annotated = label_helpers.get_annotated_output_labels(func) return label_helpers.merge_labels( first_choice=annotated, fallback=scraped, message_prefix="Annotations and scraped return labels mis-match. ", ) def _extract_combined_return_labels( func_node: ast.FunctionDef, ) -> list[tuple[str, ...]]: return_stmts = [n for n in ast.walk(func_node) if isinstance(n, ast.Return)] return_labels: list[tuple[str, ...]] = [()] if len(return_stmts) == 0 else [] for ret in return_stmts: return_labels.append(_extract_return_labels(ret)) return return_labels def _extract_return_labels(ret: ast.Return) -> tuple[str, ...]: if ret.value is None: return_labels: tuple[str, ...] = () return return_labels elif isinstance(ret.value, ast.Tuple): return tuple( elt.id if isinstance(elt, ast.Name) else default_output_label(i) for i, elt in enumerate(ret.value.elts) ) else: return ( (ret.value.id,) if isinstance(ret.value, ast.Name) else (default_output_label(0),) ) def _parse_dataclass_return_labels(func: FunctionType) -> list[str]: source_code_return = _parse_tuple_return_labels(func) if len(source_code_return) != 1: raise ValueError( f"Dataclass unpack mode requires function code to returns to consist of " f"exactly one value, i.e. the dataclass instance, but got " f"{source_code_return}" ) sig = inspect.signature(func) ann = sig.return_annotation # unwrap Annotated origin = get_origin(ann) return_annotation = get_args(ann)[0] if origin is Annotated else ann if dataclasses.is_dataclass(return_annotation): return [f.name for f in dataclasses.fields(return_annotation)] raise ValueError( f"Dataclass unpack mode requires a return type annotation that is a " f"(perhaps Annotated) dataclass, but got {ann}" )
[docs] def get_labeled_recipe( ast_call: ast.Call, existing_names: Iterable[str], scope: object_scope.ScopeProxy, info_factory: versions.VersionInfoFactory, ) -> helper_models.LabeledRecipe: child_call = object_scope.resolve_symbol_to_object(ast_call.func, scope) if isinstance(child_call, base_models.NodeRecipe): child_recipe = child_call label_prefix = _infer_node_name(child_recipe, ast_call.func) else: # Otherwise we're going to find it has already been parsed as a recipe, # or we're going to parse it as a recipe -- either way, it had better be a # FunctionType! function_call = cast(FunctionType, child_call) label_prefix = function_call.__name__ if hasattr(function_call, "flowrep_recipe"): child_recipe = function_call.flowrep_recipe if hasattr(child_recipe, "reference") and isinstance( child_recipe.reference.info, versions.VersionInfo ): child_recipe.reference.info.validate_constraints( forbid_main=info_factory.forbid_main, forbid_locals=info_factory.forbid_locals, require_version=info_factory.require_version, ) else: child_recipe = parse_atomic( function_call, version_scraping=info_factory.version_scraping, forbid_main=info_factory.forbid_main, forbid_locals=info_factory.forbid_locals, require_version=info_factory.require_version, ) label = label_helpers.unique_suffix(label_prefix, existing_names) return helper_models.LabeledRecipe(label=label, node=child_recipe)
def _infer_node_name(node: base_models.NodeRecipe, ast_call: ast.expr) -> str: reference = getattr(node, "reference", None) if reference is not None: underlying_function_name = reference.info.qualname.rsplit(".", 1)[-1] return underlying_function_name elif isinstance(ast_call, ast.Name): variable_name = ast_call.id return variable_name elif isinstance(ast_call, ast.Attribute): proximate_attribute_name = ast_call.attr return proximate_attribute_name else: raise ValueError(f"Unexpected node type: {type(ast_call)}")