From 657be9181a16a9055a0b3f9e6e5c4310b8582ec4 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Tue, 9 Sep 2025 18:25:09 -0700 Subject: [PATCH 1/2] wip --- src/daggerml_cli/api.py | 38 +++++++++++++------- src/daggerml_cli/cli/__init__.py | 1 - src/daggerml_cli/db.py | 11 ++++-- src/daggerml_cli/repo.py | 60 ++++++++++++++++++++++++-------- src/daggerml_cli/topology.py | 8 +++-- tests/test_api.py | 25 +++++++++---- tests/test_cli.py | 4 +-- tests/test_repo.py | 20 ++++++++--- 8 files changed, 120 insertions(+), 47 deletions(-) diff --git a/src/daggerml_cli/api.py b/src/daggerml_cli/api.py index 95d9703..d8f008d 100644 --- a/src/daggerml_cli/api.py +++ b/src/daggerml_cli/api.py @@ -19,6 +19,7 @@ Ctx, Dag, Error, + Executable, FnDag, Import, Index, @@ -26,7 +27,6 @@ Node, Ref, Repo, - Resource, unroll_datum, ) from daggerml_cli.topology import node_info, topology @@ -430,27 +430,39 @@ def op_start_fn(db, index, argv, name=None, doc=None): @invoke_op def op_put_literal(db, index, data, name=None, doc=None): # TODO: refactor so that Resource.data -> Ref(datum) - def fn(args): + def maybe_to_node(args): fn_ = None if isinstance(args, list): - args = [fn(x) for x in args] - if any(isinstance(x, Ref) for x in args): - fn_ = db.put_datum(Resource("daggerml:list")) + args = [maybe_to_node(x) for x in args] + if any(isinstance(x, Ref) for x in args): # only insert if needed + fn_ = db.put_datum(Executable("daggerml:list")) elif isinstance(args, dict): - args = {k: fn(v) for k, v in args.items()} - if any(isinstance(x, Ref) for x in args.values()): - fn_ = db.put_datum(Resource("daggerml:dict")) + args = {k: maybe_to_node(v) for k, v in args.items()} + if any(isinstance(x, Ref) for x in args.values()): # only insert if needed + fn_ = db.put_datum(Executable("daggerml:dict")) args = flatten(args.items()) elif isinstance(args, set): - args = {fn(x) for x in args} - if any(isinstance(x, Ref) for x in args): - fn_ = db.put_datum(Resource("daggerml:set")) + args = {maybe_to_node(x) for x in args} + if any(isinstance(x, Ref) for x in args): # only insert if needed + fn_ = db.put_datum(Executable("daggerml:set")) + elif isinstance(args, Executable): + attrs = {} + for x in ["data", "prepop"]: + attrs[x] = {} + for k, v in (getattr(args, x) or {}).items(): + attrs[x][k] = maybe_to_node(v) + if isinstance(attrs[x][k], Ref) and attrs[x][k].type == "node": + attrs[x][k] = db.get(attrs[x][k]).value + else: + attrs[x][k] = db.put_datum(attrs[x][k]) + args = Executable(args.uri, adapter=args.adapter, **attrs) # so we don't mutate + return args if fn_ is not None: return op_start_fn(db, index, [fn_, *args]) return args with db.tx(True): - data = fn(data) + data = maybe_to_node(data) if isinstance(data, Ref) and data.type == "node": return op_set_node(db, index, name, data) if name else data result = db.put_node(Literal(db.put_datum(data)), index=index, name=name, doc=doc) @@ -551,7 +563,7 @@ def inner(*_args, **_kwargs): op, args, kwargs = data if op in BUILTIN_FNS: with db.tx(True): - fn = db.put_datum(Resource(f"daggerml:{op}")) + fn = db.put_datum(Executable(f"daggerml:{op}")) fn = op_put_literal(db, index, fn, name=f"daggerml:{op}") return op_start_fn(db, index, [fn, *args], **kwargs) return invoke_op.fns.get(op, no_such_op(op))(db, index, *args, **kwargs) diff --git a/src/daggerml_cli/cli/__init__.py b/src/daggerml_cli/cli/__init__.py index 9f38cb9..641613a 100644 --- a/src/daggerml_cli/cli/__init__.py +++ b/src/daggerml_cli/cli/__init__.py @@ -1,7 +1,6 @@ import json import logging import os -import sys from functools import wraps from pathlib import Path diff --git a/src/daggerml_cli/db.py b/src/daggerml_cli/db.py index 5b983e3..ded693f 100644 --- a/src/daggerml_cli/db.py +++ b/src/daggerml_cli/db.py @@ -22,15 +22,20 @@ class CacheError(Exception): def serialize_resource(x): - from daggerml_cli.repo import Resource + from daggerml_cli.repo import Executable, Resource - if isinstance(x, Resource): + if isinstance(x, Executable): return { - "__type__": "resource", + "__type__": "executable", "uri": x.uri, "data": x.data, "adapter": x.adapter, } + if isinstance(x, Resource): + return { + "__type__": "resource", + "uri": x.uri, + } def dbenv(path, db_types, **kw): diff --git a/src/daggerml_cli/repo.py b/src/daggerml_cli/repo.py index 8b9ec49..c09475e 100644 --- a/src/daggerml_cli/repo.py +++ b/src/daggerml_cli/repo.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from dataclasses import InitVar, dataclass, field, fields, is_dataclass from hashlib import md5 -from typing import TYPE_CHECKING, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union, cast from urllib.parse import urlparse from uuid import uuid4 @@ -70,7 +70,7 @@ def to_data(obj): if isinstance(obj, (list, set)): return [n[0], *[to_data(x) for x in obj]] if isinstance(obj, dict): - return [n[0], *[[k, to_data(v)] for k, v in obj.items()]] + return [n[0], *[[k, to_data(v)] for k, v in sorted(obj.items(), key=lambda x: x[0])]] if n in DATA_TYPE: return [n, *[to_data(getattr(obj, x.name)) for x in fields(obj)]] raise ValueError(f"no data encoding for type: {n}") @@ -82,6 +82,10 @@ def get(value): value = value() if isinstance(value, Datum): value = value.value + if isinstance(value, Executable): + data = {k: get(v) for k, v in value.data.items()} + prepop = {k: get(v) for k, v in value.prepop.items()} + return Executable(value.uri, adapter=value.adapter, data=data, prepop=prepop) if isinstance(value, (type(None), str, bool, int, float, Resource)): return value if isinstance(value, list): @@ -213,15 +217,21 @@ def __str__(self): @repo_type(db=False) -@dataclass(frozen=True) +@dataclass class Resource: uri: str - data: Optional[Ref] = None # -> Datum + + +@repo_type(db=False) +@dataclass +class Executable(Resource): + data: Dict[str, Any] = field(default_factory=dict) # -> Ref(datum) adapter: Optional[str] = None + prepop: Dict[str, Any] = field(default_factory=dict) # -> Ref(datum) @repo_type -@dataclass(frozen=True) +@dataclass class Deleted(Resource): @classmethod def resource(cls, obj: Resource): @@ -274,6 +284,7 @@ def nameof(self, ref): @repo_type @dataclass class FnDag(Dag): + cache_key: str argv: Optional[Ref] = None # -> node(expr) (in this dag) @@ -315,7 +326,7 @@ def error(self): @repo_type(db=False) @dataclass class Fn(Import): - argv: Optional[list[Ref]] = None # -> node + argv: list[Ref] = field(default_factory=list) # -> node @repo_type @@ -511,6 +522,9 @@ def walk_ordered(self, *key): xs += [a for a in x if a not in result] elif isinstance(x, dict): xs += [a for a in x.values() if a not in result] + elif isinstance(x, Executable): + xs += [a for a in x.data.values() if a not in result] + xs += [a for a in x.prepop.values() if a not in result] elif isinstance(x, (Error, Resource)): pass # cannot recurse into these classes elif is_dataclass(x): @@ -759,11 +773,24 @@ def begin(self, *, message, name=None, dump=None): if dump is None: dag = self(Dag([], {}, None, None)) else: - loaded = self.load_ref(dump) - assert loaded is not None, "failed to load dump" + loaded = cast(Dict[str, Any], from_json(dump)) + named_nodes = {} with self.tx(True): - argv = self(Node(Argv(loaded))) - dag = self(FnDag([argv], {}, None, None, argv)) + argv = self(Node(Argv(self.put_datum([self.load_ref(x) for x in loaded["expr"]])))) + for k, v in loaded["prepop"].items(): + datum_ref = self.load_ref(v) + assert isinstance(datum_ref, Ref) and datum_ref.type == "datum", f"invalid datum ref: {v}" + named_nodes[k] = self(Node(Literal(datum_ref))) + dag = self( + FnDag( + sorted([argv, *named_nodes.values()]), + named_nodes, + None, + None, + md5(dump.encode()).hexdigest(), + argv, + ) + ) commit = Commit([ctx.head.commit], self(ctx.tree), self.user, self.user, message, dag_name=name) index = self(Index(self(commit), dag)) return index @@ -772,7 +799,7 @@ def put_node(self, data, index: Ref, name=None, doc=None): ctx = Ctx.from_head(index) node = data if isinstance(data, Ref) else self(Node(data, doc=doc)) if node not in ctx.dag.nodes: - ctx.dag.nodes.append(node) + ctx.dag.nodes = sorted([node, *ctx.dag.nodes], key=lambda x: x.to) if name: ctx.dag.names[name] = node ctx.commit.tree = self(ctx.tree) @@ -791,11 +818,10 @@ def get_node_value(self, ref: Ref): def start_fn(self, index, *, argv, name=None, doc=None): fn, *data = map(lambda x: x().datum, argv) - argv_datum = self.put_datum([x().value for x in argv]) if fn.adapter is None: uri = urlparse(fn.uri) assert uri.scheme == "daggerml", f"unexpected URI scheme: {uri.scheme!r} for null adapter" - argv_node = self(Node(Argv(argv_datum))) + argv_node = self(Node(Argv(self.put_datum([x().value for x in argv])))) result = error = None nodes = [argv_node] try: @@ -811,8 +837,14 @@ def start_fn(self, index, *, argv, name=None, doc=None): "cache path is required for function execution. " "Set the cache path via the DML_CACHE_PATH environment variable or in the config file." ) + argv_datum = to_json( + { + "expr": [self.dump_ref(x().value) for x in argv], + "prepop": {k: self.dump_ref(v) for k, v in fn.prepop.items()}, + } + ) with Cache(self.cache_path, create=False) as cache_db: - cached_val = cache_db.submit(fn, argv_datum.id, self.dump_ref(argv_datum)) + cached_val = cache_db.submit(unroll_datum(fn), md5(argv_datum.encode()).hexdigest(), argv_datum) fndag = self.load_ref(cached_val) if cached_val else None if isinstance(fndag, Error): fndag = self(FnDag([argv], {}, None, fndag, argv)) diff --git a/src/daggerml_cli/topology.py b/src/daggerml_cli/topology.py index 6a604ef..98a8ccb 100644 --- a/src/daggerml_cli/topology.py +++ b/src/daggerml_cli/topology.py @@ -1,4 +1,6 @@ -from daggerml_cli.repo import Fn, Import +from typing import cast + +from daggerml_cli.repo import Executable, Fn, Import from daggerml_cli.util import flatten @@ -18,8 +20,10 @@ def node_info(ref, *, include_argv=True): } if include_argv and isinstance(node.data, Fn): info["argv"] = [node_info(x, include_argv=False) for x in node.data.argv] + info["prepop"] = cast(Executable, node.data.argv[0]().datum).prepop elif include_argv: info["argv"] = None + info["prepop"] = None return info @@ -43,7 +47,7 @@ def topology(db, ref): return { "id": ref, "argv": dag.argv.to if hasattr(dag, "argv") else None, - "cache_key": dag.argv().value.id if hasattr(dag, "argv") else None, + "cache_key": getattr(dag, "cache_key", None), "nodes": [make_node(dag.nameof(x), x) for x in dag.nodes], "edges": edges, "result": dag.result.to if dag.result is not None else None, diff --git a/tests/test_api.py b/tests/test_api.py index 0f7a413..3b1b64a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,11 +7,11 @@ from daggerml_cli import api from daggerml_cli.config import Config from daggerml_cli.db import CacheError -from daggerml_cli.repo import Error, FnDag, Node, Ref, Resource +from daggerml_cli.repo import Error, Executable, FnDag, Node, Ref, Resource from tests.util import SimpleApi -SUM = Resource("./tests/fn/sum.py", adapter="dml-python-fork-adapter") -AER = Resource("./tests/fn/adapter_error.py", adapter="dml-python-fork-adapter") +SUM = Executable("./tests/fn/sum.py", adapter="dml-python-fork-adapter") +AER = Executable("./tests/fn/adapter_error.py", adapter="dml-python-fork-adapter") def env(**kwargs): @@ -223,8 +223,19 @@ def test_cached_errors(self): def test_resource(self): with SimpleApi.begin() as d0: - resource = Resource("uri:here", data={"a": 1, "b": [2, 3], "c": Resource("qwer")}) - d0.put_literal(resource) + resource = Executable( + "uri:here", + data={"a": 1, "b": [2, 3], "c": Resource("qwer")}, + prepop={"a": {"b": 2}}, + ) + node = d0.put_literal(resource, name="x") + with d0.tx(): + nodeval = node().datum + assert isinstance(nodeval, Executable) + assert nodeval.uri == resource.uri + assert nodeval.prepop != resource.prepop + assert nodeval.prepop.keys() == resource.prepop.keys() + assert d0.unroll(node) == resource def test_describe_dag(self): with self.tmpd() as cache_path: @@ -260,7 +271,7 @@ def test_describe_dag(self): ) self.assertCountEqual( [x["data_type"] for x in desc["nodes"]], - ["resource", "int", "int", "list"], + ["executable", "int", "int", "list"], ) assert len(desc["edges"]) == len(nodes) + 2 # +1 because dag->node edge assert {e["source"] for e in desc["edges"] if e["type"] == "node"} == {x for x in nodes} @@ -307,7 +318,7 @@ def test_describe_dag_w_errs(self): ) self.assertCountEqual( [x["data_type"] for x in desc["nodes"]], - ["resource", "int", "str", "error", "nonetype"], + ["executable", "int", "str", "error", "nonetype"], ) def test_backtrack_node(self): diff --git a/tests/test_cli.py b/tests/test_cli.py index 9471a10..36f77ec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,9 +12,9 @@ from click.testing import CliRunner from daggerml_cli.cli import cli, from_json, jsdumps, to_json -from daggerml_cli.repo import Resource +from daggerml_cli.repo import Executable, Resource -SUM = Resource("./tests/fn/sum.py", adapter="dml-python-fork-adapter") +SUM = Executable("./tests/fn/sum.py", adapter="dml-python-fork-adapter") @dataclass diff --git a/tests/test_repo.py b/tests/test_repo.py index 2646442..d552eec 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -6,7 +6,7 @@ import pytest -from daggerml_cli.repo import Literal, Node, Ref, Repo, Resource, unroll_datum +from daggerml_cli.repo import Executable, Literal, Node, Ref, Repo, Resource, unroll_datum @contextmanager @@ -37,13 +37,23 @@ def tmp_repo(cache_path=None): ("simple_list", [1, "string", True, None]), ("simple_dict", {"a": 1, "b": 2, "c": 3}), ("simple_set", {1, 2, 3}), - ("resource", Resource("test://uri", adapter="test-adapter")), + ("resource", Resource("test://uri")), + ( + "executable", + Executable( + "test://uri", + adapter="test-adapter", + data={"key": "value"}, + prepop={"dep1": "dep2"}, + ), + ), ( "nested_structure", { "list": [1, "string", True, None], "dict": {"a": 1, "b": [2, 3], "c": {"d": 4}}, - "resource": Resource("test://uri", adapter="test-adapter"), + "resource": Resource("test://uri"), + "executable": Executable("test://uri", adapter="test-adapter"), "set": {1, 2, 3}, }, ), @@ -94,7 +104,7 @@ def test_dump_and_load(name, test_value): ) def test_start_fn_with_builtins(op, args, expected): """Test start_fn with built-in functions using patched methods.""" - argv = [Resource(f"daggerml:{op}")] + (list(args) if isinstance(args, tuple) else [args]) + argv = [Executable(f"daggerml:{op}")] + (list(args) if isinstance(args, tuple) else [args]) with tmp_repo() as repo: with repo.tx(True): dag = repo.begin(message="test dag", name="test") @@ -105,7 +115,7 @@ def test_start_fn_with_builtins(op, args, expected): def test_adapter_called_correctly(): """Test start_fn with built-in functions using patched methods.""" - argv = [Resource("foo://bar", data={"a": "b"}, adapter="ls"), 1, 2, 3] + argv = [Executable("foo://bar", data={"a": "b"}, adapter="ls"), 1, 2, 3] with tmp_repo() as repo: with repo.tx(True): dag = repo.begin(message="test dag", name="test") From 580c324735d73505807c9063ed7e7ff1fde03f35 Mon Sep 17 00:00:00 2001 From: Aaron Niskin Date: Tue, 9 Sep 2025 20:37:21 -0700 Subject: [PATCH 2/2] wip --- src/daggerml_cli/repo.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/daggerml_cli/repo.py b/src/daggerml_cli/repo.py index c09475e..573632e 100644 --- a/src/daggerml_cli/repo.py +++ b/src/daggerml_cli/repo.py @@ -779,7 +779,8 @@ def begin(self, *, message, name=None, dump=None): argv = self(Node(Argv(self.put_datum([self.load_ref(x) for x in loaded["expr"]])))) for k, v in loaded["prepop"].items(): datum_ref = self.load_ref(v) - assert isinstance(datum_ref, Ref) and datum_ref.type == "datum", f"invalid datum ref: {v}" + if not isinstance(datum_ref, Ref) or datum_ref.type != "datum": + raise ValueError(f"invalid datum ref in `begin`: {v}") named_nodes[k] = self(Node(Literal(datum_ref))) dag = self( FnDag( @@ -831,7 +832,7 @@ def start_fn(self, index, *, argv, name=None, doc=None): else: result = self(Node(Literal(self.put_datum(result)))) nodes.append(result) - fndag = self(FnDag(nodes, {}, result, error, argv_node)) + fndag = self(FnDag(nodes, {}, result, error, argv_node().value.id, argv_node)) else: assert self.cache_path, ( "cache path is required for function execution. " @@ -843,11 +844,12 @@ def start_fn(self, index, *, argv, name=None, doc=None): "prepop": {k: self.dump_ref(v) for k, v in fn.prepop.items()}, } ) + cache_key = md5(argv_datum.encode()).hexdigest() with Cache(self.cache_path, create=False) as cache_db: - cached_val = cache_db.submit(unroll_datum(fn), md5(argv_datum.encode()).hexdigest(), argv_datum) + cached_val = cache_db.submit(unroll_datum(fn), cache_key, argv_datum) fndag = self.load_ref(cached_val) if cached_val else None if isinstance(fndag, Error): - fndag = self(FnDag([argv], {}, None, fndag, argv)) + fndag = self(FnDag([argv], {}, None, fndag, cache_key, argv)) if fndag is not None: node = self.put_node(Fn(fndag, None, argv), index=index, name=name, doc=doc) raise_ex(self.get(node).error)