Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
dependencies = []

[project.optional-dependencies]
cli = ["daggerml-cli>=0.0.33"]
cli = ["daggerml-cli>=0.0.37"]
dev = [
"pytest",
"pytest-cov",
Expand Down
11 changes: 4 additions & 7 deletions src/daggerml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback as tb
from dataclasses import dataclass, field, fields
from tempfile import TemporaryDirectory
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast

from daggerml.util import (
BackoffWithJitter,
Expand Down Expand Up @@ -306,7 +306,6 @@ def __exit__(self, exc_type, exc_value, traceback):

def __getitem__(self, name) -> "Node":
return make_node(self, self._dml.get_node(name, self._ref))
# return Node(self, self._dml.get_node(name, self._ref))

def __setitem__(self, name, value) -> "Node":
assert not self._ref
Expand Down Expand Up @@ -342,14 +341,12 @@ def __getattr__(self, name):
def argv(self) -> "Node":
"Access the dag's argv node"
return make_node(self, self._dml.get_argv(self._ref))
# return Node(self, self._dml.get_argv(self._ref))

@property
def result(self) -> "Node":
ref = self._dml.get_result(self._ref)
assert ref, f"'{self.__class__.__name__}' has no attribute 'result'"
return make_node(self, ref)
# return Node(self, ref) if ref else ref

@result.setter
def result(self, value):
Expand Down Expand Up @@ -423,10 +420,10 @@ def _commit(self, value) -> "Node":
Value to commit
"""
value = value if isinstance(value, (Node, Error)) else self._put(value)
dump = self._dml.commit(value)
ref = cast(Ref, self._dml.commit(value))
if self._message_handler:
self._message_handler(dump)
self._ref = Boxed(Ref(json.loads(dump)[-1][1][1]))
self._message_handler(self._dml("ref", "dump", to_json(ref), as_text=True))
self._ref = Boxed(ref)


@dataclass(frozen=True)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ def test_load_constructors(self):
assert c0.load("b").load(0) == l0
assert c0["b"][0] != l0

def test_fn_ok_cache(self):
with TemporaryDirectory(prefix="dml-test-") as fn_cache_dir:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir):
with TemporaryDirectory(prefix="dml-cache-") as cache_path:
with Dml.temporary(cache_path=cache_path) as dml:
with dml.new("d0", "d0") as d0:
d0.n0 = SUM
nodes = [d0.n0(i, 1, 2) for i in range(2)] # unique function applications
d0.n0(0, 1, 2) # add a repeat outside so `nodes` is still unique
d0.result = nodes[0]
self.assertEqual(d0.result.value(), 3)
cache_list = dml("cache", "list", as_text=True) # response is jsonlines format
assert len([x for x in cache_list if x.rstrip() == "{"]) == 2 # this gets us unique maps
Copy link

Copilot AI Aug 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion logic x.rstrip() == "{" is unclear and fragile. It appears to be counting JSON objects by looking for opening braces, but this could break if the JSON format changes. Consider parsing the JSON properly or using a more robust method to count cache entries.

Suggested change
assert len([x for x in cache_list if x.rstrip() == "{"]) == 2 # this gets us unique maps
# Count valid JSON objects in the jsonlines output
assert len([x for x in cache_list if x.strip() and _is_json_object(x)]) == 2

Copilot uses AI. Check for mistakes.

def test_async_fn_ok(self):
with TemporaryDirectory(prefix="dml-test-") as fn_cache_dir:
with mock.patch.dict(os.environ, DML_FN_CACHE_DIR=fn_cache_dir):
Expand Down