Skip to content

Commit c2285b8

Browse files
committed
Change dump output format to dict with value and metadata
Save each dump as {"value": tensor, "meta": {name, rank, ...}} instead of just the raw tensor. This co-locates metadata with the data and lets downstream tools (dump_comparator, dump_loader) access context without parsing filenames. Update dump_comparator._load_object and DumpLoader.load to extract the value from the new dict format while staying backward-compatible with raw tensor files.
1 parent f96242f commit c2285b8

File tree

4 files changed

+70
-3
lines changed

4 files changed

+70
-3
lines changed

python/sglang/srt/debug_utils/dump_comparator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def _load_object(path):
280280
print(f"Skip load {path} since error {e}")
281281
return None
282282

283+
if isinstance(x, dict) and "value" in x:
284+
x = x["value"]
285+
283286
if not isinstance(x, torch.Tensor):
284287
print(f"Skip load {path} since {type(x)=} is not a Tensor ({x=})")
285288
return None

python/sglang/srt/debug_utils/dump_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def load(self, name, **kwargs):
3434

3535
path = self._directory / row["filename"]
3636
output = torch.load(path, weights_only=False)
37+
if isinstance(output, dict) and "value" in output:
38+
output = output["value"]
3739

3840
print(
3941
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"

python/sglang/srt/debug_utils/dumper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def dump(self, name, value, save: bool = True, **kwargs):
164164

165165
if self._enable_write_file and save:
166166
path.parent.mkdir(parents=True, exist_ok=True)
167-
_torch_save(value, str(path))
167+
output_data = {
168+
"value": value,
169+
"meta": dict(**full_kwargs),
170+
}
171+
_torch_save(output_data, str(path))
168172

169173

170174
def _torch_save(value, path: str):

test/registered/debug_utils/test_dumper.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.distributed as dist
99

1010
from sglang.srt.debug_utils.dumper import (
11+
_Dumper,
1112
_obj_to_dict,
1213
_torch_save,
1314
get_tensor_info,
@@ -161,8 +162,12 @@ def _test_file_content_func(rank, tmpdir):
161162

162163
dist.barrier()
163164
path = _find_dump_file(tmpdir, rank=rank, name="content_check")
164-
loaded = torch.load(path, map_location="cpu", weights_only=True)
165-
assert torch.equal(loaded, tensor.cpu())
165+
raw = _load_dump(path)
166+
assert isinstance(raw, dict), f"Expected dict, got {type(raw)}"
167+
assert "value" in raw and "meta" in raw
168+
assert torch.equal(raw["value"], tensor.cpu())
169+
assert raw["meta"]["name"] == "content_check"
170+
assert raw["meta"]["rank"] == rank
166171

167172

168173
class TestDumperFileWriteControl:
@@ -230,6 +235,54 @@ def _test_save_false_func(rank, tmpdir):
230235
assert len(_get_filenames(tmpdir)) == 0
231236

232237

238+
class TestDumpDictFormat:
239+
"""Verify that dump files use the dict output format: {"value": ..., "meta": {...}}."""
240+
241+
def test_dict_format_structure(self, tmp_path):
242+
dumper = _make_test_dumper(tmp_path)
243+
tensor = torch.randn(4, 4)
244+
dumper.dump("fmt_test", tensor, custom_key="hello")
245+
246+
path = _find_dump_file(str(tmp_path), rank=0, name="fmt_test")
247+
raw = _load_dump(path)
248+
249+
assert isinstance(raw, dict)
250+
assert set(raw.keys()) == {"value", "meta"}
251+
assert torch.equal(raw["value"], tensor)
252+
253+
meta = raw["meta"]
254+
assert meta["name"] == "fmt_test"
255+
assert meta["custom_key"] == "hello"
256+
assert "forward_pass_id" in meta
257+
assert "rank" in meta
258+
assert "dump_index" in meta
259+
260+
def test_dict_format_with_context(self, tmp_path):
261+
dumper = _make_test_dumper(tmp_path)
262+
dumper.set_ctx(ctx_val=42)
263+
tensor = torch.randn(2, 2)
264+
dumper.dump("ctx_fmt", tensor)
265+
266+
path = _find_dump_file(str(tmp_path), rank=0, name="ctx_fmt")
267+
raw = _load_dump(path)
268+
269+
assert raw["meta"]["ctx_val"] == 42
270+
assert torch.equal(raw["value"], tensor)
271+
272+
273+
def _make_test_dumper(tmp_path: Path, **overrides) -> _Dumper:
274+
"""Create a _Dumper for CPU testing without HTTP server or distributed."""
275+
defaults: dict = dict(
276+
enable=True,
277+
base_dir=tmp_path,
278+
partial_name="test",
279+
enable_http_server=False,
280+
)
281+
d = _Dumper(**{**defaults, **overrides})
282+
d.on_forward_pass_start()
283+
return d
284+
285+
233286
def _get_filenames(tmpdir):
234287
return {f.name for f in Path(tmpdir).glob("sglang_dump_*/*.pt")}
235288

@@ -243,6 +296,11 @@ def _assert_files(filenames, *, exist=(), not_exist=()):
243296
), f"{p} should not exist in {filenames}"
244297

245298

299+
def _load_dump(path: Path) -> dict:
300+
"""Load a dump file and return the raw dict (with 'value' and 'meta' keys)."""
301+
return torch.load(path, map_location="cpu", weights_only=False)
302+
303+
246304
def _find_dump_file(tmpdir, *, rank: int = 0, name: str) -> Path:
247305
matches = [
248306
f

0 commit comments

Comments
 (0)