diff --git a/airflow-core/docs/administration-and-deployment/dag-bundles.rst b/airflow-core/docs/administration-and-deployment/dag-bundles.rst index 7e1eaf0123b20..095b7dce05e19 100644 --- a/airflow-core/docs/administration-and-deployment/dag-bundles.rst +++ b/airflow-core/docs/administration-and-deployment/dag-bundles.rst @@ -56,6 +56,9 @@ Airflow supports multiple types of Dag Bundles, each catering to specific use ca **airflow.providers.google.cloud.bundles.gcs.GCSDagBundle** These bundles reference a GCS bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code. +**airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle** + These bundles reference an Azure Blob Storage container containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code. + Configuring Dag bundles ----------------------- diff --git a/providers/microsoft/azure/docs/bundles/index.rst b/providers/microsoft/azure/docs/bundles/index.rst new file mode 100644 index 0000000000000..89d4eeec1460f --- /dev/null +++ b/providers/microsoft/azure/docs/bundles/index.rst @@ -0,0 +1,80 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Bundles +####### + +Dag bundles allow Airflow to load Dags from external sources. For a general overview see +:doc:`apache-airflow:administration-and-deployment/dag-bundles`. + +WasbDagBundle +============= + +Use the :class:`~airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle` to configure an Azure Blob +Storage bundle in your Airflow's ``[dag_processor] dag_bundle_config_list``. The bundle does not support +versioning; tasks always use the latest blobs synced to the local bundle directory. + +Example of using the WasbDagBundle: + +**JSON format example**: + +.. code-block:: bash + + export AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST='[ + { + "name": "my-wasb-dags", + "classpath": "airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle", + "kwargs": { + "wasb_conn_id": "wasb_default", + "container_name": "airflow-dags", + "prefix": "dags/", + "refresh_interval": 60 + } + } + ]' + +Authentication +-------------- + +The bundle uses a ``wasb`` Connection (``wasb_conn_id``). Authentication is the same as for +:class:`~airflow.providers.microsoft.azure.hooks.wasb.WasbHook` — see :doc:`../connections/wasb`. On Azure-hosted +Airflow, managed identity via ``DefaultAzureCredential`` is typical. + +Permissions +----------- + +The identity needs read access to list and download blobs in the target container. Assign +`Storage Blob Data Reader `_ +at the storage account or container scope. + +Container and prefix +-------------------- + +Set ``container_name`` to the blob container that holds your Dag files. Use ``prefix`` for an optional +virtual folder inside the container. + +Networking +---------- + +The Dag processor needs outbound HTTPS to the blob endpoint. Storage firewalls and private endpoints +must allow access from Airflow, as for any WASB client. + +Reusing the Connection in Dags +------------------------------ + +You can use the same ``wasb`` Connection ID in ``wasb_conn_id`` for the bundle and for operators or sensors +that use ``WasbHook``. diff --git a/providers/microsoft/azure/docs/index.rst b/providers/microsoft/azure/docs/index.rst index 307bacafc51f0..d40aa62d772df 100644 --- a/providers/microsoft/azure/docs/index.rst +++ b/providers/microsoft/azure/docs/index.rst @@ -34,6 +34,7 @@ :maxdepth: 1 :caption: Guides + Bundles Connection types Message queues Operators diff --git a/providers/microsoft/azure/provider.yaml b/providers/microsoft/azure/provider.yaml index 71bac83b5f199..462eade93c06a 100644 --- a/providers/microsoft/azure/provider.yaml +++ b/providers/microsoft/azure/provider.yaml @@ -311,6 +311,11 @@ hooks: python-modules: - airflow.providers.microsoft.azure.hooks.powerbi +bundles: + - integration-name: Microsoft Azure Blob Storage + python-modules: + - airflow.providers.microsoft.azure.bundles.wasb + triggers: - integration-name: Microsoft Azure Batch python-modules: diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/bundles/__init__.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/bundles/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/bundles/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/bundles/wasb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/bundles/wasb.py new file mode 100644 index 0000000000000..465f9d6725482 --- /dev/null +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/bundles/wasb.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from pathlib import Path + +import structlog + +from airflow.dag_processing.bundles.base import BaseDagBundle +from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + + +class WasbDagBundle(BaseDagBundle): + """ + WASB Dag bundle - exposes a directory in Azure Blob Storage as a Dag bundle. + + This allows Airflow to load Dags directly from an Azure Blob Storage container. + + :param wasb_conn_id: Airflow connection ID for Azure Blob Storage. Defaults to WasbHook.default_conn_name. + :param container_name: The name of the blob container containing the Dag files. + :param prefix: Optional subdirectory within the container where the Dags are stored. + If empty, Dags are assumed to be at the root of the container. + """ + + supports_versioning = False + + def __init__( + self, + *, + wasb_conn_id: str = WasbHook.default_conn_name, + container_name: str, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.wasb_conn_id = wasb_conn_id + self.container_name = container_name + self.prefix = prefix + self.wasb_dags_dir: Path = self.base_dir + + log = structlog.get_logger(__name__) + self._log = log.bind( + bundle_name=self.name, + version=self.version, + container_name=self.container_name, + prefix=self.prefix, + wasb_conn_id=self.wasb_conn_id, + ) + self._wasb_hook: WasbHook | None = None + + def _initialize(self): + with self.lock(): + if not self.wasb_dags_dir.exists(): + self._log.info("Creating local Dags directory: %s", self.wasb_dags_dir) + os.makedirs(self.wasb_dags_dir) + + if not self.wasb_dags_dir.is_dir(): + raise NotADirectoryError(f"Local Dags path: {self.wasb_dags_dir} is not a directory.") + + if not self.wasb_hook.check_for_container(container_name=self.container_name): + raise ValueError(f"WASB container '{self.container_name}' does not exist.") + + if self.prefix: + if not self.wasb_hook.check_for_prefix( + container_name=self.container_name, prefix=self.prefix, delimiter="/" + ): + raise ValueError( + f"WASB prefix 'wasb://{self.container_name}/{self.prefix}' does not exist." + ) + self.refresh() + + def initialize(self) -> None: + self._initialize() + super().initialize() + + @property + def wasb_hook(self): + if self._wasb_hook is None: + self._wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) + return self._wasb_hook + + def __repr__(self): + return ( + f"" + ) + + def get_current_version(self) -> str | None: + """Return the current version of the Dag bundle. Currently not supported.""" + return None + + @property + def path(self) -> Path: + """Return the local path to the Dag files.""" + return self.wasb_dags_dir + + def refresh(self) -> None: + """Refresh the Dag bundle by re-downloading the Dags from Azure Blob Storage.""" + if self.version: + raise ValueError("Refreshing a specific version is not supported") + + with self.lock(): + self._log.debug( + "Downloading Dags from wasb://%s/%s to %s", + self.container_name, + self.prefix, + self.wasb_dags_dir, + ) + self.wasb_hook.sync_to_local_dir( + container_name=self.container_name, + prefix=self.prefix, + local_dir=self.wasb_dags_dir, + delete_stale=True, + ) + + def view_url(self, version: str | None = None) -> str | None: + """ + Return a URL for viewing the Dags in Azure Blob Storage. Currently, versioning is not supported. + + This method is deprecated and will be removed when the minimum supported Airflow version is 3.1. + Use `view_url_template` instead. + """ + return self.view_url_template() + + def view_url_template(self) -> str | None: + """Return a URL for viewing the Dags in Azure Blob Storage. Currently, versioning is not supported.""" + if self.version: + raise ValueError("WASB url with version is not supported") + if hasattr(self, "_view_url_template") and self._view_url_template: + return self._view_url_template + account_url = self.wasb_hook.blob_service_client.url + url = f"{account_url.rstrip('/')}/{self.container_name}" + if self.prefix: + url += f"/{self.prefix}" + return url diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py index ee2e388980fdc..e657f8b328655 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py @@ -283,6 +283,12 @@ def get_provider_info(): "python-modules": ["airflow.providers.microsoft.azure.hooks.powerbi"], }, ], + "bundles": [ + { + "integration-name": "Microsoft Azure Blob Storage", + "python-modules": ["airflow.providers.microsoft.azure.bundles.wasb"], + } + ], "triggers": [ { "integration-name": "Microsoft Azure Batch", diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py index c363c227690ed..d4bfb248932cc 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py @@ -28,6 +28,7 @@ import logging import os +from pathlib import Path from typing import TYPE_CHECKING, Any, cast from azure.core.credentials import AzureSasCredential @@ -274,6 +275,17 @@ def check_for_prefix(self, container_name: str, prefix: str, **kwargs) -> bool: blobs = self.get_blobs_list(container_name=container_name, prefix=prefix, **kwargs) return bool(blobs) + def check_for_container(self, container_name: str) -> bool: + """ + Check if a container exists on Azure Blob Storage. + + :param container_name: Name of the container. + :return: True if the container exists, False otherwise. + """ + container = self._get_container_client(container_name) + self.check_for_variable_type("container", container, ContainerClient) + return cast("ContainerClient", container).exists() + def check_for_variable_type(self, variable_name: str, container: Any, expected_type: type[Any]) -> None: if not isinstance(container, expected_type): raise TypeError( @@ -463,6 +475,98 @@ def download( # TODO: rework the interface as it might also return Awaitable return blob_client.download_blob(offset=offset, length=length, **kwargs) # type: ignore[return-value] + def _sync_to_local_dir_delete_stale_local_files( + self, current_wasb_objects: list[Path], local_dir: Path + ) -> None: + current_wasb_keys = {key.resolve() for key in current_wasb_objects} + + for item in local_dir.rglob("*"): + if item.is_file() and item.resolve() not in current_wasb_keys: + self.log.debug("Deleted stale local file: %s", item) + item.unlink() + for root, dirs, _ in os.walk(local_dir, topdown=False): + for d in dirs: + dir_path = os.path.join(root, d) + if not os.listdir(dir_path): + self.log.debug("Deleted stale empty directory: %s", dir_path) + os.rmdir(dir_path) + + def _sync_to_local_dir_if_changed( + self, container_name: str, blob: BlobProperties, local_target_path: Path + ) -> None: + should_download = False + download_logs: list[str] = [] + download_log_params: list[Any] = [] + + if not local_target_path.exists(): + should_download = True + download_logs.append("Local file %s does not exist.") + download_log_params.append(local_target_path) + else: + local_stats = local_target_path.stat() + if blob.size != local_stats.st_size: + should_download = True + download_logs.append("Blob size (%s) and local file size (%s) differ.") + download_log_params.extend([blob.size, local_stats.st_size]) + + blob_last_modified = blob.last_modified + if blob_last_modified and local_stats.st_mtime < blob_last_modified.timestamp(): + should_download = True + download_logs.append("Blob last modified (%s) and local file last modified (%s) differ.") + download_log_params.extend([blob_last_modified.timestamp(), local_stats.st_mtime]) + + if should_download: + self.get_file( + file_path=str(local_target_path), + container_name=container_name, + blob_name=blob.name, + ) + download_logs.append("Downloaded %s to %s") + download_log_params.extend([blob.name, local_target_path.as_posix()]) + self.log.debug(" ".join(download_logs), *download_log_params) + else: + self.log.debug( + "Local file %s is up-to-date with blob %s. Skipping download.", + local_target_path.as_posix(), + blob.name, + ) + + def sync_to_local_dir( + self, + container_name: str, + local_dir: Path, + prefix: str = "", + delete_stale: bool = True, + ) -> None: + """Download files from an Azure Blob Storage container to a local directory.""" + self.log.debug("Downloading data from wasb://%s/%s to %s", container_name, prefix, local_dir) + + local_wasb_objects: list[Path] = [] + container = self._get_container_client(container_name) + self.check_for_variable_type("container", container, ContainerClient) + container = cast("ContainerClient", container) + + for blob in container.list_blobs(name_starts_with=prefix or None): + if blob.name.endswith("/"): + continue + blob_path = Path(blob.name) + if prefix: + local_target_path = local_dir.joinpath(blob_path.relative_to(prefix)) + else: + local_target_path = local_dir.joinpath(blob_path) + if not local_target_path.parent.exists(): + local_target_path.parent.mkdir(parents=True, exist_ok=True) + self.log.debug("Created local directory: %s", local_target_path.parent) + self._sync_to_local_dir_if_changed( + container_name=container_name, blob=blob, local_target_path=local_target_path + ) + local_wasb_objects.append(local_target_path) + + if delete_stale: + self._sync_to_local_dir_delete_stale_local_files( + current_wasb_objects=local_wasb_objects, local_dir=local_dir + ) + def create_container(self, container_name: str) -> None: """ Create container object if not already existing. diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/bundles/__init__.py b/providers/microsoft/azure/tests/unit/microsoft/azure/bundles/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/bundles/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/bundles/test_wasb.py b/providers/microsoft/azure/tests/unit/microsoft/azure/bundles/test_wasb.py new file mode 100644 index 0000000000000..d13bc3c612419 --- /dev/null +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/bundles/test_wasb.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, PropertyMock, call, patch + +import pytest + +import airflow.version +from airflow.models import Connection + +from tests_common.test_utils.config import conf_vars + +if airflow.version.version.strip().startswith("3"): + from airflow.providers.microsoft.azure.bundles.wasb import WasbDagBundle + +WASB_CONN_ID = "wasb_dags_connection" +CONTAINER_NAME = "my-airflow-dags-container" +CONTAINER_PREFIX = "project1/dags" +ACCOUNT_NAME = "myaccount" +ACCOUNT_URL = f"https://{ACCOUNT_NAME}.blob.core.windows.net" + + +@pytest.fixture(autouse=True) +def bundle_temp_dir(tmp_path): + with conf_vars({("dag_processor", "dag_bundle_storage_path"): str(tmp_path)}): + yield tmp_path + + +@pytest.mark.skipif(not airflow.version.version.strip().startswith("3"), reason="Airflow >=3.0.0 test") +class TestWasbDagBundle: + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=WASB_CONN_ID, + conn_type="wasb", + login=ACCOUNT_NAME, + password="password", + ) + ) + + @patch( + "airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle.wasb_hook", new_callable=PropertyMock + ) + def test_view_url_generates_blob_url(self, mock_wasb_hook_property): + mock_hook = MagicMock() + mock_hook.blob_service_client.url = ACCOUNT_URL + mock_wasb_hook_property.return_value = mock_hook + + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix=CONTAINER_PREFIX, + container_name=CONTAINER_NAME, + ) + url: str = bundle.view_url() + assert url == f"{ACCOUNT_URL}/{CONTAINER_NAME}/{CONTAINER_PREFIX}" + + @patch( + "airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle.wasb_hook", new_callable=PropertyMock + ) + def test_view_url_template_generates_blob_url(self, mock_wasb_hook_property): + mock_hook = MagicMock() + mock_hook.blob_service_client.url = ACCOUNT_URL + mock_wasb_hook_property.return_value = mock_hook + + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix=CONTAINER_PREFIX, + container_name=CONTAINER_NAME, + ) + url: str = bundle.view_url_template() + assert url == f"{ACCOUNT_URL}/{CONTAINER_NAME}/{CONTAINER_PREFIX}" + + def test_supports_versioning(self): + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix=CONTAINER_PREFIX, + container_name=CONTAINER_NAME, + ) + assert WasbDagBundle.supports_versioning is False + + bundle.version = "test_version" + + with pytest.raises(ValueError, match="Refreshing a specific version is not supported"): + bundle.refresh() + with pytest.raises(ValueError, match="WASB url with version is not supported"): + bundle.view_url("test_version") + + def test_local_dags_path_is_not_a_directory(self, bundle_temp_dir): + bundle_name = "test" + file_path = bundle_temp_dir / bundle_name + file_path.touch() + + bundle = WasbDagBundle( + name=bundle_name, + wasb_conn_id=WASB_CONN_ID, + prefix="project1_dags", + container_name="airflow_dags", + ) + with pytest.raises(NotADirectoryError, match=f"Local Dags path: {file_path} is not a directory."): + bundle.initialize() + + def test_correct_bundle_path_used(self): + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix="project1_dags", + container_name="airflow_dags", + ) + assert str(bundle.base_dir) == str(bundle.wasb_dags_dir) + + @patch( + "airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle.wasb_hook", new_callable=PropertyMock + ) + def test_wasb_container_and_prefix_validated(self, mock_wasb_hook_property): + mock_hook = MagicMock() + mock_wasb_hook_property.return_value = mock_hook + + mock_hook.check_for_container.return_value = False + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix="project1_dags", + container_name="non-existing-container", + ) + with pytest.raises(ValueError, match="WASB container 'non-existing-container' does not exist."): + bundle.initialize() + mock_hook.check_for_container.assert_called_once_with(container_name="non-existing-container") + + mock_hook.check_for_container.return_value = True + mock_hook.check_for_prefix.return_value = False + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix="non-existing-prefix", + container_name=CONTAINER_NAME, + ) + with pytest.raises( + ValueError, + match=f"WASB prefix 'wasb://{CONTAINER_NAME}/non-existing-prefix' does not exist.", + ): + bundle.initialize() + mock_hook.check_for_prefix.assert_called_once_with( + container_name=CONTAINER_NAME, prefix="non-existing-prefix", delimiter="/" + ) + + mock_hook.check_for_prefix.return_value = True + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix=CONTAINER_PREFIX, + container_name=CONTAINER_NAME, + ) + bundle.initialize() + + mock_hook.check_for_prefix.reset_mock() + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix="", + container_name=CONTAINER_NAME, + ) + bundle.initialize() + mock_hook.check_for_prefix.assert_not_called() + + @patch( + "airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle.wasb_hook", new_callable=PropertyMock + ) + def test_refresh(self, mock_wasb_hook_property): + mock_hook = MagicMock() + mock_hook.check_for_container.return_value = True + mock_hook.check_for_prefix.return_value = True + mock_wasb_hook_property.return_value = mock_hook + + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + prefix=CONTAINER_PREFIX, + container_name=CONTAINER_NAME, + ) + bundle._log.debug = MagicMock() + download_log_call = call( + "Downloading Dags from wasb://%s/%s to %s", + CONTAINER_NAME, + CONTAINER_PREFIX, + bundle.wasb_dags_dir, + ) + sync_call = call( + container_name=CONTAINER_NAME, + prefix=CONTAINER_PREFIX, + local_dir=bundle.wasb_dags_dir, + delete_stale=True, + ) + + bundle.initialize() + assert bundle._log.debug.call_count == 1 + assert bundle._log.debug.call_args_list == [download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 1 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call] + + bundle.refresh() + assert bundle._log.debug.call_count == 2 + assert bundle._log.debug.call_args_list == [download_log_call, download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 2 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call, sync_call] + + @patch( + "airflow.providers.microsoft.azure.bundles.wasb.WasbDagBundle.wasb_hook", new_callable=PropertyMock + ) + def test_refresh_without_prefix(self, mock_wasb_hook_property): + mock_hook = MagicMock() + mock_hook.check_for_container.return_value = True + mock_wasb_hook_property.return_value = mock_hook + + bundle = WasbDagBundle( + name="test", + wasb_conn_id=WASB_CONN_ID, + container_name=CONTAINER_NAME, + ) + bundle._log.debug = MagicMock() + download_log_call = call( + "Downloading Dags from wasb://%s/%s to %s", + CONTAINER_NAME, + "", + bundle.wasb_dags_dir, + ) + sync_call = call( + container_name=CONTAINER_NAME, + prefix="", + local_dir=bundle.wasb_dags_dir, + delete_stale=True, + ) + + assert bundle.prefix == "" + bundle.initialize() + assert bundle._log.debug.call_count == 1 + assert bundle._log.debug.call_args_list == [download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 1 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call] + + bundle.refresh() + assert bundle._log.debug.call_count == 2 + assert bundle._log.debug.call_args_list == [download_log_call, download_log_call] + assert mock_hook.sync_to_local_dir.call_count == 2 + assert mock_hook.sync_to_local_dir.call_args_list == [sync_call, sync_call] diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_wasb.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_wasb.py index 45670510944ca..0d9087bf84b19 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_wasb.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_wasb.py @@ -19,6 +19,7 @@ import os import re +from datetime import datetime, timezone from unittest import mock from unittest.mock import create_autospec @@ -457,6 +458,37 @@ def test_check_for_prefix_empty(self, get_blobs_list): assert not hook.check_for_prefix("container", "prefix", timeout=3) get_blobs_list.assert_called_once_with(container_name="container", prefix="prefix", timeout=3) + def test_check_for_container(self, mocked_blob_service_client): + mock_container = create_autospec(ContainerClient, instance=True) + mock_container.exists.return_value = True + mocked_blob_service_client.return_value.get_container_client.return_value = mock_container + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) + assert hook.check_for_container("mycontainer") is True + mocked_blob_service_client.return_value.get_container_client.assert_called_once_with("mycontainer") + mock_container.exists.assert_called_once_with() + + def test_check_for_container_not_found(self, mocked_blob_service_client): + mock_container = create_autospec(ContainerClient, instance=True) + mock_container.exists.return_value = False + mocked_blob_service_client.return_value.get_container_client.return_value = mock_container + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) + assert hook.check_for_container("missing-container") is False + mock_container.exists.assert_called_once_with() + + def test_check_for_container_raises_type_error_for_invalid_client(self, mocked_blob_service_client): + mocked_blob_service_client.return_value.get_container_client.return_value = mock.MagicMock() + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) + with pytest.raises(TypeError, match="container for WasbHook must be ContainerClient"): + hook.check_for_container("mycontainer") + + def test_check_for_container_propagates_unexpected_errors(self, mocked_blob_service_client): + mock_container = create_autospec(ContainerClient, instance=True) + mock_container.exists.side_effect = OSError("network down") + mocked_blob_service_client.return_value.get_container_client.return_value = mock_container + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) + with pytest.raises(OSError, match="network down"): + hook.check_for_container("mycontainer") + def test_get_blobs_list(self, mocked_blob_service_client): mock_container = create_autospec(ContainerClient, instance=True) mocked_blob_service_client.return_value.get_container_client.return_value = mock_container @@ -565,6 +597,101 @@ def test_download(self, mocked_blob_service_client): blob_client.assert_called_once_with(container="mycontainer", blob="myblob") blob_client.return_value.download_blob.assert_called_once_with(offset=2, length=4) + def test_sync_to_local_dir_behaviour(self, mocked_blob_service_client, tmp_path): + def get_logs_string(call_args_list): + return "".join([args[0][0] % args[0][1:] for args in call_args_list]) + + def make_blob(name, size=9, last_modified=None): + blob = mock.MagicMock(name=f"BLOB:{name}") + blob.name = name + blob.size = size + blob.last_modified = last_modified or datetime(2026, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + return blob + + container_name = "test_container" + mock_container = create_autospec(ContainerClient, instance=True) + mocked_blob_service_client.return_value.get_container_client.return_value = mock_container + + blobs = [ + make_blob("dag_01.py"), + make_blob("dag_02.py"), + make_blob("subproject1/dag_a.py"), + make_blob("subproject1/dag_b.py"), + ] + mock_container.list_blobs.return_value = blobs + + sync_local_dir = tmp_path / "wasb_sync_dir" + hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) + hook.log.debug = mock.MagicMock() + hook.get_file = mock.MagicMock() + + hook.sync_to_local_dir( + container_name=container_name, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(hook.log.debug.call_args_list) + assert f"Downloading data from wasb://{container_name}/ to {sync_local_dir}" in logs_string + assert f"Local file {sync_local_dir}/dag_01.py does not exist." in logs_string + assert f"Downloaded dag_01.py to {sync_local_dir.as_posix()}/dag_01.py" in logs_string + assert f"Local file {sync_local_dir}/subproject1/dag_a.py does not exist." in logs_string + assert ( + f"Downloaded subproject1/dag_a.py to {sync_local_dir.as_posix()}/subproject1/dag_a.py" + in logs_string + ) + assert hook.get_file.call_count == 4 + + for blob in blobs: + p = sync_local_dir / blob.name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("test data") + os.utime(p, (blob.last_modified.timestamp(), blob.last_modified.timestamp())) + + hook.log.debug = mock.MagicMock() + hook.get_file.reset_mock() + new_blob = make_blob("dag_03.py") + mock_container.list_blobs.return_value = blobs + [new_blob] + hook.sync_to_local_dir( + container_name=container_name, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(hook.log.debug.call_args_list) + assert ( + f"Local file {(sync_local_dir / 'subproject1' / 'dag_b.py').as_posix()} is up-to-date " + "with blob subproject1/dag_b.py. Skipping download." + ) in logs_string + assert f"Local file {sync_local_dir}/dag_03.py does not exist." in logs_string + assert f"Downloaded dag_03.py to {sync_local_dir.as_posix()}/dag_03.py" in logs_string + hook.get_file.assert_called_once() + (sync_local_dir / "dag_03.py").write_text("test data") + os.utime( + sync_local_dir / "dag_03.py", + (new_blob.last_modified.timestamp(), new_blob.last_modified.timestamp()), + ) + + local_file_that_should_be_deleted = sync_local_dir / "file_that_should_be_deleted.py" + local_file_that_should_be_deleted.write_text("test dag") + local_folder_should_be_deleted = sync_local_dir / "local_folder_should_be_deleted" + local_folder_should_be_deleted.mkdir(exist_ok=True) + hook.log.debug = mock.MagicMock() + hook.get_file.reset_mock() + hook.sync_to_local_dir( + container_name=container_name, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(hook.log.debug.call_args_list) + assert f"Deleted stale local file: {local_file_that_should_be_deleted.as_posix()}" in logs_string + assert f"Deleted stale empty directory: {local_folder_should_be_deleted.as_posix()}" in logs_string + assert not hook.get_file.called + + hook.log.debug = mock.MagicMock() + hook.get_file.reset_mock() + updated_blob = make_blob("dag_03.py", size=15) + mock_container.list_blobs.return_value = blobs + [updated_blob] + hook.sync_to_local_dir( + container_name=container_name, local_dir=sync_local_dir, prefix="", delete_stale=True + ) + logs_string = get_logs_string(hook.log.debug.call_args_list) + assert "Blob size (15) and local file size (9) differ." in logs_string + assert f"Downloaded dag_03.py to {sync_local_dir.as_posix()}/dag_03.py" in logs_string + hook.get_file.assert_called_once() + def test_get_container_client(self, mocked_blob_service_client): hook = WasbHook(wasb_conn_id=self.azure_shared_key_test) hook._get_container_client("mycontainer")