Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions dev/registry/extract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,17 @@ def _main_discover(
provider_yamls_by_id[pid] = py
provider_paths_by_id[pid] = yaml_path

# Filter to single provider if requested
# Filter to requested providers if --provider was given (supports space-separated list,
# matching extract_metadata.py and extract_connections.py behaviour)
if only_provider:
if only_provider not in provider_paths_by_id:
print(f"ERROR: provider '{only_provider}' not found in provider.yaml files")
requested_providers = {pid.strip() for pid in only_provider.split() if pid.strip()}
missing = requested_providers - set(provider_paths_by_id)
if missing:
print(f"ERROR: provider(s) {sorted(missing)} not found in provider.yaml files")
sys.exit(1)
provider_paths_by_id = {only_provider: provider_paths_by_id[only_provider]}
provider_yamls_by_id = {only_provider: provider_yamls_by_id[only_provider]}
print(f"Filtering to provider: {only_provider}")
provider_paths_by_id = {pid: provider_paths_by_id[pid] for pid in requested_providers}
provider_yamls_by_id = {pid: provider_yamls_by_id[pid] for pid in requested_providers}
print(f"Filtering to provider(s): {', '.join(sorted(requested_providers))}")

# Fetch Sphinx inventories in parallel
print("Fetching Sphinx inventory files ...")
Expand Down
92 changes: 92 additions & 0 deletions dev/registry/tests/test_extract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,95 @@ def test_empty_runtime_all_are_phantoms(self, tmp_path):
stats = compare_with_ast([], modules_json)
assert stats["ast_phantoms"] == 1
assert stats["ast_misses"] == 0


# ---------------------------------------------------------------------------
# _main_discover – provider filter (space-separated list support)
# ---------------------------------------------------------------------------
class TestMainDiscoverProviderFilter:
"""Verify that _main_discover correctly filters providers when only_provider is given.

Heavy I/O (filesystem scan, Sphinx fetches, class discovery, file writes) is mocked
so the tests focus purely on the filter logic introduced to fix the multi-provider bug.
"""

def _make_yaml_paths(self, tmp_path, provider_ids):
"""Write minimal provider.yaml files and return the list of paths."""
paths = []
for pid in provider_ids:
d = tmp_path / pid
d.mkdir(parents=True, exist_ok=True)
yaml_file = d / "provider.yaml"
yaml_file.write_text(
f"package-name: apache-airflow-providers-{pid}\n"
f"name: {pid.title()} Provider\n"
f"description: {pid}\n"
f"versions:\n - 1.0.0\n"
)
paths.append(yaml_file)
return paths

def _run_discover(self, tmp_path, all_providers, only_provider, monkeypatch):
"""Run _main_discover with heavy calls mocked, return provider IDs that were discovered."""
import extract_parameters as ep
from extract_parameters import _main_discover

self._make_yaml_paths(tmp_path, all_providers)

discovered_pids = []

def fake_discover(yaml_path, base_classes, inventory=None, version=""):
pid = yaml_path.parent.name
discovered_pids.append(pid)
return []

monkeypatch.setattr(ep, "PROVIDERS_DIR", tmp_path)
monkeypatch.setattr(ep, "load_base_classes", lambda: {})
monkeypatch.setattr(ep, "_fetch_inventories", lambda pids, yamls: {})
monkeypatch.setattr(ep, "discover_classes_from_provider", fake_discover)
monkeypatch.setattr(ep, "_extract_params_from_modules", lambda modules: ({}, {}, 0, 0, 0))
monkeypatch.setattr(ep, "_write_parameter_files", lambda *a, **kw: None)
# Prevent any modules.json writes
monkeypatch.setattr(ep, "validate_modules_catalog", lambda d: d)
monkeypatch.setattr(ep, "SCRIPT_DIR", tmp_path)

_main_discover(
provider_versions={pid: "1.0.0" for pid in all_providers},
generated_at="2026-01-01T00:00:00+00:00",
only_provider=only_provider,
)
return discovered_pids

def test_single_provider(self, tmp_path, monkeypatch):
"""Single provider ID works as before."""
pids = self._run_discover(tmp_path, ["amazon", "celery", "google"], "amazon", monkeypatch)
assert pids == ["amazon"]

def test_space_separated_two_providers(self, tmp_path, monkeypatch):
"""Space-separated list of two IDs filters to exactly those two."""
pids = self._run_discover(tmp_path, ["amazon", "celery", "google"], "amazon celery", monkeypatch)
assert sorted(pids) == ["amazon", "celery"]

def test_space_separated_three_providers(self, tmp_path, monkeypatch):
"""Space-separated list of three IDs (the real CI scenario) works."""
pids = self._run_discover(
tmp_path, ["amazon", "akeyless", "celery", "google"], "amazon akeyless celery", monkeypatch
)
assert sorted(pids) == ["akeyless", "amazon", "celery"]

def test_no_filter_runs_all(self, tmp_path, monkeypatch):
"""When only_provider is None, all providers are processed."""
pids = self._run_discover(tmp_path, ["amazon", "celery", "google"], None, monkeypatch)
assert sorted(pids) == ["amazon", "celery", "google"]

def test_unknown_provider_exits(self, tmp_path, monkeypatch):
"""Requesting a provider that doesn't exist exits with code 1."""
with pytest.raises(SystemExit) as exc:
self._run_discover(tmp_path, ["amazon", "celery"], "nonexistent", monkeypatch)
assert exc.value.code == 1

def test_partial_unknown_provider_exits(self, tmp_path, monkeypatch):
"""A space-separated list where one ID is unknown exits with code 1."""
with pytest.raises(SystemExit) as exc:
self._run_discover(tmp_path, ["amazon", "celery"], "amazon nonexistent", monkeypatch)
assert exc.value.code == 1
Loading