diff --git a/dev/registry/extract_parameters.py b/dev/registry/extract_parameters.py index 26f09320d7820..0078d40d4d6c2 100644 --- a/dev/registry/extract_parameters.py +++ b/dev/registry/extract_parameters.py @@ -821,7 +821,12 @@ def main(): parser.add_argument( "--provider", default=None, - help="Only process this provider ID (e.g. 'amazon'). Skips modules.json write.", + help=( + "Only process these provider ID(s) (space-separated, e.g. 'amazon google'). " + "Writes a partial modules.json containing only the requested providers; " + "merge_registry_data.py replaces those providers' entries in the global catalog " + "while preserving everyone else." + ), ) parser.add_argument( "--providers-json", @@ -845,21 +850,38 @@ def main(): provider_versions[p["id"]] = p["version"] generated_at = datetime.now(timezone.utc).isoformat() - _main_discover(provider_versions, generated_at, only_provider=args.provider) + _main_discover( + provider_versions, + generated_at, + requested_providers=_parse_requested_providers(args.provider), + ) print("\nDone!") +def _parse_requested_providers(provider_arg: str | None) -> set[str] | None: + """Parse --provider argument into a set of provider IDs. + + Accepts a space-separated string (matching extract_metadata.py and + extract_connections.py). Returns None when the argument is empty so + callers can distinguish "all providers" from "explicit empty set". + """ + if not provider_arg: + return None + return {pid.strip() for pid in provider_arg.split() if pid.strip()} + + def _main_discover( provider_versions: dict[str, str], generated_at: str, - only_provider: str | None = None, + requested_providers: set[str] | None = None, ) -> None: """Runtime discovery: find classes from provider.yaml files, produce modules.json and parameters. - When only_provider is set, only that provider is scanned and modules.json is NOT written - (it would be incomplete). This enables parallel backfills since the only output is - the per-provider parameters.json file. + When ``requested_providers`` is set, only those providers are scanned and the resulting + modules.json is partial (covers only the requested providers). ``merge_registry_data.py`` + handles incremental merges by replacing entries for provider IDs present in the new + modules.json while preserving all others, so partial output is safe. """ provider_yaml_paths = sorted(PROVIDERS_DIR.rglob("provider.yaml")) print(f"Found {len(provider_yaml_paths)} provider.yaml files") @@ -878,14 +900,15 @@ def _main_discover( provider_yamls_by_id[pid] = py provider_paths_by_id[pid] = yaml_path - # Filter to single provider if requested - if only_provider: - if only_provider not in provider_paths_by_id: - print(f"ERROR: provider '{only_provider}' not found in provider.yaml files") + # Filter to requested provider(s) if specified + if requested_providers: + missing = requested_providers - set(provider_paths_by_id) + if missing: + print(f"ERROR: provider(s) not found in provider.yaml files: {sorted(missing)}") sys.exit(1) - provider_paths_by_id = {only_provider: provider_paths_by_id[only_provider]} - provider_yamls_by_id = {only_provider: provider_yamls_by_id[only_provider]} - print(f"Filtering to provider: {only_provider}") + provider_paths_by_id = {pid: provider_paths_by_id[pid] for pid in requested_providers} + provider_yamls_by_id = {pid: provider_yamls_by_id[pid] for pid in requested_providers} + print(f"Filtering to provider(s): {', '.join(sorted(requested_providers))}") # Fetch Sphinx inventories in parallel print("Fetching Sphinx inventory files ...") @@ -920,21 +943,26 @@ def _main_discover( all_discovered = unique_modules print(f"Deduplicated to {len(all_discovered)} unique modules") - # Write modules.json only when doing a full build (no --provider filter). - # With --provider, the output would be incomplete and would clobber the - # full modules.json from a previous build. - if not only_provider: - modules_json = validate_modules_catalog({"modules": all_discovered}) - output_dirs = [SCRIPT_DIR, AIRFLOW_ROOT / "registry" / "src" / "_data"] - for out_dir in output_dirs: - if not out_dir.parent.exists(): - continue - out_dir.mkdir(parents=True, exist_ok=True) - with open(out_dir / "modules.json", "w") as f: - json.dump(modules_json, f, indent=2) - print(f"Wrote {len(all_discovered)} modules to {out_dir / 'modules.json'}") - - # Write runtime_modules.json (debug/stats file) + # Write modules.json. In --provider mode this is partial (covers only the + # requested providers); merge_registry_data.py drives module replacement + # off the provider IDs present in this file, so non-requested providers' + # entries are preserved untouched in the global catalog. + modules_json = validate_modules_catalog({"modules": all_discovered}) + scope_label = ( + f"partial, providers: {', '.join(sorted(requested_providers))}" if requested_providers else "full" + ) + output_dirs = [SCRIPT_DIR, AIRFLOW_ROOT / "registry" / "src" / "_data"] + for out_dir in output_dirs: + if not out_dir.parent.exists(): + continue + out_dir.mkdir(parents=True, exist_ok=True) + with open(out_dir / "modules.json", "w") as f: + json.dump(modules_json, f, indent=2) + print(f"Wrote {len(all_discovered)} modules ({scope_label}) to {out_dir / 'modules.json'}") + + # Write runtime_modules.json (debug/stats file). Only meaningful for full + # builds; skip in --provider mode since it would only show partial stats. + if not requested_providers: runtime_output = { "generated_at": generated_at, "discovery_method": "runtime", @@ -948,8 +976,6 @@ def _main_discover( with open(runtime_json_path, "w") as f: json.dump(runtime_output, f, indent=2) print(f"Wrote {runtime_json_path}") - else: - print("Skipping modules.json write (--provider mode)") # Extract parameters print("\nExtracting parameters from runtime-discovered classes...") diff --git a/dev/registry/merge_registry_data.py b/dev/registry/merge_registry_data.py index 2bdbe50da12b4..8b5b820f228ca 100644 --- a/dev/registry/merge_registry_data.py +++ b/dev/registry/merge_registry_data.py @@ -60,16 +60,27 @@ def merge( if new_modules_path.exists(): new_modules = json.loads(new_modules_path.read_text())["modules"] - # IDs being replaced - new_ids = {p["id"] for p in new_providers} + # Provider IDs being replaced in providers.json + new_provider_ids = {p["id"] for p in new_providers} # Merge providers: keep existing (except those being replaced), add new - merged_providers = [p for p in existing_providers if p["id"] not in new_ids] + merged_providers = [p for p in existing_providers if p["id"] not in new_provider_ids] merged_providers.extend(new_providers) merged_providers.sort(key=lambda p: p["name"].lower()) - # Merge modules: keep existing (except for replaced providers), add new - merged_modules = [m for m in existing_modules if m["provider_id"] not in new_ids] + # Merge modules: replace ONLY for providers present in the new modules + # payload. Driving the drop off new_modules (rather than new_providers) + # avoids silent catalog corruption when an extract run wrote providers.json + # but no modules.json -- previously every targeted provider's modules were + # dropped with nothing to replace them. + # + # Trade-off: a provider that genuinely went from N modules to 0 between + # releases (refactored away, deprecated to config-only) leaves stale + # entries here -- they'd need a full build to clear. That's preferable to + # the previous behaviour where every single-provider incremental update + # silently wiped the targeted provider's modules. + new_module_provider_ids = {m["provider_id"] for m in new_modules} + merged_modules = [m for m in existing_modules if m["provider_id"] not in new_module_provider_ids] merged_modules.extend(new_modules) # Sort modules by provider's last_updated date (newest first) @@ -84,7 +95,7 @@ def merge( (output_dir / "providers.json").write_text(json.dumps(providers_payload, indent=2) + "\n") (output_dir / "modules.json").write_text(json.dumps(modules_payload, indent=2) + "\n") - print(f"Merged {len(new_ids)} updated provider(s) into {len(merged_providers)} total providers") + print(f"Merged {len(new_provider_ids)} updated provider(s) into {len(merged_providers)} total providers") print(f"Total modules: {len(merged_modules)}") diff --git a/dev/registry/tests/test_extract_parameters.py b/dev/registry/tests/test_extract_parameters.py index 30a6678391301..bfa3b05b56fd5 100644 --- a/dev/registry/tests/test_extract_parameters.py +++ b/dev/registry/tests/test_extract_parameters.py @@ -26,6 +26,7 @@ from extract_parameters import ( Module, _get_source_line, + _parse_requested_providers, _should_skip_class, compare_with_ast, discover_classes_from_provider, @@ -808,3 +809,35 @@ def test_empty_runtime_all_are_phantoms(self, tmp_path): stats = compare_with_ast([], modules_json) assert stats["ast_phantoms"] == 1 assert stats["ast_misses"] == 0 + + +# --------------------------------------------------------------------------- +# _parse_requested_providers +# --------------------------------------------------------------------------- +class TestParseRequestedProviders: + def test_none_argument_returns_none(self): + assert _parse_requested_providers(None) is None + + def test_empty_string_returns_none(self): + assert _parse_requested_providers("") is None + + def test_whitespace_only_returns_empty_set(self): + # Empty set is falsy, so downstream `if requested_providers` skips + # the filter just like None does. + assert _parse_requested_providers(" ") == set() + + def test_single_provider(self): + assert _parse_requested_providers("amazon") == {"amazon"} + + def test_multiple_providers_space_separated(self): + assert _parse_requested_providers("amazon google snowflake") == { + "amazon", + "google", + "snowflake", + } + + def test_extra_whitespace_is_tolerated(self): + assert _parse_requested_providers(" amazon google ") == {"amazon", "google"} + + def test_duplicate_providers_collapsed(self): + assert _parse_requested_providers("amazon amazon google") == {"amazon", "google"} diff --git a/dev/registry/tests/test_merge_registry_data.py b/dev/registry/tests/test_merge_registry_data.py index e3f78fd4b792e..2980e8bb5ef9a 100644 --- a/dev/registry/tests/test_merge_registry_data.py +++ b/dev/registry/tests/test_merge_registry_data.py @@ -242,7 +242,13 @@ def test_missing_existing_providers_file(self, tmp_path, output_dir): assert result_providers[0]["id"] == "amazon" def test_missing_new_modules_file(self, tmp_path, output_dir): - """Incremental extract with --provider skips modules.json; merge should keep existing modules.""" + """When new_modules.json is missing, merge preserves all existing modules. + + Driving module drops off the new_modules payload (rather than new_providers) + means an extract run that wrote providers.json without writing modules.json + leaves the global catalog untouched, instead of silently dropping every + targeted provider's modules. + """ existing_providers = _write_json( tmp_path / "existing_providers.json", { @@ -265,16 +271,94 @@ def test_missing_new_modules_file(self, tmp_path, output_dir): tmp_path / "new_providers.json", {"providers": [_provider("amazon", "Amazon", "2025-01-01")]}, ) - # new_modules file does not exist (--provider mode skips modules.json) + # new_modules file does not exist (extract failure or skipped write) new_modules = tmp_path / "nonexistent_modules.json" merge(existing_providers, existing_modules, new_providers, new_modules, output_dir) result_modules = json.loads((output_dir / "modules.json").read_text())["modules"] - # Existing modules for non-updated providers are kept + # Both providers' existing modules are preserved -- no replacement available + assert any(m["id"] == "amazon-s3-op" for m in result_modules) + assert any(m["id"] == "google-bq-op" for m in result_modules) + + def test_empty_new_modules_preserves_existing(self, tmp_path, output_dir): + """An empty modules: [] payload behaves the same as a missing file. + + Same scenario as test_missing_new_modules_file but the file exists with an + empty list, which is what older extract runs would write before the + unconditional-write fix. + """ + existing_providers = _write_json( + tmp_path / "existing_providers.json", + {"providers": [_provider("amazon", "Amazon", "2024-01-01")]}, + ) + existing_modules = _write_json( + tmp_path / "existing_modules.json", + {"modules": [_module("amazon-s3-op", "amazon")]}, + ) + new_providers = _write_json( + tmp_path / "new_providers.json", + {"providers": [_provider("amazon", "Amazon", "2025-01-01")]}, + ) + new_modules = _write_json(tmp_path / "new_modules.json", {"modules": []}) + + merge(existing_providers, existing_modules, new_providers, new_modules, output_dir) + + result_modules = json.loads((output_dir / "modules.json").read_text())["modules"] + assert any(m["id"] == "amazon-s3-op" for m in result_modules) + + def test_partial_new_modules_replaces_only_listed_providers(self, tmp_path, output_dir): + """Partial modules.json (one of several updated providers) replaces only that provider's modules. + + Mirrors --provider single-target extraction: providers.json may list multiple + provider updates, but the modules payload covers only the one whose runtime + extraction produced new module data. The merge must not touch the others. + """ + existing_providers = _write_json( + tmp_path / "existing_providers.json", + { + "providers": [ + _provider("amazon", "Amazon", "2024-01-01"), + _provider("google", "Google", "2024-02-01"), + _provider("snowflake", "Snowflake", "2024-03-01"), + ] + }, + ) + existing_modules = _write_json( + tmp_path / "existing_modules.json", + { + "modules": [ + _module("amazon-s3-op", "amazon"), + _module("google-bq-op", "google"), + _module("snowflake-warehouse-op", "snowflake"), + ] + }, + ) + new_providers = _write_json( + tmp_path / "new_providers.json", + { + "providers": [ + _provider("amazon", "Amazon", "2025-01-01"), + _provider("google", "Google", "2025-02-01"), + ] + }, + ) + # Only amazon has fresh module data; google's modules are not in the payload + new_modules = _write_json( + tmp_path / "new_modules.json", + {"modules": [_module("amazon-lambda-op", "amazon")]}, + ) + + merge(existing_providers, existing_modules, new_providers, new_modules, output_dir) + + result_modules = json.loads((output_dir / "modules.json").read_text())["modules"] + # Amazon's existing module is replaced + assert any(m["id"] == "amazon-lambda-op" for m in result_modules) + assert not any(m["id"] == "amazon-s3-op" for m in result_modules) + # Google's existing module is preserved (no replacement payload) assert any(m["id"] == "google-bq-op" for m in result_modules) - # Existing modules for the updated provider are removed (no new ones to replace them) - assert not any(m["provider_id"] == "amazon" for m in result_modules) + # Snowflake (not in new_providers either) is untouched + assert any(m["id"] == "snowflake-warehouse-op" for m in result_modules) def test_output_directory_created_if_missing(self, tmp_path): output_dir = tmp_path / "does" / "not" / "exist"