diff --git a/providers/ssh/docs/operators/ssh_remote_job.rst b/providers/ssh/docs/operators/ssh_remote_job.rst index c78a07927405f..7a8783b9861cd 100644 --- a/providers/ssh/docs/operators/ssh_remote_job.rst +++ b/providers/ssh/docs/operators/ssh_remote_job.rst @@ -164,6 +164,13 @@ Parameters * ``remote_os`` (str, optional): Remote OS type (``"auto"``, ``"posix"``, ``"windows"``). Default: ``"auto"`` * ``skip_on_exit_code`` (int or list, optional): Exit code(s) that should cause task to skip instead of fail +* ``conn_timeout`` (int, optional): SSH connection timeout in seconds +* ``banner_timeout`` (float, optional): Seconds to wait for the SSH banner. Default: 30.0 +* ``conn_retry_attempts`` (int, optional): How many times to attempt the initial SSH connection for + submission and cleanup before failing. Default: 5. Raise this for large fan-outs where the remote + ``sshd`` transiently refuses connections (see :ref:`High Fan-out `) +* ``cleanup_retries`` (int, optional): How many times to retry remote directory cleanup before giving up + and leaving the directory in place. Default: 3 Remote OS Detection ------------------- @@ -213,7 +220,9 @@ Limitations and Considerations ------------------------------- **Network Interruptions**: While the operator is resilient to disconnections during monitoring, -the initial job submission must succeed. If submission fails, the task will fail immediately. +the initial job submission must succeed. The connection used for submission is retried +(``conn_retry_attempts``); if every attempt fails, the task fails immediately. The trigger also +reconnects automatically if the monitoring connection drops mid-job. **Remote Process Management**: Jobs are detached using ``nohup`` (POSIX) or ``Start-Process`` (Windows). If the remote host reboots during job execution, the job will be lost. @@ -231,7 +240,36 @@ tasks can run on the same remote host without conflicts. **Cleanup**: Use ``cleanup="on_success"`` or ``cleanup="always"`` to avoid accumulating job directories on the remote host. For debugging, use ``cleanup="never"`` and manually -inspect the job directory. +inspect the job directory. Cleanup runs only when the job reaches completion, so tasks that +are killed or time out can still leave a directory behind; for those, add a server-side TTL +reaper (for example ``systemd-tmpfiles`` or a cron job) for the base directory. + +.. _howto/operator:SSHRemoteJobOperator:fanout: + +High Fan-out (Many Concurrent Tasks) +------------------------------------- + +Many tasks targeting the same SSH server at once (a large ``.expand()`` fan-out, parallel DAG +runs, or just high concurrency) can overwhelm it. Each remote +command opens a new SSH connection, and the remote ``sshd`` throttles concurrent +*unauthenticated* connections via ``MaxStartups`` (default ``10:30:100``: start randomly +dropping at 10 concurrent, reaching 100% at 100). A dropped connection surfaces on the client +as:: + + paramiko ... Error reading SSH protocol banner + +This is the server closing the socket before the handshake, not a slow banner, so raising +``banner_timeout`` does not help. + +The operator and trigger keep the connection rate low: submission reuses a single connection +for OS detection and the submit itself, and the trigger holds **one** connection for the whole +poll loop instead of reconnecting on every status check. To push a high fan-out further: + +* Raise ``MaxStartups`` (and ``MaxSessions``) on the remote ``sshd`` -- this is the direct fix. +* Increase ``conn_retry_attempts`` so transient refusals during the initial burst are retried. +* Cap how many mapped tasks run at once with ``max_active_tis_per_dag`` (or a pool) instead of + releasing the entire fan-out simultaneously. See the "Placing Limits on Mapped Tasks" section of + :doc:`apache-airflow:authoring-and-scheduling/dynamic-task-mapping` for the available limits. Comparison with SSHOperator ---------------------------- diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py index d029a50772b35..54b7edf6008c7 100644 --- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py @@ -82,6 +82,9 @@ class SSHHook(BaseHook): lifetime of the transport :param ciphers: list of ciphers to use in order of preference :param auth_timeout: timeout (in seconds) for the attempt to authenticate with the remote_host + :param conn_retry_attempts: number of times to attempt the initial SSH connection before + giving up (default 3). Raising this helps when many tasks target the same SSH server at + once and some connections are transiently refused (e.g. ``sshd`` ``MaxStartups`` throttling). """ # List of classes to try loading private keys as, ordered (roughly) by most common to least common @@ -130,9 +133,11 @@ def __init__( ciphers: list[str] | None = None, auth_timeout: int | None = None, host_proxy_cmd: str | None = None, + conn_retry_attempts: int = 3, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id + self.conn_retry_attempts = max(1, conn_retry_attempts) self.remote_host = remote_host self.username = username self.password = password @@ -344,7 +349,7 @@ def log_before_sleep(retry_state): for attempt in Retrying( reraise=True, wait=wait_fixed(3) + wait_random(0, 2), - stop=stop_after_attempt(3), + stop=stop_after_attempt(self.conn_retry_attempts), before_sleep=log_before_sleep, ): with attempt: @@ -553,6 +558,7 @@ def __init__( key_file: str = "", passphrase: str = "", private_key: str = "", + keepalive_interval: int = 30, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -564,6 +570,7 @@ def __init__( self.key_file = key_file self.passphrase = passphrase self.private_key = private_key + self.keepalive_interval = keepalive_interval def _parse_extras(self, conn: Any) -> None: """Parse extra fields from the connection into instance fields.""" @@ -631,10 +638,26 @@ def _get_value(self_val, conn_val, default=None): conn_config["client_keys"] = [_private_key] if self.passphrase: conn_config["passphrase"] = self.passphrase + if self.keepalive_interval: + # The trigger holds one connection for the whole job; a keepalive stops idle + # NAT/firewall timeouts from silently dropping it between long poll intervals. + conn_config["keepalive_interval"] = self.keepalive_interval ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn + async def get_conn(self): + """ + Open an asyncssh connection that can be reused for multiple commands. + + Unlike :meth:`run_command`, the returned connection is **not** closed + automatically; the caller owns its lifecycle (e.g. + ``async with await hook.get_conn() as conn: ...`` or an explicit + ``conn.close()``). Reusing one connection avoids a new TCP/SSH handshake + per command, which matters when many tasks poll the same SSH server. + """ + return await self._get_conn() + async def run_command(self, command: str, timeout: float | None = None) -> tuple[int, str, str]: """ Execute a command on the remote host asynchronously. diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py index 783edb39b8fd7..dc3414eca686d 100644 --- a/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py +++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh_remote_job.py @@ -19,6 +19,7 @@ from __future__ import annotations +import time import warnings from collections.abc import Container, Sequence from datetime import timedelta @@ -74,6 +75,25 @@ class SSHRemoteJobOperator(BaseOperator): :param skip_on_exit_code: Exit codes that should skip the task instead of failing :param conn_timeout: SSH connection timeout in seconds :param banner_timeout: Timeout waiting for SSH banner in seconds + :param conn_retry_attempts: How many times to attempt the initial SSH connection for + submission/cleanup before failing (default 5). Helps when many mapped tasks hit the + same host at once and ``sshd`` transiently refuses connections (``MaxStartups``). + :param cleanup_retries: How many times to attempt remote directory cleanup before + giving up and leaving the directory in place (default 3). Prevents a transient SSH + failure during cleanup from orphaning the job directory on the remote host. + :param command_timeout: Per-command timeout in seconds for the trigger's status/log polls + (default 30.0). + :param max_reconnect_attempts: Consecutive connection failures the trigger tolerates (with + backoff) before failing the task while monitoring the remote job (default 5). + + .. note:: + A large ``expand()`` fan-out opens many SSH connections against one host. The remote + ``sshd`` throttles concurrent unauthenticated connections via ``MaxStartups`` (default + ``10:30:100``); when exceeded it drops connections, surfacing as + ``paramiko ... Error reading SSH protocol banner``. For high fan-out, raise ``MaxStartups`` + on the server. The directory ``/tmp/airflow-ssh-jobs`` (POSIX) is only cleaned when + ``cleanup`` is set and the job reaches completion, so also consider a server-side TTL + reaper (for example ``systemd-tmpfiles``) for jobs that are killed or time out. """ template_fields: Sequence[str] = ("command", "environment", "remote_host", "remote_base_dir") @@ -104,6 +124,10 @@ def __init__( skip_on_exit_code: int | Container[int] | None = None, conn_timeout: int | None = None, banner_timeout: float = 30.0, + conn_retry_attempts: int = 5, + cleanup_retries: int = 3, + command_timeout: float = 30.0, + max_reconnect_attempts: int = 5, **kwargs, ) -> None: super().__init__(**kwargs) @@ -123,6 +147,10 @@ def __init__( self.remote_os = remote_os self.conn_timeout = conn_timeout self.banner_timeout = banner_timeout + self.conn_retry_attempts = conn_retry_attempts + self.cleanup_retries = max(1, cleanup_retries) + self.command_timeout = command_timeout + self.max_reconnect_attempts = max_reconnect_attempts self.skip_on_exit_code = ( skip_on_exit_code if isinstance(skip_on_exit_code, Container) @@ -170,67 +198,69 @@ def ssh_hook(self) -> SSHHook: remote_host=self.remote_host or "", conn_timeout=self.conn_timeout, banner_timeout=self.banner_timeout, + conn_retry_attempts=self.conn_retry_attempts, ) - def _detect_remote_os(self) -> Literal["posix", "windows"]: + def _detect_remote_os(self, ssh_client) -> Literal["posix", "windows"]: """ - Detect the remote operating system. + Detect the remote operating system on an already-open SSH connection. Uses a two-stage detection: 1. Try POSIX detection via `uname` (works on Linux, macOS, BSD, Solaris, AIX, etc.) 2. Try Windows detection via PowerShell 3. Raise error if both fail + + :param ssh_client: An open paramiko SSH client to reuse (avoids a second handshake). """ if self.remote_os != "auto": return self.remote_os self.log.info("Auto-detecting remote operating system...") - with self.ssh_hook.get_conn() as ssh_client: - try: - exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( - ssh_client, - build_posix_os_detection_command(), - get_pty=False, - environment=None, - timeout=10, - ) - if exit_status == 0 and stdout: - output = stdout.decode("utf-8", errors="replace").strip().lower() - posix_systems = [ - "linux", - "darwin", - "freebsd", - "openbsd", - "netbsd", - "sunos", - "aix", - "hp-ux", - ] - if any(system in output for system in posix_systems): - self.log.info("Detected POSIX system: %s", output) - return "posix" - except Exception as e: - self.log.debug("POSIX detection failed: %s", e) - - try: - exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( - ssh_client, - build_windows_os_detection_command(), - get_pty=False, - environment=None, - timeout=10, - ) - if exit_status == 0 and stdout: - output = stdout.decode("utf-8", errors="replace").strip() - if "WINDOWS" in output.upper(): - self.log.info("Detected Windows system") - return "windows" - except Exception as e: - self.log.debug("Windows detection failed: %s", e) + try: + exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( + ssh_client, + build_posix_os_detection_command(), + get_pty=False, + environment=None, + timeout=10, + ) + if exit_status == 0 and stdout: + output = stdout.decode("utf-8", errors="replace").strip().lower() + posix_systems = [ + "linux", + "darwin", + "freebsd", + "openbsd", + "netbsd", + "sunos", + "aix", + "hp-ux", + ] + if any(system in output for system in posix_systems): + self.log.info("Detected POSIX system: %s", output) + return "posix" + except Exception as e: + self.log.debug("POSIX detection failed: %s", e) - raise AirflowException( - "Could not auto-detect remote OS. Please explicitly set remote_os='posix' or 'windows'" + try: + exit_status, stdout, _ = self.ssh_hook.exec_ssh_client_command( + ssh_client, + build_windows_os_detection_command(), + get_pty=False, + environment=None, + timeout=10, ) + if exit_status == 0 and stdout: + output = stdout.decode("utf-8", errors="replace").strip() + if "WINDOWS" in output.upper(): + self.log.info("Detected Windows system") + return "windows" + except Exception as e: + self.log.debug("Windows detection failed: %s", e) + + raise AirflowException( + "Could not auto-detect remote OS. Please explicitly set remote_os='posix' or 'windows'" + ) def execute(self, context: Context) -> None: """ @@ -241,9 +271,6 @@ def execute(self, context: Context) -> None: if not self.command: raise AirflowException("SSH operator error: command not specified.") - self._detected_os = self._detect_remote_os() - self.log.info("Remote OS: %s", self._detected_os) - ti = context["ti"] self._job_id = generate_job_id( dag_id=ti.dag_id, @@ -253,27 +280,49 @@ def execute(self, context: Context) -> None: ) self.log.info("Generated job ID: %s", self._job_id) - self._paths = RemoteJobPaths( - job_id=self._job_id, - remote_os=self._detected_os, - base_dir=self.remote_base_dir, - ) - - if self._detected_os == "posix": - wrapper_cmd = build_posix_wrapper_command( - command=self.command, - paths=self._paths, - environment=self.environment, + # Reuse a single connection for OS detection (when 'auto') and submission so the + # operator opens one SSH handshake per task instead of two. Under a large fan-out + # this halves the connection burst that triggers sshd MaxStartups throttling. + self.log.info("Connecting to %s", self.ssh_hook.remote_host) + try: + ssh_conn = self.ssh_hook.get_conn() + except Exception: + self.log.error( + "Failed to connect to %s to submit the remote job. When many SSH connections reach " + "the same host at once, the server can start refusing new ones before the handshake " + "(for example sshd MaxStartups). This is not limited to mapped tasks: parallel DAG " + "runs or high concurrency can cause it too. Try raising MaxStartups/MaxSessions on " + "the server, increasing conn_retry_attempts (currently %d), or reducing concurrency " + "with a pool (or max_active_tis_per_dag for mapped tasks). See the " + "SSHRemoteJobOperator 'High Fan-out' docs.", + self.ssh_hook.remote_host, + self.conn_retry_attempts, ) - else: - wrapper_cmd = build_windows_wrapper_command( - command=self.command, - paths=self._paths, - environment=self.environment, + raise + with ssh_conn as ssh_client: + self._detected_os = self._detect_remote_os(ssh_client) + self.log.info("Remote OS: %s", self._detected_os) + + self._paths = RemoteJobPaths( + job_id=self._job_id, + remote_os=self._detected_os, + base_dir=self.remote_base_dir, ) - self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host) - with self.ssh_hook.get_conn() as ssh_client: + if self._detected_os == "posix": + wrapper_cmd = build_posix_wrapper_command( + command=self.command, + paths=self._paths, + environment=self.environment, + ) + else: + wrapper_cmd = build_windows_wrapper_command( + command=self.command, + paths=self._paths, + environment=self.environment, + ) + + self.log.info("Submitting remote job to %s", self.ssh_hook.remote_host) exit_status, stdout, stderr = self.ssh_hook.exec_ssh_client_command( ssh_client, wrapper_cmd, @@ -320,6 +369,8 @@ def execute(self, context: Context) -> None: poll_interval=self.poll_interval, log_chunk_size=self.log_chunk_size, log_offset=0, + command_timeout=self.command_timeout, + max_reconnect_attempts=self.max_reconnect_attempts, ), method_name="execute_complete", timeout=timedelta(seconds=self.timeout) if self.timeout else None, @@ -361,6 +412,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: poll_interval=self.poll_interval, log_chunk_size=self.log_chunk_size, log_offset=event.get("log_offset", 0), + command_timeout=self.command_timeout, + max_reconnect_attempts=self.max_reconnect_attempts, ), method_name="execute_complete", timeout=timedelta(seconds=self.timeout) if self.timeout else None, @@ -389,25 +442,46 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: self.log.info("Remote job completed successfully") def _cleanup_remote_job(self, job_dir: str, remote_os: str) -> None: - """Clean up the remote job directory.""" + """ + Clean up the remote job directory, retrying on transient SSH failures. + + Under a large fan-out the cleanup connection can itself be refused by the + remote ``sshd`` (``MaxStartups``). Retrying a few times keeps a transient drop + from orphaning the job directory; if every attempt fails we log loudly and + leave the directory rather than failing the (already finished) task. + """ self.log.info("Cleaning up remote job directory: %s", job_dir) - try: - if remote_os == "posix": - cleanup_cmd = build_posix_cleanup_command(job_dir) - else: - cleanup_cmd = build_windows_cleanup_command(job_dir) + if remote_os == "posix": + cleanup_cmd = build_posix_cleanup_command(job_dir) + else: + cleanup_cmd = build_windows_cleanup_command(job_dir) - with self.ssh_hook.get_conn() as ssh_client: - self.ssh_hook.exec_ssh_client_command( - ssh_client, - cleanup_cmd, - get_pty=False, - environment=None, - timeout=30, - ) - self.log.info("Remote cleanup completed") - except Exception as e: - self.log.warning("Failed to clean up remote job directory: %s", e) + last_error: Exception | None = None + for attempt in range(1, self.cleanup_retries + 1): + try: + with self.ssh_hook.get_conn() as ssh_client: + self.ssh_hook.exec_ssh_client_command( + ssh_client, + cleanup_cmd, + get_pty=False, + environment=None, + timeout=30, + ) + self.log.info("Remote cleanup completed") + return + except Exception as e: + last_error = e + self.log.warning("Cleanup attempt %d/%d failed: %s", attempt, self.cleanup_retries, e) + if attempt < self.cleanup_retries: + time.sleep(min(2**attempt, 10)) + + self.log.warning( + "Failed to clean up remote job directory after %d attempts; leaving orphaned " + "directory %s on the remote host (last error: %s)", + self.cleanup_retries, + job_dir, + last_error, + ) def on_kill(self) -> None: """ diff --git a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py index 0d4072c1ca4c4..2d390a3f2e64f 100644 --- a/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py +++ b/providers/ssh/src/airflow/providers/ssh/triggers/ssh_remote_job.py @@ -20,10 +20,11 @@ from __future__ import annotations import asyncio +import random from collections.abc import AsyncIterator -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal -import tenacity +import asyncssh from airflow.providers.ssh.hooks.ssh import SSHHookAsync from airflow.providers.ssh.utils.remote_job import ( @@ -36,6 +37,16 @@ ) from airflow.triggers.base import BaseTrigger, TriggerEvent +if TYPE_CHECKING: + from asyncssh import SSHClientConnection + +# Errors that mean the connection itself is broken/refused and the poll should +# reconnect instead of failing the job. ``asyncssh.Error`` covers handshake, +# protocol and disconnect failures (e.g. an sshd that drops the connection under +# ``MaxStartups`` load); ``OSError`` covers TCP-level refusals; ``TimeoutError`` +# covers a wedged command or connection. +_CONNECTION_ERRORS = (OSError, asyncssh.Error, TimeoutError) + class SSHRemoteJobTrigger(BaseTrigger): """ @@ -44,6 +55,13 @@ class SSHRemoteJobTrigger(BaseTrigger): This trigger polls the remote host to check job completion status and reads log output incrementally. + A single SSH connection is opened and reused for the whole poll loop instead + of reconnecting for every command. Opening a fresh TCP/SSH connection per poll + multiplies the connection rate against the remote ``sshd`` (which throttles + concurrent unauthenticated connections via ``MaxStartups``), so reuse keeps the + load flat when many tasks target the same host. If the connection drops, the + trigger transparently reconnects with backoff up to ``max_reconnect_attempts``. + :param ssh_conn_id: SSH connection ID from Airflow Connections :param remote_host: Optional override for the remote host :param job_id: Unique identifier for the remote job @@ -54,6 +72,9 @@ class SSHRemoteJobTrigger(BaseTrigger): :param poll_interval: Seconds between polling attempts :param log_chunk_size: Maximum bytes to read per poll :param log_offset: Current byte offset in the log file + :param command_timeout: Per-command timeout in seconds + :param max_reconnect_attempts: Consecutive connection failures tolerated before the + trigger gives up and emits an error event """ def __init__( @@ -69,6 +90,7 @@ def __init__( log_chunk_size: int = 65536, log_offset: int = 0, command_timeout: float = 30.0, + max_reconnect_attempts: int = 5, ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id @@ -82,6 +104,7 @@ def __init__( self.log_chunk_size = log_chunk_size self.log_offset = log_offset self.command_timeout = command_timeout + self.max_reconnect_attempts = max_reconnect_attempts def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize the trigger for storage.""" @@ -99,6 +122,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "log_chunk_size": self.log_chunk_size, "log_offset": self.log_offset, "command_timeout": self.command_timeout, + "max_reconnect_attempts": self.max_reconnect_attempts, }, ) @@ -109,18 +133,42 @@ def _get_hook(self) -> SSHHookAsync: host=self.remote_host, ) - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(multiplier=1, min=1, max=10), - retry=tenacity.retry_if_exception_type((OSError, TimeoutError, ConnectionError)), - reraise=True, - ) - async def _check_completion(self, hook: SSHHookAsync) -> int | None: + async def _connect(self) -> SSHClientConnection: + """Open a reusable asyncssh connection. Separated out as a seam for testing.""" + return await self._get_hook().get_conn() + + @staticmethod + async def _close(conn: SSHClientConnection) -> None: + """Close a connection, swallowing teardown errors.""" + try: + conn.close() + await conn.wait_closed() + except Exception: + # Teardown is best-effort; a failing close has nothing actionable to recover. + pass + + def _reconnect_delay(self, attempt: int) -> float: + """Exponential backoff with randomness so reconnecting triggers do not retry in lockstep.""" + base = min(2 ** (attempt - 1), 30) + return base + random.uniform(0, base) + + async def _run_command(self, conn: SSHClientConnection, command: str) -> tuple[int, str, str]: + """Run a command on an existing connection, mirroring ``SSHHookAsync.run_command``.""" + result = await conn.run(command, timeout=self.command_timeout, check=False) + stdout = result.stdout or "" + stderr = result.stderr or "" + # asyncssh types stdout/stderr as bytes | str; with the default text encoding they are + # str, but decode defensively so the helper holds if a binary connection is ever used. + if isinstance(stdout, bytes): + stdout = stdout.decode("utf-8", errors="replace") + if isinstance(stderr, bytes): + stderr = stderr.decode("utf-8", errors="replace") + return result.exit_status or 0, stdout, stderr + + async def _check_completion(self, conn: SSHClientConnection) -> int | None: """ Check if the remote job has completed. - Retries transient network errors up to 3 times with exponential backoff. - :return: Exit code if completed, None if still running """ if self.remote_os == "posix": @@ -128,62 +176,32 @@ async def _check_completion(self, hook: SSHHookAsync) -> int | None: else: cmd = build_windows_completion_check_command(self.exit_code_file) - try: - _, stdout, _ = await hook.run_command(cmd, timeout=self.command_timeout) - stdout = stdout.strip() - if stdout and stdout.isdigit(): - return int(stdout) - except (OSError, TimeoutError, ConnectionError) as e: - self.log.warning("Transient error checking completion (will retry): %s", e) - raise - except Exception as e: - self.log.warning("Error checking completion status: %s", e) + _, stdout, _ = await self._run_command(conn, cmd) + stdout = stdout.strip() + if stdout and stdout.isdigit(): + return int(stdout) return None - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(multiplier=1, min=1, max=10), - retry=tenacity.retry_if_exception_type((OSError, TimeoutError, ConnectionError)), - reraise=True, - ) - async def _get_log_size(self, hook: SSHHookAsync) -> int: - """ - Get the current size of the log file in bytes. - - Retries transient network errors up to 3 times with exponential backoff. - """ + async def _get_log_size(self, conn: SSHClientConnection) -> int: + """Get the current size of the log file in bytes.""" if self.remote_os == "posix": cmd = build_posix_file_size_command(self.log_file) else: cmd = build_windows_file_size_command(self.log_file) - try: - _, stdout, _ = await hook.run_command(cmd, timeout=self.command_timeout) - stdout = stdout.strip() - if stdout and stdout.isdigit(): - return int(stdout) - except (OSError, TimeoutError, ConnectionError) as e: - self.log.warning("Transient error getting log size (will retry): %s", e) - raise - except Exception as e: - self.log.warning("Error getting log file size: %s", e) + _, stdout, _ = await self._run_command(conn, cmd) + stdout = stdout.strip() + if stdout and stdout.isdigit(): + return int(stdout) return 0 - @tenacity.retry( - stop=tenacity.stop_after_attempt(3), - wait=tenacity.wait_exponential(multiplier=1, min=1, max=10), - retry=tenacity.retry_if_exception_type((OSError, TimeoutError, ConnectionError)), - reraise=True, - ) - async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]: + async def _read_log_chunk(self, conn: SSHClientConnection) -> tuple[str, int]: """ Read a chunk of logs from the current offset. - Retries transient network errors up to 3 times with exponential backoff. - :return: Tuple of (log_chunk, new_offset) """ - file_size = await self._get_log_size(hook) + file_size = await self._get_log_size(conn) if file_size <= self.log_offset: return "", self.log_offset @@ -195,47 +213,94 @@ async def _read_log_chunk(self, hook: SSHHookAsync) -> tuple[str, int]: else: cmd = build_windows_log_tail_command(self.log_file, self.log_offset, bytes_to_read) - try: - exit_code, stdout, _ = await hook.run_command(cmd, timeout=self.command_timeout) + _, stdout, _ = await self._run_command(conn, cmd) - # Advance offset by bytes requested, not decoded string length - new_offset = self.log_offset + bytes_to_read if stdout else self.log_offset + # Advance offset by bytes requested, not decoded string length + new_offset = self.log_offset + bytes_to_read if stdout else self.log_offset + return stdout, new_offset - return stdout, new_offset - except (OSError, TimeoutError, ConnectionError) as e: - self.log.warning("Transient error reading logs (will retry): %s", e) - raise - except Exception as e: - self.log.warning("Error reading log chunk: %s", e) - return "", self.log_offset + def _error_event(self, message: str) -> TriggerEvent: + return TriggerEvent( + { + "job_id": self.job_id, + "job_dir": self.job_dir, + "log_file": self.log_file, + "exit_code_file": self.exit_code_file, + "remote_os": self.remote_os, + "status": "error", + "done": True, + "exit_code": None, + "log_chunk": "", + "log_offset": self.log_offset, + "message": message, + } + ) async def run(self) -> AsyncIterator[TriggerEvent]: """ - Poll the remote job status and yield events with log chunks. + Poll the remote job status and yield a completion event. - This method runs in a loop, checking the job status and reading - logs at each poll interval. It yields a TriggerEvent each time - with the current status and any new log output. + One connection is held for the whole loop. On a connection-level failure the + connection is dropped and re-established (with exponential backoff) up to + ``max_reconnect_attempts`` consecutive times; any other error, or exhausting the + reconnect budget, ends the trigger with an error event. """ - hook = self._get_hook() + conn: SSHClientConnection | None = None + # Consecutive failures since the last *fully successful* poll. A successful + # handshake alone does not reset this: a connection that handshakes but whose + # command channel keeps failing (e.g. ChannelOpenError under sshd MaxSessions) + # must still exhaust the budget instead of looping forever. + failures = 0 - while True: - try: - exit_code = await self._check_completion(hook) - log_chunk, new_offset = await self._read_log_chunk(hook) + try: + while True: + if conn is None: + try: + conn = await self._connect() + except _CONNECTION_ERRORS as e: + failures += 1 + if failures > self.max_reconnect_attempts: + raise + delay = self._reconnect_delay(failures) + self.log.warning( + "Failed to connect to remote host (attempt %d/%d), retrying in %.1fs: %s", + failures, + self.max_reconnect_attempts, + delay, + e, + ) + await asyncio.sleep(delay) + continue + + try: + exit_code = await self._check_completion(conn) + log_chunk, new_offset = await self._read_log_chunk(conn) + except _CONNECTION_ERRORS as e: + failures += 1 + self.log.warning( + "Lost SSH connection while polling (attempt %d/%d), reconnecting: %s", + failures, + self.max_reconnect_attempts, + e, + ) + await self._close(conn) + conn = None + if failures > self.max_reconnect_attempts: + raise + await asyncio.sleep(self._reconnect_delay(failures)) + continue - base_event = { - "job_id": self.job_id, - "job_dir": self.job_dir, - "log_file": self.log_file, - "exit_code_file": self.exit_code_file, - "remote_os": self.remote_os, - } + # A full poll cycle succeeded on this connection; clear the failure budget. + failures = 0 if exit_code is not None: yield TriggerEvent( { - **base_event, + "job_id": self.job_id, + "job_dir": self.job_dir, + "log_file": self.log_file, + "exit_code_file": self.exit_code_file, + "remote_os": self.remote_os, "status": "success" if exit_code == 0 else "failed", "done": True, "exit_code": exit_code, @@ -251,21 +316,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("%s", log_chunk.rstrip()) await asyncio.sleep(self.poll_interval) - except Exception as e: - self.log.exception("Error in SSH remote job trigger") - yield TriggerEvent( - { - "job_id": self.job_id, - "job_dir": self.job_dir, - "log_file": self.log_file, - "exit_code_file": self.exit_code_file, - "remote_os": self.remote_os, - "status": "error", - "done": True, - "exit_code": None, - "log_chunk": "", - "log_offset": self.log_offset, - "message": f"Trigger error: {e}", - } - ) - return + except Exception as e: + self.log.exception("Error in SSH remote job trigger") + yield self._error_event(f"Trigger error: {e}") + return + finally: + if conn is not None: + await self._close(conn) diff --git a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py index 289cdfa3e48dc..e277c1dc3f654 100644 --- a/providers/ssh/tests/unit/ssh/hooks/test_ssh.py +++ b/providers/ssh/tests/unit/ssh/hooks/test_ssh.py @@ -594,6 +594,26 @@ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_true(self, s assert ssh_client.return_value.connect.called is True assert ssh_client.return_value.set_missing_host_key_policy.called is True + @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") + def test_conn_retry_attempts_defaults_to_three(self, ssh_client): + hook = SSHHook(ssh_conn_id="ssh_default") + assert hook.conn_retry_attempts == 3 + + @mock.patch("time.sleep") + @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") + def test_conn_retry_attempts_retries_until_limit(self, ssh_client, _mock_sleep): + """get_conn retries the configured number of times before re-raising.""" + ssh_client.return_value.connect.side_effect = paramiko.ssh_exception.SSHException( + "Error reading SSH protocol banner" + ) + hook = SSHHook(ssh_conn_id="ssh_default", conn_retry_attempts=4) + assert hook.conn_retry_attempts == 4 + + with pytest.raises(paramiko.ssh_exception.SSHException): + hook.get_conn() + + assert ssh_client.return_value.connect.call_count == 4 + @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") def test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self, ssh_client): hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE) diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py index 871f398139384..6ecf80baef28d 100644 --- a/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py +++ b/providers/ssh/tests/unit/ssh/operators/test_ssh_remote_job.py @@ -123,6 +123,72 @@ def test_execute_defers_to_trigger(self): assert isinstance(exc_info.value.trigger, SSHRemoteJobTrigger) assert exc_info.value.method_name == "execute_complete" + def test_execute_forwards_trigger_tuning_params(self): + """command_timeout and max_reconnect_attempts must reach the deferred trigger.""" + self.mock_hook.exec_ssh_client_command.return_value = (0, b"job", b"") + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + remote_os="posix", + command_timeout=12.5, + max_reconnect_attempts=9, + ) + mock_ti = mock.MagicMock() + mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number = "d", "t", "r", 1 + + with pytest.raises(TaskDeferred) as exc_info: + op.execute({"ti": mock_ti}) + + trigger = exc_info.value.trigger + assert trigger.command_timeout == 12.5 + assert trigger.max_reconnect_attempts == 9 + + def test_execute_complete_re_defer_forwards_tuning_params(self): + """The re-defer path must also forward the trigger-tuning params.""" + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + command_timeout=7.0, + max_reconnect_attempts=4, + ) + event = { + "done": False, + "status": "running", + "job_id": "j", + "job_dir": "/tmp/airflow-ssh-jobs/j", + "log_file": "/tmp/airflow-ssh-jobs/j/stdout.log", + "exit_code_file": "/tmp/airflow-ssh-jobs/j/exit_code", + "remote_os": "posix", + "log_chunk": "", + "log_offset": 0, + "exit_code": None, + } + with pytest.raises(TaskDeferred) as exc_info: + op.execute_complete({}, event) + + trigger = exc_info.value.trigger + assert trigger.command_timeout == 7.0 + assert trigger.max_reconnect_attempts == 4 + + def test_execute_connect_failure_is_reraised(self): + """A connection failure during submit is re-raised unchanged (advisory is log-only).""" + self.mock_hook.get_conn.side_effect = OSError("Error reading SSH protocol banner") + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + remote_os="posix", + ) + mock_ti = mock.MagicMock() + mock_ti.dag_id, mock_ti.task_id, mock_ti.run_id, mock_ti.try_number = "d", "t", "r", 1 + + with pytest.raises(OSError, match="Error reading SSH protocol banner"): + op.execute({"ti": mock_ti}) + def test_execute_raises_if_no_command(self): """Test that execute raises if command is not specified.""" op = SSHRemoteJobOperator( @@ -329,3 +395,68 @@ def test_on_kill_no_active_job(self): # Should not raise even without active job op.on_kill() + + def test_execute_uses_single_connection_for_detect_and_submit(self): + """OS auto-detection and submission must share one SSH connection (one handshake).""" + self.mock_hook.exec_ssh_client_command.side_effect = [ + (0, b"Linux", b""), # OS detection + (0, b"af_test_dag_test_task_run1_try1_abc123", b""), # submission + ] + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + remote_os="auto", + ) + + mock_ti = mock.MagicMock() + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "test_task" + mock_ti.run_id = "run1" + mock_ti.try_number = 1 + + with pytest.raises(TaskDeferred): + op.execute({"ti": mock_ti}) + + # One handshake for the whole execute(): detection + submit reuse it. + self.mock_hook.get_conn.assert_called_once() + assert self.mock_hook.exec_ssh_client_command.call_count == 2 + assert op._detected_os == "posix" + + def test_cleanup_retries_then_succeeds(self): + """Cleanup retries on a transient SSH failure and stops once it succeeds.""" + self.mock_hook.exec_ssh_client_command.side_effect = [ + Exception("Error reading SSH protocol banner"), + (0, b"", b""), + ] + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + cleanup_retries=3, + ) + + with mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep") as mock_sleep: + op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123", "posix") + + assert self.mock_hook.exec_ssh_client_command.call_count == 2 + mock_sleep.assert_called_once() + + def test_cleanup_gives_up_after_retries_without_raising(self): + """When every cleanup attempt fails the task is not failed; the dir is left in place.""" + self.mock_hook.exec_ssh_client_command.side_effect = Exception("connection refused") + + op = SSHRemoteJobOperator( + task_id="test_task", + ssh_conn_id="test_conn", + command="/path/to/script.sh", + cleanup_retries=3, + ) + + with mock.patch("airflow.providers.ssh.operators.ssh_remote_job.time.sleep"): + # Must not raise even though all attempts fail. + op._cleanup_remote_job("/tmp/airflow-ssh-jobs/test_job_123", "posix") + + assert self.mock_hook.exec_ssh_client_command.call_count == 3 diff --git a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py index 67d9672a8cbeb..8e1d1e5953a6f 100644 --- a/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py +++ b/providers/ssh/tests/unit/ssh/triggers/test_ssh_remote_job.py @@ -24,6 +24,20 @@ from airflow.providers.ssh.triggers.ssh_remote_job import SSHRemoteJobTrigger +def _make_trigger(**overrides): + kwargs = dict( + ssh_conn_id="test_conn", + remote_host=None, + job_id="test_job", + job_dir="/tmp/job", + log_file="/tmp/job/stdout.log", + exit_code_file="/tmp/job/exit_code", + remote_os="posix", + ) + kwargs.update(overrides) + return SSHRemoteJobTrigger(**kwargs) + + class TestSSHRemoteJobTrigger: def test_serialization(self): """Test that the trigger can be serialized correctly.""" @@ -38,6 +52,7 @@ def test_serialization(self): poll_interval=10, log_chunk_size=32768, log_offset=1000, + max_reconnect_attempts=7, ) classpath, kwargs = trigger.serialize() @@ -53,145 +68,201 @@ def test_serialization(self): assert kwargs["poll_interval"] == 10 assert kwargs["log_chunk_size"] == 32768 assert kwargs["log_offset"] == 1000 + assert kwargs["max_reconnect_attempts"] == 7 + + def test_serialization_round_trips(self): + """The serialized kwargs must be sufficient to rebuild the trigger.""" + trigger = _make_trigger(poll_interval=3, max_reconnect_attempts=2) + _, kwargs = trigger.serialize() + rebuilt = SSHRemoteJobTrigger(**kwargs) + assert rebuilt.serialize() == trigger.serialize() def test_default_values(self): """Test default parameter values.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger() assert trigger.poll_interval == 5 assert trigger.log_chunk_size == 65536 assert trigger.log_offset == 0 + assert trigger.max_reconnect_attempts == 5 @pytest.mark.asyncio async def test_run_job_completed_success(self): """Test trigger when job completes successfully.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger() - with mock.patch.object(trigger, "_check_completion", return_value=0): - with mock.patch.object(trigger, "_read_log_chunk", return_value=("Final output\n", 100)): - events = [] - async for event in trigger.run(): - events.append(event) + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object(trigger, "_check_completion", return_value=0), + mock.patch.object(trigger, "_read_log_chunk", return_value=("Final output\n", 100)), + ): + events = [event async for event in trigger.run()] - assert len(events) == 1 - assert events[0].payload["status"] == "success" - assert events[0].payload["done"] is True - assert events[0].payload["exit_code"] == 0 - assert events[0].payload["log_chunk"] == "Final output\n" + assert len(events) == 1 + assert events[0].payload["status"] == "success" + assert events[0].payload["done"] is True + assert events[0].payload["exit_code"] == 0 + assert events[0].payload["log_chunk"] == "Final output\n" @pytest.mark.asyncio async def test_run_job_completed_failure(self): """Test trigger when job completes with failure.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger() - with mock.patch.object(trigger, "_check_completion", return_value=1): - with mock.patch.object(trigger, "_read_log_chunk", return_value=("Error output\n", 50)): - events = [] - async for event in trigger.run(): - events.append(event) + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object(trigger, "_check_completion", return_value=1), + mock.patch.object(trigger, "_read_log_chunk", return_value=("Error output\n", 50)), + ): + events = [event async for event in trigger.run()] - assert len(events) == 1 - assert events[0].payload["status"] == "failed" - assert events[0].payload["done"] is True - assert events[0].payload["exit_code"] == 1 + assert len(events) == 1 + assert events[0].payload["status"] == "failed" + assert events[0].payload["done"] is True + assert events[0].payload["exit_code"] == 1 @pytest.mark.asyncio - async def test_run_job_polls_until_completion(self): - """Test trigger polls without yielding until job completes.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - poll_interval=0.01, - ) + async def test_run_reuses_single_connection_across_polls(self): + """The connection is opened once and reused for every poll, not per command.""" + trigger = _make_trigger(poll_interval=0.01) poll_count = 0 async def mock_check_completion(_): nonlocal poll_count poll_count += 1 - # Return None (still running) for first 2 polls, then exit code 0 - if poll_count < 3: - return None + return None if poll_count < 3 else 0 + + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()) as mock_connect, + mock.patch.object(trigger, "_close", return_value=None) as mock_close, + mock.patch.object(trigger, "_check_completion", side_effect=mock_check_completion), + mock.patch.object(trigger, "_read_log_chunk", return_value=("output\n", 50)), + ): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "success" + assert poll_count == 3 + # One connect for the whole loop, closed once at teardown. + assert mock_connect.call_count == 1 + assert mock_close.call_count == 1 + + @pytest.mark.asyncio + async def test_run_reconnects_on_connection_drop(self): + """A connection-level error mid-poll drops the connection and reconnects.""" + trigger = _make_trigger(poll_interval=0.01, max_reconnect_attempts=3) + + calls = {"check": 0} + + async def flaky_check(_): + calls["check"] += 1 + if calls["check"] == 1: + raise OSError("Error reading SSH protocol banner") return 0 - with mock.patch.object(trigger, "_check_completion", side_effect=mock_check_completion): - with mock.patch.object(trigger, "_read_log_chunk", return_value=("output\n", 50)): - events = [] - async for event in trigger.run(): - events.append(event) + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()) as mock_connect, + mock.patch.object(trigger, "_close", return_value=None) as mock_close, + mock.patch.object(trigger, "_check_completion", side_effect=flaky_check), + mock.patch.object(trigger, "_read_log_chunk", return_value=("out\n", 10)), + mock.patch("asyncio.sleep", new=mock.AsyncMock()), + ): + events = [event async for event in trigger.run()] - # Only one event should be yielded (the completion event) - assert len(events) == 1 - assert events[0].payload["status"] == "success" - assert events[0].payload["done"] is True - assert events[0].payload["exit_code"] == 0 - # Should have polled 3 times - assert poll_count == 3 + assert len(events) == 1 + assert events[0].payload["status"] == "success" + # Initial connect + one reconnect after the dropped poll. + assert mock_connect.call_count == 2 + # Dropped connection closed during reconnect, plus final teardown close. + assert mock_close.call_count == 2 @pytest.mark.asyncio - async def test_run_handles_exception(self): - """Test trigger handles exceptions gracefully.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host=None, - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + async def test_run_gives_up_after_max_reconnects(self): + """When connections keep failing, the trigger emits a single error event.""" + trigger = _make_trigger(max_reconnect_attempts=2) + + with ( + mock.patch.object(trigger, "_connect", side_effect=OSError("connection refused")), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch("asyncio.sleep", new=mock.AsyncMock()), + ): + events = [event async for event in trigger.run()] - with mock.patch.object(trigger, "_check_completion", side_effect=Exception("Connection failed")): - events = [] - async for event in trigger.run(): - events.append(event) + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert events[0].payload["done"] is True + assert events[0].payload["exit_code"] is None + assert "connection refused" in events[0].payload["message"] - assert len(events) == 1 - assert events[0].payload["status"] == "error" - assert events[0].payload["done"] is True - assert "Connection failed" in events[0].payload["message"] + @pytest.mark.asyncio + async def test_run_gives_up_when_polls_keep_failing_despite_reconnects(self): + """A connection that handshakes but whose polls keep failing must still hit the cap. + + Regression: the reconnect budget must not reset on a bare successful handshake, or a + connection that reconnects fine but never completes a poll (e.g. ChannelOpenError under + sshd MaxSessions) would loop forever and the task would defer indefinitely. + """ + trigger = _make_trigger(max_reconnect_attempts=2) + + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()) as mock_connect, + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object( + trigger, "_check_completion", side_effect=OSError("channel open failed") + ) as mock_check, + mock.patch("asyncio.sleep", new=mock.AsyncMock()), + ): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert "channel open failed" in events[0].payload["message"] + # Budget = 2 -> third consecutive failure ends it (handshake succeeds each round + # but never resets the counter because no poll ever completes). + assert mock_check.call_count == 3 + assert mock_connect.call_count == 3 + + @pytest.mark.asyncio + async def test_run_handles_unexpected_exception(self): + """A non-connection error surfaces immediately as an error event.""" + trigger = _make_trigger() + + with ( + mock.patch.object(trigger, "_connect", return_value=mock.MagicMock()), + mock.patch.object(trigger, "_close", return_value=None), + mock.patch.object(trigger, "_check_completion", side_effect=ValueError("boom")), + ): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert events[0].payload["done"] is True + assert "boom" in events[0].payload["message"] def test_get_hook(self): """Test hook creation.""" - trigger = SSHRemoteJobTrigger( - ssh_conn_id="test_conn", - remote_host="custom.host.com", - job_id="test_job", - job_dir="/tmp/job", - log_file="/tmp/job/stdout.log", - exit_code_file="/tmp/job/exit_code", - remote_os="posix", - ) + trigger = _make_trigger(remote_host="custom.host.com") hook = trigger._get_hook() assert hook.ssh_conn_id == "test_conn" assert hook.host == "custom.host.com" + + @pytest.mark.asyncio + async def test_run_command_uses_existing_connection(self): + """_run_command runs on the passed connection without opening a new one.""" + trigger = _make_trigger(command_timeout=12.0) + + result = mock.MagicMock() + result.exit_status = 0 + result.stdout = "42" + result.stderr = "" + conn = mock.MagicMock() + conn.run = mock.AsyncMock(return_value=result) + + exit_code, stdout, stderr = await trigger._run_command(conn, "echo 42") + + conn.run.assert_awaited_once_with("echo 42", timeout=12.0, check=False) + assert (exit_code, stdout, stderr) == (0, "42", "")