From f49bdbd8d3123ffa8ada8a10f84e53aaace899a3 Mon Sep 17 00:00:00 2001 From: "dreadnode-renovate-bot[bot]" <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:44:07 -0600 Subject: [PATCH 1/9] chore(deps): update actions/checkout action to v6 (#238) | datasource | package | from | to | | ----------- | ---------------- | ------ | ------ | | github-tags | actions/checkout | v5.0.1 | v6.0.0 | Co-authored-by: dreadnode-renovate-bot[bot] <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> --- .github/workflows/meta-sync-labels.yaml | 2 +- .github/workflows/pre-commit.yaml | 2 +- .github/workflows/publish.yaml | 2 +- .github/workflows/renovate.yaml | 2 +- .github/workflows/rigging_pr_description.yaml | 2 +- .github/workflows/semgrep.yaml | 2 +- .github/workflows/template-sync.yaml | 2 +- .github/workflows/test.yaml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/meta-sync-labels.yaml b/.github/workflows/meta-sync-labels.yaml index b4acaac3..b83ab6f8 100644 --- a/.github/workflows/meta-sync-labels.yaml +++ b/.github/workflows/meta-sync-labels.yaml @@ -25,7 +25,7 @@ jobs: private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" - name: Set up git repository - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index fd96acc7..188679d4 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index d213b697..5b3ef7e6 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.github/workflows/renovate.yaml b/.github/workflows/renovate.yaml index 5f3f3754..7586cfe1 100644 --- a/.github/workflows/renovate.yaml +++ b/.github/workflows/renovate.yaml @@ -59,7 +59,7 @@ jobs: private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" - name: Checkout - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/rigging_pr_description.yaml b/.github/workflows/rigging_pr_description.yaml index 9df57c34..7fee8c03 100644 --- a/.github/workflows/rigging_pr_description.yaml +++ b/.github/workflows/rigging_pr_description.yaml @@ -13,7 +13,7 @@ jobs: contents: read steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: fetch-depth: 0 # full history for proper diffing diff --git a/.github/workflows/semgrep.yaml b/.github/workflows/semgrep.yaml index f87a0070..703c6069 100644 --- a/.github/workflows/semgrep.yaml +++ b/.github/workflows/semgrep.yaml @@ -38,7 +38,7 @@ jobs: steps: - name: Set up git repository - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/template-sync.yaml b/.github/workflows/template-sync.yaml index c68a4b23..cea60084 100644 --- a/.github/workflows/template-sync.yaml +++ b/.github/workflows/template-sync.yaml @@ -50,7 +50,7 @@ jobs: owner: "${{ github.repository_owner }}" - name: Checkout - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 34bdbfb5..7635f365 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,7 +25,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 From ca1077a09c0ee132c33412906455d94812f03487 Mon Sep 17 00:00:00 2001 From: "dreadnode-renovate-bot[bot]" <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 00:19:30 +0000 Subject: [PATCH 2/9] chore(deps): update pre-commit hook pycqa/bandit to v1.9.2 (#241) | datasource | package | from | to | | ----------- | ------------ | ----- | ----- | | github-tags | PyCQA/bandit | 1.9.1 | 1.9.2 | Co-authored-by: dreadnode-renovate-bot[bot] <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ab681ed..13b0307b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: # Python code security - repo: https://github.com/PyCQA/bandit - rev: 1.9.1 + rev: 1.9.2 hooks: - id: bandit name: Code security checks From 40ae38d701bdc1bdfeda97bfaa7a63531af7ab91 Mon Sep 17 00:00:00 2001 From: "dreadnode-renovate-bot[bot]" <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> Date: Wed, 26 Nov 2025 00:19:36 +0000 Subject: [PATCH 3/9] chore(deps): update actions/setup-python action to v6.1.0 (#242) | datasource | package | from | to | | ----------- | -------------------- | ------ | ------ | | github-tags | actions/setup-python | v6.0.0 | v6.1.0 | Co-authored-by: dreadnode-renovate-bot[bot] <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> --- .github/workflows/rigging_pr_description.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rigging_pr_description.yaml b/.github/workflows/rigging_pr_description.yaml index 7fee8c03..93b78b03 100644 --- a/.github/workflows/rigging_pr_description.yaml +++ b/.github/workflows/rigging_pr_description.yaml @@ -18,7 +18,7 @@ jobs: fetch-depth: 0 # full history for proper diffing - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.13" From 8d464ff64246c2a387c93237261c7b455f397f95 Mon Sep 17 00:00:00 2001 From: Brian Greunke <44581702+briangreunke@users.noreply.github.com> Date: Wed, 26 Nov 2025 11:45:26 -0600 Subject: [PATCH 4/9] feat(user-data): Scope credentials fetch by organization and workspace (#232) --- dreadnode/api/client.py | 7 +++++-- dreadnode/main.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index f2ceacca..b6290386 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -753,14 +753,17 @@ def export_timeseries( # User data access - def get_user_data_credentials(self) -> UserDataCredentials: + def get_user_data_credentials( + self, organization_id: UUID, workspace_id: UUID + ) -> UserDataCredentials: """ Retrieves user data credentials for secondary storage access. Returns: The user data credentials object. """ - response = self._request("GET", "/user-data/credentials") + params = {"org_id": str(organization_id), "workspace_id": str(workspace_id)} + response = self._request("GET", "/user-data/credentials", params=params) return UserDataCredentials(**response.json()) # Container registry access diff --git a/dreadnode/main.py b/dreadnode/main.py index a3d094ac..877b2468 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -687,8 +687,11 @@ def initialize(self) -> None: if self._api is not None: api = self._api self._credential_manager = CredentialManager( - credential_fetcher=lambda: api.get_user_data_credentials() + credential_fetcher=lambda: api.get_user_data_credentials( + self._organization.id, self._workspace.id + ) ) + self._credential_manager.initialize() self._fs = self._credential_manager.get_filesystem() From 558694feb22dd7fbcdab956e4de0006624b317ce Mon Sep 17 00:00:00 2001 From: Brian Greunke <44581702+briangreunke@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:26:29 -0600 Subject: [PATCH 5/9] feat(project): add project key for identification and creation (#240) * feat(project): add project key for identification and creation * chore: updated docs --- docs/sdk/agent_tools.mdx | 1015 ++++++++++++++++++++++++++++++++------ docs/sdk/api.mdx | 18 +- dreadnode/api/client.py | 9 +- dreadnode/constants.py | 2 + dreadnode/main.py | 11 +- 5 files changed, 877 insertions(+), 178 deletions(-) diff --git a/docs/sdk/agent_tools.mdx b/docs/sdk/agent_tools.mdx index 859670a9..4ec7af9b 100644 --- a/docs/sdk/agent_tools.mdx +++ b/docs/sdk/agent_tools.mdx @@ -545,8 +545,13 @@ async def python(code: str, *, timeout: int = 120) -> str: -Filesystem ----------- +FilesystemBase +-------------- + +Base class for filesystem operations with common interface. + +This abstract base class defines the standard interface for filesystem operations +and provides common utilities like path resolution and validation. ### fs\_options @@ -556,6 +561,14 @@ fs_options: AnyDict | None = Config(default=None) Extra options for the universal filesystem. +### max\_concurrent\_reads + +```python +max_concurrent_reads: int = Config(default=25) +``` + +Maximum number of concurrent file reads for grep operations. + ### multi\_modal ```python @@ -574,77 +587,6 @@ path: str | Path | UPath = Config( Base path to work from. -### cp - -```python -cp( - src: Annotated[str, "Source file"], - dest: Annotated[str, "Destination path"], -) -> FilesystemItem -``` - -Copy a file to a new location. - - -```python -@tool_method(variants=["write"], catch=True) -def cp( - self, - src: t.Annotated[str, "Source file"], - dest: t.Annotated[str, "Destination path"], -) -> FilesystemItem: - """Copy a file to a new location.""" - src_path = self._resolve(src) - dest_path = self._resolve(dest) - - if not src_path.exists(): - raise ValueError(f"'{src}' not found") - - if not src_path.is_file(): - raise ValueError(f"'{src}' is not a file") - - dest_path.parent.mkdir(parents=True, exist_ok=True) - - with src_path.open("rb") as src_file, dest_path.open("wb") as dest_file: - dest_file.write(src_file.read()) - - return FilesystemItem.from_path(dest_path, self._upath) -``` - - - - -### delete - -```python -delete(path: Annotated[str, File or directory]) -> bool -``` - -Delete a file or directory. - - -```python -@tool_method(variants=["write"], catch=True) -def delete( - self, - path: t.Annotated[str, "File or directory"], -) -> bool: - """Delete a file or directory.""" - _path = self._resolve(path) - if not _path.exists(): - raise ValueError(f"'{path}' not found") - - if _path.is_dir(): - _path.rmdir() - else: - _path.unlink() - - return True -``` - - - - ### glob ```python @@ -661,7 +603,7 @@ include \*\* for recursive matching, such as '/path/\**/dir/*.py'. ```python @tool_method(catch=True) -def glob( +async def glob( self, pattern: t.Annotated[str, "Glob pattern for file matching"], ) -> list[FilesystemItem]: @@ -669,7 +611,7 @@ def glob( Returns a list of paths matching a valid glob pattern. The pattern can include ** for recursive matching, such as '/path/**/dir/*.py'. """ - matches = list(self._upath.glob(pattern)) + matches = await asyncio.to_thread(lambda: list(self._upath.glob(pattern))) # Check to make sure all matches are within the base path for match in matches: @@ -699,7 +641,7 @@ grep( recursive: Annotated[ bool, "Search recursively in directories" ] = False, -) -> list[GrepMatch] +) -> list[GrepMatch | str] ``` Search for pattern in files and return matches with line numbers and context. @@ -709,14 +651,14 @@ For directories, all text files will be searched. ```python @tool_method(variants=["read", "write"], catch=True) -def grep( +async def grep( self, pattern: t.Annotated[str, "Regular expression pattern to search for"], path: t.Annotated[str, "File or directory path to search in"], *, max_results: t.Annotated[int, "Maximum number of results to return"] = 100, recursive: t.Annotated[bool, "Search recursively in directories"] = False, -) -> list[GrepMatch]: +) -> list[GrepMatch | str]: """ Search for pattern in files and return matches with line numbers and context. @@ -734,25 +676,31 @@ def grep( files_to_search.append(target_path) elif target_path.is_dir(): files_to_search.extend( - list(target_path.rglob("*") if recursive else target_path.glob("*")), + await asyncio.to_thread( + lambda: list(target_path.rglob("*") if recursive else target_path.glob("*")) + ), ) - matches: list[GrepMatch] = [] - for file_path in [f for f in files_to_search if f.is_file()]: - if len(matches) >= max_results: - break + # Filter to files only and check size + files_to_search = [ + f for f in files_to_search if f.is_file() and f.stat().st_size <= MAX_GREP_FILE_SIZE + ] - if file_path.stat().st_size > MAX_GREP_FILE_SIZE: - continue + async def search_file(file_path: UPath) -> list[GrepMatch | str]: + """Search a single file for matches.""" + file_matches: list[GrepMatch | str] = [] + try: + # Use the subclass's read_file method + content = await self.read_file(self._relative(file_path)) + if isinstance(content, bytes): + content = content.decode("utf-8") + elif isinstance(content, rg.ContentImageUrl): + # Can't grep images + return [] - with contextlib.suppress(Exception): - with file_path.open("r") as f: - lines = f.readlines() + lines = content.splitlines(keepends=True) for i, line in enumerate(lines): - if len(matches) >= max_results: - break - if regex.search(line): line_num = i + 1 context_start = max(0, i - 1) @@ -765,7 +713,7 @@ def grep( context.append(f"{prefix} {j + 1}: {shorten_string(line_text, 80)}") rel_path = self._relative(file_path) - matches.append( + file_matches.append( GrepMatch( path=rel_path, line_number=line_num, @@ -773,8 +721,38 @@ def grep( context=context, ), ) + except ( + FileNotFoundError, + PermissionError, + IsADirectoryError, + UnicodeDecodeError, + OSError, + ValueError, + ) as e: + file_matches.append(f"Error occurred while searching file {file_path}: {e}") + + return file_matches + + # Search files in parallel with concurrency limit + semaphore = asyncio.Semaphore(self.max_concurrent_reads) + + async def search_file_limited(file_path: UPath) -> list[GrepMatch | str]: + """Search a single file with semaphore to limit concurrency.""" + async with semaphore: + return await search_file(file_path) + + all_matches: list[GrepMatch | str] = [] + results = await asyncio.gather( + *[search_file_limited(file_path) for file_path in files_to_search] + ) + + # Flatten results and respect max_results + for file_matches in results: + all_matches.extend(file_matches) + if len(all_matches) >= max_results: + break - return matches + return all_matches[:max_results] ``` @@ -793,7 +771,7 @@ List the contents of a directory. ```python @tool_method(variants=["read", "write"], catch=True) -def ls( +async def ls( self, path: t.Annotated[str, "Directory path to list"] = "", ) -> list[FilesystemItem]: @@ -806,7 +784,7 @@ def ls( if not _path.is_dir(): raise ValueError(f"'{path}' is not a directory.") - items = list(_path.iterdir()) + items = await asyncio.to_thread(lambda: list(_path.iterdir())) return [FilesystemItem.from_path(item, self._upath) for item in items] ``` @@ -826,18 +804,218 @@ Create a directory and any necessary parent directories. ```python @tool_method(variants=["write"], catch=True) -def mkdir( +async def mkdir( self, path: t.Annotated[str, "Directory path to create"], ) -> FilesystemItem: """Create a directory and any necessary parent directories.""" dir_path = self._resolve(path) - dir_path.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(lambda: dir_path.mkdir(parents=True, exist_ok=True)) return FilesystemItem.from_path(dir_path, self._upath) ``` + + +### read\_file + +```python +read_file( + path: Annotated[str, "Path to the file to read"], +) -> rg.ContentImageUrl | str +``` + +Must be implemented in subclasses + + +```python +async def read_file( + self, path: t.Annotated[str, "Path to the file to read"] +) -> rg.ContentImageUrl | str: + """Must be implemented in subclasses""" + raise NotImplementedError("Subclasses must implement") +``` + + + + +FilesystemItem +-------------- + +```python +FilesystemItem( + type: Literal["file", "dir"], + name: str, + size: int | None = None, + modified: str | None = None, +) +``` + +Item in the filesystem + +### from\_path + +```python +from_path( + path: UPath, relative_base: UPath +) -> FilesystemItem +``` + +Create an Item from a UPath. + +**Parameters:** + +* **`path`** + (`UPath`) + –The UPath to create an item from +* **`relative_base`** + (`UPath`) + –The base path to calculate relative paths from + +**Returns:** + +* `FilesystemItem` + –FilesystemItem representing the path + +**Raises:** + +* `ValueError` + –If the path is neither a file nor a directory + + +```python +@classmethod +def from_path(cls, path: "UPath", relative_base: "UPath") -> "FilesystemItem": + """Create an Item from a UPath. + + Args: + path: The UPath to create an item from + relative_base: The base path to calculate relative paths from + + Returns: + FilesystemItem representing the path + + Raises: + ValueError: If the path is neither a file nor a directory + """ + base_path: str = str(relative_base.resolve()) + full_path: str = str(path.resolve()) + relative: str = full_path[len(base_path) :] + + if path.is_dir(): + return cls(type="dir", name=relative, size=None, modified=None) + + if path.is_file(): + return cls( + type="file", + name=relative, + size=path.stat().st_size, + modified=datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S", + ), + ) + + raise ValueError(f"'{relative}' is not a valid file or directory.") +``` + + + + +GrepMatch +--------- + +```python +GrepMatch( + path: str, + line_number: int, + line: str, + context: list[str], +) +``` + +Individual search match + +LocalFilesystem +--------------- + +Local filesystem implementation using aiofiles. + +Supports operations on the local disk using async file I/O. + +### cp + +```python +cp( + src: Annotated[str, "Source file"], + dest: Annotated[str, "Destination path"], +) -> FilesystemItem +``` + +Copy a file to a new location. + + +```python +@tool_method(variants=["write"], catch=True) +async def cp( + self, + src: t.Annotated[str, "Source file"], + dest: t.Annotated[str, "Destination path"], +) -> FilesystemItem: + """Copy a file to a new location.""" + src_path = self._resolve(src) + dest_path = self._resolve(dest) + + if not src_path.exists(): + raise ValueError(f"'{src}' not found") + + if not src_path.is_file(): + raise ValueError(f"'{src}' is not a file") + + await asyncio.to_thread(lambda: dest_path.parent.mkdir(parents=True, exist_ok=True)) + + async with ( + aiofiles.open(src_path, "rb") as src_file, + aiofiles.open(dest_path, "wb") as dest_file, + ): + content = await src_file.read() + await dest_file.write(content) + + return FilesystemItem.from_path(dest_path, self._upath) +``` + + + + +### delete + +```python +delete(path: Annotated[str, File or directory]) -> bool +``` + +Delete a file or directory. + + +```python +@tool_method(variants=["write"], catch=True) +async def delete( + self, + path: t.Annotated[str, "File or directory"], +) -> bool: + """Delete a file or directory.""" + _path = self._resolve(path) + if not _path.exists(): + raise ValueError(f"'{path}' not found") + + if _path.is_dir(): + await asyncio.to_thread(_path.rmdir) + else: + await asyncio.to_thread(_path.unlink) + + return True +``` + + ### mv @@ -854,7 +1032,7 @@ Move a file or directory to a new location. ```python @tool_method(variants=["write"], catch=True) -def mv( +async def mv( self, src: t.Annotated[str, "Source path"], dest: t.Annotated[str, "Destination path"], @@ -866,9 +1044,9 @@ def mv( if not src_path.exists(): raise ValueError(f"'{src}' not found") - dest_path.parent.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(lambda: dest_path.parent.mkdir(parents=True, exist_ok=True)) - src_path.rename(dest_path) + await asyncio.to_thread(lambda: src_path.rename(dest_path)) return FilesystemItem.from_path(dest_path, self._upath) ``` @@ -886,19 +1064,40 @@ read_file( Read a file and return its contents. +**Returns:** + +* `ContentImageUrl | str` + –+ str: The file contents decoded as UTF-8 if possible. +* `ContentImageUrl | str` + –+ rg.ContentImageUrl: If the file is non-text and multi\_modal is True. + + +Callers should be prepared to handle raw bytes if the file is not valid UTF-8 and multi\_modal is False. + + ```python @tool_method(variants=["read", "write"], catch=True) -def read_file( +async def read_file( self, path: t.Annotated[str, "Path to the file to read"], ) -> rg.ContentImageUrl | str: - """Read a file and return its contents.""" + """ + Read a file and return its contents. + + Returns: + - str: The file contents decoded as UTF-8 if possible. + - rg.ContentImageUrl: If the file is non-text and multi_modal is True. + + Note: + Callers should be prepared to handle raw bytes if the file is not valid UTF-8 and multi_modal is False. + """ _path = self._resolve(path) - content = _path.read_bytes() + async with aiofiles.open(_path, "rb") as f: + content = await f.read() try: - return content.decode("utf-8") + return str(content.decode("utf-8")) except UnicodeDecodeError as e: if self.multi_modal: return rg.ContentImageUrl.from_file(path) @@ -926,7 +1125,7 @@ Negative line numbers count from the end. ```python @tool_method(variants=["read", "write"], catch=True) -def read_lines( +async def read_lines( self, path: t.Annotated[str, "Path to the file to read"], start_line: t.Annotated[int, "Start line number (0-indexed)"] = 0, @@ -944,8 +1143,8 @@ def read_lines( if not _path.is_file(): raise ValueError(f"'{path}' is not a file.") - with _path.open("r") as f: - lines = f.readlines() + async with aiofiles.open(_path) as f: + lines = await f.readlines() if start_line < 0: start_line = len(lines) + start_line @@ -978,15 +1177,15 @@ Create or overwrite a file with the given contents. ```python @tool_method(variants=["write"], catch=True) -def write_file( +async def write_file( self, path: t.Annotated[str, "Path to write the file to"], contents: t.Annotated[str, "Content to write to the file"], ) -> FilesystemItem: """Create or overwrite a file with the given contents.""" - _path = self._safe_create_file(path) - with _path.open("w") as f: - f.write(contents) + _path = await self._safe_create_file(path) + async with aiofiles.open(_path, "w") as f: + await f.write(contents) return FilesystemItem.from_path(_path, self._upath) ``` @@ -994,12 +1193,44 @@ def write_file( -### write\_lines +### write\_file\_bytes ```python -write_lines( - path: Annotated[str, "Path to write to"], - contents: Annotated[str, "Content to write"], +write_file_bytes( + path: Annotated[str, "Path to write the file to"], + byte_data: Annotated[ + bytes, "Bytes to write to the file" + ], +) -> FilesystemItem +``` + +Create or overwrite a file with the given bytes. + + +```python +@tool_method(variants=["write"], catch=True) +async def write_file_bytes( + self, + path: t.Annotated[str, "Path to write the file to"], + byte_data: t.Annotated[bytes, "Bytes to write to the file"], +) -> FilesystemItem: + """Create or overwrite a file with the given bytes.""" + _path = await self._safe_create_file(path) + async with aiofiles.open(_path, "wb") as f: + await f.write(byte_data) + + return FilesystemItem.from_path(_path, self._upath) +``` + + + + +### write\_lines + +```python +write_lines( + path: Annotated[str, "Path to write to"], + contents: Annotated[str, "Content to write"], insert_line: Annotated[ int, "Line number to insert at (negative counts from end)", @@ -1016,7 +1247,7 @@ Mode can be 'insert' to add lines or 'overwrite' to replace lines. ```python @tool_method(variants=["write"], catch=True) -def write_lines( +async def write_lines( self, path: t.Annotated[str, "Path to write to"], contents: t.Annotated[str, "Content to write"], @@ -1030,11 +1261,11 @@ def write_lines( if mode not in ["insert", "overwrite"]: raise ValueError("Invalid mode. Use 'insert' or 'overwrite'") - _path = self._safe_create_file(path) + _path = await self._safe_create_file(path) lines: list[str] = [] - with _path.open("r") as f: - lines = f.readlines() + async with aiofiles.open(_path) as f: + lines = await f.readlines() # Normalize line endings in content content_lines = [ @@ -1054,8 +1285,8 @@ def write_lines( elif mode == "overwrite": lines[insert_line : insert_line + len(content_lines)] = content_lines - with _path.open("w") as f: - f.writelines(lines) + async with aiofiles.open(_path, "w") as f: + await f.writelines(lines) return FilesystemItem.from_path(_path, self._upath) ``` @@ -1063,72 +1294,534 @@ def write_lines( -FilesystemItem --------------- +S3Filesystem +------------ + +S3 filesystem implementation using aioboto3. + +Supports operations on AWS S3 buckets with async I/O. +Requires aioboto3 and properly configured AWS credentials. + +### cp ```python -FilesystemItem( - type: Literal["file", "dir"], - name: str, - size: int | None = None, - modified: str | None = None, -) +cp( + src: Annotated[str, "Source file"], + dest: Annotated[str, "Destination path"], +) -> FilesystemItem ``` -Item in the filesystem +Copy a file to a new location within S3. -### from\_path + +```python +@tool_method(variants=["write"], catch=True) +async def cp( + self, + src: t.Annotated[str, "Source file"], + dest: t.Annotated[str, "Destination path"], +) -> FilesystemItem: + """Copy a file to a new location within S3.""" + src_path = self._resolve(src) + dest_path = self._resolve(dest) + + if not src_path.exists(): + raise ValueError(f"'{src}' not found") + + if not src_path.is_file(): + raise ValueError(f"'{src}' is not a file") + + src_bucket, src_key = self._get_s3_parts(src_path) + dest_bucket, dest_key = self._get_s3_parts(dest_path) + + session = self._get_session() + async with session.client("s3") as s3_client: + # Use S3 copy_object for efficient server-side copy + copy_source = {"Bucket": src_bucket, "Key": src_key} + await s3_client.copy_object(CopySource=copy_source, Bucket=dest_bucket, Key=dest_key) + + # Return FilesystemItem without calling stat + relative = self._relative(dest_path) + return FilesystemItem(type="file", name=relative, size=None, modified=None) +``` + + + + +### delete ```python -from_path( - path: UPath, relative_base: UPath +delete(path: Annotated[str, File or directory]) -> bool +``` + +Delete a file from S3. + + +```python +@tool_method(variants=["write"], catch=True) +async def delete( + self, + path: t.Annotated[str, "File or directory"], +) -> bool: + """Delete a file from S3.""" + _path = self._resolve(path) + + if not _path.exists(): + raise ValueError(f"'{path}' not found") + + bucket, key = self._get_s3_parts(_path) + + session = self._get_session() + async with session.client("s3") as s3_client: + await s3_client.delete_object(Bucket=bucket, Key=key) + + return True +``` + + + + +### mkdir + +```python +mkdir( + path: Annotated[str, "Directory path to create"], ) -> FilesystemItem ``` -Create an Item from a UPath +Create a directory marker in S3. + +Note: S3 doesn't have true directories. This creates an empty object +with a trailing slash to simulate a directory for compatibility. ```python -@classmethod -def from_path(cls, path: "UPath", relative_base: "UPath") -> "FilesystemItem": - """Create an Item from a UPath""" +@tool_method(variants=["write"], catch=True) +async def mkdir( + self, + path: t.Annotated[str, "Directory path to create"], +) -> FilesystemItem: + """ + Create a directory marker in S3. - base_path = str(relative_base.resolve()) - full_path = str(path.resolve()) - relative = full_path[len(base_path) :] + Note: S3 doesn't have true directories. This creates an empty object + with a trailing slash to simulate a directory for compatibility. + """ + _path = self._resolve(path) + bucket, key = self._get_s3_parts(_path) - if path.is_dir(): - return cls(type="dir", name=relative, size=None, modified=None) + # Ensure key ends with slash for directory marker + if not key.endswith("/"): + key += "/" - if path.is_file(): - return cls( - type="file", - name=relative, - size=path.stat().st_size, - modified=datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).strftime( - "%Y-%m-%d %H:%M:%S", - ), - ) + session = self._get_session() + async with session.client("s3") as s3_client: + # Create empty object with trailing slash + await s3_client.put_object(Bucket=bucket, Key=key, Body=b"") - raise ValueError(f"'{relative}' is not a valid file or directory.") + relative = self._relative(_path) + return FilesystemItem(type="dir", name=relative, size=None, modified=None) ``` -GrepMatch ---------- +### mv ```python -GrepMatch( - path: str, - line_number: int, - line: str, - context: list[str], -) +mv( + src: Annotated[str, "Source path"], + dest: Annotated[str, "Destination path"], +) -> FilesystemItem ``` -Individual search match +Move a file to a new location within S3 (copy then delete). + + +```python +@tool_method(variants=["write"], catch=True) +async def mv( + self, + src: t.Annotated[str, "Source path"], + dest: t.Annotated[str, "Destination path"], +) -> FilesystemItem: + """Move a file to a new location within S3 (copy then delete).""" + # Copy to destination + result = await self.cp(src, dest) + + # Delete source + await self.delete(src) + + return result +``` + + + + +### read\_file + +```python +read_file( + path: Annotated[str, "Path to the file to read"], +) -> str +``` + +Read a file from S3 and return its contents. + +**Returns:** + +* `str` + –+ str: The file contents decoded as UTF-8 if possible. + + +multi\_modal support for S3 is limited as we can't easily determine +image types without downloading. Returns bytes for non-UTF-8 content. + + + +```python +@tool_method(variants=["read", "write"], catch=True) +async def read_file( + self, + path: t.Annotated[str, "Path to the file to read"], +) -> str: + """ + Read a file from S3 and return its contents. + + Returns: + - str: The file contents decoded as UTF-8 if possible. + + Note: + multi_modal support for S3 is limited as we can't easily determine + image types without downloading. Returns bytes for non-UTF-8 content. + """ + _path = self._resolve(path) + bucket, key = self._get_s3_parts(_path) + + session = self._get_session() + async with session.client("s3") as s3_client: + response = await s3_client.get_object(Bucket=bucket, Key=key) + content = await response["Body"].read() + + try: + return str(content.decode("utf-8")) + except UnicodeDecodeError as e: + raise ValueError("File is not a valid text file.") from e +``` + + + + +### read\_lines + +```python +read_lines( + path: Annotated[str, "Path to the file to read"], + start_line: Annotated[ + int, "Start line number (0-indexed)" + ] = 0, + end_line: Annotated[int, "End line number"] = -1, +) -> str +``` + +Read a partial file from S3 and return the contents. +Negative line numbers count from the end. + + +```python +@tool_method(variants=["read", "write"], catch=True) +async def read_lines( + self, + path: t.Annotated[str, "Path to the file to read"], + start_line: t.Annotated[int, "Start line number (0-indexed)"] = 0, + end_line: t.Annotated[int, "End line number"] = -1, +) -> str: + """ + Read a partial file from S3 and return the contents. + Negative line numbers count from the end. + """ + content = await self.read_file(path) + if isinstance(content, bytes): + content = content.decode("utf-8") + elif isinstance(content, rg.ContentImageUrl): + raise TypeError("Cannot read lines from non-text content") + + lines = content.splitlines(keepends=True) + + if start_line < 0: + start_line = len(lines) + start_line + + if end_line < 0: + end_line = len(lines) + end_line + 1 + + start_line = max(0, min(start_line, len(lines))) + end_line = max(start_line, min(end_line, len(lines))) + + return "".join(lines[start_line:end_line]) +``` + + + + +### write\_file + +```python +write_file( + path: Annotated[str, "Path to write the file to"], + contents: Annotated[ + str, "Content to write to the file" + ], +) -> FilesystemItem +``` + +Create or overwrite a file in S3 with the given contents. + + +```python +@tool_method(variants=["write"], catch=True) +async def write_file( + self, + path: t.Annotated[str, "Path to write the file to"], + contents: t.Annotated[str, "Content to write to the file"], +) -> FilesystemItem: + """Create or overwrite a file in S3 with the given contents.""" + _path = self._resolve(path) + bucket, key = self._get_s3_parts(_path) + + session = self._get_session() + async with session.client("s3") as s3_client: + await s3_client.put_object(Bucket=bucket, Key=key, Body=contents.encode("utf-8")) + + # Return FilesystemItem without calling stat (S3 put is async) + relative = self._relative(_path) + return FilesystemItem( + type="file", + name=relative, + size=len(contents.encode("utf-8")), + modified=None, + ) +``` + + + + +### write\_file\_bytes + +```python +write_file_bytes( + path: Annotated[str, "Path to write the file to"], + byte_data: Annotated[ + bytes, "Bytes to write to the file" + ], +) -> FilesystemItem +``` + +Create or overwrite a file in S3 with the given bytes. + + +```python +@tool_method(variants=["write"], catch=True) +async def write_file_bytes( + self, + path: t.Annotated[str, "Path to write the file to"], + byte_data: t.Annotated[bytes, "Bytes to write to the file"], +) -> FilesystemItem: + """Create or overwrite a file in S3 with the given bytes.""" + _path = self._resolve(path) + bucket, key = self._get_s3_parts(_path) + + session = self._get_session() + async with session.client("s3") as s3_client: + await s3_client.put_object(Bucket=bucket, Key=key, Body=byte_data) + + # Return FilesystemItem without calling stat (S3 put is async) + relative = self._relative(_path) + return FilesystemItem(type="file", name=relative, size=len(byte_data), modified=None) +``` + + + + +### write\_lines + +```python +write_lines( + path: Annotated[str, "Path to write to"], + contents: Annotated[str, "Content to write"], + insert_line: Annotated[ + int, + "Line number to insert at (negative counts from end)", + ] = -1, + mode: Annotated[ + str, "insert" or "overwrite" + ] = "insert", +) -> FilesystemItem | str +``` + +Write content to a specific line in an S3 file. +Mode can be 'insert' to add lines or 'overwrite' to replace lines. + + +```python +@tool_method(variants=["write"], catch=True) +async def write_lines( + self, + path: t.Annotated[str, "Path to write to"], + contents: t.Annotated[str, "Content to write"], + insert_line: t.Annotated[int, "Line number to insert at (negative counts from end)"] = -1, + mode: t.Annotated[str, "'insert' or 'overwrite'"] = "insert", +) -> FilesystemItem | str: + """ + Write content to a specific line in an S3 file. + Mode can be 'insert' to add lines or 'overwrite' to replace lines. + """ + if mode not in ["insert", "overwrite"]: + raise TypeError("Invalid mode. Use 'insert' or 'overwrite'") + + # Read existing content + try: + existing_content = await self.read_file(path) + if isinstance(existing_content, bytes): + existing_content = existing_content.decode("utf-8") + elif isinstance(existing_content, rg.ContentImageUrl): + logger.warning("Cannot write lines to non-text content") + lines = [] + lines = existing_content.splitlines(keepends=True) + except FileNotFoundError: + # File doesn't exist, start with empty lines + lines = [] + except (PermissionError, IsADirectoryError, ClientError, BotoCoreError, ValueError) as e: + # File doesn't exist or can't be read, start with empty lines + return f"Error occurred while trying to write to the supplied filepath {path}: {e}" + + # Normalize line endings in content + content_lines = [ + line + "\n" if not line.endswith("\n") else line + for line in contents.splitlines(keepends=False) + ] + + # Calculate insert position and ensure it's within bounds + if insert_line < 0: + insert_line = len(lines) + insert_line + 1 + + insert_line = max(0, min(insert_line, len(lines))) + + # Apply the update + if mode == "insert": + lines[insert_line:insert_line] = content_lines + elif mode == "overwrite": + lines[insert_line : insert_line + len(content_lines)] = content_lines + + # Write back + new_content = "".join(lines) + return await self.write_file(path, new_content) +``` + + + + +Filesystem +---------- + +```python +Filesystem( + path: str | Path | UPath, **kwargs: Any +) -> LocalFilesystem | S3Filesystem +``` + +Factory function to create the appropriate filesystem instance based on path. + +Automatically detects the filesystem type from the path protocol and returns +the corresponding implementation (LocalFilesystem or S3Filesystem). + +**Parameters:** + +* **`path`** + (`str | Path | UPath`) + –Local path, S3 URL (s3://), or other supported protocol +* **`**kwargs`** + (`Any`, default: + `{}` + ) + –Additional arguments passed to the filesystem constructor + +**Returns:** + +* `LocalFilesystem | S3Filesystem` + –LocalFilesystem for local paths, S3Filesystem for S3 URLs + +**Examples:** + +```python +>>> # Local filesystem +>>> fs = Filesystem(path="/tmp/data") +>>> isinstance(fs, LocalFilesystem) +True +``` + +```python +>>> # S3 filesystem +>>> fs = Filesystem(path="s3://my-bucket/data") +>>> isinstance(fs, S3Filesystem) +True +``` + + +```python +def Filesystem( # noqa: N802 + path: str | Path | UPath, **kwargs: t.Any +) -> LocalFilesystem | S3Filesystem: + """ + Factory function to create the appropriate filesystem instance based on path. + + Automatically detects the filesystem type from the path protocol and returns + the corresponding implementation (LocalFilesystem or S3Filesystem). + + Args: + path: Local path, S3 URL (s3://), or other supported protocol + **kwargs: Additional arguments passed to the filesystem constructor + + Returns: + LocalFilesystem for local paths, S3Filesystem for S3 URLs + + Examples: + >>> # Local filesystem + >>> fs = Filesystem(path="/tmp/data") + >>> isinstance(fs, LocalFilesystem) + True + + >>> # S3 filesystem + >>> fs = Filesystem(path="s3://my-bucket/data") + >>> isinstance(fs, S3Filesystem) + True + """ + # Check if it's a string starting with s3:// + if isinstance(path, str) and path.startswith("s3://"): + return S3Filesystem(path=path, **kwargs) + + # Check if it's a UPath with S3 protocol + if isinstance(path, UPath) and path.protocol in ["s3", "s3a"]: + return S3Filesystem(path=path, **kwargs) + + # Try to create UPath and check protocol + try: + fs_options = kwargs.get("fs_options", {}) + upath = UPath(str(path), **fs_options) + if upath.protocol in ["s3", "s3a"]: + return S3Filesystem(path=path, **kwargs) + except (TypeError, ValueError) as e: + # If UPath creation fails, fall through to local + logger.warning( + f"Upath initialization failed ({type(e).__name__}: {e}), defaulting to local path" + ) + return LocalFilesystem(path=path, **kwargs) + + # Default to local filesystem + return LocalFilesystem(path=path, **kwargs) +``` + + + Memory ------ diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index e892a8d1..e50127db 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -117,7 +117,8 @@ def __init__( ```python create_project( - name: str | UUID | None = None, + name: str, + key: str, workspace_id: UUID | None = None, organization_id: UUID | None = None, ) -> Project @@ -128,9 +129,7 @@ Creates a new project. **Parameters:** * **`name`** - (`str | UUID | None`, default: - `None` - ) + (`str`) –The name of the project. If None, a default name will be used. * **`workspace_id`** (`UUID | None`, default: @@ -152,7 +151,8 @@ Creates a new project. ```python def create_project( self, - name: str | UUID | None = None, + name: str, + key: str, workspace_id: UUID | None = None, organization_id: UUID | None = None, ) -> Project: @@ -167,8 +167,8 @@ def create_project( Project: The created Project object. """ payload: dict[str, t.Any] = {} - if name is not None: - payload["name"] = name + payload["name"] = name + payload["key"] = key if workspace_id is not None: payload["workspace_id"] = str(workspace_id) if organization_id is not None: @@ -861,7 +861,7 @@ Retrieves details of a specific project. * **`project_identifier`** (`str | UUID`) - –The project identifier. ID, name, or slug. + –The project identifier. ID or key. **Returns:** @@ -874,7 +874,7 @@ def get_project(self, project_identifier: str | UUID, workspace_id: UUID) -> Pro """Retrieves details of a specific project. Args: - project_identifier (str | UUID): The project identifier. ID, name, or slug. + project_identifier (str | UUID): The project identifier. ID or key. Returns: Project: The Project object. diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index b6290386..36905968 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -288,7 +288,7 @@ def get_project(self, project_identifier: str | UUID, workspace_id: UUID) -> Pro """Retrieves details of a specific project. Args: - project_identifier (str | UUID): The project identifier. ID, name, or slug. + project_identifier (str | UUID): The project identifier. ID or key. Returns: Project: The Project object. @@ -302,7 +302,8 @@ def get_project(self, project_identifier: str | UUID, workspace_id: UUID) -> Pro def create_project( self, - name: str | UUID | None = None, + name: str, + key: str, workspace_id: UUID | None = None, organization_id: UUID | None = None, ) -> Project: @@ -317,8 +318,8 @@ def create_project( Project: The created Project object. """ payload: dict[str, t.Any] = {} - if name is not None: - payload["name"] = name + payload["name"] = name + payload["key"] = key if workspace_id is not None: payload["workspace_id"] = str(workspace_id) if organization_id is not None: diff --git a/dreadnode/constants.py b/dreadnode/constants.py index 6b0b7c10..f4b96d54 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -35,6 +35,8 @@ DEFAULT_WORKSPACE_NAME = "Personal Workspace" # default project name DEFAULT_PROJECT_NAME = "Default" +# default project key +DEFAULT_PROJECT_KEY = "default" # # Environment Variable Names diff --git a/dreadnode/main.py b/dreadnode/main.py index 877b2468..87ed592d 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -32,6 +32,7 @@ ) from dreadnode.constants import ( DEFAULT_LOCAL_STORAGE_DIR, + DEFAULT_PROJECT_KEY, DEFAULT_PROJECT_NAME, DEFAULT_SERVER_URL, ENV_API_KEY, @@ -348,8 +349,8 @@ def _resolve_project(self) -> None: """ Resolve the project to use based on configuration. - If a project is specified by name and doesn't exist, it will be created. - If no project is specified, it will use or create one named 'default'. + If a project is specified by key and doesn't exist, it will be created. + If no project is specified, it will use or create one with key 'default'. Raises: RuntimeError: If the API client is not initialized. @@ -366,7 +367,7 @@ def _resolve_project(self) -> None: found_project: Project | None = None try: found_project = self._api.get_project( - project_identifier=self.project or DEFAULT_PROJECT_NAME, + project_identifier=self.project or DEFAULT_PROJECT_KEY, workspace_id=self._workspace.id, ) except RuntimeError as e: @@ -378,7 +379,9 @@ def _resolve_project(self) -> None: if not found_project: # create it in the workspace found_project = self._api.create_project( - name=self.project or DEFAULT_PROJECT_NAME, workspace_id=self._workspace.id + name=self.project or DEFAULT_PROJECT_NAME, + key=self.project or DEFAULT_PROJECT_KEY, + workspace_id=self._workspace.id, ) # This is what's used in all of the Traces/Spans/Runs self._project = found_project From 462599adfca7cee0e82f955d56ae9f29ddab0ee8 Mon Sep 17 00:00:00 2001 From: Vincent Abruzzo <6225496+vabruzzo@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:32:21 -0500 Subject: [PATCH 6/9] fix: score value race condition (#244) * bugfix score value race condition * revert agent change --- dreadnode/optimization/stop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dreadnode/optimization/stop.py b/dreadnode/optimization/stop.py index 49feaccd..cb972885 100644 --- a/dreadnode/optimization/stop.py +++ b/dreadnode/optimization/stop.py @@ -63,10 +63,11 @@ def score_value( """ def stop(trials: list[Trial]) -> bool: # noqa: PLR0911 - if not trials: + finished_trials = [t for t in trials if t.status == "finished"] + if not finished_trials: return False - trial = trials[-1] + trial = finished_trials[-1] value_to_check = trial.scores.get(metric_name) if metric_name else trial.score if value_to_check is None: return False From 7d91c211faaeaa5da4e4aa2a13077aad59a7e283 Mon Sep 17 00:00:00 2001 From: Michael Kouremetis Date: Tue, 2 Dec 2025 18:33:33 -0500 Subject: [PATCH 7/9] fix: span duration property has inconsistent time formats (#245) * fix * fix --------- Co-authored-by: Michael Kouremetis --- dreadnode/tracing/span.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 3d6ae852..05b17599 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -225,8 +225,10 @@ def duration(self) -> float: """Get the duration of the span in seconds.""" if self._span is None: return 0.0 - end_time = self.end_time or time.time() - return (end_time - self.start_time) if self.start_time else 0.0 + end_time = self.end_time or time.time_ns() + if not self.start_time: + return 0.0 + return (end_time - self.start_time) / 1e9 def set_tags(self, tags: t.Sequence[str]) -> None: tags = [tags] if isinstance(tags, str) else list(tags) From 1611e2a9073563d33cc347f32043fc3061bed99c Mon Sep 17 00:00:00 2001 From: "dreadnode-renovate-bot[bot]" <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> Date: Wed, 3 Dec 2025 00:19:34 +0000 Subject: [PATCH 8/9] chore(deps): update actions/checkout action to v6.0.1 (#246) | datasource | package | from | to | | ----------- | ---------------- | ------ | ------ | | github-tags | actions/checkout | v6.0.0 | v6.0.1 | Co-authored-by: dreadnode-renovate-bot[bot] <184170622+dreadnode-renovate-bot[bot]@users.noreply.github.com> --- .github/workflows/meta-sync-labels.yaml | 2 +- .github/workflows/pre-commit.yaml | 2 +- .github/workflows/publish.yaml | 2 +- .github/workflows/renovate.yaml | 2 +- .github/workflows/rigging_pr_description.yaml | 2 +- .github/workflows/semgrep.yaml | 2 +- .github/workflows/template-sync.yaml | 2 +- .github/workflows/test.yaml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/meta-sync-labels.yaml b/.github/workflows/meta-sync-labels.yaml index b83ab6f8..366f2b7e 100644 --- a/.github/workflows/meta-sync-labels.yaml +++ b/.github/workflows/meta-sync-labels.yaml @@ -25,7 +25,7 @@ jobs: private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" - name: Set up git repository - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 188679d4..a2ed593b 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 5b3ef7e6..5ed356b5 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.github/workflows/renovate.yaml b/.github/workflows/renovate.yaml index 7586cfe1..abeccb6e 100644 --- a/.github/workflows/renovate.yaml +++ b/.github/workflows/renovate.yaml @@ -59,7 +59,7 @@ jobs: private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" - name: Checkout - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/rigging_pr_description.yaml b/.github/workflows/rigging_pr_description.yaml index 93b78b03..3c709c88 100644 --- a/.github/workflows/rigging_pr_description.yaml +++ b/.github/workflows/rigging_pr_description.yaml @@ -13,7 +13,7 @@ jobs: contents: read steps: - - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 0 # full history for proper diffing diff --git a/.github/workflows/semgrep.yaml b/.github/workflows/semgrep.yaml index 703c6069..aa9ef6cd 100644 --- a/.github/workflows/semgrep.yaml +++ b/.github/workflows/semgrep.yaml @@ -38,7 +38,7 @@ jobs: steps: - name: Set up git repository - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/template-sync.yaml b/.github/workflows/template-sync.yaml index cea60084..384633f6 100644 --- a/.github/workflows/template-sync.yaml +++ b/.github/workflows/template-sync.yaml @@ -50,7 +50,7 @@ jobs: owner: "${{ github.repository_owner }}" - name: Checkout - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7635f365..968127ba 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,7 +25,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 From 4a51777e4aa0aa2326f9cdf926fa3aa8e78eb7aa Mon Sep 17 00:00:00 2001 From: Brian Greunke Date: Fri, 5 Dec 2025 21:32:16 -0600 Subject: [PATCH 9/9] feat(dataset): add push, save-to-disk, and refactor API --- dreadnode/__init__.py | 6 +- dreadnode/api/client.py | 54 ++++-- dreadnode/api/models.py | 85 ++++----- dreadnode/cli/datasets/__init__.py | 3 - dreadnode/cli/datasets/cli.py | 255 ------------------------- dreadnode/dataset.py | 140 ++++++++++---- dreadnode/main.py | 36 +++- dreadnode/storage/datasets/manager.py | 55 ++++-- dreadnode/storage/datasets/manifest.py | 3 +- 9 files changed, 249 insertions(+), 388 deletions(-) delete mode 100644 dreadnode/cli/datasets/__init__.py delete mode 100644 dreadnode/cli/datasets/cli.py diff --git a/dreadnode/__init__.py b/dreadnode/__init__.py index 9db45a31..9abb387d 100644 --- a/dreadnode/__init__.py +++ b/dreadnode/__init__.py @@ -56,7 +56,8 @@ push_update = DEFAULT_INSTANCE.push_update tag = DEFAULT_INSTANCE.tag load_dataset = DEFAULT_INSTANCE.load_dataset -save_dataset = DEFAULT_INSTANCE.save_dataset +save_dataset_to_disk = DEFAULT_INSTANCE.save_dataset_to_disk +push_dataset = DEFAULT_INSTANCE.push_dataset get_run_context = DEFAULT_INSTANCE.get_run_context continue_run = DEFAULT_INSTANCE.continue_run log_metric = DEFAULT_INSTANCE.log_metric @@ -134,9 +135,10 @@ "logging", "meta", "optimization", + "push_dataset", "push_update", "run", - "save_dataset", + "save_dataset_to_disk", "scorer", "scorers", "shutdown", diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 623fff8d..ebea04b9 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -16,12 +16,12 @@ from dreadnode.api.models import ( AccessRefreshTokenResponse, ContainerRegistryCredentials, + CreateDatasetRequest, + CreateDatasetResponse, DatasetDownloadRequest, DatasetDownloadResponse, DatasetMetadata, - DatasetUploadComplete, - DatasetUploadRequest, - DatasetUploadResponse, + DatasetUploadCompleteRequest, DeviceCodeResponse, ExportFormat, GithubTokenResponse, @@ -761,7 +761,10 @@ def export_timeseries( # User data access def get_user_data_credentials( - self, organization_id: UUID, workspace_id: UUID + self, + organization_id: UUID | None = None, + workspace_id: UUID | None = None, + dataset_id: UUID | None = None, ) -> UserDataCredentials: """ Retrieves user data credentials for secondary storage access. @@ -769,8 +772,14 @@ def get_user_data_credentials( Returns: The user data credentials object. """ - params = {"org_id": str(organization_id), "workspace_id": str(workspace_id)} - response = self._request("GET", "/user-data/credentials", params=params) + params: dict[str, str] = {} + if organization_id: + params["org_id"] = str(organization_id) + if workspace_id: + params["workspace_id"] = str(workspace_id) + if dataset_id: + params["dataset_id"] = str(dataset_id) + response = self.request("GET", "/user-data/credentials", params=params) return UserDataCredentials(**response.json()) # Container registry access @@ -921,8 +930,8 @@ def delete_workspace(self, workspace_id: str | UUID) -> None: def create_dataset( self, - request: DatasetUploadRequest, - ) -> DatasetUploadResponse: + request: CreateDatasetRequest, + ) -> CreateDatasetResponse: """ Creates a new dataset. @@ -933,13 +942,11 @@ def create_dataset( DatasetUploadResponse: The dataset upload response object. """ - payload: dict[str, t.Any] = request.model_dump() + response = self.request("POST", "/datasets", json_data=request.model_dump()) - response = self.request("POST", "/datasets/upload", json_data=payload) + return CreateDatasetResponse(**response.json()) - return DatasetUploadResponse.model_validate(response.json()) - - def upload_complete(self, request: DatasetUploadComplete) -> None: + def upload_complete(self, request: DatasetUploadCompleteRequest) -> DatasetMetadata: """ Marks a dataset upload as complete. @@ -947,9 +954,10 @@ def upload_complete(self, request: DatasetUploadComplete) -> None: request (DatasetUploadComplete): The dataset upload completion request object. """ - payload: dict[str, t.Any] = request - - self.request("POST", "/datasets/upload/complete", json_data=payload) + response = self.request( + "POST", "/datasets/upload-complete", json_data=request.model_dump(mode="json") + ) + return DatasetMetadata(**response.json()) def download_dataset(self, request: DatasetDownloadRequest) -> DatasetDownloadResponse: """ @@ -963,7 +971,7 @@ def download_dataset(self, request: DatasetDownloadRequest) -> DatasetDownloadRe """ response = self.request( "GET", - f"/datasets/{request.dataset_uri}/download/?version={request.version}", + f"/datasets/{request.dataset_uri}/download?version={request.version}", ) return DatasetDownloadResponse.model_validate(response.json()) @@ -987,7 +995,7 @@ def get_dataset( def update_dataset( self, dataset_id_or_key: str | UUID, - dataset: DatasetUploadRequest, + dataset: CreateDatasetRequest, ) -> DatasetMetadata: """ Updates an existing dataset. @@ -1004,3 +1012,13 @@ def update_dataset( response = self.request("PUT", f"/datasets/{dataset_id_or_key}", json_data=payload) return DatasetMetadata(**response.json()) + + def delete_dataset(self, dataset_id_or_key: str | UUID) -> None: + """ + Deletes a specific dataset. + + Args: + dataset_id_or_key (str | UUID): The dataset identifier. + """ + + self.request("DELETE", f"/datasets/{dataset_id_or_key}") diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index 7887d447..ceef2b1b 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -1,4 +1,5 @@ import contextlib +import re import typing as t from datetime import datetime from functools import cached_property @@ -7,6 +8,7 @@ import requests from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Field, PrivateAttr, @@ -18,6 +20,17 @@ AnyDict = dict[str, t.Any] + +def _validate_key(key: str) -> str: + """Validate that a key only contains alphanumeric characters and dashes.""" + pattern = r"^(?=.{3,100}$)[a-z0-9]+(?:-[a-z0-9]+)*$" + if not bool(re.match(pattern, key)): + raise ValidationError( + detail="Key can only contain lowercase alphanumeric characters and dashes." + ) + return key + + # User @@ -42,6 +55,10 @@ class UserDataCredentials(BaseModel): prefix: str endpoint: str | None + @property + def upload_uri(self) -> str: + return f"dn://{self.bucket}/{self.prefix}" + class ContainerRegistryCredentials(BaseModel): registry: str @@ -556,42 +573,29 @@ class DatasetMetadata(BaseModel): A data model representing the metadata of a dataset. """ - id: UUID - """Unique identifier for the dataset.""" - org_id: UUID - """Unique identifier for the organization owning the dataset.""" - repo_id: UUID - """Unique identifier for the repository containing the dataset.""" - name: str - """Name of the dataset.""" - description: str | None = None - """Description of the dataset.""" - version: str | None = None - """Version of the dataset.""" - license: str | None = None - """License of the dataset.""" - tags: list[str] | None = None - """Tags associated with the dataset.""" - ds_schema: dict[str, t.Any] | None = None - """Schema of the dataset.""" - file_pointers: list[str] | None = None - """List of file pointers for the dataset files.""" - - -class DatasetUploadRequest(BaseModel): + id: UUID = Field(..., description="Dataset ID") + key: t.Annotated[str, BeforeValidator(_validate_key)] = Field(..., description="Dataset name") + tags: list[str] | None = Field(None, description="Dataset tags") + download_count: int | None = Field( + None, description="Number of times dataset has been downloaded" + ) + created_at: datetime = Field(..., description="Creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + is_public: bool = Field(..., description="Whether the dataset is public") + + +class CreateDatasetRequest(BaseModel): """ A data model representing the request body for creating a new dataset. """ - id: str | None - """Unique identifier for the dataset.""" - name: str | None - """Name of the dataset.""" - manifest: dict[str, t.Any] | None = None - """Manifest of the dataset.""" + org_key: t.Annotated[str, BeforeValidator(_validate_key)] + """Unique identifier for the organization owning the dataset.""" + key: t.Annotated[str, BeforeValidator(_validate_key)] + """Unique identifier of the dataset.""" -class DatasetUploadResponse(BaseModel): +class CreateDatasetResponse(BaseModel): """ A data model representing the response after creating a new dataset. @@ -601,34 +605,23 @@ class DatasetUploadResponse(BaseModel): status_code (int): HTTP status code of the upload request. """ - id: str + dataset_id: str """Unique identifier for the dataset.""" - upload_uri: str + user_data_access_response: UserDataCredentials """URI to upload the dataset files.""" - status_code: int - """HTTP status code of the upload request.""" -class DatasetUploadComplete(BaseModel): +class DatasetUploadCompleteRequest(BaseModel): """ A data model representing the request body for completing a dataset upload. """ - id: str + dataset_id: str """Unique identifier for the dataset.""" - success: bool + complete: bool """Status code indicating the result of the upload.""" -class DatasetUploadCompleteResponse(BaseModel): - """ - A data model representing the response after completing a dataset upload. - """ - - status_code: int - """HTTP status code of the upload completion request.""" - - class DatasetDownloadRequest(BaseModel): """ A data model representing the request body for downloading a dataset. diff --git a/dreadnode/cli/datasets/__init__.py b/dreadnode/cli/datasets/__init__.py deleted file mode 100644 index 70b21eb8..00000000 --- a/dreadnode/cli/datasets/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dreadnode.cli.datasets.cli import cli - -__all__ = ["cli"] diff --git a/dreadnode/cli/datasets/cli.py b/dreadnode/cli/datasets/cli.py deleted file mode 100644 index 81219956..00000000 --- a/dreadnode/cli/datasets/cli.py +++ /dev/null @@ -1,255 +0,0 @@ -from pathlib import Path - -import cyclopts - -cli = cyclopts.App("dataset", help="Run and manage datasets.") - - -@cli.command(name="list") -def list() -> None: - """ - List available datasets on the Dreadnode platform. - """ - print("Listing datasets available on the Dreadnode platform.") - - -@cli.command(name="push") -def push(dataset: Path) -> None: - """ - Push a dataset to the Dreadnode platform. - """ - print(f"Pushing dataset from {dataset} to Dreadnode platform.") - - -@cli.command(name="pull") -def pull(dataset_id: str, destination: Path | None = None) -> None: - """ - Pull a dataset from the Dreadnode platform. - """ - - # def log_artifact( - # self, - # local_uri: str | Path, - # ) -> None: - # """ - # Logs a local file or directory as an artifact to the object store. - # Preserves directory structure and uses content hashing for deduplication. - - # Args: - # local_uri: Path to the local file or directory - - # Returns: - # DirectoryNode representing the artifact's tree structure - - # Raises: - # FileNotFoundError: If the path doesn't exist - # """ - # artifact_tree = self._artifact_tree_builder.process_artifact(local_uri) - # self._artifact_merger.add_tree(artifact_tree) - # self._artifacts = self._artifact_merger.get_merged_trees() - - -import shutil -from collections.abc import Callable -from pathlib import Path -from typing import Any - -from loguru import logger - -from dreadnode.api.models import UserDataCredentials -from dreadnode.constants import DATASETS_CACHE, METADATA_FILE -from dreadnode.storage.base import BaseStorage -from dreadnode.storage.datasets.metadata import DatasetMetadata - - -class DatasetStorage(BaseStorage): - """ - High-level client for dataset operations. - - This is the main interface users interact with. - """ - - def __init__( - self, - credential_fetcher: Callable[[], UserDataCredentials] | None = None, - cache_dir: Path | None = None, - ): - """ - Initialize dataset client. - - Args: - credential_fetcher: Function to get S3 credentials - cache_dir: Custom cache directory - """ - self._credential_fetcher = credential_fetcher - self.cache_dir = cache_dir or DATASETS_CACHE - - def list_cached_datasets(self) -> list[DatasetMetadata]: - """ - List all datasets in cache. - - Returns: - List of metadata for cached datasets - """ - datasets = [] - - for org_dir in self.cache_dir.iterdir(): - if not org_dir.is_dir(): - continue - - for dataset_dir in org_dir.iterdir(): - if not dataset_dir.is_dir(): - continue - - for version_dir in dataset_dir.iterdir(): - if not version_dir.is_dir(): - continue - - metadata_path = version_dir / METADATA_FILE - if metadata_path.exists(): - try: - metadata = DatasetMetadata.load(metadata_path) - datasets.append(metadata) - except Exception as e: - logger.warning(f"Failed to load metadata from {metadata_path}: {e}") - - return datasets - - def delete_dataset( - self, - uri: str, - version: str | None = None, - *, - cache_only: bool = True, - ) -> bool: - """ - Remove dataset from cache and optionally remote. - - Args: - uri: Dataset URI - version: Specific version (removes all if None) - cache_only: If True, only remove from cache - - Returns: - True if removed successfully - """ - parsed_uri, parsed_version = self.parse_uri(uri) - version = version or parsed_version - - logger.info(f"Removing dataset {parsed_uri}@{version or 'all'}") - - # Remove from cache - if version: - cache_path = self.get_dataset_path(uri, version) - if cache_path.exists(): - logger.info(f"Removing dataset from cache: {cache_path}") - shutil.rmtree(cache_path) - return True - return False - - if not cache_only and self._credential_fetcher: - # Remove from remote (implementation depends on backend) - logger.warning("Remote deletion not implemented") - - return True - - def list_datasets( - self, - *, - remote: bool = False, - cache_only: bool = True, - ) -> list[any]: - """ - List available datasets. - - Args: - remote: If True, list from remote storage - cache_only: If True, only list cached datasets - - Returns: - List of dataset metadata - """ - if cache_only or not remote: - return self.list_cached_datasets() - - # Remote listing would require API implementation - logger.warning("Remote dataset listing not implemented") - return [] - - def search_datasets( - self, - name_pattern: str | None = None, - tags: list[str] | None = None, - version: str | None = None, - *, - remote: bool = False, - ) -> list[any]: - """ - Search for datasets matching criteria. - - Args: - name_pattern: Pattern to match in name - tags: Required tags - version: Specific version - remote: Search remote storage - - Returns: - List of matching dataset metadata - """ - if remote and not self._credential_fetcher: - logger.warning("Remote search requires credential fetcher") - remote = False - - all_datasets = self.list_cached_datasets() - - results = [ds for ds in all_datasets if ds.matches_filter(name_pattern, tags, version)] - - if remote: - # Remote search would require API implementation - logger.warning("Remote dataset search not implemented") - - return results - - def get_cache_info(self) -> dict[str, Any]: - """ - Get information about the cache. - - Returns: - Dictionary with cache statistics - - Examples: - >>> client = DatasetStorage() - >>> info = client.get_cache_info() - >>> print(f"Cache size: {info['size_gb']:.2f} GB") - """ - size_bytes = self.get_cache_size() - datasets = self.list_cached_datasets() - - return { - "cache_dir": str(self.cache_dir), - "size_bytes": size_bytes, - "size_mb": size_bytes / (1024 * 1024), - "size_gb": size_bytes / (1024 * 1024 * 1024), - "dataset_count": len(datasets), - "datasets": [ - { - "name": ds.name, - "version": ds.version, - "uri": ds.uri, - } - for ds in datasets - ], - } - - def get_cache_size(self) -> int: - """ - Get total size of cache in bytes. - - Returns: - Cache size in bytes - """ - total_size = 0 - for path in self.cache_dir.rglob("*"): - if path.is_file(): - total_size += path.stat().st_size - return total_size diff --git a/dreadnode/dataset.py b/dreadnode/dataset.py index 6207cb04..d537f8de 100644 --- a/dreadnode/dataset.py +++ b/dreadnode/dataset.py @@ -8,6 +8,7 @@ from pyarrow.fs import FileSystem from dreadnode.constants import MANIFEST_FILE, METADATA_FILE +from dreadnode.logging_ import print_info from dreadnode.storage.datasets.manager import DatasetManager from dreadnode.storage.datasets.manifest import DatasetManifest, create_manifest from dreadnode.storage.datasets.metadata import DatasetMetadata, VersionInfo @@ -39,7 +40,7 @@ def __init__( self.ds = ds.to_table() if not metadata: - print("[*] No metadata provided, check your dataset!") + print_info("[*] No metadata provided, check your dataset!") def update_metadata(self, metadata: DatasetMetadata) -> None: self.metadata = metadata @@ -110,7 +111,7 @@ def save_dataset( path: str, fs: FileSystem, *, - to_cache: bool = False, + create_dir: bool = False, **kwargs: Any, ) -> None: """ @@ -123,53 +124,118 @@ def save_dataset( format="parquet", filesystem=fs, existing_data_behavior="overwrite_or_ignore", - create_dir=to_cache, + create_dir=create_dir, **kwargs, ) -def save_dataset( +def _persist_dataset( dataset: Dataset, + path_str: str, *, - to_cache: bool = False, + create_dir: bool = False, fsm: DatasetManager, **kwargs: Any, ) -> None: - if to_cache: - path_str = fsm.get_cache_save_uri(metadata=dataset.metadata) - print("[*] Saving dataset to local cache") - else: - path_str = fsm.get_remote_save_uri(metadata=dataset.metadata) - print("[*] Saving dataset to remote storage") + """Persists a dataset to the given path. + Args: + dataset: The Dataset to persist. + path_str: The path to persist the dataset to. + create_dir: Whether to create the directory if it doesn't exist. Defaults to False. + fsm: The DatasetManager instance. + kwargs: Additional arguments to pass to pyarrow.dataset.write_dataset. + """ fs, base_path = fsm.get_fs_and_path(path_str) - try: - fsm.ensure_dir(fs, base_path) + fsm.ensure_dir(fs, base_path) + data_path = f"{base_path}/data" + fsm.ensure_dir(fs, data_path) - dataset.save_dataset(path=f"{base_path}/data", fs=fs, to_cache=to_cache, **kwargs) + dataset.save_dataset(path=data_path, fs=fs, create_dir=create_dir, **kwargs) - dataset.save_metadata(path=f"{base_path}/{METADATA_FILE}", fs=fs) + dataset.save_metadata(path=f"{base_path}/{METADATA_FILE}", fs=fs) - manifest = create_manifest( - path=base_path, - version=dataset.metadata.version, - previous_manifest=dataset.manifest if dataset.manifest else None, - fs=fs, - ) - manifest.save(f"{base_path}/{MANIFEST_FILE}", fs=fs) - except Exception as e: - # if remote save failed, notify API - if not to_cache: - fsm.remote_save_complete(success=False, dataset_id=dataset.metadata.id) - print(f"[!] Failed to save dataset: {e}") - raise + manifest = create_manifest( + path=base_path, + version=dataset.metadata.version, + previous_manifest=dataset.manifest if dataset.manifest else None, + fs=fs, + ) + manifest.save(f"{base_path}/{MANIFEST_FILE}", fs=fs) + + print_info("[+] Saved dataset successfully") - # if remote save succeeded, notify API - if not to_cache: - fsm.remote_save_complete(success=True, dataset_id=dataset.metadata.id) - print("[+] Saved dataset successfully") +def save_dataset_to_disk( + dataset: Dataset, + *, + fsm: DatasetManager, + **kwargs: Any, +) -> None: + """Saves a dataset to local disk cache. + + Args: + dataset: The Dataset to save. + fsm: The DatasetManager instance. + kwargs: Additional arguments to pass to pyarrow.dataset.write_dataset. + + Returns: + None + + """ + path_str = fsm.get_cache_save_uri(metadata=dataset.metadata) + print_info("[*] Saving dataset to local cache") + + _persist_dataset( + dataset=dataset, + path_str=path_str, + create_dir=True, + fsm=fsm, + **kwargs, + ) + + +def push_dataset( + dataset: Dataset, + *, + to_cache: bool = True, + fsm: DatasetManager, + **kwargs: Any, +) -> None: + """Pushes a dataset to remote storage. + + Args: + dataset: The Dataset to push. + to_cache: Whether to save to local cache first. Defaults to True. + fsm: The DatasetManager instance. + kwargs: Additional arguments to pass to pyarrow.dataset.write_dataset. + + Returns: + None + """ + if to_cache: + save_dataset_to_disk( + dataset=dataset, + fsm=fsm, + **kwargs, + ) + + dataset_id, path_str = fsm.get_remote_save_uri(metadata=dataset.metadata) + dataset.metadata.id = dataset_id + print_info("[*] Saving dataset to remote storage") + try: + _persist_dataset( + dataset=dataset, + path_str=path_str, + fsm=fsm, + **kwargs, + ) + fsm.remote_save_complete(complete=True, dataset_id=dataset.metadata.id) + except Exception: + # if remote save failed, remove the record from API + fsm.delete_remote_dataset_record(dataset_id_or_key=dataset.metadata.id) + raise def load_dataset( @@ -203,7 +269,7 @@ def load_dataset( if protocol in ("file", "local", ""): # check cache first if not fsm.check_cache(uri, version): - print("[+] Dataset not found in cache. Loading dataset from local path...") + print_info("[+] Dataset not found in cache. Loading dataset from local path...") # load directly from local path fs, fs_path = fsm.get_fs_and_path(uri) @@ -220,7 +286,7 @@ def load_dataset( return Dataset(ds=dataset, metadata=metadata, materialize=materialize) # if in cache, load from cache - print("[+] Loading dataset from cache...") + print_info("[+] Loading dataset from cache...") # get the filesystem and path fs, fs_path = fsm.get_fs_and_path(uri) @@ -242,7 +308,7 @@ def load_dataset( return Dataset(ds=dataset, materialize=materialize, metadata=metadata, manifest=manifest) # if not local path, and not in cache, load from remote - print("[+] Loading from remote storage...") + print_info("[+] Loading from remote storage...") try: # get remote URI remote_uri = fsm.get_remote_load_uri(uri=strip_protocol(uri), version=version) @@ -264,13 +330,13 @@ def load_dataset( is_valid = manifest.validate(fs_path, fs) if not is_valid: # invalid manifest, sync from remote - print("[!] Remote dataset manifest validation failed.") + print_info("[!] Remote dataset manifest validation failed.") # load dataset dataset = ds.dataset(f"{fs_path}/data", format=format, filesystem=fs, **kwargs) return Dataset(ds=dataset, metadata=metadata, manifest=manifest, materialize=materialize) except Exception as e: - print(f"[!] Failed to load dataset from remote: {e}") + print_info(f"[!] Failed to load dataset from remote: {e}") raise FileNotFoundError(f"[!] Dataset not found: {uri}") diff --git a/dreadnode/main.py b/dreadnode/main.py index b076c9fe..3924389e 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -255,7 +255,7 @@ def _resolve_organization(self) -> None: if len(organizations) > 1: # We should not presume to choose an organization - org_list = "\t\n".join([f"- {o.name}" for o in organizations]) + org_list = "\t\n".join([f"- {o.key}" for o in organizations]) raise RuntimeError( f"You are part of multiple organizations. Please specify an organization from:\n{org_list}" ) @@ -704,7 +704,7 @@ def initialize(self) -> None: api = self._api self._credential_manager = CredentialManager( credential_fetcher=lambda: api.get_user_data_credentials( - self._organization.id, self._workspace.id + organization_id=self._organization.id, workspace_id=self._workspace.id ) ) @@ -730,6 +730,7 @@ def initialize(self) -> None: self._fs_manager = DatasetManager().configure( api=self._api, # type: ignore[return-value] organization=self._organization.key, + organization_id=self._organization.id, ) self._initialized = True @@ -1335,24 +1336,39 @@ def load_dataset( fsm=self._fs_manager, ) - def save_dataset( + def save_dataset_to_disk( self, ds: dataset.Dataset, - *, - to_cache: bool = False, - ) -> str: + ) -> None: + """ + Save a dataset to the local cache. + + Example: + ``` + dreadnode.save_dataset_to_disk(my_dataset) + ``` + """ + + dataset.save_dataset_to_disk( + dataset=ds, + fsm=self._fs_manager, + ) + + def push_dataset( + self, + ds: dataset.Dataset, + ) -> None: """ - Save a dataset to the local cache and optionally to the Dreadnode server. + Push a dataset to the Dreadnode server. Example: ``` - uri = dreadnode.save_dataset(my_dataset) + dreadnode.push_dataset(my_dataset) ``` """ - dataset.save_dataset( + dataset.push_dataset( dataset=ds, - to_cache=to_cache, fsm=self._fs_manager, ) diff --git a/dreadnode/storage/datasets/manager.py b/dreadnode/storage/datasets/manager.py index 71c91f7a..e3dc6a36 100644 --- a/dreadnode/storage/datasets/manager.py +++ b/dreadnode/storage/datasets/manager.py @@ -2,15 +2,16 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any +from uuid import UUID import pyarrow.fs as pafs # The Native FS from pyarrow.fs import FileSystem, FileType from dreadnode.api import ApiClient from dreadnode.api.models import ( + CreateDatasetRequest, DatasetDownloadRequest, - DatasetUploadComplete, - DatasetUploadRequest, + DatasetUploadCompleteRequest, ) from dreadnode.constants import ( DEFAULT_LOCAL_STORAGE_DIR, @@ -20,6 +21,7 @@ METADATA_FILE, ) from dreadnode.logging_ import console as logging_console +from dreadnode.logging_ import print_info from dreadnode.storage.datasets.metadata import DatasetMetadata from dreadnode.util import resolve_endpoint @@ -37,6 +39,7 @@ class DatasetManager: """ organization: str | None = None + organization_id: UUID | None = None _instance: "DatasetManager | None" = None _api: ApiClient | None = None @@ -57,10 +60,12 @@ def configure( cls, api: ApiClient | None = None, organization: str | None = None, + organization_id: UUID | None = None, ) -> "DatasetManager": instance = cls() instance._api = api instance.organization = organization + instance.organization_id = organization_id return instance def metadata_exists(self, path: str) -> bool: @@ -128,7 +133,7 @@ def get_cache_load_uri( return str(dataset_uri / latest) - def get_remote_save_uri(self, metadata: DatasetMetadata) -> str: + def get_remote_save_uri(self, metadata: DatasetMetadata) -> tuple[UUID, str]: """ Constructs the full remote storage URI. Example: dreadnode://datasets/main/my-dataset @@ -137,13 +142,15 @@ def get_remote_save_uri(self, metadata: DatasetMetadata) -> str: if not self._api: raise ValueError("No client configured") - upload_request = DatasetUploadRequest.model_validate(metadata.model_dump()) - - response = self._api.upload_dataset_request(request=upload_request) + upload_request = CreateDatasetRequest( + org_key=metadata.organization, + key=metadata.name, + ) - return response.upload_uri + response = self._api.create_dataset(request=upload_request) + return response.dataset_id, response.user_data_access_response.upload_uri - def remote_save_complete(self, dataset_id: str, *, success: bool) -> None: + def remote_save_complete(self, dataset_id: str, *, complete: bool) -> None: """ Notifies the API that the remote upload is complete. """ @@ -151,9 +158,9 @@ def remote_save_complete(self, dataset_id: str, *, success: bool) -> None: if not self._api: raise ValueError("No client configured") - request = DatasetUploadComplete(id=dataset_id, success=success) + request = DatasetUploadCompleteRequest(dataset_id=dataset_id, complete=complete) - self._api.upload_complete(request=request.model_dump()) + self._api.upload_complete(request=request) def get_remote_load_uri(self, uri: str, version: str | None = "latest") -> str: """ @@ -167,17 +174,19 @@ def get_remote_load_uri(self, uri: str, version: str | None = "latest") -> str: response = self._api.download_dataset(request) - print(f"[*] Download URI: {response.download_uri}") + print_info(f"[*] Download URI: {response.download_uri}") return response.download_uri - def get_s3_config(self) -> dict[str, Any]: + def get_s3_config(self, dataset_id: UUID) -> dict[str, Any]: """ Translates your UserDataCredentials into PyArrow S3 arguments. """ if not self._api: raise ValueError("No client configured") - creds = self._api.get_user_data_credentials() + creds = self._api.get_user_data_credentials( + organization_id=self.organization_id, dataset_id=dataset_id + ) self._credentials_expiry = creds.expiration resolved_endpoint = resolve_endpoint(creds.endpoint) @@ -220,9 +229,13 @@ def get_fs_and_path(self, uri: str) -> tuple[FileSystem, str]: if self._cached_s3_fs is None or self.needs_refresh(): try: - config = self.get_s3_config() + # Try to extract dataset ID from URI which expect is of the form dn:///datasets/ + dataset_id = UUID(path_body.split("/")[-1]) + config = self.get_s3_config(dataset_id=dataset_id) self._cached_s3_fs = pafs.S3FileSystem(**config) - + except ValueError: + logging_console.print(f"[red]Invalid dataset ID in URI: [green]{uri}[/green][/red]") + raise except Exception as e: logging_console.print(f"Auth failed: {e}") raise @@ -244,7 +257,7 @@ def resolve_latest_version(self, uri: str, fs: FileSystem) -> str: ] latest = sorted(versions, reverse=True)[0] - print(f"[*] Resolved latest version {latest}") + print_info(f"[*] Resolved latest version {latest}") return latest def ensure_dir(self, fs: FileSystem, path: str) -> None: @@ -255,3 +268,13 @@ def ensure_dir(self, fs: FileSystem, path: str) -> None: return with contextlib.suppress(OSError): fs.create_dir(path, recursive=True) + + def delete_remote_dataset_record(self, dataset_id_or_key: UUID | str) -> None: + """ + Deletes a remote dataset via the API. + """ + + if not self._api: + raise ValueError("No client configured") + + self._api.delete_dataset(dataset_id_or_key=dataset_id_or_key) diff --git a/dreadnode/storage/datasets/manifest.py b/dreadnode/storage/datasets/manifest.py index 64dc9237..ef64fd27 100644 --- a/dreadnode/storage/datasets/manifest.py +++ b/dreadnode/storage/datasets/manifest.py @@ -132,7 +132,8 @@ def compute_file_hash( ) -> str: try: with fs.open_input_stream(file_path) as f: - return hashlib.file_digest(f, algorithm) + digest = hashlib.file_digest(f, algorithm) + return digest.hexdigest() except Exception as e: logging_console.print(f"Failed to hash {file_path}: {e}") return ""