From f72f9d0589fdaa24d39df174a6ca0f349152b2b1 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Tue, 25 Mar 2025 21:17:34 -0700 Subject: [PATCH 1/7] wip --- submodules/daggerml_cli | 2 +- tests/assets/fns/async.py | 8 ++++++-- tests/assets/fns/error.py | 19 ------------------- tests/assets/fns/sum.py | 7 ++++++- tests/test_core.py | 7 +++---- 5 files changed, 16 insertions(+), 27 deletions(-) delete mode 100644 tests/assets/fns/error.py diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index 623ef91..223d244 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit 623ef912682e6f074c52887f0dd31d58fd45dfb1 +Subproject commit 223d244dc8dd01d8c0cbb68e0d890c38787abe89 diff --git a/tests/assets/fns/async.py b/tests/assets/fns/async.py index 83549e3..e913761 100644 --- a/tests/assets/fns/async.py +++ b/tests/assets/fns/async.py @@ -4,7 +4,11 @@ from daggerml import Dml -# print(sys.stdin.read(), file=sys.stderr) + +def pr(dump): + print(json.dumps({"dump": dump})) + + with Dml() as dml: stdin = json.loads(sys.stdin.read()) cache_dir = os.getenv("DML_FN_CACHE_DIR", "") @@ -15,7 +19,7 @@ f.write("ASYNC EXECUTING\n") if os.path.isfile(cache_file): - with dml.new("test", "test", stdin["dump"], print) as d0: + with dml.new("test", "test", stdin["dump"], pr) as d0: d0.result = sum(d0.argv[1:].value()) else: open(cache_file, "w").close() diff --git a/tests/assets/fns/error.py b/tests/assets/fns/error.py deleted file mode 100644 index 81b8b3d..0000000 --- a/tests/assets/fns/error.py +++ /dev/null @@ -1,19 +0,0 @@ -import json -import os -import sys - -from daggerml import Dml - -with Dml(data=json.loads(sys.stdin.read())["dump"]) as dml: - cache_dir = os.getenv("DML_FN_CACHE_DIR", "") - cache_file = os.path.join(cache_dir, dml.cache_key) - debug_file = os.path.join(cache_dir, "debug") - - with open(debug_file, "a") as f: - f.write("ASYNC EXECUTING\n") - - if os.path.isfile(cache_file): - with dml.new("test", "test", json.loads(sys.stdin.read())["dump"], print) as d0: - d0.result = 1 / 0 - else: - open(cache_file, "w").close() diff --git a/tests/assets/fns/sum.py b/tests/assets/fns/sum.py index 01987bb..9176194 100644 --- a/tests/assets/fns/sum.py +++ b/tests/assets/fns/sum.py @@ -3,8 +3,13 @@ from daggerml import Dml + +def pr(dump): + print(json.dumps({"dump": dump})) + + with Dml() as dml: - with dml.new("test", "test", json.loads(sys.stdin.read())["dump"], print) as d0: + with dml.new("test", "test", json.loads(sys.stdin.read())["dump"], pr) as d0: d0.num_args = len(d0.argv[1:]) d0.n0 = sum(d0.argv[1:].value()) d0.result = d0.n0 diff --git a/tests/test_core.py b/tests/test_core.py index 3a16244..6aeda46 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,7 +6,6 @@ SUM = Resource("./tests/assets/fns/sum.py", adapter="dml-python-fork-adapter") ASYNC = Resource("./tests/assets/fns/async.py", adapter="dml-python-fork-adapter") -ERROR = Resource("./tests/assets/fns/error.py", adapter="dml-python-fork-adapter") TIMEOUT = Resource("./tests/assets/fns/timeout.py", adapter="dml-python-fork-adapter") @@ -128,10 +127,10 @@ def test_async_fn_error(self): with TemporaryDirectory() as fn_cache_dir: with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir): with Dml() as dml: - with self.assertRaises(Error): + with self.assertRaisesRegex(Error, r".*unsupported operand type.*"): with dml.new("d0", "d0") as d0: - d0.n0 = ERROR - d0.n1 = d0.n0(1, 2, 3) + d0.n0 = SUM + d0.n1 = d0.n0(1, 2, "asdf") info = [x for x in dml("dag", "list") if x["name"] == "d0"] self.assertEqual(len(info), 1) From b99f7d329391b79fde57dc45eccb3e676bb54a54 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Wed, 26 Mar 2025 23:32:09 -0700 Subject: [PATCH 2/7] wip --- submodules/daggerml_cli | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index 223d244..aeab3f4 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit 223d244dc8dd01d8c0cbb68e0d890c38787abe89 +Subproject commit aeab3f4bdb11fd479b24afbdff9f3485634d8eed From 470c8889f16d876674b672da28d4e7092415a967 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Sun, 20 Apr 2025 08:06:31 -0700 Subject: [PATCH 3/7] updated cli and moved to types in ids --- src/daggerml/core.py | 80 +++---------------------------------- submodules/daggerml_cli | 2 +- tests/assets/fns/envvars.py | 13 ++++++ tests/test_core.py | 28 ++++++++++++- 4 files changed, 46 insertions(+), 77 deletions(-) create mode 100644 tests/assets/fns/envvars.py diff --git a/src/daggerml/core.py b/src/daggerml/core.py index a5392bd..95c5f92 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -75,36 +75,10 @@ def to_data(obj): def from_json(text): - """ - Parse JSON string into Python objects. - - Parameters - ---------- - text : str - JSON string to parse - - Returns - ------- - Any - Deserialized Python object - """ return from_data(json.loads(text)) def to_json(obj): - """ - Convert Python object to JSON string. - - Parameters - ---------- - obj : Any - Object to serialize - - Returns - ------- - str - JSON string representation - """ return json.dumps(to_data(obj), separators=(",", ":")) @@ -112,7 +86,7 @@ def to_json(obj): @dataclass(frozen=True) class Ref: # noqa: F811 """ - Reference to a DaggerML node. + Reference to a DaggerML object. Parameters ---------- @@ -127,7 +101,7 @@ class Ref: # noqa: F811 @dataclass(frozen=True) class Resource: # noqa: F811 """ - Representation of an external resource. + Representation of an externally managed object with an identifier. Parameters ---------- @@ -183,24 +157,6 @@ def __str__(self): class Dml: # noqa: F811 - """ - Main DaggerML interface for creating and managing DAGs. - - Parameters - ---------- - data : Any, optional - Initial data for the DML instance - **kwargs : dict - Additional configuration options - - Examples - -------- - >>> from daggerml import Dml - >>> with Dml() as dml: - ... with dml.new("d0", "message") as dag: - ... pass - """ - def __init__(self, **kwargs): self.kwargs = kwargs self.opts = kwargs2opts(**kwargs) @@ -208,30 +164,6 @@ def __init__(self, **kwargs): self.tmpdirs = None def __call__(self, *args: str, input=None, as_text: bool = False) -> Any: - """ - Call the dml cli with the given arguments. - - Parameters - ---------- - *args : str - Arguments to pass to the dml cli - input : str, optional - data to pipe to `dml`. - as_text : bool, optional - If True, return the result as text, otherwise json - - Returns - ------- - Any - Result of the execution - - Examples - ----- - >>> dml = Dml() - >>> _ = dml("repo", "list") - - is equivalent to `dml repo list`. - """ resp = None path = shutil.which("dml") argv = [path, *self.opts, *args] @@ -256,8 +188,8 @@ def __enter__(self): "Use temporary config and project directories." self.tmpdirs = [TemporaryDirectory() for _ in range(2)] self.kwargs = { - "config_dir": getenv("DML_CONFIG_DIR") or self.tmpdirs[0].__enter__(), - "project_dir": getenv("DML_PROJECT_DIR") or self.tmpdirs[1].__enter__(), + "config_dir": getenv("DML_CONFIG_DIR") or self.tmpdirs[0].name, + "project_dir": getenv("DML_PROJECT_DIR") or self.tmpdirs[1].name, "repo": getenv("DML_REPO") or "test", "user": getenv("DML_USER") or "test", "branch": getenv("DML_BRANCH") or "main", @@ -266,12 +198,10 @@ def __enter__(self): self.opts = kwargs2opts(**self.kwargs) if self.kwargs["repo"] not in [x["name"] for x in self("repo", "list")]: self("repo", "create", self.kwargs["repo"]) - if self.kwargs["branch"] not in self("branch", "list"): - self("branch", "create", self.kwargs["branch"]) return self def __exit__(self, exc_type, exc_value, traceback): - [x.__exit__(exc_type, exc_value, traceback) for x in self.tmpdirs] + [x.cleanup() for x in self.tmpdirs] @property def envvars(self): diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index aeab3f4..b51538a 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit aeab3f4bdb11fd479b24afbdff9f3485634d8eed +Subproject commit b51538a21dbd7eafbeaedca39f595aebe164f786 diff --git a/tests/assets/fns/envvars.py b/tests/assets/fns/envvars.py new file mode 100644 index 0000000..ee05f19 --- /dev/null +++ b/tests/assets/fns/envvars.py @@ -0,0 +1,13 @@ +import json +import sys + +from daggerml import Dml + + +def pr(dump): + print(json.dumps({"dump": dump})) + + +with Dml() as dml: + with dml.new("test", "test", json.loads(sys.stdin.read())["dump"], pr) as d0: + d0.result = dml.kwargs diff --git a/tests/test_core.py b/tests/test_core.py index 6aeda46..2a1a27d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,6 +6,7 @@ SUM = Resource("./tests/assets/fns/sum.py", adapter="dml-python-fork-adapter") ASYNC = Resource("./tests/assets/fns/async.py", adapter="dml-python-fork-adapter") +ENVVARS = Resource("./tests/assets/fns/envvars.py", adapter="dml-python-fork-adapter") TIMEOUT = Resource("./tests/assets/fns/timeout.py", adapter="dml-python-fork-adapter") @@ -34,6 +35,30 @@ def test_init(self): }, ) + def test_init_kwargs(self): + with Dml(repo="does-not-exist", branch="unique-name") as dml: + self.assertDictEqual( + dml("status"), + { + "repo": "does-not-exist", + "branch": "unique-name", + "user": dml.kwargs.get("user"), + "config_dir": dml.kwargs.get("config_dir"), + "project_dir": dml.kwargs.get("project_dir"), + }, + ) + self.assertEqual(dml.envvars["DML_CONFIG_DIR"], dml.kwargs.get("config_dir")) + self.assertEqual( + dml.envvars, + { + "DML_REPO": "does-not-exist", + "DML_BRANCH": "unique-name", + "DML_USER": dml.kwargs.get("user"), + "DML_CONFIG_DIR": dml.kwargs.get("config_dir"), + "DML_PROJECT_DIR": dml.kwargs.get("project_dir"), + }, + ) + def test_dag(self): local_value = None @@ -73,7 +98,8 @@ def message_handler(dump): d0.result = result = d0.n0 self.assertIsInstance(local_value, str) dag = dml("dag", "list")[0] - self.assertEqual(dag["result"], result.ref.to.split("/", 1)[1]) + self.assertEqual(dag["result"], result.ref.to) + assert len(dml("dag", "list", "--all")) > 1 dml("dag", "delete", dag["name"], "Deleting dag") dml("repo", "gc", as_text=True) From 886429ff7fbe6e2c142fee264f875cb9928fbba8 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Sat, 28 Jun 2025 00:03:55 -0700 Subject: [PATCH 4/7] collection backtracking --- pyproject.toml | 9 +- src/daggerml/__init__.py | 4 +- src/daggerml/core.py | 200 +++++++++++++----- submodules/daggerml_cli | 2 +- tests/assets/fns/async.py | 28 ++- tests/assets/fns/envvars.py | 13 -- tests/assets/fns/sum.py | 19 +- tests/assets/fns/timeout.py | 3 +- tests/test_core.py | 393 +++++++++++++++++++++--------------- 9 files changed, 403 insertions(+), 268 deletions(-) delete mode 100644 tests/assets/fns/envvars.py diff --git a/pyproject.toml b/pyproject.toml index b3a6f11..acf5f22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ minversion = "6.0" addopts = "-ra --ignore=submodules/" testpaths = [ "tests", + "src/daggerml", ] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", @@ -75,10 +76,10 @@ artifacts = [ [tool.hatch.envs.default] python="3.10" -dependencies = [ - "twine", - "daggerml_cli @ {root:uri}/submodules/daggerml_cli", - "daggerml[test]", +features = ["test"] +dependencies = ["twine"] +pre-install-commands = [ + "pip install -e {root:uri}/submodules/daggerml_cli", ] [tool.hatch.envs.default.scripts] diff --git a/src/daggerml/__init__.py b/src/daggerml/__init__.py index d334db2..607f9a5 100644 --- a/src/daggerml/__init__.py +++ b/src/daggerml/__init__.py @@ -5,7 +5,7 @@ with strong typing support and a context-manager based interface. """ -from daggerml.core import Dml, Error, Node, Resource +from daggerml.core import Dag, Dml, Error, Node, Resource try: from daggerml.__about__ import __version__ @@ -13,4 +13,4 @@ __version__ = "local" -__all__ = ("Dml", "Error", "Node", "Resource") +__all__ = ("Dag", "Dml", "Error", "Node", "Resource") diff --git a/src/daggerml/core.py b/src/daggerml/core.py index 95c5f92..e76926e 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -4,10 +4,9 @@ import subprocess import time from dataclasses import dataclass, field, fields -from os import getenv from tempfile import TemporaryDirectory from traceback import format_exception -from typing import Any, Callable, NewType, Optional, Union +from typing import Any, Callable, Optional, Union from daggerml.util import ( BackoffWithJitter, @@ -24,13 +23,7 @@ DATA_TYPE = {} -Node = NewType("Node", None) -Resource = NewType("Resource", None) -Error = NewType("Error", None) -Ref = NewType("Ref", None) -Dml = NewType("Dml", None) -Dag = NewType("Dag", None) -Scalar = Union[str, int, float, bool, type(None), Resource, Node] +Scalar = Union[str, int, float, bool, type(None), "Resource", "Node"] Collection = Union[list, tuple, set, dict] @@ -156,17 +149,67 @@ def __str__(self): return "".join(self.context.get("trace", [self.message])) -class Dml: # noqa: F811 - def __init__(self, **kwargs): - self.kwargs = kwargs - self.opts = kwargs2opts(**kwargs) - self.token = None - self.tmpdirs = None +@dataclass +class Dml: + """ + DaggerML cli client wrapper + """ + + config_dir: Union[str, None] = None + project_dir: Union[str, None] = None + cache_path: Union[str, None] = None + repo: Union[str, None] = None + user: Union[str, None] = None + branch: Union[str, None] = None + token: Union[str, None] = None + tmpdirs: dict[str, TemporaryDirectory] = field(default_factory=dict) + + @property + def kwargs(self) -> dict: + out = { + "config_dir": self.config_dir, + "project_dir": self.project_dir, + "cache_path": self.cache_path, + "repo": self.repo, + "user": self.user, + "branch": self.branch, + } + return {k: v for k, v in out.items() if v is not None} + + @classmethod + def temporary(cls, repo="test", user="user", branch="main", cache_path=None, **kwargs) -> "Dml": + """ + Create a temporary Dml instance with specified parameters. + + Parameters + ---------- + repo : str, default="test" + user : str, default="user" + branch : str, default="main" + **kwargs : dict + Additional keyword arguments for configuration include `config_dir`, `project_dir`, and `cache_path`. + If any of those is provided, it will not create a temporary directory for that parameter. If provided and + set to None, the dml default will be used. + """ + tmpdirs = {k: TemporaryDirectory(prefix="dml-") for k in ["config_dir", "project_dir"] if k not in kwargs} + self = cls( + repo=repo, + user=user, + branch=branch, + cache_path=cache_path, + **{k: v.name for k, v in tmpdirs.items()}, + tmpdirs=tmpdirs, + ) + if self.kwargs["repo"] not in [x["name"] for x in self("repo", "list")]: + self("repo", "create", self.kwargs["repo"]) + return self + + def cleanup(self): + [x.cleanup() for x in self.tmpdirs.values()] def __call__(self, *args: str, input=None, as_text: bool = False) -> Any: - resp = None path = shutil.which("dml") - argv = [path, *self.opts, *args] + argv = [path, *kwargs2opts(**self.kwargs), *args] resp = subprocess.run(argv, check=True, capture_output=True, text=True, input=input) if resp.stderr: log.error(resp.stderr.rstrip()) @@ -185,34 +228,21 @@ def invoke(*args, **kwargs): return invoke def __enter__(self): - "Use temporary config and project directories." - self.tmpdirs = [TemporaryDirectory() for _ in range(2)] - self.kwargs = { - "config_dir": getenv("DML_CONFIG_DIR") or self.tmpdirs[0].name, - "project_dir": getenv("DML_PROJECT_DIR") or self.tmpdirs[1].name, - "repo": getenv("DML_REPO") or "test", - "user": getenv("DML_USER") or "test", - "branch": getenv("DML_BRANCH") or "main", - **self.kwargs, - } - self.opts = kwargs2opts(**self.kwargs) - if self.kwargs["repo"] not in [x["name"] for x in self("repo", "list")]: - self("repo", "create", self.kwargs["repo"]) return self def __exit__(self, exc_type, exc_value, traceback): - [x.cleanup() for x in self.tmpdirs] + self.cleanup() @property def envvars(self): return {f"DML_{k.upper()}": str(v) for k, v in self.kwargs.items()} - def new(self, name="", message="", data=None, message_handler=None): + def new(self, name="", message="", data=None, message_handler=None) -> "Dag": opts = kwargs2opts(dump="-") if data else [] token = self("api", "create", *opts, name, message, input=data, as_text=True) return Dag(replace(self, token=token), message_handler) - def load(self, name: Union[str, Node], recurse=False) -> Dag: + def load(self, name: Union[str, "Node"], recurse=False) -> "Dag": return Dag(replace(self, token=None), _ref=self.get_dag(name, recurse=recurse)) @@ -222,7 +252,7 @@ class Boxed: @dataclass -class Dag: # noqa: F811 +class Dag: _dml: Dml _message_handler: Optional[Callable] = None _ref: Optional[Ref] = None @@ -244,16 +274,16 @@ def __exit__(self, exc_type, exc_value, traceback): if exc_value is not None: self._commit(Error(exc_value)) - def __getitem__(self, name): + def __getitem__(self, name) -> "Node": return Node(self, self._dml.get_node(name, self._ref)) - def __setitem__(self, name, value): + def __setitem__(self, name, value) -> "Node": assert not self._ref if isinstance(value, Ref): return self._dml.set_node(name, value) return self._put(value, name=name) - def __len__(self): + def __len__(self) -> int: return len(self._dml.get_names(self._ref)) def __iter__(self): @@ -278,12 +308,12 @@ def __getattr__(self, name): return self.__getitem__(name) @property - def argv(self) -> Node: + def argv(self) -> "Node": "Access the dag's argv node" return Node(self, self._dml.get_argv(self._ref)) @property - def result(self) -> Node: + def result(self) -> "Node": ref = self._dml.get_result(self._ref) assert ref, f"'{self.__class__.__name__}' has no attribute 'result'" return Node(self, ref) if ref else ref @@ -297,14 +327,14 @@ def keys(self) -> list[str]: return lambda: self._dml.get_names(self._ref).keys() @property - def values(self) -> list[Node]: + def values(self) -> list["Node"]: def result(): nodes = self._dml.get_names(self._ref).values() return [Node(self, x) for x in nodes] return result - def _put(self, value: Union[Scalar, Collection], *, name=None, doc=None) -> Node: + def _put(self, value: Union[Scalar, Collection], *, name=None, doc=None) -> "Node": """ Add a value to the DAG. @@ -329,7 +359,7 @@ def _put(self, value: Union[Scalar, Collection], *, name=None, doc=None) -> Node ) return Node(self, self._dml.put_literal(value, name=name, doc=doc)) - def _load(self, dag_name, node=None, *, name=None, doc=None) -> Node: + def _load(self, dag_name, node=None, *, name=None, doc=None) -> "Node": """ Load a DAG by name. @@ -350,7 +380,7 @@ def _load(self, dag_name, node=None, *, name=None, doc=None) -> Node: dag = dag_name if isinstance(dag_name, str) else dag_name._ref return Node(self, self._dml.put_load(dag, node, name=name, doc=doc)) - def _commit(self, value) -> Node: + def _commit(self, value) -> "Node": """ Commit a value to the DAG. @@ -389,7 +419,63 @@ def __repr__(self): def __hash__(self): return hash(self.ref) - def __getitem__(self, key: Union[slice, str, int, Node]) -> Node: + @property + def argv(self) -> "Node": + "Access the node's argv list" + return [Node(self.dag, x) for x in self.dag._dml.get_argv(self)] + + def load(self, *keys: Union[str, int], recurse: bool = False) -> Dag: + """ + Convenience wrapper around `dml.load(node)` + + If `key` is provided, it considers this node to be a collection created + by the appropriate method and loads the dag that corresponds to this key + + Parameters + ---------- + key : str, optional + Key to load from the DAG. If not provided, the entire DAG is loaded. + + Returns + ------- + Dag + The dag that this node was imported from (or in the case of a function call, this returns the fndag) + d0 = dml.new("d0", "d0") + l0 = d0._put(42) + c0 = d0._put({"a": 1, "b": [l0, "23"]}) + assert c0.load("b", 0) == l0 + assert c0.load("b").load(0) == l0 + assert c0["b"][0] != l0 + + Examples + -------- + >>> dml = Dml.temporary() + >>> dag = dml.new("test", "test") + >>> l0 = dag._put(42) + >>> c0 = dag._put({"a": 1, "b": [l0, "23"]}) + >>> assert c0.load("b", 0) == l0 + >>> assert c0.load("b").load(0) == l0 + >>> assert c0["b"][0] != l0 # this is a different node, not the same as l0 + >>> dml.cleanup() + """ + if len(keys) == 0: + return self.dag._dml.load(self, recurse=recurse) + keys = list(keys) + while len(keys) > 0: + key = keys.pop(0) + fn, *args = (x.value() for x in self.argv) + if fn.uri == "daggerml:list": + assert isinstance(key, int), "list keys must be integers" + elif fn.uri == "daggerml:dict": + assert isinstance(key, str), "dict keys must be strings" + i = args.index(key) + key = i + 1 + else: + raise Error(f"{fn.uri} is not a collection constructor") + self = self.argv[key + 1] + return self + + def __getitem__(self, key: Union[slice, str, int, "Node"]) -> "Node": """ Get the `key` item. It should be the same as if you were working on the actual value. @@ -406,8 +492,16 @@ def __getitem__(self, key: Union[slice, str, int, Node]) -> Node: Examples -------- - >>> node = dag._put({"a": 1, "b": 5}) - >>> assert node["a"].value() == 1 + >>> dml = Dml.temporary() + >>> dag = dml.new("test", "test") + >>> node = dag._put({"a": 1, "b": [5, 6]}) + >>> nested = node["a"] + >>> isinstance(nested, Node) + True + >>> nested.value() + 1 + >>> node["b"][0].value() # lists too + 5 """ if isinstance(key, slice): key = [key.start, key.stop, key.step] @@ -467,7 +561,7 @@ def __iter__(self): for k in self.keys(): yield k - def __call__(self, *args, name=None, doc=None, retry=False, sleep=None, timeout=0) -> Node: + def __call__(self, *args, name=None, doc=None, sleep=None, timeout=0) -> "Node": """ Call this node as a function. @@ -479,8 +573,6 @@ def __call__(self, *args, name=None, doc=None, retry=False, sleep=None, timeout= Name for the result node doc : str, optional Documentation - retry : bool, default=False - Retry a failed run? sleep : callable, optional A nullary function that returns sleep time in milliseconds timeout : int, default=30000 @@ -501,16 +593,14 @@ def __call__(self, *args, name=None, doc=None, retry=False, sleep=None, timeout= sleep = sleep or BackoffWithJitter() args = [self.dag._put(x) for x in args] end = current_time_millis() + timeout - kw = {"retry": retry} while timeout <= 0 or current_time_millis() < end: - resp = self.dag._dml.start_fn([self, *args], name=name, doc=doc, **kw) + resp = self.dag._dml.start_fn([self, *args], name=name, doc=doc) if resp: return Node(self.dag, resp) - kw["retry"] = False time.sleep(sleep() / 1000) raise TimeoutError(f"invoking function: {self.value()}") - def keys(self, *, name=None, doc=None) -> Node: + def keys(self, *, name=None, doc=None) -> "Node": """ Get the keys of a dictionary node. @@ -528,7 +618,7 @@ def keys(self, *, name=None, doc=None) -> Node: """ return Node(self.dag, self.dag._dml.keys(self, name=name, doc=doc)) - def len(self, *, name=None, doc=None) -> Node: + def len(self, *, name=None, doc=None) -> "Node": """ Get the length of a collection node. @@ -546,7 +636,7 @@ def len(self, *, name=None, doc=None) -> Node: """ return Node(self.dag, self.dag._dml.len(self, name=name, doc=doc)) - def type(self, *, name=None, doc=None) -> Node: + def type(self, *, name=None, doc=None) -> "Node": """ Get the type of this node. diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index b51538a..40a7d10 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit b51538a21dbd7eafbeaedca39f595aebe164f786 +Subproject commit 40a7d10b97170640f3419ec6188430139380be16 diff --git a/tests/assets/fns/async.py b/tests/assets/fns/async.py index e913761..295f795 100644 --- a/tests/assets/fns/async.py +++ b/tests/assets/fns/async.py @@ -4,22 +4,18 @@ from daggerml import Dml - -def pr(dump): - print(json.dumps({"dump": dump})) - - -with Dml() as dml: +if __name__ == "__main__": stdin = json.loads(sys.stdin.read()) - cache_dir = os.getenv("DML_FN_CACHE_DIR", "") - cache_file = os.path.join(cache_dir, stdin["cache_key"]) - debug_file = os.path.join(cache_dir, "debug") + with Dml.temporary(cache_path=stdin["cache_path"]) as dml: + cache_dir = os.getenv("DML_FN_CACHE_DIR", "") + cache_file = os.path.join(cache_dir, stdin["cache_key"]) + debug_file = os.path.join(cache_dir, "debug") - with open(debug_file, "a") as f: - f.write("ASYNC EXECUTING\n") + with open(debug_file, "a") as f: + f.write("ASYNC EXECUTING\n") - if os.path.isfile(cache_file): - with dml.new("test", "test", stdin["dump"], pr) as d0: - d0.result = sum(d0.argv[1:].value()) - else: - open(cache_file, "w").close() + if os.path.isfile(cache_file): + with dml.new("test", "test", stdin["dump"], print) as d0: + d0.result = sum(d0.argv[1:].value()) + else: + open(cache_file, "w").close() diff --git a/tests/assets/fns/envvars.py b/tests/assets/fns/envvars.py deleted file mode 100644 index ee05f19..0000000 --- a/tests/assets/fns/envvars.py +++ /dev/null @@ -1,13 +0,0 @@ -import json -import sys - -from daggerml import Dml - - -def pr(dump): - print(json.dumps({"dump": dump})) - - -with Dml() as dml: - with dml.new("test", "test", json.loads(sys.stdin.read())["dump"], pr) as d0: - d0.result = dml.kwargs diff --git a/tests/assets/fns/sum.py b/tests/assets/fns/sum.py index 9176194..29ca54a 100644 --- a/tests/assets/fns/sum.py +++ b/tests/assets/fns/sum.py @@ -1,15 +1,14 @@ import json import sys +from uuid import uuid4 from daggerml import Dml - -def pr(dump): - print(json.dumps({"dump": dump})) - - -with Dml() as dml: - with dml.new("test", "test", json.loads(sys.stdin.read())["dump"], pr) as d0: - d0.num_args = len(d0.argv[1:]) - d0.n0 = sum(d0.argv[1:].value()) - d0.result = d0.n0 +if __name__ == "__main__": + stdin = json.loads(sys.stdin.read()) + with Dml.temporary(cache_path=stdin["cache_path"]) as dml: + with dml.new("test", "test", stdin["dump"], print) as d0: + d0.num_args = len(d0.argv[1:]) + d0.n0 = sum(d0.argv[1:].value()) + d0.uuid = str(uuid4()) + d0.result = d0.n0 diff --git a/tests/assets/fns/timeout.py b/tests/assets/fns/timeout.py index 1669f3c..56c09fe 100644 --- a/tests/assets/fns/timeout.py +++ b/tests/assets/fns/timeout.py @@ -1 +1,2 @@ -exit() +if __name__ == "__main__": + exit() diff --git a/tests/test_core.py b/tests/test_core.py index 2a1a27d..5a6355e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,9 +12,10 @@ class TestBasic(TestCase): def test_init(self): - with Dml() as dml: + with Dml.temporary() as dml: + status = dml("status") self.assertDictEqual( - dml("status"), + {k: v for k, v in status.items() if k != "cache_path"}, { "repo": dml.kwargs.get("repo"), "branch": dml.kwargs.get("branch"), @@ -23,9 +24,10 @@ def test_init(self): "project_dir": dml.kwargs.get("project_dir"), }, ) + assert status["cache_path"].startswith(os.path.expanduser("~")) self.assertEqual(dml.envvars["DML_CONFIG_DIR"], dml.kwargs.get("config_dir")) self.assertEqual( - dml.envvars, + {k: v for k, v in dml.envvars.items() if k != "DML_CACHE_PATH"}, { "DML_REPO": dml.kwargs.get("repo"), "DML_BRANCH": dml.kwargs.get("branch"), @@ -36,28 +38,31 @@ def test_init(self): ) def test_init_kwargs(self): - with Dml(repo="does-not-exist", branch="unique-name") as dml: - self.assertDictEqual( - dml("status"), - { - "repo": "does-not-exist", - "branch": "unique-name", - "user": dml.kwargs.get("user"), - "config_dir": dml.kwargs.get("config_dir"), - "project_dir": dml.kwargs.get("project_dir"), - }, - ) - self.assertEqual(dml.envvars["DML_CONFIG_DIR"], dml.kwargs.get("config_dir")) - self.assertEqual( - dml.envvars, - { - "DML_REPO": "does-not-exist", - "DML_BRANCH": "unique-name", - "DML_USER": dml.kwargs.get("user"), - "DML_CONFIG_DIR": dml.kwargs.get("config_dir"), - "DML_PROJECT_DIR": dml.kwargs.get("project_dir"), - }, - ) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(repo="does-not-exist", branch="unique-name", cache_path=cache_path) as dml: + self.assertDictEqual( + dml("status"), + { + "repo": "does-not-exist", + "branch": "unique-name", + "user": dml.kwargs.get("user"), + "config_dir": dml.kwargs.get("config_dir"), + "project_dir": dml.kwargs.get("project_dir"), + "cache_path": dml.kwargs.get("cache_path"), + }, + ) + self.assertEqual(dml.envvars["DML_CONFIG_DIR"], dml.kwargs.get("config_dir")) + self.assertEqual( + dml.envvars, + { + "DML_REPO": "does-not-exist", + "DML_BRANCH": "unique-name", + "DML_USER": dml.kwargs.get("user"), + "DML_CONFIG_DIR": dml.kwargs.get("config_dir"), + "DML_PROJECT_DIR": dml.kwargs.get("project_dir"), + "DML_CACHE_PATH": cache_path, + }, + ) def test_dag(self): local_value = None @@ -66,167 +71,223 @@ def message_handler(dump): nonlocal local_value local_value = dump - with Dml() as dml: - d0 = dml.new("d0", "d0", message_handler=message_handler) - d0.n0 = [42] - self.assertIsInstance(d0.n0, Node) - self.assertEqual(d0.n0.value(), [42]) - self.assertEqual(d0.n0.len().value(), 1) - self.assertEqual(d0.n0.type().value(), "list") - d0["x0"] = d0.n0 - self.assertEqual(d0["x0"], d0.n0) - self.assertEqual(d0.x0, d0.n0) - d0.x1 = 42 - self.assertEqual(d0["x1"].value(), 42) - self.assertEqual(d0.x1.value(), 42) - d0.n1 = d0.n0[0] - self.assertIsInstance(d0.n1, Node) - self.assertEqual([x for x in d0.n0], [d0.n1]) - self.assertEqual(d0.n1.value(), 42) - d0.n2 = {"x": d0.n0, "y": "z"} - self.assertNotEqual(d0.n2["x"], d0.n0) - self.assertEqual(d0.n2["x"].value(), d0.n0.value()) - d0.n3 = list(d0.n2.items()) - self.assertIsInstance([x for x in d0.n3], list) - self.assertDictEqual( - {k.value(): v.value() for k, v in d0.n2.items()}, - {"x": d0.n0.value(), "y": "z"}, - ) - d0.n4 = [1, 2, 3, 4, 5] - d0.n5 = d0.n4[1:] - self.assertListEqual([x.value() for x in d0.n5], [2, 3, 4, 5]) - d0.result = result = d0.n0 - self.assertIsInstance(local_value, str) - dag = dml("dag", "list")[0] - self.assertEqual(dag["result"], result.ref.to) - assert len(dml("dag", "list", "--all")) > 1 - dml("dag", "delete", dag["name"], "Deleting dag") - dml("repo", "gc", as_text=True) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + d0 = dml.new("d0", "d0", message_handler=message_handler) + d0.n0 = [42] + self.assertIsInstance(d0.n0, Node) + self.assertEqual(d0.n0.value(), [42]) + self.assertEqual(d0.n0.len().value(), 1) + self.assertEqual(d0.n0.type().value(), "list") + d0["x0"] = d0.n0 + self.assertEqual(d0["x0"], d0.n0) + self.assertEqual(d0.x0, d0.n0) + d0.x1 = 42 + self.assertEqual(d0["x1"].value(), 42) + self.assertEqual(d0.x1.value(), 42) + d0.n1 = d0.n0[0] + self.assertIsInstance(d0.n1, Node) + self.assertEqual([x.value() for x in d0.n0], [d0.n1.value()]) + self.assertEqual(d0.n1.value(), 42) + d0.n2 = {"x": d0.n0, "y": "z"} + self.assertNotEqual(d0.n2["x"], d0.n0) + self.assertEqual(d0.n2["x"].value(), d0.n0.value()) + d0.n3 = list(d0.n2.items()) + self.assertIsInstance([x for x in d0.n3], list) + self.assertDictEqual( + {k.value(): v.value() for k, v in d0.n2.items()}, + {"x": d0.n0.value(), "y": "z"}, + ) + d0.n4 = [1, 2, 3, 4, 5] + d0.n5 = d0.n4[1:] + self.assertListEqual([x.value() for x in d0.n5], [2, 3, 4, 5]) + d0.result = result = d0.n0 + self.assertIsInstance(local_value, str) + dag = dml("dag", "list")[0] + self.assertEqual(dag["result"], result.ref.to) + assert len(dml("dag", "list", "--all")) > 1 + dml("dag", "delete", dag["name"], "Deleting dag") + dml("repo", "gc", as_text=True) def test_list_attrs(self): - with Dml() as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = [0] - assert d0.n0.contains(1).value() is False - assert d0.n0.contains(0).value() is True - assert 0 in d0.n0 - d0.n1 = d0.n0.append(1) - assert d0.n1.value() == [0, 1] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.n0 = [0] + assert d0.n0.contains(1).value() is False + assert d0.n0.contains(0).value() is True + assert 0 in d0.n0 + d0.n1 = d0.n0.append(1) + assert d0.n1.value() == [0, 1] def test_set_attrs(self): - with Dml() as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = {0} - assert d0.n0.contains(1).value() is False - assert d0.n0.contains(0).value() is True - assert 0 in d0.n0 - d0.n1 = d0.n0.append(1) - assert d0.n1.value() == {0, 1} + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.n0 = {0} + assert d0.n0.contains(1).value() is False + assert d0.n0.contains(0).value() is True + assert 0 in d0.n0 + d0.n1 = d0.n0.append(1) + assert d0.n1.value() == {0, 1} def test_dict_attrs(self): - with Dml() as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = {"x": 42} - assert d0.n0.contains("y").value() is False - assert d0.n0.contains("x").value() is True - assert "y" not in d0.n0 - assert "x" in d0.n0 - d0.n1 = d0.n0.assoc("y", 3) - assert d0.n1.value() == {"x": 42, "y": 3} - d0.n2 = d0.n1.update({"z": 1, "a": 2}) - assert d0.n2.value() == {"a": 2, "x": 42, "y": 3, "z": 1} + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.n0 = {"x": 42} + assert d0.n0.contains("y").value() is False + assert d0.n0.contains("x").value() is True + assert "y" not in d0.n0 + assert "x" in d0.n0 + d0.n1 = d0.n0.assoc("y", 3) + assert d0.n1.value() == {"x": 42, "y": 3} + d0.n2 = d0.n1.update({"z": 1, "a": 2}) + assert d0.n2.value() == {"a": 2, "x": 42, "y": 3, "z": 1} + + def test_load_constructors(self): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + d0 = dml.new("d0", "d0") + l0 = d0._put(42) + c0 = d0._put({"a": 1, "b": [l0, "23"]}) + assert c0.load("b", 0) == l0 + assert c0.load("b").load(0) == l0 + assert c0["b"][0] != l0 def test_async_fn_ok(self): - with TemporaryDirectory() as fn_cache_dir: + with TemporaryDirectory(prefix="dml-test-") as fn_cache_dir: with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir): debug_file = os.path.join(fn_cache_dir, "debug") - with Dml() as dml: - with dml.new("d0", "d0") as d0: - d0.n0 = ASYNC - d0.n1 = d0.n0(1, 2, 3) - d0.result = result = d0.n1 - self.assertEqual(result.value(), 6) - with open(debug_file, "r") as f: - self.assertEqual(len([1 for _ in f]), 2) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.n0 = ASYNC + d0.n1 = d0.n0(1, 2, 3) + d0.result = result = d0.n1 + self.assertEqual(result.value(), 6) + with open(debug_file, "r") as f: + self.assertEqual(len([1 for _ in f]), 2) def test_async_fn_error(self): - with TemporaryDirectory() as fn_cache_dir: + with TemporaryDirectory(prefix="dml-test-") as fn_cache_dir: with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir): - with Dml() as dml: - with self.assertRaisesRegex(Error, r".*unsupported operand type.*"): - with dml.new("d0", "d0") as d0: - d0.n0 = SUM - d0.n1 = d0.n0(1, 2, "asdf") - info = [x for x in dml("dag", "list") if x["name"] == "d0"] - self.assertEqual(len(info), 1) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with self.assertRaisesRegex(Error, r".*unsupported operand type.*"): + with dml.new("d0", "d0") as d0: + d0.n0 = ASYNC + d0.n1 = d0.n0(1, 2, "asdf") + info = [x for x in dml("dag", "list") if x["name"] == "d0"] + self.assertEqual(len(info), 1) def test_async_fn_timeout(self): - with Dml() as dml: - with self.assertRaises(TimeoutError): - with dml.new("d0", "d0") as d0: - d0.n0 = TIMEOUT - d0.n0(1, 2, 3, timeout=1000) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with self.assertRaises(TimeoutError): + with dml.new("d0", "d0") as d0: + d0.n0 = TIMEOUT + d0.n0(1, 2, 3, timeout=1000) def test_load(self): - with Dml() as dml: - with dml.new("d0", "d0") as d0: - # only fn dags have an argv attribute, expect AttributeError - with self.assertRaises(Error): - d0.argv # noqa: B018 - # d0.result hasn't been assigned yet but it can't raise an - # AttributeError because we also have __getitem__ implemented - # which would then be called, so an AssertionError is raised. - with self.assertRaises(AssertionError): - d0.result # noqa: B018 - d0.n0 = 42 - self.assertEqual(type(d0.n0), Node) - d0.n1 = 420 - d0.result = d0.n0 - dl = dml.load("d0") - self.assertEqual(type(dl), Dag) - self.assertEqual(type(dl.n0), Node) - self.assertEqual(dl.n0.value(), 42) - self.assertEqual(type(dl.result), Node) - self.assertEqual(dl.result.value(), 42) - self.assertEqual(len(dl), 2) - self.assertEqual(set(dl.keys()), {"n0", "n1"}) - self.assertEqual(set(dl.values()), {dl.n0, dl.n1}) - for x in dl.values(): - self.assertIsInstance(x, Node) - with dml.new("d1", "d1") as d1: - d0 = dml.load("d0") - self.assertEqual(d0.result.value(), 42) - self.assertEqual(d0.n0.value(), 42) - self.assertEqual(d0["n0"].value(), 42) - - self.assertEqual(len(d0), 2) - self.assertEqual(set(d0.keys()), {"n0", "n1"}) - self.assertEqual(set(d0.values()), {d0.n0, d0.n1}) - # d0 has been committed: its nodes are now imports - for x in d0.values(): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + # only fn dags have an argv attribute, expect AttributeError + with self.assertRaises(Error): + d0.argv # noqa: B018 + # d0.result hasn't been assigned yet but it can't raise an + # AttributeError because we also have __getitem__ implemented + # which would then be called, so an AssertionError is raised. + with self.assertRaises(AssertionError): + d0.result # noqa: B018 + d0.n0 = 42 + self.assertEqual(type(d0.n0), Node) + d0.n1 = 420 + d0.result = d0.n0 + dl = dml.load("d0") + self.assertEqual(type(dl), Dag) + self.assertEqual(type(dl.n0), Node) + self.assertEqual(dl.n0.value(), 42) + self.assertEqual(type(dl.result), Node) + self.assertEqual(dl.result.value(), 42) + self.assertEqual(len(dl), 2) + self.assertEqual(set(dl.keys()), {"n0", "n1"}) + self.assertEqual(set(dl.values()), {dl.n0, dl.n1}) + for x in dl.values(): self.assertIsInstance(x, Node) + with dml.new("d1", "d1") as d1: + d0 = dml.load("d0") + self.assertEqual(d0.result.value(), 42) + self.assertEqual(d0.n0.value(), 42) + self.assertEqual(d0["n0"].value(), 42) + + self.assertEqual(len(d0), 2) + self.assertEqual(set(d0.keys()), {"n0", "n1"}) + self.assertEqual(set(d0.values()), {d0.n0, d0.n1}) + # d0 has been committed: its nodes are now imports + for x in d0.values(): + self.assertIsInstance(x, Node) - d1.n0 = 42 - d1.n1 = 420 - self.assertEqual(set(d1.keys()), {"n0", "n1"}) - # d1 has not yet been committed: its nodes are of type Node - for x in d1.values(): - self.assertEqual(type(x), Node) + d1.n0 = 42 + d1.n1 = 420 + self.assertEqual(set(d1.keys()), {"n0", "n1"}) + # d1 has not yet been committed: its nodes are of type Node + for x in d1.values(): + self.assertEqual(type(x), Node) - d1.result = d0.result + d1.result = d0.result def test_load_recursing(self): nums = [1, 2, 3] - with Dml() as dml: - with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): - with dml.new("d0", "d0") as d0: - d0.n0 = SUM - d0.n1 = d0.n0(*nums) - assert d0.n1.dag == d0 - d0.result = d0.n1 - d1 = dml.new("d1", "d1") - d1.n1 = dml.load("d0").n1 - assert d1.n1.dag == d1 - d1.n2 = dml.load(d1.n1, recurse=True).num_args - assert d1.n2.value() == len(nums) - assert d1.n1.value() == sum(nums) + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as d0: + d0.n0 = SUM + d0.n1 = d0.n0(*nums) + assert d0.n1.dag == d0 + d0.result = d0.n1 + d1 = dml.new("d1", "d1") + d1.n1 = dml.load("d0").n1 + assert d1.n1.dag == d1 + d1.n2 = dml.load(d1.n1, recurse=True).num_args + assert d1.n2.value() == len(nums) + assert d1.n1.value() == sum(nums) + assert isinstance(d1.n1.load(), Dag) + + def test_caching(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + config_dir = dml.config_dir + with dml.new("d0", "d0") as d1: + d1.n0 = SUM + d1.n1 = d1.n0(*nums) + uid = d1.n1.load().uuid.value() + with Dml.temporary(cache_path=cache_path) as dml: + assert dml.config_dir != config_dir, "Config dir should not be the same" + with dml.new("d1", "d0") as d1: + d1.n0 = SUM + d1.n1 = d1.n0(*nums) + uid1 = d1.n1.load().uuid.value() + assert uid == uid1, "Cached dag should have the same UUID" + + def test_no_caching(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + config_dir = dml.config_dir + with dml.new("d0", "d0") as d1: + d1.n0 = SUM + d1.n1 = d1.n0(*nums) + assert isinstance(d1.n1, Node) + uid = d1.n1.load().uuid.value() + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + assert dml.config_dir != config_dir, "Config dir should not be the same" + with dml.new("d1", "d0") as d1: + d1.n0 = SUM + d1.n1 = d1.n0(*nums) + uid1 = d1.n1.load().uuid.value() + assert uid != uid1, "Cached dag should have the same UUID" From 855f9a59831531a1b6785ef5852f85ae46efc100 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Sat, 28 Jun 2025 09:23:50 -0700 Subject: [PATCH 5/7] cleanup --- src/daggerml/core.py | 28 +++++----------------------- submodules/daggerml_cli | 2 +- tests/test_core.py | 1 + 3 files changed, 7 insertions(+), 24 deletions(-) diff --git a/src/daggerml/core.py b/src/daggerml/core.py index e76926e..5b54c43 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -424,7 +424,7 @@ def argv(self) -> "Node": "Access the node's argv list" return [Node(self.dag, x) for x in self.dag._dml.get_argv(self)] - def load(self, *keys: Union[str, int], recurse: bool = False) -> Dag: + def load(self, *keys: Union[str, int]) -> Dag: """ Convenience wrapper around `dml.load(node)` @@ -433,19 +433,13 @@ def load(self, *keys: Union[str, int], recurse: bool = False) -> Dag: Parameters ---------- - key : str, optional + *keys : str, optional Key to load from the DAG. If not provided, the entire DAG is loaded. Returns ------- Dag The dag that this node was imported from (or in the case of a function call, this returns the fndag) - d0 = dml.new("d0", "d0") - l0 = d0._put(42) - c0 = d0._put({"a": 1, "b": [l0, "23"]}) - assert c0.load("b", 0) == l0 - assert c0.load("b").load(0) == l0 - assert c0["b"][0] != l0 Examples -------- @@ -459,21 +453,9 @@ def load(self, *keys: Union[str, int], recurse: bool = False) -> Dag: >>> dml.cleanup() """ if len(keys) == 0: - return self.dag._dml.load(self, recurse=recurse) - keys = list(keys) - while len(keys) > 0: - key = keys.pop(0) - fn, *args = (x.value() for x in self.argv) - if fn.uri == "daggerml:list": - assert isinstance(key, int), "list keys must be integers" - elif fn.uri == "daggerml:dict": - assert isinstance(key, str), "dict keys must be strings" - i = args.index(key) - key = i + 1 - else: - raise Error(f"{fn.uri} is not a collection constructor") - self = self.argv[key + 1] - return self + return self.dag._dml.load(self) + data = self.dag._dml("node", "backtrack", self.ref.to, *map(str, keys)) + return Node(self.dag, from_data(data)) def __getitem__(self, key: Union[slice, str, int, "Node"]) -> "Node": """ diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index 40a7d10..37ed44a 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit 40a7d10b97170640f3419ec6188430139380be16 +Subproject commit 37ed44aa000bef7195f139bbd07fe9f037a595f7 diff --git a/tests/test_core.py b/tests/test_core.py index 5a6355e..66b60fd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -152,6 +152,7 @@ def test_load_constructors(self): l0 = d0._put(42) c0 = d0._put({"a": 1, "b": [l0, "23"]}) assert c0.load("b", 0) == l0 + assert c0.load("b", 1).value() == "23" assert c0.load("b").load(0) == l0 assert c0["b"][0] != l0 From b6a52e96a1a93da31f92ffb0348d3bff26750053 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Sat, 28 Jun 2025 20:40:26 -0700 Subject: [PATCH 6/7] wip --- src/daggerml/core.py | 5 ++++- submodules/daggerml_cli | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/daggerml/core.py b/src/daggerml/core.py index 5b54c43..869759a 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -210,7 +210,10 @@ def cleanup(self): def __call__(self, *args: str, input=None, as_text: bool = False) -> Any: path = shutil.which("dml") argv = [path, *kwargs2opts(**self.kwargs), *args] - resp = subprocess.run(argv, check=True, capture_output=True, text=True, input=input) + resp = subprocess.run(argv, check=False, capture_output=True, text=True, input=input) + if resp.returncode != 0: + raise_ex(Error(resp.stderr or "DML command failed", code="DmlError")) + log.debug("dml command stderr: %s", resp.stderr) if resp.stderr: log.error(resp.stderr.rstrip()) try: diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index 37ed44a..32717a0 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit 37ed44aa000bef7195f139bbd07fe9f037a595f7 +Subproject commit 32717a04cf55e71153af1dfd7291d767ac2e6987 From e3fb174a03dba2ffbbc188bfc3ab9c49c0d85393 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Sat, 28 Jun 2025 22:03:25 -0700 Subject: [PATCH 7/7] increment cli version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index acf5f22..a064c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ ] [project.optional-dependencies] -cli = ["daggerml-cli"] +cli = ["daggerml-cli>=0.0.29"] test = [ "pytest", "pytest-cov",