Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import abc
import contextlib
import functools
import logging
import os
import socket
Expand Down Expand Up @@ -601,3 +603,94 @@ async def setup(self, nanny):
sys.path.insert(0, path)

os.remove(fn)


class forward_stream:
def __init__(self, stream, worker):
self._worker = worker
self._original_methods = {}
self._stream = getattr(sys, stream)
if stream == "stdout":
self._file = 1
elif stream == "stderr":
self._file = 2
else:
raise ValueError(
f"Expected stream to be 'stdout' or 'stderr'; got '{stream}'"
)

self._file = 1 if stream == "stdout" else 2
self._buffer = []

def _write(self, write_fn, data):
self._forward(data)
write_fn(data)

def _forward(self, data):
self._buffer.append(data)
# Mimic line buffering
if "\n" in data or "\r" in data:
self._send()

def _send(self):
msg = {"args": self._buffer, "file": self._file, "sep": "", "end": ""}
self._worker.log_event("print", msg)
self._buffer = []

def _flush(self, flush_fn):
self._send()
flush_fn()

def _close(self, close_fn):
self._send()
close_fn()

def _intercept(self, method_name, interceptor):
original_method = getattr(self._stream, method_name)
self._original_methods[method_name] = original_method
setattr(
self._stream, method_name, functools.partial(interceptor, original_method)
)

def __enter__(self):
self._intercept("write", self._write)
self._intercept("flush", self._flush)
self._intercept("close", self._close)
return self._stream

def __exit__(self, exc_type, exc_value, traceback):
self._stream.flush()
for attr, original in self._original_methods.items():
setattr(self._stream, attr, original)
self._original_methods = {}


class ForwardOutput(WorkerPlugin):
"""A Worker Plugin that forwards ``stdout`` and ``stderr`` from workers to clients

This plugin forwards all output sent to ``stdout`` and ``stderr` on all workers
to all clients where it is written to the respective streams. Analogous to the
terminal, this plugin uses line buffering. To ensure that an output is written
without a newline, make sure to flush the stream.

.. warning::

Using this plugin will forward **all** output in ``stdout`` and ``stderr`` from
every worker to every client. If the output is very chatty, this will add
significant strain on the scheduler. Proceed with caution!

Examples
--------
>>> from dask.distributed import ForwardOutput
>>> plugin = ForwardOutput()

>>> client.register_worker_plugin(plugin)
"""

def setup(self, worker):
self._exit_stack = contextlib.ExitStack()
self._exit_stack.enter_context(forward_stream("stdout", worker=worker))
self._exit_stack.enter_context(forward_stream("stderr", worker=worker))

def teardown(self, worker):
self._exit_stack.close()
112 changes: 111 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import gc
import importlib
import itertools
import logging
import os
import sys
Expand Down Expand Up @@ -44,7 +45,12 @@
from distributed.compatibility import LINUX, WINDOWS, to_thread
from distributed.core import CommClosedError, Status, rpc
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import CondaInstall, PackageInstall, PipInstall
from distributed.diagnostics.plugin import (
CondaInstall,
ForwardOutput,
PackageInstall,
PipInstall,
)
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
Expand Down Expand Up @@ -3741,3 +3747,107 @@ async def test_worker_log_memory_limit_too_high(s):
for snippets in expected_snippets:
# assert any(snip in caplog.text for snip in snippets)
assert any(snip in caplog.getvalue().lower() for snip in snippets)


@gen_cluster(client=True, Worker=Nanny)
async def test_forward_output(c, s, a, b, capsys):
def print_stdout(*args, **kwargs):
print(*args, file=sys.stdout, **kwargs)

def print_stderr(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)

plugin = ForwardOutput()
out, err = capsys.readouterr()

counter = itertools.count()

# Without the plugin installed, we should not see any output
# Note that we use nannies so workers run in subprocesses
await c.submit(print_stdout, "foo", key=next(counter))
await c.submit(print_stdout, "bar\n", key=next(counter))
await c.submit(print_stdout, "baz", end="", flush=True, key=next(counter))
await c.submit(print_stdout, 1, 2, end="\n", sep="\n", key=next(counter))
out, err = capsys.readouterr()

assert "" == out
assert "" == err

await c.submit(print_stderr, "foo", key=next(counter))
await c.submit(print_stderr, "bar\n", key=next(counter))
await c.submit(print_stderr, "baz", flush=True, key=next(counter))
await c.submit(print_stderr, 1, 2, end="\n", sep="\n", key=next(counter))
out, err = capsys.readouterr()

assert "" == out
assert "" == err

# After installing, output should be forwarded
await c.register_worker_plugin(plugin, "forward")
await asyncio.sleep(0.1) # Let setup messages come in
capsys.readouterr()

await c.submit(print_stdout, "foo", key=next(counter))
out, err = capsys.readouterr()

assert "foo\n" == out
assert "" == err

await c.submit(print_stdout, "bar\n", key=next(counter))
out, err = capsys.readouterr()

assert "bar\n\n" == out
assert "" == err

await c.submit(print_stdout, "baz", end="", flush=True, key=next(counter))
out, err = capsys.readouterr()

assert "baz" == out
assert "" == err

await c.submit(print_stdout, "first\nsecond", end="", key=next(counter))
out, err = capsys.readouterr()

assert "first\nsecond" == out
assert "" == err

await c.submit(print_stdout, 1, 2, sep=":", key=next(counter))
out, err = capsys.readouterr()

assert "1:2\n" == out
assert err == ""

await c.submit(print_stderr, "fatal", key=next(counter))
out, err = capsys.readouterr()

assert "" == out
assert "fatal\n" == err

# Registering the plugin is idempotent
other_plugin = ForwardOutput()
await c.register_worker_plugin(other_plugin, "forward")
await asyncio.sleep(0.1) # Let teardown/setup messages come in
out, err = capsys.readouterr()

await c.submit(print_stdout, "foo", key=next(counter))
out, err = capsys.readouterr()

assert "foo\n" == out
assert "" == err

# After unregistering the plugin, we should once again not see any output
await c.unregister_worker_plugin("forward")
await asyncio.sleep(0.1) # Let teardown messages come in
capsys.readouterr()

await c.submit(print_stdout, "foo", key=next(counter))
out, err = capsys.readouterr()

assert "" == out
assert "" == err

await c.submit(print_stderr, "fatal", key=next(counter))
out, err = capsys.readouterr()

assert "" == out
assert "" == err