Skip to content

Commit ebdaae5

Browse files
Merge pull request #132 from johnmarktaylor91/fix/elk-sfdp-fallback-oom
fix(vis): avoid graphviz.Digraph memory bomb when ELK fails on large graphs
2 parents 364ea39 + 37cce3a commit ebdaae5

File tree

2 files changed

+244
-69
lines changed

2 files changed

+244
-69
lines changed

torchlens/visualization/elk_layout.py

Lines changed: 226 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,164 @@ def _inject(node):
747747
_inject(elk_graph)
748748

749749

750+
# ── Python topological layout for very large graphs ──
751+
#
752+
# ELK's stress algorithm allocates TWO O(n²) distance matrices (n² × 16 bytes).
753+
# At 150k nodes that's 360 GB — impossible on any workstation. At 1M nodes
754+
# it's 16 TB. The layered algorithm (Sugiyama) also scales poorly for wide
755+
# layers.
756+
#
757+
# For graphs above _ELK_STRESS_LIMIT we skip ELK entirely and compute a
758+
# topological rank layout in Python. Kahn's algorithm gives O(n+m) time
759+
# and memory, and produces a clean DAG layout with directional flow (inputs
760+
# at bottom, outputs at top). Module bounding boxes are computed from the
761+
# positions of nodes assigned to each module.
762+
763+
_ELK_STRESS_LIMIT = 100_000 # nodes — above this, ELK stress cannot allocate
764+
765+
766+
def _compute_topological_layout(
767+
node_data: dict,
768+
all_edges: list,
769+
elk_id_sizes: dict,
770+
module_direct_nodes: dict,
771+
module_child_map: dict,
772+
) -> tuple:
773+
"""Compute node positions via topological rank layout.
774+
775+
Returns ``(positions, compound_bboxes, max_y)`` with the same interface
776+
as the ELK path so Phase 3 can use either interchangeably.
777+
778+
Args:
779+
node_data: ``{dot_name: {"attrs": {...}, "elk_id": label}}``
780+
all_edges: List of edge dicts with ``tail_name``, ``head_name``.
781+
elk_id_sizes: ``{elk_id: (width, height)}`` — label-based size estimates.
782+
module_direct_nodes: ``{module_key: [dot_names]}``
783+
module_child_map: ``{module_key: {child_module_keys}}``
784+
785+
Returns:
786+
(positions, compound_bboxes, max_y) — same types as ELK path.
787+
"""
788+
from collections import defaultdict, deque
789+
790+
# Map elk_id -> dot_name for reverse lookup.
791+
elk_to_dot = {}
792+
for dot_name, nd in node_data.items():
793+
elk_to_dot[nd["elk_id"]] = dot_name
794+
795+
all_elk_ids = set(nd["elk_id"] for nd in node_data.values())
796+
797+
# Build adjacency from DOT-level edges.
798+
children_of = defaultdict(list)
799+
in_degree: dict = defaultdict(int)
800+
for e in all_edges:
801+
src = e.get("tail_name") or e["tail_name"]
802+
tgt = e.get("head_name") or e["head_name"]
803+
# Map dot_name -> elk_id
804+
src_eid = node_data.get(src, {}).get("elk_id")
805+
tgt_eid = node_data.get(tgt, {}).get("elk_id")
806+
if src_eid in all_elk_ids and tgt_eid in all_elk_ids:
807+
children_of[src_eid].append(tgt_eid)
808+
in_degree[tgt_eid] += 1
809+
810+
# Kahn's algorithm for topological depth assignment.
811+
depth: dict = {}
812+
queue: deque = deque()
813+
for nid in all_elk_ids:
814+
if in_degree[nid] == 0:
815+
depth[nid] = 0
816+
queue.append(nid)
817+
818+
while queue:
819+
nid = queue.popleft()
820+
for child in children_of[nid]:
821+
new_depth = depth[nid] + 1
822+
if child not in depth or new_depth > depth[child]:
823+
depth[child] = new_depth
824+
in_degree[child] -= 1
825+
if in_degree[child] == 0:
826+
queue.append(child)
827+
828+
# Unreached nodes (cycles or disconnected) get depth 0.
829+
for nid in all_elk_ids:
830+
if nid not in depth:
831+
depth[nid] = 0
832+
833+
# Group by depth rank.
834+
ranks = defaultdict(list)
835+
for nid, d in depth.items():
836+
ranks[d].append(nid)
837+
838+
# Sort nodes within each rank by module membership for visual grouping.
839+
# Build elk_id -> module_key lookup.
840+
elk_id_module = {}
841+
for mod_key, dot_names in module_direct_nodes.items():
842+
for dn in dot_names:
843+
nd = node_data.get(dn)
844+
if nd:
845+
elk_id_module[nd["elk_id"]] = mod_key
846+
847+
for d in ranks:
848+
ranks[d].sort(key=lambda nid: elk_id_module.get(nid, ""))
849+
850+
# Compute positions. Y = depth rank, X = position within rank.
851+
spacing_y = 120 # points between ranks
852+
spacing_x = 30 # points between node edges within a rank
853+
positions = {}
854+
855+
for d, nodes in sorted(ranks.items()):
856+
x_cursor = 0.0
857+
for nid in nodes:
858+
w, h = elk_id_sizes.get(nid, (_DEFAULT_NODE_WIDTH, _DEFAULT_NODE_HEIGHT))
859+
cx = x_cursor + w / 2
860+
cy = d * spacing_y + h / 2
861+
positions[nid] = (cx, cy)
862+
x_cursor += w + spacing_x
863+
864+
max_y = max((y for _, y in positions.values()), default=0) + _DEFAULT_NODE_HEIGHT
865+
866+
# Compute module bounding boxes from node positions.
867+
# Collect all elk_ids in each module (including nested children).
868+
def _collect_module_elk_ids(mod_key):
869+
ids = set()
870+
for dn in module_direct_nodes.get(mod_key, []):
871+
nd = node_data.get(dn)
872+
if nd and nd["elk_id"] in positions:
873+
ids.add(nd["elk_id"])
874+
for child_mod in module_child_map.get(mod_key, set()):
875+
ids.update(_collect_module_elk_ids(child_mod))
876+
return ids
877+
878+
compound_bboxes = {}
879+
padding = 60 # points around contained nodes
880+
881+
all_mod_keys = set(module_direct_nodes.keys()) | set(module_child_map.keys())
882+
for mod_key in all_mod_keys:
883+
elk_ids = _collect_module_elk_ids(mod_key)
884+
if not elk_ids:
885+
continue
886+
xs = []
887+
ys = []
888+
for eid in elk_ids:
889+
cx, cy = positions[eid]
890+
w, h = elk_id_sizes.get(eid, (_DEFAULT_NODE_WIDTH, _DEFAULT_NODE_HEIGHT))
891+
xs.extend([cx - w / 2, cx + w / 2])
892+
ys.extend([cy - h / 2, cy + h / 2])
893+
min_x, max_x_val = min(xs) - padding, max(xs) + padding
894+
min_y, max_y_val = min(ys) - padding, max(ys) + padding
895+
mod_addr = mod_key.split(":")[0] if ":" in mod_key else mod_key
896+
group_id = f"group_{mod_addr}"
897+
# ELK bbox format: (x, y, width, height) in y-down coords
898+
compound_bboxes[group_id] = (
899+
min_x,
900+
min_y,
901+
max_x_val - min_x,
902+
max_y_val - min_y,
903+
)
904+
905+
return positions, compound_bboxes, max_y
906+
907+
750908
def _estimate_node_size(label: str) -> tuple:
751909
"""Estimate graphviz node dimensions in points from an HTML label.
752910
@@ -1077,66 +1235,82 @@ def _assign_to_hierarchy(node_name, mod_keys, has_ancestor):
10771235

10781236
all_edges.append(edge)
10791237

1080-
# ── Phase 2: ELK layout ──
1238+
# ── Phase 2: Layout ──
1239+
#
1240+
# ELK's stress algorithm allocates TWO O(n²) distance matrices totalling
1241+
# n² × 16 bytes. At 100k nodes that's 160 GB — at 1M nodes, 16 TB.
1242+
# For graphs above _ELK_STRESS_LIMIT we bypass ELK entirely and compute
1243+
# a topological rank layout in Python (O(n+m) time and memory).
10811244

1082-
# Build per-node size estimates from labels, so ELK spaces correctly.
1245+
# Build per-node size estimates from labels (used by both paths).
10831246
elk_id_sizes = {}
10841247
for dot_name, nd in node_data.items():
10851248
elk_id = nd["elk_id"]
10861249
label = nd["attrs"].get("label", "")
10871250
elk_id_sizes[elk_id] = _estimate_node_size(label)
10881251

1089-
elk_graph = build_elk_graph_hierarchical(entries_to_plot, show_buffer_layers)
1252+
num_elk_nodes = len(node_data)
10901253

1091-
# Override ELK node sizes with label-based estimates.
1092-
def _patch_sizes(elk_node):
1093-
for ch in elk_node.get("children", []):
1094-
if ch["id"].startswith("group_"):
1095-
_patch_sizes(ch)
1096-
elif ch["id"] in elk_id_sizes:
1097-
w, h = elk_id_sizes[ch["id"]]
1098-
ch["width"] = w
1099-
ch["height"] = h
1254+
if num_elk_nodes > _ELK_STRESS_LIMIT:
1255+
# ── Python topological layout (O(n+m)) ──
1256+
positions, compound_bboxes, max_y = _compute_topological_layout(
1257+
node_data, all_edges, elk_id_sizes, module_direct_nodes, module_child_map
1258+
)
1259+
else:
1260+
# ── ELK layout (Node.js subprocess) ──
1261+
elk_graph = build_elk_graph_hierarchical(entries_to_plot, show_buffer_layers)
11001262

1101-
_patch_sizes(elk_graph)
1263+
def _patch_sizes(elk_node):
1264+
for ch in elk_node.get("children", []):
1265+
if ch["id"].startswith("group_"):
1266+
_patch_sizes(ch)
1267+
elif ch["id"] in elk_id_sizes:
1268+
w, h = elk_id_sizes[ch["id"]]
1269+
ch["width"] = w
1270+
ch["height"] = h
11021271

1103-
# Scale timeout with graph size: ~15ms per node, minimum 120s.
1104-
# Empirical: 5k→10s, 25k→114s, scaling ~O(n^1.4).
1105-
num_elk_nodes = len(node_data)
1272+
_patch_sizes(elk_graph)
11061273

1107-
# The layered algorithm (Sugiyama) uses O(n^2) memory for crossing
1108-
# minimization — at ~100k+ nodes it triggers std::bad_alloc in elkjs.
1109-
# Switch to stress-majorization which is O(n) memory, seeded with
1110-
# topological positions so the layout preserves directional flow.
1111-
if num_elk_nodes > 150000:
1112-
elk_graph["layoutOptions"]["elk.algorithm"] = "stress"
1113-
_seed_stress_positions(elk_graph, all_edges)
1114-
1115-
elk_timeout = max(_ELK_TIMEOUT, int(num_elk_nodes * 0.015))
1116-
positioned = run_elk_layout(elk_graph, timeout=elk_timeout)
1117-
1118-
# Collect leaf node centers and compound node bounding boxes from ELK output.
1119-
positions = {} # leaf_id -> (center_x, center_y) in ELK coords
1120-
compound_bboxes = {} # "group_<mod>" -> (x, y, w, h) in ELK coords (absolute)
1121-
1122-
def _collect_pos(elk_node, ox=0, oy=0):
1123-
for ch in elk_node.get("children", []):
1124-
ax = ox + ch.get("x", 0)
1125-
ay = oy + ch.get("y", 0)
1126-
if ch["id"].startswith("group_"):
1127-
w = ch.get("width", 0)
1128-
h = ch.get("height", 0)
1129-
compound_bboxes[ch["id"]] = (ax, ay, w, h)
1130-
_collect_pos(ch, ax, ay)
1131-
else:
1132-
w = ch.get("width", _DEFAULT_NODE_WIDTH)
1133-
h = ch.get("height", _DEFAULT_NODE_HEIGHT)
1134-
positions[ch["id"]] = (ax + w / 2, ay + h / 2)
1274+
elk_timeout = max(_ELK_TIMEOUT, int(num_elk_nodes * 0.015))
11351275

1136-
_collect_pos(positioned)
1137-
# Use the root node's full height as the y-flip reference.
1138-
root_h = positioned.get("height", 0)
1139-
max_y = max(root_h, max((y for _, y in positions.values()), default=0))
1276+
try:
1277+
positioned = run_elk_layout(elk_graph, timeout=elk_timeout)
1278+
except RuntimeError as e:
1279+
warnings.warn(f"ELK layout failed ({e}), falling back to Python topological layout.")
1280+
positioned = None
1281+
1282+
# Collect positions from ELK output, or fall back to Python layout.
1283+
positions = {}
1284+
compound_bboxes = {}
1285+
max_y = 0
1286+
1287+
if positioned is None:
1288+
positions, compound_bboxes, max_y = _compute_topological_layout(
1289+
node_data,
1290+
all_edges,
1291+
elk_id_sizes,
1292+
module_direct_nodes,
1293+
module_child_map,
1294+
)
1295+
else:
1296+
1297+
def _collect_pos(elk_node, ox=0, oy=0):
1298+
for ch in elk_node.get("children", []):
1299+
ax = ox + ch.get("x", 0)
1300+
ay = oy + ch.get("y", 0)
1301+
if ch["id"].startswith("group_"):
1302+
w = ch.get("width", 0)
1303+
h = ch.get("height", 0)
1304+
compound_bboxes[ch["id"]] = (ax, ay, w, h)
1305+
_collect_pos(ch, ax, ay)
1306+
else:
1307+
w = ch.get("width", _DEFAULT_NODE_WIDTH)
1308+
h = ch.get("height", _DEFAULT_NODE_HEIGHT)
1309+
positions[ch["id"]] = (ax + w / 2, ay + h / 2)
1310+
1311+
_collect_pos(positioned)
1312+
root_h = positioned.get("height", 0)
1313+
max_y = max(root_h, max((y for _, y in positions.values()), default=0))
11401314

11411315
# ── Phase 3: Generate DOT with clusters and positions ──
11421316

@@ -1266,6 +1440,9 @@ def _write_cluster(mod_key, depth, indent):
12661440
dot_source = "\n".join(lines)
12671441

12681442
# ── Phase 4: Render with neato -n ──
1443+
#
1444+
# Both ELK and the Python topological layout produce positions, so we
1445+
# always use neato -n (pre-positioned layout that respects clusters).
12691446

12701447
if num_elk_nodes > 25000 and vis_fileformat != "svg":
12711448
warnings.warn(
@@ -1279,8 +1456,8 @@ def _write_cluster(mod_key, depth, indent):
12791456
f.write(dot_source)
12801457

12811458
rendered_path = f"{vis_outpath}.{vis_fileformat}"
1282-
# Spline routing is O(n^2) — use straight lines for large graphs.
12831459
num_nodes = len(node_data)
1460+
# Spline routing is O(n^2) — use straight lines for large graphs.
12841461
spline_mode = "true" if num_nodes < 1000 else "line"
12851462
cmd = [
12861463
"neato",

torchlens/visualization/rendering.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,28 +204,26 @@ def render_graph(
204204

205205
# ELK fast path: skip graphviz.Digraph construction entirely.
206206
# Generates DOT directly with ELK positions and cluster subgraphs (module boxes).
207+
# If ELK layout fails (OOM, timeout), render_elk_direct falls back internally
208+
# to sfdp — still using the fast DOT-text path, never graphviz.Digraph.
207209
if engine == "elk":
208-
from .elk_layout import render_elk_direct, render_with_sfdp
210+
from .elk_layout import render_elk_direct
209211

210-
try:
211-
result = render_elk_direct(
212-
self,
213-
entries_to_plot,
214-
vis_mode,
215-
vis_nesting_depth,
216-
show_buffer_layers,
217-
overrides,
218-
vis_outpath,
219-
vis_fileformat,
220-
vis_save_only,
221-
graph_caption,
222-
rankdir,
223-
)
224-
_vprint(self, f"Graph saved to {vis_outpath}.{vis_fileformat}")
225-
return result
226-
except RuntimeError as e:
227-
warnings.warn(f"ELK layout failed ({e}), falling back to sfdp.")
228-
engine = "sfdp" # fall through to build graphviz.Digraph for sfdp
212+
result = render_elk_direct(
213+
self,
214+
entries_to_plot,
215+
vis_mode,
216+
vis_nesting_depth,
217+
show_buffer_layers,
218+
overrides,
219+
vis_outpath,
220+
vis_fileformat,
221+
vis_save_only,
222+
graph_caption,
223+
rankdir,
224+
)
225+
_vprint(self, f"Graph saved to {vis_outpath}.{vis_fileformat}")
226+
return result
229227

230228
dot = graphviz.Digraph(
231229
name=self.model_name,

0 commit comments

Comments
 (0)