Skip to content
Merged
Prev Previous commit
Next Next commit
refactor(test): extract _capture_stdout context manager for stdout re…
…direction
  • Loading branch information
fzyzcjy committed Feb 19, 2026
commit ff367582ea0bc738699db24838f14440dba57d32
37 changes: 18 additions & 19 deletions test/registered/debug_utils/test_dumper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
import sys
import threading
import time
from contextlib import contextmanager
from pathlib import Path

import pytest
Expand All @@ -27,6 +29,17 @@
register_amd_ci(est_time=60, suite="nightly-amd", nightly=True)


@contextmanager
def _capture_stdout():
captured = io.StringIO()
old_stdout = sys.stdout
sys.stdout = captured
try:
yield captured
finally:
sys.stdout = old_stdout


class TestDumperPureFunctions:
def test_get_truncated_value(self):
assert get_truncated_value(None) is None
Expand Down Expand Up @@ -90,23 +103,17 @@ def test_silent_skip(self, tmp_path, capsys):

class TestCollectiveTimeout:
def test_watchdog_fires_on_timeout(self):
import io

block_event = threading.Event()

old_stdout = sys.stdout
captured = io.StringIO()
sys.stdout = captured
captured_output = [None]

def run_with_timeout():
try:
with _capture_stdout() as captured:
_collective_with_timeout(
lambda: block_event.wait(),
operation_name="test_blocked_op",
timeout_seconds=2,
)
finally:
sys.stdout = old_stdout
captured_output[0] = captured.getvalue()

worker = threading.Thread(target=run_with_timeout)
worker.start()
Expand All @@ -115,7 +122,7 @@ def run_with_timeout():
block_event.set()
worker.join(timeout=5)

output = captured.getvalue()
output = captured_output[0]
assert "WARNING" in output
assert "test_blocked_op" in output
assert "2s" in output
Expand Down Expand Up @@ -166,8 +173,6 @@ def test_collective_timeout(self):

@staticmethod
def _test_collective_timeout_func(rank):
import io

dumper = _Dumper(
enable=True,
base_dir=Path("/tmp"),
Expand All @@ -176,16 +181,10 @@ def _test_collective_timeout_func(rank):
collective_timeout=3,
)

captured = io.StringIO()
old_stdout = sys.stdout
sys.stdout = captured

try:
with _capture_stdout() as captured:
if rank != 0:
time.sleep(6)
dumper.on_forward_pass_start()
finally:
sys.stdout = old_stdout

output = captured.getvalue()

Expand Down
Loading