diff --git a/pyproject.toml b/pyproject.toml index d3e9268..2b965f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ dependencies = [] [project.optional-dependencies] -cli = ["daggerml-cli>=0.0.33"] +cli = ["daggerml-cli>=0.0.37"] dev = [ "pytest", "pytest-cov", diff --git a/src/daggerml/core.py b/src/daggerml/core.py index 64ef70e..7271d74 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -6,7 +6,7 @@ import traceback as tb from dataclasses import dataclass, field, fields from tempfile import TemporaryDirectory -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast from daggerml.util import ( BackoffWithJitter, @@ -306,7 +306,6 @@ def __exit__(self, exc_type, exc_value, traceback): def __getitem__(self, name) -> "Node": return make_node(self, self._dml.get_node(name, self._ref)) - # return Node(self, self._dml.get_node(name, self._ref)) def __setitem__(self, name, value) -> "Node": assert not self._ref @@ -342,14 +341,12 @@ def __getattr__(self, name): def argv(self) -> "Node": "Access the dag's argv node" return make_node(self, self._dml.get_argv(self._ref)) - # return Node(self, self._dml.get_argv(self._ref)) @property def result(self) -> "Node": ref = self._dml.get_result(self._ref) assert ref, f"'{self.__class__.__name__}' has no attribute 'result'" return make_node(self, ref) - # return Node(self, ref) if ref else ref @result.setter def result(self, value): @@ -423,10 +420,10 @@ def _commit(self, value) -> "Node": Value to commit """ value = value if isinstance(value, (Node, Error)) else self._put(value) - dump = self._dml.commit(value) + ref = cast(Ref, self._dml.commit(value)) if self._message_handler: - self._message_handler(dump) - self._ref = Boxed(Ref(json.loads(dump)[-1][1][1])) + self._message_handler(self._dml("ref", "dump", to_json(ref), as_text=True)) + self._ref = Boxed(ref) @dataclass(frozen=True) diff --git a/submodules/daggerml_cli b/submodules/daggerml_cli index 7bb1820..cbb1a1a 160000 --- a/submodules/daggerml_cli +++ b/submodules/daggerml_cli @@ -1 +1 @@ -Subproject commit 7bb1820c7eacf3c571974109650c1a0de37791a6 +Subproject commit cbb1a1af1da0dafd96857c52fa01519975129b25 diff --git a/tests/test_core.py b/tests/test_core.py index 23de2c9..1a77b1d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -159,6 +159,20 @@ def test_load_constructors(self): assert c0.load("b").load(0) == l0 assert c0["b"][0] != l0 + def test_fn_ok_cache(self): + with TemporaryDirectory(prefix="dml-test-") as fn_cache_dir: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.n0 = SUM + nodes = [d0.n0(i, 1, 2) for i in range(2)] # unique function applications + d0.n0(0, 1, 2) # add a repeat outside so `nodes` is still unique + d0.result = nodes[0] + self.assertEqual(d0.result.value(), 3) + cache_list = dml("cache", "list", as_text=True) # response is jsonlines format + assert len([x for x in cache_list if x.rstrip() == "{"]) == 2 # this gets us unique maps + def test_async_fn_ok(self): with TemporaryDirectory(prefix="dml-test-") as fn_cache_dir: with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir):