Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions src/daggerml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,17 @@ def result(self) -> "Node":
assert isinstance(ref, Ref), f"'{self.__class__.__name__}' dag has not been committed yet"
return make_node(self, ref)

def import_(self, dag_name: str, *, name=None, doc=None) -> "Node":
"""Import a dag result into this dag
def load(self, dag_name: str, key: str = "result", *, name=None, doc=None) -> "Node":
"""Load a node from a different dag into this one

Parameters
----------
dag_name : str
Name of the dag to load
key : str, default="result"
The name of the node (or "result") to import from the loaded dag. By default, it imports the result node.
name : str, optional
Name for the node
Name to assign the resulting node in this dag
doc : str, optional
Documentation for the node

Expand All @@ -378,11 +380,14 @@ def import_(self, dag_name: str, *, name=None, doc=None) -> "Node":
>>> dml = Dml.temporary()
>>> dml.new("my-dag-0", "going to import this").commit(42)
>>> dag = dml.new("my-dag-1", "importing my-dag-0")
>>> node = dag.import_("my-dag-0")
>>> node = dag.load("my-dag-0")
>>> node.value()
42
"""
return self.put(self.dml.load(dag_name).result, name=name, doc=doc)
resp = getattr(self.dml.load(dag_name), key, None)
if resp is None:
raise_ex(Error(f"dag '{dag_name}' has no '{key}'", origin="dml", type="KeyError"))
return self.put(resp, name=name, doc=doc)

@overload
def put(self, value: Union[list, set, "ListNode"], *, name=None, doc=None) -> "ListNode": ...
Expand Down Expand Up @@ -705,17 +710,7 @@ def __len__(self): # python requires this to be an int
Error
If the node isn't a collection (e.g. list, set, or dict).
"""
if self._info["length"]:
return self._info["length"]
raise Error(f"Cannot get length of type: {self._info['data_type']}", origin="dml", type="TypeError")

def get(self, key, default=None, *, name=None, doc=None) -> "Node":
"""
For a dict node, return the value for key if key exists, else default.

If default is not given, it defaults to None, so that this method never raises a KeyError.
"""
return make_node(self.dag, self.dag.dml.get(self, key, default, name=name, doc=doc))
return self._info["length"]
Copy link

Copilot AI Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix removes the conditional check for self._info["length"] but doesn't handle the case where the length might be None or missing. This could cause a KeyError or return None unexpectedly when len() is called on nodes that don't have length information.

Copilot uses AI. Check for mistakes.


class ListNode(CollectionNode): # noqa: F811
Expand Down Expand Up @@ -816,6 +811,14 @@ def __iter__(self):
for k in self.keys():
yield k

def get(self, key, default=None, *, name=None, doc=None) -> "Node":
"""
For a dict node, return the value for key if key exists, else default.

If default is not given, it defaults to None, so that this method never raises a KeyError.
"""
return make_node(self.dag, self.dag.dml.get(self, key, default, name=name, doc=doc))

def items(self) -> Iterator[tuple[str, "Node"]]:
"""
Iterate over key-value pairs of a dictionary node.
Expand Down
294 changes: 157 additions & 137 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from tempfile import TemporaryDirectory
from unittest import TestCase, mock

import pytest

from daggerml.core import Dag, Dml, Error, Executable, Node, from_data

SUM = Executable("./tests/assets/fns/sum.py", adapter="dml-python-fork-adapter")
Expand All @@ -10,6 +12,161 @@
TIMEOUT = Executable("./tests/assets/fns/timeout.py", adapter="dml-python-fork-adapter")


class TestSetAttrs:
@pytest.mark.parametrize("x", [[0], (0,), [], ["asdf", None]]) # none contain 1
def test_list_attrs(self, x):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as dag:
n0 = dag.put(x)
assert n0.contains(1).value() is False
assert 1 not in n0
assert len(n0) == len(x)
for index, item_node in enumerate(n0):
item = x[index]
assert item_node.value() == item
assert n0.contains(item).value() is True
assert item in n0
assert n0[index].value() == item
assert n0.append(1).value() == [*x, 1]
assert n0.conj(1).value() == [*x, 1]

@pytest.mark.parametrize("x", [{}, {"a": 1}, {"x": 42, "y": {"k0": None}}]) # none contain 'z'
def test_dict_attrs(self, x):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as dag:
n0 = dag.put(x)
assert n0.contains("z").value() is False
assert "z" not in n0
assert len(n0) == len(x)
assert n0.get("z", default=123).value() == 123
for key in n0:
item = x[key]
assert n0[key].value() == item
assert n0.contains(key).value() is True
assert key in n0
assert n0.get(key).value() == item
assert [(k, v.value()) for k, v in n0.items()] == list(x.items())
assert n0.keys() == list(x.keys())
assert [x.value() for x in n0.values()] == list(x.values())
assert n0.assoc("y", 3).value() == {**x, "y": 3}
assert n0.update({"z": 1, "a": 2}).value() == {**x, "z": 1, "a": 2}

def test_load_reboot(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as dag:
dag.put(42, name="n0")
dag.commit("foo")
with dml.new("d1", "d1") as dag:
node = dag.load("d0", name="n1")
assert node.dag == dag
assert node.value() == "foo"
assert node.load().n0.value() == 42
assert dag.load("d0", key="n0").value() == 42

def test_node_call_w_literal_deps(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
fn = Executable(
"./tests/assets/fns/sum.py",
adapter="dml-python-fork-adapter",
prepop={"x": 10},
)
result = dag.call(fn, *nums)
assert result.value() == sum(nums)
assert "x" in result.load().keys()
assert result.load().x.value() == 10

def test_node_call_w_node_deps(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
fn = Executable(
"./tests/assets/fns/sum.py",
adapter="dml-python-fork-adapter",
prepop={"x": dag.put(10)},
)
result = dag.call(fn, *nums)
assert result.value() == sum(nums)
assert "x" in result.load().keys()
assert result.load().x.value() == 10

def test_node_call(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
fn = dag.put(SUM)
result = fn(*nums)
assert result.value() == sum(nums)

def test_load_recursing(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
dag.commit(dag.call(SUM, *nums, name="n1"))
d1 = dml.new("d1", "d1")
n1 = d1.put(dml.load("d0").n1, name="n1_1")
assert n1.dag == d1
n2 = n1.load().n1.load().num_args
assert n2.value() == len(nums)
assert n1.value() == sum(nums)

def test_caching(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
config_dir = dml.config_dir
with dml.new("d0", "d0") as d1:
n1 = d1.call(SUM, *nums)
assert n1.value() == sum(nums)
assert isinstance(n1.load(), Dag)
uid = n1.load().uuid.value()
with Dml.temporary(cache_path=cache_path) as dml:
assert dml.config_dir != config_dir, "Config dir should not be the same"
with dml.new("d1", "d0") as d1:
n1 = d1.call(SUM, *nums)
uid1 = n1.load().uuid.value()
assert uid == uid1, "Cached dag should have the same UUID"

def test_no_caching(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
config_dir = dml.config_dir
with dml.new("d0", "d0") as d1:
n1 = d1.call(SUM, *nums)
uid = n1.load().uuid.value()
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
assert dml.config_dir != config_dir, "Config dir should not be the same"
with dml.new("d1", "d0") as d1:
n1 = d1.call(SUM, *nums)
uid1 = n1.load().uuid.value()
assert uid != uid1, "Cached dag should have the same UUID"

def test_nodemap(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as d0:
d0.a = 23
node = d0.put(42, name="b")
other = d0.put(420)
assert d0.a.value() == 23
assert list(d0) == ["a", "b"]
d0.commit([node, other])


class TestBasic(TestCase):
def test_init(self):
with Dml.temporary() as dml:
Expand Down Expand Up @@ -131,17 +288,6 @@ def message_handler(dump):
dml("dag", "delete", dag["name"], "Deleting dag")
dml("repo", "gc", as_text=True)

def test_list_attrs(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as dag:
n0 = dag.put([0])
assert n0.contains(1).value() is False
assert n0.contains(0).value() is True
assert 0 in n0
n1 = n0.append(1)
assert n1.value() == [0, 1]

def test_set_attrs(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
Expand All @@ -153,20 +299,6 @@ def test_set_attrs(self):
n1 = n0.append(1)
assert n1.value() == {0, 1}

def test_dict_attrs(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as dag:
n0 = dag.put({"x": 42}, name="n0")
assert n0.contains("y").value() is False
assert n0.contains("x").value() is True
assert "y" not in n0
assert "x" in n0
n1 = n0.assoc("y", 3)
assert n1.value() == {"x": 42, "y": 3}
n2 = n1.update({"z": 1, "a": 2})
assert n2.value() == {"a": 2, "x": 42, "y": 3, "z": 1}

def test_load_constructors(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
Expand Down Expand Up @@ -235,115 +367,3 @@ def test_load(self):
assert isinstance(dl, Dag)
self.assertEqual(dl.n0.value(), 42)
self.assertEqual(dl.result.value(), "foo")

def test_load_reboot(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as dag:
dag.put(42, name="n0")
dag.commit("foo")
with dml.new("d1", "d1") as dag:
node = dag.import_("d0", name="n1")
assert node.dag == dag
assert node.value() == "foo"
assert node.load().n0.value() == 42

def test_node_call_w_literal_deps(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
fn = Executable(
"./tests/assets/fns/sum.py",
adapter="dml-python-fork-adapter",
prepop={"x": 10},
)
result = dag.call(fn, *nums)
assert result.value() == sum(nums)
assert "x" in result.load().keys()
assert result.load().x.value() == 10

def test_node_call_w_node_deps(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
fn = Executable(
"./tests/assets/fns/sum.py",
adapter="dml-python-fork-adapter",
prepop={"x": dag.put(10)},
)
result = dag.call(fn, *nums)
assert result.value() == sum(nums)
assert "x" in result.load().keys()
assert result.load().x.value() == 10

def test_node_call(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
fn = dag.put(SUM)
result = fn(*nums)
assert result.value() == sum(nums)

def test_load_recursing(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=dml.kwargs["config_dir"]):
with dml.new("d0", "d0") as dag:
dag.commit(dag.call(SUM, *nums, name="n1"))
d1 = dml.new("d1", "d1")
n1 = d1.put(dml.load("d0").n1, name="n1_1")
assert n1.dag == d1
n2 = n1.load().n1.load().num_args
assert n2.value() == len(nums)
assert n1.value() == sum(nums)

def test_caching(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
config_dir = dml.config_dir
with dml.new("d0", "d0") as d1:
n1 = d1.call(SUM, *nums)
assert n1.value() == sum(nums)
assert isinstance(n1.load(), Dag)
uid = n1.load().uuid.value()
with Dml.temporary(cache_path=cache_path) as dml:
assert dml.config_dir != config_dir, "Config dir should not be the same"
with dml.new("d1", "d0") as d1:
n1 = d1.call(SUM, *nums)
uid1 = n1.load().uuid.value()
assert uid == uid1, "Cached dag should have the same UUID"

def test_no_caching(self):
nums = [1, 2, 3]
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
config_dir = dml.config_dir
with dml.new("d0", "d0") as d1:
n1 = d1.call(SUM, *nums)
uid = n1.load().uuid.value()
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
assert dml.config_dir != config_dir, "Config dir should not be the same"
with dml.new("d1", "d0") as d1:
n1 = d1.call(SUM, *nums)
uid1 = n1.load().uuid.value()
assert uid != uid1, "Cached dag should have the same UUID"

def test_nodemap(self):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as d0:
d0.a = 23
node = d0.put(42, name="b")
other = d0.put(420)
assert d0.a.value() == 23
assert list(d0) == ["a", "b"]
d0.commit([node, other])