From ec0aa885268b97335f985d9614d2646c35ba346c Mon Sep 17 00:00:00 2001 From: Raja Sekhar Rao Dheekonda Date: Wed, 30 Jul 2025 23:48:25 -0700 Subject: [PATCH 1/3] Fix auto refresh s3 access token logic --- dreadnode/artifact/storage.py | 75 +++++++++---------- dreadnode/constants.py | 2 +- dreadnode/credential_manager.py | 128 ++++++++++++++++++++++++++++++++ dreadnode/main.py | 74 +++--------------- dreadnode/storage_utils.py | 37 --------- dreadnode/tracing/span.py | 70 +++++++---------- pyproject.toml | 4 +- 7 files changed, 204 insertions(+), 186 deletions(-) create mode 100644 dreadnode/credential_manager.py delete mode 100644 dreadnode/storage_utils.py diff --git a/dreadnode/artifact/storage.py b/dreadnode/artifact/storage.py index f6fa180a..65b46866 100644 --- a/dreadnode/artifact/storage.py +++ b/dreadnode/artifact/storage.py @@ -4,12 +4,9 @@ """ import hashlib -import typing as t from pathlib import Path -import fsspec # type: ignore[import-untyped] - -from dreadnode.storage_utils import with_credential_refresh +from dreadnode.credential_manager import CredentialManager from dreadnode.util import logger CHUNK_SIZE = 8 * 1024 * 1024 # 8MB @@ -24,27 +21,15 @@ class ArtifactStorage: - Batch uploads for directories handled by fsspec """ - def __init__( - self, - file_system: fsspec.AbstractFileSystem, - credential_refresher: t.Callable[[], bool] | None = None, - ): + def __init__(self, credential_manager: CredentialManager): """ - Initialize artifact storage with a file system and prefix path. + Initialize artifact storage with credential manager. Args: - file_system: FSSpec-compatible file system - credential_refresher: Optional function to refresh credentials when it's about to expire + credential_manager: Optional credential manager for S3 operations """ - self._file_system = file_system - self._credential_refresher = credential_refresher - - def _refresh_credentials_if_needed(self) -> None: - """Refresh credentials if refresher is available.""" - if self._credential_refresher: - self._credential_refresher() + self._credential_manager: CredentialManager = credential_manager - @with_credential_refresh def store_file(self, file_path: Path, target_key: str) -> str: """ Store a file in the storage system, using multipart upload for large files. @@ -56,13 +41,19 @@ def store_file(self, file_path: Path, target_key: str) -> str: Returns: Full URI with protocol to the stored file """ - if not self._file_system.exists(target_key): - self._file_system.put(str(file_path), target_key) - logger.debug("Artifact successfully stored at %s", target_key) - else: - logger.debug("Artifact already exists at %s, skipping upload.", target_key) - return str(self._file_system.unstrip_protocol(target_key)) + def store_operation() -> str: + filesystem = self._credential_manager.get_filesystem() + + if not filesystem.exists(target_key): + filesystem.put(str(file_path), target_key) + logger.info("Artifact successfully stored at %s", target_key) + else: + logger.info("Artifact already exists at %s, skipping upload.", target_key) + + return str(filesystem.unstrip_protocol(target_key)) + + return self._credential_manager.execute_with_retry(store_operation) def batch_upload_files(self, source_paths: list[str], target_paths: list[str]) -> list[str]: """ @@ -78,23 +69,26 @@ def batch_upload_files(self, source_paths: list[str], target_paths: list[str]) - if not source_paths: return [] - logger.debug("Batch uploading %d files", len(source_paths)) + def batch_upload_operation() -> list[str]: + filesystem = self._credential_manager.get_filesystem() - srcs = [] - dsts = [] + srcs = [] + dsts = [] - for src, dst in zip(source_paths, target_paths, strict=False): - if not self._file_system.exists(dst): - srcs.append(src) - dsts.append(dst) + for src, dst in zip(source_paths, target_paths, strict=False): + if not filesystem.exists(dst): + srcs.append(src) + dsts.append(dst) - if srcs: - self._file_system.put(srcs, dsts) - logger.debug("Batch upload completed for %d files", len(srcs)) - else: - logger.debug("All files already exist, skipping upload") + if srcs: + filesystem.put(srcs, dsts) + logger.info("Batch upload completed for %d files", len(srcs)) + else: + logger.info("All files already exist, skipping upload") - return [str(self._file_system.unstrip_protocol(target)) for target in target_paths] + return [str(filesystem.unstrip_protocol(target)) for target in target_paths] + + return self._credential_manager.execute_with_retry(batch_upload_operation) def compute_file_hash(self, file_path: Path, stream_threshold_mb: int = 10) -> str: """ @@ -107,8 +101,9 @@ def compute_file_hash(self, file_path: Path, stream_threshold_mb: int = 10) -> s Returns: First 16 chars of SHA1 hash """ + file_size = file_path.stat().st_size - stream_threshold = stream_threshold_mb * 1024 * 1024 # Convert MB to bytes + stream_threshold = stream_threshold_mb * 1024 * 1024 sha1 = hashlib.sha1() # noqa: S324 # nosec diff --git a/dreadnode/constants.py b/dreadnode/constants.py index 3212e404..f2888347 100644 --- a/dreadnode/constants.py +++ b/dreadnode/constants.py @@ -58,4 +58,4 @@ ) # Default values for the file system credential management -FS_CREDENTIAL_REFRESH_BUFFER = 300 # 5 minutes in seconds +FS_CREDENTIAL_REFRESH_BUFFER = 900 # 15 minutes in seconds diff --git a/dreadnode/credential_manager.py b/dreadnode/credential_manager.py new file mode 100644 index 00000000..d90405c7 --- /dev/null +++ b/dreadnode/credential_manager.py @@ -0,0 +1,128 @@ +import time +from collections.abc import Callable +from datetime import datetime, timezone +from typing import TYPE_CHECKING, TypeVar + +from botocore.exceptions import ClientError +from s3fs import S3FileSystem + +from dreadnode.constants import FS_CREDENTIAL_REFRESH_BUFFER +from dreadnode.util import logger, resolve_endpoint + +if TYPE_CHECKING: + from dreadnode.api.models import UserDataCredentials + + +T = TypeVar("T") + + +class CredentialManager: + """Simple credential manager that handles S3 credential refresh automatically.""" + + def __init__(self, credential_fetcher: Callable[[], "UserDataCredentials"]): + """ + Initialize credential manager. + + Args: + credential_fetcher: Function that returns new UserDataCredentials when called + """ + self._credential_fetcher = credential_fetcher + self._credentials: UserDataCredentials | None = None + self._credentials_expiry: datetime | None = None + self._filesystem = None + self._prefix = "" + + def initialize(self) -> None: + """Initialize with fresh credentials.""" + self._refresh_credentials() + + def get_filesystem(self) -> S3FileSystem: + """Get current filesystem, refreshing credentials if needed.""" + if self._needs_refresh(): + self._refresh_credentials() + return self._filesystem + + def get_prefix(self) -> str: + """Get current prefix path.""" + return self._prefix + + def _needs_refresh(self) -> bool: + """Check if credentials need refreshing.""" + if not self._credentials_expiry or not self._filesystem: + return True + + now = datetime.now(timezone.utc) + time_left = (self._credentials_expiry - now).total_seconds() + return time_left < FS_CREDENTIAL_REFRESH_BUFFER + + def _refresh_credentials(self) -> None: + """Refresh credentials and create new filesystem.""" + try: + logger.info("Refreshing storage credentials") + new_credentials = self._credential_fetcher() + resolved_endpoint = resolve_endpoint(new_credentials.endpoint) + + new_filesystem = S3FileSystem( + key=new_credentials.access_key_id, + secret=new_credentials.secret_access_key, + token=new_credentials.session_token, + client_kwargs={ + "endpoint_url": resolved_endpoint, + "region_name": new_credentials.region, + }, + use_listings_cache=False, + listings_expiry_time=0, + skip_instance_cache=True, + ) + + # Update internal state + self._credentials = new_credentials + self._credentials_expiry = new_credentials.expiration + self._filesystem = new_filesystem + self._prefix = f"{new_credentials.bucket}/{new_credentials.prefix}/" + + logger.info("Storage credentials refreshed, valid until %s", self._credentials_expiry) + + except Exception: + logger.exception("Failed to refresh storage credentials") + raise + + def execute_with_retry(self, operation: Callable[[], T], max_retries: int = 3) -> T: + """ + Execute an operation with automatic credential refresh on auth errors. + + Args: + operation: Function to execute (should use self.get_filesystem()) + max_retries: Maximum number of retry attempts + + Returns: + Result of the operation + """ + for attempt in range(max_retries): + try: + return operation() + except ClientError as e: # noqa: PERF203 + error_code = e.response.get("Error", {}).get("Code", "") + if error_code in ["ExpiredToken", "InvalidAccessKeyId", "SignatureDoesNotMatch"]: + logger.info( + "Credential error on attempt %d/%d, refreshing...", attempt + 1, max_retries + ) + + try: + self._refresh_credentials() + except Exception: + logger.exception("Failed to refresh credentials") + if attempt == max_retries - 1: + raise + + if attempt < max_retries - 1: + time.sleep(attempt + 1) + continue + else: + raise + except Exception: + raise + + raise RuntimeError( + f"Operation failed after {max_retries} attempts due to credential issues" + ) diff --git a/dreadnode/main.py b/dreadnode/main.py index f4c84ccf..f4c8b93a 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -21,7 +21,6 @@ from opentelemetry.exporter.otlp.proto.http import Compression from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace.export import BatchSpanProcessor -from s3fs import S3FileSystem # type: ignore [import-untyped] from dreadnode.api.client import ApiClient from dreadnode.config import UserConfig @@ -35,8 +34,8 @@ ENV_PROJECT, ENV_SERVER, ENV_SERVER_URL, - FS_CREDENTIAL_REFRESH_BUFFER, ) +from dreadnode.credential_manager import CredentialManager from dreadnode.metric import ( Metric, MetricAggMode, @@ -65,7 +64,7 @@ Inherited, JsonValue, ) -from dreadnode.util import clean_str, handle_internal_errors, logger, resolve_endpoint +from dreadnode.util import clean_str, handle_internal_errors from dreadnode.version import VERSION if t.TYPE_CHECKING: @@ -74,8 +73,6 @@ from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.trace import Tracer - from dreadnode.api.models import UserDataCredentials - ToObject = t.Literal["task-or-run", "run"] @@ -132,7 +129,7 @@ def __init__( self.otel_scope = otel_scope self._api: ApiClient | None = None - + self._credential_manager: CredentialManager | None = None self._logfire = logfire.DEFAULT_LOGFIRE_INSTANCE self._logfire.config.ignore_no_config = True @@ -140,8 +137,6 @@ def __init__( self._fs_prefix: str = ".dreadnode/storage/" self._initialized = False - self._credentials: UserDataCredentials | None = None - self._credentials_expiry: datetime | None = None def _get_profile_server(self, profile: str | None = None) -> str | None: with contextlib.suppress(Exception): @@ -352,19 +347,13 @@ def initialize(self) -> None: # ) # ) # ) - self._credentials = self._api.get_user_data_credentials() - self._credentials_expiry = self._credentials.expiration - resolved_endpoint = resolve_endpoint(self._credentials.endpoint) - self._fs = S3FileSystem( - key=self._credentials.access_key_id, - secret=self._credentials.secret_access_key, - token=self._credentials.session_token, - client_kwargs={ - "endpoint_url": resolved_endpoint, - "region_name": self._credentials.region, - }, + self._credential_manager = CredentialManager( + credential_fetcher=lambda: self._api.get_user_data_credentials() ) - self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/" + self._credential_manager.initialize() + + self._fs = self._credential_manager.get_filesystem() + self._fs_prefix = self._credential_manager.get_prefix() self._logfire = logfire.configure( local=not self.is_default, @@ -411,43 +400,6 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie return self._api - def _refresh_storage_credentials(self) -> bool: - """Refresh storage credentials if they are about to expire.""" - if not self._api or not self._credentials: - return False - - now = datetime.now(timezone.utc) - - if ( - self._credentials_expiry is None - or (self._credentials_expiry - now).total_seconds() < FS_CREDENTIAL_REFRESH_BUFFER - ): - try: - logger.info("Refreshing storage credentials") - self._credentials = self._api.get_user_data_credentials() - self._credentials_expiry = self._credentials.expiration - - resolved_endpoint = resolve_endpoint(self._credentials.endpoint) - self._fs = S3FileSystem( - key=self._credentials.access_key_id, - secret=self._credentials.secret_access_key, - token=self._credentials.session_token, - client_kwargs={ - "endpoint_url": resolved_endpoint, - "region_name": self._credentials.region, - }, - ) - logger.info( - f"Storage credentials refreshed, valid until {self._credentials_expiry}" - ) - return True # noqa: TRY300 - - except Exception as e: # noqa: BLE001 - logger.error(f"Failed to refresh storage credentials: {e}") - return False - - return True - def _get_tracer(self, *, is_span_tracer: bool = True) -> "Tracer": return self._logfire._tracer_provider.get_tracer( # noqa: SLF001 self.otel_scope, @@ -817,10 +769,8 @@ def run( tracer=self._get_tracer(), params=params, tags=tags, - file_system=self._fs, - prefix_path=self._fs_prefix, + credential_manager=self._credential_manager, # type: ignore[arg-type] autolog=autolog, - credential_refresher=self._refresh_storage_credentials if self._credentials else None, ) def get_run_context(self) -> RunContext: @@ -865,9 +815,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: return RunSpan.from_context( context=run_context, tracer=self._get_tracer(), - file_system=self._fs, - prefix_path=self._fs_prefix, - credential_refresher=self._refresh_storage_credentials if self._credentials else None, + credential_manager=self._credential_manager, # type: ignore[arg-type] ) def tag(self, *tag: str, to: ToObject = "task-or-run") -> None: diff --git a/dreadnode/storage_utils.py b/dreadnode/storage_utils.py deleted file mode 100644 index 9599a238..00000000 --- a/dreadnode/storage_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import typing as t - -from dreadnode.util import logger - - -def with_credential_refresh(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - """Decorator that automatically handles credential refresh on storage errors.""" - - @functools.wraps(func) - def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: - # Try to refresh credentials before operation - if hasattr(self, "_refresh_credentials_if_needed"): - self._refresh_credentials_if_needed() - - try: - return func(self, *args, **kwargs) - except Exception as e: - error_str = str(e) - if any( - error in error_str - for error in [ - "ExpiredToken", - "TokenRefreshRequired", - "InvalidAccessKeyId", - "The Access Key Id you provided does not exist", - ] - ): - logger.info("Storage credential error, forcing refresh and retrying") - - if hasattr(self, "_refresh_credentials_if_needed"): - self._refresh_credentials_if_needed() - - return func(self, *args, **kwargs) - raise - - return wrapper diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index ee8ec07b..7d605e7e 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -9,7 +9,6 @@ from pathlib import Path import typing_extensions as te -from fsspec import AbstractFileSystem # type: ignore [import-untyped] from logfire._internal.json_encoder import logfire_json_dumps as json_dumps from logfire._internal.json_schema import ( JsonSchemaProperties, @@ -33,10 +32,10 @@ from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode from dreadnode.constants import DEFAULT_MAX_INLINE_OBJECT_BYTES from dreadnode.convert import run_span_to_graph +from dreadnode.credential_manager import CredentialManager from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize -from dreadnode.storage_utils import with_credential_refresh from dreadnode.tracing.constants import ( EVENT_ATTRIBUTE_LINK_HASH, EVENT_ATTRIBUTE_OBJECT_HASH, @@ -73,6 +72,7 @@ if t.TYPE_CHECKING: import networkx as nx # type: ignore [import-untyped] + logger = logging.getLogger(__name__) R = t.TypeVar("R") @@ -355,8 +355,7 @@ def __init__( name: str, project: str, tracer: Tracer, - file_system: AbstractFileSystem, - prefix_path: str, + credential_manager: CredentialManager, *, attributes: AnyDict | None = None, params: AnyDict | None = None, @@ -366,7 +365,6 @@ def __init__( update_frequency: int = 5, run_id: str | ULID | None = None, type: SpanType = "run", - credential_refresher: t.Callable[[], bool] | None = None, ) -> None: self.autolog = autolog self.project = project @@ -377,14 +375,16 @@ def __init__( self._object_schemas: dict[str, JsonDict] = {} self._inputs: list[ObjectRef] = [] self._outputs: list[ObjectRef] = [] - self._artifact_storage = ArtifactStorage( - file_system=file_system, credential_refresher=credential_refresher - ) + + # Credential manager for S3 operations + self._credential_manager = credential_manager + + # Initialize artifact components + self._artifact_storage = ArtifactStorage(credential_manager=credential_manager) self._artifacts: list[DirectoryNode] = [] self._artifact_merger = ArtifactMerger() self._artifact_tree_builder = ArtifactTreeBuilder( - storage=self._artifact_storage, - prefix_path=prefix_path, + storage=self._artifact_storage, prefix_path=self._credential_manager.get_prefix() ) # Update mechanics @@ -397,12 +397,9 @@ def __init__( self._pending_objects = deepcopy(self._objects) self._pending_object_schemas = deepcopy(self._object_schemas) - self._context_token: Token[RunSpan | None] | None = None # contextvars context - self._remote_context: dict[str, str] | None = None # remote run trace context + self._context_token: Token[RunSpan | None] | None = None + self._remote_context: dict[str, str] | None = None self._remote_token: object | None = None - self._file_system = file_system - self._prefix_path = prefix_path - self._tasks: list[TaskSpan[t.Any]] = [] attributes = { @@ -410,7 +407,7 @@ def __init__( SPAN_ATTRIBUTE_PROJECT: project, **(attributes or {}), } - self._credential_refresher = credential_refresher + super().__init__(name, tracer, attributes=attributes, type=type, tags=tags) @classmethod @@ -418,24 +415,19 @@ def from_context( cls, context: RunContext, tracer: Tracer, - file_system: AbstractFileSystem, - prefix_path: str, - credential_refresher: t.Callable[[], bool] | None = None, + credential_manager: CredentialManager, ) -> "RunSpan": self = RunSpan( name=f"run.{context['run_id']}.fragment", project=context["project"], attributes={}, tracer=tracer, - file_system=file_system, - prefix_path=prefix_path, type="run_fragment", run_id=context["run_id"], - credential_refresher=credential_refresher, + credential_manager=credential_manager, ) self._remote_context = context["trace_context"] - return self def __enter__(self) -> te.Self: @@ -507,10 +499,6 @@ def __exit__( if self._context_token is not None: current_run_span.reset(self._context_token) - def _refresh_credentials_if_needed(self) -> None: - if self._credential_refresher: - self._credential_refresher() - def push_update(self, *, force: bool = False) -> None: if self._span is None: return @@ -615,24 +603,19 @@ def log_object( return composite_hash - @with_credential_refresh - def _store_file_by_hash(self, data: bytes, full_path: str) -> str: - """ - Writes data to the given full_path in the object store if it doesn't already exist. + def _store_file_by_hash(self, data_bytes: bytes, full_path: str) -> str: + """Store file with automatic credential refresh.""" - Args: - data: Content to write. - full_path: The path in the object store (e.g., S3 key or local path). + def store_operation() -> str: + filesystem = self._credential_manager.get_filesystem() - Returns: - The unstrip_protocol version of the full path (for object store URI). - """ - if not self._file_system.exists(full_path): - logger.debug("Storing new object at: %s", full_path) - with self._file_system.open(full_path, "wb") as f: - f.write(data) + if not filesystem.exists(full_path): + with filesystem.open(full_path, "wb") as f: + f.write(data_bytes) + + return str(filesystem.unstrip_protocol(full_path)) - return str(self._file_system.unstrip_protocol(full_path)) + return self._credential_manager.execute_with_retry(store_operation) def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Object: """Create an ObjectVal or ObjectUri depending on size with a specific hash.""" @@ -652,7 +635,8 @@ def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Ob # Offload to file system (e.g., S3) # For storage efficiency, still use just the data_hash for the file path # This ensures we don't duplicate storage for the same data - full_path = f"{self._prefix_path.rstrip('/')}/{data_hash}" + prefix = self._credential_manager.get_prefix() + full_path = f"{prefix.rstrip('/')}/{data_hash}" object_uri = self._store_file_by_hash(data_bytes, full_path) return ObjectUri( diff --git a/pyproject.toml b/pyproject.toml index a70c9d94..620acfb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [project] name = "dreadnode" -version = "1.13.2" +version = "1.13.3" description = "Dreadnode SDK" requires-python = ">=3.10,<3.14" [tool.poetry] name = "dreadnode" -version = "1.13.0" +version = "1.13.3" description = "Dreadnode SDK" authors = ["Nick Landers "] repository = "https://github.com/dreadnode/sdk" From 10c6ae9194990e90b77ae69c471502d60d05442d Mon Sep 17 00:00:00 2001 From: Raja Sekhar Rao Dheekonda Date: Thu, 31 Jul 2025 00:08:23 -0700 Subject: [PATCH 2/3] Fix mypy errors --- dreadnode/credential_manager.py | 5 +++-- dreadnode/main.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/dreadnode/credential_manager.py b/dreadnode/credential_manager.py index d90405c7..07f90c2d 100644 --- a/dreadnode/credential_manager.py +++ b/dreadnode/credential_manager.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, TypeVar from botocore.exceptions import ClientError -from s3fs import S3FileSystem +from s3fs import S3FileSystem # type: ignore[import-untyped] from dreadnode.constants import FS_CREDENTIAL_REFRESH_BUFFER from dreadnode.util import logger, resolve_endpoint @@ -29,7 +29,7 @@ def __init__(self, credential_fetcher: Callable[[], "UserDataCredentials"]): self._credential_fetcher = credential_fetcher self._credentials: UserDataCredentials | None = None self._credentials_expiry: datetime | None = None - self._filesystem = None + self._filesystem: S3FileSystem | None = None self._prefix = "" def initialize(self) -> None: @@ -40,6 +40,7 @@ def get_filesystem(self) -> S3FileSystem: """Get current filesystem, refreshing credentials if needed.""" if self._needs_refresh(): self._refresh_credentials() + assert self._filesystem is not None # noqa: S101 return self._filesystem def get_prefix(self) -> str: diff --git a/dreadnode/main.py b/dreadnode/main.py index f4c8b93a..ebcfbe87 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -347,13 +347,15 @@ def initialize(self) -> None: # ) # ) # ) - self._credential_manager = CredentialManager( - credential_fetcher=lambda: self._api.get_user_data_credentials() - ) - self._credential_manager.initialize() + if self._api is not None: + api = self._api + self._credential_manager = CredentialManager( + credential_fetcher=lambda: api.get_user_data_credentials() + ) + self._credential_manager.initialize() - self._fs = self._credential_manager.get_filesystem() - self._fs_prefix = self._credential_manager.get_prefix() + self._fs = self._credential_manager.get_filesystem() + self._fs_prefix = self._credential_manager.get_prefix() self._logfire = logfire.configure( local=not self.is_default, From b9c352098e7273c4cec218f57a0c9b0b16d473f9 Mon Sep 17 00:00:00 2001 From: Raja Sekhar Rao Dheekonda Date: Thu, 31 Jul 2025 00:09:08 -0700 Subject: [PATCH 3/3] Docs update --- docs/sdk/artifact.mdx | 83 ++++++++++++++++++++----------------------- docs/sdk/main.mdx | 34 +++++++----------- 2 files changed, 51 insertions(+), 66 deletions(-) diff --git a/docs/sdk/artifact.mdx b/docs/sdk/artifact.mdx index 7c6495f1..25adc4d2 100644 --- a/docs/sdk/artifact.mdx +++ b/docs/sdk/artifact.mdx @@ -244,10 +244,7 @@ ArtifactStorage --------------- ```python -ArtifactStorage( - file_system: AbstractFileSystem, - credential_refresher: Callable[[], bool] | None = None, -) +ArtifactStorage(credential_manager: CredentialManager) ``` Storage for artifacts with efficient handling of large files and directories. @@ -256,35 +253,24 @@ Supports: - Content-based deduplication using SHA1 hashing - Batch uploads for directories handled by fsspec -Initialize artifact storage with a file system and prefix path. +Initialize artifact storage with credential manager. **Parameters:** -* **`file_system`** - (`AbstractFileSystem`) - –FSSpec-compatible file system -* **`credential_refresher`** - (`Callable[[], bool] | None`, default: - `None` - ) - –Optional function to refresh credentials when it's about to expire +* **`credential_manager`** + (`CredentialManager`) + –Optional credential manager for S3 operations ```python -def __init__( - self, - file_system: fsspec.AbstractFileSystem, - credential_refresher: t.Callable[[], bool] | None = None, -): +def __init__(self, credential_manager: CredentialManager): """ - Initialize artifact storage with a file system and prefix path. + Initialize artifact storage with credential manager. Args: - file_system: FSSpec-compatible file system - credential_refresher: Optional function to refresh credentials when it's about to expire + credential_manager: Optional credential manager for S3 operations """ - self._file_system = file_system - self._credential_refresher = credential_refresher + self._credential_manager: CredentialManager = credential_manager ``` @@ -330,23 +316,26 @@ def batch_upload_files(self, source_paths: list[str], target_paths: list[str]) - if not source_paths: return [] - logger.debug("Batch uploading %d files", len(source_paths)) + def batch_upload_operation() -> list[str]: + filesystem = self._credential_manager.get_filesystem() - srcs = [] - dsts = [] + srcs = [] + dsts = [] - for src, dst in zip(source_paths, target_paths, strict=False): - if not self._file_system.exists(dst): - srcs.append(src) - dsts.append(dst) + for src, dst in zip(source_paths, target_paths, strict=False): + if not filesystem.exists(dst): + srcs.append(src) + dsts.append(dst) - if srcs: - self._file_system.put(srcs, dsts) - logger.debug("Batch upload completed for %d files", len(srcs)) - else: - logger.debug("All files already exist, skipping upload") + if srcs: + filesystem.put(srcs, dsts) + logger.info("Batch upload completed for %d files", len(srcs)) + else: + logger.info("All files already exist, skipping upload") - return [str(self._file_system.unstrip_protocol(target)) for target in target_paths] + return [str(filesystem.unstrip_protocol(target)) for target in target_paths] + + return self._credential_manager.execute_with_retry(batch_upload_operation) ``` @@ -391,8 +380,9 @@ def compute_file_hash(self, file_path: Path, stream_threshold_mb: int = 10) -> s Returns: First 16 chars of SHA1 hash """ + file_size = file_path.stat().st_size - stream_threshold = stream_threshold_mb * 1024 * 1024 # Convert MB to bytes + stream_threshold = stream_threshold_mb * 1024 * 1024 sha1 = hashlib.sha1() # noqa: S324 # nosec @@ -478,7 +468,6 @@ Store a file in the storage system, using multipart upload for large files. ```python -@with_credential_refresh def store_file(self, file_path: Path, target_key: str) -> str: """ Store a file in the storage system, using multipart upload for large files. @@ -490,13 +479,19 @@ def store_file(self, file_path: Path, target_key: str) -> str: Returns: Full URI with protocol to the stored file """ - if not self._file_system.exists(target_key): - self._file_system.put(str(file_path), target_key) - logger.debug("Artifact successfully stored at %s", target_key) - else: - logger.debug("Artifact already exists at %s, skipping upload.", target_key) - return str(self._file_system.unstrip_protocol(target_key)) + def store_operation() -> str: + filesystem = self._credential_manager.get_filesystem() + + if not filesystem.exists(target_key): + filesystem.put(str(file_path), target_key) + logger.info("Artifact successfully stored at %s", target_key) + else: + logger.info("Artifact already exists at %s, skipping upload.", target_key) + + return str(filesystem.unstrip_protocol(target_key)) + + return self._credential_manager.execute_with_retry(store_operation) ``` diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 5af97c3e..c59e8d74 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -57,7 +57,7 @@ def __init__( self.otel_scope = otel_scope self._api: ApiClient | None = None - + self._credential_manager: CredentialManager | None = None self._logfire = logfire.DEFAULT_LOGFIRE_INSTANCE self._logfire.config.ignore_no_config = True @@ -65,8 +65,6 @@ def __init__( self._fs_prefix: str = ".dreadnode/storage/" self._initialized = False - self._credentials: UserDataCredentials | None = None - self._credentials_expiry: datetime | None = None ``` @@ -380,9 +378,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: return RunSpan.from_context( context=run_context, tracer=self._get_tracer(), - file_system=self._fs, - prefix_path=self._fs_prefix, - credential_refresher=self._refresh_storage_credentials if self._credentials else None, + credential_manager=self._credential_manager, # type: ignore[arg-type] ) ``` @@ -526,19 +522,15 @@ def initialize(self) -> None: # ) # ) # ) - self._credentials = self._api.get_user_data_credentials() - self._credentials_expiry = self._credentials.expiration - resolved_endpoint = resolve_endpoint(self._credentials.endpoint) - self._fs = S3FileSystem( - key=self._credentials.access_key_id, - secret=self._credentials.secret_access_key, - token=self._credentials.session_token, - client_kwargs={ - "endpoint_url": resolved_endpoint, - "region_name": self._credentials.region, - }, - ) - self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/" + if self._api is not None: + api = self._api + self._credential_manager = CredentialManager( + credential_fetcher=lambda: api.get_user_data_credentials() + ) + self._credential_manager.initialize() + + self._fs = self._credential_manager.get_filesystem() + self._fs_prefix = self._credential_manager.get_prefix() self._logfire = logfire.configure( local=not self.is_default, @@ -1723,10 +1715,8 @@ def run( tracer=self._get_tracer(), params=params, tags=tags, - file_system=self._fs, - prefix_path=self._fs_prefix, + credential_manager=self._credential_manager, # type: ignore[arg-type] autolog=autolog, - credential_refresher=self._refresh_storage_credentials if self._credentials else None, ) ```