Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ requests and appreciate your help in improving this project.
- Add or update unit tests for any new features or bug fixes.
- Use [pytest](https://pytest.org/) for running tests.
- The testing requirements are included in the `test` feature for the library.
- You can run tests using [hatch](https://hatch.pypa.io/):
- You can run tests using [hatch](https://hatch.pypa.io/):
```
hatch run pytest .
```
- If you're using vscode, you can create a venv with the `test` feature and run tests with the command palette:
```
Python: Run Tests
```
- Or install the `test` feature with pip and run tests:
- Or install the `test` feature with pip and run tests:
```
pip install -e </path/to/library>[test]
pytest .
Expand All @@ -55,6 +55,10 @@ requests and appreciate your help in improving this project.
```
pytest -m "not slow" .
```
- We mark tests that require `daggerml-cli` to be installed with `@pytest.mark.needs_dml`. You can exclude those tests with:
```
pytest -m "not needs_dml" .
```
- Run all tests locally before submitting a pull request:
- Ensure your code passes all tests and does not decrease code coverage.
- If your changes introduce new dependencies, please update `pyproject.toml`, but we prefer to keep the dependencies to a minimum.
Expand Down
39 changes: 29 additions & 10 deletions src/daggerml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,18 @@ def make_node(dag: "Dag", ref: Ref) -> "Node":
"""
info = dag.dml("node", "describe", ref.to)
if info["data_type"] == "list":
return ListNode(dag, ref, _info=info)
if info["data_type"] == "dict":
return DictNode(dag, ref, _info=info)
if info["data_type"] == "set":
return ListNode(dag, ref, _info=info)
if info["data_type"] == "executable":
return ExecutableNode(dag, ref, _info=info)
return ScalarNode(dag, ref, _info=info)
node = ListNode(dag, ref, _info=info)
elif info["data_type"] == "dict":
node = DictNode(dag, ref, _info=info)
elif info["data_type"] == "set":
node = ListNode(dag, ref, _info=info)
elif info["data_type"] == "executable":
node = ExecutableNode(dag, ref, _info=info)
else:
node = ScalarNode(dag, ref, _info=info)
if info["doc"]:
object.__setattr__(node, "__doc__", info["doc"])
return node


@dataclass
Expand Down Expand Up @@ -444,6 +448,7 @@ def call(
doc: Optional[str] = None,
sleep: Optional[callable] = None,
timeout: int = -1,
**kw,
) -> "Node":
"""
Call a function node with arguments.
Expand All @@ -462,6 +467,8 @@ def call(
A nullary function that returns sleep time in milliseconds
timeout : int, default=-1
Maximum time to wait in milliseconds. If <= 0, wait indefinitely.
**kw : dict
Keyword arguments override any prepop values in the Executable (fn).

Returns
-------
Expand All @@ -475,6 +482,16 @@ def call(
Error
If the function returns an error
"""
if len(kw) > 0:
if isinstance(fn, Node):
fn = fn.value()
if set(kw) - set(fn.prepop):
extras = sorted(set(kw) - set(fn.prepop))
msg = f"Function called with extraneous kwargs (not in `fn.prepop`): {extras}"
raise Error(msg, origin="dml", type="KeyError")
fn = Executable(uri=fn.uri, data=fn.data, adapter=fn.adapter, prepop={**fn.prepop, **kw})
# FIXME: replace fails: `TypeError: Executable.__init__() missing 1 required positional argument: 'uri'`
# fn = replace(fn, prepop={**fn.prepop, **kw})
Comment on lines +493 to +494
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented FIXME indicates incomplete implementation. Consider either implementing the replace functionality properly or removing the comment if the current approach is intentional.

Suggested change
# FIXME: replace fails: `TypeError: Executable.__init__() missing 1 required positional argument: 'uri'`
# fn = replace(fn, prepop={**fn.prepop, **kw})

Copilot uses AI. Check for mistakes.
sleep = sleep or BackoffWithJitter()
expr = [self.put(x) for x in [fn, *args]]
end = current_time_millis() + timeout
Expand Down Expand Up @@ -602,7 +619,7 @@ class ScalarNode(Node):


class ExecutableNode(Node):
def __call__(self, *args, name=None, doc=None, sleep=None, timeout=-1) -> "Node":
def __call__(self, *args, name=None, doc=None, sleep=None, timeout=-1, **kw) -> "Node":
"""
Call this node as a function.

Expand All @@ -618,6 +635,8 @@ def __call__(self, *args, name=None, doc=None, sleep=None, timeout=-1) -> "Node"
A nullary function that returns sleep time in milliseconds
timeout : int, default=-1
Maximum time to wait in milliseconds. -1 means wait forever.
**kw : dict
Keyword arguments override any prepop values in the Executable (fn).

Returns
-------
Expand All @@ -631,7 +650,7 @@ def __call__(self, *args, name=None, doc=None, sleep=None, timeout=-1) -> "Node"
Error
If the function returns an error
"""
return self.dag.call(self, *args, name=name, doc=doc, sleep=sleep, timeout=timeout)
return self.dag.call(self, *args, name=name, doc=doc, sleep=sleep, timeout=timeout, **kw)


class CollectionNode(Node): # noqa: F811
Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Common test fixtures for dml-util tests."""

import logging
import os
from unittest.mock import patch

import pytest

from daggerml import Dml


@pytest.fixture(autouse=True)
def clear_envvars():
with patch.dict(os.environ):
# Clear AWS environment variables before any tests run
for k in os.environ:
if k.startswith("AWS_") or k.startswith("DML_"):
del os.environ[k]
os.environ["AWS_SHARED_CREDENTIALS_FILE"] = "/dev/null"
yield


@pytest.fixture(autouse=True)
def debug(clear_envvars):
"""Fixture to set debug mode for tests."""
with patch.dict(os.environ, {"DML_DEBUG": "1"}):
logging.basicConfig(level=logging.DEBUG)
yield


@pytest.fixture
def dml(tmpdir):
with Dml.temporary(cache_path=str(tmpdir)) as _dml:
with patch.dict(os.environ, DML_FN_CACHE_DIR=_dml.kwargs["config_dir"], **_dml.envvars):
yield _dml


@pytest.fixture
def fake_dml():
# patches Dml and Dag so that neither does anything
with patch("daggerml.core.Dml", autospec=True) as mock_dml:
with patch("daggerml.core.Dag", autospec=True) as mock_dag:
yield mock_dml, mock_dag
Loading