from __future__ import annotations
import ast
import dataclasses
import inspect
import textwrap
from collections.abc import Callable
from types import FunctionType
from typing import Any, cast
from flowrep import base_models
from flowrep.nodes import helper_models
from flowrep.parsers import symbol_scope
[docs]
class SourceCodeUnavailableError(ValueError): ...
[docs]
def parser2decorator(
func: FunctionType | str | None,
output_labels: tuple[str, ...],
*,
parser: Callable[..., Any],
decorator_name: str,
parser_kwargs: dict[str, Any] | None = None,
) -> FunctionType | Callable[[FunctionType], FunctionType]:
parser_kwargs = parser_kwargs or {}
if isinstance(func, FunctionType):
# Direct decoration: @workflow / @atomic
parsed_labels: tuple[str, ...] = ()
target_func = func
elif func is not None and not isinstance(func, str):
raise TypeError(
f"{decorator_name} can only decorate functions, got {type(func).__name__}"
)
else:
# Called with args: @decorator(...) or @decorator("label", ...)
parsed_labels = (func,) + output_labels if func is not None else output_labels
target_func = None
def decorator(f: FunctionType) -> FunctionType:
ensure_function(f, decorator_name)
f.flowrep_recipe = parser(f, *parsed_labels, **parser_kwargs) # type: ignore[attr-defined]
return f
return decorator(target_func) if target_func else decorator
[docs]
def ensure_function(f: Any, decorator_name: str) -> None:
if not isinstance(f, FunctionType):
raise TypeError(
f"{decorator_name} can only decorate functions, got {type(f).__name__}"
)
[docs]
def get_function_definition(tree: ast.Module) -> ast.FunctionDef:
if len(tree.body) == 1 and isinstance(tree.body[0], ast.FunctionDef):
return tree.body[0]
raise ValueError(
f"Expected ast to receive a single function definition, but got a body of "
f"{[type(t) for t in tree.body]}"
)
[docs]
def get_source_code(func: FunctionType) -> str:
if func.__name__ == "<lambda>":
raise SourceCodeUnavailableError(
"Cannot parse return labels for lambda functions. "
"Use a named function with @atomic decorator."
)
try:
source_code = textwrap.dedent(inspect.getsource(func))
except (OSError, TypeError) as e:
raise SourceCodeUnavailableError(
f"Cannot parse return labels for {func.__qualname__}: "
f"source code unavailable (lambdas, dynamically defined functions, "
f"and compiled code are not supported)"
) from e
return source_code
[docs]
def get_available_source_code(func: FunctionType) -> str | None:
try:
return get_source_code(func)
except SourceCodeUnavailableError:
return None
[docs]
@dataclasses.dataclass(frozen=True)
class SignatureInfo:
names: list[str]
have_defaults: list[str]
have_restricted_kinds: dict[str, base_models.RestrictedParamKind]
[docs]
@classmethod
def of(cls, func: FunctionType) -> SignatureInfo:
sig = inspect.signature(func)
return SignatureInfo(
names=list(sig.parameters.keys()),
have_defaults=[
label
for label, param in sig.parameters.items()
if param.default is not inspect.Parameter.empty
],
have_restricted_kinds={
label: rk
for label, param in sig.parameters.items()
if (rk := base_models.RestrictedParamKind.from_param_kind(param.kind))
is not None
},
)
[docs]
def get_ast_function_node(func: FunctionType) -> ast.FunctionDef:
return get_function_definition(ast.parse(get_source_code(func)))
[docs]
def resolve_symbols_to_strings(
node: (
ast.expr | None
), # Expecting a Name or Tuple[Name], and will otherwise TypeError
) -> list[str]:
if isinstance(node, ast.Name):
return [node.id]
elif isinstance(node, ast.Tuple) and all(
isinstance(elt, ast.Name) for elt in node.elts
):
return [cast(ast.Name, elt).id for elt in node.elts]
else:
raise TypeError(
f"Expected to receive a symbol or tuple of symbols from ast.Name or "
f"ast.Tuple, but could not parse this from {type(node)}."
)
[docs]
def consume_call_arguments(
scope: symbol_scope.SymbolScope,
ast_call: ast.Call,
child: helper_models.LabeledRecipe,
) -> None:
"""Record all argument->port consumptions for a node-creating call."""
def _validate_is_ast_name(node: ast.expr) -> ast.Name:
if not isinstance(node, ast.Name):
raise TypeError(
f"Workflow python definitions can only interpret function "
f"calls with symbolic input, and thus expected to find an "
f"ast.Name, but when parsing input for {child.label}, found a "
f"type {type(node)}"
)
return node
for i, arg in enumerate(ast_call.args):
name_arg = _validate_is_ast_name(arg)
scope.consume(name_arg.id, child.label, child.node.inputs[i])
for kw in ast_call.keywords:
name_arg = _validate_is_ast_name(kw.value)
if not isinstance(kw.arg, str): # pragma: no cover
raise TypeError(
"How did you get here? A `None` value should be possible for "
"**kwargs, but variadics should have been excluded before "
"this. Please raise a GitHub issue."
)
scope.consume(name_arg.id, child.label, kw.arg)