Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | No
if self.connections_prefix is None:
return None

return self._get_secret(self.connections_prefix, conn_id, self.connections_lookup_pattern)
return self._get_secret(self.connections_prefix, conn_id, self.connections_lookup_pattern, team_name)

def get_variable(self, key: str, team_name: str | None = None) -> str | None:
"""
Expand All @@ -155,7 +155,7 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None:
if self.variables_prefix is None:
return None

return self._get_secret(self.variables_prefix, key, self.variables_lookup_pattern)
return self._get_secret(self.variables_prefix, key, self.variables_lookup_pattern, team_name)

def get_config(self, key: str) -> str | None:
"""
Expand All @@ -169,19 +169,25 @@ def get_config(self, key: str) -> str | None:

return self._get_secret(self.config_prefix, key, self.config_lookup_pattern)

def _get_secret(self, path_prefix: str, secret_id: str, lookup_pattern: str | None) -> str | None:
def _get_secret(
self, path_prefix: str, secret_id: str, lookup_pattern: str | None, team_name: str | None = None
) -> str | None:
"""
Get secret value from Parameter Store.

:param path_prefix: Prefix for the Path to get Secret
:param secret_id: Secret Key
:param lookup_pattern: If provided, `secret_id` must match this pattern to look up the secret in
Systems Manager
:param team_name: Team name associated to the task trying to access the variable (if any)
"""
if lookup_pattern and not re.match(lookup_pattern, secret_id, re.IGNORECASE):
return None

ssm_path = self.build_path(path_prefix, secret_id)
if team_name:
ssm_path = self.build_path(path_prefix, team_name)
ssm_path = self.build_path(ssm_path, secret_id)
else:
ssm_path = self.build_path(path_prefix, secret_id)
ssm_path = self._ensure_leading_slash(ssm_path)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def test_get_conn_value_non_existent_key(self):
assert ssm_backend.get_conn_value(conn_id=conn_id) is None
assert ssm_backend.get_connection(conn_id=conn_id) is None

@mock_aws
def test_get_conn_value_with_team_name(self):
param = {
"Name": "/airflow/connections/my_team/test_postgres",
"Type": "String",
"Value": "postgresql://airflow:airflow@host:5432/airflow",
}
ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)
returned_uri = ssm_backend.get_conn_value(conn_id="test_postgres", team_name="my_team")
assert returned_uri == "postgresql://airflow:airflow@host:5432/airflow"

@mock_aws
def test_get_variable(self):
param = {"Name": "/airflow/variables/hello", "Type": "String", "Value": "world"}
Expand Down Expand Up @@ -145,6 +157,15 @@ def test_get_variable_non_existent_key(self):

assert ssm_backend.get_variable("test_mysql") is None

@mock_aws
def test_get_variable_with_team_name(self):
param = {"Name": "/airflow/variables/my_team/hello", "Type": "String", "Value": "world"}

ssm_backend = SystemsManagerParameterStoreBackend()
ssm_backend.client.put_parameter(**param)

assert ssm_backend.get_variable(key="hello", team_name="my_team") == "world"

@conf_vars(
{
("secrets", "backend"): "airflow.providers.amazon.aws.secrets.systems_manager."
Expand Down
Loading