Skip to content

Commit f5563ee

Browse files
fix(vis): avoid graphviz.Digraph memory bomb when ELK fails on large graphs
When ELK layout fails (OOM/timeout) on 1M+ node graphs, the fallback path previously built a graphviz.Digraph in Python — nested subgraph body-list copies exploded memory and hung indefinitely. Now render_elk_direct handles the failure internally: reuses already-collected Phase 1 data to generate DOT text without positions and renders directly with sfdp, bypassing graphviz.Digraph entirely. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 69099e8 commit f5563ee

File tree

2 files changed

+83
-52
lines changed

2 files changed

+83
-52
lines changed

torchlens/visualization/elk_layout.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,30 +1113,45 @@ def _patch_sizes(elk_node):
11131113
_seed_stress_positions(elk_graph, all_edges)
11141114

11151115
elk_timeout = max(_ELK_TIMEOUT, int(num_elk_nodes * 0.015))
1116-
positioned = run_elk_layout(elk_graph, timeout=elk_timeout)
1116+
1117+
# If ELK layout fails (OOM, timeout, etc.), fall back to generating DOT
1118+
# without positions and rendering with sfdp. This avoids the catastrophic
1119+
# graphviz.Digraph construction path in rendering.py which explodes on
1120+
# very large graphs (1M+ nodes) due to nested subgraph body-list copies.
1121+
elk_failed = False
1122+
try:
1123+
positioned = run_elk_layout(elk_graph, timeout=elk_timeout)
1124+
except RuntimeError as e:
1125+
warnings.warn(
1126+
f"ELK layout failed ({e}), generating DOT without positions and rendering with sfdp."
1127+
)
1128+
elk_failed = True
11171129

11181130
# Collect leaf node centers and compound node bounding boxes from ELK output.
11191131
positions = {} # leaf_id -> (center_x, center_y) in ELK coords
11201132
compound_bboxes = {} # "group_<mod>" -> (x, y, w, h) in ELK coords (absolute)
1133+
max_y = 0
1134+
1135+
if not elk_failed:
1136+
1137+
def _collect_pos(elk_node, ox=0, oy=0):
1138+
for ch in elk_node.get("children", []):
1139+
ax = ox + ch.get("x", 0)
1140+
ay = oy + ch.get("y", 0)
1141+
if ch["id"].startswith("group_"):
1142+
w = ch.get("width", 0)
1143+
h = ch.get("height", 0)
1144+
compound_bboxes[ch["id"]] = (ax, ay, w, h)
1145+
_collect_pos(ch, ax, ay)
1146+
else:
1147+
w = ch.get("width", _DEFAULT_NODE_WIDTH)
1148+
h = ch.get("height", _DEFAULT_NODE_HEIGHT)
1149+
positions[ch["id"]] = (ax + w / 2, ay + h / 2)
11211150

1122-
def _collect_pos(elk_node, ox=0, oy=0):
1123-
for ch in elk_node.get("children", []):
1124-
ax = ox + ch.get("x", 0)
1125-
ay = oy + ch.get("y", 0)
1126-
if ch["id"].startswith("group_"):
1127-
w = ch.get("width", 0)
1128-
h = ch.get("height", 0)
1129-
compound_bboxes[ch["id"]] = (ax, ay, w, h)
1130-
_collect_pos(ch, ax, ay)
1131-
else:
1132-
w = ch.get("width", _DEFAULT_NODE_WIDTH)
1133-
h = ch.get("height", _DEFAULT_NODE_HEIGHT)
1134-
positions[ch["id"]] = (ax + w / 2, ay + h / 2)
1135-
1136-
_collect_pos(positioned)
1137-
# Use the root node's full height as the y-flip reference.
1138-
root_h = positioned.get("height", 0)
1139-
max_y = max(root_h, max((y for _, y in positions.values()), default=0))
1151+
_collect_pos(positioned)
1152+
# Use the root node's full height as the y-flip reference.
1153+
root_h = positioned.get("height", 0)
1154+
max_y = max(root_h, max((y for _, y in positions.values()), default=0))
11401155

11411156
# ── Phase 3: Generate DOT with clusters and positions ──
11421157

@@ -1265,7 +1280,9 @@ def _write_cluster(mod_key, depth, indent):
12651280
lines.append("}")
12661281
dot_source = "\n".join(lines)
12671282

1268-
# ── Phase 4: Render with neato -n ──
1283+
# ── Phase 4: Render ──
1284+
# When ELK succeeded, render with neato -n (pre-positioned layout).
1285+
# When ELK failed, render with sfdp (spring-embedder, computes own layout).
12691286

12701287
if num_elk_nodes > 25000 and vis_fileformat != "svg":
12711288
warnings.warn(
@@ -1279,24 +1296,40 @@ def _write_cluster(mod_key, depth, indent):
12791296
f.write(dot_source)
12801297

12811298
rendered_path = f"{vis_outpath}.{vis_fileformat}"
1282-
# Spline routing is O(n^2) — use straight lines for large graphs.
12831299
num_nodes = len(node_data)
1284-
spline_mode = "true" if num_nodes < 1000 else "line"
1285-
cmd = [
1286-
"neato",
1287-
"-n",
1288-
f"-Gsplines={spline_mode}",
1289-
f"-T{vis_fileformat}",
1290-
"-o",
1291-
rendered_path,
1292-
source_path,
1293-
]
1300+
1301+
if elk_failed:
1302+
# sfdp computes its own layout — no positions needed in DOT.
1303+
# Use overlap removal for readability.
1304+
cmd = [
1305+
"sfdp",
1306+
"-Goverlap=prism",
1307+
f"-T{vis_fileformat}",
1308+
"-o",
1309+
rendered_path,
1310+
source_path,
1311+
]
1312+
else:
1313+
# neato -n uses ELK-computed positions.
1314+
# Spline routing is O(n^2) — use straight lines for large graphs.
1315+
spline_mode = "true" if num_nodes < 1000 else "line"
1316+
cmd = [
1317+
"neato",
1318+
"-n",
1319+
f"-Gsplines={spline_mode}",
1320+
f"-T{vis_fileformat}",
1321+
"-o",
1322+
rendered_path,
1323+
source_path,
1324+
]
1325+
12941326
render_timeout = max(_SFDP_TIMEOUT, int(num_nodes * 0.01))
12951327
try:
12961328
result = subprocess.run(cmd, timeout=render_timeout, capture_output=True, text=True)
12971329
if result.returncode != 0:
1330+
engine_name = "sfdp" if elk_failed else "neato"
12981331
raise RuntimeError(
1299-
f"neato rendering failed (exit {result.returncode}):\n{result.stderr}"
1332+
f"{engine_name} rendering failed (exit {result.returncode}):\n{result.stderr}"
13001333
)
13011334
if not vis_save_only:
13021335
import graphviz

torchlens/visualization/rendering.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,28 +204,26 @@ def render_graph(
204204

205205
# ELK fast path: skip graphviz.Digraph construction entirely.
206206
# Generates DOT directly with ELK positions and cluster subgraphs (module boxes).
207+
# If ELK layout fails (OOM, timeout), render_elk_direct falls back internally
208+
# to sfdp — still using the fast DOT-text path, never graphviz.Digraph.
207209
if engine == "elk":
208-
from .elk_layout import render_elk_direct, render_with_sfdp
210+
from .elk_layout import render_elk_direct
209211

210-
try:
211-
result = render_elk_direct(
212-
self,
213-
entries_to_plot,
214-
vis_mode,
215-
vis_nesting_depth,
216-
show_buffer_layers,
217-
overrides,
218-
vis_outpath,
219-
vis_fileformat,
220-
vis_save_only,
221-
graph_caption,
222-
rankdir,
223-
)
224-
_vprint(self, f"Graph saved to {vis_outpath}.{vis_fileformat}")
225-
return result
226-
except RuntimeError as e:
227-
warnings.warn(f"ELK layout failed ({e}), falling back to sfdp.")
228-
engine = "sfdp" # fall through to build graphviz.Digraph for sfdp
212+
result = render_elk_direct(
213+
self,
214+
entries_to_plot,
215+
vis_mode,
216+
vis_nesting_depth,
217+
show_buffer_layers,
218+
overrides,
219+
vis_outpath,
220+
vis_fileformat,
221+
vis_save_only,
222+
graph_caption,
223+
rankdir,
224+
)
225+
_vprint(self, f"Graph saved to {vis_outpath}.{vis_fileformat}")
226+
return result
229227

230228
dot = graphviz.Digraph(
231229
name=self.model_name,

0 commit comments

Comments
 (0)