diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index fe94e215cba89..dadae6af724dc 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -27,7 +27,7 @@ import types from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type -import pkg_resources +import importlib_metadata from airflow import settings from airflow.utils.file import find_path_from_directory @@ -88,15 +88,16 @@ def __html__(self): class EntryPointSource(AirflowPluginSource): """Class used to define Plugins loaded from entrypoint.""" - def __init__(self, entrypoint): - self.dist = str(entrypoint.dist) + def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution): + self.dist = dist.metadata['name'] + self.version = dist.version self.entrypoint = str(entrypoint) def __str__(self): - return f"{self.dist}: {self.entrypoint}" + return f"{self.dist}=={self.version}: {self.entrypoint}" def __html__(self): - return f"{self.dist}: {self.entrypoint}" + return f"{self.dist}=={self.version}: {self.entrypoint}" class AirflowPluginException(Exception): @@ -169,6 +170,23 @@ def is_valid_plugin(plugin_obj): return False +def entry_points_with_dist(group: str): + """ + Return EntryPoint objects of the given group, along with the distribution information. + + This is like the ``entry_points()`` function from importlib.metadata, + except it also returns the distribution the entry_point was loaded from. + + :param group: FIlter results to only this entrypoint group + :return: Generator of (EntryPoint, Distribution) objects for the specified groups + """ + for dist in importlib_metadata.distributions(): + for e in dist.entry_points: + if e.group != group: + continue + yield (e, dist) + + def load_entrypoint_plugins(): """ Load and register plugins AirflowPlugin subclasses from the entrypoints. @@ -177,20 +195,20 @@ def load_entrypoint_plugins(): global import_errors # pylint: disable=global-statement global plugins # pylint: disable=global-statement - entry_points = pkg_resources.iter_entry_points('airflow.plugins') - log.debug("Loading plugins from entrypoints") - for entry_point in entry_points: # pylint: disable=too-many-nested-blocks + for entry_point, dist in entry_points_with_dist('airflow.plugins'): log.debug('Importing entry_point plugin %s', entry_point.name) try: plugin_class = entry_point.load() - if is_valid_plugin(plugin_class): - plugin_instance = plugin_class() - if callable(getattr(plugin_instance, 'on_load', None)): - plugin_instance.on_load() - plugin_instance.source = EntryPointSource(entry_point) - plugins.append(plugin_instance) + if not is_valid_plugin(plugin_class): + continue + + plugin_instance = plugin_class() + if callable(getattr(plugin_instance, 'on_load', None)): + plugin_instance.on_load() + plugin_instance.source = EntryPointSource(entry_point, dist) + plugins.append(plugin_instance) except Exception as e: # pylint: disable=broad-except log.exception("Failed to import plugin %s", entry_point.name) import_errors[entry_point.module_name] = str(e) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 30041c92270c7..44821f7d90a3d 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -24,7 +24,6 @@ from typing import Dict, Tuple import jsonschema -import pkg_resources import yaml try: @@ -90,19 +89,13 @@ def _discover_all_providers_from_packages(self) -> None: via the 'apache_airflow_provider' entrypoint as a dictionary conforming to the 'airflow/provider.yaml.schema.json' schema. """ - for entry_point in pkg_resources.iter_entry_points('apache_airflow_provider'): - package_name = entry_point.dist.project_name + from airflow.plugins_manager import entry_points_with_dist + + for (entry_point, dist) in entry_points_with_dist('apache_airflow_provider'): + package_name = dist.metadata['name'] log.debug("Loading %s from package %s", entry_point, package_name) - version = entry_point.dist.version - try: - provider_info = entry_point.load()() - except pkg_resources.VersionConflict as e: - log.warning( - "The provider package %s could not be registered because of version conflict : %s", - package_name, - e, - ) - continue + version = dist.version + provider_info = entry_point.load()() self._validator.validate(provider_info) provider_info_package_name = provider_info['package-name'] if package_name != provider_info_package_name: diff --git a/setup.cfg b/setup.cfg index 38e05ef82ad9e..282ceff8acabd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -100,6 +100,7 @@ install_requires = funcsigs>=1.0.0, <2.0.0 graphviz>=0.12 gunicorn>=19.5.0, <20.0 + importlib_metadata~=1.7 # We could work with 3.1, but argparse needs <2 importlib_resources~=1.4 iso8601>=0.1.12 itsdangerous>=1.1.0 diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 48407ae0757a6..117df989c4c23 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -78,30 +78,6 @@ def test_app_blueprints(self): self.assertTrue('test_plugin' in self.app.blueprints) self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name) - @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points') - def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_ep_plugins): - """ - Test that Airflow does not raise an Error if there is any Exception because of the - Plugin. - """ - from airflow.plugins_manager import import_errors, load_entrypoint_plugins - - mock_entrypoint = mock.Mock() - mock_entrypoint.name = 'test-entrypoint' - mock_entrypoint.module_name = 'test.plugins.test_plugins_manager' - mock_entrypoint.load.side_effect = Exception('Version Conflict') - mock_ep_plugins.return_value = [mock_entrypoint] - - with self.assertLogs("airflow.plugins_manager", level="ERROR") as log_output: - load_entrypoint_plugins() - - received_logs = log_output.output[0] - # Assert Traceback is shown too - assert "Traceback (most recent call last):" in received_logs - assert "Version Conflict" in received_logs - assert "Failed to import plugin test-entrypoint" in received_logs - assert ("test.plugins.test_plugins_manager", "Version Conflict") in import_errors.items() - class TestPluginsManager: def test_no_log_when_no_plugins(self, caplog): @@ -210,6 +186,33 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin): assert caplog.record_tuples == [] + def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog): + """ + Test that Airflow does not raise an error if there is any Exception because of a plugin. + """ + from airflow.plugins_manager import import_errors, load_entrypoint_plugins + + mock_dist = mock.Mock() + + mock_entrypoint = mock.Mock() + mock_entrypoint.name = 'test-entrypoint' + mock_entrypoint.group = 'airflow.plugins' + mock_entrypoint.module_name = 'test.plugins.test_plugins_manager' + mock_entrypoint.load.side_effect = ImportError('my_fake_module not found') + mock_dist.entry_points = [mock_entrypoint] + + with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]), caplog.at_level( + logging.ERROR, logger='airflow.plugins_manager' + ): + load_entrypoint_plugins() + + received_logs = caplog.text + # Assert Traceback is shown too + assert "Traceback (most recent call last):" in received_logs + assert "my_fake_module not found" in received_logs + assert "Failed to import plugin test-entrypoint" in received_logs + assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items() + class TestPluginsDirectorySource(unittest.TestCase): def test_should_return_correct_path_name(self): @@ -221,20 +224,23 @@ def test_should_return_correct_path_name(self): self.assertEqual("$PLUGINS_FOLDER/test_plugins_manager.py", source.__html__()) -class TestEntryPointSource(unittest.TestCase): - @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points') - def test_should_return_correct_source_details(self, mock_ep_plugins): +class TestEntryPointSource: + def test_should_return_correct_source_details(self): from airflow import plugins_manager mock_entrypoint = mock.Mock() mock_entrypoint.name = 'test-entrypoint-plugin' mock_entrypoint.module_name = 'module_name_plugin' - mock_entrypoint.dist = 'test-entrypoint-plugin==1.0.0' - mock_ep_plugins.return_value = [mock_entrypoint] - plugins_manager.load_entrypoint_plugins() + mock_dist = mock.Mock() + mock_dist.metadata = {'name': 'test-entrypoint-plugin'} + mock_dist.version = '1.0.0' + mock_dist.entry_points = [mock_entrypoint] + + with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]): + plugins_manager.load_entrypoint_plugins() - source = plugins_manager.EntryPointSource(mock_entrypoint) - self.assertEqual(str(mock_entrypoint), source.entrypoint) - self.assertEqual("test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint), str(source)) - self.assertEqual("test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint), source.__html__()) + source = plugins_manager.EntryPointSource(mock_entrypoint, mock_dist) + assert str(mock_entrypoint) == source.entrypoint + assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == str(source) + assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == source.__html__() diff --git a/tests/www/test_views.py b/tests/www/test_views.py index 9208f328a2590..d4d05720fb646 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -52,7 +52,7 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator -from airflow.plugins_manager import AirflowPlugin, EntryPointSource, PluginsDirectorySource +from airflow.plugins_manager import AirflowPlugin, EntryPointSource from airflow.security import permissions from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES from airflow.utils import dates, timezone @@ -67,6 +67,7 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs +from tests.test_utils.mock_plugins import mock_plugin_manager class TemplateWithContext(NamedTuple): @@ -337,10 +338,6 @@ class PluginOperator(BaseOperator): pass -class EntrypointPlugin(AirflowPlugin): - name = 'test-entrypoint-testpluginview' - - class TestPluginView(TestBase): def test_should_list_plugins_on_page_with_details(self): resp = self.client.get('/plugin') @@ -349,19 +346,15 @@ def test_should_list_plugins_on_page_with_details(self): self.check_content_in_response("source", resp) self.check_content_in_response("$PLUGINS_FOLDER/test_plugin.py", resp) - @mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points') - def test_should_list_entrypoint_plugins_on_page_with_details(self, mock_ep_plugins): - from airflow.plugins_manager import load_entrypoint_plugins - - mock_entrypoint = mock.Mock() - mock_entrypoint.name = 'test-entrypoint-testpluginview' - mock_entrypoint.module_name = 'module_name_testpluginview' - mock_entrypoint.dist = 'test-entrypoint-testpluginview==1.0.0' - mock_entrypoint.load.return_value = EntrypointPlugin - mock_ep_plugins.return_value = [mock_entrypoint] + def test_should_list_entrypoint_plugins_on_page_with_details(self): - load_entrypoint_plugins() - resp = self.client.get('/plugin') + mock_plugin = AirflowPlugin() + mock_plugin.name = "test_plugin" + mock_plugin.source = EntryPointSource( + mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'}) + ) + with mock_plugin_manager(plugins=[mock_plugin]): + resp = self.client.get('/plugin') self.check_content_in_response("test_plugin", resp) self.check_content_in_response("Airflow Plugins", resp) @@ -369,25 +362,6 @@ def test_should_list_entrypoint_plugins_on_page_with_details(self, mock_ep_plugi self.check_content_in_response("test-entrypoint-testpluginview==1.0.0: $PLUGINS_FOLDER/../../test_views.py", source.__html__()) - self.assertEqual("../../test_views.py", source.path) - - -class TestEntryPointSource(unittest.TestCase): - def test_should_provide_correct_attribute_values(self): - mock_entrypoint = mock.Mock() - mock_entrypoint.dist = 'test-entrypoint-dist==1.0.0' - source = EntryPointSource(mock_entrypoint) - self.assertEqual("test-entrypoint-dist==1.0.0", source.dist) - self.assertEqual(str(mock_entrypoint), source.entrypoint) - self.assertEqual("test-entrypoint-dist==1.0.0: " + str(mock_entrypoint), str(source)) - self.assertEqual("test-entrypoint-dist==1.0.0: " + str(mock_entrypoint), source.__html__()) - - class TestPoolModelView(TestBase): def setUp(self): super().setUp()