Skip to content

Commit 23ef8d8

Browse files
fix(vis): pass heap limits to ELK Worker thread to prevent OOM on 1M nodes
The Node.js Worker running ELK layout had no explicit maxOldGenerationSizeMb in its resourceLimits — only stackSizeMb was set. The --max-old-space-size flag controls the main thread's V8 isolate, not the Worker's. This caused the Worker to OOM at ~16GB on 1M-node graphs despite the main thread being configured for up to 64GB. - Add maxOldGenerationSizeMb and maxYoungGenerationSizeMb to Worker resourceLimits, passed via _TL_HEAP_MB env var - Add _available_memory_mb() to detect system RAM and cap heap allocation to (available - 4GB), preventing competition with Python process - Include available system memory in OOM diagnostic messages Also includes field/param renames from feat/grand-rename branch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 254337f commit 23ef8d8

File tree

1 file changed

+65
-33
lines changed

1 file changed

+65
-33
lines changed

torchlens/visualization/elk_layout.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,25 @@
3535
pass
3636

3737

38+
def _available_memory_mb() -> int:
39+
"""Return available system memory in MB, or 0 if unknown."""
40+
try:
41+
with open("/proc/meminfo") as f:
42+
for line in f:
43+
if line.startswith("MemAvailable:"):
44+
return int(line.split()[1]) // 1024 # KB -> MB
45+
except (OSError, ValueError, IndexError):
46+
pass
47+
try:
48+
pages = os.sysconf("SC_AVPHYS_PAGES")
49+
page_size = os.sysconf("SC_PAGE_SIZE")
50+
if pages > 0 and page_size > 0:
51+
return (pages * page_size) // (1024 * 1024)
52+
except (ValueError, OSError):
53+
pass
54+
return 0
55+
56+
3857
_ELK_NODE_THRESHOLD = 3500
3958
_ELK_TIMEOUT = 120 # seconds for Node.js subprocess
4059
_SFDP_TIMEOUT = 120 # seconds for sfdp/neato subprocess
@@ -52,6 +71,7 @@
5271
// flag for preventing "Maximum call stack size exceeded" in deeply recursive
5372
// ELK layout on large graphs (100k+ nodes).
5473
const stackMb = parseInt(process.env._TL_STACK_MB || '64', 10);
74+
const heapMb = parseInt(process.env._TL_HEAP_MB || '16384', 10);
5575
5676
const workerCode = `
5777
const { parentPort, workerData } = require('worker_threads');
@@ -67,7 +87,11 @@
6787
const worker = new Worker(workerCode, {
6888
eval: true,
6989
workerData: input,
70-
resourceLimits: { stackSizeMb: stackMb },
90+
resourceLimits: {
91+
stackSizeMb: stackMb,
92+
maxOldGenerationSizeMb: heapMb,
93+
maxYoungGenerationSizeMb: Math.min(2048, Math.floor(heapMb / 8)),
94+
},
7195
});
7296
worker.on('message', (result) => {
7397
process.stdout.write(result);
@@ -212,7 +236,7 @@ def build_elk_graph_hierarchical(entries_to_plot, show_buffer_layers: bool = Fal
212236
from collections import defaultdict
213237

214238
# Step 1: Collect all nodes and their module paths.
215-
# module_path is containing_modules_origin_nested with pass info stripped.
239+
# module_path is containing_modules with pass info stripped.
216240
node_module_map = {} # node_label -> [module_addr, ...]
217241
node_labels = []
218242
edges = []
@@ -226,7 +250,7 @@ def build_elk_graph_hierarchical(entries_to_plot, show_buffer_layers: bool = Fal
226250

227251
# Strip pass numbers from module addresses for grouping.
228252
modules = []
229-
for mod in node.containing_modules_origin_nested:
253+
for mod in node.containing_modules:
230254
addr = mod.split(":")[0]
231255
if addr not in modules:
232256
modules.append(addr)
@@ -424,12 +448,19 @@ def run_elk_layout(elk_graph: dict, timeout: Optional[int] = None) -> dict:
424448
# Cap heap at 64GB — V8 only allocates what it actually needs, and
425449
# unbounded values (e.g. 5.6TB for 1M nodes) are nonsensical.
426450
heap_mb = min(65536, max(16384, graph_kb * 48))
451+
452+
# Further cap to available system memory (leave 4GB for Python + OS).
453+
avail_mb = _available_memory_mb()
454+
if avail_mb > 0:
455+
heap_mb = min(heap_mb, max(4096, avail_mb - 4096))
456+
427457
# Worker thread stack via resourceLimits.stackSizeMb (MB).
428458
# Floor of 4096 MB (matches CHANGELOG), cap at 8192 MB.
429459
stack_mb = min(8192, max(4096, graph_kb // 8))
430460

431461
env = _node_env()
432462
env["_TL_STACK_MB"] = str(stack_mb)
463+
env["_TL_HEAP_MB"] = str(heap_mb)
433464

434465
# Write JSON to a temp file so Node.js reads from disk instead of stdin.
435466
# This lets us free the graph_json string before the subprocess runs,
@@ -473,7 +504,8 @@ def run_elk_layout(elk_graph: dict, timeout: Optional[int] = None) -> dict:
473504
if result.stderr
474505
else (
475506
f"Node.js exited with code {result.returncode} (no stderr). "
476-
f"Likely OOM — JSON was {graph_kb} KB, heap was {heap_mb} MB. "
507+
f"Likely OOM — JSON was {graph_kb} KB, heap was {heap_mb} MB"
508+
f"{f', system had {avail_mb} MB available' if avail_mb > 0 else ''}. "
477509
f"Try reducing vis_nesting_depth to collapse modules."
478510
)
479511
)
@@ -549,7 +581,7 @@ def render_with_sfdp(
549581
source_path: str,
550582
vis_outpath: str,
551583
vis_fileformat: str,
552-
save_only: bool = False,
584+
vis_save_only: bool = False,
553585
timeout: Optional[int] = None,
554586
) -> None:
555587
"""Render a DOT source file using Graphviz sfdp engine.
@@ -558,7 +590,7 @@ def render_with_sfdp(
558590
source_path: Path to the DOT source file.
559591
vis_outpath: Output path (without extension).
560592
vis_fileformat: Output format (pdf, png, etc.).
561-
save_only: If True, don't open viewer.
593+
vis_save_only: If True, don't open viewer.
562594
timeout: Subprocess timeout in seconds.
563595
"""
564596
import graphviz
@@ -569,7 +601,7 @@ def render_with_sfdp(
569601
rendered_path = f"{vis_outpath}.{vis_fileformat}"
570602
cmd = ["sfdp", f"-T{vis_fileformat}", "-Goverlap=prism", "-o", rendered_path, source_path]
571603
subprocess.run(cmd, timeout=timeout, check=True, capture_output=True)
572-
if not save_only:
604+
if not vis_save_only:
573605
graphviz.backend.viewing.view(rendered_path)
574606

575607

@@ -578,7 +610,7 @@ def render_with_elk(
578610
source_path: str,
579611
vis_outpath: str,
580612
vis_fileformat: str,
581-
save_only: bool = False,
613+
vis_save_only: bool = False,
582614
entries_to_plot=None,
583615
show_buffer_layers: bool = False,
584616
) -> None:
@@ -589,7 +621,7 @@ def render_with_elk(
589621
source_path: Path where DOT source was saved.
590622
vis_outpath: Output path (without extension).
591623
vis_fileformat: Output format (pdf, png, etc.).
592-
save_only: If True, don't open viewer.
624+
vis_save_only: If True, don't open viewer.
593625
entries_to_plot: If provided, build hierarchical ELK graph with
594626
module grouping. Otherwise falls back to flat DOT parsing.
595627
show_buffer_layers: Whether to include buffer layers.
@@ -620,7 +652,7 @@ def render_with_elk(
620652
]
621653
try:
622654
subprocess.run(cmd, timeout=_SFDP_TIMEOUT, check=True, capture_output=True)
623-
if not save_only:
655+
if not vis_save_only:
624656
graphviz.backend.viewing.view(rendered_path)
625657
finally:
626658
import os
@@ -763,13 +795,13 @@ def _dot_id(name: str) -> str:
763795
def render_elk_direct(
764796
model_log,
765797
entries_to_plot: dict,
766-
vis_opt: str,
798+
vis_mode: str,
767799
vis_nesting_depth: int,
768800
show_buffer_layers: bool,
769801
overrides,
770802
vis_outpath: str,
771803
vis_fileformat: str,
772-
save_only: bool,
804+
vis_save_only: bool,
773805
graph_caption: str,
774806
rankdir: str,
775807
) -> str:
@@ -788,13 +820,13 @@ def render_elk_direct(
788820
Args:
789821
model_log: The ModelLog instance.
790822
entries_to_plot: Dict of node_barcode -> LayerPassLog/LayerLog.
791-
vis_opt: ``'unrolled'`` or ``'rolled'``.
823+
vis_mode: ``'unrolled'`` or ``'rolled'``.
792824
vis_nesting_depth: Module nesting depth for collapsed modules.
793825
show_buffer_layers: Whether to include buffer layers.
794826
overrides: VisualizationOverrides instance.
795827
vis_outpath: Output file path (without extension).
796828
vis_fileformat: Output format (pdf, png, svg, etc.).
797-
save_only: If True, don't open viewer.
829+
vis_save_only: If True, don't open viewer.
798830
graph_caption: HTML label for the graph title.
799831
rankdir: Graphviz rank direction (BT, TB, LR).
800832
@@ -841,10 +873,10 @@ def render_elk_direct(
841873
def _module_keys_for_node(node, is_collapsed_mod):
842874
"""Get module hierarchy keys for a node."""
843875
if is_collapsed_mod:
844-
mods = list(node.containing_modules_origin_nested[: vis_nesting_depth - 1])
876+
mods = list(node.containing_modules[: vis_nesting_depth - 1])
845877
else:
846-
mods = list(node.containing_modules_origin_nested)
847-
if vis_opt == "rolled":
878+
mods = list(node.containing_modules)
879+
if vis_mode == "rolled":
848880
return list(dict.fromkeys(m.split(":")[0] for m in mods))
849881
return mods
850882

@@ -868,18 +900,18 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
868900
is_collapsed = _is_collapsed_module(node, vis_nesting_depth)
869901

870902
if is_collapsed:
871-
mod_w_pass = node.containing_modules_origin_nested[vis_nesting_depth - 1]
903+
mod_w_pass = node.containing_modules[vis_nesting_depth - 1]
872904
mod_parts = mod_w_pass.rsplit(":", 1)
873905
mod_addr, pass_num = mod_parts
874-
node_name = "pass".join(mod_parts) if vis_opt == "unrolled" else mod_addr
906+
node_name = "pass".join(mod_parts) if vis_mode == "unrolled" else mod_addr
875907
elk_id = node.layer_label
876908

877909
if node_name not in collapsed_set:
878910
collapsed_set.add(node_name)
879911
ml = model_log.modules[mod_addr]
880912
mod_out = model_log[mod_w_pass]
881913

882-
if vis_opt == "unrolled":
914+
if vis_mode == "unrolled":
883915
mpl = model_log.modules[mod_w_pass]
884916
n_tensors = mpl.num_layers
885917
has_anc = any(model_log[la].has_input_ancestor for la in mpl.layers)
@@ -890,7 +922,7 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
890922
np_ = ml.num_passes
891923
if np_ == 1:
892924
title = f"<b>@{mod_addr}</b>"
893-
elif vis_opt == "unrolled":
925+
elif vis_mode == "unrolled":
894926
title = f"<b>@{mod_addr}:{pass_num}</b>"
895927
else:
896928
title = f"<b>@{mod_addr} (x{np_})</b>"
@@ -929,7 +961,7 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
929961
ls = "solid" if has_anc else "dashed"
930962
lbl = (
931963
f"<{title}<br/>{ml.module_class_name}<br/>"
932-
f"{ss} ({mod_out.tensor_fsize_nice})<br/>"
964+
f"{ss} ({mod_out.tensor_memory_str})<br/>"
933965
f"{n_tensors} layers total<br/>{pd}>"
934966
)
935967
attrs = {
@@ -962,7 +994,7 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
962994
)
963995
bg = _get_node_bg_color(model_log, node)
964996
ls = "solid" if node.has_input_ancestor else "dashed"
965-
lbl = _make_node_label(node, addr, vis_opt)
997+
lbl = _make_node_label(node, addr, vis_mode)
966998

967999
attrs = {
9681000
"label": lbl,
@@ -985,7 +1017,7 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
9851017

9861018
# ── Collect edges (this node → its children) ──
9871019
for child_label in node.child_layers:
988-
if vis_opt == "unrolled":
1020+
if vis_mode == "unrolled":
9891021
child_node = model_log.layer_dict_main_keys.get(child_label)
9901022
else:
9911023
child_node = model_log.layer_logs.get(child_label)
@@ -1003,19 +1035,19 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
10031035
# Resolve head name
10041036
child_is_collapsed = _is_collapsed_module(child_node, vis_nesting_depth)
10051037
if child_is_collapsed:
1006-
c_mod_w_pass = child_node.containing_modules_origin_nested[vis_nesting_depth - 1]
1038+
c_mod_w_pass = child_node.containing_modules[vis_nesting_depth - 1]
10071039
c_parts = c_mod_w_pass.rsplit(":", 1)
1008-
head_name = "pass".join(c_parts) if vis_opt == "unrolled" else c_parts[0]
1040+
head_name = "pass".join(c_parts) if vis_mode == "unrolled" else c_parts[0]
10091041
else:
10101042
head_name = child_node.layer_label.replace(":", "pass")
10111043

10121044
# Intra-module skip for two collapsed nodes in the same module
10131045
if is_collapsed and child_is_collapsed and tail_name != head_name:
1014-
p_mods = node.containing_modules_origin_nested[:]
1015-
c_mods = child_node.containing_modules_origin_nested[:]
1016-
if node.is_bottom_level_submodule_output:
1046+
p_mods = node.containing_modules[:]
1047+
c_mods = child_node.containing_modules[:]
1048+
if node.is_leaf_module_output:
10171049
p_mods = p_mods[:-1]
1018-
if child_node.is_bottom_level_submodule_output:
1050+
if child_node.is_leaf_module_output:
10191051
c_mods = c_mods[:-1]
10201052
if p_mods[:vis_nesting_depth] == c_mods[:vis_nesting_depth]:
10211053
continue
@@ -1165,9 +1197,9 @@ def _write_cluster(mod_key, depth, indent):
11651197
mod_type = ml.module_class_name if ml else "Module"
11661198
np_ = ml.num_passes if ml else 1
11671199

1168-
if vis_opt == "unrolled" and np_ > 1 and ":" in mod_key:
1200+
if vis_mode == "unrolled" and np_ > 1 and ":" in mod_key:
11691201
title = mod_key
1170-
elif vis_opt == "rolled" and np_ > 1:
1202+
elif vis_mode == "rolled" and np_ > 1:
11711203
title = f"{mod_addr} (x{np_})"
11721204
else:
11731205
title = mod_addr
@@ -1266,7 +1298,7 @@ def _write_cluster(mod_key, depth, indent):
12661298
raise RuntimeError(
12671299
f"neato rendering failed (exit {result.returncode}):\n{result.stderr}"
12681300
)
1269-
if not save_only:
1301+
if not vis_save_only:
12701302
import graphviz
12711303

12721304
graphviz.backend.viewing.view(rendered_path)

0 commit comments

Comments
 (0)