Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,21 +247,22 @@ def _handle_status(
error_message = execution_message
raise RuntimeError(error_message)

def get_project_s3_path(self, domain_identifier: str, project_id: str) -> str:
def get_project_s3_path(self, domain_identifier: str, project_id: str) -> tuple[str, str]:
"""
Look up the S3 bucket path for a SageMaker Unified Studio project.
Look up the S3 location for a SageMaker Unified Studio project.

The bucket path is read from the ``s3BucketPath`` provisioned resource of
the project's default ("Tooling") environment via the DataZone APIs:
``GetEnvironment(GetProjectDefaultEnvironment(...))``. This mirrors how
SageMaker Unified Studio resolves the project bucket, and accommodates projects
whose bucket name does not follow the
``amazon-sagemaker-{account_id}-{region}-{project_id}`` template (for
example, BYOR-bucket projects).
The bucket and key prefix are read from the ``s3BucketPath`` provisioned
resource of the project's default ("Tooling") environment via the
DataZone APIs. This mirrors how SageMaker Unified Studio resolves the
project bucket and accommodates projects whose bucket name does not
follow the ``amazon-sagemaker-{account_id}-{region}-{project_id}``
template (for example, BYOR-bucket projects).

:param domain_identifier: The ID of the DataZone domain.
:param project_id: The ID of the DataZone project.
:return: The S3 bucket name for the project.
:return: A ``(bucket, prefix)`` tuple. ``bucket`` is the S3 bucket name.
``prefix`` is the path component of the project's
``s3BucketPath`` (with no leading or trailing ``/``).
:raises RuntimeError: If the default tooling environment or the
``s3BucketPath`` provisioned resource cannot be found.
"""
Expand All @@ -277,7 +278,9 @@ def get_project_s3_path(self, domain_identifier: str, project_id: str) -> str:
f"environment {environment_id} for project {project_id} in domain "
f"{domain_identifier}"
)
# value looks like "s3://<bucket>/<prefix>"; return the bucket name only.
# value looks like "s3://<bucket>/shared/<suffix>" (IAM) or
# "s3://<bucket>/<domain>/<project>/dev/<suffix>" (IDC). Return both
# parts so callers can construct project-scoped keys.
parts = urlparse(value, allow_fragments=False)
bucket = parts.netloc
if not bucket:
Expand All @@ -286,7 +289,8 @@ def get_project_s3_path(self, domain_identifier: str, project_id: str) -> str:
f"'{value}' in default tooling environment {environment_id} for "
f"project {project_id} in domain {domain_identifier}"
)
return bucket
prefix = parts.path.strip("/")
return bucket, prefix

raise RuntimeError(
f"s3BucketPath provisioned resource not found in default tooling environment "
Expand Down Expand Up @@ -419,18 +423,23 @@ def get_notebook_outputs(
"""
log = logging.getLogger(__name__)
try:
bucket = self.get_project_s3_path(domain_identifier, owning_project_identifier)
bucket, prefix = self.get_project_s3_path(domain_identifier, owning_project_identifier)
except Exception:
log.warning(
"Failed to resolve project S3 bucket for project %s in domain %s, "
"Failed to resolve project S3 location for project %s in domain %s, "
"skipping notebook outputs read.",
owning_project_identifier,
domain_identifier,
exc_info=True,
)
return {}

key = f"sys/notebooks/{notebook_identifier}/runs/{notebook_run_id}/notebook_outputs.json"
# IDC domains have a non-empty prefix (e.g. "<domain>/<project>/<scope>")
# and the project role's IAM policy only allows S3 reads under that prefix.
# IAM domains have prefix == "" and the key is unchanged from the
# legacy bucket-root layout.
run_key = f".sys/notebooks/{notebook_identifier}/runs/{notebook_run_id}/notebook_outputs.json"
key = f"{prefix}/{run_key}" if prefix else run_key

log.info("Reading notebook outputs from s3://%s/%s", bucket, key)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_get_project_s3_path_uses_default_tooling_environment(self):

result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)

assert result == bucket
assert result == (bucket, f"dzd_x/{PROJECT_ID}/dev")
self.mock_client.list_environment_blueprints.assert_any_call(
domainIdentifier=DOMAIN_ID,
managed=True,
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_get_project_s3_path_picks_lowest_deployment_order(self):

result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)

assert result == bucket
assert result == (bucket, "p")
self.mock_client.get_environment.assert_called_once_with(
domainIdentifier=DOMAIN_ID,
identifier=env_id,
Expand All @@ -389,7 +389,7 @@ def test_get_project_s3_path_falls_back_to_tooling_lite(self):

result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)

assert result == bucket
assert result == (bucket, "p")
# Both blueprint lookups happened.
assert self.mock_client.list_environment_blueprints.call_count == 2
self.mock_client.get_environment.assert_called_once_with(
Expand All @@ -414,7 +414,7 @@ def test_get_project_s3_path_falls_back_to_first_when_no_deployment_order(self):

result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID)

assert result == bucket
assert result == (bucket, "p")
self.mock_client.get_environment.assert_called_once_with(
domainIdentifier=DOMAIN_ID,
identifier=env_id,
Expand Down Expand Up @@ -509,7 +509,33 @@ def test_get_notebook_outputs_success(self):
)

assert result == outputs
expected_key = f"sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json"
expected_key = f"dzd_x/{PROJECT_ID}/dev/.sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json"
mock_s3_hook_cls.return_value.read_key.assert_called_once_with(key=expected_key, bucket_name=bucket)

def test_get_notebook_outputs_iam_mode_no_prefix(self):
"""IAM-mode projects (s3BucketPath is bucket-only) read from the bucket root."""
outputs = {"name": "Alice"}
bucket = "iam-mode-bucket"
# Tooling env returns s3BucketPath without a path component.
self._stub_tooling_blueprint_lookup(
environments=[{"id": "env-1", "name": "Tooling", "deploymentOrder": 1}]
)
self.mock_client.get_environment.return_value = {
"id": "env-1",
"provisionedResources": [{"name": "s3BucketPath", "value": f"s3://{bucket}"}],
}

with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls:
mock_s3_hook_cls.return_value.read_key.return_value = json.dumps(outputs)
result = self.hook.get_notebook_outputs(
notebook_identifier=NOTEBOOK_ID,
notebook_run_id=NOTEBOOK_RUN_ID,
domain_identifier=DOMAIN_ID,
owning_project_identifier=PROJECT_ID,
)

assert result == outputs
expected_key = f".sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json"
mock_s3_hook_cls.return_value.read_key.assert_called_once_with(key=expected_key, bucket_name=bucket)

def test_get_notebook_outputs_no_such_key(self):
Expand Down
Loading