From c62b8708a4981f583816b885f3392ef24f2df5b2 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 29 Jul 2023 19:52:05 +0200 Subject: [PATCH 1/8] Convert ABCs to protocols. --- docs/source/reference_guides/api.md | 2 +- src/_pytask/collect.py | 14 ++++++------ src/_pytask/collect_command.py | 8 +++---- src/_pytask/collect_utils.py | 24 ++++++++++----------- src/_pytask/dag.py | 11 +++++----- src/_pytask/database_utils.py | 4 ++-- src/_pytask/execute.py | 6 +++--- src/_pytask/hookspecs.py | 6 +++--- src/_pytask/node_protocols.py | 33 +++++++++++++++++++++++++++++ src/_pytask/nodes.py | 27 +++++++---------------- src/_pytask/profile.py | 4 ++-- src/_pytask/report.py | 2 +- src/_pytask/shared.py | 11 ++++------ src/pytask/__init__.py | 6 +++--- tests/test_collect_command.py | 17 +++++++-------- tests/test_dag.py | 4 ++-- tests/test_execute.py | 4 ++-- tests/test_nodes.py | 8 +++---- 18 files changed, 103 insertions(+), 88 deletions(-) create mode 100644 src/_pytask/node_protocols.py diff --git a/docs/source/reference_guides/api.md b/docs/source/reference_guides/api.md index 2227723db..bfd922e29 100644 --- a/docs/source/reference_guides/api.md +++ b/docs/source/reference_guides/api.md @@ -246,7 +246,7 @@ from {class}`pytask.MetaNode`. Then, different kinds of nodes can be implemented. ```{eval-rst} -.. autoclass:: pytask.FilePathNode +.. autoclass:: pytask.PathNode :members: ``` diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index d4d135c96..26db423fd 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -21,8 +21,8 @@ from _pytask.exceptions import CollectionError from _pytask.mark_utils import has_mark from _pytask.models import NodeInfo -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import Node +from _pytask.nodes import PathNode from _pytask.nodes import PythonNode from _pytask.nodes import Task from _pytask.outcomes import CollectionOutcome @@ -95,7 +95,7 @@ def pytask_collect_file_protocol( ) flat_reports = list(itertools.chain.from_iterable(new_reports)) except Exception: # noqa: BLE001 - node = FilePathNode.from_path(path) + node = PathNode.from_path(path) flat_reports = [ CollectionReport.from_exception( outcome=CollectionOutcome.FAIL, node=node, exc_info=sys.exc_info() @@ -204,8 +204,8 @@ def pytask_collect_task( @hookimpl(trylast=True) -def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> MetaNode: - """Collect a node of a task as a :class:`pytask.nodes.FilePathNode`. +def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Node: + """Collect a node of a task as a :class:`pytask.nodes.PathNode`. Strings are assumed to be paths. This might be a strict assumption, but since this hook is executed at last and possible errors will be shown, it seems reasonable and @@ -223,7 +223,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Me node.name = node_info.arg_name + suffix return node - if isinstance(node, MetaNode): + if isinstance(node, Node): return node if isinstance(node, Path): @@ -243,7 +243,7 @@ def pytask_collect_node(session: Session, path: Path, node_info: NodeInfo) -> Me if str(node) != str(case_sensitive_path): raise ValueError(_TEMPLATE_ERROR.format(node, case_sensitive_path)) - return FilePathNode.from_path(node) + return PathNode.from_path(node) suffix = "-" + "-".join(map(str, node_info.path)) if node_info.path else "" node_name = node_info.arg_name + suffix diff --git a/src/_pytask/collect_command.py b/src/_pytask/collect_command.py index e615a8ea9..e2ee15c06 100644 --- a/src/_pytask/collect_command.py +++ b/src/_pytask/collect_command.py @@ -20,7 +20,7 @@ from _pytask.exceptions import ResolvingDependenciesError from _pytask.mark import select_by_keyword from _pytask.mark import select_by_mark -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.outcomes import ExitCode from _pytask.path import find_common_ancestor from _pytask.path import relative_to @@ -125,9 +125,7 @@ def _find_common_ancestor_of_all_nodes( all_paths.append(task.path) if show_nodes: all_paths.extend( - x.path - for x in tree_leaves(task.depends_on) - if isinstance(x, FilePathNode) + x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode) ) all_paths.extend(x.path for x in tree_leaves(task.produces)) @@ -205,7 +203,7 @@ def _print_collected_tasks( file_path_nodes = list(tree_leaves(task.depends_on)) sorted_nodes = sorted(file_path_nodes, key=lambda x: x.name) for node in sorted_nodes: - if isinstance(node, FilePathNode): + if isinstance(node, PPathNode): reduced_node_name = relative_to(node.path, common_ancestor) url_style = create_url_style_for_path( node.path, editor_url_scheme diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index b333d2c66..906285cfa 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -17,7 +17,7 @@ from _pytask.mark_utils import has_mark from _pytask.mark_utils import remove_marks from _pytask.models import NodeInfo -from _pytask.nodes import MetaNode +from _pytask.node_protocols import Node from _pytask.nodes import ProductType from _pytask.nodes import PythonNode from _pytask.shared import find_duplicates @@ -80,7 +80,7 @@ def parse_nodes( objects = _extract_nodes_from_function_markers(obj, parser) nodes = _convert_objects_to_node_dictionary(objects, arg_name) nodes = tree_map( - lambda x: _collect_decorator_nodes( + lambda x: _collect_decorator_node( session, path, name, NodeInfo(arg_name, (), x) ), nodes, @@ -250,7 +250,7 @@ def parse_dependencies_from_task_function( # noqa: C901 if "depends_on" in kwargs: has_depends_on_argument = True dependencies["depends_on"] = tree_map( - lambda x: _collect_decorator_nodes( + lambda x: _collect_decorator_node( session, path, name, NodeInfo(arg_name="depends_on", path=(), value=x) ), kwargs["depends_on"], @@ -281,7 +281,7 @@ def _evolve(x: Any) -> Any: return x nodes = tree_map_with_path( - lambda p, x: _collect_dependencies( + lambda p, x: _collect_dependency( session, path, name, @@ -302,7 +302,7 @@ def _evolve(x: Any) -> Any: return dependencies -def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, MetaNode]: +def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, Node]: """Find args with node annotations.""" annotations = get_annotations(func, eval_str=True) metas = { @@ -314,9 +314,7 @@ def _find_args_with_node_annotation(func: Callable[..., Any]) -> dict[str, MetaN args_with_node_annotation = {} for name, meta in metas.items(): annot = [ - i - for i in meta - if not isinstance(i, ProductType) and isinstance(i, MetaNode) + i for i in meta if not isinstance(i, ProductType) and isinstance(i, Node) ] if len(annot) >= 2: # noqa: PLR2004 raise ValueError( @@ -456,9 +454,9 @@ def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]: """ -def _collect_decorator_nodes( +def _collect_decorator_node( session: Session, path: Path, name: str, node_info: NodeInfo -) -> dict[str, MetaNode]: +) -> Node: """Collect nodes for a task. Raises @@ -495,9 +493,9 @@ def _collect_decorator_nodes( return collected_node -def _collect_dependencies( +def _collect_dependency( session: Session, path: Path, name: str, node_info: NodeInfo -) -> dict[str, MetaNode]: +) -> Node: """Collect nodes for a task. Raises @@ -525,7 +523,7 @@ def _collect_product( task_name: str, node_info: NodeInfo, is_string_allowed: bool = False, -) -> dict[str, MetaNode]: +) -> Node: """Collect products for a task. Defining products with strings is only allowed when using the decorator. Parameter diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index bee58b988..050aae214 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -21,8 +21,9 @@ from _pytask.mark import Mark from _pytask.mark_utils import get_marks from _pytask.mark_utils import has_mark -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import Node +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.path import find_common_ancestor_of_nodes from _pytask.report import DagReport @@ -140,13 +141,13 @@ def pytask_dag_has_node_changed(node: MetaNode, task_name: str) -> bool: if db_state is None: return True - if isinstance(node, (FilePathNode, Task)): + if isinstance(node, (PPathNode, Task)): # If the modification times match, the node has not been changed. if node_state == db_state.modification_time: return False # If the modification time changed, quickly return for non-tasks. - if isinstance(node, FilePathNode): + if not isinstance(node, Task): return True # When modification times changed, we are still comparing the hash of the file @@ -238,7 +239,7 @@ def _check_if_root_nodes_are_available(dag: nx.DiGraph) -> None: def _check_if_tasks_are_skipped( - node: MetaNode, dag: nx.DiGraph, is_task_skipped: dict[str, bool] + node: Node, dag: nx.DiGraph, is_task_skipped: dict[str, bool] ) -> tuple[bool, dict[str, bool]]: """Check for a given node whether it is only used by skipped tasks.""" are_all_tasks_skipped = [] diff --git a/src/_pytask/database_utils.py b/src/_pytask/database_utils.py index 05a09de1b..19f5724b7 100644 --- a/src/_pytask/database_utils.py +++ b/src/_pytask/database_utils.py @@ -4,7 +4,7 @@ import hashlib from _pytask.dag_utils import node_and_neighbors -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.session import Session from sqlalchemy import Column @@ -80,7 +80,7 @@ def update_states_in_database(session: Session, task_name: str) -> None: if isinstance(node, Task): modification_time = node.state() hash_ = hashlib.sha256(node.path.read_bytes()).hexdigest() - elif isinstance(node, FilePathNode): + elif isinstance(node, PPathNode): modification_time = node.state() hash_ = "" else: diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index bcc3fb418..a4d703117 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -21,7 +21,7 @@ from _pytask.exceptions import NodeNotFoundError from _pytask.mark import Mark from _pytask.mark_utils import has_mark -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.outcomes import count_outcomes from _pytask.outcomes import Exit @@ -129,7 +129,7 @@ def pytask_execute_task_setup(session: Session, task: Task) -> None: # method for the node classes. for product in session.dag.successors(task.name): node = session.dag.nodes[product]["node"] - if isinstance(node, FilePathNode): + if isinstance(node, PPathNode): node.path.parent.mkdir(parents=True, exist_ok=True) would_be_executed = has_mark(task, "would_be_executed") @@ -159,7 +159,7 @@ def pytask_execute_task(session: Session, task: Task) -> bool: @hookimpl def pytask_execute_task_teardown(session: Session, task: Task) -> None: - """Check if :class:`_pytask.nodes.FilePathNode` are produced by a task.""" + """Check if :class:`_pytask.nodes.PathNode` are produced by a task.""" missing_nodes = [] for product in session.dag.successors(task.name): node = session.dag.nodes[product]["node"] diff --git a/src/_pytask/hookspecs.py b/src/_pytask/hookspecs.py index 3eb56599c..ef460c342 100644 --- a/src/_pytask/hookspecs.py +++ b/src/_pytask/hookspecs.py @@ -14,11 +14,11 @@ import networkx import pluggy from _pytask.models import NodeInfo +from _pytask.node_protocols import Node if TYPE_CHECKING: from _pytask.session import Session - from _pytask.nodes import MetaNode from _pytask.nodes import Task from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import TaskOutcome @@ -196,7 +196,7 @@ def pytask_collect_task_teardown(session: Session, task: Task) -> None: @hookspec(firstresult=True) def pytask_collect_node( session: Session, path: pathlib.Path, node_info: NodeInfo -) -> MetaNode | None: +) -> Node | None: """Collect a node which is a dependency or a product of a task.""" @@ -266,7 +266,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: networkx.DiGraph) -> @hookspec(firstresult=True) def pytask_dag_has_node_changed( - session: Session, dag: networkx.DiGraph, node: MetaNode, task_name: str + session: Session, dag: networkx.DiGraph, node: Node, task_name: str ) -> None: """Select the subgraph which needs to be executed. diff --git a/src/_pytask/node_protocols.py b/src/_pytask/node_protocols.py new file mode 100644 index 000000000..f3062cb4c --- /dev/null +++ b/src/_pytask/node_protocols.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from abc import abstractmethod +from pathlib import Path +from typing import Any +from typing import Protocol +from typing import runtime_checkable + + +@runtime_checkable +class MetaNode(Protocol): + """Protocol for an intersection between nodes and tasks.""" + + name: str | None + """The name of node that must be unique.""" + + @abstractmethod + def state(self) -> Any: + ... + + +@runtime_checkable +class Node(MetaNode, Protocol): + """Protocol for nodes.""" + + value: Any + + +@runtime_checkable +class PPathNode(Node, Protocol): + """Nodes with paths.""" + + path: Path diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index b41c906a9..0dc8619ba 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -3,13 +3,13 @@ import functools import hashlib -from abc import ABCMeta -from abc import abstractmethod from pathlib import Path from typing import Any from typing import Callable from typing import TYPE_CHECKING +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import Node from _pytask.tree_util import PyTree from attrs import define from attrs import field @@ -19,7 +19,7 @@ from _pytask.mark import Mark -__all__ = ["FilePathNode", "MetaNode", "Product", "Task"] +__all__ = ["PathNode", "Product", "Task"] @define(frozen=True) @@ -30,17 +30,6 @@ class ProductType: Product = ProductType() -class MetaNode(metaclass=ABCMeta): - """Meta class for nodes.""" - - name: str - """str: The name of node that must be unique.""" - - @abstractmethod - def state(self) -> Any: - ... - - @define(kw_only=True) class Task(MetaNode): """The class for tasks which are Python functions.""" @@ -55,9 +44,9 @@ class Task(MetaNode): """The name of the task.""" short_name: str | None = field(default=None, init=False) """The shortest uniquely identifiable name for task for display.""" - depends_on: PyTree[MetaNode] = field(factory=dict) + depends_on: PyTree[Node] = field(factory=dict) """A list of dependencies of task.""" - produces: PyTree[MetaNode] = field(factory=dict) + produces: PyTree[Node] = field(factory=dict) """A list of products of task.""" markers: list[Mark] = field(factory=list) """A list of markers attached to the task function.""" @@ -92,7 +81,7 @@ def add_report_section(self, when: str, key: str, content: str) -> None: @define(kw_only=True) -class FilePathNode(MetaNode): +class PathNode(Node): """The class for a node which is a path.""" name: str = "" @@ -104,7 +93,7 @@ class FilePathNode(MetaNode): @classmethod @functools.lru_cache - def from_path(cls, path: Path) -> FilePathNode: + def from_path(cls, path: Path) -> PathNode: """Instantiate class from path to file. The `lru_cache` decorator ensures that the same object is not collected twice. @@ -126,7 +115,7 @@ def state(self) -> str | None: @define(kw_only=True) -class PythonNode(MetaNode): +class PythonNode(Node): """The class for a node which is a Python object.""" name: str = "" diff --git a/src/_pytask/profile.py b/src/_pytask/profile.py index e0b29d3bf..5c8a33e57 100644 --- a/src/_pytask/profile.py +++ b/src/_pytask/profile.py @@ -23,7 +23,7 @@ from _pytask.database_utils import DatabaseSession from _pytask.exceptions import CollectionError from _pytask.exceptions import ConfigurationError -from _pytask.nodes import FilePathNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.outcomes import ExitCode from _pytask.outcomes import TaskOutcome @@ -228,7 +228,7 @@ def pytask_profile_add_info_on_task( sum_bytes = 0 for successor in successors: node = session.dag.nodes[successor]["node"] - if isinstance(node, FilePathNode): + if isinstance(node, PPathNode): with suppress(FileNotFoundError): sum_bytes += node.path.stat().st_size diff --git a/src/_pytask/report.py b/src/_pytask/report.py index 0bc0cf03d..7df151fc3 100644 --- a/src/_pytask/report.py +++ b/src/_pytask/report.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from _pytask.nodes import MetaNode + from _pytask.node_protocols import MetaNode from _pytask.nodes import Task diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 4747813d7..6059bd1b2 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -10,8 +10,8 @@ import click import networkx as nx from _pytask.console import format_task_id -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import MetaNode +from _pytask.node_protocols import PPathNode from _pytask.nodes import Task from _pytask.path import find_closest_ancestor from _pytask.path import find_common_ancestor @@ -67,7 +67,7 @@ def reduce_node_name(node: MetaNode, paths: Sequence[str | Path]) -> str: path from one path in ``session.config["paths"]`` to the node. """ - if isinstance(node, (Task, FilePathNode)): + if isinstance(node, (PPathNode, Task)): ancestor = find_closest_ancestor(node.path, paths) if ancestor is None: try: @@ -75,10 +75,7 @@ def reduce_node_name(node: MetaNode, paths: Sequence[str | Path]) -> str: except ValueError: ancestor = node.path.parents[-1] - if isinstance(node, MetaNode): - name = relative_to(node.path, ancestor).as_posix() - else: - raise TypeError(f"Unknown node {node} with type {type(node)!r}.") + name = relative_to(node.path, ancestor).as_posix() return name return node.name diff --git a/src/pytask/__init__.py b/src/pytask/__init__.py index 1b5eca521..d626d4dd3 100644 --- a/src/pytask/__init__.py +++ b/src/pytask/__init__.py @@ -36,8 +36,8 @@ from _pytask.mark_utils import set_marks from _pytask.models import CollectionMetadata from _pytask.models import NodeInfo -from _pytask.nodes import FilePathNode -from _pytask.nodes import MetaNode +from _pytask.node_protocols import MetaNode +from _pytask.nodes import PathNode from _pytask.nodes import Product from _pytask.nodes import PythonNode from _pytask.nodes import Task @@ -84,7 +84,7 @@ "ExecutionReport", "Exit", "ExitCode", - "FilePathNode", + "PathNode", "Mark", "MarkDecorator", "MarkGenerator", diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 6ab330119..f28064930 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -7,11 +7,10 @@ import pytest from _pytask.collect_command import _find_common_ancestor_of_all_nodes from _pytask.collect_command import _print_collected_tasks -from _pytask.nodes import FilePathNode +from _pytask.nodes import PathNode from attrs import define from pytask import cli from pytask import ExitCode -from pytask import MetaNode from pytask import Task @@ -343,7 +342,7 @@ def task_example_2(): @define -class MetaNode(MetaNode): +class Node: path: Path def state(self): @@ -362,8 +361,8 @@ def test_print_collected_tasks_without_nodes(capsys): base_name="function", path=Path("task_path.py"), function=function, - depends_on={0: MetaNode("in.txt")}, - produces={0: MetaNode("out.txt")}, + depends_on={0: Node("in.txt")}, + produces={0: Node("out.txt")}, ) ] } @@ -386,12 +385,12 @@ def test_print_collected_tasks_with_nodes(capsys): path=Path("task_path.py"), function=function, depends_on={ - "depends_on": FilePathNode( + "depends_on": PathNode( name="in.txt", value=Path("in.txt"), path=Path("in.txt") ) }, produces={ - 0: FilePathNode( + 0: PathNode( name="out.txt", value=Path("out.txt"), path=Path("out.txt") ) }, @@ -418,10 +417,10 @@ def test_find_common_ancestor_of_all_nodes(show_nodes, expected_add): path=Path.cwd() / "src" / "task_path.py", function=function, depends_on={ - "depends_on": FilePathNode.from_path(Path.cwd() / "src" / "in.txt") + "depends_on": PathNode.from_path(Path.cwd() / "src" / "in.txt") }, produces={ - 0: FilePathNode.from_path( + 0: PathNode.from_path( Path.cwd().joinpath("..", "bld", "out.txt").resolve() ) }, diff --git a/tests/test_dag.py b/tests/test_dag.py index 017374eba..d78cfa862 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -14,12 +14,12 @@ from attrs import define from pytask import cli from pytask import ExitCode -from pytask import FilePathNode +from pytask import PathNode from pytask import Task @define -class Node(FilePathNode): +class Node(PathNode): """See https://github.com/python-attrs/attrs/issues/293 for property hack.""" name: str diff --git a/tests/test_execute.py b/tests/test_execute.py index a86dc1768..21b3909a2 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -530,11 +530,11 @@ def test_error_with_multiple_different_dep_annotations(runner, tmp_path): source = """ from pathlib import Path from typing_extensions import Annotated - from pytask import Product, PythonNode, FilePathNode + from pytask import Product, PythonNode, PathNode from typing import Any def task_example( - dependency: Annotated[Any, PythonNode(), FilePathNode()] = "hello", + dependency: Annotated[Any, PythonNode(), PathNode()] = "hello", path: Annotated[Path, Product] = Path("out.txt") ) -> None: path.write_text(dependency) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 9165a2f91..a58d7e501 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -5,7 +5,7 @@ import pytest from _pytask.shared import reduce_node_name -from pytask import FilePathNode +from pytask import PathNode _ROOT = Path.cwd() @@ -16,14 +16,14 @@ ("node", "paths", "expectation", "expected"), [ pytest.param( - FilePathNode.from_path(_ROOT.joinpath("src/module.py")), + PathNode.from_path(_ROOT.joinpath("src/module.py")), [_ROOT.joinpath("alternative_src")], does_not_raise(), "pytask/src/module.py", - id="Common path found for FilePathNode not in 'paths' and 'paths'", + id="Common path found for PathNode not in 'paths' and 'paths'", ), pytest.param( - FilePathNode.from_path(_ROOT.joinpath("top/src/module.py")), + PathNode.from_path(_ROOT.joinpath("top/src/module.py")), [_ROOT.joinpath("top/src")], does_not_raise(), "src/module.py", From e728ebacdffd2482fe195320c3bf7600a9613423 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 29 Jul 2023 21:00:45 +0200 Subject: [PATCH 2/8] Add test with custom node. --- tests/test_node_protocols.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_node_protocols.py diff --git a/tests/test_node_protocols.py b/tests/test_node_protocols.py new file mode 100644 index 000000000..4d74cdc85 --- /dev/null +++ b/tests/test_node_protocols.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import pickle +import textwrap + +from pytask import cli +from pytask import ExitCode + + +def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from pathlib import Path + from attrs import define + import pickle + + @define + class PickleFile: + name: str + path: Path + + @property + def value(self): + with self.path.open("rb") as f: + out = pickle.load(f) + return out + + def state(self): + return str(self.path.stat().st_mtime) + + + _PATH = Path(__file__).parent.joinpath("in.pkl") + + def task_example( + data = PickleFile(_PATH.as_posix(), _PATH), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("in.pkl").write_bytes(pickle.dumps("text")) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").read_text() == "text" From 35affb351dce6a86755b63bbca02221888a794ba Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 29 Jul 2023 21:10:12 +0200 Subject: [PATCH 3/8] Add another test. --- tests/test_collect_command.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index f28064930..66ee25ec4 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import pickle import textwrap from pathlib import Path @@ -517,3 +518,42 @@ def task_example( assert "task_example>" in captured assert "" in result.output assert "Product" in captured + + +def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from pathlib import Path + from attrs import define + import pickle + + @define + class PickleFile: + name: str + path: Path + + @property + def value(self): + with self.path.open("rb") as f: + out = pickle.load(f) + return out + + def state(self): + return str(self.path.stat().st_mtime) + + + _PATH = Path(__file__).parent.joinpath("in.pkl") + + def task_example( + data = PickleFile(_PATH.as_posix(), _PATH), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("in.pkl").write_bytes(pickle.dumps("text")) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "in.pkl" in result.output From b9abc5ff5949ea818fef129b825409d6c1622217 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 29 Jul 2023 21:54:22 +0200 Subject: [PATCH 4/8] Add tests for custom nodes without paths. --- tests/test_collect_command.py | 29 +++++++++++++++++++++++++++++ tests/test_node_protocols.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 66ee25ec4..1ac7e5710 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -520,6 +520,35 @@ def task_example( assert "Product" in captured +def test_node_protocol_for_custom_nodes(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from attrs import define + from pathlib import Path + + @define + class CustomNode: + name: str + value: str + + def state(self): + return self.value + + + def task_example( + data = CustomNode("custom", "text"), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "" in result.output + + def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): source = """ from typing_extensions import Annotated diff --git a/tests/test_node_protocols.py b/tests/test_node_protocols.py index 4d74cdc85..4cf1944c0 100644 --- a/tests/test_node_protocols.py +++ b/tests/test_node_protocols.py @@ -7,6 +7,35 @@ from pytask import ExitCode +def test_node_protocol_for_custom_nodes(runner, tmp_path): + source = """ + from typing_extensions import Annotated + from pytask import Product + from attrs import define + from pathlib import Path + + @define + class CustomNode: + name: str + value: str + + def state(self): + return self.value + + + def task_example( + data = CustomNode("custom", "text"), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").read_text() == "text" + + def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): source = """ from typing_extensions import Annotated From 2f9fe6150b3374d0c2bdd53af0ec9d57ece17136 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sun, 30 Jul 2023 09:38:00 +0200 Subject: [PATCH 5/8] Fix tests. --- tests/test_console.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/test_console.py b/tests/test_console.py index 92cb4fc99..bd0b5eca6 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -127,14 +127,7 @@ def test_render_to_string(color_system, text, strip_styles, expected): None, Text( _THIS_FILE.as_posix() + "::task_a", - spans=[ - Span(0, len(_THIS_FILE.as_posix()) + 2, "dim"), - Span( - len(_THIS_FILE.as_posix()) + 2, - len(_THIS_FILE.as_posix()) + 2 + 6, - Style(), - ), - ], + spans=[Span(0, len(_THIS_FILE.as_posix()) + 2, "dim")], ), id="format full id", ), @@ -146,7 +139,7 @@ def test_render_to_string(color_system, text, strip_styles, expected): None, Text( "test_console.py::task_a", - spans=[Span(0, 17, "dim"), Span(17, 23, Style())], + spans=[Span(0, 17, "dim")], ), id="format short id", ), @@ -158,7 +151,7 @@ def test_render_to_string(color_system, text, strip_styles, expected): _THIS_FILE.parent, Text( "tests/test_console.py::task_a", - spans=[Span(0, 23, "dim"), Span(23, 29, Style())], + spans=[Span(0, 23, "dim")], ), id="format relative to id", ), From 49beafc54e9f64d578ce4117c8dfb6402c905dab Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sun, 30 Jul 2023 15:02:54 +0200 Subject: [PATCH 6/8] small fix. --- src/_pytask/nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 0dc8619ba..50c19df11 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -89,7 +89,7 @@ class PathNode(Node): value: Path | None = None """Value passed to the decorator which can be requested inside the function.""" path: Path | None = None - """Path to the FilePathNode.""" + """Path to the file.""" @classmethod @functools.lru_cache @@ -100,7 +100,7 @@ def from_path(cls, path: Path) -> PathNode: """ if not path.is_absolute(): - raise ValueError("FilePathNode must be instantiated from absolute path.") + raise ValueError("Node must be instantiated from absolute path.") return cls(name=path.as_posix(), value=path, path=path) def state(self) -> str | None: @@ -142,4 +142,4 @@ def state(self) -> str | None: if isinstance(self.value, str): return str(hashlib.sha256(self.value.encode()).hexdigest()) return str(hash(self.value)) - return str(0) + return "0" From 4aa2398da55c55cf4b96ce7f8eb2f376808d8936 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sun, 30 Jul 2023 20:00:40 +0200 Subject: [PATCH 7/8] Refactor evolution of nodes. --- src/_pytask/collect_command.py | 17 +++++++++----- src/_pytask/collect_utils.py | 43 ++++++++++++++++++++-------------- src/_pytask/hookspecs.py | 3 ++- src/_pytask/nodes.py | 22 +++++++++++++---- tests/test_collect.py | 21 +++++++++++++++++ tests/test_collect_command.py | 31 +++++++++++++++++------- tests/test_dag.py | 5 ---- 7 files changed, 101 insertions(+), 41 deletions(-) diff --git a/src/_pytask/collect_command.py b/src/_pytask/collect_command.py index e2ee15c06..a8b3fcf75 100644 --- a/src/_pytask/collect_command.py +++ b/src/_pytask/collect_command.py @@ -200,24 +200,29 @@ def _print_collected_tasks( ) if show_nodes: - file_path_nodes = list(tree_leaves(task.depends_on)) - sorted_nodes = sorted(file_path_nodes, key=lambda x: x.name) + nodes = list(tree_leaves(task.depends_on)) + sorted_nodes = sorted(nodes, key=lambda x: x.name) for node in sorted_nodes: if isinstance(node, PPathNode): - reduced_node_name = relative_to(node.path, common_ancestor) + if node.path.as_posix() in node.name: + reduced_node_name = str( + relative_to(node.path, common_ancestor) + ) + else: + reduced_node_name = node.name url_style = create_url_style_for_path( node.path, editor_url_scheme ) - text = Text(str(reduced_node_name), style=url_style) + text = Text(reduced_node_name, style=url_style) else: text = node.name task_branch.add(Text.assemble(FILE_ICON, "")) for node in sorted(tree_leaves(task.produces), key=lambda x: x.path): - reduced_node_name = relative_to(node.path, common_ancestor) + reduced_node_name = str(relative_to(node.path, common_ancestor)) url_style = create_url_style_for_path(node.path, editor_url_scheme) - text = Text(str(reduced_node_name), style=url_style) + text = Text(reduced_node_name, style=url_style) task_branch.add(Text.assemble(FILE_ICON, "")) console.print(tree) diff --git a/src/_pytask/collect_utils.py b/src/_pytask/collect_utils.py index 906285cfa..a5e6e507c 100644 --- a/src/_pytask/collect_utils.py +++ b/src/_pytask/collect_utils.py @@ -1,6 +1,7 @@ """This module provides utility functions for :mod:`_pytask.collect`.""" from __future__ import annotations +import functools import itertools import uuid import warnings @@ -11,13 +12,13 @@ from typing import Iterable from typing import TYPE_CHECKING -import attrs from _pytask._inspect import get_annotations from _pytask.exceptions import NodeNotCollectedError from _pytask.mark_utils import has_mark from _pytask.mark_utils import remove_marks from _pytask.models import NodeInfo from _pytask.node_protocols import Node +from _pytask.node_protocols import PPathNode from _pytask.nodes import ProductType from _pytask.nodes import PythonNode from _pytask.shared import find_duplicates @@ -228,7 +229,7 @@ def _merge_dictionaries(list_of_dicts: list[dict[Any, Any]]) -> dict[Any, Any]: """ -def parse_dependencies_from_task_function( # noqa: C901 +def parse_dependencies_from_task_function( session: Session, path: Path, name: str, obj: Any ) -> dict[str, Any]: """Parse dependencies from task function.""" @@ -269,23 +270,17 @@ def parse_dependencies_from_task_function( # noqa: C901 if parameter_name == "depends_on": continue - if parameter_name in parameters_with_node_annot: - - def _evolve(x: Any) -> Any: - instance = parameters_with_node_annot[parameter_name] # noqa: B023 - return attrs.evolve(instance, value=x) # type: ignore[misc] - - else: - - def _evolve(x: Any) -> Any: - return x + partialed_evolve = functools.partial( + _evolve_instance, + instance_from_annot=parameters_with_node_annot.get(parameter_name), + ) nodes = tree_map_with_path( lambda p, x: _collect_dependency( session, path, name, - NodeInfo(parameter_name, p, _evolve(x)), # noqa: B023 + NodeInfo(parameter_name, p, partialed_evolve(x)), # noqa: B023 ), value, ) @@ -295,7 +290,7 @@ def _evolve(x: Any) -> Any: are_all_nodes_python_nodes_without_hash = all( isinstance(x, PythonNode) and not x.hash for x in tree_leaves(nodes) ) - if are_all_nodes_python_nodes_without_hash: + if not isinstance(nodes, Node) and are_all_nodes_python_nodes_without_hash: dependencies[parameter_name] = PythonNode(value=value, name=parameter_name) else: dependencies[parameter_name] = nodes @@ -378,6 +373,7 @@ def parse_products_from_task_function( kwargs = {**signature_defaults, **task_kwargs} parameters_with_product_annot = _find_args_with_product_annotation(obj) + parameters_with_node_annot = _find_args_with_node_annotation(obj) # Parse products from task decorated with @task and that uses produces. if "produces" in kwargs: @@ -402,13 +398,17 @@ def parse_products_from_task_function( has_annotation = True for parameter_name in parameters_with_product_annot: if parameter_name in kwargs: - # Use _collect_new_node to not collect strings. + partialed_evolve = functools.partial( + _evolve_instance, + instance_from_annot=parameters_with_node_annot.get(parameter_name), + ) + collected_products = tree_map_with_path( lambda p, x: _collect_product( session, path, name, - NodeInfo(parameter_name, p, x), # noqa: B023 + NodeInfo(parameter_name, p, partialed_evolve(x)), # noqa: B023 is_string_allowed=False, ), kwargs[parameter_name], @@ -544,7 +544,7 @@ def _collect_product( f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}." ) # The parameter defaults only support Path objects. - if not isinstance(node, Path) and not is_string_allowed: + if not isinstance(node, (Path, PPathNode)) and not is_string_allowed: raise ValueError( "If you declare products with 'Annotated[..., Product]', only values of " "type 'pathlib.Path' optionally nested in tuples, lists, and " @@ -564,3 +564,12 @@ def _collect_product( ) return collected_node + + +def _evolve_instance(x: Any, instance_from_annot: Node | None) -> Any: + """Evolve a value to a node if it is given by annotations.""" + if not instance_from_annot: + return x + + instance_from_annot.value = x + return instance_from_annot diff --git a/src/_pytask/hookspecs.py b/src/_pytask/hookspecs.py index ef460c342..dfa83f92f 100644 --- a/src/_pytask/hookspecs.py +++ b/src/_pytask/hookspecs.py @@ -14,6 +14,7 @@ import networkx import pluggy from _pytask.models import NodeInfo +from _pytask.node_protocols import MetaNode from _pytask.node_protocols import Node @@ -266,7 +267,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: networkx.DiGraph) -> @hookspec(firstresult=True) def pytask_dag_has_node_changed( - session: Session, dag: networkx.DiGraph, node: Node, task_name: str + session: Session, dag: networkx.DiGraph, node: MetaNode, task_name: str ) -> None: """Select the subgraph which needs to be executed. diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 50c19df11..b86503109 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -86,10 +86,24 @@ class PathNode(Node): name: str = "" """Name of the node which makes it identifiable in the DAG.""" - value: Path | None = None + _value: Path | None = None """Value passed to the decorator which can be requested inside the function.""" - path: Path | None = None - """Path to the file.""" + + @property + def path(self) -> Path: + return self.value + + @property + def value(self) -> Path: + return self._value + + @value.setter + def value(self, value: Path) -> None: + if not isinstance(value, Path): + raise TypeError("'value' must be a 'pathlib.Path'.") + if not self.name: + self.name = value.as_posix() + self._value = value @classmethod @functools.lru_cache @@ -101,7 +115,7 @@ def from_path(cls, path: Path) -> PathNode: """ if not path.is_absolute(): raise ValueError("Node must be instantiated from absolute path.") - return cls(name=path.as_posix(), value=path, path=path) + return cls(name=path.as_posix(), value=path) def state(self) -> str | None: """Calculate the state of the node. diff --git a/tests/test_collect.py b/tests/test_collect.py index db951b15e..3e6c117b4 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -422,3 +422,24 @@ def task_write_text(depends_on, produces): assert "FutureWarning" in result.output assert "Using strings to specify a dependency" in result.output assert "Using strings to specify a product" in result.output + + +@pytest.mark.end_to_end() +def test_setting_name_for_path_node_via_annotation(tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated + from pytask import Product, PathNode + from typing import Any + + def task_example( + path: Annotated[Path, Product, PathNode(name="product")] = Path("out.txt"), + ) -> None: + path.write_text("text") + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = main({"paths": [tmp_path]}) + assert session.exit_code == ExitCode.OK + product = session.tasks[0].produces["path"] + assert product.name == "product" diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 1ac7e5710..4856db532 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -386,15 +386,9 @@ def test_print_collected_tasks_with_nodes(capsys): path=Path("task_path.py"), function=function, depends_on={ - "depends_on": PathNode( - name="in.txt", value=Path("in.txt"), path=Path("in.txt") - ) - }, - produces={ - 0: PathNode( - name="out.txt", value=Path("out.txt"), path=Path("out.txt") - ) + "depends_on": PathNode(name="in.txt", value=Path("in.txt")) }, + produces={0: PathNode(name="out.txt", value=Path("out.txt"))}, ) ] } @@ -586,3 +580,24 @@ def task_example( result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK assert "in.pkl" in result.output + + +@pytest.mark.end_to_end() +def test_setting_name_for_python_node_via_annotation(runner, tmp_path): + source = """ + from pathlib import Path + from typing_extensions import Annotated + from pytask import Product, PythonNode + from typing import Any + + def task_example( + input: Annotated[str, PythonNode(name="node-name")] = "text", + path: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + path.write_text(input) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert "node-name" in result.output diff --git a/tests/test_dag.py b/tests/test_dag.py index d78cfa862..09259abc8 100644 --- a/tests/test_dag.py +++ b/tests/test_dag.py @@ -3,7 +3,6 @@ import textwrap from contextlib import ExitStack as does_not_raise # noqa: N813 from pathlib import Path -from typing import Any import networkx as nx import pytest @@ -22,10 +21,6 @@ class Node(PathNode): """See https://github.com/python-attrs/attrs/issues/293 for property hack.""" - name: str - value: Any - path: Path - def state(self): if "missing" in self.name: raise NodeNotFoundError From eb0885577911b127b6d7e6696324888626de6728 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sun, 30 Jul 2023 20:40:14 +0200 Subject: [PATCH 8/8] Fix changes. --- docs/source/changes.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/changes.md b/docs/source/changes.md index 18c6c6c6d..7a535918c 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -17,8 +17,12 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`395` refactors all occurrences of pybaum to {mod}`_pytask.tree_util`. - {pull}`396` replaces pybaum with optree and adds paths to the name of {class}`pytask.PythonNode`'s allowing for better hashing. -- {class}`397` adds support for {class}`typing.NamedTuple` and attrs classes in +- {pull}`397` adds support for {class}`typing.NamedTuple` and attrs classes in `@pytask.mark.task(kwargs=...)`. +- {pull}`398` deprecates the decorators `@pytask.mark.depends_on` and + `@pytask.mark.produces`. +- {pull}`402` replaces ABCs with protocols allowing for more flexibility for users + implementing their own nodes. ## 0.3.2 - 2023-06-07