Skip to content

Commit f0d7452

Browse files
feat: rename all data structure fields and function args for clarity
Rename ~68 fields across all 8 data structures (ModelLog, LayerPassLog, LayerLog, ParamLog, ModuleLog, BufferLog, ModulePassLog, FuncCallLocation) plus user-facing function arguments. Key changes: - tensor_contents → activation, grad_contents → gradient - All *_fsize* → *_memory* (e.g. tensor_fsize → tensor_memory) - func_applied_name → func_name, gradfunc → grad_fn_name - is_bottom_level_submodule_output → is_leaf_module_output - containing_module_origin → containing_module - spouse_layers → co_parent_layers, orig_ancestors → root_ancestors - model_is_recurrent → is_recurrent, elapsed_time_* → time_* - vis_opt → vis_mode, save_only → vis_save_only - Fix typo: output_descendents → output_descendants Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 23ef8d8 commit f0d7452

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2328
-2377
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class SimpleRecurrent(nn.Module):
3737
simple_recurrent = SimpleRecurrent()
3838
model_history = tl.log_forward_pass(simple_recurrent, x,
3939
layers_to_save='all',
40-
vis_opt='rolled')
41-
print(model_history['linear_1_1:2'].tensor_contents) # second pass of first linear layer
40+
vis_mode='rolled')
41+
print(model_history['linear_1_1:2'].activation) # second pass of first linear layer
4242

4343
'''
4444
tensor([[-0.0690, -1.3957, -0.3231, -0.1980, 0.7197],
@@ -88,7 +88,7 @@ import torchlens as tl
8888

8989
alexnet = torchvision.models.alexnet()
9090
x = torch.rand(1, 3, 224, 224)
91-
model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_opt='unrolled')
91+
model_history = tl.log_forward_pass(alexnet, x, layers_to_save='all', vis_mode='unrolled')
9292
print(model_history)
9393

9494
'''
@@ -166,16 +166,16 @@ Layer conv2d_3_7, operation 8/24:
166166
Params: Computed from params with shape (384,), (384, 192, 3, 3); 663936 params total (2.5 MB)
167167
Parent Layers: maxpool2d_2_6
168168
Child Layers: relu_3_8
169-
Function: conv2d (gradfunc=ConvolutionBackward0)
169+
Function: conv2d (grad_fn=ConvolutionBackward0)
170170
Computed inside module: features.6
171171
Time elapsed: 5.670E-04s
172172
Output of modules: features.6
173173
Output of bottom-level module: features.6
174174
Lookup keys: -17, 7, conv2d_3_7, conv2d_3_7:1, features.6, features.6:1
175175
'''
176176

177-
# You can pull out the actual output activations from a layer with the tensor_contents field:
178-
print(model_history['conv2d_3_7'].tensor_contents)
177+
# You can pull out the actual output activations from a layer with the activation field:
178+
print(model_history['conv2d_3_7'].activation)
179179
'''
180180
tensor([[[[-0.0867, -0.0787, -0.0817, ..., -0.0820, -0.0655, -0.0195],
181181
[-0.1213, -0.1130, -0.1386, ..., -0.1331, -0.1118, -0.0520],
@@ -194,7 +194,7 @@ will pull out all conv layers):
194194

195195
```python
196196
# Pull out conv2d_3_7, the output of the 'features' module, the fifth-to-last layer, and all linear (i.e., fc) layers:
197-
model_history = tl.log_forward_pass(alexnet, x, vis_opt='unrolled',
197+
model_history = tl.log_forward_pass(alexnet, x, vis_mode='unrolled',
198198
layers_to_save=['conv2d_3_7', 'features', -5, 'linear'])
199199
print(model_history.layer_labels)
200200
'''

scripts/render_large_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def main():
5555

5656
# Render from ModelLog metadata (model tensors no longer in memory).
5757
ml.render_graph(
58-
vis_opt="rolled",
58+
vis_mode="rolled",
5959
vis_outpath=os.path.join(args.outdir, label),
60-
save_only=True,
60+
vis_save_only=True,
6161
vis_fileformat=args.format,
6262
vis_node_placement="elk",
6363
)

tests/CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_model_descriptive_name(default_input1):
6464
model = example_models.MyNewModel()
6565
assert validate_saved_activations(model, default_input1)
6666
show_model_graph(
67-
model, default_input1, save_only=True, vis_opt="unrolled",
67+
model, default_input1, vis_save_only=True, vis_mode="unrolled",
6868
vis_outpath=opj(VIS_OUTPUT_DIR, "toy-networks", "my_new_model"),
6969
)
7070
```

tests/example_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,7 @@ def forward(self, x, c, t, context_mask):
17401740

17411741
class ViewMutationUnsqueeze(nn.Module):
17421742
"""Mutation through unsqueeze view: y = x.unsqueeze(0); y.fill_(0); return x.
1743-
The fill_ mutates x's storage through the view, so x's tensor_contents at
1743+
The fill_ mutates x's storage through the view, so x's activation at
17441744
logging time differs from what children actually receive."""
17451745

17461746
def __init__(self):

tests/test_gc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_release_param_refs_preserves_grad_metadata(self):
174174
has_any_grad = True
175175
assert pl._grad_shape is not None
176176
assert pl._grad_dtype is not None
177-
assert pl._grad_fsize > 0
177+
assert pl._grad_memory > 0
178178
assert has_any_grad, "Expected at least one param to have grad metadata cached"
179179
model_log.cleanup()
180180

tests/test_internals.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,14 @@ def test_cleanup_no_crash(self):
463463

464464
class TestNestedTupleArgs:
465465
def test_nested_tuple_independence(self):
466-
"""Nested tuples/lists in creation_args should be independent copies."""
466+
"""Nested tuples/lists in captured_args should be independent copies."""
467467
model = _SimpleLinear()
468468
x = torch.randn(2, 10)
469469
log = log_forward_pass(model, x, save_function_args=True)
470470
found_args = False
471471
for label in log.layer_labels:
472472
entry = log[label]
473-
if entry.creation_args is not None and len(entry.creation_args) > 0:
473+
if entry.captured_args is not None and len(entry.captured_args) > 0:
474474
found_args = True
475475
break
476476
assert found_args or True # OK if no args (model-dependent)
@@ -495,5 +495,5 @@ def test_shape_matches_capture_time(self):
495495
log = log_forward_pass(model, x, layers_to_save="all")
496496
for label in log.layer_labels:
497497
entry = log[label]
498-
if entry.tensor_contents is not None:
498+
if entry.activation is not None:
499499
assert entry.tensor_shape is not None

tests/test_large_graphs.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ def test_nesting_depth(self):
127127
"""Module nesting creates expected hierarchy."""
128128
model = RandomGraphModel(target_nodes=500, nesting_depth=4, seed=42)
129129
ml = log_forward_pass(model, torch.randn(2, 64))
130-
max_depth = max(
131-
len(ml[label].containing_modules_origin_nested) for label in ml.layer_labels
132-
)
130+
max_depth = max(len(ml[label].containing_modules) for label in ml.layer_labels)
133131
assert max_depth >= 3
134132
ml.cleanup()
135133

@@ -325,7 +323,7 @@ def test_dot_renders_small_graph(self):
325323
model,
326324
torch.randn(2, 64),
327325
vis_node_placement="dot",
328-
save_only=True,
326+
vis_save_only=True,
329327
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "dot_200"),
330328
)
331329

@@ -337,7 +335,7 @@ def test_sfdp_renders_large_graph(self):
337335
model,
338336
torch.randn(2, 64),
339337
vis_node_placement="sfdp",
340-
save_only=True,
338+
vis_save_only=True,
341339
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "sfdp_3k"),
342340
)
343341

@@ -350,7 +348,7 @@ def test_elk_renders_3k(self):
350348
model,
351349
torch.randn(2, 64),
352350
vis_node_placement="elk",
353-
save_only=True,
351+
vis_save_only=True,
354352
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_3k"),
355353
)
356354

@@ -363,7 +361,7 @@ def test_elk_renders_5k(self):
363361
model,
364362
torch.randn(2, 64),
365363
vis_node_placement="elk",
366-
save_only=True,
364+
vis_save_only=True,
367365
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_5k"),
368366
)
369367

@@ -376,7 +374,7 @@ def test_elk_renders_10k(self):
376374
model,
377375
torch.randn(2, 64),
378376
vis_node_placement="elk",
379-
save_only=True,
377+
vis_save_only=True,
380378
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_10k"),
381379
)
382380

@@ -389,7 +387,7 @@ def test_elk_renders_20k(self):
389387
model,
390388
torch.randn(2, 64),
391389
vis_node_placement="elk",
392-
save_only=True,
390+
vis_save_only=True,
393391
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_20k"),
394392
)
395393

@@ -402,7 +400,7 @@ def test_elk_renders_50k(self):
402400
model,
403401
torch.randn(2, 64),
404402
vis_node_placement="elk",
405-
save_only=True,
403+
vis_save_only=True,
406404
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_50k"),
407405
)
408406

@@ -416,7 +414,7 @@ def test_elk_renders_100k(self):
416414
model,
417415
torch.randn(2, 64),
418416
vis_node_placement="elk",
419-
save_only=True,
417+
vis_save_only=True,
420418
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_100k"),
421419
)
422420

@@ -430,7 +428,7 @@ def test_elk_renders_250k(self):
430428
torch.randn(2, 64),
431429
vis_node_placement="elk",
432430
vis_fileformat="svg",
433-
save_only=True,
431+
vis_save_only=True,
434432
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_250k"),
435433
)
436434

@@ -444,7 +442,7 @@ def test_elk_renders_1M(self):
444442
torch.randn(2, 64),
445443
vis_node_placement="elk",
446444
vis_fileformat="svg",
447-
save_only=True,
445+
vis_save_only=True,
448446
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "elk_1M"),
449447
)
450448

@@ -453,8 +451,8 @@ def test_vis_node_placement_forwarded(self):
453451
model = RandomGraphModel(target_nodes=200, seed=42)
454452
ml = log_forward_pass(model, torch.randn(2, 64))
455453
ml.render_graph(
456-
vis_opt="unrolled",
457-
save_only=True,
454+
vis_mode="unrolled",
455+
vis_save_only=True,
458456
vis_outpath=os.path.join(VIS_OUTPUT_DIR, "placement_test"),
459457
vis_node_placement="dot",
460458
)
@@ -485,17 +483,17 @@ def _render_both(self, model, x, name):
485483
show_model_graph(
486484
model,
487485
x,
488-
save_only=True,
489-
vis_opt="unrolled",
486+
vis_save_only=True,
487+
vis_mode="unrolled",
490488
vis_node_placement="dot",
491489
vis_outpath=os.path.join(self.COMPARE_DIR, f"{name}_dot"),
492490
)
493491
# ELK
494492
show_model_graph(
495493
model,
496494
x,
497-
save_only=True,
498-
vis_opt="unrolled",
495+
vis_save_only=True,
496+
vis_mode="unrolled",
499497
vis_node_placement="elk",
500498
vis_outpath=os.path.join(self.COMPARE_DIR, f"{name}_elk"),
501499
)
@@ -553,8 +551,8 @@ def test_benchmark_dot_scaling(self):
553551
start = time.time()
554552
try:
555553
ml.render_graph(
556-
vis_opt="unrolled",
557-
save_only=True,
554+
vis_mode="unrolled",
555+
vis_save_only=True,
558556
vis_outpath=os.path.join(VIS_OUTPUT_DIR, f"bench_{target}"),
559557
vis_node_placement="dot",
560558
)

tests/test_layer_log.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ def test_back_reference_set(self, simple_log):
8383
class TestSinglePassDelegation:
8484
@pytest.mark.smoke
8585
def test_tensor_contents_delegation(self, simple_log):
86-
"""tensor_contents delegates to passes[1] for single-pass."""
86+
"""activation delegates to passes[1] for single-pass."""
8787
for layer_log in simple_log.layer_logs.values():
8888
pass_log = layer_log.passes[1]
89-
assert layer_log.tensor_contents is pass_log.tensor_contents
89+
assert layer_log.activation is pass_log.activation
9090

9191
def test_has_saved_activations_delegation(self, simple_log):
9292
for layer_log in simple_log.layer_logs.values():
@@ -121,12 +121,12 @@ def test_aggregate_fields_match_first_pass(self, simple_log):
121121
for layer_log in simple_log.layer_logs.values():
122122
fp = layer_log.passes[1]
123123
assert layer_log.layer_type == fp.layer_type
124-
assert layer_log.func_applied_name == fp.func_applied_name
124+
assert layer_log.func_name == fp.func_name
125125
assert layer_log.tensor_shape == fp.tensor_shape
126126
assert layer_log.tensor_dtype == fp.tensor_dtype
127127
assert layer_log.is_input_layer == fp.is_input_layer
128128
assert layer_log.is_output_layer == fp.is_output_layer
129-
assert layer_log.computed_with_params == fp.computed_with_params
129+
assert layer_log.uses_params == fp.uses_params
130130

131131
def test_layer_label_is_no_pass(self, simple_log):
132132
for layer_log in simple_log.layer_logs.values():
@@ -155,7 +155,7 @@ def test_multi_pass_tensor_contents_raises(self, recurrent_log):
155155
for layer_log in recurrent_log.layer_logs.values():
156156
if layer_log.num_passes > 1:
157157
with pytest.raises(ValueError, match="has .* passes"):
158-
_ = layer_log.tensor_contents
158+
_ = layer_log.activation
159159
break
160160

161161
def test_multi_pass_getattr_raises(self, recurrent_log):
@@ -234,17 +234,13 @@ def test_repr_eq_str(self, simple_log):
234234

235235

236236
class TestConvenienceAliases:
237-
def test_layer_passes_total_alias(self, simple_log):
238-
for layer_log in simple_log.layer_logs.values():
239-
assert layer_log.layer_passes_total == layer_log.num_passes
240-
241237
def test_layer_label_no_pass_alias(self, simple_log):
242238
for layer_log in simple_log.layer_logs.values():
243239
assert layer_log.layer_label_no_pass == layer_log.layer_label
244240

245241
def test_params_accessor(self, recurrent_log):
246242
for layer_log in recurrent_log.layer_logs.values():
247-
if layer_log.computed_with_params:
243+
if layer_log.uses_params:
248244
params = layer_log.params
249245
assert params is not None
250246
break

0 commit comments

Comments
 (0)