diff --git a/airflow-core/src/airflow/provider.yaml.schema.json b/airflow-core/src/airflow/provider.yaml.schema.json index 50dacbc79634c..49e1426261620 100644 --- a/airflow-core/src/airflow/provider.yaml.schema.json +++ b/airflow-core/src/airflow/provider.yaml.schema.json @@ -199,6 +199,27 @@ ] } }, + "object-storage-providers": { + "type": "array", + "description": "Array of object storage providers mapped to provider class names", + "items": { + "type": "object", + "properties": { + "storage-type": { + "description": "Type of object storage (e.g. s3, gcs, azure)", + "type": "string" + }, + "provider-class-name": { + "description": "Provider class name that implements the ObjectStorageProvider", + "type": "string" + } + }, + "required": [ + "storage-type", + "provider-class-name" + ] + } + }, "hooks": { "type": "array", "items": { diff --git a/airflow-core/src/airflow/provider_info.schema.json b/airflow-core/src/airflow/provider_info.schema.json index b3e9bf74b6284..0ef7de0c35563 100644 --- a/airflow-core/src/airflow/provider_info.schema.json +++ b/airflow-core/src/airflow/provider_info.schema.json @@ -206,6 +206,23 @@ } } }, + "object-storage-providers": { + "description": "List of object storage providers the provider provides", + "type": "array", + "items": { + "type": "object", + "properties": { + "storage-type": { + "description": "Type of object storage (e.g. s3, gcs, azure)", + "type": "string" + }, + "provider-class-name": { + "description": "Class name that implements the ObjectStorageProvider", + "type": "string" + } + } + } + }, "transfers": { "description": "List of transfer operators the provider provides", "type": "array", diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 8e270209a8b0a..b36a4cc4d6d1f 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -208,6 +208,14 @@ class DialectInfo(NamedTuple): provider_name: str +class ObjectStorageProviderInfo(NamedTuple): + """Object storage provider class and Provider it comes from.""" + + name: str + provider_class_name: str + provider_name: str + + class TriggerInfo(NamedTuple): """Trigger class and provider it comes from.""" @@ -428,6 +436,7 @@ def __init__(self): # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} self._dialect_provider_dict: dict[str, DialectInfo] = {} + self._object_storage_provider_dict: dict[str, ObjectStorageProviderInfo] = {} # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache() # Keeps hook display names read from provider.yaml (hook-name field) @@ -789,6 +798,7 @@ def _discover_hooks(self) -> None: duplicated_connection_types: set[str] = set() hook_class_names_registered: set[str] = set() self._discover_provider_dialects(package_name, provider) + self._discover_object_storage_providers(package_name, provider) provider_uses_connection_types = self._discover_hooks_from_connection_types( hook_class_names_registered, duplicated_connection_types, package_name, provider ) @@ -815,6 +825,24 @@ def _discover_provider_dialects(self, provider_name: str, provider: ProviderInfo } ) + def _discover_object_storage_providers(self, provider_name: str, provider: ProviderInfo): + entries = provider.data.get("object-storage-providers", []) + for item in entries: + storage_type = item["storage-type"] + if storage_type in self._object_storage_provider_dict: + existing = self._object_storage_provider_dict[storage_type] + log.warning( + "ObjectStorageProvider for '%s' already registered by %s, overriding with %s", + storage_type, + existing.provider_name, + provider_name, + ) + self._object_storage_provider_dict[storage_type] = ObjectStorageProviderInfo( + name=storage_type, + provider_class_name=item["provider-class-name"], + provider_name=provider_name, + ) + @provider_info_cache("import_all_hooks") def _import_info_from_all_hooks(self): """Force-import all hooks and initialize the connections/fields.""" @@ -1450,6 +1478,12 @@ def dialects(self) -> MutableMapping[str, DialectInfo]: # When we return dialects here it will only be used to retrieve dialect information return self._dialect_provider_dict + @property + def object_storage_providers(self) -> MutableMapping[str, ObjectStorageProviderInfo]: + """Return dictionary of storage-type to ObjectStorageProviderInfo mapping.""" + self.initialize_providers_hooks() + return self._object_storage_provider_dict + @property def plugins(self) -> list[PluginInfo]: """Returns information about plugins available in providers.""" @@ -1579,6 +1613,7 @@ def _cleanup(self): self._taskflow_decorators.clear() self._hook_provider_dict.clear() self._dialect_provider_dict.clear() + self._object_storage_provider_dict.clear() self._hooks_lazy_dict.clear() self._connection_form_widgets.clear() self._field_behaviours.clear() diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index 25e3774d4f7e7..673e317534967 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -34,6 +34,7 @@ from airflow.providers_manager import ( DialectInfo, LazyDictWithCache, + ObjectStorageProviderInfo, PluginInfo, ProviderInfo, ProvidersManager, @@ -341,6 +342,101 @@ def test_dialects(self): assert len(dialect_class_names) == 3 assert dialect_class_names == ["default", "mssql", "postgresql"] + def test_object_storage_providers(self): + provider_manager = ProvidersManager() + storage_types = sorted(provider_manager.object_storage_providers) + assert "s3" in storage_types + info = provider_manager.object_storage_providers["s3"] + assert ( + info.provider_class_name + == "airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider" + ) + assert info.provider_name == "apache-airflow-providers-amazon" + + def test_discover_object_storage_providers(self): + providers_manager = ProvidersManager() + providers_manager._provider_dict = LazyDictWithCache() + providers_manager._provider_dict["airflow.providers.amazon"] = ProviderInfo( + version="1.0.0", + data={ + "object-storage-providers": [ + { + "storage-type": "s3", + "provider-class-name": "airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider", + } + ] + }, + ) + providers_manager._discover_hooks() + assert len(providers_manager._object_storage_provider_dict) == 1 + assert providers_manager._object_storage_provider_dict["s3"] == ObjectStorageProviderInfo( + name="s3", + provider_class_name="airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider", + provider_name="airflow.providers.amazon", + ) + + def test_discover_object_storage_providers_empty(self): + providers_manager = ProvidersManager() + providers_manager._provider_dict = LazyDictWithCache() + providers_manager._provider_dict["airflow.providers.example"] = ProviderInfo( + version="1.0.0", + data={}, + ) + providers_manager._discover_hooks() + assert len(providers_manager._object_storage_provider_dict) == 0 + + def test_discover_object_storage_providers_duplicate_logs_warning(self, caplog): + providers_manager = ProvidersManager() + providers_manager._provider_dict = LazyDictWithCache() + providers_manager._provider_dict["airflow.providers.first"] = ProviderInfo( + version="1.0.0", + data={ + "object-storage-providers": [ + { + "storage-type": "s3", + "provider-class-name": "first.S3Provider", + } + ] + }, + ) + providers_manager._provider_dict["airflow.providers.second"] = ProviderInfo( + version="1.0.0", + data={ + "object-storage-providers": [ + { + "storage-type": "s3", + "provider-class-name": "second.S3Provider", + } + ] + }, + ) + with caplog.at_level(logging.WARNING): + providers_manager._discover_hooks() + assert "already registered" in caplog.text + assert ( + providers_manager._object_storage_provider_dict["s3"].provider_class_name == "second.S3Provider" + ) + + def test_object_storage_providers_property(self): + providers_manager = ProvidersManager() + providers_manager._provider_dict = LazyDictWithCache() + providers_manager._provider_dict["airflow.providers.amazon"] = ProviderInfo( + version="1.0.0", + data={ + "object-storage-providers": [ + { + "storage-type": "s3", + "provider-class-name": "airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider", + } + ] + }, + ) + providers_manager._discover_hooks() + osp = providers_manager._object_storage_provider_dict + assert isinstance(osp, dict) + assert "s3" in osp + assert osp["s3"].provider_name == "airflow.providers.amazon" + class TestWithoutCheckProviderManager: @pytest.fixture(autouse=True) diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 2ebbef397176b..2c1803c6bb2aa 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -646,6 +646,10 @@ dataset-uris: filesystems: - airflow.providers.amazon.aws.fs.s3 +object-storage-providers: + - storage-type: s3 + provider-class-name: airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider + hooks: - integration-name: Amazon Athena python-modules: diff --git a/providers/amazon/pyproject.toml b/providers/amazon/pyproject.toml index ef24eaea8f167..53c70a1d5cdd8 100644 --- a/providers/amazon/pyproject.toml +++ b/providers/amazon/pyproject.toml @@ -94,6 +94,9 @@ dependencies = [ "cncf.kubernetes" = [ "apache-airflow-providers-cncf-kubernetes>=7.2.0", ] +"datafusion" = [ + "datafusion>=50.0.0,<52.0.0", +] "s3fs" = [ "s3fs>=2023.10.0", ] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/datafusion/__init__.py b/providers/amazon/src/airflow/providers/amazon/aws/datafusion/__init__.py new file mode 100644 index 0000000000000..21d298ede6ed3 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/datafusion/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations diff --git a/providers/amazon/src/airflow/providers/amazon/aws/datafusion/object_storage.py b/providers/amazon/src/airflow/providers/amazon/aws/datafusion/object_storage.py new file mode 100644 index 0000000000000..b160779b84cfd --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/datafusion/object_storage.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datafusion.object_store import AmazonS3 + +from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook +from airflow.providers.common.compat.sdk import BaseHook +from airflow.providers.common.sql.config import ConnectionConfig, StorageType +from airflow.providers.common.sql.datafusion.base import ObjectStorageProvider +from airflow.providers.common.sql.datafusion.exceptions import ObjectStoreCreationException + + +class S3ObjectStorageProvider(ObjectStorageProvider): + """S3 Object Storage Provider using DataFusion's AmazonS3.""" + + @property + def get_storage_type(self) -> StorageType: + """Return the storage type.""" + return StorageType.S3 + + def create_object_store(self, path: str, connection_config: ConnectionConfig | None = None): + """Create an S3 object store using DataFusion's AmazonS3.""" + if connection_config is None: + raise ValueError(f"connection_config must be provided for {self.get_storage_type}") + + try: + conn = BaseHook.get_connection(connection_config.conn_id) + aws_hook: AwsGenericHook = AwsGenericHook(aws_conn_id=conn.conn_id, client_type="s3") + creds = aws_hook.get_credentials() + + credentials = { + "access_key_id": conn.login or creds.access_key, + "secret_access_key": conn.password or creds.secret_key, + "session_token": creds.token if creds.token else None, + } + credentials = {k: v for k, v in credentials.items() if v is not None} + extra_config = {k: conn.extra_dejson[k] for k in ["region", "endpoint"] if k in conn.extra_dejson} + + bucket = self.get_bucket(path) + s3_store = AmazonS3(**credentials, **extra_config, bucket_name=bucket) + self.log.info("Created S3 object store for bucket %s", bucket) + return s3_store + + except ObjectStoreCreationException: + raise + except Exception as e: + raise ObjectStoreCreationException(f"Failed to create S3 object store: {e}") + + def get_scheme(self) -> str: + """Return the scheme for S3.""" + return "s3://" diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 3a407bb1b2fa6..e815647a847dd 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -694,6 +694,12 @@ def get_provider_info(): }, ], "filesystems": ["airflow.providers.amazon.aws.fs.s3"], + "object-storage-providers": [ + { + "storage-type": "s3", + "provider-class-name": "airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider", + } + ], "hooks": [ { "integration-name": "Amazon Athena", diff --git a/providers/amazon/tests/unit/amazon/aws/datafusion/__init__.py b/providers/amazon/tests/unit/amazon/aws/datafusion/__init__.py new file mode 100644 index 0000000000000..21d298ede6ed3 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/datafusion/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations diff --git a/providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py b/providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py new file mode 100644 index 0000000000000..7e1965f1706c3 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.models import Connection +from airflow.providers.common.sql.config import ConnectionConfig, StorageType +from airflow.providers.common.sql.datafusion.exceptions import ObjectStoreCreationException + + +class TestS3ObjectStorageProvider: + """Tests for S3ObjectStorageProvider in the amazon provider package.""" + + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id="aws_default", + conn_type="aws", + login="fake_id", + password="fake_secret", + extra='{"region": "us-east-1"}', + ) + ) + + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3", + autospec=True, + ) + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook", + autospec=True, + ) + def test_s3_provider_with_login_password(self, mock_hook_cls, mock_s3): + """Login/password on the connection override hook credentials.""" + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + mock_creds = MagicMock() + mock_creds.access_key = "hook_key" + mock_creds.secret_key = "hook_secret" + mock_creds.token = None + mock_hook_cls.return_value.get_credentials.return_value = mock_creds + + provider = S3ObjectStorageProvider() + config = ConnectionConfig(conn_id="aws_default") + + store = provider.create_object_store("s3://demo-data/path", connection_config=config) + + mock_s3.assert_called_once_with( + access_key_id="fake_id", + secret_access_key="fake_secret", + region="us-east-1", + bucket_name="demo-data", + ) + assert store == mock_s3.return_value + assert provider.get_storage_type == StorageType.S3 + assert provider.get_scheme() == "s3://" + + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3", + autospec=True, + ) + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook", + autospec=True, + ) + def test_s3_provider_falls_back_to_hook_credentials(self, mock_hook_cls, mock_s3): + """When login/password are empty, hook credentials are used.""" + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + mock_creds = MagicMock() + mock_creds.access_key = "hook_key" + mock_creds.secret_key = "hook_secret" + mock_creds.token = "session_tok" + mock_hook_cls.return_value.get_credentials.return_value = mock_creds + + provider = S3ObjectStorageProvider() + config = ConnectionConfig(conn_id="aws_no_login") + + with patch( + "airflow.providers.amazon.aws.datafusion.object_storage.BaseHook.get_connection", + return_value=Connection( + conn_id="aws_no_login", + conn_type="aws", + extra='{"endpoint": "http://localhost:4566"}', + ), + ): + store = provider.create_object_store("s3://bucket/path", connection_config=config) + + mock_s3.assert_called_once_with( + access_key_id="hook_key", + secret_access_key="hook_secret", + session_token="session_tok", + endpoint="http://localhost:4566", + bucket_name="bucket", + ) + assert store == mock_s3.return_value + + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3", + autospec=True, + ) + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook", + autospec=True, + ) + def test_s3_provider_session_token(self, mock_hook_cls, mock_s3): + """Session token from hook is forwarded when present.""" + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + mock_creds = MagicMock() + mock_creds.access_key = "hook_key" + mock_creds.secret_key = "hook_secret" + mock_creds.token = "my_session_token" + mock_hook_cls.return_value.get_credentials.return_value = mock_creds + + provider = S3ObjectStorageProvider() + config = ConnectionConfig(conn_id="aws_default") + + store = provider.create_object_store("s3://bucket/path", connection_config=config) + + call_kwargs = mock_s3.call_args.kwargs + assert call_kwargs["session_token"] == "my_session_token" + assert store == mock_s3.return_value + + def test_s3_provider_missing_connection_config(self): + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + provider = S3ObjectStorageProvider() + with pytest.raises(ValueError, match="connection_config must be provided"): + provider.create_object_store("s3://bucket/path", connection_config=None) + + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3", + autospec=True, + ) + @patch( + "airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook", + autospec=True, + ) + def test_s3_provider_creation_failure(self, mock_hook_cls, mock_s3): + """Internal exceptions are wrapped in ObjectStoreCreationException.""" + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + mock_creds = MagicMock() + mock_creds.access_key = "k" + mock_creds.secret_key = "s" + mock_creds.token = None + mock_hook_cls.return_value.get_credentials.return_value = mock_creds + mock_s3.side_effect = Exception("boom") + + provider = S3ObjectStorageProvider() + config = ConnectionConfig(conn_id="aws_default") + + with pytest.raises(ObjectStoreCreationException, match="Failed to create S3 object store"): + provider.create_object_store("s3://bucket/path", connection_config=config) + + def test_s3_provider_bucket_extraction(self): + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + provider = S3ObjectStorageProvider() + assert provider.get_bucket("s3://my-bucket/prefix/file.parquet") == "my-bucket" + assert provider.get_bucket("s3://another-bucket/") == "another-bucket" + assert provider.get_bucket("file:///local/path") is None diff --git a/providers/common/sql/docs/operators.rst b/providers/common/sql/docs/operators.rst index 88e326a25fe5b..361019bab4d4b 100644 --- a/providers/common/sql/docs/operators.rst +++ b/providers/common/sql/docs/operators.rst @@ -266,7 +266,7 @@ The Analytics Operator is ideal for performing efficient, high-performance analy Supported Storage Systems ------------------------- -- S3 +- S3 (requires ``apache-airflow-providers-amazon[datafusion]``) - Local File System .. note:: diff --git a/providers/common/sql/pyproject.toml b/providers/common/sql/pyproject.toml index ddadb5ceb6908..7ca5a84b80fff 100644 --- a/providers/common/sql/pyproject.toml +++ b/providers/common/sql/pyproject.toml @@ -87,9 +87,6 @@ dependencies = [ "sqlalchemy" = [ "sqlalchemy>=1.4.54", ] -"amazon" = [ - "apache-airflow-providers-amazon" -] # DataFusion 52.0.0 crate is not supported at the moment with iceberg-core "datafusion" = [ "datafusion>=50.0.0,<52.0.0", @@ -100,6 +97,9 @@ dependencies = [ "apache.iceberg" = [ "apache-airflow-providers-apache-iceberg" ] +"amazon" = [ + "apache-airflow-providers-amazon" +] [dependency-groups] dev = [ diff --git a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py index 21e0b72390b6f..66a749aa6d6a8 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py +++ b/providers/common/sql/src/airflow/providers/common/sql/datafusion/engine.py @@ -20,7 +20,6 @@ from datafusion import SessionContext -from airflow.providers.common.compat.sdk import BaseHook, Connection from airflow.providers.common.sql.config import ConnectionConfig, DataSourceConfig, StorageType from airflow.providers.common.sql.datafusion.exceptions import ( ObjectStoreCreationException, @@ -36,7 +35,6 @@ class DataFusionEngine(LoggingMixin): def __init__(self): super().__init__() - # TODO: session context has additional parameters via SessionConfig see what's possible we can use Possible via DataFusionHook ? self.df_ctx = SessionContext() self.registered_tables: dict[str, str] = {} @@ -122,61 +120,8 @@ def execute_query(self, query: str, max_rows: int | None = None) -> dict[str, li raise QueryExecutionException(f"Error while executing query: {e}") def _get_connection_config(self, conn_id: str) -> ConnectionConfig: - - airflow_conn = BaseHook.get_connection(conn_id) - - credentials, extra_config = self._get_credentials(airflow_conn) - - return ConnectionConfig( - conn_id=airflow_conn.conn_id, - credentials=credentials, - extra_config=extra_config, - ) - - def _get_credentials(self, conn: Connection) -> tuple[dict[str, Any], dict[str, Any]]: - - credentials = {} - extra_config = {} - - def _fetch_extra_configs(keys: list[str]) -> dict[str, Any]: - conf = {} - extra_dejson = conn.extra_dejson - for key in keys: - if key in extra_dejson: - conf[key] = conn.extra_dejson[key] - return conf - - match conn.conn_type: - case "aws": - try: - from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook - except ImportError: - from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException - - raise AirflowOptionalProviderFeatureException( - "Failed to import AwsGenericHook. To use the S3 storage functionality, please install the " - "apache-airflow-providers-amazon package." - ) - aws_hook: AwsGenericHook = AwsGenericHook(aws_conn_id=conn.conn_id, client_type="s3") - creds = aws_hook.get_credentials() - credentials.update( - { - "access_key_id": conn.login or creds.access_key, - "secret_access_key": conn.password or creds.secret_key, - "session_token": creds.token if creds.token else None, - } - ) - credentials = self._remove_none_values(credentials) - extra_config = _fetch_extra_configs(["region", "endpoint"]) - - case _: - raise ValueError(f"Unknown connection type {conn.conn_type}") - return credentials, extra_config - - @staticmethod - def _remove_none_values(params: dict[str, Any]) -> dict[str, Any]: - """Filter out None values from the dictionary.""" - return {k: v for k, v in params.items() if v is not None} + """Build a ConnectionConfig; credential resolution is delegated to the provider.""" + return ConnectionConfig(conn_id=conn_id) def get_schema(self, table_name: str): """Get the schema of a table.""" diff --git a/providers/common/sql/src/airflow/providers/common/sql/datafusion/object_storage_provider.py b/providers/common/sql/src/airflow/providers/common/sql/datafusion/object_storage_provider.py index 1ce917a1bb2da..5c94f28e1ba04 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/datafusion/object_storage_provider.py +++ b/providers/common/sql/src/airflow/providers/common/sql/datafusion/object_storage_provider.py @@ -16,41 +16,15 @@ # under the License. from __future__ import annotations -from datafusion.object_store import AmazonS3, LocalFileSystem +import warnings +from typing import Any +from datafusion.object_store import LocalFileSystem + +from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.common.compat.module_loading import import_string from airflow.providers.common.sql.config import ConnectionConfig, StorageType from airflow.providers.common.sql.datafusion.base import ObjectStorageProvider -from airflow.providers.common.sql.datafusion.exceptions import ObjectStoreCreationException - - -class S3ObjectStorageProvider(ObjectStorageProvider): - """S3 Object Storage Provider using DataFusion's AmazonS3.""" - - @property - def get_storage_type(self) -> StorageType: - """Return the storage type.""" - return StorageType.S3 - - def create_object_store(self, path: str, connection_config: ConnectionConfig | None = None): - """Create an S3 object store using DataFusion's AmazonS3.""" - if connection_config is None: - raise ValueError("connection_config must be provided for %s", self.get_storage_type) - - try: - credentials = connection_config.credentials - bucket = self.get_bucket(path) - - s3_store = AmazonS3(**credentials, **connection_config.extra_config, bucket_name=bucket) - self.log.info("Created S3 object store for bucket %s", bucket) - - return s3_store - - except Exception as e: - raise ObjectStoreCreationException(f"Failed to create S3 object store: {e}") - - def get_scheme(self) -> str: - """Return the scheme for S3.""" - return "s3://" class LocalObjectStorageProvider(ObjectStorageProvider): @@ -70,18 +44,58 @@ def get_scheme(self) -> str: return "file://" +_STORAGE_TYPE_PROVIDER_HINTS: dict[str, str] = { + "s3": "apache-airflow-providers-amazon[datafusion]", +} + + +def _missing_provider_message(type_key: str) -> str: + hint = _STORAGE_TYPE_PROVIDER_HINTS.get(type_key, "the appropriate provider package") + return f"No ObjectStorageProvider registered for storage type '{type_key}'. Install or upgrade {hint}." + + +def _get_legacy_object_storage_provider(type_key: str) -> ObjectStorageProvider: + if type_key == StorageType.S3.value: + try: + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + except ImportError as err: + raise ValueError(_missing_provider_message(type_key)) from err + return S3ObjectStorageProvider() + + raise ValueError(_missing_provider_message(type_key)) + + def get_object_storage_provider(storage_type: StorageType) -> ObjectStorageProvider: """Get an object storage provider based on the storage type.""" - # TODO: Add support for GCS, Azure, HTTP: https://datafusion.apache.org/python/autoapi/datafusion/object_store/index.html - providers: dict[StorageType, type] = { - StorageType.S3: S3ObjectStorageProvider, - StorageType.LOCAL: LocalObjectStorageProvider, - } - - if storage_type not in providers: - raise ValueError( - f"Unsupported storage type: {storage_type}. Supported types: {list(providers.keys())}" + if storage_type == StorageType.LOCAL: + return LocalObjectStorageProvider() + + type_key = storage_type.value + + from airflow.providers_manager import ProvidersManager + + manager = ProvidersManager() + if not hasattr(manager, "object_storage_providers"): + return _get_legacy_object_storage_provider(type_key) + + registry = manager.object_storage_providers + if type_key in registry: + provider_cls = import_string(registry[type_key].provider_class_name) + return provider_cls() + + raise ValueError(_missing_provider_message(type_key)) + + +def __getattr__(name: str) -> Any: + if name == "S3ObjectStorageProvider": + warnings.warn( + "Importing S3ObjectStorageProvider from " + "airflow.providers.common.sql.datafusion.object_storage_provider is deprecated. " + "Import it from airflow.providers.amazon.aws.datafusion.object_storage instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, ) + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider - provider_class = providers[storage_type] - return provider_class() + return S3ObjectStorageProvider + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py index d48c705046f5e..a3b8b1fc29b77 100644 --- a/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py +++ b/providers/common/sql/tests/unit/common/sql/datafusion/test_engine.py @@ -23,7 +23,6 @@ import pytest from datafusion import SessionContext -from airflow.models import Connection from airflow.providers.common.sql.config import ConnectionConfig, DataSourceConfig from airflow.providers.common.sql.datafusion.base import ObjectStorageProvider from airflow.providers.common.sql.datafusion.engine import DataFusionEngine @@ -32,30 +31,10 @@ QueryExecutionException, ) -TEST_CONNECTION_CONFIG = ConnectionConfig( - conn_id="aws_default", - credentials={ - "access_key_id": "test", - "secret_access_key": "test", - "session_token": None, - }, - extra_config={"region_name": "us-east-1"}, -) +TEST_CONNECTION_CONFIG = ConnectionConfig(conn_id="test_conn") class TestDataFusionEngine: - @pytest.fixture(autouse=True) - def setup_connections(self, create_connection_without_db): - create_connection_without_db( - Connection( - conn_id="aws_default", - conn_type="aws", - login="fake_id", - password="fake_secret", - extra='{"region": "us-east-1"}', - ) - ) - def test_init(self): engine = DataFusionEngine() assert engine.df_ctx is not None @@ -88,7 +67,7 @@ def test_register_datasource_success(self, mock_get_conn, mock_factory, storage_ engine = DataFusionEngine() datasource_config = DataSourceConfig( - conn_id="aws_default", table_name="test_table", uri=f"{scheme}://bucket/path", format=format + conn_id="test_conn", table_name="test_table", uri=f"{scheme}://bucket/path", format=format ) engine.df_ctx = MagicMock(spec=SessionContext) @@ -118,7 +97,7 @@ def test_register_datasource_object_store_exception(self, mock_get_conn, mock_fa engine = DataFusionEngine() datasource_config = DataSourceConfig( - conn_id="aws_default", table_name="test_table", uri="s3://bucket/path", format="parquet" + conn_id="test_conn", table_name="test_table", uri="s3://bucket/path", format="parquet" ) with pytest.raises(ObjectStoreCreationException, match="Error while creating object store"): @@ -131,7 +110,7 @@ def test_register_datasource_duplicate_table(self, mock_get_conn): engine.registered_tables["test_table"] = "s3://old/path" datasource_config = DataSourceConfig( - conn_id="aws_default", table_name="test_table", uri="s3://new/path", format="parquet" + conn_id="test_conn", table_name="test_table", uri="s3://new/path", format="parquet" ) with patch.object(engine, "_register_object_store"): @@ -235,7 +214,7 @@ def test_register_datasource_with_options(self, mock_get_conn, mock_factory): engine = DataFusionEngine() datasource_config = DataSourceConfig( - conn_id="aws_default", + conn_id="test_conn", table_name="test_table", uri="s3://bucket/path/", format="parquet", @@ -260,34 +239,13 @@ def test_register_datasource_with_options(self, mock_get_conn, mock_factory): assert engine.registered_tables == {"test_table": "s3://bucket/path/"} - def test_remove_none_values(self): - result = DataFusionEngine._remove_none_values({"a": 1, "b": None, "c": "test", "d": None}) - assert result == {"a": 1, "c": "test"} - - def test_get_connection_config(self): - + def test_get_connection_config_delegates_to_provider(self): + """_get_connection_config only passes conn_id; credential resolution is the provider's job.""" engine = DataFusionEngine() - - result = engine._get_connection_config("aws_default") - expected = ConnectionConfig( - conn_id="aws_default", - credentials={ - "access_key_id": "fake_id", - "secret_access_key": "fake_secret", - }, - extra_config={"region": "us-east-1"}, - ) - assert result.conn_id == expected.conn_id - assert result.credentials == expected.credentials - assert result.extra_config == expected.extra_config - - def test_get_credentials_unknown_type(self): - mock_conn = MagicMock() - mock_conn.conn_type = "dummy" - engine = DataFusionEngine() - - with pytest.raises(ValueError, match="Unknown connection type dummy"): - engine._get_credentials(mock_conn) + result = engine._get_connection_config("my_conn") + assert result == ConnectionConfig(conn_id="my_conn") + assert result.credentials == {} + assert result.extra_config == {} def test_get_schema_success(self): engine = DataFusionEngine() diff --git a/providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py b/providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py index 9b0ff756a1fa4..e998b696c0aeb 100644 --- a/providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py +++ b/providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py @@ -16,59 +16,136 @@ # under the License. from __future__ import annotations -from unittest.mock import patch +import sys +from unittest.mock import MagicMock, patch import pytest -from airflow.providers.common.sql.config import ConnectionConfig, StorageType -from airflow.providers.common.sql.datafusion.exceptions import ObjectStoreCreationException +from airflow.providers.common.sql.config import StorageType from airflow.providers.common.sql.datafusion.object_storage_provider import ( LocalObjectStorageProvider, - S3ObjectStorageProvider, get_object_storage_provider, ) -class TestObjectStorageProvider: - @patch("airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3") - def test_s3_provider_success(self, mock_s3): - provider = S3ObjectStorageProvider() - connection_config = ConnectionConfig( - conn_id="aws_default", - credentials={"access_key_id": "fake_key", "secret_access_key": "fake_secret"}, - ) +class TestLocalObjectStorageProvider: + @patch( + "airflow.providers.common.sql.datafusion.object_storage_provider.LocalFileSystem", + autospec=True, + ) + def test_local_provider(self, mock_local): + provider = LocalObjectStorageProvider() + assert provider.get_storage_type == StorageType.LOCAL + assert provider.get_scheme() == "file://" + local_store = provider.create_object_store("file://path") + assert local_store == mock_local.return_value + + +class TestGetObjectStorageProvider: + def test_returns_local_provider_directly(self): + provider = get_object_storage_provider(StorageType.LOCAL) + assert isinstance(provider, LocalObjectStorageProvider) - store = provider.create_object_store("s3://demo-data/path", connection_config) + @patch("airflow.providers.common.sql.datafusion.object_storage_provider.import_string", autospec=True) + @patch("airflow.providers_manager.ProvidersManager", autospec=True) + def test_resolves_s3_via_registry(self, mock_pm_cls, mock_import_string): + mock_provider_cls = MagicMock() + mock_import_string.return_value = mock_provider_cls - mock_s3.assert_called_once_with( - access_key_id="fake_key", secret_access_key="fake_secret", bucket_name="demo-data" + mock_pm_cls.return_value.object_storage_providers = { + "s3": MagicMock( + provider_class_name="airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider", + ), + } + + provider = get_object_storage_provider(StorageType.S3) + + mock_import_string.assert_called_once_with( + "airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider" ) - assert store == mock_s3.return_value - assert provider.get_storage_type == StorageType.S3 - assert provider.get_scheme() == "s3://" + assert provider == mock_provider_cls.return_value + + @patch("airflow.providers_manager.ProvidersManager", autospec=True) + def test_unregistered_storage_type_raises(self, mock_pm_cls): + mock_pm_cls.return_value.object_storage_providers = {} + + with pytest.raises(ValueError, match="No ObjectStorageProvider registered.*Install or upgrade"): + get_object_storage_provider(StorageType.S3) + + def test_error_message_includes_install_hint_for_s3(self): + with patch("airflow.providers_manager.ProvidersManager", autospec=True) as mock_pm_cls: + mock_pm_cls.return_value.object_storage_providers = {} + + with pytest.raises(ValueError, match="apache-airflow-providers-amazon"): + get_object_storage_provider(StorageType.S3) + + def test_legacy_core_resolves_s3_via_amazon_direct_import(self): + """On an older core without the registry property, S3 resolves via a direct amazon import.""" + pytest.importorskip("airflow.providers.amazon") + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + with patch("airflow.providers_manager.ProvidersManager") as mock_pm_cls: + mock_pm_cls.return_value = MagicMock(spec=[]) + + provider = get_object_storage_provider(StorageType.S3) - def test_s3_provider_failure(self): - provider = S3ObjectStorageProvider() - connection_config = ConnectionConfig(conn_id="aws_default") + assert isinstance(provider, S3ObjectStorageProvider) - with patch( - "airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3", - side_effect=Exception("Error"), + def test_legacy_core_s3_without_amazon_raises_install_hint(self): + """On an older core without the registry and amazon missing, raise the hinted ValueError.""" + with patch("airflow.providers_manager.ProvidersManager") as mock_pm_cls: + mock_pm_cls.return_value = MagicMock(spec=[]) + + with patch.dict( + sys.modules, + {"airflow.providers.amazon.aws.datafusion.object_storage": None}, + ): + with pytest.raises(ValueError, match="apache-airflow-providers-amazon"): + get_object_storage_provider(StorageType.S3) + + def test_no_amazon_imports_at_module_level(self): + """Verify common-sql no longer statically imports amazon provider code at the top level.""" + import airflow.providers.common.sql.datafusion.object_storage_provider as mod + + top_level_names = [ + name + for name, obj in vars(mod).items() + if not name.startswith("_") + and hasattr(obj, "__module__") + and "amazon" in getattr(obj, "__module__", "") + ] + assert top_level_names == [], f"Amazon symbols found at module level: {top_level_names}" + + +class TestS3DeprecationShim: + def test_old_import_path_emits_deprecation_warning(self): + """Importing S3ObjectStorageProvider from the old path still works but warns.""" + pytest.importorskip("airflow.providers.amazon") + import airflow.providers.common.sql.datafusion.object_storage_provider as mod + + with pytest.warns( + match="Import it from airflow.providers.amazon", ): - with pytest.raises(ObjectStoreCreationException, match="Failed to create S3 object store"): - provider.create_object_store("s3://demo-data/path", connection_config) + cls = mod.S3ObjectStorageProvider - @patch("airflow.providers.common.sql.datafusion.object_storage_provider.LocalFileSystem") - def test_local_provider(self, mock_local): - provider = LocalObjectStorageProvider() - assert provider.get_storage_type == StorageType.LOCAL - assert provider.get_scheme() == "file://" - local_store = provider.create_object_store("file://path") - assert local_store == mock_local.return_value + assert cls.__name__ == "S3ObjectStorageProvider" + + def test_old_import_path_returns_same_class(self): + """The shim re-exports the exact same class from the new location.""" + pytest.importorskip("airflow.providers.amazon") + import airflow.providers.common.sql.datafusion.object_storage_provider as mod + + with pytest.warns( + match="Import it from airflow.providers.amazon", + ): + old_cls = mod.S3ObjectStorageProvider + + from airflow.providers.amazon.aws.datafusion.object_storage import S3ObjectStorageProvider + + assert old_cls is S3ObjectStorageProvider - def test_get_object_storage_provider(self): - assert isinstance(get_object_storage_provider(StorageType.S3), S3ObjectStorageProvider) - assert isinstance(get_object_storage_provider(StorageType.LOCAL), LocalObjectStorageProvider) + def test_unknown_attr_raises_attribute_error(self): + import airflow.providers.common.sql.datafusion.object_storage_provider as mod - with pytest.raises(ValueError, match="Unsupported storage type"): - get_object_storage_provider("invalid") + with pytest.raises(AttributeError, match="has no attribute"): + _ = mod.NonExistentClass diff --git a/uv.lock b/uv.lock index 709e8ee883e4d..c972795b9df58 100644 --- a/uv.lock +++ b/uv.lock @@ -3044,6 +3044,9 @@ cncf-kubernetes = [ common-messaging = [ { name = "apache-airflow-providers-common-messaging" }, ] +datafusion = [ + { name = "datafusion" }, +] exasol = [ { name = "apache-airflow-providers-exasol" }, ] @@ -3155,6 +3158,7 @@ requires-dist = [ { name = "asgiref", marker = "python_full_version >= '3.14'", specifier = ">=3.11.1" }, { name = "boto3", specifier = ">=1.41.0" }, { name = "botocore", specifier = ">=1.41.0" }, + { name = "datafusion", marker = "extra == 'datafusion'", specifier = ">=50.0.0,<52.0.0" }, { name = "inflection", specifier = ">=0.5.1" }, { name = "jmespath", specifier = ">=0.7.0" }, { name = "jsonpath-ng", specifier = ">=1.5.3" }, @@ -3173,7 +3177,7 @@ requires-dist = [ { name = "watchtower", specifier = ">=3.3.1,<4" }, { name = "xmlsec", marker = "python_full_version < '3.13' and extra == 'python3-saml'", specifier = ">=1.3.14" }, ] -provides-extras = ["aiobotocore", "cncf-kubernetes", "s3fs", "python3-saml", "apache-hive", "exasol", "fab", "ftp", "google", "imap", "microsoft-azure", "mongo", "pandas", "openlineage", "salesforce", "ssh", "standard", "common-messaging", "sqlalchemy"] +provides-extras = ["aiobotocore", "cncf-kubernetes", "datafusion", "s3fs", "python3-saml", "apache-hive", "exasol", "fab", "ftp", "google", "imap", "microsoft-azure", "mongo", "pandas", "openlineage", "salesforce", "ssh", "standard", "common-messaging", "sqlalchemy"] [package.metadata.requires-dev] dev = [ @@ -4660,7 +4664,7 @@ requires-dist = [ { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=1.4.54" }, { name = "sqlparse", specifier = ">=0.5.1" }, ] -provides-extras = ["pandas", "openlineage", "polars", "sqlalchemy", "amazon", "datafusion", "pyiceberg-core", "apache-iceberg"] +provides-extras = ["pandas", "openlineage", "polars", "sqlalchemy", "datafusion", "pyiceberg-core", "apache-iceberg", "amazon"] [package.metadata.requires-dev] dev = [