diff --git a/docs/source/changes.md b/docs/source/changes.md index 08604c9ba..eaa1292d0 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -24,9 +24,10 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`402` replaces ABCs with protocols allowing for more flexibility for users implementing their own nodes. - {pull}`404` allows to use function returns to define task products. -- {pull}`405` allows to match function returns to node annotations with prefix trees. -- {pull}`406` removes `.value` from `Node` protocol. -- {pull}`407` make `.from_annot` an optional feature of nodes. +- {pull}`406` allows to match function returns to node annotations with prefix trees. +- {pull}`408` removes `.value` from `Node` protocol. +- {pull}`409` make `.from_annot` an optional feature of nodes. +- {pull}`410` allows to pass functions to `PythonNode(hash=...)`. ## 0.3.2 - 2023-06-07 diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index ca506f261..c653304aa 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -174,7 +174,7 @@ class PythonNode(Node): """Name of the node.""" value: Any = None """Value of the node.""" - hash: bool = False # noqa: A003 + hash: bool | Callable[[Any], bool] = False # noqa: A003 """Whether the value should be hashed to determine the state.""" def load(self) -> Any: @@ -206,6 +206,8 @@ def state(self) -> str | None: If ``hash = False``, the function returns ``"0"``, a constant hash value, so the :class:`PythonNode` is ignored when checking for a changed state of the task. + If ``hash`` is a callable, then use this function to calculate a hash. + If ``hash = True``, :func:`hash` is used for all types except strings. The hash for strings is calculated using hashlib because ``hash("asd")`` returns @@ -214,6 +216,8 @@ def state(self) -> str | None: """ if self.hash: + if callable(self.hash): + return str(self.hash(self.value)) if isinstance(self.value, str): return str(hashlib.sha256(self.value.encode()).hexdigest()) return str(hash(self.value)) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index a58d7e501..067df04d3 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest +from _pytask.nodes import PythonNode from _pytask.shared import reduce_node_name from pytask import PathNode @@ -35,3 +36,20 @@ def test_reduce_node_name(node, paths, expectation, expected): with expectation: result = reduce_node_name(node, paths) assert result == expected + + +@pytest.mark.unit() +@pytest.mark.parametrize( + ("value", "hash_", "expected"), + [ + (0, False, "0"), + (0, True, "0"), + (0, lambda x: 1, "1"), # noqa: ARG005 + ("0", False, "0"), + ("0", True, "5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91b46729d73a27fb57e9"), + ], +) +def test_hash_of_python_node(value, hash_, expected): + node = PythonNode(name="test", value=value, hash=hash_) + state = node.state() + assert state == expected