Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions src/daggerml_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
Ctx,
Dag,
Error,
Executable,
FnDag,
Import,
Index,
Literal,
Node,
Ref,
Repo,
Resource,
unroll_datum,
)
from daggerml_cli.topology import node_info, topology
Expand Down Expand Up @@ -430,27 +430,39 @@ def op_start_fn(db, index, argv, name=None, doc=None):
@invoke_op
def op_put_literal(db, index, data, name=None, doc=None):
# TODO: refactor so that Resource.data -> Ref(datum)
def fn(args):
def maybe_to_node(args):
fn_ = None
if isinstance(args, list):
args = [fn(x) for x in args]
if any(isinstance(x, Ref) for x in args):
fn_ = db.put_datum(Resource("daggerml:list"))
args = [maybe_to_node(x) for x in args]
if any(isinstance(x, Ref) for x in args): # only insert if needed
fn_ = db.put_datum(Executable("daggerml:list"))
elif isinstance(args, dict):
args = {k: fn(v) for k, v in args.items()}
if any(isinstance(x, Ref) for x in args.values()):
fn_ = db.put_datum(Resource("daggerml:dict"))
args = {k: maybe_to_node(v) for k, v in args.items()}
if any(isinstance(x, Ref) for x in args.values()): # only insert if needed
fn_ = db.put_datum(Executable("daggerml:dict"))
args = flatten(args.items())
elif isinstance(args, set):
args = {fn(x) for x in args}
if any(isinstance(x, Ref) for x in args):
fn_ = db.put_datum(Resource("daggerml:set"))
args = {maybe_to_node(x) for x in args}
if any(isinstance(x, Ref) for x in args): # only insert if needed
fn_ = db.put_datum(Executable("daggerml:set"))
elif isinstance(args, Executable):
attrs = {}
for x in ["data", "prepop"]:
attrs[x] = {}
for k, v in (getattr(args, x) or {}).items():
attrs[x][k] = maybe_to_node(v)
if isinstance(attrs[x][k], Ref) and attrs[x][k].type == "node":
attrs[x][k] = db.get(attrs[x][k]).value
else:
attrs[x][k] = db.put_datum(attrs[x][k])
args = Executable(args.uri, adapter=args.adapter, **attrs) # so we don't mutate
return args
if fn_ is not None:
return op_start_fn(db, index, [fn_, *args])
return args

with db.tx(True):
data = fn(data)
data = maybe_to_node(data)
if isinstance(data, Ref) and data.type == "node":
return op_set_node(db, index, name, data) if name else data
result = db.put_node(Literal(db.put_datum(data)), index=index, name=name, doc=doc)
Expand Down Expand Up @@ -551,7 +563,7 @@ def inner(*_args, **_kwargs):
op, args, kwargs = data
if op in BUILTIN_FNS:
with db.tx(True):
fn = db.put_datum(Resource(f"daggerml:{op}"))
fn = db.put_datum(Executable(f"daggerml:{op}"))
fn = op_put_literal(db, index, fn, name=f"daggerml:{op}")
return op_start_fn(db, index, [fn, *args], **kwargs)
return invoke_op.fns.get(op, no_such_op(op))(db, index, *args, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/daggerml_cli/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import os
import sys
from functools import wraps
from pathlib import Path

Expand Down
11 changes: 8 additions & 3 deletions src/daggerml_cli/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@ class CacheError(Exception):


def serialize_resource(x):
from daggerml_cli.repo import Resource
from daggerml_cli.repo import Executable, Resource

if isinstance(x, Resource):
if isinstance(x, Executable):
return {
"__type__": "resource",
"__type__": "executable",
"uri": x.uri,
"data": x.data,
"adapter": x.adapter,
}
if isinstance(x, Resource):
return {
"__type__": "resource",
"uri": x.uri,
}


def dbenv(path, db_types, **kw):
Expand Down
66 changes: 50 additions & 16 deletions src/daggerml_cli/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from dataclasses import InitVar, dataclass, field, fields, is_dataclass
from hashlib import md5
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union, cast
from urllib.parse import urlparse
from uuid import uuid4

Expand Down Expand Up @@ -70,7 +70,7 @@ def to_data(obj):
if isinstance(obj, (list, set)):
return [n[0], *[to_data(x) for x in obj]]
if isinstance(obj, dict):
return [n[0], *[[k, to_data(v)] for k, v in obj.items()]]
return [n[0], *[[k, to_data(v)] for k, v in sorted(obj.items(), key=lambda x: x[0])]]
if n in DATA_TYPE:
return [n, *[to_data(getattr(obj, x.name)) for x in fields(obj)]]
raise ValueError(f"no data encoding for type: {n}")
Expand All @@ -82,6 +82,10 @@ def get(value):
value = value()
if isinstance(value, Datum):
value = value.value
if isinstance(value, Executable):
data = {k: get(v) for k, v in value.data.items()}
prepop = {k: get(v) for k, v in value.prepop.items()}
return Executable(value.uri, adapter=value.adapter, data=data, prepop=prepop)
if isinstance(value, (type(None), str, bool, int, float, Resource)):
return value
if isinstance(value, list):
Expand Down Expand Up @@ -213,15 +217,21 @@ def __str__(self):


@repo_type(db=False)
@dataclass(frozen=True)
@dataclass
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Resource class should remain frozen (immutable) as it was before, but Executable should be mutable. Consider adding frozen=True back to the Resource class and keeping Executable mutable to maintain the original immutability contract for basic resources.

Suggested change
@dataclass
@dataclass(frozen=True)

Copilot uses AI. Check for mistakes.
class Resource:
uri: str
data: Optional[Ref] = None # -> Datum


@repo_type(db=False)
@dataclass
class Executable(Resource):
data: Dict[str, Any] = field(default_factory=dict) # -> Ref(datum)
adapter: Optional[str] = None
prepop: Dict[str, Any] = field(default_factory=dict) # -> Ref(datum)
Comment on lines +225 to +230
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Resource class should remain frozen (immutable) as it was before, but Executable should be mutable. Consider adding frozen=True back to the Resource class and keeping Executable mutable to maintain the original immutability contract for basic resources.

Copilot uses AI. Check for mistakes.


@repo_type
@dataclass(frozen=True)
@dataclass
class Deleted(Resource):
@classmethod
def resource(cls, obj: Resource):
Expand Down Expand Up @@ -274,6 +284,7 @@ def nameof(self, ref):
@repo_type
@dataclass
class FnDag(Dag):
cache_key: str
argv: Optional[Ref] = None # -> node(expr) (in this dag)


Expand Down Expand Up @@ -315,7 +326,7 @@ def error(self):
@repo_type(db=False)
@dataclass
class Fn(Import):
argv: Optional[list[Ref]] = None # -> node
argv: list[Ref] = field(default_factory=list) # -> node


@repo_type
Expand Down Expand Up @@ -511,6 +522,9 @@ def walk_ordered(self, *key):
xs += [a for a in x if a not in result]
elif isinstance(x, dict):
xs += [a for a in x.values() if a not in result]
elif isinstance(x, Executable):
xs += [a for a in x.data.values() if a not in result]
xs += [a for a in x.prepop.values() if a not in result]
elif isinstance(x, (Error, Resource)):
pass # cannot recurse into these classes
elif is_dataclass(x):
Expand Down Expand Up @@ -759,11 +773,25 @@ def begin(self, *, message, name=None, dump=None):
if dump is None:
dag = self(Dag([], {}, None, None))
else:
loaded = self.load_ref(dump)
assert loaded is not None, "failed to load dump"
loaded = cast(Dict[str, Any], from_json(dump))
named_nodes = {}
with self.tx(True):
argv = self(Node(Argv(loaded)))
dag = self(FnDag([argv], {}, None, None, argv))
argv = self(Node(Argv(self.put_datum([self.load_ref(x) for x in loaded["expr"]]))))
for k, v in loaded["prepop"].items():
datum_ref = self.load_ref(v)
if not isinstance(datum_ref, Ref) or datum_ref.type != "datum":
raise ValueError(f"invalid datum ref in `begin`: {v}")
named_nodes[k] = self(Node(Literal(datum_ref)))
dag = self(
FnDag(
sorted([argv, *named_nodes.values()]),
named_nodes,
None,
None,
md5(dump.encode()).hexdigest(),
argv,
)
)
commit = Commit([ctx.head.commit], self(ctx.tree), self.user, self.user, message, dag_name=name)
index = self(Index(self(commit), dag))
return index
Expand All @@ -772,7 +800,7 @@ def put_node(self, data, index: Ref, name=None, doc=None):
ctx = Ctx.from_head(index)
node = data if isinstance(data, Ref) else self(Node(data, doc=doc))
if node not in ctx.dag.nodes:
ctx.dag.nodes.append(node)
ctx.dag.nodes = sorted([node, *ctx.dag.nodes], key=lambda x: x.to)
if name:
ctx.dag.names[name] = node
ctx.commit.tree = self(ctx.tree)
Expand All @@ -791,11 +819,10 @@ def get_node_value(self, ref: Ref):

def start_fn(self, index, *, argv, name=None, doc=None):
fn, *data = map(lambda x: x().datum, argv)
argv_datum = self.put_datum([x().value for x in argv])
if fn.adapter is None:
uri = urlparse(fn.uri)
assert uri.scheme == "daggerml", f"unexpected URI scheme: {uri.scheme!r} for null adapter"
argv_node = self(Node(Argv(argv_datum)))
argv_node = self(Node(Argv(self.put_datum([x().value for x in argv]))))
result = error = None
nodes = [argv_node]
try:
Expand All @@ -805,17 +832,24 @@ def start_fn(self, index, *, argv, name=None, doc=None):
else:
result = self(Node(Literal(self.put_datum(result))))
nodes.append(result)
fndag = self(FnDag(nodes, {}, result, error, argv_node))
fndag = self(FnDag(nodes, {}, result, error, argv_node().value.id, argv_node))
else:
assert self.cache_path, (
"cache path is required for function execution. "
"Set the cache path via the DML_CACHE_PATH environment variable or in the config file."
)
argv_datum = to_json(
{
"expr": [self.dump_ref(x().value) for x in argv],
"prepop": {k: self.dump_ref(v) for k, v in fn.prepop.items()},
}
)
cache_key = md5(argv_datum.encode()).hexdigest()
with Cache(self.cache_path, create=False) as cache_db:
cached_val = cache_db.submit(fn, argv_datum.id, self.dump_ref(argv_datum))
cached_val = cache_db.submit(unroll_datum(fn), cache_key, argv_datum)
fndag = self.load_ref(cached_val) if cached_val else None
if isinstance(fndag, Error):
fndag = self(FnDag([argv], {}, None, fndag, argv))
fndag = self(FnDag([argv], {}, None, fndag, cache_key, argv))
if fndag is not None:
node = self.put_node(Fn(fndag, None, argv), index=index, name=name, doc=doc)
raise_ex(self.get(node).error)
Expand Down
8 changes: 6 additions & 2 deletions src/daggerml_cli/topology.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from daggerml_cli.repo import Fn, Import
from typing import cast

from daggerml_cli.repo import Executable, Fn, Import
from daggerml_cli.util import flatten


Expand All @@ -18,8 +20,10 @@ def node_info(ref, *, include_argv=True):
}
if include_argv and isinstance(node.data, Fn):
info["argv"] = [node_info(x, include_argv=False) for x in node.data.argv]
info["prepop"] = cast(Executable, node.data.argv[0]().datum).prepop
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that argv[0] always contains an Executable object, but it could contain other types. Add a type check before casting: if isinstance(node.data.argv[0]().datum, Executable):

Suggested change
info["prepop"] = cast(Executable, node.data.argv[0]().datum).prepop
datum0 = node.data.argv[0]().datum
info["prepop"] = datum0.prepop if isinstance(datum0, Executable) else None

Copilot uses AI. Check for mistakes.
elif include_argv:
info["argv"] = None
info["prepop"] = None
return info


Expand All @@ -43,7 +47,7 @@ def topology(db, ref):
return {
"id": ref,
"argv": dag.argv.to if hasattr(dag, "argv") else None,
"cache_key": dag.argv().value.id if hasattr(dag, "argv") else None,
"cache_key": getattr(dag, "cache_key", None),
"nodes": [make_node(dag.nameof(x), x) for x in dag.nodes],
"edges": edges,
"result": dag.result.to if dag.result is not None else None,
Expand Down
25 changes: 18 additions & 7 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from daggerml_cli import api
from daggerml_cli.config import Config
from daggerml_cli.db import CacheError
from daggerml_cli.repo import Error, FnDag, Node, Ref, Resource
from daggerml_cli.repo import Error, Executable, FnDag, Node, Ref, Resource
from tests.util import SimpleApi

SUM = Resource("./tests/fn/sum.py", adapter="dml-python-fork-adapter")
AER = Resource("./tests/fn/adapter_error.py", adapter="dml-python-fork-adapter")
SUM = Executable("./tests/fn/sum.py", adapter="dml-python-fork-adapter")
AER = Executable("./tests/fn/adapter_error.py", adapter="dml-python-fork-adapter")


def env(**kwargs):
Expand Down Expand Up @@ -223,8 +223,19 @@ def test_cached_errors(self):

def test_resource(self):
with SimpleApi.begin() as d0:
resource = Resource("uri:here", data={"a": 1, "b": [2, 3], "c": Resource("qwer")})
d0.put_literal(resource)
resource = Executable(
"uri:here",
data={"a": 1, "b": [2, 3], "c": Resource("qwer")},
prepop={"a": {"b": 2}},
)
node = d0.put_literal(resource, name="x")
with d0.tx():
nodeval = node().datum
assert isinstance(nodeval, Executable)
assert nodeval.uri == resource.uri
assert nodeval.prepop != resource.prepop
assert nodeval.prepop.keys() == resource.prepop.keys()
assert d0.unroll(node) == resource

def test_describe_dag(self):
with self.tmpd() as cache_path:
Expand Down Expand Up @@ -260,7 +271,7 @@ def test_describe_dag(self):
)
self.assertCountEqual(
[x["data_type"] for x in desc["nodes"]],
["resource", "int", "int", "list"],
["executable", "int", "int", "list"],
)
assert len(desc["edges"]) == len(nodes) + 2 # +1 because dag->node edge
assert {e["source"] for e in desc["edges"] if e["type"] == "node"} == {x for x in nodes}
Expand Down Expand Up @@ -307,7 +318,7 @@ def test_describe_dag_w_errs(self):
)
self.assertCountEqual(
[x["data_type"] for x in desc["nodes"]],
["resource", "int", "str", "error", "nonetype"],
["executable", "int", "str", "error", "nonetype"],
)

def test_backtrack_node(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from click.testing import CliRunner

from daggerml_cli.cli import cli, from_json, jsdumps, to_json
from daggerml_cli.repo import Resource
from daggerml_cli.repo import Executable, Resource

SUM = Resource("./tests/fn/sum.py", adapter="dml-python-fork-adapter")
SUM = Executable("./tests/fn/sum.py", adapter="dml-python-fork-adapter")


@dataclass
Expand Down
Loading