-
Notifications
You must be signed in to change notification settings - Fork 0
Implemented pre-populated nodes on Executors. #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
| import json | ||
| import logging | ||
| import os | ||
| import sys | ||
| from functools import wraps | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| from contextlib import contextmanager | ||
| from dataclasses import InitVar, dataclass, field, fields, is_dataclass | ||
| from hashlib import md5 | ||
| from typing import TYPE_CHECKING, Dict, Optional, Type, Union | ||
| from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union, cast | ||
| from urllib.parse import urlparse | ||
| from uuid import uuid4 | ||
|
|
||
|
|
@@ -70,7 +70,7 @@ def to_data(obj): | |
| if isinstance(obj, (list, set)): | ||
| return [n[0], *[to_data(x) for x in obj]] | ||
| if isinstance(obj, dict): | ||
| return [n[0], *[[k, to_data(v)] for k, v in obj.items()]] | ||
| return [n[0], *[[k, to_data(v)] for k, v in sorted(obj.items(), key=lambda x: x[0])]] | ||
| if n in DATA_TYPE: | ||
| return [n, *[to_data(getattr(obj, x.name)) for x in fields(obj)]] | ||
| raise ValueError(f"no data encoding for type: {n}") | ||
|
|
@@ -82,6 +82,10 @@ def get(value): | |
| value = value() | ||
| if isinstance(value, Datum): | ||
| value = value.value | ||
| if isinstance(value, Executable): | ||
| data = {k: get(v) for k, v in value.data.items()} | ||
| prepop = {k: get(v) for k, v in value.prepop.items()} | ||
| return Executable(value.uri, adapter=value.adapter, data=data, prepop=prepop) | ||
| if isinstance(value, (type(None), str, bool, int, float, Resource)): | ||
| return value | ||
| if isinstance(value, list): | ||
|
|
@@ -213,15 +217,21 @@ def __str__(self): | |
|
|
||
|
|
||
| @repo_type(db=False) | ||
| @dataclass(frozen=True) | ||
| @dataclass | ||
| class Resource: | ||
| uri: str | ||
| data: Optional[Ref] = None # -> Datum | ||
|
|
||
|
|
||
| @repo_type(db=False) | ||
| @dataclass | ||
| class Executable(Resource): | ||
| data: Dict[str, Any] = field(default_factory=dict) # -> Ref(datum) | ||
| adapter: Optional[str] = None | ||
| prepop: Dict[str, Any] = field(default_factory=dict) # -> Ref(datum) | ||
|
Comment on lines
+225
to
+230
|
||
|
|
||
|
|
||
| @repo_type | ||
| @dataclass(frozen=True) | ||
| @dataclass | ||
| class Deleted(Resource): | ||
| @classmethod | ||
| def resource(cls, obj: Resource): | ||
|
|
@@ -274,6 +284,7 @@ def nameof(self, ref): | |
| @repo_type | ||
| @dataclass | ||
| class FnDag(Dag): | ||
| cache_key: str | ||
| argv: Optional[Ref] = None # -> node(expr) (in this dag) | ||
|
|
||
|
|
||
|
|
@@ -315,7 +326,7 @@ def error(self): | |
| @repo_type(db=False) | ||
| @dataclass | ||
| class Fn(Import): | ||
| argv: Optional[list[Ref]] = None # -> node | ||
| argv: list[Ref] = field(default_factory=list) # -> node | ||
|
|
||
|
|
||
| @repo_type | ||
|
|
@@ -511,6 +522,9 @@ def walk_ordered(self, *key): | |
| xs += [a for a in x if a not in result] | ||
| elif isinstance(x, dict): | ||
| xs += [a for a in x.values() if a not in result] | ||
| elif isinstance(x, Executable): | ||
| xs += [a for a in x.data.values() if a not in result] | ||
| xs += [a for a in x.prepop.values() if a not in result] | ||
| elif isinstance(x, (Error, Resource)): | ||
| pass # cannot recurse into these classes | ||
| elif is_dataclass(x): | ||
|
|
@@ -759,11 +773,25 @@ def begin(self, *, message, name=None, dump=None): | |
| if dump is None: | ||
| dag = self(Dag([], {}, None, None)) | ||
| else: | ||
| loaded = self.load_ref(dump) | ||
| assert loaded is not None, "failed to load dump" | ||
| loaded = cast(Dict[str, Any], from_json(dump)) | ||
| named_nodes = {} | ||
| with self.tx(True): | ||
| argv = self(Node(Argv(loaded))) | ||
| dag = self(FnDag([argv], {}, None, None, argv)) | ||
| argv = self(Node(Argv(self.put_datum([self.load_ref(x) for x in loaded["expr"]])))) | ||
| for k, v in loaded["prepop"].items(): | ||
| datum_ref = self.load_ref(v) | ||
| if not isinstance(datum_ref, Ref) or datum_ref.type != "datum": | ||
| raise ValueError(f"invalid datum ref in `begin`: {v}") | ||
| named_nodes[k] = self(Node(Literal(datum_ref))) | ||
| dag = self( | ||
| FnDag( | ||
| sorted([argv, *named_nodes.values()]), | ||
| named_nodes, | ||
| None, | ||
| None, | ||
| md5(dump.encode()).hexdigest(), | ||
| argv, | ||
| ) | ||
| ) | ||
| commit = Commit([ctx.head.commit], self(ctx.tree), self.user, self.user, message, dag_name=name) | ||
| index = self(Index(self(commit), dag)) | ||
| return index | ||
|
|
@@ -772,7 +800,7 @@ def put_node(self, data, index: Ref, name=None, doc=None): | |
| ctx = Ctx.from_head(index) | ||
| node = data if isinstance(data, Ref) else self(Node(data, doc=doc)) | ||
| if node not in ctx.dag.nodes: | ||
| ctx.dag.nodes.append(node) | ||
| ctx.dag.nodes = sorted([node, *ctx.dag.nodes], key=lambda x: x.to) | ||
| if name: | ||
| ctx.dag.names[name] = node | ||
| ctx.commit.tree = self(ctx.tree) | ||
|
|
@@ -791,11 +819,10 @@ def get_node_value(self, ref: Ref): | |
|
|
||
| def start_fn(self, index, *, argv, name=None, doc=None): | ||
| fn, *data = map(lambda x: x().datum, argv) | ||
| argv_datum = self.put_datum([x().value for x in argv]) | ||
| if fn.adapter is None: | ||
| uri = urlparse(fn.uri) | ||
| assert uri.scheme == "daggerml", f"unexpected URI scheme: {uri.scheme!r} for null adapter" | ||
| argv_node = self(Node(Argv(argv_datum))) | ||
| argv_node = self(Node(Argv(self.put_datum([x().value for x in argv])))) | ||
| result = error = None | ||
| nodes = [argv_node] | ||
| try: | ||
|
|
@@ -805,17 +832,24 @@ def start_fn(self, index, *, argv, name=None, doc=None): | |
| else: | ||
| result = self(Node(Literal(self.put_datum(result)))) | ||
| nodes.append(result) | ||
| fndag = self(FnDag(nodes, {}, result, error, argv_node)) | ||
| fndag = self(FnDag(nodes, {}, result, error, argv_node().value.id, argv_node)) | ||
| else: | ||
| assert self.cache_path, ( | ||
| "cache path is required for function execution. " | ||
| "Set the cache path via the DML_CACHE_PATH environment variable or in the config file." | ||
| ) | ||
| argv_datum = to_json( | ||
| { | ||
| "expr": [self.dump_ref(x().value) for x in argv], | ||
| "prepop": {k: self.dump_ref(v) for k, v in fn.prepop.items()}, | ||
| } | ||
| ) | ||
| cache_key = md5(argv_datum.encode()).hexdigest() | ||
| with Cache(self.cache_path, create=False) as cache_db: | ||
| cached_val = cache_db.submit(fn, argv_datum.id, self.dump_ref(argv_datum)) | ||
| cached_val = cache_db.submit(unroll_datum(fn), cache_key, argv_datum) | ||
| fndag = self.load_ref(cached_val) if cached_val else None | ||
| if isinstance(fndag, Error): | ||
| fndag = self(FnDag([argv], {}, None, fndag, argv)) | ||
| fndag = self(FnDag([argv], {}, None, fndag, cache_key, argv)) | ||
| if fndag is not None: | ||
| node = self.put_node(Fn(fndag, None, argv), index=index, name=name, doc=doc) | ||
| raise_ex(self.get(node).error) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,6 @@ | ||||||||
| from daggerml_cli.repo import Fn, Import | ||||||||
| from typing import cast | ||||||||
|
|
||||||||
| from daggerml_cli.repo import Executable, Fn, Import | ||||||||
| from daggerml_cli.util import flatten | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -18,8 +20,10 @@ def node_info(ref, *, include_argv=True): | |||||||
| } | ||||||||
| if include_argv and isinstance(node.data, Fn): | ||||||||
| info["argv"] = [node_info(x, include_argv=False) for x in node.data.argv] | ||||||||
| info["prepop"] = cast(Executable, node.data.argv[0]().datum).prepop | ||||||||
|
||||||||
| info["prepop"] = cast(Executable, node.data.argv[0]().datum).prepop | |
| datum0 = node.data.argv[0]().datum | |
| info["prepop"] = datum0.prepop if isinstance(datum0, Executable) else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Resourceclass should remain frozen (immutable) as it was before, butExecutableshould be mutable. Consider addingfrozen=Trueback to theResourceclass and keepingExecutablemutable to maintain the original immutability contract for basic resources.