Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

import pkg_resources
import importlib_metadata

from airflow import settings
from airflow.utils.file import find_path_from_directory
Expand Down Expand Up @@ -88,15 +88,16 @@ def __html__(self):
class EntryPointSource(AirflowPluginSource):
"""Class used to define Plugins loaded from entrypoint."""

def __init__(self, entrypoint):
self.dist = str(entrypoint.dist)
def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
self.dist = dist.metadata['name']
self.version = dist.version
self.entrypoint = str(entrypoint)

def __str__(self):
return f"{self.dist}: {self.entrypoint}"
return f"{self.dist}=={self.version}: {self.entrypoint}"

def __html__(self):
return f"<em>{self.dist}:</em> {self.entrypoint}"
return f"<em>{self.dist}=={self.version}:</em> {self.entrypoint}"


class AirflowPluginException(Exception):
Expand Down Expand Up @@ -169,6 +170,23 @@ def is_valid_plugin(plugin_obj):
return False


def entry_points_with_dist(group: str):
"""
Return EntryPoint objects of the given group, along with the distribution information.

This is like the ``entry_points()`` function from importlib.metadata,
except it also returns the distribution the entry_point was loaded from.

:param group: FIlter results to only this entrypoint group
:return: Generator of (EntryPoint, Distribution) objects for the specified groups
"""
for dist in importlib_metadata.distributions():
for e in dist.entry_points:
if e.group != group:
continue
yield (e, dist)


def load_entrypoint_plugins():
"""
Load and register plugins AirflowPlugin subclasses from the entrypoints.
Expand All @@ -177,20 +195,20 @@ def load_entrypoint_plugins():
global import_errors # pylint: disable=global-statement
global plugins # pylint: disable=global-statement

entry_points = pkg_resources.iter_entry_points('airflow.plugins')

log.debug("Loading plugins from entrypoints")

for entry_point in entry_points: # pylint: disable=too-many-nested-blocks
for entry_point, dist in entry_points_with_dist('airflow.plugins'):
log.debug('Importing entry_point plugin %s', entry_point.name)
try:
plugin_class = entry_point.load()
if is_valid_plugin(plugin_class):
plugin_instance = plugin_class()
if callable(getattr(plugin_instance, 'on_load', None)):
plugin_instance.on_load()
plugin_instance.source = EntryPointSource(entry_point)
plugins.append(plugin_instance)
if not is_valid_plugin(plugin_class):
continue

plugin_instance = plugin_class()
if callable(getattr(plugin_instance, 'on_load', None)):
plugin_instance.on_load()
plugin_instance.source = EntryPointSource(entry_point, dist)
plugins.append(plugin_instance)
except Exception as e: # pylint: disable=broad-except
log.exception("Failed to import plugin %s", entry_point.name)
import_errors[entry_point.module_name] = str(e)
Expand Down
19 changes: 6 additions & 13 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import Dict, Tuple

import jsonschema
import pkg_resources
import yaml

try:
Expand Down Expand Up @@ -90,19 +89,13 @@ def _discover_all_providers_from_packages(self) -> None:
via the 'apache_airflow_provider' entrypoint as a dictionary conforming to the
'airflow/provider.yaml.schema.json' schema.
"""
for entry_point in pkg_resources.iter_entry_points('apache_airflow_provider'):
package_name = entry_point.dist.project_name
from airflow.plugins_manager import entry_points_with_dist

for (entry_point, dist) in entry_points_with_dist('apache_airflow_provider'):
package_name = dist.metadata['name']
log.debug("Loading %s from package %s", entry_point, package_name)
version = entry_point.dist.version
try:
provider_info = entry_point.load()()
except pkg_resources.VersionConflict as e:
log.warning(
"The provider package %s could not be registered because of version conflict : %s",
package_name,
e,
)
continue
version = dist.version
provider_info = entry_point.load()()
self._validator.validate(provider_info)
provider_info_package_name = provider_info['package-name']
if package_name != provider_info_package_name:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ install_requires =
funcsigs>=1.0.0, <2.0.0
graphviz>=0.12
gunicorn>=19.5.0, <20.0
importlib_metadata~=1.7 # We could work with 3.1, but argparse needs <2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you meant argcomplete perhaps? I note that argcomplete has been updated to support importlib_metadata 3.x (and otherwise only has bugfixes, looks like a desireable upgrade).

And if 3.1 works, please widen the pin to >=1.7,<4; the argcomplete~=1.10 pin above already lets you use importlib_metadata 3.1 if you have the newer argcomplete installed.

Also, Python 3.8 doesn't need importlib_metadata anymore, and argcomplete correctly won't depend on it in 3.8 and up. Please do so here too:

     importlib_metadata>=1.7,<4;python_version<="3.7"

and make the imports conditional:

try:
    # 3.8 and up
    from importlib import metadata as importlib_metadata
except ImportError:
    # use the backport
    import importlib_metadata

If you update the argcomplete dependency too, then you'd not even have to set the >=1.7 lower bound, ~=3.1 would do.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whops, yes I did mean argcomplete. Though I also found virtualenv (dep of pylint) was also pinning this, so it may not help us yet.

This line in the readme https://github.com/python/importlib_metadata/ :

As of Python 3.8, this functionality has been added to the Python standard library. This package supplies backports of that functionality including improvements added to subsequent Python versions

Makes me wonder if/when we should continue to use it anyway on Py 3.8?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whops, yes I did mean argcomplete. Though I also found virtualenv (dep of pylint) was also pinning this, so it may not help us yet.

Ah though virtualenv has been updated too.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes me wonder if/when we should continue to use it anyway on Py 3.8?

Your not using the 3.9 / importlib_metadata 1.6 feature: Added module and attr attributes to EntryPoint; everything else is bug fixing or perf improvements, which all have been backported to 3.8.

Even if you were using those extra attributes, using python_version<="3.8" (and inverting the import guard) would be a good idea to keep the dependencies flexible and avoid limiting what packages people can use with Airflow.

I don't see where virtualenv depends on importlib_metadata? Nor does pylint depend on virtualenv. Might be in older versions. pre-commit depends on virtualenv though. And jsonlint depends on importlib_metadata but sets no pin.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

God in getting my dist names all in a mix today. I did mean pre-commit -> virtualenv. I'll have to check what version of virtualenv pre-commit specifies.

I'll open a new PR to add the guard as you suggest.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already opened a PR :-) See #12703

importlib_resources~=1.4
iso8601>=0.1.12
itsdangerous>=1.1.0
Expand Down
74 changes: 40 additions & 34 deletions tests/plugins/test_plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,6 @@ def test_app_blueprints(self):
self.assertTrue('test_plugin' in self.app.blueprints)
self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name)

@mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was in the "TestPluginsRBAC" class -- I moved to the "right" place.

def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_ep_plugins):
"""
Test that Airflow does not raise an Error if there is any Exception because of the
Plugin.
"""
from airflow.plugins_manager import import_errors, load_entrypoint_plugins

mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint'
mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
mock_entrypoint.load.side_effect = Exception('Version Conflict')
mock_ep_plugins.return_value = [mock_entrypoint]

with self.assertLogs("airflow.plugins_manager", level="ERROR") as log_output:
load_entrypoint_plugins()

received_logs = log_output.output[0]
# Assert Traceback is shown too
assert "Traceback (most recent call last):" in received_logs
assert "Version Conflict" in received_logs
assert "Failed to import plugin test-entrypoint" in received_logs
assert ("test.plugins.test_plugins_manager", "Version Conflict") in import_errors.items()


class TestPluginsManager:
def test_no_log_when_no_plugins(self, caplog):
Expand Down Expand Up @@ -210,6 +186,33 @@ class AirflowAdminMenuLinksPlugin(AirflowPlugin):

assert caplog.record_tuples == []

def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog):
"""
Test that Airflow does not raise an error if there is any Exception because of a plugin.
"""
from airflow.plugins_manager import import_errors, load_entrypoint_plugins

mock_dist = mock.Mock()

mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint'
mock_entrypoint.group = 'airflow.plugins'
mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
mock_entrypoint.load.side_effect = ImportError('my_fake_module not found')
mock_dist.entry_points = [mock_entrypoint]

with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]), caplog.at_level(
logging.ERROR, logger='airflow.plugins_manager'
):
load_entrypoint_plugins()

received_logs = caplog.text
# Assert Traceback is shown too
assert "Traceback (most recent call last):" in received_logs
assert "my_fake_module not found" in received_logs
assert "Failed to import plugin test-entrypoint" in received_logs
assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items()


class TestPluginsDirectorySource(unittest.TestCase):
def test_should_return_correct_path_name(self):
Expand All @@ -221,20 +224,23 @@ def test_should_return_correct_path_name(self):
self.assertEqual("<em>$PLUGINS_FOLDER/</em>test_plugins_manager.py", source.__html__())


class TestEntryPointSource(unittest.TestCase):
@mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
def test_should_return_correct_source_details(self, mock_ep_plugins):
class TestEntryPointSource:
def test_should_return_correct_source_details(self):
from airflow import plugins_manager

mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint-plugin'
mock_entrypoint.module_name = 'module_name_plugin'
mock_entrypoint.dist = 'test-entrypoint-plugin==1.0.0'
mock_ep_plugins.return_value = [mock_entrypoint]

plugins_manager.load_entrypoint_plugins()
mock_dist = mock.Mock()
mock_dist.metadata = {'name': 'test-entrypoint-plugin'}
mock_dist.version = '1.0.0'
mock_dist.entry_points = [mock_entrypoint]

with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]):
plugins_manager.load_entrypoint_plugins()

source = plugins_manager.EntryPointSource(mock_entrypoint)
self.assertEqual(str(mock_entrypoint), source.entrypoint)
self.assertEqual("test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint), str(source))
self.assertEqual("<em>test-entrypoint-plugin==1.0.0:</em> " + str(mock_entrypoint), source.__html__())
source = plugins_manager.EntryPointSource(mock_entrypoint, mock_dist)
assert str(mock_entrypoint) == source.entrypoint
assert "test-entrypoint-plugin==1.0.0: " + str(mock_entrypoint) == str(source)
assert "<em>test-entrypoint-plugin==1.0.0:</em> " + str(mock_entrypoint) == source.__html__()
46 changes: 10 additions & 36 deletions tests/www/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.bash import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.plugins_manager import AirflowPlugin, EntryPointSource, PluginsDirectorySource
from airflow.plugins_manager import AirflowPlugin, EntryPointSource
from airflow.security import permissions
from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES
from airflow.utils import dates, timezone
Expand All @@ -67,6 +67,7 @@
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_runs
from tests.test_utils.mock_plugins import mock_plugin_manager


class TemplateWithContext(NamedTuple):
Expand Down Expand Up @@ -337,10 +338,6 @@ class PluginOperator(BaseOperator):
pass


class EntrypointPlugin(AirflowPlugin):
name = 'test-entrypoint-testpluginview'


class TestPluginView(TestBase):
def test_should_list_plugins_on_page_with_details(self):
resp = self.client.get('/plugin')
Expand All @@ -349,45 +346,22 @@ def test_should_list_plugins_on_page_with_details(self):
self.check_content_in_response("source", resp)
self.check_content_in_response("<em>$PLUGINS_FOLDER/</em>test_plugin.py", resp)

@mock.patch('airflow.plugins_manager.pkg_resources.iter_entry_points')
def test_should_list_entrypoint_plugins_on_page_with_details(self, mock_ep_plugins):
from airflow.plugins_manager import load_entrypoint_plugins

mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint-testpluginview'
mock_entrypoint.module_name = 'module_name_testpluginview'
mock_entrypoint.dist = 'test-entrypoint-testpluginview==1.0.0'
mock_entrypoint.load.return_value = EntrypointPlugin
mock_ep_plugins.return_value = [mock_entrypoint]
def test_should_list_entrypoint_plugins_on_page_with_details(self):

load_entrypoint_plugins()
resp = self.client.get('/plugin')
mock_plugin = AirflowPlugin()
mock_plugin.name = "test_plugin"
mock_plugin.source = EntryPointSource(
mock.Mock(), mock.Mock(version='1.0.0', metadata={'name': 'test-entrypoint-testpluginview'})
)
with mock_plugin_manager(plugins=[mock_plugin]):
resp = self.client.get('/plugin')

self.check_content_in_response("test_plugin", resp)
self.check_content_in_response("Airflow Plugins", resp)
self.check_content_in_response("source", resp)
self.check_content_in_response("<em>test-entrypoint-testpluginview==1.0.0:</em> <Mock id=", resp)


class TestPluginsDirectorySource(unittest.TestCase):
def test_should_provide_correct_attribute_values(self):
source = PluginsDirectorySource("./test_views.py")
self.assertEqual("$PLUGINS_FOLDER/../../test_views.py", str(source))
self.assertEqual("<em>$PLUGINS_FOLDER/</em>../../test_views.py", source.__html__())
self.assertEqual("../../test_views.py", source.path)


class TestEntryPointSource(unittest.TestCase):
def test_should_provide_correct_attribute_values(self):
mock_entrypoint = mock.Mock()
mock_entrypoint.dist = 'test-entrypoint-dist==1.0.0'
source = EntryPointSource(mock_entrypoint)
self.assertEqual("test-entrypoint-dist==1.0.0", source.dist)
self.assertEqual(str(mock_entrypoint), source.entrypoint)
self.assertEqual("test-entrypoint-dist==1.0.0: " + str(mock_entrypoint), str(source))
self.assertEqual("<em>test-entrypoint-dist==1.0.0:</em> " + str(mock_entrypoint), source.__html__())


Comment on lines -372 to -390

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were duplicating tests in test_plugins_manager.py -- they don't belong here.

class TestPoolModelView(TestBase):
def setUp(self):
super().setUp()
Expand Down