diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e2ddef2edff7..0f68fec20b5d1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -410,7 +410,13 @@ repos: entry: ./scripts/ci/prek/check_shared_distributions_structure.py language: python pass_filenames: false - files: ^shared/.*$ + files: > + (?x) + ^shared/.*$| + ^.*/pyproject\.toml$| + ^.*/_shared/.*$| + ^airflow-core/.*\.py$| + ^task-sdk/.*\.py$ - id: check-shared-distributions-usage name: Check shared distributions usage entry: ./scripts/ci/prek/check_shared_distributions_usage.py diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index df9c44910fa20..948f560d2af2c 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -397,6 +397,7 @@ def task_test(args, dag: DAG | None = None) -> None: # airflow.task would redirect to a file, but here we want it to propagate # up to the normal airflow handler. + # TODO: move bulk of this logic into the SDK: https://github.com/apache/airflow/issues/54658 from airflow.sdk._shared.secrets_masker import SecretsMasker SecretsMasker.enable_log_masking() diff --git a/airflow-core/src/airflow/cli/commands/triggerer_command.py b/airflow-core/src/airflow/cli/commands/triggerer_command.py index 8b9ee178a19f8..aff89b3e10766 100644 --- a/airflow-core/src/airflow/cli/commands/triggerer_command.py +++ b/airflow-core/src/airflow/cli/commands/triggerer_command.py @@ -63,6 +63,7 @@ def triggerer_run( @providers_configuration_loaded def triggerer(args): """Start Airflow Triggerer.""" + # TODO: Decide what to do with Triggerer from airflow.sdk._shared.secrets_masker import SecretsMasker SecretsMasker.enable_log_masking() diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index 47f84bbb28e9c..4cde1e5e371cb 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -176,6 +176,7 @@ def __hash__(self): r"^Dockerfile", r"^scripts/ci/docker-compose", r"^scripts/ci/kubernetes", + r"^scripts/ci/prek", r"^scripts/docker", r"^scripts/in_container", r"^generated/provider_dependencies.json$", @@ -1930,7 +1931,9 @@ def provider_dependency_bump(self) -> bool: console_print( "[warning]Only Release Managers should change >= conditions for apache-airflow-providers " "dependencies.[/]\n\nIf you want to refer to a future version of the dependency, please add a " - "comment [info]'# use next version'[/info] in the line of the dependency instead.\n" + "comment [info]'# use next version'[/info] in the line of the dependency instead.\n\n" + "If there is already the same provider before the `# Additional devel dependencies` without " + "version, move the provider dependency to required ones and add `# use next version` there." ) console_print() console_print( diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 21296618fe0f6..63acc7c29fe08 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1603,9 +1603,9 @@ def test_excluded_providers(): pytest.param( ("scripts/ci/prek/file.sh",), { - "full-tests-needed": "false", + "full-tests-needed": "true", }, - id="No full tests needed when prek scripts change", + id="Full tests needed when prek scripts change", ) ), ( diff --git a/scripts/ci/prek/check_shared_distributions_structure.py b/scripts/ci/prek/check_shared_distributions_structure.py index 3a7366bd300cd..b8a2f19870455 100755 --- a/scripts/ci/prek/check_shared_distributions_structure.py +++ b/scripts/ci/prek/check_shared_distributions_structure.py @@ -28,6 +28,8 @@ from __future__ import annotations +import ast +import itertools import re import sys from pathlib import Path @@ -40,6 +42,9 @@ from common_prek_utils import AIRFLOW_ROOT_PATH, console SHARED_DIR = AIRFLOW_ROOT_PATH / "shared" +TASK_SDK_DIR = AIRFLOW_ROOT_PATH / "task-sdk" +AIRFLOW_CORE_DIR = AIRFLOW_ROOT_PATH / "airflow-core" +DEVEL_COMMON_DIR = AIRFLOW_ROOT_PATH / "devel-common" def check_pyproject_exists(shared_path: Path) -> bool: @@ -174,6 +179,199 @@ def check_ruff_lint_rules(ruff: dict, shared_path: Path) -> bool: return False +def _parse_python_file(py_file: Path, base_path: Path) -> ast.Module | None: + """Parse a Python file and return an AST tree, handling errors gracefully.""" + try: + with open(py_file, encoding="utf-8") as f: + return ast.parse(f.read(), filename=str(py_file)) + except SyntaxError as e: + console.print(f" [yellow]Warning: Could not parse {py_file.relative_to(base_path)}: {e}[/yellow]") + return None + except Exception as e: + console.print(f" [yellow]Warning: Error reading {py_file.relative_to(base_path)}: {e}[/yellow]") + return None + + +def _is_type_checking_guard(node: ast.If) -> bool: + """Check if an ``If`` node is a ``TYPE_CHECKING`` guard.""" + test = node.test + if isinstance(test, ast.Name) and test.id == "TYPE_CHECKING": + return True + if isinstance(test, ast.Attribute) and test.attr == "TYPE_CHECKING": + return True + return False + + +def _collect_type_checking_node_ids(tree: ast.AST) -> set[int]: + """Return the ``id()`` of every AST node nested inside an ``if TYPE_CHECKING`` block.""" + guarded: set[int] = set() + for node in ast.walk(tree): + if isinstance(node, ast.If) and _is_type_checking_guard(node): + for child in ast.walk(node): + guarded.add(id(child)) + return guarded + + +def _check_imports_in_files( + py_files: list[Path], + base_path: Path, + import_predicate, + dist_name: str, +) -> list[tuple[Path, int, str]]: + """ + Check imports in Python files based on a predicate function. + + Imports nested inside ``if TYPE_CHECKING:`` blocks are skipped — they are type-only + and never resolved at runtime, so they cannot create real dependencies between + distributions. + + Args: + py_files: List of Python files to check + base_path: Base path for relative path calculation + import_predicate: Function that takes (node, alias/module) and returns (should_report, import_stmt) + dist_name: Distribution name for logging + + Returns: + List of violations (file_path, lineno, import_stmt) + """ + violations = [] + console.print(f" Checking imports in {len(py_files)} files for [magenta]{dist_name}[/magenta]") + for py_file in py_files: + tree = _parse_python_file(py_file, base_path) + if tree is None: + continue + + type_checking_ids = _collect_type_checking_node_ids(tree) + + for node in ast.walk(tree): + if id(node) in type_checking_ids: + continue + if isinstance(node, ast.Import): + for alias in node.names: + should_report, import_stmt = import_predicate(node, alias.name, is_from_import=False) + if should_report: + violations.append((py_file, node.lineno, import_stmt)) + + elif isinstance(node, ast.ImportFrom): + if node.module: + should_report, import_stmt = import_predicate(node, node.module, is_from_import=True) + if should_report: + violations.append((py_file, node.lineno, import_stmt)) + + return violations + + +def check_no_airflow_dependencies(pyproject: dict, shared_path: Path) -> bool: + """Check that airflow packages are not listed as dependencies.""" + forbidden_packages = {"apache-airflow", "apache-airflow-core", "apache-airflow-task-sdk"} + violations = [] + + # Check regular dependencies + dependencies = pyproject.get("project", {}).get("dependencies", []) + for dep in dependencies: + # Parse dependency string (may include version specifiers) + dep_name = dep.split("[")[0].split(">")[0].split("<")[0].split("=")[0].split("!")[0].strip() + if dep_name in forbidden_packages: + violations.append(("dependencies", dep)) + + # Check optional dependencies + optional_deps = pyproject.get("project", {}).get("optional-dependencies", {}) + for group_name, deps in optional_deps.items(): + for dep in deps: + dep_name = dep.split("[")[0].split(">")[0].split("<")[0].split("=")[0].split("!")[0].strip() + if dep_name in forbidden_packages: + violations.append((f"optional-dependencies.{group_name}", dep)) + + # Check dependency groups (PEP 735) + dependency_groups = pyproject.get("dependency-groups", {}) + for group_name, deps in dependency_groups.items(): + for dep in deps: + dep_name = dep.split("[")[0].split(">")[0].split("<")[0].split("=")[0].split("!")[0].strip() + if dep_name in forbidden_packages: + violations.append((f"dependency-groups.{group_name}", dep)) + + if violations: + console.print( + f" [red]Found forbidden airflow dependencies in [magenta]{shared_path.name}[/magenta]:[/red]" + ) + for location, dep in violations: + console.print(f" [red]{location}: {dep}[/red]") + console.print() + console.print(f" [red]Shared distributions must not depend on {', '.join(forbidden_packages)}[/red]") + return False + + console.print( + f" No forbidden airflow dependencies in [magenta]{shared_path.name}[/magenta] [bold green]OK[/bold green]" + ) + return True + + +def check_no_airflow_imports(shared_path: Path) -> bool: + """Check that no Python files in src/ import from airflow using AST.""" + src_path = shared_path / "src" + test_path = shared_path / "tests" + if not src_path.exists(): + console.print( + f" [yellow]src/ directory does not exist for [magenta]{shared_path.name}[/magenta][/yellow]" + ) + return True + + def airflow_import_predicate(node, module_name, is_from_import): + """Check if import is from airflow package.""" + if module_name == "airflow" or module_name.startswith("airflow."): + if is_from_import: + imported_names = ", ".join(alias.name for alias in node.names) + return True, f"from {module_name} import {imported_names}" + return True, f"import {module_name}" + return False, "" + + py_files = list(itertools.chain(src_path.rglob("*.py"), test_path.rglob("*.py"))) + violations = _check_imports_in_files(py_files, shared_path, airflow_import_predicate, shared_path.name) + + if violations: + console.print(f" [red]Found airflow imports in [magenta]{shared_path.name}[/magenta]:[/red]") + for file_path, lineno, import_stmt in violations: + rel_path = file_path.relative_to(shared_path) + console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]") + console.print() + console.print( + f" [red]Please remove airflow imports from [magenta]{shared_path.name}[/magenta][/red] " + f"in the way described below:" + ) + console.print() + console.print(" [yellow]How to fix: When you see `airflow` import in `src`:[/]") + console.print() + console.print( + "[bright_blue]" + " * When you want to use `from airflow.**._shared.ANOTHER_DISTRIBUTIONS - use \n" + " relative import `from ..ANOTHER_DISTRIBUTION` (might need ../../ or ../../..)" + ) + console.print( + " [bright_blue]" + " * When you want to use `from airflow.MODULE`, move the module to another shared \n" + " distribution (sometimes it might cause loss of DRY)" + ) + console.print() + console.print(" [yellow]How to fix: When you see `airflow` import in `tests`:[/]") + console.print() + console.print( + "[bright_blue]" + " * When you want to use `from airflow.**._shared.MY_DISTRIBUTION`\n" + " use `from airflow_shared.MY_DISTRIBUTION import ...` " + ) + console.print( + "[bright_blue]" + " * When you want to use `from airflow.MODULE` where it is some test-related code \n" + " this code should likely be moved to `devel-common`" + ) + return False + + console.print( + f" No airflow imports found in [magenta]{shared_path.name}[/magenta] [bold green]OK[/bold green]" + ) + return True + + def check_shared_distribution(shared_path: Path) -> bool: pyproject_path = shared_path / "pyproject.toml" console.print(f"\n[bold blue]Checking:[/bold blue] [magenta]{shared_path.name}[/magenta] shared project") @@ -201,26 +399,222 @@ def check_shared_distribution(shared_path: Path) -> bool: return False if not check_ruff_lint_rules(ruff, shared_path): return False + if not check_no_airflow_dependencies(pyproject, shared_path): + return False + if not check_no_airflow_imports(shared_path): + return False console.print(f"[bold green]Summary: {shared_path.name} is OK[/bold green]") return True +def check_no_airflow_shared_imports(dist_path: Path, dist_name: str) -> bool: + """Check that no Python files use airflow_shared imports.""" + src_path = dist_path / "src" + if not src_path.exists(): + console.print(f" [yellow]src/ directory does not exist for [magenta]{dist_name}[/magenta][/yellow]") + return True + + def airflow_shared_import_predicate(node, module_name, is_from_import): + """Check if import is from airflow_shared package.""" + if module_name == "airflow_shared" or module_name.startswith("airflow_shared."): + if is_from_import: + imported_names = ", ".join(alias.name for alias in node.names) + return True, f"from {module_name} import {imported_names}" + return True, f"import {module_name}" + return False, "" + + py_files = list(src_path.rglob("*.py")) + violations = _check_imports_in_files(py_files, dist_path, airflow_shared_import_predicate, dist_name) + + if violations: + console.print(f" [red]Found airflow_shared imports in [magenta]{dist_name}[/magenta]:[/red]") + for file_path, lineno, import_stmt in violations: + rel_path = file_path.relative_to(dist_path) + console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]") + console.print() + console.print( + f" [red]Please do not use airflow_shared imports in [magenta]{dist_name}[/magenta][/red]" + ) + console.print( + " [yellow]Use proper _shared imports instead (e.g., airflow._shared.* or airflow.sdk._shared.*)[/yellow]" + ) + return False + + console.print( + f" No airflow_shared imports found in [magenta]{dist_name}[/magenta] [bold green]OK[/bold green]" + ) + return True + + +def check_only_allowed_shared_imports(dist_path: Path, dist_name: str, allowed_prefix: str) -> bool: + """Check that only imports with the allowed _shared prefix are used.""" + src_path = dist_path / "src" + if not src_path.exists(): + console.print(f" [yellow]src/ directory does not exist for [magenta]{dist_name}[/magenta][/yellow]") + return True + + def allowed_shared_import_predicate(node, module_name, is_from_import): + """Check if _shared import uses the correct prefix.""" + if "._shared" in module_name or module_name.endswith("._shared"): + if not module_name.startswith(allowed_prefix): + if is_from_import: + imported_names = ", ".join(alias.name for alias in node.names) + return True, f"from {module_name} import {imported_names}" + return True, f"import {module_name}" + return False, "" + + py_files = list(src_path.rglob("*.py")) + violations = _check_imports_in_files(py_files, dist_path, allowed_shared_import_predicate, dist_name) + + if violations: + console.print(f" [red]Found disallowed _shared imports in [magenta]{dist_name}[/magenta]:[/red]") + for file_path, lineno, import_stmt in violations: + rel_path = file_path.relative_to(dist_path) + console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]") + console.print() + console.print( + f" [red]Only imports starting with '{allowed_prefix}' are allowed in [magenta]{dist_name}[/magenta][/red]" + ) + return False + + console.print( + f" Only allowed _shared imports found in [magenta]{dist_name}[/magenta] [bold green]OK[/bold green]" + ) + return True + + +def check_distribution(dist_path: Path, dist_name: str, allowed_shared_prefix: str) -> bool: + """ + Check a distribution for proper _shared imports usage. + + Args: + dist_path: Path to the distribution directory + dist_name: Name of the distribution for display + allowed_shared_prefix: Allowed prefix for _shared imports (e.g., 'airflow.sdk._shared') + + Returns: + True if all checks pass, False otherwise + """ + console.print(f"\n[bold blue]Checking:[/bold blue] [magenta]{dist_name}[/magenta] distribution") + + if not dist_path.exists(): + console.print(f" [yellow]{dist_name} directory does not exist[/yellow]") + return True + + all_ok = True + + # Check 1: No airflow_shared imports + if not check_no_airflow_shared_imports(dist_path, dist_name): + all_ok = False + + # Check 2: Only allowed _shared imports + if not check_only_allowed_shared_imports(dist_path, dist_name, allowed_shared_prefix): + all_ok = False + + if all_ok: + console.print(f"[bold green]Summary: {dist_name} is OK[/bold green]") + + return all_ok + + +def check_task_sdk_distribution() -> bool: + """Check task-sdk distribution for proper _shared imports usage.""" + return check_distribution(TASK_SDK_DIR, "task-sdk", "airflow.sdk._shared") + + +def check_airflow_core_distribution() -> bool: + """Check airflow-core distribution for proper _shared imports usage.""" + return check_distribution(AIRFLOW_CORE_DIR, "airflow-core", "airflow._shared") + + +def check_no_airflow_imports_devel_common(dist_path: Path) -> bool: + """Check that no Python files in devel-common use airflow imports.""" + src_path = dist_path / "src" + if not src_path.exists(): + console.print(" [yellow]src/ directory does not exist for [magenta]devel-common[/magenta][/yellow]") + return True + + def airflow_import_predicate(node, module_name, is_from_import): + """Check if import is from airflow package.""" + if module_name == "airflow" or module_name.startswith("airflow."): + if is_from_import: + imported_names = ", ".join(alias.name for alias in node.names) + return True, f"from {module_name} import {imported_names}" + return True, f"import {module_name}" + return False, "" + + py_files = list(src_path.rglob("*.py")) + violations = _check_imports_in_files(py_files, dist_path, airflow_import_predicate, "devel-common") + + if violations: + console.print(" [red]Found airflow imports in [magenta]devel-common[/magenta]:[/red]") + for file_path, lineno, import_stmt in violations: + rel_path = file_path.relative_to(dist_path) + console.print(f" [red]{rel_path}:{lineno}: {import_stmt}[/red]") + console.print() + console.print(" [red]Please remove airflow imports from [magenta]devel-common[/magenta][/red]") + console.print( + " [yellow]devel-common should not depend on airflow packages to remain independent[/yellow]\n\n" + " [yellow]Those imports should be converted to `from airflow_shared` or " + "moved to the devel-common distribution.[/yellow]" + ) + return False + + console.print(" No airflow imports found in [magenta]devel-common[/magenta] [bold green]OK[/bold green]") + return True + + +def check_devel_common_distribution() -> bool: + """Check devel-common distribution for proper imports usage.""" + console.print("\n[bold blue]Checking:[/bold blue] [magenta]devel-common[/magenta] distribution") + + if not DEVEL_COMMON_DIR.exists(): + console.print(" [yellow]devel-common directory does not exist[/yellow]") + return True + + all_ok = True + + # Check: No airflow imports + if not check_no_airflow_imports_devel_common(DEVEL_COMMON_DIR): + all_ok = False + + if all_ok: + console.print("[bold green]Summary: devel-common is OK[/bold green]") + + return all_ok + + def main() -> None: - if not SHARED_DIR.exists(): - print("No shared directory found.") - sys.exit(1) all_ok = True - for shared_project in SHARED_DIR.iterdir(): - if shared_project.is_dir(): - # Not having a pyproject.toml means no distribution here, so using it as a marker to skip - if not (shared_project / "pyproject.toml").exists(): - console.print( - f"\n[yellow]Skipping empty directory:[/yellow] [magenta]{shared_project.name}[/magenta]" - ) - continue - ok = check_shared_distribution(shared_project) - if not ok: - all_ok = False + # Check shared distributions + if SHARED_DIR.exists(): + for shared_project in SHARED_DIR.iterdir(): + if shared_project.is_dir(): + # Not having a pyproject.toml means no distribution here, so using it as a marker to skip + if not (shared_project / "pyproject.toml").exists(): + console.print( + f"\n[yellow]Skipping empty directory:[/yellow] [magenta]{shared_project.name}[/magenta]" + ) + continue + ok = check_shared_distribution(shared_project) + if not ok: + all_ok = False + else: + console.print("[yellow]No shared directory found.[/yellow]") + sys.exit(1) + + # Check task-sdk distribution + if not check_task_sdk_distribution(): + all_ok = False + + # Check airflow-core distribution + if not check_airflow_core_distribution(): + all_ok = False + + # Check devel-common distribution + if not check_devel_common_distribution(): + all_ok = False + if not all_ok: sys.exit(2) diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py b/task-sdk/src/airflow/sdk/definitions/deadline.py index 7164eeec07672..2ab20d056b9d5 100644 --- a/task-sdk/src/airflow/sdk/definitions/deadline.py +++ b/task-sdk/src/airflow/sdk/definitions/deadline.py @@ -98,7 +98,7 @@ def serialize_reference(self) -> dict[str, Any]: @classmethod def deserialize_reference(cls, reference_data: dict[str, Any]) -> FixedDatetimeDeadline: - from airflow._shared.timezones import timezone + from airflow.sdk._shared.timezones import timezone return cls(_datetime=timezone.from_timestamp(reference_data["datetime"]))