diff --git a/src/daggerml/__init__.py b/src/daggerml/__init__.py index 607f9a5..a1ab66a 100644 --- a/src/daggerml/__init__.py +++ b/src/daggerml/__init__.py @@ -5,7 +5,7 @@ with strong typing support and a context-manager based interface. """ -from daggerml.core import Dag, Dml, Error, Node, Resource +from daggerml.core import Dag, Dml, Error, Executable, Node, Resource try: from daggerml.__about__ import __version__ @@ -13,4 +13,4 @@ __version__ = "local" -__all__ = ("Dag", "Dml", "Error", "Node", "Resource") +__all__ = ("Dag", "Dml", "Error", "Executable", "Node", "Resource") diff --git a/src/daggerml/core.py b/src/daggerml/core.py index 7271d74..f9f2b35 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -6,24 +6,15 @@ import traceback as tb from dataclasses import dataclass, field, fields from tempfile import TemporaryDirectory -from typing import Any, Callable, Optional, Union, cast - -from daggerml.util import ( - BackoffWithJitter, - current_time_millis, - kwargs2opts, - postwalk, - properties, - raise_ex, - replace, - setter, -) +from typing import Any, Callable, Dict, Iterator, Optional, Union, cast, overload + +from daggerml.util import BackoffWithJitter, current_time_millis, kwargs2opts, raise_ex, replace log = logging.getLogger(__name__) DATA_TYPE = {} -Scalar = Union[str, int, float, bool, type(None), "Resource", "Node"] +Scalar = Union[str, int, float, bool, type(None), "Resource", "Executable"] Collection = Union[list, tuple, set, dict] @@ -91,11 +82,26 @@ class Ref: # noqa: F811 @dml_type -@dataclass(frozen=True) +@dataclass class Resource: # noqa: F811 """ Representation of an externally managed object with an identifier. + Parameters + ---------- + uri : str + Resource URI + """ + + uri: str + + +@dml_type +@dataclass +class Executable(Resource): # noqa: F811 + """ + Representation of an executable externally managed object with an identifier. + Parameters ---------- uri : str @@ -103,12 +109,12 @@ class Resource: # noqa: F811 data : str, optional Associated data adapter : str, optional - Resource adapter name + Adapter cli script """ - uri: str - data: Optional[str] = None + data: dict = field(default_factory=dict) adapter: Optional[str] = None + prepop: Dict[str, Union["Node", Scalar, Collection]] = field(default_factory=dict) @dml_type @@ -163,6 +169,11 @@ class Dml: token: Union[str, None] = None tmpdirs: dict[str, TemporaryDirectory] = field(default_factory=dict) + @property + def index(self) -> Optional[str]: + if self.token: + return json.loads(self.token)[-1] + @property def kwargs(self) -> dict: out = { @@ -245,12 +256,7 @@ def new(self, name="", message="", data=None, message_handler=None) -> "Dag": return Dag(replace(self, token=token), message_handler) def load(self, name: Union[str, "Node"], recurse=False) -> "Dag": - return Dag(replace(self, token=None), _ref=self.get_dag(name, recurse=recurse)) - - -@dataclass -class Boxed: - value: Any + return Dag(replace(self, token=None), ref=self.get_dag(name, recurse=recurse)) def make_node(dag: "Dag", ref: Ref) -> "Node": @@ -269,27 +275,29 @@ def make_node(dag: "Dag", ref: Ref) -> "Node": Node A Node instance representing the reference in the DAG. """ - info = dag._dml("node", "describe", ref.to) + info = dag.dml("node", "describe", ref.to) if info["data_type"] == "list": return ListNode(dag, ref, _info=info) if info["data_type"] == "dict": return DictNode(dag, ref, _info=info) if info["data_type"] == "set": return ListNode(dag, ref, _info=info) + if info["data_type"] == "executable": + return ExecutableNode(dag, ref, _info=info) if info["data_type"] == "resource": - return ResourceNode(dag, ref, _info=info) - return Node(dag, ref, _info=info) + return ExecutableNode(dag, ref, _info=info) + return ScalarNode(dag, ref, _info=info) @dataclass class Dag: - _dml: Dml - _message_handler: Optional[Callable] = None - _ref: Optional[Ref] = None - _init_complete: bool = False + dml: Dml + message_handler: Optional[Callable] = None + ref: Optional[Ref] = None - def __post_init__(self): - self._init_complete = True + def __repr__(self): + to = self.ref.to if self.ref else self.dml.index or "NA" + return f"Dag({to})" def __hash__(self): "Useful only for tests." @@ -297,74 +305,98 @@ def __hash__(self): def __enter__(self): "Catch exceptions and commit an Error" - assert not self._ref + assert not self.ref return self def __exit__(self, exc_type, exc_value, traceback): if exc_value is not None: - self._commit(Error.from_ex(exc_value)) + self.commit(Error.from_ex(exc_value)) - def __getitem__(self, name) -> "Node": - return make_node(self, self._dml.get_node(name, self._ref)) + def __getitem__(self, name): + return make_node(self, self.dml.get_node(name, self.ref)) - def __setitem__(self, name, value) -> "Node": - assert not self._ref + def __setitem__(self, name, value): + assert not self.ref if isinstance(value, Ref): - return self._dml.set_node(name, value) - return self._put(value, name=name) + return self.dml.set_node(name, value) + return self.put(value, name=name) + + def __setattr__(self, name, value): + if name in [x.name for x in fields(self.__class__)]: + return super().__setattr__(name, value) + return self.__setitem__(name, value) + + def __getattr__(self, name): + if name in [x.name for x in fields(self.__class__)]: + return super().__getattribute__(name) + return self.__getitem__(name) def __len__(self) -> int: - return len(self._dml.get_names(self._ref)) + return len(self.dml.get_names(self.ref)) def __iter__(self): - for k in self.keys(): - yield k + yield from self.keys() - def __setattr__(self, name, value): - priv = name.startswith("_") - flds = name in {x.name for x in fields(self)} - prps = name in properties(self) - init = not self._init_complete - boxd = isinstance(value, Boxed) - if (flds and init) or (not self._ref and ((not flds and not priv) or prps or boxd)): - value = value.value if boxd else value - if flds or (prps and setter(self, name)): - return super(Dag, self).__setattr__(name, value) - elif not prps: - return self.__setitem__(name, value) - raise AttributeError(f"can't set attribute: '{name}'") + def keys(self) -> list[str]: + """Get the list of all node names in the dag""" + return self.dml.get_names(self.ref).keys() - def __getattr__(self, name): - return self.__getitem__(name) + def values(self) -> list["Node"]: + """Get the list of all nodes in the dag""" + nodes = self.dml.get_names(self.ref).values() + return [make_node(self, x) for x in nodes] @property - def argv(self) -> "Node": + def argv(self) -> "ListNode": "Access the dag's argv node" - return make_node(self, self._dml.get_argv(self._ref)) + return make_node(self, self.dml.get_argv(self.ref)) @property def result(self) -> "Node": - ref = self._dml.get_result(self._ref) - assert ref, f"'{self.__class__.__name__}' has no attribute 'result'" + """Get the result node of the dag""" + ref = self.dml.get_result(self.ref) + assert isinstance(ref, Ref), f"'{self.__class__.__name__}' dag has not been committed yet" return make_node(self, ref) - @result.setter - def result(self, value): - return self._commit(value) - - @property - def keys(self) -> list[str]: - return lambda: self._dml.get_names(self._ref).keys() + def import_(self, dag_name: str, *, name=None, doc=None) -> "Node": + """Import a dag result into this dag - @property - def values(self) -> list["Node"]: - def result(): - nodes = self._dml.get_names(self._ref).values() - return [make_node(self, x) for x in nodes] + Parameters + ---------- + dag_name : str + Name of the dag to load + name : str, optional + Name for the node + doc : str, optional + Documentation for the node - return result + Returns + ------- + Node + Import Node representing the result of the loaded dag - def _put(self, value: Union[Scalar, Collection], *, name=None, doc=None) -> "Node": + Examples + -------- + >>> dml = Dml.temporary() + >>> dml.new("my-dag-0", "going to import this").commit(42) + >>> dag = dml.new("my-dag-1", "importing my-dag-0") + >>> node = dag.import_("my-dag-0") + >>> node.value() + 42 + """ + return self.put(self.dml.load(dag_name).result, name=name, doc=doc) + + @overload + def put(self, value: Union[list, set, "ListNode"], *, name=None, doc=None) -> "ListNode": ... + @overload + def put(self, value: Union[dict, "DictNode"], *, name=None, doc=None) -> "DictNode": ... + @overload + def put(self, value: Union[Executable, "ExecutableNode"], *, name=None, doc=None) -> "ExecutableNode": ... + @overload + def put(self, value: Union[Scalar, "ScalarNode"], *, name=None, doc=None) -> "ScalarNode": ... + @overload + def put(self, value: "Node", *, name=None, doc=None) -> "Node": ... + def put(self, value: Union[Scalar, Collection, "Node"], *, name=None, doc=None) -> "Node": """ Add a value to the DAG. @@ -381,36 +413,99 @@ def _put(self, value: Union[Scalar, Collection], *, name=None, doc=None) -> "Nod ------- Node Node representing the value + + Examples + -------- + >>> dml = Dml.temporary() + >>> dag = dml.new("test", "test") + >>> n1 = dag.put(42, name="answer", doc="the answer to life, the universe, and everything") + >>> n1.value() + 42 + >>> n2 = dag.put({"a": 1, "b": [n1, "23"]}) + >>> n2.value() + {'a': 1, 'b': [42, '23']} + >>> dml.new("other-dag", "we'll import from here").commit(308) # create and commit another dag to import + >>> n3 = dag.load("other-dag") + >>> n3.value() + 308 """ - value = postwalk( - value, - lambda x: isinstance(x, Node) and x.dag._ref, - lambda x: self._load(x.dag, x.ref), - ) - return make_node(self, self._dml.put_literal(value, name=name, doc=doc)) + if isinstance(value, Node) and value.dag != self: + return make_node(self, self.dml.put_load(value.dag.ref, value.ref, name=name, doc=doc)) + return make_node(self, self.dml.put_literal(value, name=name, doc=doc)) - def _load(self, dag_name, node=None, *, name=None, doc=None) -> "Node": + def get(self, name: str) -> "Node": """ - Load a DAG by name. + Get a node reference by name. Parameters ---------- - dag_name : str - Name of the DAG to load + name : str + Name of the node + + Returns + ------- + Node + Node representing the named node + + Raises + ------ + KeyError + If the node is not found + """ + if name in self.nodes: + return self.nodes[name] + raise KeyError(f"Node '{name}' not found in DAG") + + def call( + self, + fn: Union[Executable, "ExecutableNode"], + *args: Union["Node", Scalar, Collection], + name: Optional[str] = None, + doc: Optional[str] = None, + sleep: Optional[callable] = None, + timeout: int = 30000, + ) -> "Node": + """ + Call a function node with arguments. + + Parameters + ---------- + fn : Union[Executable, ExecutableNode] + Function to call + *args : Union[Node, Scalar, Collection] + Arguments to pass to the function name : str, optional - Name for the node + Name for the result node doc : str, optional Documentation + sleep : callable, optional + A nullary function that returns sleep time in milliseconds + timeout : int, default=30000 + Maximum time to wait in milliseconds Returns ------- Node - Node representing the loaded DAG + Result node + + Raises + ------ + TimeoutError + If the function call exceeds the timeout + Error + If the function returns an error """ - dag = dag_name if isinstance(dag_name, str) else dag_name._ref - return make_node(self, self._dml.put_load(dag, node, name=name, doc=doc)) + sleep = sleep or BackoffWithJitter() + expr = [self.put(x) for x in [fn, *args]] + end = current_time_millis() + timeout + while timeout <= 0 or current_time_millis() < end: + resp = self.dml.start_fn(expr, name=name, doc=doc) + if resp: + return make_node(self, resp) + time.sleep(sleep() / 1000) + raise TimeoutError(f"invoking function: {expr[0].value()}") - def _commit(self, value) -> "Node": + def commit(self, value) -> None: """ Commit a value to the DAG. @@ -419,11 +514,11 @@ def _commit(self, value) -> "Node": value : Union[Node, Error, Any] Value to commit """ - value = value if isinstance(value, (Node, Error)) else self._put(value) - ref = cast(Ref, self._dml.commit(value)) - if self._message_handler: - self._message_handler(self._dml("ref", "dump", to_json(ref), as_text=True)) - self._ref = Boxed(ref) + value = value if isinstance(value, (Node, Error)) else self.put(value) + ref = cast(Ref, self.dml.commit(value)) + if self.message_handler: + self.message_handler(self.dml("ref", "dump", to_json(ref), as_text=True)) + self.ref = ref @dataclass(frozen=True) @@ -453,19 +548,17 @@ def __hash__(self): @property def argv(self) -> "Node": "Access the node's argv list" - return [make_node(self.dag, x) for x in self.dag._dml.get_argv(self)] + return [make_node(self.dag, x) for x in self.dag.dml.get_argv(self)] - def load(self, *keys: Union[str, int]) -> Dag: + def backtrack(self, *keys: Union[str, int]) -> "Node": """ - Convenience wrapper around `dml.load(node)` - If `key` is provided, it considers this node to be a collection created by the appropriate method and loads the dag that corresponds to this key Parameters ---------- *keys : str, optional - Key to load from the DAG. If not provided, the entire DAG is loaded. + Keys to backtrack through the node's structure Returns ------- @@ -476,23 +569,42 @@ def load(self, *keys: Union[str, int]) -> Dag: -------- >>> dml = Dml.temporary() >>> dag = dml.new("test", "test") - >>> l0 = dag._put(42) - >>> c0 = dag._put({"a": 1, "b": [l0, "23"]}) - >>> assert c0.load("b", 0) == l0 - >>> assert c0.load("b").load(0) == l0 + >>> l0 = dag.put(42) + >>> c0 = dag.put({"a": 1, "b": [l0, "23"]}) + >>> assert c0.backtrack("b", 0) == l0 + >>> assert c0.backtrack("b").backtrack(0) == l0 >>> assert c0["b"][0] != l0 # this is a different node, not the same as l0 >>> dml.cleanup() """ - if len(keys) == 0: - return self.dag._dml.load(self) - data = self.dag._dml("node", "backtrack", self.ref.to, *map(str, keys)) + data = self.dag.dml("node", "backtrack", self.ref.to, *map(str, keys)) return make_node(self.dag, from_data(data)) + def load(self) -> Dag: + """ + Convenience wrapper around `dml.load(node)` + + Returns + ------- + Dag + The dag that this node was imported from (or in the case of a function call, this returns the fndag) + """ + return self.dag.dml.load(self) + @property def type(self): """Get the data type of the node.""" return self._info["data_type"] + @overload + def value(self: "ScalarNode") -> Scalar: ... + @overload + def value(self: "ListNode") -> list: ... + @overload + def value(self: "DictNode") -> dict: ... + @overload + def value(self: "ExecutableNode") -> Executable: ... + @overload + def value(self: "Node") -> Any: ... def value(self): """ Get the concrete value of this node. @@ -502,10 +614,14 @@ def value(self): Any The actual value represented by this node """ - return self.dag._dml.get_node_value(self.ref) + return self.dag.dml.get_node_value(self.ref) + + +class ScalarNode(Node): + pass -class ResourceNode(Node): +class ExecutableNode(Node): def __call__(self, *args, name=None, doc=None, sleep=None, timeout=0) -> "Node": """ Call this node as a function. @@ -536,10 +652,10 @@ def __call__(self, *args, name=None, doc=None, sleep=None, timeout=0) -> "Node": If the function returns an error """ sleep = sleep or BackoffWithJitter() - args = [self.dag._put(x) for x in args] + args = [self.dag.put(x) for x in args] end = current_time_millis() + timeout while timeout <= 0 or current_time_millis() < end: - resp = self.dag._dml.start_fn([self, *args], name=name, doc=doc) + resp = self.dag.dml.start_fn([self, *args], name=name, doc=doc) if resp: return make_node(self.dag, resp) time.sleep(sleep() / 1000) @@ -558,7 +674,11 @@ class CollectionNode(Node): # noqa: F811 Node reference """ - def __getitem__(self, key: Union[slice, str, int, "Node"]) -> "Node": + @overload + def __getitem__(self, key: slice) -> "ListNode": ... + @overload + def __getitem__(self, key: Union[str, int, "Node"]) -> Any: ... + def __getitem__(self, key: Union[slice, str, int, "Node"]) -> Any: """ Get the `key` item. It should be the same as if you were working on the actual value. @@ -577,7 +697,7 @@ def __getitem__(self, key: Union[slice, str, int, "Node"]) -> "Node": -------- >>> dml = Dml.temporary() >>> dag = dml.new("test", "test") - >>> node = dag._put({"a": 1, "b": [5, 6]}) + >>> node = dag.put({"a": 1, "b": [5, 6]}) >>> nested = node["a"] >>> isinstance(nested, Node) True @@ -588,9 +708,9 @@ def __getitem__(self, key: Union[slice, str, int, "Node"]) -> "Node": """ if isinstance(key, slice): key = [key.start, key.stop, key.step] - return make_node(self.dag, self.dag._dml.get(self, key)) + return make_node(self.dag, self.dag.dml.get(self, key)) - def contains(self, item, *, name=None, doc=None): + def contains(self, item, *, name=None, doc=None) -> "ScalarNode": """ For collection nodes, checks to see if `item` is in `self` @@ -599,7 +719,7 @@ def contains(self, item, *, name=None, doc=None): Node Node with the boolean of is `item` in `self` """ - return make_node(self.dag, self.dag._dml.contains(self, item, name=name, doc=doc)) + return make_node(self.dag, self.dag.dml.contains(self, item, name=name, doc=doc)) def __contains__(self, item): return self.contains(item).value() # has to return boolean @@ -622,13 +742,13 @@ def __len__(self): # python requires this to be an int return self._info["length"] raise Error(f"Cannot get length of type: {self._info['data_type']}", origin="dml", type="TypeError") - def get(self, key, default=None, *, name=None, doc=None): + def get(self, key, default=None, *, name=None, doc=None) -> "Node": """ For a dict node, return the value for key if key exists, else default. If default is not given, it defaults to None, so that this method never raises a KeyError. """ - return make_node(self.dag, self.dag._dml.get(self, key, default, name=name, doc=doc)) + return make_node(self.dag, self.dag.dml.get(self, key, default, name=name, doc=doc)) class ListNode(CollectionNode): # noqa: F811 @@ -661,7 +781,7 @@ def __iter__(self): for i in range(len(self)): yield self[i] - def conj(self, item, *, name=None, doc=None): + def conj(self, item, *, name=None, doc=None) -> "ListNode": """ For a list or set node, append an item @@ -674,9 +794,9 @@ def conj(self, item, *, name=None, doc=None): ----- `append` is an alias `conj` """ - return make_node(self.dag, self.dag._dml.conj(self, item, name=name, doc=doc)) + return make_node(self.dag, self.dag.dml.conj(self, item, name=name, doc=doc)) - def append(self, item, *, name=None, doc=None): + def append(self, item, *, name=None, doc=None) -> "ListNode": """ For a list or set node, append an item @@ -729,7 +849,7 @@ def __iter__(self): for k in self.keys(): yield k - def items(self): + def items(self) -> Iterator[tuple[str, "Node"]]: """ Iterate over key-value pairs of a dictionary node. @@ -761,7 +881,7 @@ def values(self) -> list["Node"]: """ return [self[k] for k in self] - def assoc(self, key, value, *, name=None, doc=None): + def assoc(self, key, value, *, name=None, doc=None) -> "DictNode": """ For a dict node, associate a new value into the map @@ -770,9 +890,9 @@ def assoc(self, key, value, *, name=None, doc=None): Node Node containing the new dict """ - return make_node(self.dag, self.dag._dml.assoc(self, key, value, name=name, doc=doc)) + return make_node(self.dag, self.dag.dml.assoc(self, key, value, name=name, doc=doc)) - def update(self, update): + def update(self, update) -> "DictNode": """ For a dict node, update like python dicts diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index cbb1a1a..21d02d0 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit cbb1a1af1da0dafd96857c52fa01519975129b25 +Subproject commit 21d02d0f14da6f96b0ee7a5d4fbf85c61c62aa88 diff --git a/tests/assets/fns/async.py b/tests/assets/fns/async.py index 295f795..9b5de33 100644 --- a/tests/assets/fns/async.py +++ b/tests/assets/fns/async.py @@ -16,6 +16,6 @@ if os.path.isfile(cache_file): with dml.new("test", "test", stdin["dump"], print) as d0: - d0.result = sum(d0.argv[1:].value()) + d0.commit(sum(d0.argv[1:].value())) else: open(cache_file, "w").close() diff --git a/tests/assets/fns/sum.py b/tests/assets/fns/sum.py index 29ca54a..688804e 100644 --- a/tests/assets/fns/sum.py +++ b/tests/assets/fns/sum.py @@ -7,8 +7,8 @@ if __name__ == "__main__": stdin = json.loads(sys.stdin.read()) with Dml.temporary(cache_path=stdin["cache_path"]) as dml: - with dml.new("test", "test", stdin["dump"], print) as d0: - d0.num_args = len(d0.argv[1:]) - d0.n0 = sum(d0.argv[1:].value()) - d0.uuid = str(uuid4()) - d0.result = d0.n0 + with dml.new("test", "test", stdin["dump"], print) as dag: + dag.put(len(dag.argv[1:]), name="num_args") + dag.put(sum(dag.argv[1:].value()), name="n0") + dag.put(str(uuid4()), name="uuid") + dag.commit(dag.n0) diff --git a/tests/test_core.py b/tests/test_core.py index 1a77b1d..2b2348d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,12 +2,12 @@ from tempfile import TemporaryDirectory from unittest import TestCase, mock -from daggerml.core import Dag, Dml, Error, Node, Resource +from daggerml.core import Dag, Dml, Error, Executable, Node, from_data -SUM = Resource("./tests/assets/fns/sum.py", adapter="dml-python-fork-adapter") -ASYNC = Resource("./tests/assets/fns/async.py", adapter="dml-python-fork-adapter") -ENVVARS = Resource("./tests/assets/fns/envvars.py", adapter="dml-python-fork-adapter") -TIMEOUT = Resource("./tests/assets/fns/timeout.py", adapter="dml-python-fork-adapter") +SUM = Executable("./tests/assets/fns/sum.py", adapter="dml-python-fork-adapter") +ASYNC = Executable("./tests/assets/fns/async.py", adapter="dml-python-fork-adapter") +ENVVARS = Executable("./tests/assets/fns/envvars.py", adapter="dml-python-fork-adapter") +TIMEOUT = Executable("./tests/assets/fns/timeout.py", adapter="dml-python-fork-adapter") class TestBasic(TestCase): @@ -64,6 +64,25 @@ def test_init_kwargs(self): }, ) + def test_message_handler_load(self): + local_value = None + + def message_handler(dump): + nonlocal local_value + local_value = dump + + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + d0 = dml.new("d0", "d0", message_handler=message_handler) + data = {"key": "value", "list": [1, 2, 3], "dict": {"a": 1, "b": 2}, "resource": SUM} + n0 = d0.put(data, name="n0") + d0.commit(n0) + assert isinstance(local_value, str) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + ref = from_data(dml("ref", "load", local_value)) + assert len(dml("dag", "describe", ref.to)["nodes"]) == 1 + def test_dag(self): local_value = None @@ -75,11 +94,11 @@ def message_handler(dump): with Dml.temporary(cache_path=cache_path) as dml: d0 = dml.new("d0", "d0", message_handler=message_handler) self.assertIsInstance(d0, Dag) - # d0.n0 = [42] - n0 = d0._put([42], name="n0") + n0 = d0.put([42], name="n0") assert isinstance(n0, Node) self.assertIsInstance(n0, Node) self.assertEqual(n0.value(), [42]) + assert len(d0) == 1 self.assertEqual(len(n0), 1) self.assertEqual(n0.type, "list") d0["x0"] = n0 @@ -104,10 +123,10 @@ def message_handler(dump): d0.n4 = [1, 2, 3, 4, 5] d0.n5 = d0.n4[1:] self.assertListEqual([x.value() for x in d0.n5], [2, 3, 4, 5]) - d0.result = result = n0 + d0.commit(n0) self.assertIsInstance(local_value, str) dag = dml("dag", "list")[0] - self.assertEqual(dag["result"], result.ref.to) + self.assertEqual(dag["result"], n0.ref.to) assert len(dml("dag", "list", "--all")) > 1 dml("dag", "delete", dag["name"], "Deleting dag") dml("repo", "gc", as_text=True) @@ -115,48 +134,48 @@ def message_handler(dump): def test_list_attrs(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = [0] - assert d0.n0.contains(1).value() is False - assert d0.n0.contains(0).value() is True - assert 0 in d0.n0 - d0.n1 = d0.n0.append(1) - assert d0.n1.value() == [0, 1] + with dml.new("d0", "d0") as dag: + n0 = dag.put([0]) + assert n0.contains(1).value() is False + assert n0.contains(0).value() is True + assert 0 in n0 + n1 = n0.append(1) + assert n1.value() == [0, 1] def test_set_attrs(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = {0} - assert d0.n0.contains(1).value() is False - assert d0.n0.contains(0).value() is True - assert 0 in d0.n0 - d0.n1 = d0.n0.append(1) - assert d0.n1.value() == {0, 1} + with dml.new("d0", "d0") as dag: + n0 = dag.put({0}) + assert n0.contains(1).value() is False + assert n0.contains(0).value() is True + assert 0 in n0 + n1 = n0.append(1) + assert n1.value() == {0, 1} def test_dict_attrs(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = {"x": 42} - assert d0.n0.contains("y").value() is False - assert d0.n0.contains("x").value() is True - assert "y" not in d0.n0 - assert "x" in d0.n0 - d0.n1 = d0.n0.assoc("y", 3) - assert d0.n1.value() == {"x": 42, "y": 3} - d0.n2 = d0.n1.update({"z": 1, "a": 2}) - assert d0.n2.value() == {"a": 2, "x": 42, "y": 3, "z": 1} + with dml.new("d0", "d0") as dag: + n0 = dag.put({"x": 42}, name="n0") + assert n0.contains("y").value() is False + assert n0.contains("x").value() is True + assert "y" not in n0 + assert "x" in n0 + n1 = n0.assoc("y", 3) + assert n1.value() == {"x": 42, "y": 3} + n2 = n1.update({"z": 1, "a": 2}) + assert n2.value() == {"a": 2, "x": 42, "y": 3, "z": 1} def test_load_constructors(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - d0 = dml.new("d0", "d0") - l0 = d0._put(42) - c0 = d0._put({"a": 1, "b": [l0, "23"]}) - assert c0.load("b", 0) == l0 - assert c0.load("b", 1).value() == "23" - assert c0.load("b").load(0) == l0 + dag = dml.new("d0", "d0") + l0 = dag.put(42) + c0 = dag.put({"a": 1, "b": [l0, "23"]}) + assert c0.backtrack("b", 0) == l0 + assert c0.backtrack("b", 1).value() == "23" + assert c0.backtrack("b").backtrack(0) == l0 assert c0["b"][0] != l0 def test_fn_ok_cache(self): @@ -164,12 +183,11 @@ def test_fn_ok_cache(self): with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = SUM - nodes = [d0.n0(i, 1, 2) for i in range(2)] # unique function applications - d0.n0(0, 1, 2) # add a repeat outside so `nodes` is still unique - d0.result = nodes[0] - self.assertEqual(d0.result.value(), 3) + with dml.new("d0", "d0") as dag: + nodes = [dag.call(SUM, i, 1, 2) for i in range(2)] # unique function applications + dag.call(SUM, 0, 1, 2) # add a repeat outside so `nodes` is still unique + dag.commit(nodes[0]) + self.assertEqual(dag.result.value(), 3) cache_list = dml("cache", "list", as_text=True) # response is jsonlines format assert len([x for x in cache_list if x.rstrip() == "{"]) == 2 # this gets us unique maps @@ -179,11 +197,10 @@ def test_async_fn_ok(self): debug_file = os.path.join(fn_cache_dir, "debug") with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = ASYNC - d0.n1 = d0.n0(1, 2, 3) - d0.result = result = d0.n1 - self.assertEqual(result.value(), 6) + with dml.new("d0", "d0") as dag: + n1 = dag.call(ASYNC, 1, 2, 3) + dag.commit(n1) + self.assertEqual(n1.value(), 6) with open(debug_file, "r") as f: self.assertEqual(len([1 for _ in f]), 2) @@ -193,9 +210,8 @@ def test_async_fn_error(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: with self.assertRaisesRegex(Error, r".*unsupported operand type.*"): - with dml.new("d0", "d0") as d0: - d0.n0 = ASYNC - d0.n1 = d0.n0(1, 2, "asdf") + with dml.new("d0", "d0") as dag: + dag.call(ASYNC, 1, 2, "asdf") info = [x for x in dml("dag", "list") if x["name"] == "d0"] self.assertEqual(len(info), 1) @@ -203,40 +219,87 @@ def test_async_fn_timeout(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: with self.assertRaises(TimeoutError): - with dml.new("d0", "d0") as d0: - d0.n0 = TIMEOUT - d0.n0(1, 2, 3, timeout=1000) + with dml.new("d0", "d0") as dag: + dag.call(TIMEOUT, 1, 2, 3, timeout=1000) def test_load(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = 42 - d0.result = "foo" + with dml.new("d0", "d0") as dag: + dag.put(42, name="n0") + dag.commit("foo") dl = dml.load("d0") assert isinstance(dl, Dag) - self.assertEqual(type(dl.n0), Node) self.assertEqual(dl.n0.value(), 42) - self.assertEqual(type(dl.result), Node) self.assertEqual(dl.result.value(), "foo") + def test_load_reboot(self): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as dag: + dag.put(42, name="n0") + dag.commit("foo") + with dml.new("d1", "d1") as dag: + node = dag.import_("d0", name="n1") + assert node.dag == dag + assert node.value() == "foo" + assert node.load().n0.value() == 42 + + def test_node_call_w_literal_deps(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + fn = Executable( + "./tests/assets/fns/sum.py", + adapter="dml-python-fork-adapter", + prepop={"x": 10}, + ) + result = dag.call(fn, *nums) + assert result.value() == sum(nums) + assert "x" in result.load().keys() + assert result.load().x.value() == 10 + + def test_node_call_w_node_deps(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + fn = Executable( + "./tests/assets/fns/sum.py", + adapter="dml-python-fork-adapter", + prepop={"x": dag.put(10)}, + ) + result = dag.call(fn, *nums) + assert result.value() == sum(nums) + assert "x" in result.load().keys() + assert result.load().x.value() == 10 + + def test_node_call(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + fn = dag.put(SUM) + result = fn(*nums) + assert result.value() == sum(nums) + def test_load_recursing(self): nums = [1, 2, 3] with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): - with dml.new("d0", "d0") as d0: - d0.n0 = SUM - d0.n1 = d0.n0(*nums) - assert d0.n1.dag == d0 - d0.result = d0.n1 + with dml.new("d0", "d0") as dag: + dag.commit(dag.call(SUM, *nums, name="n1")) d1 = dml.new("d1", "d1") - d1.n1 = dml.load("d0").n1 - assert d1.n1.dag == d1 - d1.n2 = dml.load(d1.n1, recurse=True).num_args - assert d1.n2.value() == len(nums) - assert d1.n1.value() == sum(nums) - assert isinstance(d1.n1.load(), Dag) + n1 = d1.put(dml.load("d0").n1, name="n1_1") + assert n1.dag == d1 + n2 = n1.load().n1.load().num_args + assert n2.value() == len(nums) + assert n1.value() == sum(nums) def test_caching(self): nums = [1, 2, 3] @@ -244,17 +307,15 @@ def test_caching(self): with Dml.temporary(cache_path=cache_path) as dml: config_dir = dml.config_dir with dml.new("d0", "d0") as d1: - d1.sum_fn = SUM - n1 = d1.sum_fn(*nums, name="n1") + n1 = d1.call(SUM, *nums) assert n1.value() == sum(nums) assert isinstance(n1.load(), Dag) - uid = d1.n1.load().uuid.value() + uid = n1.load().uuid.value() with Dml.temporary(cache_path=cache_path) as dml: assert dml.config_dir != config_dir, "Config dir should not be the same" with dml.new("d1", "d0") as d1: - d1.sum_fn = SUM - d1.n1 = d1.sum_fn(*nums) - uid1 = d1.n1.load().uuid.value() + n1 = d1.call(SUM, *nums) + uid1 = n1.load().uuid.value() assert uid == uid1, "Cached dag should have the same UUID" def test_no_caching(self): @@ -263,15 +324,24 @@ def test_no_caching(self): with Dml.temporary(cache_path=cache_path) as dml: config_dir = dml.config_dir with dml.new("d0", "d0") as d1: - d1.n0 = SUM - d1.n1 = d1.n0(*nums) - assert isinstance(d1.n1, Node) - uid = d1.n1.load().uuid.value() + n1 = d1.call(SUM, *nums) + uid = n1.load().uuid.value() with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: assert dml.config_dir != config_dir, "Config dir should not be the same" with dml.new("d1", "d0") as d1: - d1.n0 = SUM - d1.n1 = d1.n0(*nums) - uid1 = d1.n1.load().uuid.value() + n1 = d1.call(SUM, *nums) + uid1 = n1.load().uuid.value() assert uid != uid1, "Cached dag should have the same UUID" + + + def test_nodemap(self): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.a = 23 + node = d0.put(42, name="b") + other = d0.put(420) + assert d0.a.value() == 23 + assert list(d0) == ["a", "b"] + d0.commit([node, other])