diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py index 20ed57b3a2be3..2c548a5c45cf3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py @@ -247,21 +247,22 @@ def _handle_status( error_message = execution_message raise RuntimeError(error_message) - def get_project_s3_path(self, domain_identifier: str, project_id: str) -> str: + def get_project_s3_path(self, domain_identifier: str, project_id: str) -> tuple[str, str]: """ - Look up the S3 bucket path for a SageMaker Unified Studio project. + Look up the S3 location for a SageMaker Unified Studio project. - The bucket path is read from the ``s3BucketPath`` provisioned resource of - the project's default ("Tooling") environment via the DataZone APIs: - ``GetEnvironment(GetProjectDefaultEnvironment(...))``. This mirrors how - SageMaker Unified Studio resolves the project bucket, and accommodates projects - whose bucket name does not follow the - ``amazon-sagemaker-{account_id}-{region}-{project_id}`` template (for - example, BYOR-bucket projects). + The bucket and key prefix are read from the ``s3BucketPath`` provisioned + resource of the project's default ("Tooling") environment via the + DataZone APIs. This mirrors how SageMaker Unified Studio resolves the + project bucket and accommodates projects whose bucket name does not + follow the ``amazon-sagemaker-{account_id}-{region}-{project_id}`` + template (for example, BYOR-bucket projects). :param domain_identifier: The ID of the DataZone domain. :param project_id: The ID of the DataZone project. - :return: The S3 bucket name for the project. + :return: A ``(bucket, prefix)`` tuple. ``bucket`` is the S3 bucket name. + ``prefix`` is the path component of the project's + ``s3BucketPath`` (with no leading or trailing ``/``). :raises RuntimeError: If the default tooling environment or the ``s3BucketPath`` provisioned resource cannot be found. """ @@ -277,7 +278,9 @@ def get_project_s3_path(self, domain_identifier: str, project_id: str) -> str: f"environment {environment_id} for project {project_id} in domain " f"{domain_identifier}" ) - # value looks like "s3:///"; return the bucket name only. + # value looks like "s3:///shared/" (IAM) or + # "s3://///dev/" (IDC). Return both + # parts so callers can construct project-scoped keys. parts = urlparse(value, allow_fragments=False) bucket = parts.netloc if not bucket: @@ -286,7 +289,8 @@ def get_project_s3_path(self, domain_identifier: str, project_id: str) -> str: f"'{value}' in default tooling environment {environment_id} for " f"project {project_id} in domain {domain_identifier}" ) - return bucket + prefix = parts.path.strip("/") + return bucket, prefix raise RuntimeError( f"s3BucketPath provisioned resource not found in default tooling environment " @@ -419,10 +423,10 @@ def get_notebook_outputs( """ log = logging.getLogger(__name__) try: - bucket = self.get_project_s3_path(domain_identifier, owning_project_identifier) + bucket, prefix = self.get_project_s3_path(domain_identifier, owning_project_identifier) except Exception: log.warning( - "Failed to resolve project S3 bucket for project %s in domain %s, " + "Failed to resolve project S3 location for project %s in domain %s, " "skipping notebook outputs read.", owning_project_identifier, domain_identifier, @@ -430,7 +434,12 @@ def get_notebook_outputs( ) return {} - key = f"sys/notebooks/{notebook_identifier}/runs/{notebook_run_id}/notebook_outputs.json" + # IDC domains have a non-empty prefix (e.g. "//") + # and the project role's IAM policy only allows S3 reads under that prefix. + # IAM domains have prefix == "" and the key is unchanged from the + # legacy bucket-root layout. + run_key = f".sys/notebooks/{notebook_identifier}/runs/{notebook_run_id}/notebook_outputs.json" + key = f"{prefix}/{run_key}" if prefix else run_key log.info("Reading notebook outputs from s3://%s/%s", bucket, key) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py index 9202343116153..728ec55ca02c1 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio_notebook.py @@ -334,7 +334,7 @@ def test_get_project_s3_path_uses_default_tooling_environment(self): result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID) - assert result == bucket + assert result == (bucket, f"dzd_x/{PROJECT_ID}/dev") self.mock_client.list_environment_blueprints.assert_any_call( domainIdentifier=DOMAIN_ID, managed=True, @@ -368,7 +368,7 @@ def test_get_project_s3_path_picks_lowest_deployment_order(self): result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID) - assert result == bucket + assert result == (bucket, "p") self.mock_client.get_environment.assert_called_once_with( domainIdentifier=DOMAIN_ID, identifier=env_id, @@ -389,7 +389,7 @@ def test_get_project_s3_path_falls_back_to_tooling_lite(self): result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID) - assert result == bucket + assert result == (bucket, "p") # Both blueprint lookups happened. assert self.mock_client.list_environment_blueprints.call_count == 2 self.mock_client.get_environment.assert_called_once_with( @@ -414,7 +414,7 @@ def test_get_project_s3_path_falls_back_to_first_when_no_deployment_order(self): result = self.hook.get_project_s3_path(DOMAIN_ID, PROJECT_ID) - assert result == bucket + assert result == (bucket, "p") self.mock_client.get_environment.assert_called_once_with( domainIdentifier=DOMAIN_ID, identifier=env_id, @@ -509,7 +509,33 @@ def test_get_notebook_outputs_success(self): ) assert result == outputs - expected_key = f"sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json" + expected_key = f"dzd_x/{PROJECT_ID}/dev/.sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json" + mock_s3_hook_cls.return_value.read_key.assert_called_once_with(key=expected_key, bucket_name=bucket) + + def test_get_notebook_outputs_iam_mode_no_prefix(self): + """IAM-mode projects (s3BucketPath is bucket-only) read from the bucket root.""" + outputs = {"name": "Alice"} + bucket = "iam-mode-bucket" + # Tooling env returns s3BucketPath without a path component. + self._stub_tooling_blueprint_lookup( + environments=[{"id": "env-1", "name": "Tooling", "deploymentOrder": 1}] + ) + self.mock_client.get_environment.return_value = { + "id": "env-1", + "provisionedResources": [{"name": "s3BucketPath", "value": f"s3://{bucket}"}], + } + + with patch(f"{HOOK_MODULE}.S3Hook") as mock_s3_hook_cls: + mock_s3_hook_cls.return_value.read_key.return_value = json.dumps(outputs) + result = self.hook.get_notebook_outputs( + notebook_identifier=NOTEBOOK_ID, + notebook_run_id=NOTEBOOK_RUN_ID, + domain_identifier=DOMAIN_ID, + owning_project_identifier=PROJECT_ID, + ) + + assert result == outputs + expected_key = f".sys/notebooks/{NOTEBOOK_ID}/runs/{NOTEBOOK_RUN_ID}/notebook_outputs.json" mock_s3_hook_cls.return_value.read_key.assert_called_once_with(key=expected_key, bucket_name=bucket) def test_get_notebook_outputs_no_such_key(self):