diff --git a/packages/jumpstarter-driver-shell/README.md b/packages/jumpstarter-driver-shell/README.md index 6e1afde1d..d80f497bf 100644 --- a/packages/jumpstarter-driver-shell/README.md +++ b/packages/jumpstarter-driver-shell/README.md @@ -63,3 +63,53 @@ methods will be generated dynamically, and they will be available as follows: :returns: A tuple(stdout, stderr, return_code) ``` + +## CLI Usage + +The shell driver also provides a CLI when using `jmp shell`. All configured methods become available as CLI commands, except for methods starting with `_` which are considered private and hidden from the end user: + +```console +$ jmp shell --exporter shell-exporter +$ j shell +Usage: j shell [OPTIONS] COMMAND [ARGS]... + + Shell command executor + +Commands: + env_var Execute the env_var shell method + ls Execute the ls shell method + method2 Execute the method2 shell method + method3 Execute the method3 shell method +``` + +### CLI Command Usage + +Each configured method becomes a CLI command with the following options: + +```console +$ j shell ls --help +Usage: j shell ls [OPTIONS] [ARGS]... + + Execute the ls shell method + +Options: + -e, --env TEXT Environment variables in KEY=VALUE format + --help Show this message and exit. +``` + +### Examples + +```console +# Execute simple commands +$ j shell ls +file1.txt file2.txt directory/ + +# Pass arguments to shell methods +$ j shell method3 "first arg" "second arg" +Hello World first arg +Hello World second arg + +# Set environment variables +$ j shell env_var arg1 arg2 --env ENV_VAR=myvalue +arg1,arg2,myvalue +``` diff --git a/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/client.py b/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/client.py index 257f87cd3..ef8cee2f6 100644 --- a/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/client.py +++ b/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/client.py @@ -1,5 +1,8 @@ +import sys from dataclasses import dataclass +import click + from jumpstarter.client import DriverClient @@ -11,8 +14,8 @@ class ShellClient(DriverClient): Client interface for Shell driver. This client dynamically checks that the method is configured - on the driver, and if it is, it will call it and get the results - in the form of (stdout, stderr, returncode). + on the driver, and if it is, it will call it with live streaming output. + Output chunks are displayed as they arrive. """ def _check_method_exists(self, method): @@ -24,4 +27,59 @@ def _check_method_exists(self, method): ## capture any method calls dynamically def __getattr__(self, name): self._check_method_exists(name) - return lambda *args, **kwargs: tuple(self.call("call_method", name, kwargs, *args)) + def execute(*args, **kwargs): + returncode = 0 + for stdout, stderr, code in self.streamingcall("call_method", name, kwargs, *args): + if stdout: + print(stdout, end='', flush=True) + if stderr: + print(stderr, end='', file=sys.stderr, flush=True) + if code is not None: + returncode = code + return returncode + return execute + + def cli(self): + """Create CLI interface for dynamically configured shell methods""" + @click.group + def base(): + """Shell command executor""" + pass + + # Get available methods from the driver + if self._methods is None: + self._methods = self.call("get_methods") + + # Create a command for each configured method + for method_name in self._methods: + self._add_method_command(base, method_name) + + return base + + def _add_method_command(self, group, method_name): + """Add a Click command for a specific shell method""" + @group.command( + name=method_name, + context_settings={"ignore_unknown_options": True, "allow_interspersed_args": False}, + ) + @click.argument('args', nargs=-1, type=click.UNPROCESSED) + @click.option('--env', '-e', multiple=True, + help='Environment variables in KEY=VALUE format') + def method_command(args, env): + # Parse environment variables + env_dict = {} + for env_var in env: + if '=' in env_var: + key, value = env_var.split('=', 1) + env_dict[key] = value + else: + raise click.BadParameter(f"Invalid --env value '{env_var}'. Use KEY=VALUE.") + + returncode = getattr(self, method_name)(*args, **env_dict) + + # Exit with the same return code as the shell command + if returncode != 0: + raise click.exceptions.Exit(returncode) + + # Update the docstring dynamically + method_command.__doc__ = f"Execute the {method_name} shell method" diff --git a/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py b/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py index 979747a4a..bbabadfc0 100644 --- a/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py +++ b/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py @@ -1,6 +1,9 @@ +import asyncio import os +import signal import subprocess from dataclasses import dataclass, field +from typing import AsyncGenerator from jumpstarter.driver import Driver, export @@ -27,41 +30,38 @@ def get_methods(self) -> list[str]: return methods @export - def call_method(self, method: str, env, *args): + async def call_method(self, method: str, env, *args) -> AsyncGenerator[tuple[str, str, int | None], None]: + """ + Execute a shell method with live streaming output. + Yields (stdout_chunk, stderr_chunk, returncode) tuples. + returncode is None until the process completes, then it's the final return code. + """ self.logger.info(f"calling {method} with args: {args} and kwargs as env: {env}") if method not in self.methods: raise ValueError(f"Method '{method}' not found in available methods: {list(self.methods.keys())}") script = self.methods[method] self.logger.debug(f"running script: {script}") + try: - result = self._run_inline_shell_script(method, script, *args, env_vars=env) - if result.returncode != 0: - self.logger.info(f"{method} return code: {result.returncode}") - if result.stderr != "": - stderr = result.stderr.rstrip("\n") - self.logger.debug(f"{method} stderr:\n{stderr}") - if result.stdout != "": - stdout = result.stdout.rstrip("\n") - self.logger.debug(f"{method} stdout:\n{stdout}") - return result.stdout, result.stderr, result.returncode + async for stdout_chunk, stderr_chunk, returncode in self._run_inline_shell_script( + method, script, *args, env_vars=env + ): + if stdout_chunk: + self.logger.debug(f"{method} stdout:\n{stdout_chunk.rstrip()}") + if stderr_chunk: + self.logger.debug(f"{method} stderr:\n{stderr_chunk.rstrip()}") + + if returncode is not None and returncode != 0: + self.logger.info(f"{method} return code: {returncode}") + + yield stdout_chunk, stderr_chunk, returncode except subprocess.TimeoutExpired as e: self.logger.error(f"Timeout expired while running {method}: {e}") - return "", f"Timeout expired while running {method}: {e}", 199 - - def _run_inline_shell_script(self, method, script, *args, env_vars=None): - """ - Run the given shell script (as a string) with optional arguments and - environment variables. Returns a CompletedProcess with stdout, stderr, and returncode. - - :param script: The shell script contents as a string. - :param args: Arguments to pass to the script (mapped to $1, $2, etc. in the script). - :param env_vars: A dict of environment variables to make available to the script. - - :return: A subprocess.CompletedProcess object (Python 3.5+). - """ + yield "", f"\nTimeout expired while running {method}: {e}\n", 199 + def _validate_script_params(self, script, args, env_vars): + """Validate script parameters and return combined environment.""" # Merge parent environment with the user-supplied env_vars - # so that we don't lose existing environment variables. combined_env = os.environ.copy() if env_vars: # Validate environment variable names @@ -82,16 +82,108 @@ def _run_inline_shell_script(self, method, script, *args, env_vars=None): if self.cwd and not os.path.isdir(self.cwd): raise ValueError(f"Working directory does not exist: {self.cwd}") + return combined_env + + async def _read_process_output(self, process, read_all=False): + """Read data from stdout and stderr streams. + + :param process: The subprocess to read from + :param read_all: If True, read all remaining data. If False, read with timeout. + :return: Tuple of (stdout_data, stderr_data) + """ + stdout_data = "" + stderr_data = "" + + # Read from stdout + if process.stdout: + try: + if read_all: + chunk = await process.stdout.read() + else: + chunk = await asyncio.wait_for(process.stdout.read(1024), timeout=0.01) + if chunk: + stdout_data = chunk.decode('utf-8', errors='replace') + except (asyncio.TimeoutError, Exception): + pass + + # Read from stderr + if process.stderr: + try: + if read_all: + chunk = await process.stderr.read() + else: + chunk = await asyncio.wait_for(process.stderr.read(1024), timeout=0.01) + if chunk: + stderr_data = chunk.decode('utf-8', errors='replace') + except (asyncio.TimeoutError, Exception): + pass + + return stdout_data, stderr_data + + async def _run_inline_shell_script( + self, method, script, *args, env_vars=None + ) -> AsyncGenerator[tuple[str, str, int | None], None]: + """ + Run the given shell script with live streaming output. + + :param method: The method name (for logging). + :param script: The shell script contents as a string. + :param args: Arguments to pass to the script (mapped to $1, $2, etc. in the script). + :param env_vars: A dict of environment variables to make available to the script. + + :yields: Tuples of (stdout_chunk, stderr_chunk, returncode). + returncode is None until the process completes. + """ + combined_env = self._validate_script_params(script, args, env_vars) cmd = self.shell + [script, method] + list(args) - # Run the command - result = subprocess.run( - cmd, - capture_output=True, # Captures stdout and stderr - text=True, # Returns stdout/stderr as strings (not bytes) - env=combined_env, # Pass our merged environment - cwd=self.cwd, # Run in the working directory (if set) - timeout=self.timeout, + # Start the process with pipes for streaming and new process group + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=combined_env, + cwd=self.cwd, + start_new_session=True, # Create new process group ) - return result + # Create a task to monitor the process timeout + start_time = asyncio.get_event_loop().time() + + # Read output in real-time + while process.returncode is None: + self.logger.debug(f"running {method} with cmd: {cmd} and env: {combined_env} and args: {args}") + if asyncio.get_event_loop().time() - start_time > self.timeout: + # Send SIGTERM to entire process group for graceful termination + try: + os.killpg(process.pid, signal.SIGTERM) + except (ProcessLookupError, OSError): + # Process group might already be gone + pass + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + try: + os.killpg(process.pid, signal.SIGKILL) + self.logger.warning(f"SIGTERM failed to terminate {process.pid}, sending SIGKILL") + except (ProcessLookupError, OSError): + pass + raise subprocess.TimeoutExpired(cmd, self.timeout) from None + + try: + stdout_data, stderr_data = await self._read_process_output(process, read_all=False) + + # Yield any data we got + if stdout_data or stderr_data: + yield stdout_data, stderr_data, None + + # Small delay to prevent busy waiting + await asyncio.sleep(0.1) + + except Exception: + break + + # Process completed, get return code and final output + returncode = process.returncode + remaining_stdout, remaining_stderr = await self._read_process_output(process, read_all=True) + yield remaining_stdout, remaining_stderr, returncode diff --git a/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py b/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py index 7787bbb9c..4e03b5f64 100644 --- a/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py +++ b/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py @@ -1,9 +1,28 @@ + import pytest from .driver import Shell from jumpstarter.common.utils import serve +def _collect_streaming_output(client, method_name, env_vars=None, *args): + """Helper function to collect streaming output for testing""" + stdout_parts = [] + stderr_parts = [] + final_returncode = None + + env_vars = env_vars or {} + for stdout_chunk, stderr_chunk, returncode in client.streamingcall("call_method", method_name, env_vars, *args): + if stdout_chunk: + stdout_parts.append(stdout_chunk) + if stderr_chunk: + stderr_parts.append(stderr_chunk) + if returncode is not None: + final_returncode = returncode + + return "".join(stdout_parts), "".join(stderr_parts), final_returncode + + @pytest.fixture def client(): instance = Shell( @@ -21,23 +40,38 @@ def client(): def test_normal_args(client): - assert client.echo("hello") == ("hello\n", "", 0) + stdout, stderr, returncode = _collect_streaming_output(client, "echo", {}, "hello") + assert stdout == "hello\n" + assert stderr == "" + assert returncode == 0 def test_env_vars(client): - assert client.env(ENV1="world") == ("world\n", "", 0) + stdout, stderr, returncode = _collect_streaming_output(client, "env", {"ENV1": "world"}) + assert stdout == "world\n" + assert stderr == "" + assert returncode == 0 def test_multi_line_scripts(client): - assert client.multi_line("a", "b", "c") == ("a\nb\nc\n", "", 0) + stdout, stderr, returncode = _collect_streaming_output(client, "multi_line", {}, "a", "b", "c") + assert stdout == "a\nb\nc\n" + assert stderr == "" + assert returncode == 0 def test_return_codes(client): - assert client.exit1() == ("", "", 1) + stdout, stderr, returncode = _collect_streaming_output(client, "exit1") + assert stdout == "" + assert stderr == "" + assert returncode == 1 def test_stderr(client): - assert client.stderr("error") == ("", "error\n", 0) + stdout, stderr, returncode = _collect_streaming_output(client, "stderr", {}, "error") + assert stdout == "" + assert stderr == "error\n" + assert returncode == 0 def test_unknown_method(client): @@ -47,3 +81,94 @@ def test_unknown_method(client): assert "method unknown not found in" in str(e) else: raise AssertionError("Expected AttributeError") + + +def test_cli_interface(client): + """Test that the CLI interface is created with all methods""" + cli = client.cli() + + # Check that it's a Click group + assert hasattr(cli, 'commands') + + # Check that all configured methods are available as commands + expected_methods = {"echo", "env", "multi_line", "exit1", "stderr"} + available_commands = set(cli.commands.keys()) + + assert expected_methods == available_commands, f"Expected {expected_methods}, got {available_commands}" + + +def test_cli_method_execution(client): + """Test that CLI methods can be executed""" + cli = client.cli() + + # Test that we can get the echo command + echo_command = cli.commands.get('echo') + assert echo_command is not None + assert echo_command.name == 'echo' + + +def test_cli_includes_all_methods(): + """Test that CLI includes all methods""" + from .driver import Shell + from jumpstarter.common.utils import serve + + shell_instance = Shell( + log_level="DEBUG", + methods={ + "method1": "echo method1", + "method2": "echo method2", + "method3": "echo method3", + }, + ) + + with serve(shell_instance) as test_client: + cli = test_client.cli() + available_commands = set(cli.commands.keys()) + + # All methods should be available + expected_methods = {"method1", "method2", "method3"} + assert available_commands == expected_methods, f"Expected {expected_methods}, got {available_commands}" + + +def test_cli_exit_codes(): + """Test that CLI commands preserve shell command exit codes""" + import click + + from .driver import Shell + from jumpstarter.common.utils import serve + + # Create a shell instance with methods that have different exit codes + shell_instance = Shell( + log_level="DEBUG", + methods={ + "success": "exit 0", + "fail_1": "exit 1", + "fail_42": "exit 42", + }, + ) + + with serve(shell_instance) as test_client: + cli = test_client.cli() + + # Test successful command (exit 0) - should not raise + success_cmd = cli.commands['success'] + try: + success_cmd.callback([], []) # Call with empty args and env + except click.exceptions.Exit: + raise AssertionError("Success command should not raise Exit exception") from None + + # Test command that exits with code 1 - should raise Exit(1) + fail1_cmd = cli.commands['fail_1'] + try: + fail1_cmd.callback([], []) + raise AssertionError("Command should have raised Exit exception") + except click.exceptions.Exit as e: + assert e.exit_code == 1 + + # Test command that exits with code 42 - should raise Exit(42) + fail42_cmd = cli.commands['fail_42'] + try: + fail42_cmd.callback([], []) + raise AssertionError("Command should have raised Exit exception") + except click.exceptions.Exit as e: + assert e.exit_code == 42 diff --git a/packages/jumpstarter-driver-shell/pyproject.toml b/packages/jumpstarter-driver-shell/pyproject.toml index 08b642387..ca2641272 100644 --- a/packages/jumpstarter-driver-shell/pyproject.toml +++ b/packages/jumpstarter-driver-shell/pyproject.toml @@ -6,7 +6,7 @@ readme = "README.md" authors = [{ name = "Miguel Angel Ajo", email = "miguelangel@ajo.es" }] requires-python = ">=3.11" license = "Apache-2.0" -dependencies = ["anyio>=4.10.0", "jumpstarter"] +dependencies = ["anyio>=4.10.0", "jumpstarter", "click>=8.1.8"] [project.entry-points."jumpstarter.drivers"] Shell = "jumpstarter_driver_shell.driver:Shell"