diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index f6c82f4c06b..d1b5b256bbd 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -1,6 +1,8 @@ from __future__ import annotations import abc +import contextlib +import functools import logging import os import socket @@ -601,3 +603,94 @@ async def setup(self, nanny): sys.path.insert(0, path) os.remove(fn) + + +class forward_stream: + def __init__(self, stream, worker): + self._worker = worker + self._original_methods = {} + self._stream = getattr(sys, stream) + if stream == "stdout": + self._file = 1 + elif stream == "stderr": + self._file = 2 + else: + raise ValueError( + f"Expected stream to be 'stdout' or 'stderr'; got '{stream}'" + ) + + self._file = 1 if stream == "stdout" else 2 + self._buffer = [] + + def _write(self, write_fn, data): + self._forward(data) + write_fn(data) + + def _forward(self, data): + self._buffer.append(data) + # Mimic line buffering + if "\n" in data or "\r" in data: + self._send() + + def _send(self): + msg = {"args": self._buffer, "file": self._file, "sep": "", "end": ""} + self._worker.log_event("print", msg) + self._buffer = [] + + def _flush(self, flush_fn): + self._send() + flush_fn() + + def _close(self, close_fn): + self._send() + close_fn() + + def _intercept(self, method_name, interceptor): + original_method = getattr(self._stream, method_name) + self._original_methods[method_name] = original_method + setattr( + self._stream, method_name, functools.partial(interceptor, original_method) + ) + + def __enter__(self): + self._intercept("write", self._write) + self._intercept("flush", self._flush) + self._intercept("close", self._close) + return self._stream + + def __exit__(self, exc_type, exc_value, traceback): + self._stream.flush() + for attr, original in self._original_methods.items(): + setattr(self._stream, attr, original) + self._original_methods = {} + + +class ForwardOutput(WorkerPlugin): + """A Worker Plugin that forwards ``stdout`` and ``stderr`` from workers to clients + + This plugin forwards all output sent to ``stdout`` and ``stderr` on all workers + to all clients where it is written to the respective streams. Analogous to the + terminal, this plugin uses line buffering. To ensure that an output is written + without a newline, make sure to flush the stream. + + .. warning:: + + Using this plugin will forward **all** output in ``stdout`` and ``stderr`` from + every worker to every client. If the output is very chatty, this will add + significant strain on the scheduler. Proceed with caution! + + Examples + -------- + >>> from dask.distributed import ForwardOutput + >>> plugin = ForwardOutput() + + >>> client.register_worker_plugin(plugin) + """ + + def setup(self, worker): + self._exit_stack = contextlib.ExitStack() + self._exit_stack.enter_context(forward_stream("stdout", worker=worker)) + self._exit_stack.enter_context(forward_stream("stderr", worker=worker)) + + def teardown(self, worker): + self._exit_stack.close() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8f1bfcc9c32..cf8190312ec 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3,6 +3,7 @@ import asyncio import gc import importlib +import itertools import logging import os import sys @@ -44,7 +45,12 @@ from distributed.compatibility import LINUX, WINDOWS, to_thread from distributed.core import CommClosedError, Status, rpc from distributed.diagnostics import nvml -from distributed.diagnostics.plugin import CondaInstall, PackageInstall, PipInstall +from distributed.diagnostics.plugin import ( + CondaInstall, + ForwardOutput, + PackageInstall, + PipInstall, +) from distributed.metrics import time from distributed.protocol import pickle from distributed.scheduler import Scheduler @@ -3741,3 +3747,107 @@ async def test_worker_log_memory_limit_too_high(s): for snippets in expected_snippets: # assert any(snip in caplog.text for snip in snippets) assert any(snip in caplog.getvalue().lower() for snip in snippets) + + +@gen_cluster(client=True, Worker=Nanny) +async def test_forward_output(c, s, a, b, capsys): + def print_stdout(*args, **kwargs): + print(*args, file=sys.stdout, **kwargs) + + def print_stderr(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + plugin = ForwardOutput() + out, err = capsys.readouterr() + + counter = itertools.count() + + # Without the plugin installed, we should not see any output + # Note that we use nannies so workers run in subprocesses + await c.submit(print_stdout, "foo", key=next(counter)) + await c.submit(print_stdout, "bar\n", key=next(counter)) + await c.submit(print_stdout, "baz", end="", flush=True, key=next(counter)) + await c.submit(print_stdout, 1, 2, end="\n", sep="\n", key=next(counter)) + out, err = capsys.readouterr() + + assert "" == out + assert "" == err + + await c.submit(print_stderr, "foo", key=next(counter)) + await c.submit(print_stderr, "bar\n", key=next(counter)) + await c.submit(print_stderr, "baz", flush=True, key=next(counter)) + await c.submit(print_stderr, 1, 2, end="\n", sep="\n", key=next(counter)) + out, err = capsys.readouterr() + + assert "" == out + assert "" == err + + # After installing, output should be forwarded + await c.register_worker_plugin(plugin, "forward") + await asyncio.sleep(0.1) # Let setup messages come in + capsys.readouterr() + + await c.submit(print_stdout, "foo", key=next(counter)) + out, err = capsys.readouterr() + + assert "foo\n" == out + assert "" == err + + await c.submit(print_stdout, "bar\n", key=next(counter)) + out, err = capsys.readouterr() + + assert "bar\n\n" == out + assert "" == err + + await c.submit(print_stdout, "baz", end="", flush=True, key=next(counter)) + out, err = capsys.readouterr() + + assert "baz" == out + assert "" == err + + await c.submit(print_stdout, "first\nsecond", end="", key=next(counter)) + out, err = capsys.readouterr() + + assert "first\nsecond" == out + assert "" == err + + await c.submit(print_stdout, 1, 2, sep=":", key=next(counter)) + out, err = capsys.readouterr() + + assert "1:2\n" == out + assert err == "" + + await c.submit(print_stderr, "fatal", key=next(counter)) + out, err = capsys.readouterr() + + assert "" == out + assert "fatal\n" == err + + # Registering the plugin is idempotent + other_plugin = ForwardOutput() + await c.register_worker_plugin(other_plugin, "forward") + await asyncio.sleep(0.1) # Let teardown/setup messages come in + out, err = capsys.readouterr() + + await c.submit(print_stdout, "foo", key=next(counter)) + out, err = capsys.readouterr() + + assert "foo\n" == out + assert "" == err + + # After unregistering the plugin, we should once again not see any output + await c.unregister_worker_plugin("forward") + await asyncio.sleep(0.1) # Let teardown messages come in + capsys.readouterr() + + await c.submit(print_stdout, "foo", key=next(counter)) + out, err = capsys.readouterr() + + assert "" == out + assert "" == err + + await c.submit(print_stderr, "fatal", key=next(counter)) + out, err = capsys.readouterr() + + assert "" == out + assert "" == err