diff --git a/src/daggerml/core.py b/src/daggerml/core.py index 1c08f58..eeba04b 100644 --- a/src/daggerml/core.py +++ b/src/daggerml/core.py @@ -356,15 +356,17 @@ def result(self) -> "Node": assert isinstance(ref, Ref), f"'{self.__class__.__name__}' dag has not been committed yet" return make_node(self, ref) - def import_(self, dag_name: str, *, name=None, doc=None) -> "Node": - """Import a dag result into this dag + def load(self, dag_name: str, key: str = "result", *, name=None, doc=None) -> "Node": + """Load a node from a different dag into this one Parameters ---------- dag_name : str Name of the dag to load + key : str, default="result" + The name of the node (or "result") to import from the loaded dag. By default, it imports the result node. name : str, optional - Name for the node + Name to assign the resulting node in this dag doc : str, optional Documentation for the node @@ -378,11 +380,14 @@ def import_(self, dag_name: str, *, name=None, doc=None) -> "Node": >>> dml = Dml.temporary() >>> dml.new("my-dag-0", "going to import this").commit(42) >>> dag = dml.new("my-dag-1", "importing my-dag-0") - >>> node = dag.import_("my-dag-0") + >>> node = dag.load("my-dag-0") >>> node.value() 42 """ - return self.put(self.dml.load(dag_name).result, name=name, doc=doc) + resp = getattr(self.dml.load(dag_name), key, None) + if resp is None: + raise_ex(Error(f"dag '{dag_name}' has no '{key}'", origin="dml", type="KeyError")) + return self.put(resp, name=name, doc=doc) @overload def put(self, value: Union[list, set, "ListNode"], *, name=None, doc=None) -> "ListNode": ... @@ -705,17 +710,7 @@ def __len__(self): # python requires this to be an int Error If the node isn't a collection (e.g. list, set, or dict). """ - if self._info["length"]: - return self._info["length"] - raise Error(f"Cannot get length of type: {self._info['data_type']}", origin="dml", type="TypeError") - - def get(self, key, default=None, *, name=None, doc=None) -> "Node": - """ - For a dict node, return the value for key if key exists, else default. - - If default is not given, it defaults to None, so that this method never raises a KeyError. - """ - return make_node(self.dag, self.dag.dml.get(self, key, default, name=name, doc=doc)) + return self._info["length"] class ListNode(CollectionNode): # noqa: F811 @@ -816,6 +811,14 @@ def __iter__(self): for k in self.keys(): yield k + def get(self, key, default=None, *, name=None, doc=None) -> "Node": + """ + For a dict node, return the value for key if key exists, else default. + + If default is not given, it defaults to None, so that this method never raises a KeyError. + """ + return make_node(self.dag, self.dag.dml.get(self, key, default, name=name, doc=doc)) + def items(self) -> Iterator[tuple[str, "Node"]]: """ Iterate over key-value pairs of a dictionary node. diff --git a/tests/test_core.py b/tests/test_core.py index 2d456e1..809f480 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,6 +2,8 @@ from tempfile import TemporaryDirectory from unittest import TestCase, mock +import pytest + from daggerml.core import Dag, Dml, Error, Executable, Node, from_data SUM = Executable("./tests/assets/fns/sum.py", adapter="dml-python-fork-adapter") @@ -10,6 +12,161 @@ TIMEOUT = Executable("./tests/assets/fns/timeout.py", adapter="dml-python-fork-adapter") +class TestSetAttrs: + @pytest.mark.parametrize("x", [[0], (0,), [], ["asdf", None]]) # none contain 1 + def test_list_attrs(self, x): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as dag: + n0 = dag.put(x) + assert n0.contains(1).value() is False + assert 1 not in n0 + assert len(n0) == len(x) + for index, item_node in enumerate(n0): + item = x[index] + assert item_node.value() == item + assert n0.contains(item).value() is True + assert item in n0 + assert n0[index].value() == item + assert n0.append(1).value() == [*x, 1] + assert n0.conj(1).value() == [*x, 1] + + @pytest.mark.parametrize("x", [{}, {"a": 1}, {"x": 42, "y": {"k0": None}}]) # none contain 'z' + def test_dict_attrs(self, x): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as dag: + n0 = dag.put(x) + assert n0.contains("z").value() is False + assert "z" not in n0 + assert len(n0) == len(x) + assert n0.get("z", default=123).value() == 123 + for key in n0: + item = x[key] + assert n0[key].value() == item + assert n0.contains(key).value() is True + assert key in n0 + assert n0.get(key).value() == item + assert [(k, v.value()) for k, v in n0.items()] == list(x.items()) + assert n0.keys() == list(x.keys()) + assert [x.value() for x in n0.values()] == list(x.values()) + assert n0.assoc("y", 3).value() == {**x, "y": 3} + assert n0.update({"z": 1, "a": 2}).value() == {**x, "z": 1, "a": 2} + + def test_load_reboot(self): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as dag: + dag.put(42, name="n0") + dag.commit("foo") + with dml.new("d1", "d1") as dag: + node = dag.load("d0", name="n1") + assert node.dag == dag + assert node.value() == "foo" + assert node.load().n0.value() == 42 + assert dag.load("d0", key="n0").value() == 42 + + def test_node_call_w_literal_deps(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + fn = Executable( + "./tests/assets/fns/sum.py", + adapter="dml-python-fork-adapter", + prepop={"x": 10}, + ) + result = dag.call(fn, *nums) + assert result.value() == sum(nums) + assert "x" in result.load().keys() + assert result.load().x.value() == 10 + + def test_node_call_w_node_deps(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + fn = Executable( + "./tests/assets/fns/sum.py", + adapter="dml-python-fork-adapter", + prepop={"x": dag.put(10)}, + ) + result = dag.call(fn, *nums) + assert result.value() == sum(nums) + assert "x" in result.load().keys() + assert result.load().x.value() == 10 + + def test_node_call(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + fn = dag.put(SUM) + result = fn(*nums) + assert result.value() == sum(nums) + + def test_load_recursing(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): + with dml.new("d0", "d0") as dag: + dag.commit(dag.call(SUM, *nums, name="n1")) + d1 = dml.new("d1", "d1") + n1 = d1.put(dml.load("d0").n1, name="n1_1") + assert n1.dag == d1 + n2 = n1.load().n1.load().num_args + assert n2.value() == len(nums) + assert n1.value() == sum(nums) + + def test_caching(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + config_dir = dml.config_dir + with dml.new("d0", "d0") as d1: + n1 = d1.call(SUM, *nums) + assert n1.value() == sum(nums) + assert isinstance(n1.load(), Dag) + uid = n1.load().uuid.value() + with Dml.temporary(cache_path=cache_path) as dml: + assert dml.config_dir != config_dir, "Config dir should not be the same" + with dml.new("d1", "d0") as d1: + n1 = d1.call(SUM, *nums) + uid1 = n1.load().uuid.value() + assert uid == uid1, "Cached dag should have the same UUID" + + def test_no_caching(self): + nums = [1, 2, 3] + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + config_dir = dml.config_dir + with dml.new("d0", "d0") as d1: + n1 = d1.call(SUM, *nums) + uid = n1.load().uuid.value() + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + assert dml.config_dir != config_dir, "Config dir should not be the same" + with dml.new("d1", "d0") as d1: + n1 = d1.call(SUM, *nums) + uid1 = n1.load().uuid.value() + assert uid != uid1, "Cached dag should have the same UUID" + + def test_nodemap(self): + with TemporaryDirectory(prefix="dml-cache-") as cache_path: + with Dml.temporary(cache_path=cache_path) as dml: + with dml.new("d0", "d0") as d0: + d0.a = 23 + node = d0.put(42, name="b") + other = d0.put(420) + assert d0.a.value() == 23 + assert list(d0) == ["a", "b"] + d0.commit([node, other]) + + class TestBasic(TestCase): def test_init(self): with Dml.temporary() as dml: @@ -131,17 +288,6 @@ def message_handler(dump): dml("dag", "delete", dag["name"], "Deleting dag") dml("repo", "gc", as_text=True) - def test_list_attrs(self): - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as dag: - n0 = dag.put([0]) - assert n0.contains(1).value() is False - assert n0.contains(0).value() is True - assert 0 in n0 - n1 = n0.append(1) - assert n1.value() == [0, 1] - def test_set_attrs(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: @@ -153,20 +299,6 @@ def test_set_attrs(self): n1 = n0.append(1) assert n1.value() == {0, 1} - def test_dict_attrs(self): - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as dag: - n0 = dag.put({"x": 42}, name="n0") - assert n0.contains("y").value() is False - assert n0.contains("x").value() is True - assert "y" not in n0 - assert "x" in n0 - n1 = n0.assoc("y", 3) - assert n1.value() == {"x": 42, "y": 3} - n2 = n1.update({"z": 1, "a": 2}) - assert n2.value() == {"a": 2, "x": 42, "y": 3, "z": 1} - def test_load_constructors(self): with TemporaryDirectory(prefix="dml-cache-") as cache_path: with Dml.temporary(cache_path=cache_path) as dml: @@ -235,115 +367,3 @@ def test_load(self): assert isinstance(dl, Dag) self.assertEqual(dl.n0.value(), 42) self.assertEqual(dl.result.value(), "foo") - - def test_load_reboot(self): - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as dag: - dag.put(42, name="n0") - dag.commit("foo") - with dml.new("d1", "d1") as dag: - node = dag.import_("d0", name="n1") - assert node.dag == dag - assert node.value() == "foo" - assert node.load().n0.value() == 42 - - def test_node_call_w_literal_deps(self): - nums = [1, 2, 3] - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): - with dml.new("d0", "d0") as dag: - fn = Executable( - "./tests/assets/fns/sum.py", - adapter="dml-python-fork-adapter", - prepop={"x": 10}, - ) - result = dag.call(fn, *nums) - assert result.value() == sum(nums) - assert "x" in result.load().keys() - assert result.load().x.value() == 10 - - def test_node_call_w_node_deps(self): - nums = [1, 2, 3] - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): - with dml.new("d0", "d0") as dag: - fn = Executable( - "./tests/assets/fns/sum.py", - adapter="dml-python-fork-adapter", - prepop={"x": dag.put(10)}, - ) - result = dag.call(fn, *nums) - assert result.value() == sum(nums) - assert "x" in result.load().keys() - assert result.load().x.value() == 10 - - def test_node_call(self): - nums = [1, 2, 3] - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): - with dml.new("d0", "d0") as dag: - fn = dag.put(SUM) - result = fn(*nums) - assert result.value() == sum(nums) - - def test_load_recursing(self): - nums = [1, 2, 3] - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]): - with dml.new("d0", "d0") as dag: - dag.commit(dag.call(SUM, *nums, name="n1")) - d1 = dml.new("d1", "d1") - n1 = d1.put(dml.load("d0").n1, name="n1_1") - assert n1.dag == d1 - n2 = n1.load().n1.load().num_args - assert n2.value() == len(nums) - assert n1.value() == sum(nums) - - def test_caching(self): - nums = [1, 2, 3] - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - config_dir = dml.config_dir - with dml.new("d0", "d0") as d1: - n1 = d1.call(SUM, *nums) - assert n1.value() == sum(nums) - assert isinstance(n1.load(), Dag) - uid = n1.load().uuid.value() - with Dml.temporary(cache_path=cache_path) as dml: - assert dml.config_dir != config_dir, "Config dir should not be the same" - with dml.new("d1", "d0") as d1: - n1 = d1.call(SUM, *nums) - uid1 = n1.load().uuid.value() - assert uid == uid1, "Cached dag should have the same UUID" - - def test_no_caching(self): - nums = [1, 2, 3] - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - config_dir = dml.config_dir - with dml.new("d0", "d0") as d1: - n1 = d1.call(SUM, *nums) - uid = n1.load().uuid.value() - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - assert dml.config_dir != config_dir, "Config dir should not be the same" - with dml.new("d1", "d0") as d1: - n1 = d1.call(SUM, *nums) - uid1 = n1.load().uuid.value() - assert uid != uid1, "Cached dag should have the same UUID" - - def test_nodemap(self): - with TemporaryDirectory(prefix="dml-cache-") as cache_path: - with Dml.temporary(cache_path=cache_path) as dml: - with dml.new("d0", "d0") as d0: - d0.a = 23 - node = d0.put(42, name="b") - other = d0.put(420) - assert d0.a.value() == 23 - assert list(d0) == ["a", "b"] - d0.commit([node, other])