diff --git a/.github/workflows/meta-sync-labels.yaml b/.github/workflows/meta-sync-labels.yaml index b4acaac3..366f2b7e 100644 --- a/.github/workflows/meta-sync-labels.yaml +++ b/.github/workflows/meta-sync-labels.yaml @@ -25,7 +25,7 @@ jobs: private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" - name: Set up git repository - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index fd96acc7..a2ed593b 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index d213b697..5ed356b5 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.github/workflows/renovate.yaml b/.github/workflows/renovate.yaml index 5f3f3754..abeccb6e 100644 --- a/.github/workflows/renovate.yaml +++ b/.github/workflows/renovate.yaml @@ -59,7 +59,7 @@ jobs: private-key: "${{ secrets.BOT_APP_PRIVATE_KEY }}" - name: Checkout - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/rigging_pr_description.yaml b/.github/workflows/rigging_pr_description.yaml index 9df57c34..3c709c88 100644 --- a/.github/workflows/rigging_pr_description.yaml +++ b/.github/workflows/rigging_pr_description.yaml @@ -13,12 +13,12 @@ jobs: contents: read steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 0 # full history for proper diffing - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.13" diff --git a/.github/workflows/semgrep.yaml b/.github/workflows/semgrep.yaml index f87a0070..aa9ef6cd 100644 --- a/.github/workflows/semgrep.yaml +++ b/.github/workflows/semgrep.yaml @@ -38,7 +38,7 @@ jobs: steps: - name: Set up git repository - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/template-sync.yaml b/.github/workflows/template-sync.yaml index c68a4b23..384633f6 100644 --- a/.github/workflows/template-sync.yaml +++ b/.github/workflows/template-sync.yaml @@ -50,7 +50,7 @@ jobs: owner: "${{ github.repository_owner }}" - name: Checkout - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: "${{ steps.app-token.outputs.token }}" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 34bdbfb5..968127ba 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,7 +25,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Install uv uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7.1.4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3ab681ed..13b0307b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: # Python code security - repo: https://github.com/PyCQA/bandit - rev: 1.9.1 + rev: 1.9.2 hooks: - id: bandit name: Code security checks diff --git a/docs/sdk/api.mdx b/docs/sdk/api.mdx index 93c77838..c193e8b4 100644 --- a/docs/sdk/api.mdx +++ b/docs/sdk/api.mdx @@ -164,7 +164,8 @@ def create_dataset( ```python create_project( - name: str | UUID | None = None, + name: str, + key: str, workspace_id: UUID | None = None, organization_id: UUID | None = None, ) -> Project @@ -175,9 +176,7 @@ Creates a new project. **Parameters:** * **`name`** - (`str | UUID | None`, default: - `None` - ) + (`str`) –The name of the project. If None, a default name will be used. * **`workspace_id`** (`UUID | None`, default: @@ -199,7 +198,8 @@ Creates a new project. ```python def create_project( self, - name: str | UUID | None = None, + name: str, + key: str, workspace_id: UUID | None = None, organization_id: UUID | None = None, ) -> Project: @@ -214,8 +214,8 @@ def create_project( Project: The created Project object. """ payload: dict[str, t.Any] = {} - if name is not None: - payload["name"] = name + payload["name"] = name + payload["key"] = key if workspace_id is not None: payload["workspace_id"] = str(workspace_id) if organization_id is not None: @@ -995,7 +995,7 @@ Retrieves details of a specific project. * **`project_identifier`** (`str | UUID`) - –The project identifier. ID, name, or slug. + –The project identifier. ID or key. **Returns:** @@ -1008,7 +1008,7 @@ def get_project(self, project_identifier: str | UUID, workspace_id: UUID) -> Pro """Retrieves details of a specific project. Args: - project_identifier (str | UUID): The project identifier. ID, name, or slug. + project_identifier (str | UUID): The project identifier. ID or key. Returns: Project: The Project object. diff --git a/dreadnode/__init__.py b/dreadnode/__init__.py index 9db45a31..9abb387d 100644 --- a/dreadnode/__init__.py +++ b/dreadnode/__init__.py @@ -56,7 +56,8 @@ push_update = DEFAULT_INSTANCE.push_update tag = DEFAULT_INSTANCE.tag load_dataset = DEFAULT_INSTANCE.load_dataset -save_dataset = DEFAULT_INSTANCE.save_dataset +save_dataset_to_disk = DEFAULT_INSTANCE.save_dataset_to_disk +push_dataset = DEFAULT_INSTANCE.push_dataset get_run_context = DEFAULT_INSTANCE.get_run_context continue_run = DEFAULT_INSTANCE.continue_run log_metric = DEFAULT_INSTANCE.log_metric @@ -134,9 +135,10 @@ "logging", "meta", "optimization", + "push_dataset", "push_update", "run", - "save_dataset", + "save_dataset_to_disk", "scorer", "scorers", "shutdown", diff --git a/dreadnode/api/client.py b/dreadnode/api/client.py index 09e2e131..ebea04b9 100644 --- a/dreadnode/api/client.py +++ b/dreadnode/api/client.py @@ -16,12 +16,12 @@ from dreadnode.api.models import ( AccessRefreshTokenResponse, ContainerRegistryCredentials, + CreateDatasetRequest, + CreateDatasetResponse, DatasetDownloadRequest, DatasetDownloadResponse, DatasetMetadata, - DatasetUploadComplete, - DatasetUploadRequest, - DatasetUploadResponse, + DatasetUploadCompleteRequest, DeviceCodeResponse, ExportFormat, GithubTokenResponse, @@ -294,7 +294,7 @@ def get_project(self, project_identifier: str | UUID, workspace_id: UUID) -> Pro """Retrieves details of a specific project. Args: - project_identifier (str | UUID): The project identifier. ID, name, or slug. + project_identifier (str | UUID): The project identifier. ID or key. Returns: Project: The Project object. @@ -308,7 +308,8 @@ def get_project(self, project_identifier: str | UUID, workspace_id: UUID) -> Pro def create_project( self, - name: str | UUID | None = None, + name: str, + key: str, workspace_id: UUID | None = None, organization_id: UUID | None = None, ) -> Project: @@ -323,8 +324,8 @@ def create_project( Project: The created Project object. """ payload: dict[str, t.Any] = {} - if name is not None: - payload["name"] = name + payload["name"] = name + payload["key"] = key if workspace_id is not None: payload["workspace_id"] = str(workspace_id) if organization_id is not None: @@ -759,14 +760,26 @@ def export_timeseries( # User data access - def get_user_data_credentials(self) -> UserDataCredentials: + def get_user_data_credentials( + self, + organization_id: UUID | None = None, + workspace_id: UUID | None = None, + dataset_id: UUID | None = None, + ) -> UserDataCredentials: """ Retrieves user data credentials for secondary storage access. Returns: The user data credentials object. """ - response = self._request("GET", "/user-data/credentials") + params: dict[str, str] = {} + if organization_id: + params["org_id"] = str(organization_id) + if workspace_id: + params["workspace_id"] = str(workspace_id) + if dataset_id: + params["dataset_id"] = str(dataset_id) + response = self.request("GET", "/user-data/credentials", params=params) return UserDataCredentials(**response.json()) # Container registry access @@ -917,8 +930,8 @@ def delete_workspace(self, workspace_id: str | UUID) -> None: def create_dataset( self, - request: DatasetUploadRequest, - ) -> DatasetUploadResponse: + request: CreateDatasetRequest, + ) -> CreateDatasetResponse: """ Creates a new dataset. @@ -929,13 +942,11 @@ def create_dataset( DatasetUploadResponse: The dataset upload response object. """ - payload: dict[str, t.Any] = request.model_dump() - - response = self.request("POST", "/datasets/upload", json_data=payload) + response = self.request("POST", "/datasets", json_data=request.model_dump()) - return DatasetUploadResponse.model_validate(response.json()) + return CreateDatasetResponse(**response.json()) - def upload_complete(self, request: DatasetUploadComplete) -> None: + def upload_complete(self, request: DatasetUploadCompleteRequest) -> DatasetMetadata: """ Marks a dataset upload as complete. @@ -943,9 +954,10 @@ def upload_complete(self, request: DatasetUploadComplete) -> None: request (DatasetUploadComplete): The dataset upload completion request object. """ - payload: dict[str, t.Any] = request - - self.request("POST", "/datasets/upload/complete", json_data=payload) + response = self.request( + "POST", "/datasets/upload-complete", json_data=request.model_dump(mode="json") + ) + return DatasetMetadata(**response.json()) def download_dataset(self, request: DatasetDownloadRequest) -> DatasetDownloadResponse: """ @@ -959,7 +971,7 @@ def download_dataset(self, request: DatasetDownloadRequest) -> DatasetDownloadRe """ response = self.request( "GET", - f"/datasets/{request.dataset_uri}/download/?version={request.version}", + f"/datasets/{request.dataset_uri}/download?version={request.version}", ) return DatasetDownloadResponse.model_validate(response.json()) @@ -983,7 +995,7 @@ def get_dataset( def update_dataset( self, dataset_id_or_key: str | UUID, - dataset: DatasetUploadRequest, + dataset: CreateDatasetRequest, ) -> DatasetMetadata: """ Updates an existing dataset. @@ -1000,3 +1012,13 @@ def update_dataset( response = self.request("PUT", f"/datasets/{dataset_id_or_key}", json_data=payload) return DatasetMetadata(**response.json()) + + def delete_dataset(self, dataset_id_or_key: str | UUID) -> None: + """ + Deletes a specific dataset. + + Args: + dataset_id_or_key (str | UUID): The dataset identifier. + """ + + self.request("DELETE", f"/datasets/{dataset_id_or_key}") diff --git a/dreadnode/api/models.py b/dreadnode/api/models.py index 7887d447..ceef2b1b 100644 --- a/dreadnode/api/models.py +++ b/dreadnode/api/models.py @@ -1,4 +1,5 @@ import contextlib +import re import typing as t from datetime import datetime from functools import cached_property @@ -7,6 +8,7 @@ import requests from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Field, PrivateAttr, @@ -18,6 +20,17 @@ AnyDict = dict[str, t.Any] + +def _validate_key(key: str) -> str: + """Validate that a key only contains alphanumeric characters and dashes.""" + pattern = r"^(?=.{3,100}$)[a-z0-9]+(?:-[a-z0-9]+)*$" + if not bool(re.match(pattern, key)): + raise ValidationError( + detail="Key can only contain lowercase alphanumeric characters and dashes." + ) + return key + + # User @@ -42,6 +55,10 @@ class UserDataCredentials(BaseModel): prefix: str endpoint: str | None + @property + def upload_uri(self) -> str: + return f"dn://{self.bucket}/{self.prefix}" + class ContainerRegistryCredentials(BaseModel): registry: str @@ -556,42 +573,29 @@ class DatasetMetadata(BaseModel): A data model representing the metadata of a dataset. """ - id: UUID - """Unique identifier for the dataset.""" - org_id: UUID - """Unique identifier for the organization owning the dataset.""" - repo_id: UUID - """Unique identifier for the repository containing the dataset.""" - name: str - """Name of the dataset.""" - description: str | None = None - """Description of the dataset.""" - version: str | None = None - """Version of the dataset.""" - license: str | None = None - """License of the dataset.""" - tags: list[str] | None = None - """Tags associated with the dataset.""" - ds_schema: dict[str, t.Any] | None = None - """Schema of the dataset.""" - file_pointers: list[str] | None = None - """List of file pointers for the dataset files.""" - - -class DatasetUploadRequest(BaseModel): + id: UUID = Field(..., description="Dataset ID") + key: t.Annotated[str, BeforeValidator(_validate_key)] = Field(..., description="Dataset name") + tags: list[str] | None = Field(None, description="Dataset tags") + download_count: int | None = Field( + None, description="Number of times dataset has been downloaded" + ) + created_at: datetime = Field(..., description="Creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + is_public: bool = Field(..., description="Whether the dataset is public") + + +class CreateDatasetRequest(BaseModel): """ A data model representing the request body for creating a new dataset. """ - id: str | None - """Unique identifier for the dataset.""" - name: str | None - """Name of the dataset.""" - manifest: dict[str, t.Any] | None = None - """Manifest of the dataset.""" + org_key: t.Annotated[str, BeforeValidator(_validate_key)] + """Unique identifier for the organization owning the dataset.""" + key: t.Annotated[str, BeforeValidator(_validate_key)] + """Unique identifier of the dataset.""" -class DatasetUploadResponse(BaseModel): +class CreateDatasetResponse(BaseModel): """ A data model representing the response after creating a new dataset. @@ -601,34 +605,23 @@ class DatasetUploadResponse(BaseModel): status_code (int): HTTP status code of the upload request. """ - id: str + dataset_id: str """Unique identifier for the dataset.""" - upload_uri: str + user_data_access_response: UserDataCredentials """URI to upload the dataset files.""" - status_code: int - """HTTP status code of the upload request.""" -class DatasetUploadComplete(BaseModel): +class DatasetUploadCompleteRequest(BaseModel): """ A data model representing the request body for completing a dataset upload. """ - id: str + dataset_id: str """Unique identifier for the dataset.""" - success: bool + complete: bool """Status code indicating the result of the upload.""" -class DatasetUploadCompleteResponse(BaseModel): - """ - A data model representing the response after completing a dataset upload. - """ - - status_code: int - """HTTP status code of the upload completion request.""" - - class DatasetDownloadRequest(BaseModel): """ A data model representing the request body for downloading a dataset. diff --git a/dreadnode/cli/datasets/__init__.py b/dreadnode/cli/datasets/__init__.py deleted file mode 100644 index 70b21eb8..00000000 --- a/dreadnode/cli/datasets/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dreadnode.cli.datasets.cli import cli - -__all__ = ["cli"] diff --git a/dreadnode/cli/datasets/cli.py b/dreadnode/cli/datasets/cli.py deleted file mode 100644 index 81219956..00000000 --- a/dreadnode/cli/datasets/cli.py +++ /dev/null @@ -1,255 +0,0 @@ -from pathlib import Path - -import cyclopts - -cli = cyclopts.App("dataset", help="Run and manage datasets.") - - -@cli.command(name="list") -def list() -> None: - """ - List available datasets on the Dreadnode platform. - """ - print("Listing datasets available on the Dreadnode platform.") - - -@cli.command(name="push") -def push(dataset: Path) -> None: - """ - Push a dataset to the Dreadnode platform. - """ - print(f"Pushing dataset from {dataset} to Dreadnode platform.") - - -@cli.command(name="pull") -def pull(dataset_id: str, destination: Path | None = None) -> None: - """ - Pull a dataset from the Dreadnode platform. - """ - - # def log_artifact( - # self, - # local_uri: str | Path, - # ) -> None: - # """ - # Logs a local file or directory as an artifact to the object store. - # Preserves directory structure and uses content hashing for deduplication. - - # Args: - # local_uri: Path to the local file or directory - - # Returns: - # DirectoryNode representing the artifact's tree structure - - # Raises: - # FileNotFoundError: If the path doesn't exist - # """ - # artifact_tree = self._artifact_tree_builder.process_artifact(local_uri) - # self._artifact_merger.add_tree(artifact_tree) - # self._artifacts = self._artifact_merger.get_merged_trees() - - -import shutil -from collections.abc import Callable -from pathlib import Path -from typing import Any - -from loguru import logger - -from dreadnode.api.models import UserDataCredentials -from dreadnode.constants import DATASETS_CACHE, METADATA_FILE -from dreadnode.storage.base import BaseStorage -from dreadnode.storage.datasets.metadata import DatasetMetadata - - -class DatasetStorage(BaseStorage): - """ - High-level client for dataset operations. - - This is the main interface users interact with. - """ - - def __init__( - self, - credential_fetcher: Callable[[], UserDataCredentials] | None = None, - cache_dir: Path | None = None, - ): - """ - Initialize dataset client. - - Args: - credential_fetcher: Function to get S3 credentials - cache_dir: Custom cache directory - """ - self._credential_fetcher = credential_fetcher - self.cache_dir = cache_dir or DATASETS_CACHE - - def list_cached_datasets(self) -> list[DatasetMetadata]: - """ - List all datasets in cache. - - Returns: - List of metadata for cached datasets - """ - datasets = [] - - for org_dir in self.cache_dir.iterdir(): - if not org_dir.is_dir(): - continue - - for dataset_dir in org_dir.iterdir(): - if not dataset_dir.is_dir(): - continue - - for version_dir in dataset_dir.iterdir(): - if not version_dir.is_dir(): - continue - - metadata_path = version_dir / METADATA_FILE - if metadata_path.exists(): - try: - metadata = DatasetMetadata.load(metadata_path) - datasets.append(metadata) - except Exception as e: - logger.warning(f"Failed to load metadata from {metadata_path}: {e}") - - return datasets - - def delete_dataset( - self, - uri: str, - version: str | None = None, - *, - cache_only: bool = True, - ) -> bool: - """ - Remove dataset from cache and optionally remote. - - Args: - uri: Dataset URI - version: Specific version (removes all if None) - cache_only: If True, only remove from cache - - Returns: - True if removed successfully - """ - parsed_uri, parsed_version = self.parse_uri(uri) - version = version or parsed_version - - logger.info(f"Removing dataset {parsed_uri}@{version or 'all'}") - - # Remove from cache - if version: - cache_path = self.get_dataset_path(uri, version) - if cache_path.exists(): - logger.info(f"Removing dataset from cache: {cache_path}") - shutil.rmtree(cache_path) - return True - return False - - if not cache_only and self._credential_fetcher: - # Remove from remote (implementation depends on backend) - logger.warning("Remote deletion not implemented") - - return True - - def list_datasets( - self, - *, - remote: bool = False, - cache_only: bool = True, - ) -> list[any]: - """ - List available datasets. - - Args: - remote: If True, list from remote storage - cache_only: If True, only list cached datasets - - Returns: - List of dataset metadata - """ - if cache_only or not remote: - return self.list_cached_datasets() - - # Remote listing would require API implementation - logger.warning("Remote dataset listing not implemented") - return [] - - def search_datasets( - self, - name_pattern: str | None = None, - tags: list[str] | None = None, - version: str | None = None, - *, - remote: bool = False, - ) -> list[any]: - """ - Search for datasets matching criteria. - - Args: - name_pattern: Pattern to match in name - tags: Required tags - version: Specific version - remote: Search remote storage - - Returns: - List of matching dataset metadata - """ - if remote and not self._credential_fetcher: - logger.warning("Remote search requires credential fetcher") - remote = False - - all_datasets = self.list_cached_datasets() - - results = [ds for ds in all_datasets if ds.matches_filter(name_pattern, tags, version)] - - if remote: - # Remote search would require API implementation - logger.warning("Remote dataset search not implemented") - - return results - - def get_cache_info(self) -> dict[str, Any]: - """ - Get information about the cache. - - Returns: - Dictionary with cache statistics - - Examples: - >>> client = DatasetStorage() - >>> info = client.get_cache_info() - >>> print(f"Cache size: {info['size_gb']:.2f} GB") - """ - size_bytes = self.get_cache_size() - datasets = self.list_cached_datasets() - - return { - "cache_dir": str(self.cache_dir), - "size_bytes": size_bytes, - "size_mb": size_bytes / (1024 * 1024), - "size_gb": size_bytes / (1024 * 1024 * 1024), - "dataset_count": len(datasets), - "datasets": [ - { - "name": ds.name, - "version": ds.version, - "uri": ds.uri, - } - for ds in datasets - ], - } - - def get_cache_size(self) -> int: - """ - Get total size of cache in bytes. - - Returns: - Cache size in bytes - """ - total_size = 0 - for path in self.cache_dir.rglob("*"): - if path.is_file(): - total_size += path.stat().st_size - return total_size diff --git a/dreadnode/constants.py b/dreadnode/constants.py index 47b2ccaa..7c10fd5f 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -35,6 +35,8 @@ DEFAULT_WORKSPACE_NAME = "Personal Workspace" # default project name DEFAULT_PROJECT_NAME = "Default" +# default project key +DEFAULT_PROJECT_KEY = "default" # diff --git a/dreadnode/dataset.py b/dreadnode/dataset.py index 6207cb04..d537f8de 100644 --- a/dreadnode/dataset.py +++ b/dreadnode/dataset.py @@ -8,6 +8,7 @@ from pyarrow.fs import FileSystem from dreadnode.constants import MANIFEST_FILE, METADATA_FILE +from dreadnode.logging_ import print_info from dreadnode.storage.datasets.manager import DatasetManager from dreadnode.storage.datasets.manifest import DatasetManifest, create_manifest from dreadnode.storage.datasets.metadata import DatasetMetadata, VersionInfo @@ -39,7 +40,7 @@ def __init__( self.ds = ds.to_table() if not metadata: - print("[*] No metadata provided, check your dataset!") + print_info("[*] No metadata provided, check your dataset!") def update_metadata(self, metadata: DatasetMetadata) -> None: self.metadata = metadata @@ -110,7 +111,7 @@ def save_dataset( path: str, fs: FileSystem, *, - to_cache: bool = False, + create_dir: bool = False, **kwargs: Any, ) -> None: """ @@ -123,53 +124,118 @@ def save_dataset( format="parquet", filesystem=fs, existing_data_behavior="overwrite_or_ignore", - create_dir=to_cache, + create_dir=create_dir, **kwargs, ) -def save_dataset( +def _persist_dataset( dataset: Dataset, + path_str: str, *, - to_cache: bool = False, + create_dir: bool = False, fsm: DatasetManager, **kwargs: Any, ) -> None: - if to_cache: - path_str = fsm.get_cache_save_uri(metadata=dataset.metadata) - print("[*] Saving dataset to local cache") - else: - path_str = fsm.get_remote_save_uri(metadata=dataset.metadata) - print("[*] Saving dataset to remote storage") + """Persists a dataset to the given path. + Args: + dataset: The Dataset to persist. + path_str: The path to persist the dataset to. + create_dir: Whether to create the directory if it doesn't exist. Defaults to False. + fsm: The DatasetManager instance. + kwargs: Additional arguments to pass to pyarrow.dataset.write_dataset. + """ fs, base_path = fsm.get_fs_and_path(path_str) - try: - fsm.ensure_dir(fs, base_path) + fsm.ensure_dir(fs, base_path) + data_path = f"{base_path}/data" + fsm.ensure_dir(fs, data_path) - dataset.save_dataset(path=f"{base_path}/data", fs=fs, to_cache=to_cache, **kwargs) + dataset.save_dataset(path=data_path, fs=fs, create_dir=create_dir, **kwargs) - dataset.save_metadata(path=f"{base_path}/{METADATA_FILE}", fs=fs) + dataset.save_metadata(path=f"{base_path}/{METADATA_FILE}", fs=fs) - manifest = create_manifest( - path=base_path, - version=dataset.metadata.version, - previous_manifest=dataset.manifest if dataset.manifest else None, - fs=fs, - ) - manifest.save(f"{base_path}/{MANIFEST_FILE}", fs=fs) - except Exception as e: - # if remote save failed, notify API - if not to_cache: - fsm.remote_save_complete(success=False, dataset_id=dataset.metadata.id) - print(f"[!] Failed to save dataset: {e}") - raise + manifest = create_manifest( + path=base_path, + version=dataset.metadata.version, + previous_manifest=dataset.manifest if dataset.manifest else None, + fs=fs, + ) + manifest.save(f"{base_path}/{MANIFEST_FILE}", fs=fs) + + print_info("[+] Saved dataset successfully") - # if remote save succeeded, notify API - if not to_cache: - fsm.remote_save_complete(success=True, dataset_id=dataset.metadata.id) - print("[+] Saved dataset successfully") +def save_dataset_to_disk( + dataset: Dataset, + *, + fsm: DatasetManager, + **kwargs: Any, +) -> None: + """Saves a dataset to local disk cache. + + Args: + dataset: The Dataset to save. + fsm: The DatasetManager instance. + kwargs: Additional arguments to pass to pyarrow.dataset.write_dataset. + + Returns: + None + + """ + path_str = fsm.get_cache_save_uri(metadata=dataset.metadata) + print_info("[*] Saving dataset to local cache") + + _persist_dataset( + dataset=dataset, + path_str=path_str, + create_dir=True, + fsm=fsm, + **kwargs, + ) + + +def push_dataset( + dataset: Dataset, + *, + to_cache: bool = True, + fsm: DatasetManager, + **kwargs: Any, +) -> None: + """Pushes a dataset to remote storage. + + Args: + dataset: The Dataset to push. + to_cache: Whether to save to local cache first. Defaults to True. + fsm: The DatasetManager instance. + kwargs: Additional arguments to pass to pyarrow.dataset.write_dataset. + + Returns: + None + """ + if to_cache: + save_dataset_to_disk( + dataset=dataset, + fsm=fsm, + **kwargs, + ) + + dataset_id, path_str = fsm.get_remote_save_uri(metadata=dataset.metadata) + dataset.metadata.id = dataset_id + print_info("[*] Saving dataset to remote storage") + try: + _persist_dataset( + dataset=dataset, + path_str=path_str, + fsm=fsm, + **kwargs, + ) + fsm.remote_save_complete(complete=True, dataset_id=dataset.metadata.id) + except Exception: + # if remote save failed, remove the record from API + fsm.delete_remote_dataset_record(dataset_id_or_key=dataset.metadata.id) + raise def load_dataset( @@ -203,7 +269,7 @@ def load_dataset( if protocol in ("file", "local", ""): # check cache first if not fsm.check_cache(uri, version): - print("[+] Dataset not found in cache. Loading dataset from local path...") + print_info("[+] Dataset not found in cache. Loading dataset from local path...") # load directly from local path fs, fs_path = fsm.get_fs_and_path(uri) @@ -220,7 +286,7 @@ def load_dataset( return Dataset(ds=dataset, metadata=metadata, materialize=materialize) # if in cache, load from cache - print("[+] Loading dataset from cache...") + print_info("[+] Loading dataset from cache...") # get the filesystem and path fs, fs_path = fsm.get_fs_and_path(uri) @@ -242,7 +308,7 @@ def load_dataset( return Dataset(ds=dataset, materialize=materialize, metadata=metadata, manifest=manifest) # if not local path, and not in cache, load from remote - print("[+] Loading from remote storage...") + print_info("[+] Loading from remote storage...") try: # get remote URI remote_uri = fsm.get_remote_load_uri(uri=strip_protocol(uri), version=version) @@ -264,13 +330,13 @@ def load_dataset( is_valid = manifest.validate(fs_path, fs) if not is_valid: # invalid manifest, sync from remote - print("[!] Remote dataset manifest validation failed.") + print_info("[!] Remote dataset manifest validation failed.") # load dataset dataset = ds.dataset(f"{fs_path}/data", format=format, filesystem=fs, **kwargs) return Dataset(ds=dataset, metadata=metadata, manifest=manifest, materialize=materialize) except Exception as e: - print(f"[!] Failed to load dataset from remote: {e}") + print_info(f"[!] Failed to load dataset from remote: {e}") raise FileNotFoundError(f"[!] Dataset not found: {uri}") diff --git a/dreadnode/main.py b/dreadnode/main.py index af463786..3924389e 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -38,6 +38,7 @@ ) from dreadnode.constants import ( DEFAULT_LOCAL_STORAGE_DIR, + DEFAULT_PROJECT_KEY, DEFAULT_PROJECT_NAME, DEFAULT_SERVER_URL, ENV_API_KEY, @@ -254,7 +255,7 @@ def _resolve_organization(self) -> None: if len(organizations) > 1: # We should not presume to choose an organization - org_list = "\t\n".join([f"- {o.name}" for o in organizations]) + org_list = "\t\n".join([f"- {o.key}" for o in organizations]) raise RuntimeError( f"You are part of multiple organizations. Please specify an organization from:\n{org_list}" ) @@ -360,8 +361,8 @@ def _resolve_project(self) -> None: """ Resolve the project to use based on configuration. - If a project is specified by name and doesn't exist, it will be created. - If no project is specified, it will use or create one named 'default'. + If a project is specified by key and doesn't exist, it will be created. + If no project is specified, it will use or create one with key 'default'. Raises: RuntimeError: If the API client is not initialized. @@ -378,7 +379,7 @@ def _resolve_project(self) -> None: found_project: Project | None = None try: found_project = self._api.get_project( - project_identifier=self.project or DEFAULT_PROJECT_NAME, + project_identifier=self.project or DEFAULT_PROJECT_KEY, workspace_id=self._workspace.id, ) except RuntimeError as e: @@ -391,6 +392,7 @@ def _resolve_project(self) -> None: # create it in the workspace found_project = self._api.create_project( name=self.project or DEFAULT_PROJECT_NAME, + key=self.project or DEFAULT_PROJECT_KEY, workspace_id=self._workspace.id, ) # This is what's used in all of the Traces/Spans/Runs @@ -701,8 +703,11 @@ def initialize(self) -> None: if self._api is not None: api = self._api self._credential_manager = CredentialManager( - credential_fetcher=lambda: api.get_user_data_credentials() + credential_fetcher=lambda: api.get_user_data_credentials( + organization_id=self._organization.id, workspace_id=self._workspace.id + ) ) + self._credential_manager.initialize() self._fs = self._credential_manager.get_filesystem() @@ -725,6 +730,7 @@ def initialize(self) -> None: self._fs_manager = DatasetManager().configure( api=self._api, # type: ignore[return-value] organization=self._organization.key, + organization_id=self._organization.id, ) self._initialized = True @@ -1330,24 +1336,39 @@ def load_dataset( fsm=self._fs_manager, ) - def save_dataset( + def save_dataset_to_disk( self, ds: dataset.Dataset, - *, - to_cache: bool = False, - ) -> str: + ) -> None: + """ + Save a dataset to the local cache. + + Example: + ``` + dreadnode.save_dataset_to_disk(my_dataset) + ``` + """ + + dataset.save_dataset_to_disk( + dataset=ds, + fsm=self._fs_manager, + ) + + def push_dataset( + self, + ds: dataset.Dataset, + ) -> None: """ - Save a dataset to the local cache and optionally to the Dreadnode server. + Push a dataset to the Dreadnode server. Example: ``` - uri = dreadnode.save_dataset(my_dataset) + dreadnode.push_dataset(my_dataset) ``` """ - dataset.save_dataset( + dataset.push_dataset( dataset=ds, - to_cache=to_cache, fsm=self._fs_manager, ) diff --git a/dreadnode/optimization/stop.py b/dreadnode/optimization/stop.py index 49feaccd..cb972885 100644 --- a/dreadnode/optimization/stop.py +++ b/dreadnode/optimization/stop.py @@ -63,10 +63,11 @@ def score_value( """ def stop(trials: list[Trial]) -> bool: # noqa: PLR0911 - if not trials: + finished_trials = [t for t in trials if t.status == "finished"] + if not finished_trials: return False - trial = trials[-1] + trial = finished_trials[-1] value_to_check = trial.scores.get(metric_name) if metric_name else trial.score if value_to_check is None: return False diff --git a/dreadnode/storage/datasets/manager.py b/dreadnode/storage/datasets/manager.py index 71c91f7a..e3dc6a36 100644 --- a/dreadnode/storage/datasets/manager.py +++ b/dreadnode/storage/datasets/manager.py @@ -2,15 +2,16 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any +from uuid import UUID import pyarrow.fs as pafs # The Native FS from pyarrow.fs import FileSystem, FileType from dreadnode.api import ApiClient from dreadnode.api.models import ( + CreateDatasetRequest, DatasetDownloadRequest, - DatasetUploadComplete, - DatasetUploadRequest, + DatasetUploadCompleteRequest, ) from dreadnode.constants import ( DEFAULT_LOCAL_STORAGE_DIR, @@ -20,6 +21,7 @@ METADATA_FILE, ) from dreadnode.logging_ import console as logging_console +from dreadnode.logging_ import print_info from dreadnode.storage.datasets.metadata import DatasetMetadata from dreadnode.util import resolve_endpoint @@ -37,6 +39,7 @@ class DatasetManager: """ organization: str | None = None + organization_id: UUID | None = None _instance: "DatasetManager | None" = None _api: ApiClient | None = None @@ -57,10 +60,12 @@ def configure( cls, api: ApiClient | None = None, organization: str | None = None, + organization_id: UUID | None = None, ) -> "DatasetManager": instance = cls() instance._api = api instance.organization = organization + instance.organization_id = organization_id return instance def metadata_exists(self, path: str) -> bool: @@ -128,7 +133,7 @@ def get_cache_load_uri( return str(dataset_uri / latest) - def get_remote_save_uri(self, metadata: DatasetMetadata) -> str: + def get_remote_save_uri(self, metadata: DatasetMetadata) -> tuple[UUID, str]: """ Constructs the full remote storage URI. Example: dreadnode://datasets/main/my-dataset @@ -137,13 +142,15 @@ def get_remote_save_uri(self, metadata: DatasetMetadata) -> str: if not self._api: raise ValueError("No client configured") - upload_request = DatasetUploadRequest.model_validate(metadata.model_dump()) - - response = self._api.upload_dataset_request(request=upload_request) + upload_request = CreateDatasetRequest( + org_key=metadata.organization, + key=metadata.name, + ) - return response.upload_uri + response = self._api.create_dataset(request=upload_request) + return response.dataset_id, response.user_data_access_response.upload_uri - def remote_save_complete(self, dataset_id: str, *, success: bool) -> None: + def remote_save_complete(self, dataset_id: str, *, complete: bool) -> None: """ Notifies the API that the remote upload is complete. """ @@ -151,9 +158,9 @@ def remote_save_complete(self, dataset_id: str, *, success: bool) -> None: if not self._api: raise ValueError("No client configured") - request = DatasetUploadComplete(id=dataset_id, success=success) + request = DatasetUploadCompleteRequest(dataset_id=dataset_id, complete=complete) - self._api.upload_complete(request=request.model_dump()) + self._api.upload_complete(request=request) def get_remote_load_uri(self, uri: str, version: str | None = "latest") -> str: """ @@ -167,17 +174,19 @@ def get_remote_load_uri(self, uri: str, version: str | None = "latest") -> str: response = self._api.download_dataset(request) - print(f"[*] Download URI: {response.download_uri}") + print_info(f"[*] Download URI: {response.download_uri}") return response.download_uri - def get_s3_config(self) -> dict[str, Any]: + def get_s3_config(self, dataset_id: UUID) -> dict[str, Any]: """ Translates your UserDataCredentials into PyArrow S3 arguments. """ if not self._api: raise ValueError("No client configured") - creds = self._api.get_user_data_credentials() + creds = self._api.get_user_data_credentials( + organization_id=self.organization_id, dataset_id=dataset_id + ) self._credentials_expiry = creds.expiration resolved_endpoint = resolve_endpoint(creds.endpoint) @@ -220,9 +229,13 @@ def get_fs_and_path(self, uri: str) -> tuple[FileSystem, str]: if self._cached_s3_fs is None or self.needs_refresh(): try: - config = self.get_s3_config() + # Try to extract dataset ID from URI which expect is of the form dn:///datasets/ + dataset_id = UUID(path_body.split("/")[-1]) + config = self.get_s3_config(dataset_id=dataset_id) self._cached_s3_fs = pafs.S3FileSystem(**config) - + except ValueError: + logging_console.print(f"[red]Invalid dataset ID in URI: [green]{uri}[/green][/red]") + raise except Exception as e: logging_console.print(f"Auth failed: {e}") raise @@ -244,7 +257,7 @@ def resolve_latest_version(self, uri: str, fs: FileSystem) -> str: ] latest = sorted(versions, reverse=True)[0] - print(f"[*] Resolved latest version {latest}") + print_info(f"[*] Resolved latest version {latest}") return latest def ensure_dir(self, fs: FileSystem, path: str) -> None: @@ -255,3 +268,13 @@ def ensure_dir(self, fs: FileSystem, path: str) -> None: return with contextlib.suppress(OSError): fs.create_dir(path, recursive=True) + + def delete_remote_dataset_record(self, dataset_id_or_key: UUID | str) -> None: + """ + Deletes a remote dataset via the API. + """ + + if not self._api: + raise ValueError("No client configured") + + self._api.delete_dataset(dataset_id_or_key=dataset_id_or_key) diff --git a/dreadnode/storage/datasets/manifest.py b/dreadnode/storage/datasets/manifest.py index 64dc9237..ef64fd27 100644 --- a/dreadnode/storage/datasets/manifest.py +++ b/dreadnode/storage/datasets/manifest.py @@ -132,7 +132,8 @@ def compute_file_hash( ) -> str: try: with fs.open_input_stream(file_path) as f: - return hashlib.file_digest(f, algorithm) + digest = hashlib.file_digest(f, algorithm) + return digest.hexdigest() except Exception as e: logging_console.print(f"Failed to hash {file_path}: {e}") return "" diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index a5d6e9f7..6ba1f64b 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -227,8 +227,10 @@ def duration(self) -> float: """Get the duration of the span in seconds.""" if self._span is None: return 0.0 - end_time = self.end_time or time.time() - return (end_time - self.start_time) if self.start_time else 0.0 + end_time = self.end_time or time.time_ns() + if not self.start_time: + return 0.0 + return (end_time - self.start_time) / 1e9 def set_tags(self, tags: t.Sequence[str]) -> None: tags = [tags] if isinstance(tags, str) else list(tags)