From 58660a80d27a36b54718877c5dc2517fe0a17851 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 6 Jun 2026 02:16:35 +0100 Subject: [PATCH 1/3] Reduce SSH connection churn in SSHRemoteJobOperator under high fan-out The operator and trigger opened a new SSH connection for every remote command. A large expand() fan-out against one host drove the connection rate past the remote sshd MaxStartups limit, which drops connections and surfaces as "paramiko ... Error reading SSH protocol banner" (an immediate EOF, not a banner timeout) at submit time, and left job directories behind when the cleanup connection was dropped too. Changes: - Trigger holds one connection for the whole poll loop instead of reconnecting per command, with bounded jittered reconnect on drops and asyncssh.Error treated as reconnectable. - Operator reuses one connection for OS detection and submission. - Cleanup retries instead of orphaning the job directory on a dropped connection. - Configurable conn_retry_attempts (operator/hook) for the submit burst, plus command_timeout and max_reconnect_attempts forwarded to the trigger. - SSHHookAsync sets a keepalive on the long-lived trigger connection. --- .../ssh/docs/operators/ssh_remote_job.rst | 40 ++- .../src/airflow/providers/ssh/hooks/ssh.py | 25 +- .../providers/ssh/operators/ssh_remote_job.py | 227 +++++++++------ .../providers/ssh/triggers/ssh_remote_job.py | 244 +++++++++------- .../ssh/tests/unit/ssh/hooks/test_ssh.py | 20 ++ .../unit/ssh/operators/test_ssh_remote_job.py | 115 ++++++++ .../unit/ssh/triggers/test_ssh_remote_job.py | 273 +++++++++++------- 7 files changed, 657 insertions(+), 287 deletions(-) diff --git a/providers/ssh/docs/operators/ssh_remote_job.rst b/providers/ssh/docs/operators/ssh_remote_job.rst index c78a07927405f..84757a9f11d0c 100644 --- a/providers/ssh/docs/operators/ssh_remote_job.rst +++ b/providers/ssh/docs/operators/ssh_remote_job.rst @@ -164,6 +164,13 @@ Parameters * ``remote_os`` (str, optional): Remote OS type (``"auto"``, ``"posix"``, ``"windows"``). Default: ``"auto"`` * ``skip_on_exit_code`` (int or list, optional): Exit code(s) that should cause task to skip instead of fail +* ``conn_timeout`` (int, optional): SSH connection timeout in seconds +* ``banner_timeout`` (float, optional): Seconds to wait for the SSH banner. Default: 30.0 +* ``conn_retry_attempts`` (int, optional): How many times to attempt the initial SSH connection for + submission and cleanup before failing. Default: 5. Raise this for large fan-outs where the remote + ``sshd`` transiently refuses connections (see :ref:`High Fan-out `) +* ``cleanup_retries`` (int, optional): How many times to retry remote directory cleanup before giving up + and leaving the directory in place. Default: 3 Remote OS Detection ------------------- @@ -213,7 +220,9 @@ Limitations and Considerations ------------------------------- **Network Interruptions**: While the operator is resilient to disconnections during monitoring, -the initial job submission must succeed. If submission fails, the task will fail immediately. +the initial job submission must succeed. The connection used for submission is retried +(``conn_retry_attempts``); if every attempt fails, the task fails immediately. The trigger also +reconnects automatically if the monitoring connection drops mid-job. **Remote Process Management**: Jobs are detached using ``nohup`` (POSIX) or ``Start-Process`` (Windows). If the remote host reboots during job execution, the job will be lost. @@ -231,7 +240,34 @@ tasks can run on the same remote host without conflicts. **Cleanup**: Use ``cleanup="on_success"`` or ``cleanup="always"`` to avoid accumulating job directories on the remote host. For debugging, use ``cleanup="never"`` and manually -inspect the job directory. +inspect the job directory. Cleanup runs only when the job reaches completion, so tasks that +are killed or time out can still leave a directory behind; for those, add a server-side TTL +reaper (for example ``systemd-tmpfiles`` or a cron job) for the base directory. + +.. _howto/operator:SSHRemoteJobOperator:fanout: + +High Fan-out (Many Concurrent Tasks) +------------------------------------- + +A large ``.expand()`` fan-out points many tasks at the same SSH server at once. Each remote +command opens a new SSH connection, and the remote ``sshd`` throttles concurrent +*unauthenticated* connections via ``MaxStartups`` (default ``10:30:100``: start randomly +dropping at 10 concurrent, reaching 100% at 100). A dropped connection surfaces on the client +as:: + + paramiko ... Error reading SSH protocol banner + +This is the server closing the socket before the handshake, not a slow banner, so raising +``banner_timeout`` does not help. + +The operator and trigger keep the connection rate low: submission reuses a single connection +for OS detection and the submit itself, and the trigger holds **one** connection for the whole +poll loop instead of reconnecting on every status check. To push a high fan-out further: + +* Raise ``MaxStartups`` (and ``MaxSessions``) on the remote ``sshd`` -- this is the direct fix. +* Increase ``conn_retry_attempts`` so transient refusals during the initial burst are retried. +* Spread submissions out (for example with a pool or ``max_active_tis_per_dag``) rather than + releasing the entire fan-out simultaneously. Comparison with SSHOperator ---------------------------- diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py index d029a50772b35..54b7edf6008c7 100644 --- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py @@ -82,6 +82,9 @@ class SSHHook(BaseHook): lifetime of the transport :param ciphers: list of ciphers to use in order of preference :param auth_timeout: timeout (in seconds) for the attempt to authenticate with the remote_host + :param conn_retry_attempts: number of times to attempt the initial SSH connection before + giving up (default 3). Raising this helps when many tasks target the same SSH server at + once and some connections are transiently refused (e.g. ``sshd`` ``MaxStartups`` throttling). """ # List of classes to try loading private keys as, ordered (roughly) by most common to least common @@ -130,9 +133,11 @@ def __init__( ciphers: list[str] | None = None, auth_timeout: int | None = None, host_proxy_cmd: str | None = None, + conn_retry_attempts: int = 3, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id + self.conn_retry_attempts = max(1, conn_retry_attempts) self.remote_host = remote_host self.username = username self.password = password @@ -344,7 +349,7 @@ def log_before_sleep(retry_state): for attempt in Retrying( reraise=True, wait=wait_fixed(3) + wait_random(0, 2), - stop=stop_after_attempt(3), + stop=stop_after_attempt(self.conn_retry_attempts), before_sleep=log_before_sleep, ): with attempt: @@ -553,6 +558,7 @@ def __init__( key_file: str = "", passphrase: str = "", private_key: str = "", + keepalive_interval: int = 30, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -564,6 +570,7 @@ def __init__( self.key_file = key_file self.passphrase = passphrase self.private_key = private_key + self.keepalive_interval = keepalive_interval def _parse_extras(self, conn: Any) -> None: """Parse extra fields from the connection into instance fields.""" @@ -631,10 +638,26 @@ def _get_value(self_val, conn_val, default=None): conn_config["client_keys"] = [_private_key] if self.passphrase: conn_config["passphrase"] = self.passphrase + if self.keepalive_interval: + # The trigger holds one connection for the whole job; a keepalive stops idle + # NAT/firewall timeouts from silently dropping it between long poll intervals. + conn_config["keepalive_interval"] = self.keepalive_interval ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn + async def get_conn(self): + """ + Open an asyncssh connection that can be reused for multiple commands. + + Unlike :meth:`run_command`, the returned connection is **not** closed + automatically; the caller owns its lifecycle (e.g. + ``async with await hook.get_conn() as conn: ...`` or an explicit + ``conn.close()``). Reusing one connection avoids a new TCP/SSH handshake + per command, which matters when many tasks poll the same SSH server. + """ + return await self._get_conn() + async def run_command(self, command: str, timeout: float | None = None) -> tuple[int, str, str]: """ Execute a command on the remote host asynchronously. diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py index 783edb39b8fd7..cfe18a1e5c59b 100644 --- a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py +++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py @@ -19,6 +19,7 @@ from __future__ import annotations +import time import warnings from collections.abc import Container, Sequence from datetime import timedelta @@ -74,6 +75,25 @@ class SSHRemoteJobOperator(BaseOperator): :param skip_on_exit_code: Exit codes that should skip the task instead of failing :param conn_timeout: SSH connection timeout in seconds :param banner_timeout: Timeout waiting for SSH banner in seconds + :param conn_retry_attempts: How many times to attempt the initial SSH connection for + submission/cleanup before failing (default 5). Helps when many mapped tasks hit the + same host at once and ``sshd`` transiently refuses connections (``MaxStartups``). + :param cleanup_retries: How many times to attempt remote directory cleanup before + giving up and leaving the directory in place (default 3). Prevents a transient SSH + failure during cleanup from orphaning the job directory on the remote host. + :param command_timeout: Per-command timeout in seconds for the trigger's status/log polls + (default 30.0). + :param max_reconnect_attempts: Consecutive connection failures the trigger tolerates (with + backoff) before failing the task while monitoring the remote job (default 5). + + .. note:: + A large ``expand()`` fan-out opens many SSH connections against one host. The remote + ``sshd`` throttles concurrent unauthenticated connections via ``MaxStartups`` (default + ``10:30:100``); when exceeded it drops connections, surfacing as + ``paramiko ... Error reading SSH protocol banner``. For high fan-out, raise ``MaxStartups`` + on the server. The directory ``/tmp/airflow-ssh-jobs`` (POSIX) is only cleaned when + ``cleanup`` is set and the job reaches completion, so also consider a server-side TTL + reaper (for example ``systemd-tmpfiles``) for jobs that are killed or time out. """ template_fields: Sequence[str] = ("command", "environment", "remote_host", "remote_base_dir") @@ -104,6 +124,10 @@ def __init__( skip_on_exit_code: int | Container[int] | None = None, conn_timeout: int | None = None, banner_timeout: float = 30.0, + conn_retry_attempts: int = 5, + cleanup_retries: int = 3, + command_timeout: float = 30.0, + max_reconnect_attempts: int = 5, **kwargs, ) -> None: super().__init__(**kwargs) @@ -123,6 +147,10 @@ def __init__( self.remote_os = remote_os self.conn_timeout = conn_timeout self.banner_timeout = banner_timeout + self.conn_retry_attempts = conn_retry_attempts + self.cleanup_retries = max(1, cleanup_retries) + self.command_timeout = command_timeout + self.max_reconnect_attempts = max_reconnect_attempts self.skip_on_exit_code = ( skip_on_exit_code if isinstance(skip_on_exit_code, Container) @@ -170,67 +198,69 @@ def ssh_hook(self) -> SSHHook: remote_host=self.remote_host or "", conn_timeout=self.conn_timeout, banner_timeout=self.banner_timeout, + conn_retry_attempts=self.conn_retry_attempts, ) - def _detect_remote_os(self) -> Literal["posix", "windows"]: + def _detect_remote_os(self, ssh_client) -> Literal["posix", "windows"]: """ - Detect the remote operating system. + Detect the remote operating system on an already-open SSH connection. Uses a two-stage detection: 1. Try POSIX detection via `uname` (works on Linux, macOS, BSD, Solaris, AIX, etc.) 2. Try Windows detection via PowerShell 3. Raise error if both fail + + :param ssh_client: An open paramiko SSH client to reuse (avoids a second handshake). """ if self.remote_os != "auto": return self.remote_os self.log.info("Auto-detecting remote operating system...") - with self.ssh_hook.get_conn() as ssh_client: - try: - exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( - ssh_client, - build_posix_os_detection_command(), - get_pty=False, - environment=None, - timeout=10, - ) - if exit_status == 0 and stdout: - output = stdout.decode("utf-8", errors="replace").strip().lower() - posix_systems = [ - "linux", - "darwin", - "freebsd", - "openbsd", - "netbsd", - "sunos", - "aix", - "hp-ux", - ] - if any(system in output for system in posix_systems): - self.log.info("Detected POSIX system: %s", output) - return "posix" - except Exception as e: - self.log.debug("POSIX detection failed: %s", e) - - try: - exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( - ssh_client, - build_windows_os_detection_command(), - get_pty=False, - environment=None, - timeout=10, - ) - if exit_status == 0 and stdout: - output = stdout.decode("utf-8", errors="replace").strip() - if "WINDOWS" in output.upper(): - self.log.info("Detected Windows system") - return "windows" - except Exception as e: - self.log.debug("Windows detection failed: %s", e) + try: + exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( + ssh_client, + build_posix_os_detection_command(), + get_pty=False, + environment=None, + timeout=10, + ) + if exit_status == 0 and stdout: + output = stdout.decode("utf-8", errors="replace").strip().lower() + posix_systems = [ + "linux", + "darwin", + "freebsd", + "openbsd", + "netbsd", + "sunos", + "aix", + "hp-ux", + ] + if any(system in output for system in posix_systems): + self.log.info("Detected POSIX system: %s", output) + return "posix" + except Exception as e: + self.log.debug("POSIX detection failed: %s", e) - raise AirflowException( - "Could not auto-detect remote OS. Please explicitly set remote_os='posix' or 'windows'" + try: + exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( + ssh_client, + build_windows_os_detection_command(), + get_pty=False, + environment=None, + timeout=10, ) + if exit_status == 0 and stdout: + output = stdout.decode("utf-8", errors="replace").strip() + if "WINDOWS" in output.upper(): + self.log.info("Detected Windows system") + return "windows" + except Exception as e: + self.log.debug("Windows detection failed: %s", e) + + raise AirflowException( + "Could not auto-detect remote OS. Please explicitly set remote_os='posix' or 'windows'" + ) def execute(self, context: Context) -> None: """ @@ -241,9 +271,6 @@ def execute(self, context: Context) -> None: if not self.command: raise AirflowException("SSH operator error: command not specified.") - self._detected_os = self._detect_remote_os() - self.log.info("Remote OS: %s", self._detected_os) - ti = context["ti"] self._job_id = generate_job_id( dag_id=ti.dag_id, @@ -253,27 +280,34 @@ def execute(self, context: Context) -> None: ) self.log.info("Generated job ID: %s", self._job_id) - self._paths = RemoteJobPaths( - job_id=self._job_id, - remote_os=self._detected_os, - base_dir=self.remote_base_dir, - ) + # Reuse a single connection for OS detection (when 'auto') and submission so the + # operator opens one SSH handshake per task instead of two. Under a large fan-out + # this halves the connection burst that triggers sshd MaxStartups throttling. + self.log.info("Connecting to %s", self.ssh_hook.remote_host) + with self.ssh_hook.get_conn() as ssh_client: + self._detected_os = self._detect_remote_os(ssh_client) + self.log.info("Remote OS: %s", self._detected_os) - if self._detected_os == "posix": - wrapper_cmd = build_posix_wrapper_command( - command=self.command, - paths=self._paths, - environment=self.environment, - ) - else: - wrapper_cmd = build_windows_wrapper_command( - command=self.command, - paths=self._paths, - environment=self.environment, + self._paths = RemoteJobPaths( + job_id=self._job_id, + remote_os=self._detected_os, + base_dir=self.remote_base_dir, ) - self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host) - with self.ssh_hook.get_conn() as ssh_client: + if self._detected_os == "posix": + wrapper_cmd = build_posix_wrapper_command( + command=self.command, + paths=self._paths, + environment=self.environment, + ) + else: + wrapper_cmd = build_windows_wrapper_command( + command=self.command, + paths=self._paths, + environment=self.environment, + ) + + self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host) exit_status, stdout, stderr = self.ssh_hook.exec_ssh_client_command( ssh_client, wrapper_cmd, @@ -320,6 +354,8 @@ def execute(self, context: Context) -> None: poll_interval=self.poll_interval, log_chunk_size=self.log_chunk_size, log_offset=0, + command_timeout=self.command_timeout, + max_reconnect_attempts=self.max_reconnect_attempts, ), method_name="execute_complete", timeout=timedelta(seconds=self.timeout) if self.timeout else None, @@ -361,6 +397,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: poll_interval=self.poll_interval, log_chunk_size=self.log_chunk_size, log_offset=event.get("log_offset", 0), + command_timeout=self.command_timeout, + max_reconnect_attempts=self.max_reconnect_attempts, ), method_name="execute_complete", timeout=timedelta(seconds=self.timeout) if self.timeout else None, @@ -389,25 +427,46 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: self.log.info("Remote job completed successfully") def _cleanup_remote_job(self, job_dir: str, remote_os: str) -> None: - """Clean up the remote job directory.""" + """ + Clean up the remote job directory, retrying on transient SSH failures. + + Under a large fan-out the cleanup connection can itself be refused by the + remote ``sshd`` (``MaxStartups``). Retrying a few times keeps a transient drop + from orphaning the job directory; if every attempt fails we log loudly and + leave the directory rather than failing the (already finished) task. + """ self.log.info("Cleaning up remote job directory: %s", job_dir) - try: - if remote_os == "posix": - cleanup_cmd = build_posix_cleanup_command(job_dir) - else: - cleanup_cmd = build_windows_cleanup_command(job_dir) + if remote_os == "posix": + cleanup_cmd = build_posix_cleanup_command(job_dir) + else: + cleanup_cmd = build_windows_cleanup_command(job_dir) - with self.ssh_hook.get_conn() as ssh_client: - self.ssh_hook.exec_ssh_client_command( - ssh_client, - cleanup_cmd, - get_pty=False, - environment=None, - timeout=30, - ) - self.log.info("Remote cleanup completed") - except Exception as e: - self.log.warning("Failed to clean up remote job directory: %s", e) + last_error: Exception | None = None + for attempt in range(1, self.cleanup_retries + 1): + try: + with self.ssh_hook.get_conn() as ssh_client: + self.ssh_hook.exec_ssh_client_command( + ssh_client, + cleanup_cmd, + get_pty=False, + environment=None, + timeout=30, + ) + self.log.info("Remote cleanup completed") + return + except Exception as e: + last_error = e + self.log.warning("Cleanup attempt %d/%d failed: %s", attempt, self.cleanup_retries, e) + if attempt < self.cleanup_retries: + time.sleep(min(2**attempt, 10)) + + self.log.warning( + "Failed to clean up remote job directory after %d attempts; leaving orphaned " + "directory %s on the remote host (last error: %s)", + self.cleanup_retries, + job_dir, + last_error, + ) def on_kill(self) -> None: """ diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py index 0d4072c1ca4c4..cdb441176e389 100644 --- a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py +++ b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py @@ -20,10 +20,11 @@ from __future__ import annotations import asyncio +import random from collections.abc import AsyncIterator -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal -import tenacity +import asyncssh from airflow.providers.ssh.hooks.ssh import SSHHookAsync from airflow.providers.ssh.utils.remote_job import ( @@ -36,6 +37,16 @@ ) from airflow.triggers.base import BaseTrigger, TriggerEvent +if TYPE_CHECKING: + from asyncssh import SSHClientConnection + +# Errors that mean the connection itself is broken/refused and the poll should +# reconnect instead of failing the job. ``asyncssh.Error`` covers handshake, +# protocol and disconnect failures (e.g. an sshd that drops the connection under +# ``MaxStartups`` load); ``OSError`` covers TCP-level refusals; ``TimeoutError`` +# covers a wedged command or connection. +_CONNECTION_ERRORS = (OSError, asyncssh.Error, TimeoutError) + class SSHRemoteJobTrigger(BaseTrigger): """ @@ -44,6 +55,13 @@ class SSHRemoteJobTrigger(BaseTrigger): This trigger polls the remote host to check job completion status and reads log output incrementally. + A single SSH connection is opened and reused for the whole poll loop instead + of reconnecting for every command. Opening a fresh TCP/SSH connection per poll + multiplies the connection rate against the remote ``sshd`` (which throttles + concurrent unauthenticated connections via ``MaxStartups``), so reuse keeps the + load flat when many tasks target the same host. If the connection drops, the + trigger transparently reconnects with backoff up to ``max_reconnect_attempts``. + :param ssh_conn_id: SSH connection ID from Airflow Connections :param remote_host: Optional override for the remote host :param job_id: Unique identifier for the remote job @@ -54,6 +72,9 @@ class SSHRemoteJobTrigger(BaseTrigger): :param poll_interval: Seconds between polling attempts :param log_chunk_size: Maximum bytes to read per poll :param log_offset: Current byte offset in the log file + :param command_timeout: Per-command timeout in seconds + :param max_reconnect_attempts: Consecutive connection failures tolerated before the + trigger gives up and emits an error event """ def __init__( @@ -69,6 +90,7 @@ def __init__( log_chunk_size: int = 65536, log_offset: int = 0, command_timeout: float = 30.0, + max_reconnect_attempts: int = 5, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -82,6 +104,7 @@ def __init__( self.log_chunk_size = log_chunk_size self.log_offset = log_offset self.command_timeout = command_timeout + self.max_reconnect_attempts = max_reconnect_attempts def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize the trigger for storage.""" @@ -99,6 +122,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "log_chunk_size": self.log_chunk_size, "log_offset": self.log_offset, "command_timeout": self.command_timeout, + "max_reconnect_attempts": self.max_reconnect_attempts, }, ) @@ -109,18 +133,34 @@ def _get_hook(self) -> SSHHookAsync: host=self.remote_host, ) - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(multiplier=1, min=1, max=10), - retry=tenacity.retry_if_exception_type((OSError, TimeoutError, ConnectionError)), - reraise=True, - ) - async def _check_completion(self, hook: SSHHookAsync) -> int | None: + async def _connect(self) -> SSHClientConnection: + """Open a reusable asyncssh connection. Separated out as a seam for testing.""" + return await self._get_hook().get_conn() + + @staticmethod + async def _close(conn: SSHClientConnection) -> None: + """Close a connection, swallowing teardown errors.""" + try: + conn.close() + await conn.wait_closed() + except Exception: + # Teardown is best-effort; a failing close has nothing actionable to recover. + pass + + def _reconnect_delay(self, attempt: int) -> float: + """Exponential backoff with full jitter to desynchronise reconnecting triggers.""" + base = min(2 ** (attempt - 1), 30) + return base + random.uniform(0, base) + + async def _run_command(self, conn: SSHClientConnection, command: str) -> tuple[int, str, str]: + """Run a command on an existing connection, mirroring ``SSHHookAsync.run_command``.""" + result = await conn.run(command, timeout=self.command_timeout, check=False) + return result.exit_status or 0, result.stdout or "", result.stderr or "" + + async def _check_completion(self, conn: SSHClientConnection) -> int | None: """ Check if the remote job has completed. - Retries transient network errors up to 3 times with exponential backoff. - :return: Exit code if completed, None if still running """ if self.remote_os == "posix": @@ -128,62 +168,32 @@ async def _check_completion(self, hook: SSHHookAsync) -> int | None: else: cmd = build_windows_completion_check_command(self.exit_code_file) - try: - _, stdout, _ = await hook.run_command(cmd, timeout=self.command_timeout) - stdout = stdout.strip() - if stdout and stdout.isdigit(): - return int(stdout) - except (OSError, TimeoutError, ConnectionError) as e: - self.log.warning("Transient error checking completion (will retry): %s", e) - raise - except Exception as e: - self.log.warning("Error checking completion status: %s", e) + _, stdout, _ = await self._run_command(conn, cmd) + stdout = stdout.strip() + if stdout and stdout.isdigit(): + return int(stdout) return None - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(multiplier=1, min=1, max=10), - retry=tenacity.retry_if_exception_type((OSError, TimeoutError, ConnectionError)), - reraise=True, - ) - async def _get_log_size(self, hook: SSHHookAsync) -> int: - """ - Get the current size of the log file in bytes. - - Retries transient network errors up to 3 times with exponential backoff. - """ + async def _get_log_size(self, conn: SSHClientConnection) -> int: + """Get the current size of the log file in bytes.""" if self.remote_os == "posix": cmd = build_posix_file_size_command(self.log_file) else: cmd = build_windows_file_size_command(self.log_file) - try: - _, stdout, _ = await hook.run_command(cmd, timeout=self.command_timeout) - stdout = stdout.strip() - if stdout and stdout.isdigit(): - return int(stdout) - except (OSError, TimeoutError, ConnectionError) as e: - self.log.warning("Transient error getting log size (will retry): %s", e) - raise - except Exception as e: - self.log.warning("Error getting log file size: %s", e) + _, stdout, _ = await self._run_command(conn, cmd) + stdout = stdout.strip() + if stdout and stdout.isdigit(): + return int(stdout) return 0 - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(multiplier=1, min=1, max=10), - retry=tenacity.retry_if_exception_type((OSError, TimeoutError, ConnectionError)), - reraise=True, - ) - async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]: + async def _read_log_chunk(self, conn: SSHClientConnection) -> tuple[str, int]: """ Read a chunk of logs from the current offset. - Retries transient network errors up to 3 times with exponential backoff. - :return: Tuple of (log_chunk, new_offset) """ - file_size = await self._get_log_size(hook) + file_size = await self._get_log_size(conn) if file_size <= self.log_offset: return "", self.log_offset @@ -195,47 +205,94 @@ async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]: else: cmd = build_windows_log_tail_command(self.log_file, self.log_offset, bytes_to_read) - try: - exit_code, stdout, _ = await hook.run_command(cmd, timeout=self.command_timeout) + _, stdout, _ = await self._run_command(conn, cmd) - # Advance offset by bytes requested, not decoded string length - new_offset = self.log_offset + bytes_to_read if stdout else self.log_offset + # Advance offset by bytes requested, not decoded string length + new_offset = self.log_offset + bytes_to_read if stdout else self.log_offset + return stdout, new_offset - return stdout, new_offset - except (OSError, TimeoutError, ConnectionError) as e: - self.log.warning("Transient error reading logs (will retry): %s", e) - raise - except Exception as e: - self.log.warning("Error reading log chunk: %s", e) - return "", self.log_offset + def _error_event(self, message: str) -> TriggerEvent: + return TriggerEvent( + { + "job_id": self.job_id, + "job_dir": self.job_dir, + "log_file": self.log_file, + "exit_code_file": self.exit_code_file, + "remote_os": self.remote_os, + "status": "error", + "done": True, + "exit_code": None, + "log_chunk": "", + "log_offset": self.log_offset, + "message": message, + } + ) async def run(self) -> AsyncIterator[TriggerEvent]: """ - Poll the remote job status and yield events with log chunks. + Poll the remote job status and yield a completion event. - This method runs in a loop, checking the job status and reading - logs at each poll interval. It yields a TriggerEvent each time - with the current status and any new log output. + One connection is held for the whole loop. On a connection-level failure the + connection is dropped and re-established (with jittered backoff) up to + ``max_reconnect_attempts`` consecutive times; any other error, or exhausting the + reconnect budget, ends the trigger with an error event. """ - hook = self._get_hook() + conn: SSHClientConnection | None = None + # Consecutive failures since the last *fully successful* poll. A successful + # handshake alone does not reset this: a connection that handshakes but whose + # command channel keeps failing (e.g. ChannelOpenError under sshd MaxSessions) + # must still exhaust the budget instead of looping forever. + failures = 0 - while True: - try: - exit_code = await self._check_completion(hook) - log_chunk, new_offset = await self._read_log_chunk(hook) + try: + while True: + if conn is None: + try: + conn = await self._connect() + except _CONNECTION_ERRORS as e: + failures += 1 + if failures > self.max_reconnect_attempts: + raise + delay = self._reconnect_delay(failures) + self.log.warning( + "Failed to connect to remote host (attempt %d/%d), retrying in %.1fs: %s", + failures, + self.max_reconnect_attempts, + delay, + e, + ) + await asyncio.sleep(delay) + continue + + try: + exit_code = await self._check_completion(conn) + log_chunk, new_offset = await self._read_log_chunk(conn) + except _CONNECTION_ERRORS as e: + failures += 1 + self.log.warning( + "Lost SSH connection while polling (attempt %d/%d), reconnecting: %s", + failures, + self.max_reconnect_attempts, + e, + ) + await self._close(conn) + conn = None + if failures > self.max_reconnect_attempts: + raise + await asyncio.sleep(self._reconnect_delay(failures)) + continue - base_event = { - "job_id": self.job_id, - "job_dir": self.job_dir, - "log_file": self.log_file, - "exit_code_file": self.exit_code_file, - "remote_os": self.remote_os, - } + # A full poll cycle succeeded on this connection; clear the failure budget. + failures = 0 if exit_code is not None: yield TriggerEvent( { - **base_event, + "job_id": self.job_id, + "job_dir": self.job_dir, + "log_file": self.log_file, + "exit_code_file": self.exit_code_file, + "remote_os": self.remote_os, "status": "success" if exit_code == 0 else "failed", "done": True, "exit_code": exit_code, @@ -251,21 +308,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("%s", log_chunk.rstrip()) await asyncio.sleep(self.poll_interval) - except Exception as e: - self.log.exception("Error in SSH remote job trigger") - yield TriggerEvent( - { - "job_id": self.job_id, - "job_dir": self.job_dir, - "log_file": self.log_file, - "exit_code_file": self.exit_code_file, - "remote_os": self.remote_os, - "status": "error", - "done": True, - "exit_code": None, - "log_chunk": "", - "log_offset": self.log_offset, - "message": f"Trigger error: {e}", - } - ) - return + except Exception as e: + self.log.exception("Error in SSH remote job trigger") + yield self._error_event(f"Trigger error: {e}") + return + finally: + if conn is not None: + await self._close(conn) diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py index 289cdfa3e48dc..e277c1dc3f654 100644 --- a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py @@ -594,6 +594,26 @@ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_true(self, s assert ssh_client.return_value.connect.called is True assert ssh_client.return_value.set_missing_host_key_policy.called is True + @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") + def test_conn_retry_attempts_defaults_to_three(self, ssh_client): + hook = SSHHook(ssh_conn_id="ssh_default") + assert hook.conn_retry_attempts == 3 + + @mock.patch("time.sleep") + @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") + def test_conn_retry_attempts_retries_until_limit(self, ssh_client, _mock_sleep): + """get_conn retries the configured number of times before re-raising.""" + ssh_client.return_value.connect.side_effect = paramiko.ssh_exception.SSHException( + "Error reading SSH protocol banner" + ) + hook = SSHHook(ssh_conn_id="ssh_default", conn_retry_attempts=4) + assert hook.conn_retry_attempts == 4 + + with pytest.raises(paramiko.ssh_exception.SSHException): + hook.get_conn() + + assert ssh_client.return_value.connect.call_count == 4 + @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") def test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self, ssh_client): hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE) diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py index 871f398139384..63ed479120789 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py @@ -123,6 +123,56 @@ def test_execute_defers_to_trigger(self): assert isinstance(exc_info.value.trigger, SSHRemoteJobTrigger) assert exc_info.value.method_name == "execute_complete" + def test_execute_forwards_trigger_tuning_params(self): + """command_timeout and max_reconnect_attempts must reach the deferred trigger.""" + self.mock_hook.exec_ssh_client_command.return_value = (0, b"job", b"") + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + remote_os="posix", + command_timeout=12.5, + max_reconnect_attempts=9, + ) + mock_ti = mock.MagicMock() + mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number = "d", "t", "r", 1 + + with pytest.raises(TaskDeferred) as exc_info: + op.execute({"ti": mock_ti}) + + trigger = exc_info.value.trigger + assert trigger.command_timeout == 12.5 + assert trigger.max_reconnect_attempts == 9 + + def test_execute_complete_re_defer_forwards_tuning_params(self): + """The re-defer path must also forward the trigger-tuning params.""" + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + command_timeout=7.0, + max_reconnect_attempts=4, + ) + event = { + "done": False, + "status": "running", + "job_id": "j", + "job_dir": "/tmp/airflow-ssh-jobs/j", + "log_file": "/tmp/airflow-ssh-jobs/j/stdout.log", + "exit_code_file": "/tmp/airflow-ssh-jobs/j/exit_code", + "remote_os": "posix", + "log_chunk": "", + "log_offset": 0, + "exit_code": None, + } + with pytest.raises(TaskDeferred) as exc_info: + op.execute_complete({}, event) + + trigger = exc_info.value.trigger + assert trigger.command_timeout == 7.0 + assert trigger.max_reconnect_attempts == 4 + def test_execute_raises_if_no_command(self): """Test that execute raises if command is not specified.""" op = SSHRemoteJobOperator( @@ -329,3 +379,68 @@ def test_on_kill_no_active_job(self): # Should not raise even without active job op.on_kill() + + def test_execute_uses_single_connection_for_detect_and_submit(self): + """OS auto-detection and submission must share one SSH connection (one handshake).""" + self.mock_hook.exec_ssh_client_command.side_effect = [ + (0, b"Linux", b""), # OS detection + (0, b"af_test_dag_test_task_run1_try1_abc123", b""), # submission + ] + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + remote_os="auto", + ) + + mock_ti = mock.MagicMock() + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "test_task" + mock_ti.run_id = "run1" + mock_ti.try_number = 1 + + with pytest.raises(TaskDeferred): + op.execute({"ti": mock_ti}) + + # One handshake for the whole execute(): detection + submit reuse it. + self.mock_hook.get_conn.assert_called_once() + assert self.mock_hook.exec_ssh_client_command.call_count == 2 + assert op._detected_os == "posix" + + def test_cleanup_retries_then_succeeds(self): + """Cleanup retries on a transient SSH failure and stops once it succeeds.""" + self.mock_hook.exec_ssh_client_command.side_effect = [ + Exception("Error reading SSH protocol banner"), + (0, b"", b""), + ] + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + cleanup_retries=3, + ) + + with mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep") as mock_sleep: + op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123", "posix") + + assert self.mock_hook.exec_ssh_client_command.call_count == 2 + mock_sleep.assert_called_once() + + def test_cleanup_gives_up_after_retries_without_raising(self): + """When every cleanup attempt fails the task is not failed; the dir is left in place.""" + self.mock_hook.exec_ssh_client_command.side_effect = Exception("connection refused") + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + cleanup_retries=3, + ) + + with mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep"): + # Must not raise even though all attempts fail. + op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123", "posix") + + assert self.mock_hook.exec_ssh_client_command.call_count == 3 diff --git a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py index 67d9672a8cbeb..8e1d1e5953a6f 100644 --- a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py +++ b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py @@ -24,6 +24,20 @@ from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger +def _make_trigger(**overrides): + kwargs = dict( + ssh_conn_id="test_conn", + remote_host=None, + job_id="test_job", + job_dir="/tmp/job", + log_file="/tmp/job/stdout.log", + exit_code_file="/tmp/job/exit_code", + remote_os="posix", + ) + kwargs.update(overrides) + return SSHRemoteJobTrigger(**kwargs) + + class TestSSHRemoteJobTrigger: def test_serialization(self): """Test that the trigger can be serialized correctly.""" @@ -38,6 +52,7 @@ def test_serialization(self): poll_interval=10, log_chunk_size=32768, log_offset=1000, + max_reconnect_attempts=7, ) classpath, kwargs = trigger.serialize() @@ -53,145 +68,201 @@ def test_serialization(self): assert kwargs["poll_interval"] == 10 assert kwargs["log_chunk_size"] == 32768 assert kwargs["log_offset"] == 1000 + assert kwargs["max_reconnect_attempts"] == 7 + + def test_serialization_round_trips(self): + """The serialized kwargs must be sufficient to rebuild the trigger.""" + trigger = _make_trigger(poll_interval=3, max_reconnect_attempts=2) + _, kwargs = trigger.serialize() + rebuilt = SSHRemoteJobTrigger(**kwargs) + assert rebuilt.serialize() == trigger.serialize() def test_default_values(self): """Test default parameter values.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger() assert trigger.poll_interval == 5 assert trigger.log_chunk_size == 65536 assert trigger.log_offset == 0 + assert trigger.max_reconnect_attempts == 5 @pytest.mark.asyncio async def test_run_job_completed_success(self): """Test trigger when job completes successfully.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger() - with mock.patch.object(trigger, "_check_completion", return_value=0): - with mock.patch.object(trigger, "_read_log_chunk", return_value=("Final output\n", 100)): - events = [] - async for event in trigger.run(): - events.append(event) + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object(trigger, "_check_completion", return_value=0), + mock.patch.object(trigger, "_read_log_chunk", return_value=("Final output\n", 100)), + ): + events = [event async for event in trigger.run()] - assert len(events) == 1 - assert events[0].payload["status"] == "success" - assert events[0].payload["done"] is True - assert events[0].payload["exit_code"] == 0 - assert events[0].payload["log_chunk"] == "Final output\n" + assert len(events) == 1 + assert events[0].payload["status"] == "success" + assert events[0].payload["done"] is True + assert events[0].payload["exit_code"] == 0 + assert events[0].payload["log_chunk"] == "Final output\n" @pytest.mark.asyncio async def test_run_job_completed_failure(self): """Test trigger when job completes with failure.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger() - with mock.patch.object(trigger, "_check_completion", return_value=1): - with mock.patch.object(trigger, "_read_log_chunk", return_value=("Error output\n", 50)): - events = [] - async for event in trigger.run(): - events.append(event) + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object(trigger, "_check_completion", return_value=1), + mock.patch.object(trigger, "_read_log_chunk", return_value=("Error output\n", 50)), + ): + events = [event async for event in trigger.run()] - assert len(events) == 1 - assert events[0].payload["status"] == "failed" - assert events[0].payload["done"] is True - assert events[0].payload["exit_code"] == 1 + assert len(events) == 1 + assert events[0].payload["status"] == "failed" + assert events[0].payload["done"] is True + assert events[0].payload["exit_code"] == 1 @pytest.mark.asyncio - async def test_run_job_polls_until_completion(self): - """Test trigger polls without yielding until job completes.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - poll_interval=0.01, - ) + async def test_run_reuses_single_connection_across_polls(self): + """The connection is opened once and reused for every poll, not per command.""" + trigger = _make_trigger(poll_interval=0.01) poll_count = 0 async def mock_check_completion(_): nonlocal poll_count poll_count += 1 - # Return None (still running) for first 2 polls, then exit code 0 - if poll_count < 3: - return None + return None if poll_count < 3 else 0 + + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()) as mock_connect, + mock.patch.object(trigger, "_close", return_value=None) as mock_close, + mock.patch.object(trigger, "_check_completion", side_effect=mock_check_completion), + mock.patch.object(trigger, "_read_log_chunk", return_value=("output\n", 50)), + ): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "success" + assert poll_count == 3 + # One connect for the whole loop, closed once at teardown. + assert mock_connect.call_count == 1 + assert mock_close.call_count == 1 + + @pytest.mark.asyncio + async def test_run_reconnects_on_connection_drop(self): + """A connection-level error mid-poll drops the connection and reconnects.""" + trigger = _make_trigger(poll_interval=0.01, max_reconnect_attempts=3) + + calls = {"check": 0} + + async def flaky_check(_): + calls["check"] += 1 + if calls["check"] == 1: + raise OSError("Error reading SSH protocol banner") return 0 - with mock.patch.object(trigger, "_check_completion", side_effect=mock_check_completion): - with mock.patch.object(trigger, "_read_log_chunk", return_value=("output\n", 50)): - events = [] - async for event in trigger.run(): - events.append(event) + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()) as mock_connect, + mock.patch.object(trigger, "_close", return_value=None) as mock_close, + mock.patch.object(trigger, "_check_completion", side_effect=flaky_check), + mock.patch.object(trigger, "_read_log_chunk", return_value=("out\n", 10)), + mock.patch("asyncio.sleep", new=mock.AsyncMock()), + ): + events = [event async for event in trigger.run()] - # Only one event should be yielded (the completion event) - assert len(events) == 1 - assert events[0].payload["status"] == "success" - assert events[0].payload["done"] is True - assert events[0].payload["exit_code"] == 0 - # Should have polled 3 times - assert poll_count == 3 + assert len(events) == 1 + assert events[0].payload["status"] == "success" + # Initial connect + one reconnect after the dropped poll. + assert mock_connect.call_count == 2 + # Dropped connection closed during reconnect, plus final teardown close. + assert mock_close.call_count == 2 @pytest.mark.asyncio - async def test_run_handles_exception(self): - """Test trigger handles exceptions gracefully.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + async def test_run_gives_up_after_max_reconnects(self): + """When connections keep failing, the trigger emits a single error event.""" + trigger = _make_trigger(max_reconnect_attempts=2) + + with ( + mock.patch.object(trigger, "_connect", side_effect=OSError("connection refused")), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch("asyncio.sleep", new=mock.AsyncMock()), + ): + events = [event async for event in trigger.run()] - with mock.patch.object(trigger, "_check_completion", side_effect=Exception("Connection failed")): - events = [] - async for event in trigger.run(): - events.append(event) + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert events[0].payload["done"] is True + assert events[0].payload["exit_code"] is None + assert "connection refused" in events[0].payload["message"] - assert len(events) == 1 - assert events[0].payload["status"] == "error" - assert events[0].payload["done"] is True - assert "Connection failed" in events[0].payload["message"] + @pytest.mark.asyncio + async def test_run_gives_up_when_polls_keep_failing_despite_reconnects(self): + """A connection that handshakes but whose polls keep failing must still hit the cap. + + Regression: the reconnect budget must not reset on a bare successful handshake, or a + connection that reconnects fine but never completes a poll (e.g. ChannelOpenError under + sshd MaxSessions) would loop forever and the task would defer indefinitely. + """ + trigger = _make_trigger(max_reconnect_attempts=2) + + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()) as mock_connect, + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object( + trigger, "_check_completion", side_effect=OSError("channel open failed") + ) as mock_check, + mock.patch("asyncio.sleep", new=mock.AsyncMock()), + ): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert "channel open failed" in events[0].payload["message"] + # Budget = 2 -> third consecutive failure ends it (handshake succeeds each round + # but never resets the counter because no poll ever completes). + assert mock_check.call_count == 3 + assert mock_connect.call_count == 3 + + @pytest.mark.asyncio + async def test_run_handles_unexpected_exception(self): + """A non-connection error surfaces immediately as an error event.""" + trigger = _make_trigger() + + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object(trigger, "_check_completion", side_effect=ValueError("boom")), + ): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert events[0].payload["done"] is True + assert "boom" in events[0].payload["message"] def test_get_hook(self): """Test hook creation.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host="custom.host.com", - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger(remote_host="custom.host.com") hook = trigger._get_hook() assert hook.ssh_conn_id == "test_conn" assert hook.host == "custom.host.com" + + @pytest.mark.asyncio + async def test_run_command_uses_existing_connection(self): + """_run_command runs on the passed connection without opening a new one.""" + trigger = _make_trigger(command_timeout=12.0) + + result = mock.MagicMock() + result.exit_status = 0 + result.stdout = "42" + result.stderr = "" + conn = mock.MagicMock() + conn.run = mock.AsyncMock(return_value=result) + + exit_code, stdout, stderr = await trigger._run_command(conn, "echo 42") + + conn.run.assert_awaited_once_with("echo 42", timeout=12.0, check=False) + assert (exit_code, stdout, stderr) == (0, "42", "") From a0892feb005e4aedd1127214c6b8035be907f54a Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 6 Jun 2026 03:04:27 +0100 Subject: [PATCH 2/3] Fix mypy return type and docs spellcheck in SSH remote job trigger - _run_command decodes bytes stdout/stderr so the return matches tuple[int, str, str] (asyncssh types them as bytes | str). - Drop 'jittered'/'desynchronise' from docstrings (Sphinx spellcheck). --- .../providers/ssh/triggers/ssh_remote_job.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py index cdb441176e389..2d390a3f2e64f 100644 --- a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py +++ b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py @@ -148,14 +148,22 @@ async def _close(conn: SSHClientConnection) -> None: pass def _reconnect_delay(self, attempt: int) -> float: - """Exponential backoff with full jitter to desynchronise reconnecting triggers.""" + """Exponential backoff with randomness so reconnecting triggers do not retry in lockstep.""" base = min(2 ** (attempt - 1), 30) return base + random.uniform(0, base) async def _run_command(self, conn: SSHClientConnection, command: str) -> tuple[int, str, str]: """Run a command on an existing connection, mirroring ``SSHHookAsync.run_command``.""" result = await conn.run(command, timeout=self.command_timeout, check=False) - return result.exit_status or 0, result.stdout or "", result.stderr or "" + stdout = result.stdout or "" + stderr = result.stderr or "" + # asyncssh types stdout/stderr as bytes | str; with the default text encoding they are + # str, but decode defensively so the helper holds if a binary connection is ever used. + if isinstance(stdout, bytes): + stdout = stdout.decode("utf-8", errors="replace") + if isinstance(stderr, bytes): + stderr = stderr.decode("utf-8", errors="replace") + return result.exit_status or 0, stdout, stderr async def _check_completion(self, conn: SSHClientConnection) -> int | None: """ @@ -233,7 +241,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: Poll the remote job status and yield a completion event. One connection is held for the whole loop. On a connection-level failure the - connection is dropped and re-established (with jittered backoff) up to + connection is dropped and re-established (with exponential backoff) up to ``max_reconnect_attempts`` consecutive times; any other error, or exhausting the reconnect budget, ends the trigger with an error event. """ From b692c002fa82641779a1e9039552d946deb89a81 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Mon, 8 Jun 2026 23:43:34 +0100 Subject: [PATCH 3/3] Address review: log hint on submit-connection failure + mapped-task docs - execute() now logs an actionable hint when the submit connection fails (sshd MaxStartups under concurrency, conn_retry_attempts, pools/max_active_tis_per_dag), then re-raises the original error. Scoped to the connect step, no error matching. - High Fan-out docs: link the dynamic task mapping limits page and note the storm is not specific to mapped tasks (parallel runs/high concurrency too). --- providers/ssh/docs/operators/ssh_remote_job.rst | 8 +++++--- .../providers/ssh/operators/ssh_remote_job.py | 17 ++++++++++++++++- .../unit/ssh/operators/test_ssh_remote_job.py | 16 ++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/providers/ssh/docs/operators/ssh_remote_job.rst b/providers/ssh/docs/operators/ssh_remote_job.rst index 84757a9f11d0c..7a8783b9861cd 100644 --- a/providers/ssh/docs/operators/ssh_remote_job.rst +++ b/providers/ssh/docs/operators/ssh_remote_job.rst @@ -249,7 +249,8 @@ reaper (for example ``systemd-tmpfiles`` or a cron job) for the base directory. High Fan-out (Many Concurrent Tasks) ------------------------------------- -A large ``.expand()`` fan-out points many tasks at the same SSH server at once. Each remote +Many tasks targeting the same SSH server at once (a large ``.expand()`` fan-out, parallel DAG +runs, or just high concurrency) can overwhelm it. Each remote command opens a new SSH connection, and the remote ``sshd`` throttles concurrent *unauthenticated* connections via ``MaxStartups`` (default ``10:30:100``: start randomly dropping at 10 concurrent, reaching 100% at 100). A dropped connection surfaces on the client @@ -266,8 +267,9 @@ poll loop instead of reconnecting on every status check. To push a high fan-out * Raise ``MaxStartups`` (and ``MaxSessions``) on the remote ``sshd`` -- this is the direct fix. * Increase ``conn_retry_attempts`` so transient refusals during the initial burst are retried. -* Spread submissions out (for example with a pool or ``max_active_tis_per_dag``) rather than - releasing the entire fan-out simultaneously. +* Cap how many mapped tasks run at once with ``max_active_tis_per_dag`` (or a pool) instead of + releasing the entire fan-out simultaneously. See the "Placing Limits on Mapped Tasks" section of + :doc:`apache-airflow:authoring-and-scheduling/dynamic-task-mapping` for the available limits. Comparison with SSHOperator ---------------------------- diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py index cfe18a1e5c59b..dc3414eca686d 100644 --- a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py +++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py @@ -284,7 +284,22 @@ def execute(self, context: Context) -> None: # operator opens one SSH handshake per task instead of two. Under a large fan-out # this halves the connection burst that triggers sshd MaxStartups throttling. self.log.info("Connecting to %s", self.ssh_hook.remote_host) - with self.ssh_hook.get_conn() as ssh_client: + try: + ssh_conn = self.ssh_hook.get_conn() + except Exception: + self.log.error( + "Failed to connect to %s to submit the remote job. When many SSH connections reach " + "the same host at once, the server can start refusing new ones before the handshake " + "(for example sshd MaxStartups). This is not limited to mapped tasks: parallel DAG " + "runs or high concurrency can cause it too. Try raising MaxStartups/MaxSessions on " + "the server, increasing conn_retry_attempts (currently %d), or reducing concurrency " + "with a pool (or max_active_tis_per_dag for mapped tasks). See the " + "SSHRemoteJobOperator 'High Fan-out' docs.", + self.ssh_hook.remote_host, + self.conn_retry_attempts, + ) + raise + with ssh_conn as ssh_client: self._detected_os = self._detect_remote_os(ssh_client) self.log.info("Remote OS: %s", self._detected_os) diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py index 63ed479120789..6ecf80baef28d 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py @@ -173,6 +173,22 @@ def test_execute_complete_re_defer_forwards_tuning_params(self): assert trigger.command_timeout == 7.0 assert trigger.max_reconnect_attempts == 4 + def test_execute_connect_failure_is_reraised(self): + """A connection failure during submit is re-raised unchanged (advisory is log-only).""" + self.mock_hook.get_conn.side_effect = OSError("Error reading SSH protocol banner") + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + remote_os="posix", + ) + mock_ti = mock.MagicMock() + mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number = "d", "t", "r", 1 + + with pytest.raises(OSError, match="Error reading SSH protocol banner"): + op.execute({"ti": mock_ti}) + def test_execute_raises_if_no_command(self): """Test that execute raises if command is not specified.""" op = SSHRemoteJobOperator(