diff --git a/providers/amazon/src/airflow/providers/amazon/aws/secrets/systems_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/secrets/systems_manager.py index a6749d9f995c6..542a35f5f910c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/secrets/systems_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/secrets/systems_manager.py @@ -142,7 +142,7 @@ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | No if self.connections_prefix is None: return None - return self._get_secret(self.connections_prefix, conn_id, self.connections_lookup_pattern) + return self._get_secret(self.connections_prefix, conn_id, self.connections_lookup_pattern, team_name) def get_variable(self, key: str, team_name: str | None = None) -> str | None: """ @@ -155,7 +155,7 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: if self.variables_prefix is None: return None - return self._get_secret(self.variables_prefix, key, self.variables_lookup_pattern) + return self._get_secret(self.variables_prefix, key, self.variables_lookup_pattern, team_name) def get_config(self, key: str) -> str | None: """ @@ -169,7 +169,9 @@ def get_config(self, key: str) -> str | None: return self._get_secret(self.config_prefix, key, self.config_lookup_pattern) - def _get_secret(self, path_prefix: str, secret_id: str, lookup_pattern: str | None) -> str | None: + def _get_secret( + self, path_prefix: str, secret_id: str, lookup_pattern: str | None, team_name: str | None = None + ) -> str | None: """ Get secret value from Parameter Store. @@ -177,11 +179,15 @@ def _get_secret(self, path_prefix: str, secret_id: str, lookup_pattern: str | No :param secret_id: Secret Key :param lookup_pattern: If provided, `secret_id` must match this pattern to look up the secret in Systems Manager + :param team_name: Team name associated to the task trying to access the variable (if any) """ if lookup_pattern and not re.match(lookup_pattern, secret_id, re.IGNORECASE): return None - - ssm_path = self.build_path(path_prefix, secret_id) + if team_name: + ssm_path = self.build_path(path_prefix, team_name) + ssm_path = self.build_path(ssm_path, secret_id) + else: + ssm_path = self.build_path(path_prefix, secret_id) ssm_path = self._ensure_leading_slash(ssm_path) try: diff --git a/providers/amazon/tests/unit/amazon/aws/secrets/test_systems_manager.py b/providers/amazon/tests/unit/amazon/aws/secrets/test_systems_manager.py index 3aed3819f88af..244c8bc5ce839 100644 --- a/providers/amazon/tests/unit/amazon/aws/secrets/test_systems_manager.py +++ b/providers/amazon/tests/unit/amazon/aws/secrets/test_systems_manager.py @@ -100,6 +100,18 @@ def test_get_conn_value_non_existent_key(self): assert ssm_backend.get_conn_value(conn_id=conn_id) is None assert ssm_backend.get_connection(conn_id=conn_id) is None + @mock_aws + def test_get_conn_value_with_team_name(self): + param = { + "Name": "/airflow/connections/my_team/test_postgres", + "Type": "String", + "Value": "postgresql://airflow:airflow@host:5432/airflow", + } + ssm_backend = SystemsManagerParameterStoreBackend() + ssm_backend.client.put_parameter(**param) + returned_uri = ssm_backend.get_conn_value(conn_id="test_postgres", team_name="my_team") + assert returned_uri == "postgresql://airflow:airflow@host:5432/airflow" + @mock_aws def test_get_variable(self): param = {"Name": "/airflow/variables/hello", "Type": "String", "Value": "world"} @@ -145,6 +157,15 @@ def test_get_variable_non_existent_key(self): assert ssm_backend.get_variable("test_mysql") is None + @mock_aws + def test_get_variable_with_team_name(self): + param = {"Name": "/airflow/variables/my_team/hello", "Type": "String", "Value": "world"} + + ssm_backend = SystemsManagerParameterStoreBackend() + ssm_backend.client.put_parameter(**param) + + assert ssm_backend.get_variable(key="hello", team_name="my_team") == "world" + @conf_vars( { ("secrets", "backend"): "airflow.providers.amazon.aws.secrets.systems_manager."