Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,13 @@ repos:
entry: ./scripts/ci/prek/check_shared_distributions_structure.py
language: python
pass_filenames: false
files: ^shared/.*$
files: >
(?x)
^shared/.*$|
^.*/pyproject\.toml$|
^.*/_shared/.*$|
^airflow-core/.*\.py$|
^task-sdk/.*\.py$
- id: check-shared-distributions-usage
name: Check shared distributions usage
entry: ./scripts/ci/prek/check_shared_distributions_usage.py
Expand Down
1 change: 1 addition & 0 deletions dev/breeze/src/airflow_breeze/utils/selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __hash__(self):
r"^Dockerfile",
r"^scripts/ci/docker-compose",
r"^scripts/ci/kubernetes",
r"^scripts/ci/prek",
r"^scripts/docker",
r"^scripts/in_container",
r"^generated/provider_dependencies.json$",
Expand Down
4 changes: 2 additions & 2 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,9 +1603,9 @@ def test_excluded_providers():
pytest.param(
("scripts/ci/prek/file.sh",),
{
"full-tests-needed": "false",
"full-tests-needed": "true",
},
id="No full tests needed when prek scripts change",
id="Full tests needed when prek scripts change",
)
),
(
Expand Down
325 changes: 325 additions & 0 deletions scripts/ci/prek/check_shared_distributions_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from __future__ import annotations

import ast
import itertools
import re
import sys
from pathlib import Path
Expand All @@ -40,6 +42,8 @@
from common_prek_utils import AIRFLOW_ROOT_PATH, console

SHARED_DIR = AIRFLOW_ROOT_PATH / "shared"
TASK_SDK_DIR = AIRFLOW_ROOT_PATH / "task-sdk"
AIRFLOW_CORE_DIR = AIRFLOW_ROOT_PATH / "airflow-core"


def check_pyproject_exists(shared_path: Path) -> bool:
Expand Down Expand Up @@ -174,6 +178,199 @@ def check_ruff_lint_rules(ruff: dict, shared_path: Path) -> bool:
return False


def _parse_python_file(py_file: Path, base_path: Path) -> ast.Module | None:
"""Parse a Python file and return an AST tree, handling errors gracefully."""
try:
with open(py_file, encoding="utf-8") as f:
return ast.parse(f.read(), filename=str(py_file))
except SyntaxError as e:
console.print(f" [yellow]Warning: Could not parse {py_file.relative_to(base_path)}: {e}[/yellow]")
return None
except Exception as e:
console.print(f" [yellow]Warning: Error reading {py_file.relative_to(base_path)}: {e}[/yellow]")
return None


def _is_type_checking_guard(node: ast.If) -> bool:
"""Check if an ``If`` node is a ``TYPE_CHECKING`` guard."""
test = node.test
if isinstance(test, ast.Name) and test.id == "TYPE_CHECKING":
return True
if isinstance(test, ast.Attribute) and test.attr == "TYPE_CHECKING":
return True
return False


def _collect_type_checking_node_ids(tree: ast.AST) -> set[int]:
"""Return the ``id()`` of every AST node nested inside an ``if TYPE_CHECKING`` block."""
guarded: set[int] = set()
for node in ast.walk(tree):
if isinstance(node, ast.If) and _is_type_checking_guard(node):
for child in ast.walk(node):
guarded.add(id(child))
return guarded


def _check_imports_in_files(
py_files: list[Path],
base_path: Path,
import_predicate,
dist_name: str,
) -> list[tuple[Path, int, str]]:
"""
Check imports in Python files based on a predicate function.

Imports nested inside ``if TYPE_CHECKING:`` blocks are skipped — they are type-only
and never resolved at runtime, so they cannot create real dependencies between
distributions.

Args:
py_files: List of Python files to check
base_path: Base path for relative path calculation
import_predicate: Function that takes (node, alias/module) and returns (should_report, import_stmt)
dist_name: Distribution name for logging

Returns:
List of violations (file_path, lineno, import_stmt)
"""
violations = []
console.print(f" Checking imports in {len(py_files)} files for [magenta]{dist_name}[/magenta]")
for py_file in py_files:
tree = _parse_python_file(py_file, base_path)
if tree is None:
continue

type_checking_ids = _collect_type_checking_node_ids(tree)

for node in ast.walk(tree):
if id(node) in type_checking_ids:
continue
if isinstance(node, ast.Import):
for alias in node.names:
should_report, import_stmt = import_predicate(node, alias.name, is_from_import=False)
if should_report:
violations.append((py_file, node.lineno, import_stmt))

elif isinstance(node, ast.ImportFrom):
if node.module:
should_report, import_stmt = import_predicate(node, node.module, is_from_import=True)
if should_report:
violations.append((py_file, node.lineno, import_stmt))

return violations


def check_no_airflow_dependencies(pyproject: dict, shared_path: Path) -> bool:
"""Check that airflow packages are not listed as dependencies."""
forbidden_packages = {"apache-airflow", "apache-airflow-core", "apache-airflow-task-sdk"}
violations = []

# Check regular dependencies
dependencies = pyproject.get("project", {}).get("dependencies", [])
for dep in dependencies:
# Parse dependency string (may include version specifiers)
dep_name = dep.split("[")[0].split(">")[0].split("<")[0].split("=")[0].split("!")[0].strip()
if dep_name in forbidden_packages:
violations.append(("dependencies", dep))

# Check optional dependencies
optional_deps = pyproject.get("project", {}).get("optional-dependencies", {})
for group_name, deps in optional_deps.items():
for dep in deps:
dep_name = dep.split("[")[0].split(">")[0].split("<")[0].split("=")[0].split("!")[0].strip()
if dep_name in forbidden_packages:
violations.append((f"optional-dependencies.{group_name}", dep))

# Check dependency groups (PEP 735)
dependency_groups = pyproject.get("dependency-groups", {})
for group_name, deps in dependency_groups.items():
for dep in deps:
dep_name = dep.split("[")[0].split(">")[0].split("<")[0].split("=")[0].split("!")[0].strip()
if dep_name in forbidden_packages:
violations.append((f"dependency-groups.{group_name}", dep))

if violations:
console.print(
f" [red]Found forbidden airflow dependencies in [magenta]{shared_path.name}[/magenta]:[/red]"
)
for location, dep in violations:
console.print(f" [red]{location}: {dep}[/red]")
console.print()
console.print(f" [red]Shared distributions must not depend on {', '.join(forbidden_packages)}[/red]")
return False

console.print(
f" No forbidden airflow dependencies in [magenta]{shared_path.name}[/magenta] [bold green]OK[/bold green]"
)
return True


def check_no_airflow_imports(shared_path: Path) -> bool:
"""Check that no Python files in src/ or tests/ import from airflow using AST."""
src_path = shared_path / "src"
test_path = shared_path / "tests"
if not src_path.exists():
console.print(
f" [yellow]src/ directory does not exist for [magenta]{shared_path.name}[/magenta][/yellow]"
)
return True

def airflow_import_predicate(node, module_name, is_from_import):
"""Check if import is from airflow package."""
if module_name == "airflow" or module_name.startswith("airflow."):
if is_from_import:
imported_names = ", ".join(alias.name for alias in node.names)
return True, f"from {module_name} import {imported_names}"
return True, f"import {module_name}"
return False, ""

py_files = list(itertools.chain(src_path.rglob("*.py"), test_path.rglob("*.py")))
violations = _check_imports_in_files(py_files, shared_path, airflow_import_predicate, shared_path.name)

if violations:
console.print(f" [red]Found airflow imports in [magenta]{shared_path.name}[/magenta]:[/red]")
for file_path, lineno, import_stmt in violations:
rel_path = file_path.relative_to(shared_path)
console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]")
console.print()
console.print(
f" [red]Please remove airflow imports from [magenta]{shared_path.name}[/magenta][/red] "
f"in the way described below:"
)
console.print()
console.print(" [yellow]How to fix: When you see `airflow` import in `src`:[/]")
console.print()
console.print(
"[bright_blue]"
" * When you want to use `from airflow.**._shared.ANOTHER_DISTRIBUTION` - use \n"
" relative import `from ..ANOTHER_DISTRIBUTION` (might need ../../ or ../../..)"
)
console.print(
" [bright_blue]"
" * When you want to use `from airflow.MODULE`, move the module to another shared \n"
" distribution (sometimes it might cause loss of DRY)"
)
console.print()
console.print(" [yellow]How to fix: When you see `airflow` import in `tests`:[/]")
console.print()
console.print(
"[bright_blue]"
" * When you want to use `from airflow.**._shared.MY_DISTRIBUTION`\n"
" use `from airflow_shared.MY_DISTRIBUTION import ...` "
)
console.print(
"[bright_blue]"
" * When you want to use `from airflow.MODULE` where it is some test-related code \n"
" this code should likely be moved to `devel-common`"
)
return False

console.print(
f" No airflow imports found in [magenta]{shared_path.name}[/magenta] [bold green]OK[/bold green]"
)
return True


def check_shared_distribution(shared_path: Path) -> bool:
pyproject_path = shared_path / "pyproject.toml"
console.print(f"\n[bold blue]Checking:[/bold blue] [magenta]{shared_path.name}[/magenta] shared project")
Expand Down Expand Up @@ -201,10 +398,134 @@ def check_shared_distribution(shared_path: Path) -> bool:
return False
if not check_ruff_lint_rules(ruff, shared_path):
return False
if not check_no_airflow_dependencies(pyproject, shared_path):
return False
if not check_no_airflow_imports(shared_path):
return False
console.print(f"[bold green]Summary: {shared_path.name} is OK[/bold green]")
return True


def check_no_airflow_shared_imports(dist_path: Path, dist_name: str) -> bool:
"""Check that no Python files use ``airflow_shared`` imports.

The ``airflow_shared`` namespace is what shared distributions expose to consumers
that install them as PyPI packages. Inside the monorepo, distributions like
``task-sdk`` and ``airflow-core`` consume shared code via the ``_shared`` symlinks
(e.g. ``airflow.sdk._shared.*``, ``airflow._shared.*``) and must NOT import the
PyPI namespace directly.
"""
src_path = dist_path / "src"
if not src_path.exists():
console.print(f" [yellow]src/ directory does not exist for [magenta]{dist_name}[/magenta][/yellow]")
return True

def airflow_shared_import_predicate(node, module_name, is_from_import):
"""Check if import is from airflow_shared package."""
if module_name == "airflow_shared" or module_name.startswith("airflow_shared."):
if is_from_import:
imported_names = ", ".join(alias.name for alias in node.names)
return True, f"from {module_name} import {imported_names}"
return True, f"import {module_name}"
return False, ""

py_files = list(src_path.rglob("*.py"))
violations = _check_imports_in_files(py_files, dist_path, airflow_shared_import_predicate, dist_name)

if violations:
console.print(f" [red]Found airflow_shared imports in [magenta]{dist_name}[/magenta]:[/red]")
for file_path, lineno, import_stmt in violations:
rel_path = file_path.relative_to(dist_path)
console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]")
console.print()
console.print(
f" [red]Please do not use airflow_shared imports in [magenta]{dist_name}[/magenta][/red]"
)
console.print(
" [yellow]Use proper _shared imports instead (e.g., airflow._shared.* or airflow.sdk._shared.*)[/yellow]"
)
return False

console.print(
f" No airflow_shared imports found in [magenta]{dist_name}[/magenta] [bold green]OK[/bold green]"
)
return True


def check_only_allowed_shared_imports(dist_path: Path, dist_name: str, allowed_prefix: str) -> bool:
"""Check that only imports with the allowed _shared prefix are used."""
src_path = dist_path / "src"
if not src_path.exists():
console.print(f" [yellow]src/ directory does not exist for [magenta]{dist_name}[/magenta][/yellow]")
return True

def allowed_shared_import_predicate(node, module_name, is_from_import):
"""Check if _shared import uses the correct prefix."""
if "._shared" in module_name or module_name.endswith("._shared"):
if not module_name.startswith(allowed_prefix):
if is_from_import:
imported_names = ", ".join(alias.name for alias in node.names)
return True, f"from {module_name} import {imported_names}"
return True, f"import {module_name}"
return False, ""

py_files = list(src_path.rglob("*.py"))
violations = _check_imports_in_files(py_files, dist_path, allowed_shared_import_predicate, dist_name)

if violations:
console.print(f" [red]Found disallowed _shared imports in [magenta]{dist_name}[/magenta]:[/red]")
for file_path, lineno, import_stmt in violations:
rel_path = file_path.relative_to(dist_path)
console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]")
console.print()
console.print(
f" [red]Only imports starting with '{allowed_prefix}' are allowed in [magenta]{dist_name}[/magenta][/red]"
)
return False

console.print(
f" Only allowed _shared imports found in [magenta]{dist_name}[/magenta] [bold green]OK[/bold green]"
)
return True


def check_task_sdk_distribution() -> bool:
"""Check task-sdk for airflow_shared and cross-distribution _shared imports."""
dist_name = "task-sdk"
console.print(f"\n[bold blue]Checking:[/bold blue] [magenta]{dist_name}[/magenta] distribution")
if not TASK_SDK_DIR.exists():
console.print(f" [yellow]{dist_name} directory does not exist[/yellow]")
return True
all_ok = True
if not check_no_airflow_shared_imports(TASK_SDK_DIR, dist_name):
all_ok = False
if not check_only_allowed_shared_imports(TASK_SDK_DIR, dist_name, "airflow.sdk._shared"):
all_ok = False
if all_ok:
console.print(f"[bold green]Summary: {dist_name} is OK[/bold green]")
return all_ok


def check_airflow_core_distribution() -> bool:
"""Check airflow-core for airflow_shared imports.

Note: ``check_only_allowed_shared_imports`` is intentionally not run here yet —
airflow-core currently imports a few symbols from ``airflow.sdk._shared`` (e.g.
``SecretsMasker``) that need to be relocated before that check can be enabled.
"""
dist_name = "airflow-core"
console.print(f"\n[bold blue]Checking:[/bold blue] [magenta]{dist_name}[/magenta] distribution")
if not AIRFLOW_CORE_DIR.exists():
console.print(f" [yellow]{dist_name} directory does not exist[/yellow]")
return True
all_ok = True
if not check_no_airflow_shared_imports(AIRFLOW_CORE_DIR, dist_name):
all_ok = False
if all_ok:
console.print(f"[bold green]Summary: {dist_name} is OK[/bold green]")
return all_ok


def main() -> None:
if not SHARED_DIR.exists():
print("No shared directory found.")
Expand All @@ -221,6 +542,10 @@ def main() -> None:
ok = check_shared_distribution(shared_project)
if not ok:
all_ok = False
if not check_task_sdk_distribution():
all_ok = False
if not check_airflow_core_distribution():
all_ok = False
if not all_ok:
sys.exit(2)

Expand Down
Loading
Loading