diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py index 8ce9a735e89f5..f52d08d574b03 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -72,9 +72,9 @@ class _VaultClient(LoggingMixin): (for ``token`` and ``github`` auth_type). :param username: Username for Authentication (for ``ldap`` and ``userpass`` auth_types). :param password: Password for Authentication (for ``ldap`` and ``userpass`` auth_types). - :param key_id: Key ID for Authentication (for ``aws_iam`` and ''azure`` auth_type). + :param key_id: Key ID for Authentication (for ``aws_iam`` and ``azure`` auth_type). :param secret_id: Secret ID for Authentication (for ``approle``, ``aws_iam`` and ``azure`` auth_types). - :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` auth_types). + :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` and ``gcp`` auth_types). :param assume_role_kwargs: AWS assume role param. See AWS STS Docs: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html @@ -83,7 +83,7 @@ class _VaultClient(LoggingMixin): :param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type). :param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default: ``/var/run/secrets/kubernetes.io/serviceaccount/token``). - :param gcp_key_path: Path to Google Cloud Service Account key file (JSON) (for ``gcp`` auth_type). + :param gcp_key_path: Path to Google Cloud Service Account key file (JSON) (for ``gcp`` auth_type). Mutually exclusive with gcp_keyfile_dict :param gcp_keyfile_dict: Dictionary of keyfile parameters. (for ``gcp`` auth_type). Mutually exclusive with gcp_key_path @@ -171,10 +171,6 @@ def __init__( raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'") if not role_id: raise VaultError("The 'gcp' authentication type requires 'role_id'") - if not gcp_key_path and not gcp_keyfile_dict: - raise VaultError( - "The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'" - ) self.kv_engine_version = kv_engine_version or 2 self.url = url @@ -352,25 +348,43 @@ def _auth_gcp(self, _client: hvac.Client) -> None: import json import time - import googleapiclient + # Determine service account email + service_account_email = getattr(credentials, "service_account_email", None) or getattr( + credentials, "client_email", None + ) + + if service_account_email is None: + # Fallback for Compute Engine credentials if email is not yet populated + try: + from google.auth import compute_engine, exceptions - if self.gcp_keyfile_dict: - creds = self.gcp_keyfile_dict - elif self.gcp_key_path: - with open(self.gcp_key_path) as f: - creds = json.load(f) + if isinstance(credentials, compute_engine.Credentials): + from google.auth import transport - service_account = creds["client_email"] + credentials.refresh(transport.requests.Request()) + service_account_email = getattr(credentials, "service_account_email", None) + except exceptions.RefreshError: + self.log.error("Failed to refresh Compute Engine credentials.") + except ImportError: + self.log.error("google-auth not installed, skipping credential refresh.") - # Generate a payload for subsequent "signJwt()" call - # Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials + if not isinstance(service_account_email, str): + raise VaultError( + f"Could not determine service account email from credentials. " + f"Expected string, got {type(service_account_email).__name__}" + ) + + # Generate a payload for subsequent "signJwt()" call. + # The 'sub' claim must be the service account email. now = int(time.time()) expires = now + 900 # 15 mins in seconds, can't be longer. - payload = {"iat": now, "exp": expires, "sub": credentials, "aud": f"vault/{self.role_id}"} + payload = {"iat": now, "exp": expires, "sub": service_account_email, "aud": f"vault/{self.role_id}"} body = {"payload": json.dumps(payload)} - name = f"projects/{project_id}/serviceAccounts/{service_account}" + name = f"projects/{project_id}/serviceAccounts/{service_account_email}" # Perform the GCP API call + import googleapiclient.discovery + iam = googleapiclient.discovery.build("iam", "v1", credentials=credentials) request = iam.projects().serviceAccounts().signJwt(name=name, body=body) resp = request.execute() diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py b/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py index 6de3138a0c840..d142b7eba2440 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py @@ -80,12 +80,12 @@ class VaultHook(BaseHook): :param vault_conn_id: The id of the connection to use :param auth_type: Authentication Type for the Vault. Default is ``token``. Available values are: - ('approle', 'github', 'gcp', 'jwt', 'kubernetes', 'ldap', 'token', 'userpass') + ('approle', 'aws_iam', 'azure', 'github', 'gcp', 'jwt', 'kubernetes', 'ldap', 'radius', 'token', 'userpass') :param auth_mount_point: It can be used to define mount_point for authentication chosen Default depends on the authentication method used. :param kv_engine_version: Select the version of the engine to run (``1`` or ``2``). Defaults to version defined in connection or ``2`` if not defined in connection. - :param role_id: Role ID for ``aws_iam`` Authentication. + :param role_id: Role ID for ``approle``, ``aws_iam`` and ``gcp`` Authentication. :param region: AWS region for STS API calls (for ``aws_iam`` auth_type). :param kubernetes_role: Role for Authentication (for ``kubernetes`` auth_type) :param kubernetes_jwt_path: Path for kubernetes jwt token (for ``kubernetes`` auth_type, default: @@ -162,6 +162,10 @@ def __init__( if not region: region = self.connection.extra_dejson.get("region") + if auth_type == "gcp": + if not role_id: + role_id = self.connection.extra_dejson.get("role_id") or self.connection.login + azure_resource, azure_tenant_id = ( self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id) if auth_type == "azure" diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py b/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py index 53c2c638d5663..72ac30901b2e3 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py @@ -75,7 +75,7 @@ class VaultBackend(BaseSecretsBackend, LoggingMixin): :param password: Password for Authentication (for ``ldap`` and ``userpass`` auth_type). :param key_id: Key ID for Authentication (for ``aws_iam`` and ''azure`` auth_type). :param secret_id: Secret ID for Authentication (for ``approle``, ``aws_iam`` and ``azure`` auth_types). - :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` auth_types). + :param role_id: Role ID for Authentication (for ``approle``, ``aws_iam`` and ``gcp`` auth_types). :param assume_role_kwargs: AWS assume role param. See AWS STS Docs: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html diff --git a/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py b/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py index fe27763e60cd4..0988062ca1ea2 100644 --- a/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py +++ b/providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import time from unittest import mock from unittest.mock import call, mock_open, patch @@ -255,26 +254,26 @@ def test_azure_missing_tenant_id(self, mock_hvac): secret_id="pass", ) - @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") @mock.patch("googleapiclient.discovery.build") - def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open): - # Mock the content of the file 'path.json' - mock_file = mock.MagicMock() - mock_file.read.return_value = '{"client_email": "service_account_email"}' - mock_open.return_value.__enter__.return_value = mock_file - + @mock.patch("time.time") + def test_gcp_key( + self, mock_time, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes + ): mock_client = mock.MagicMock() mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] - mock_get_credentials.return_value = ("credentials", "project_id") + + mock_credentials = mock.Mock(spec=[]) + mock_credentials.client_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") # Mock the current time to use for iat and exp - current_time = int(time.time()) - iat = current_time - exp = iat + 3600 # 1 hour after iat + mock_time.return_value = 1234567890.0 + iat = 1234567890 + exp = iat + 900 # 15 minutes after iat # Mock the signJwt API to return the expected payload mock_sign_jwt = ( @@ -291,21 +290,7 @@ def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mo session=None, ) - # Preserve the original json.dumps - original_json_dumps = json.dumps - - # Inject the mocked payload into the JWT signing process - with mock.patch("json.dumps") as mock_json_dumps: - - def mocked_json_dumps(payload): - # Override the payload to inject controlled iat and exp values - payload["iat"] = iat - payload["exp"] = exp - return original_json_dumps(payload) # Use the original json.dumps - - mock_json_dumps.side_effect = mocked_json_dumps - - client = vault_client.client # Trigger the Vault client creation + client = vault_client.client # Trigger the Vault client creation # Validate that the HVAC client and other mocks are called correctly mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) @@ -321,65 +306,108 @@ def mocked_json_dumps(payload): # Assert iat and exp values are as expected assert payload["iat"] == iat assert payload["exp"] == exp - assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + assert payload["sub"] == "service_account_email" client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 - @mock.patch("builtins.open", create=True) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") @mock.patch("googleapiclient.discovery.build") - def test_gcp_different_auth_mount_point( - self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open + @mock.patch("time.time") + def test_gcp_adc( + self, mock_time, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes ): - # Mock the content of the file 'path.json' - mock_file = mock.MagicMock() - mock_file.read.return_value = '{"client_email": "service_account_email"}' - mock_open.return_value.__enter__.return_value = mock_file - mock_client = mock.MagicMock() mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] - mock_get_credentials.return_value = ("credentials", "project_id") + + mock_credentials = mock.Mock(spec=[]) + mock_credentials.service_account_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") mock_sign_jwt = ( mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt ) mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} - # Generate realistic iat and exp values - current_time = int(time.time()) - iat = current_time - exp = current_time + 3600 # 1 hour later + # Mock the current time to use for iat and exp + mock_time.return_value = 1234567890.0 + iat = 1234567890 + exp = iat + 900 # 15 minutes after iat vault_client = _VaultClient( auth_type="gcp", - gcp_key_path="path.json", gcp_scopes="scope1,scope2", role_id="role", url="http://localhost:8180", - auth_mount_point="other", session=None, ) - # Preserve the original json.dumps - original_json_dumps = json.dumps + client = vault_client.client # Trigger the Vault client creation + + # Validate that the HVAC client and other mocks are called correctly + mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) + mock_get_scopes.assert_called_with("scope1,scope2") + mock_get_credentials.assert_called_with(key_path=None, keyfile_dict=None, scopes=["scope1", "scope2"]) + + # Extract the arguments passed to the mocked signJwt API + args, kwargs = mock_sign_jwt.call_args + payload = json.loads(kwargs["body"]["payload"]) + + # Assert iat and exp values are as expected + assert payload["iat"] == iat + assert payload["exp"] == exp + assert payload["sub"] == "service_account_email" + + client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") + client.is_authenticated.assert_called_with() + assert vault_client.kv_engine_version == 2 + + @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") + @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + @mock.patch("googleapiclient.discovery.build") + @mock.patch("time.time") + def test_gcp_different_auth_mount_point( + self, + mock_time, + mock_google_build, + mock_hvac_client, + mock_get_credentials, + mock_get_scopes, + ): + mock_client = mock.MagicMock() + mock_hvac_client.return_value = mock_client + mock_get_scopes.return_value = ["scope1", "scope2"] + + mock_credentials = mock.Mock(spec=[]) + mock_credentials.client_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") - # Inject the mocked payload into the JWT signing process - with mock.patch("json.dumps") as mock_json_dumps: + mock_sign_jwt = ( + mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt + ) + mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} - def mocked_json_dumps(payload): - # Override the payload to inject controlled iat and exp values - payload["iat"] = iat - payload["exp"] = exp - return original_json_dumps(payload) # Use the original json.dumps + # Mock the current time to use for iat and exp + mock_time.return_value = 1234567890.0 + iat = 1234567890 + exp = iat + 900 # 15 minutes after iat - mock_json_dumps.side_effect = mocked_json_dumps + vault_client = _VaultClient( + auth_type="gcp", + gcp_key_path="path.json", + gcp_scopes="scope1,scope2", + role_id="role", + url="http://localhost:8180", + auth_mount_point="other", + session=None, + ) - client = vault_client.client # Trigger the Vault client creation + client = vault_client.client # Trigger the Vault client creation # Assertions mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) @@ -394,36 +422,37 @@ def mocked_json_dumps(payload): # Assert iat and exp values are as expected assert payload["iat"] == iat assert payload["exp"] == exp - assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + assert payload["sub"] == "service_account_email" client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt", mount_point="other") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 - @mock.patch( - "builtins.open", new_callable=mock_open, read_data='{"client_email": "service_account_email"}' - ) @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") @mock.patch("googleapiclient.discovery.build") + @mock.patch("time.time") def test_gcp_dict( - self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_file + self, mock_time, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes ): mock_client = mock.MagicMock() mock_hvac_client.return_value = mock_client mock_get_scopes.return_value = ["scope1", "scope2"] - mock_get_credentials.return_value = ("credentials", "project_id") + + mock_credentials = mock.Mock(spec=[]) + mock_credentials.client_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") mock_sign_jwt = ( mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt ) mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"} - # Generate realistic iat and exp values - current_time = int(time.time()) - iat = current_time - exp = current_time + 3600 # 1 hour later + # Mock the current time to use for iat and exp + mock_time.return_value = 1234567890.0 + iat = 1234567890 + exp = iat + 900 # 15 minutes after iat vault_client = _VaultClient( auth_type="gcp", @@ -434,21 +463,7 @@ def test_gcp_dict( session=None, ) - # Preserve the original json.dumps - original_json_dumps = json.dumps - - # Inject the mocked payload into the JWT signing process - with mock.patch("json.dumps") as mock_json_dumps: - - def mocked_json_dumps(payload): - # Override the payload to inject controlled iat and exp values - payload["iat"] = iat - payload["exp"] = exp - return original_json_dumps(payload) # Use the original json.dumps - - mock_json_dumps.side_effect = mocked_json_dumps - - client = vault_client.client # Trigger the Vault client creation + client = vault_client.client # Trigger the Vault client creation # Assertions mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None) @@ -463,12 +478,35 @@ def mocked_json_dumps(payload): # Assert iat and exp values are as expected assert payload["iat"] == iat assert payload["exp"] == exp - assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat + assert payload["sub"] == "service_account_email" client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt") client.is_authenticated.assert_called_with() assert vault_client.kv_engine_version == 2 + @mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes") + @mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id") + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client") + def test_gcp_error_wrong_type(self, mock_hvac_client, mock_get_credentials, mock_get_scopes): + mock_client = mock.MagicMock() + mock_hvac_client.return_value = mock_client + mock_get_scopes.return_value = ["scope1"] + + # Return something that is not a string for client_email + mock_credentials = mock.Mock(spec=[]) + mock_credentials.client_email = 12345 + mock_get_credentials.return_value = (mock_credentials, "project_id") + + vault_client = _VaultClient( + auth_type="gcp", + gcp_scopes="scope1", + role_id="role", + url="http://localhost:8180", + ) + + with pytest.raises(VaultError, match="Expected string, got int"): + _ = vault_client.client + @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_github(self, mock_hvac): mock_client = mock.MagicMock() diff --git a/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py b/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py index 141f6b838267d..8de19dcb5d6c3 100644 --- a/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py +++ b/providers/hashicorp/tests/unit/hashicorp/hooks/test_vault.py @@ -442,7 +442,9 @@ def test_gcp_init_params( mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection mock_get_scopes.return_value = ["scope1", "scope2"] - mock_get_credentials.return_value = ("credentials", "project_id") + mock_credentials = MagicMock() + mock_credentials.service_account_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") # Mock googleapiclient.discovery.build chain mock_service = MagicMock() @@ -467,12 +469,8 @@ def test_gcp_init_params( "session": None, } - with patch( - "builtins.open", mock_open(read_data='{"client_email": "service_account_email"}') - ) as mock_file: - test_hook = VaultHook(**kwargs) - test_client = test_hook.get_conn() - mock_file.assert_called_with("path.json") + test_hook = VaultHook(**kwargs) + test_client = test_hook.get_conn() mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") @@ -497,7 +495,9 @@ def test_gcp_dejson( mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection mock_get_scopes.return_value = ["scope1", "scope2"] - mock_get_credentials.return_value = ("credentials", "project_id") + mock_credentials = MagicMock() + mock_credentials.service_account_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") # Mock googleapiclient.discovery.build chain mock_service = MagicMock() @@ -524,12 +524,8 @@ def test_gcp_dejson( "role_id": "role", } - with patch( - "builtins.open", mock_open(read_data='{"client_email": "service_account_email"}') - ) as mock_file: - test_hook = VaultHook(**kwargs) - test_client = test_hook.get_conn() - mock_file.assert_called_with("path.json") + test_hook = VaultHook(**kwargs) + test_client = test_hook.get_conn() mock_get_connection.assert_called_with("vault_conn_id") mock_get_scopes.assert_called_with("scope1,scope2") @@ -554,7 +550,9 @@ def test_gcp_dict_dejson( mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection mock_get_scopes.return_value = ["scope1", "scope2"] - mock_get_credentials.return_value = ("credentials", "project_id") + mock_credentials = MagicMock() + mock_credentials.service_account_email = "service_account_email" + mock_get_credentials.return_value = (mock_credentials, "project_id") # Mock googleapiclient.discovery.build chain mock_service = MagicMock()