Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 56 additions & 30 deletions dev/registry/extract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,12 @@ def main():
parser.add_argument(
"--provider",
default=None,
help="Only process this provider ID (e.g. 'amazon'). Skips modules.json write.",
help=(
"Only process these provider ID(s) (space-separated, e.g. 'amazon google'). "
"Writes a partial modules.json containing only the requested providers; "
"merge_registry_data.py replaces those providers' entries in the global catalog "
"while preserving everyone else."
),
)
parser.add_argument(
"--providers-json",
Expand All @@ -845,21 +850,38 @@ def main():
provider_versions[p["id"]] = p["version"]

generated_at = datetime.now(timezone.utc).isoformat()
_main_discover(provider_versions, generated_at, only_provider=args.provider)
_main_discover(
provider_versions,
generated_at,
requested_providers=_parse_requested_providers(args.provider),
)

print("\nDone!")


def _parse_requested_providers(provider_arg: str | None) -> set[str] | None:
"""Parse --provider argument into a set of provider IDs.

Accepts a space-separated string (matching extract_metadata.py and
extract_connections.py). Returns None when the argument is empty so
callers can distinguish "all providers" from "explicit empty set".
"""
if not provider_arg:
return None
return {pid.strip() for pid in provider_arg.split() if pid.strip()}


def _main_discover(
provider_versions: dict[str, str],
generated_at: str,
only_provider: str | None = None,
requested_providers: set[str] | None = None,
) -> None:
"""Runtime discovery: find classes from provider.yaml files, produce modules.json and parameters.

When only_provider is set, only that provider is scanned and modules.json is NOT written
(it would be incomplete). This enables parallel backfills since the only output is
the per-provider parameters.json file.
When ``requested_providers`` is set, only those providers are scanned and the resulting
modules.json is partial (covers only the requested providers). ``merge_registry_data.py``
handles incremental merges by replacing entries for provider IDs present in the new
modules.json while preserving all others, so partial output is safe.
"""
provider_yaml_paths = sorted(PROVIDERS_DIR.rglob("provider.yaml"))
print(f"Found {len(provider_yaml_paths)} provider.yaml files")
Expand All @@ -878,14 +900,15 @@ def _main_discover(
provider_yamls_by_id[pid] = py
provider_paths_by_id[pid] = yaml_path

# Filter to single provider if requested
if only_provider:
if only_provider not in provider_paths_by_id:
print(f"ERROR: provider '{only_provider}' not found in provider.yaml files")
# Filter to requested provider(s) if specified
if requested_providers:
missing = requested_providers - set(provider_paths_by_id)
if missing:
print(f"ERROR: provider(s) not found in provider.yaml files: {sorted(missing)}")
sys.exit(1)
provider_paths_by_id = {only_provider: provider_paths_by_id[only_provider]}
provider_yamls_by_id = {only_provider: provider_yamls_by_id[only_provider]}
print(f"Filtering to provider: {only_provider}")
provider_paths_by_id = {pid: provider_paths_by_id[pid] for pid in requested_providers}
provider_yamls_by_id = {pid: provider_yamls_by_id[pid] for pid in requested_providers}
print(f"Filtering to provider(s): {', '.join(sorted(requested_providers))}")

# Fetch Sphinx inventories in parallel
print("Fetching Sphinx inventory files ...")
Expand Down Expand Up @@ -920,21 +943,26 @@ def _main_discover(
all_discovered = unique_modules
print(f"Deduplicated to {len(all_discovered)} unique modules")

# Write modules.json only when doing a full build (no --provider filter).
# With --provider, the output would be incomplete and would clobber the
# full modules.json from a previous build.
if not only_provider:
modules_json = validate_modules_catalog({"modules": all_discovered})
output_dirs = [SCRIPT_DIR, AIRFLOW_ROOT / "registry" / "src" / "_data"]
for out_dir in output_dirs:
if not out_dir.parent.exists():
continue
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "modules.json", "w") as f:
json.dump(modules_json, f, indent=2)
print(f"Wrote {len(all_discovered)} modules to {out_dir / 'modules.json'}")

# Write runtime_modules.json (debug/stats file)
# Write modules.json. In --provider mode this is partial (covers only the
# requested providers); merge_registry_data.py drives module replacement
# off the provider IDs present in this file, so non-requested providers'
# entries are preserved untouched in the global catalog.
modules_json = validate_modules_catalog({"modules": all_discovered})
scope_label = (
f"partial, providers: {', '.join(sorted(requested_providers))}" if requested_providers else "full"
)
output_dirs = [SCRIPT_DIR, AIRFLOW_ROOT / "registry" / "src" / "_data"]
for out_dir in output_dirs:
if not out_dir.parent.exists():
continue
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "modules.json", "w") as f:
json.dump(modules_json, f, indent=2)
print(f"Wrote {len(all_discovered)} modules ({scope_label}) to {out_dir / 'modules.json'}")

# Write runtime_modules.json (debug/stats file). Only meaningful for full
# builds; skip in --provider mode since it would only show partial stats.
if not requested_providers:
runtime_output = {
"generated_at": generated_at,
"discovery_method": "runtime",
Expand All @@ -948,8 +976,6 @@ def _main_discover(
with open(runtime_json_path, "w") as f:
json.dump(runtime_output, f, indent=2)
print(f"Wrote {runtime_json_path}")
else:
print("Skipping modules.json write (--provider mode)")

# Extract parameters
print("\nExtracting parameters from runtime-discovered classes...")
Expand Down
23 changes: 17 additions & 6 deletions dev/registry/merge_registry_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,27 @@ def merge(
if new_modules_path.exists():
new_modules = json.loads(new_modules_path.read_text())["modules"]

# IDs being replaced
new_ids = {p["id"] for p in new_providers}
# Provider IDs being replaced in providers.json
new_provider_ids = {p["id"] for p in new_providers}

# Merge providers: keep existing (except those being replaced), add new
merged_providers = [p for p in existing_providers if p["id"] not in new_ids]
merged_providers = [p for p in existing_providers if p["id"] not in new_provider_ids]
merged_providers.extend(new_providers)
merged_providers.sort(key=lambda p: p["name"].lower())

# Merge modules: keep existing (except for replaced providers), add new
merged_modules = [m for m in existing_modules if m["provider_id"] not in new_ids]
# Merge modules: replace ONLY for providers present in the new modules
# payload. Driving the drop off new_modules (rather than new_providers)
# avoids silent catalog corruption when an extract run wrote providers.json
# but no modules.json -- previously every targeted provider's modules were
# dropped with nothing to replace them.
#
# Trade-off: a provider that genuinely went from N modules to 0 between
# releases (refactored away, deprecated to config-only) leaves stale
# entries here -- they'd need a full build to clear. That's preferable to
# the previous behaviour where every single-provider incremental update
# silently wiped the targeted provider's modules.
new_module_provider_ids = {m["provider_id"] for m in new_modules}
merged_modules = [m for m in existing_modules if m["provider_id"] not in new_module_provider_ids]
merged_modules.extend(new_modules)

# Sort modules by provider's last_updated date (newest first)
Expand All @@ -84,7 +95,7 @@ def merge(
(output_dir / "providers.json").write_text(json.dumps(providers_payload, indent=2) + "\n")
(output_dir / "modules.json").write_text(json.dumps(modules_payload, indent=2) + "\n")

print(f"Merged {len(new_ids)} updated provider(s) into {len(merged_providers)} total providers")
print(f"Merged {len(new_provider_ids)} updated provider(s) into {len(merged_providers)} total providers")
print(f"Total modules: {len(merged_modules)}")


Expand Down
33 changes: 33 additions & 0 deletions dev/registry/tests/test_extract_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from extract_parameters import (
Module,
_get_source_line,
_parse_requested_providers,
_should_skip_class,
compare_with_ast,
discover_classes_from_provider,
Expand Down Expand Up @@ -808,3 +809,35 @@ def test_empty_runtime_all_are_phantoms(self, tmp_path):
stats = compare_with_ast([], modules_json)
assert stats["ast_phantoms"] == 1
assert stats["ast_misses"] == 0


# ---------------------------------------------------------------------------
# _parse_requested_providers
# ---------------------------------------------------------------------------
class TestParseRequestedProviders:
def test_none_argument_returns_none(self):
assert _parse_requested_providers(None) is None

def test_empty_string_returns_none(self):
assert _parse_requested_providers("") is None

def test_whitespace_only_returns_empty_set(self):
# Empty set is falsy, so downstream `if requested_providers` skips
# the filter just like None does.
assert _parse_requested_providers(" ") == set()

def test_single_provider(self):
assert _parse_requested_providers("amazon") == {"amazon"}

def test_multiple_providers_space_separated(self):
assert _parse_requested_providers("amazon google snowflake") == {
"amazon",
"google",
"snowflake",
}

def test_extra_whitespace_is_tolerated(self):
assert _parse_requested_providers(" amazon google ") == {"amazon", "google"}

def test_duplicate_providers_collapsed(self):
assert _parse_requested_providers("amazon amazon google") == {"amazon", "google"}
94 changes: 89 additions & 5 deletions dev/registry/tests/test_merge_registry_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,13 @@ def test_missing_existing_providers_file(self, tmp_path, output_dir):
assert result_providers[0]["id"] == "amazon"

def test_missing_new_modules_file(self, tmp_path, output_dir):
"""Incremental extract with --provider skips modules.json; merge should keep existing modules."""
"""When new_modules.json is missing, merge preserves all existing modules.

Driving module drops off the new_modules payload (rather than new_providers)
means an extract run that wrote providers.json without writing modules.json
leaves the global catalog untouched, instead of silently dropping every
targeted provider's modules.
"""
existing_providers = _write_json(
tmp_path / "existing_providers.json",
{
Expand All @@ -265,16 +271,94 @@ def test_missing_new_modules_file(self, tmp_path, output_dir):
tmp_path / "new_providers.json",
{"providers": [_provider("amazon", "Amazon", "2025-01-01")]},
)
# new_modules file does not exist (--provider mode skips modules.json)
# new_modules file does not exist (extract failure or skipped write)
new_modules = tmp_path / "nonexistent_modules.json"

merge(existing_providers, existing_modules, new_providers, new_modules, output_dir)

result_modules = json.loads((output_dir / "modules.json").read_text())["modules"]
# Existing modules for non-updated providers are kept
# Both providers' existing modules are preserved -- no replacement available
assert any(m["id"] == "amazon-s3-op" for m in result_modules)
assert any(m["id"] == "google-bq-op" for m in result_modules)

def test_empty_new_modules_preserves_existing(self, tmp_path, output_dir):
"""An empty modules: [] payload behaves the same as a missing file.

Same scenario as test_missing_new_modules_file but the file exists with an
empty list, which is what older extract runs would write before the
unconditional-write fix.
"""
existing_providers = _write_json(
tmp_path / "existing_providers.json",
{"providers": [_provider("amazon", "Amazon", "2024-01-01")]},
)
existing_modules = _write_json(
tmp_path / "existing_modules.json",
{"modules": [_module("amazon-s3-op", "amazon")]},
)
new_providers = _write_json(
tmp_path / "new_providers.json",
{"providers": [_provider("amazon", "Amazon", "2025-01-01")]},
)
new_modules = _write_json(tmp_path / "new_modules.json", {"modules": []})

merge(existing_providers, existing_modules, new_providers, new_modules, output_dir)

result_modules = json.loads((output_dir / "modules.json").read_text())["modules"]
assert any(m["id"] == "amazon-s3-op" for m in result_modules)

def test_partial_new_modules_replaces_only_listed_providers(self, tmp_path, output_dir):
"""Partial modules.json (one of several updated providers) replaces only that provider's modules.

Mirrors --provider single-target extraction: providers.json may list multiple
provider updates, but the modules payload covers only the one whose runtime
extraction produced new module data. The merge must not touch the others.
"""
existing_providers = _write_json(
tmp_path / "existing_providers.json",
{
"providers": [
_provider("amazon", "Amazon", "2024-01-01"),
_provider("google", "Google", "2024-02-01"),
_provider("snowflake", "Snowflake", "2024-03-01"),
]
},
)
existing_modules = _write_json(
tmp_path / "existing_modules.json",
{
"modules": [
_module("amazon-s3-op", "amazon"),
_module("google-bq-op", "google"),
_module("snowflake-warehouse-op", "snowflake"),
]
},
)
new_providers = _write_json(
tmp_path / "new_providers.json",
{
"providers": [
_provider("amazon", "Amazon", "2025-01-01"),
_provider("google", "Google", "2025-02-01"),
]
},
)
# Only amazon has fresh module data; google's modules are not in the payload
new_modules = _write_json(
tmp_path / "new_modules.json",
{"modules": [_module("amazon-lambda-op", "amazon")]},
)

merge(existing_providers, existing_modules, new_providers, new_modules, output_dir)

result_modules = json.loads((output_dir / "modules.json").read_text())["modules"]
# Amazon's existing module is replaced
assert any(m["id"] == "amazon-lambda-op" for m in result_modules)
assert not any(m["id"] == "amazon-s3-op" for m in result_modules)
# Google's existing module is preserved (no replacement payload)
assert any(m["id"] == "google-bq-op" for m in result_modules)
# Existing modules for the updated provider are removed (no new ones to replace them)
assert not any(m["provider_id"] == "amazon" for m in result_modules)
# Snowflake (not in new_providers either) is untouched
assert any(m["id"] == "snowflake-warehouse-op" for m in result_modules)

def test_output_directory_created_if_missing(self, tmp_path):
output_dir = tmp_path / "does" / "not" / "exist"
Expand Down
Loading