diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py b/airflow/providers/microsoft/azure/hooks/container_registry.py index 2b9383e5d3f9c..ea7b129a461b5 100644 --- a/airflow/providers/microsoft/azure/hooks/container_registry.py +++ b/airflow/providers/microsoft/azure/hooks/container_registry.py @@ -21,12 +21,11 @@ from functools import cached_property from typing import Any -from azure.identity import DefaultAzureCredential from azure.mgmt.containerinstance.models import ImageRegistryCredential from azure.mgmt.containerregistry import ContainerRegistryManagementClient from airflow.hooks.base import BaseHook -from airflow.providers.microsoft.azure.utils import get_field +from airflow.providers.microsoft.azure.utils import get_default_azure_credential, get_field class AzureContainerRegistryHook(BaseHook): @@ -59,6 +58,12 @@ def get_connection_form_widgets() -> dict[str, Any]: lazy_gettext("Resource group name (optional)"), widget=BS3TextFieldWidget(), ), + "managed_identity_client_id": StringField( + lazy_gettext("Managed Identity Client ID"), widget=BS3TextFieldWidget() + ), + "workload_identity_tenant_id": StringField( + lazy_gettext("Workload Identity Tenant ID"), widget=BS3TextFieldWidget() + ), } @classmethod @@ -77,6 +82,8 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: "host": "docker image registry server", "subscription_id": "Subscription id (required for Azure AD authentication)", "resource_group": "Resource group name (required for Azure AD authentication)", + "managed_identity_client_id": "Managed Identity Client ID", + "workload_identity_tenant_id": "Workload Identity Tenant ID", }, } @@ -103,8 +110,13 @@ def get_conn(self) -> ImageRegistryCredential: extras = conn.extra_dejson subscription_id = self._get_field(extras, "subscription_id") resource_group = self._get_field(extras, "resource_group") + managed_identity_client_id = self._get_field(extras, "managed_identity_client_id") + workload_identity_tenant_id = self._get_field(extras, "workload_identity_tenant_id") client = ContainerRegistryManagementClient( - credential=DefaultAzureCredential(), subscription_id=subscription_id + credential=get_default_azure_credential( + managed_identity_client_id, workload_identity_tenant_id + ), + subscription_id=subscription_id, ) credentials = client.registries.list_credentials(resource_group, conn.login).as_dict() password = credentials["passwords"][0]["value"] diff --git a/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst b/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst index 913fcf795680d..c539d8bfefce8 100644 --- a/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst +++ b/docs/apache-airflow-providers-microsoft-azure/connections/acr.rst @@ -27,13 +27,13 @@ The Microsoft Azure Container Registry connection type enables the Azure Contain Authenticating to Azure Container Registry ------------------------------------------ -There is one way to connect to Azure Container Registry using Airflow. +There are three way to connect to Azure Container Registry using Airflow. 1. Use `Individual login with Azure AD `_ i.e. add specific credentials to the Airflow connection. -2. Fallback on `DefaultAzureCredential - `_. +2. Use managed identity by setting ``managed_identity_client_id``, ``workload_identity_tenant_id`` (under the hook, it uses DefaultAzureCredential_ with these arguments) +3. Fallback on DefaultAzureCredential_. This includes a mechanism to try different options to authenticate: Managed System Identity, environment variables, authentication through Azure CLI... Default Connection IDs @@ -48,7 +48,7 @@ Login Specify the Image Registry Username used for the initial connection. Password (optional) - Specify the Image Registry Password used for the initial connection. It can be left out to fall back on ``DefaultAzureCredential``. + Specify the Image Registry Password used for the initial connection. It can be left out to fall back on DefaultAzureCredential_. Host Specify the Image Registry Server used for the initial connection. @@ -63,6 +63,13 @@ Resource Group Name (optional) This is needed for Azure Active Directory (Azure AD) authentication. Use extra param ``resource_group`` to pass in the resource group name. +Managed Identity Client ID (optional) + The client ID of a user-assigned managed identity. If provided with ``workload_identity_tenant_id``, they'll pass to DefaultAzureCredential_. + +Workload Identity Tenant ID (optional) + ID of the application's Microsoft Entra tenant. Also called its "directory" ID. If provided with ``managed_identity_client_id``, they'll pass to DefaultAzureCredential_. + + When specifying the connection in environment variable you should specify it using URI syntax. @@ -73,3 +80,8 @@ For example: .. code-block:: bash export AIRFLOW_CONN_AZURE_CONTAINER_REGISTRY_DEFAULT='azure-container-registry://username:password@myregistry.com?tenant=tenant+id&account_name=store+name' + +.. _DefaultAzureCredential: https://docs.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#defaultazurecredential + +.. spelling:word-list:: + Entra diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py index a2b0635749b47..063a1290d2ef9 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py @@ -63,7 +63,7 @@ def test_get_conn(self, mocked_connection): @mock.patch( "airflow.providers.microsoft.azure.hooks.container_registry.ContainerRegistryManagementClient" ) - @mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.DefaultAzureCredential") + @mock.patch("airflow.providers.microsoft.azure.hooks.container_registry.get_default_azure_credential") def test_get_conn_with_default_azure_credential( self, mocked_default_azure_credential, mocked_client, mocked_connection ): @@ -80,4 +80,4 @@ def test_get_conn_with_default_azure_credential( assert hook.connection.password == "password" assert hook.connection.server == "test.cr" - mocked_default_azure_credential.assert_called_with() + mocked_default_azure_credential.assert_called_with(None, None)