Skip to content

Commit 5b0baa8

Browse files
feat(data): add ModuleLog, ModulePassLog, and ModuleAccessor
Introduce structured per-module metadata classes following the ParamLog/ParamAccessor pattern. log.modules["features.3"] now returns a rich ModuleLog with class, params, layers, source info, hierarchy, hooks, forward signature, and nesting depth. Multi-pass modules support per-call access via passes dict and pass notation (e.g. "fc1:2"). - ModulePassLog: per-(module, pass) lightweight container - ModuleLog: per-module-object user-facing class with delegating properties for single-pass modules - ModuleAccessor: dict-like accessor with summary()/to_pandas() - Metadata captured in prepare_model() before cleanup strips tl_* attrs - Forward args/kwargs captured per pass in module_forward_decorator() - _build_module_logs() in postprocess Step 17 assembles everything - Old module_* dicts kept alive for vis.py backward compat - 44 new tests, 315 total passing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0132773 commit 5b0baa8

File tree

9 files changed

+1191
-2
lines changed

9 files changed

+1191
-2
lines changed

tests/test_module_log.py

Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
"""Tests for ModuleLog, ModulePassLog, and ModuleAccessor."""
2+
3+
import pytest
4+
import torch
5+
import torch.nn as nn
6+
7+
import example_models
8+
from torchlens import ModuleLog, ModulePassLog, log_forward_pass
9+
from torchlens.data_classes import ModuleAccessor, ParamAccessor
10+
11+
12+
# ---------------------------------------------------------------------------
13+
# Helpers
14+
# ---------------------------------------------------------------------------
15+
16+
17+
def _make_simple_model():
18+
return nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2))
19+
20+
21+
def _simple_input():
22+
return torch.randn(1, 10)
23+
24+
25+
def _make_nested_model():
26+
"""Model with nested submodules for hierarchy tests."""
27+
return nn.Sequential(
28+
nn.Sequential(nn.Linear(10, 8), nn.ReLU()),
29+
nn.Sequential(nn.Linear(8, 4), nn.Sigmoid()),
30+
nn.Linear(4, 2),
31+
)
32+
33+
34+
def _nested_input():
35+
return torch.randn(1, 10)
36+
37+
38+
# ---------------------------------------------------------------------------
39+
# TestModuleLogBasic
40+
# ---------------------------------------------------------------------------
41+
42+
43+
class TestModuleLogBasic:
44+
def test_modules_accessor_exists(self):
45+
log = log_forward_pass(_make_simple_model(), _simple_input())
46+
assert isinstance(log.modules, ModuleAccessor)
47+
48+
def test_root_module_exists(self):
49+
log = log_forward_pass(_make_simple_model(), _simple_input())
50+
root = log.modules["self"]
51+
assert isinstance(root, ModuleLog)
52+
assert root.address == "self"
53+
54+
def test_root_module_alias(self):
55+
log = log_forward_pass(_make_simple_model(), _simple_input())
56+
assert log.modules[""] is log.modules["self"]
57+
58+
def test_root_module_property(self):
59+
log = log_forward_pass(_make_simple_model(), _simple_input())
60+
assert log.root_module is log.modules["self"]
61+
62+
def test_module_count(self):
63+
log = log_forward_pass(_make_simple_model(), _simple_input())
64+
# Sequential has 3 children: Linear, ReLU, Linear → 3 submodules + root = 4
65+
assert len(log.modules) >= 4
66+
67+
def test_access_by_address(self):
68+
log = log_forward_pass(_make_simple_model(), _simple_input())
69+
ml = log.modules["0"]
70+
assert isinstance(ml, ModuleLog)
71+
assert ml.address == "0"
72+
73+
def test_access_by_index(self):
74+
log = log_forward_pass(_make_simple_model(), _simple_input())
75+
ml = log.modules[0]
76+
assert isinstance(ml, ModuleLog)
77+
78+
def test_access_by_pass_notation(self):
79+
log = log_forward_pass(_make_simple_model(), _simple_input())
80+
# All modules have 1 pass in a non-recurrent model
81+
addresses = [ml.address for ml in log.modules if ml.address != "self"]
82+
if addresses:
83+
addr = addresses[0]
84+
mpl = log.modules[f"{addr}:1"]
85+
assert isinstance(mpl, ModulePassLog)
86+
87+
def test_contains(self):
88+
log = log_forward_pass(_make_simple_model(), _simple_input())
89+
assert "0" in log.modules
90+
assert "self" in log.modules
91+
assert "nonexistent" not in log.modules
92+
93+
def test_iter(self):
94+
log = log_forward_pass(_make_simple_model(), _simple_input())
95+
modules_list = list(log.modules)
96+
assert len(modules_list) == len(log.modules)
97+
assert all(isinstance(ml, ModuleLog) for ml in modules_list)
98+
99+
def test_getitem_multi_pass_returns_module_log(self, input_2d):
100+
"""log["fc1"] for a multi-pass module should return ModuleLog (instead of error)."""
101+
model = example_models.RecurrentParamsSimple()
102+
log = log_forward_pass(model, input_2d)
103+
result = log["fc1"]
104+
assert isinstance(result, ModuleLog)
105+
106+
107+
# ---------------------------------------------------------------------------
108+
# TestModuleLogFields
109+
# ---------------------------------------------------------------------------
110+
111+
112+
class TestModuleLogFields:
113+
def test_identity_fields(self):
114+
log = log_forward_pass(_make_simple_model(), _simple_input())
115+
ml = log.modules["0"]
116+
assert ml.address == "0"
117+
assert ml.name == "0"
118+
assert ml.module_class_name == "Linear"
119+
120+
def test_source_info(self):
121+
log = log_forward_pass(_make_simple_model(), _simple_input())
122+
ml = log.modules["0"]
123+
assert ml.source_file is not None # nn.Linear has inspectable source
124+
assert ml.forward_signature is not None
125+
126+
def test_hierarchy_address(self):
127+
log = log_forward_pass(_make_nested_model(), _nested_input())
128+
# "0.0" is Linear inside first Sequential
129+
ml = log.modules["0.0"]
130+
assert ml.address_parent == "0"
131+
assert ml.address_depth == 2
132+
133+
# "0" is the first Sequential
134+
parent = log.modules["0"]
135+
assert parent.address_parent == "self"
136+
assert "0.0" in parent.address_children
137+
138+
def test_hierarchy_call(self):
139+
log = log_forward_pass(_make_simple_model(), _simple_input())
140+
root = log.modules["self"]
141+
# Root's call_children should include top-level modules
142+
assert len(root.call_children) > 0
143+
144+
def test_nesting_depth(self):
145+
log = log_forward_pass(_make_nested_model(), _nested_input())
146+
root = log.modules["self"]
147+
assert root.nesting_depth == 0
148+
149+
# Top-level module should be depth 1
150+
top = log.modules["0"]
151+
assert top.nesting_depth == 1
152+
153+
# Nested inside "0" should be depth 2
154+
nested = log.modules["0.0"]
155+
assert nested.nesting_depth == 2
156+
157+
def test_address_depth(self):
158+
log = log_forward_pass(_make_nested_model(), _nested_input())
159+
assert log.modules["self"].address_depth == 0
160+
assert log.modules["0"].address_depth == 1
161+
assert log.modules["0.0"].address_depth == 2
162+
163+
def test_layers_populated(self):
164+
log = log_forward_pass(_make_simple_model(), _simple_input())
165+
ml = log.modules["0"]
166+
assert len(ml.all_layers) > 0
167+
assert ml.num_layers == len(ml.all_layers)
168+
169+
def test_params_accessor(self):
170+
log = log_forward_pass(_make_simple_model(), _simple_input())
171+
ml = log.modules["0"] # Linear layer
172+
assert isinstance(ml.params, ParamAccessor)
173+
assert len(ml.params) == 2 # weight + bias
174+
175+
def test_training_mode(self):
176+
model = _make_simple_model()
177+
model.eval()
178+
log = log_forward_pass(model, _simple_input())
179+
for ml in log.modules:
180+
if ml.address != "self":
181+
assert ml.training_mode is False
182+
183+
def test_hooks_detected_false(self):
184+
log = log_forward_pass(_make_simple_model(), _simple_input())
185+
for ml in log.modules:
186+
assert ml.has_forward_hooks is False
187+
assert ml.has_backward_hooks is False
188+
189+
def test_repr(self):
190+
log = log_forward_pass(_make_simple_model(), _simple_input())
191+
ml = log.modules["0"]
192+
r = repr(ml)
193+
assert "ModuleLog" in r
194+
assert ml.address in r
195+
assert ml.module_class_name in r
196+
197+
198+
# ---------------------------------------------------------------------------
199+
# TestModulePassLog
200+
# ---------------------------------------------------------------------------
201+
202+
203+
class TestModulePassLog:
204+
def test_pass_layers(self):
205+
log = log_forward_pass(_make_simple_model(), _simple_input())
206+
ml = log.modules["0"]
207+
mpl = ml.passes[1]
208+
assert isinstance(mpl, ModulePassLog)
209+
# Pass layers should be a subset of parent all_layers
210+
assert all(label in ml.all_layers for label in mpl.layers)
211+
212+
def test_input_output_layers(self):
213+
log = log_forward_pass(_make_simple_model(), _simple_input())
214+
ml = log.modules["0"]
215+
mpl = ml.passes[1]
216+
assert isinstance(mpl.input_layers, list)
217+
assert isinstance(mpl.output_layers, list)
218+
219+
def test_call_children(self):
220+
log = log_forward_pass(_make_nested_model(), _nested_input())
221+
# "0" contains "0.0" and "0.1" as submodules
222+
ml = log.modules["0"]
223+
mpl = ml.passes[1]
224+
assert isinstance(mpl.call_children, list)
225+
226+
def test_repr(self):
227+
log = log_forward_pass(_make_simple_model(), _simple_input())
228+
ml = log.modules["0"]
229+
mpl = ml.passes[1]
230+
r = repr(mpl)
231+
assert "ModulePassLog" in r
232+
assert len(r) > 0
233+
234+
235+
# ---------------------------------------------------------------------------
236+
# TestMultiPassModules
237+
# ---------------------------------------------------------------------------
238+
239+
240+
class TestMultiPassModules:
241+
def test_num_passes_gt_1(self, input_2d):
242+
model = example_models.RecurrentParamsSimple()
243+
log = log_forward_pass(model, input_2d)
244+
# fc1 is used 4 times
245+
ml = log.modules["fc1"]
246+
assert ml.num_passes >= 2
247+
248+
def test_per_call_field_raises(self, input_2d):
249+
model = example_models.RecurrentParamsSimple()
250+
log = log_forward_pass(model, input_2d)
251+
ml = log.modules["fc1"]
252+
assert ml.num_passes > 1
253+
with pytest.raises(AttributeError, match="passes"):
254+
_ = ml.layers
255+
256+
def test_pass_access(self, input_2d):
257+
model = example_models.RecurrentParamsSimple()
258+
log = log_forward_pass(model, input_2d)
259+
ml = log.modules["fc1"]
260+
assert 1 in ml.passes
261+
assert 2 in ml.passes
262+
assert isinstance(ml.passes[1], ModulePassLog)
263+
assert isinstance(ml.passes[2], ModulePassLog)
264+
265+
def test_pass_notation_accessor(self, input_2d):
266+
model = example_models.RecurrentParamsSimple()
267+
log = log_forward_pass(model, input_2d)
268+
mpl = log.modules["fc1:2"]
269+
assert isinstance(mpl, ModulePassLog)
270+
assert mpl.pass_num == 2
271+
272+
273+
# ---------------------------------------------------------------------------
274+
# TestSinglePassDelegation
275+
# ---------------------------------------------------------------------------
276+
277+
278+
class TestSinglePassDelegation:
279+
def test_layers_delegates(self):
280+
log = log_forward_pass(_make_simple_model(), _simple_input())
281+
ml = log.modules["0"]
282+
assert ml.num_passes == 1
283+
# Should delegate to passes[1].layers
284+
assert ml.layers == ml.passes[1].layers
285+
286+
def test_forward_args_delegates(self):
287+
log = log_forward_pass(_make_simple_model(), _simple_input())
288+
ml = log.modules["0"]
289+
# forward_args should be accessible for single-pass
290+
_ = ml.forward_args # should not raise
291+
292+
293+
# ---------------------------------------------------------------------------
294+
# TestModuleAccessorSummary
295+
# ---------------------------------------------------------------------------
296+
297+
298+
class TestModuleAccessorSummary:
299+
def test_to_pandas(self):
300+
log = log_forward_pass(_make_simple_model(), _simple_input())
301+
df = log.modules.to_pandas()
302+
assert len(df) == len(log.modules)
303+
assert "address" in df.columns
304+
assert "module_class_name" in df.columns
305+
assert "nesting_depth" in df.columns
306+
assert "num_params" in df.columns
307+
308+
def test_summary(self):
309+
log = log_forward_pass(_make_simple_model(), _simple_input())
310+
s = log.modules.summary()
311+
assert isinstance(s, str)
312+
assert len(s) > 0
313+
assert "Address" in s
314+
315+
def test_repr(self):
316+
log = log_forward_pass(_make_simple_model(), _simple_input())
317+
r = repr(log.modules)
318+
assert "ModuleAccessor" in r
319+
320+
321+
# ---------------------------------------------------------------------------
322+
# TestModuleLogIntegration
323+
# ---------------------------------------------------------------------------
324+
325+
326+
class TestModuleLogIntegration:
327+
def test_root_all_layers_equals_model_layers(self):
328+
log = log_forward_pass(_make_simple_model(), _simple_input())
329+
root = log.root_module
330+
assert root.all_layers == log.layer_labels
331+
332+
def test_root_params_count(self):
333+
log = log_forward_pass(_make_simple_model(), _simple_input())
334+
root = log.root_module
335+
assert root.num_params == log.total_params
336+
337+
def test_module_class_name_matches(self):
338+
log = log_forward_pass(_make_simple_model(), _simple_input())
339+
# Module "0" is Linear
340+
assert log.modules["0"].module_class_name == "Linear"
341+
# Module "1" is ReLU
342+
assert log.modules["1"].module_class_name == "ReLU"
343+
# Module "2" is Linear
344+
assert log.modules["2"].module_class_name == "Linear"
345+
346+
def test_nested_model_hierarchy(self):
347+
log = log_forward_pass(_make_nested_model(), _nested_input())
348+
# Check that nesting is consistent
349+
for ml in log.modules:
350+
if ml.address == "self":
351+
continue
352+
# address_parent should be a valid module
353+
assert ml.address_parent in log.modules
354+
355+
def test_module_log_to_pandas(self):
356+
log = log_forward_pass(_make_simple_model(), _simple_input())
357+
ml = log.modules["0"]
358+
df = ml.to_pandas()
359+
assert len(df) == ml.num_layers
360+
361+
def test_module_log_iter(self):
362+
log = log_forward_pass(_make_simple_model(), _simple_input())
363+
ml = log.modules["0"]
364+
entries = list(ml)
365+
assert len(entries) == ml.num_layers
366+
367+
def test_module_log_getitem(self):
368+
log = log_forward_pass(_make_simple_model(), _simple_input())
369+
ml = log.modules["0"]
370+
if ml.num_layers > 0:
371+
entry = ml[0]
372+
assert entry.layer_label == ml.all_layers[0]
373+
374+
def test_old_module_dicts_still_exist(self):
375+
"""Old module_* dicts should still exist for vis.py compatibility."""
376+
log = log_forward_pass(_make_simple_model(), _simple_input())
377+
assert hasattr(log, "module_types")
378+
assert hasattr(log, "module_addresses")
379+
assert hasattr(log, "module_layers")
380+
assert hasattr(log, "module_pass_layers")
381+
assert hasattr(log, "module_pass_children")
382+
383+
def test_nested_modules_model(self, input_2d):
384+
"""Integration test with the NestedModules example model."""
385+
model = example_models.NestedModules()
386+
log = log_forward_pass(model, input_2d)
387+
assert len(log.modules) > 1
388+
root = log.root_module
389+
assert root.address == "self"
390+
# Should have nested hierarchy
391+
max_depth = max(ml.nesting_depth for ml in log.modules)
392+
assert max_depth >= 2 # At least 3 levels of nesting

torchlens/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
)
1212
from .data_classes.model_log import ModelLog
1313
from .data_classes.tensor_log import TensorLog, RolledTensorLog
14-
from .data_classes import FuncCallLocation, ParamLog
14+
from .data_classes import FuncCallLocation, ModuleAccessor, ModuleLog, ModulePassLog, ParamLog

0 commit comments

Comments
 (0)