Skip to content

Commit b15c5bf

Browse files
docs: update all CLAUDE.md files with deepdive session 4 findings
Sync all project and subpackage documentation with current codebase: - Updated line counts across all 36 modules - Added elk_layout.py documentation to visualization/ - Added arg_positions.py and salient_args.py to capture/ - Documented 13 new bugs (ELK-IF-THEN, BFLOAT16-TOL, etc.) - Updated test counts (1,004 tests across 16 files) - Added known bugs sections to validation/, utils/, decoration/ - Updated data_classes/ with new fields and properties Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9b65063 commit b15c5bf

File tree

10 files changed

+223
-105
lines changed

10 files changed

+223
-105
lines changed

CLAUDE.md

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
## Project Overview
44

5-
TorchLens is a Python package for extracting activations from PyTorch models. It provides functionality for extracting model activations, visualizing computational graphs, and extracting exhaustive metadata about models.
5+
TorchLens is a Python package for extracting activations from PyTorch models. It permanently
6+
wraps all PyTorch functions at import time with toggle-gated wrappers, runs forward passes
7+
with the toggle enabled, and logs every operation into ModelLog/LayerLog/LayerPassLog objects.
8+
~20,800 lines core code (36 modules across 7 subpackages), ~1,004 tests across 16 test files.
69

710
## Commit Convention
811

@@ -38,12 +41,31 @@ If there is no issue, omit the issue reference — but prefer having an issue fo
3841

3942
## Testing
4043

41-
- Run tests: `pytest tests/`
42-
- Linting: `black --check .`
44+
- Run all tests: `pytest tests/`
45+
- Smoke tests (~6s): `pytest tests/ -m smoke`
46+
- Skip slow tests: `pytest tests/ -m "not slow"`
47+
- Linting: `ruff format` + `ruff check --fix`
4348

4449
## Project Structure
4550

46-
- `torchlens/` — main package source
47-
- `tests/` — test suite
51+
- `torchlens/` — main package source ([see subpackage docs](torchlens/CLAUDE.md))
52+
- `tests/` — test suite ([see test docs](tests/CLAUDE.md))
53+
- `scripts/` — development utilities ([see scripts docs](scripts/CLAUDE.md))
54+
- `.github/` — CI/CD workflows ([see CI docs](.github/CLAUDE.md))
4855
- `images/` — documentation images
49-
- `local_jmt/` — local development scripts (not packaged)
56+
57+
## Architecture Quick Reference
58+
59+
```
60+
import torchlens -> decorate_all_once() wraps ~2000 torch functions (ONE TIME)
61+
62+
log_forward_pass(model, input)
63+
1. Prepare model (decoration/model_prep.py)
64+
2. Run forward pass with logging (capture/)
65+
3. 18-step postprocess pipeline (postprocess/)
66+
4. Return ModelLog with all data
67+
```
68+
69+
Key packages: `capture/` (7 files), `data_classes/` (10 files), `decoration/` (2 files),
70+
`postprocess/` (6 files), `validation/` (3 files), `visualization/` (2 files: rendering.py + elk_layout.py),
71+
`utils/` (7 files).

tests/CLAUDE.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
11
# tests/ — Test Suite
22

33
## Overview
4-
~951 tests across 15 test files. Uses pytest with deterministic torch seeding.
4+
~1,004 tests across 16 test files + 2 support files. Uses pytest with deterministic torch seeding.
55

66
## Test Files
77

88
| File | Tests | What It Covers |
99
|------|-------|----------------|
1010
| `conftest.py` || Fixtures, deterministic seeding, output directory setup, coverage reporting |
11-
| `example_models.py` || 250 toy model class definitions for controlled testing |
11+
| `example_models.py` || ~5,400 lines, model definitions for controlled testing |
1212
| `test_toy_models.py` | 258 | Validation + visualization for toy models (14 sections by category) |
1313
| `test_real_world_models.py` | 185 | Real-world architectures (20 fast, 165 `@pytest.mark.slow`) |
14-
| `test_metadata.py` | 107 | Field-level coverage for ModelLog and LayerPassLog |
15-
| `test_module_log.py` | 45 | ModuleLog/ModulePassLog/ModuleAccessor |
14+
| `test_metadata.py` | 129 | Field-level metadata tests, conditional branch detection |
1615
| `test_param_log.py` | 70 | ParamLog/ParamAccessor |
1716
| `test_decoration.py` | 61 | Permanent decoration architecture (toggle, crawl, JIT, signals) |
1817
| `test_validation.py` | 59 | Validation subpackage (registries, perturbation, invariants A-R) |
19-
| `test_large_graphs.py` | 51 | Large graph rendering, RandomGraphModel, ELK layout engine |
18+
| `test_module_log.py` | 45 | ModuleLog/ModulePassLog/ModuleAccessor |
19+
| `test_large_graphs.py` | 43 | Large graph rendering, RandomGraphModel, ELK layout engine |
20+
| `test_internals.py` | 41 | Internal implementation details |
21+
| `test_func_config.py` | 34 | func_config / salient_args metadata |
2022
| `test_layer_log.py` | 34 | LayerLog aggregate class |
21-
| `test_internals.py` | 36 | Internal implementation details |
2223
| `test_save_new_activations.py` | 21 | `save_new_activations()` re-logging |
23-
| `test_profiling.py` | 1 | Performance profiling + decoration overhead benchmarks |
2424
| `test_output_aesthetics.py` | 12 | Aesthetic report + vis PDFs for human review |
2525
| `test_gc.py` | 10 | GC correctness, memory leak detection, param ref release |
26+
| `test_profiling.py` | 1 | Performance profiling + decoration overhead benchmarks |
2627
| `test_arg_positions.py` | 1 | ArgSpec lookup table coverage (runs last) |
2728

2829
## Running Tests

torchlens/CLAUDE.md

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,28 @@ each forward pass. Entry point: `log_forward_pass()` in `user_funcs.py`.
99

1010
```
1111
import torchlens (ONE TIME)
12-
├─ decorate_all_once() — wraps ~2000 torch functions
13-
├─ patch_detached_references() — patches `from torch import cos` style imports
14-
12+
|- decorate_all_once() — wraps ~2000 torch functions
13+
|- patch_detached_references() — patches `from torch import cos` style imports
14+
|
1515
log_forward_pass(model, input)
16-
├─ decoration/model_prep.py — prepare model (once + per-session)
17-
├─ capture/trace.py — run forward pass with logging enabled
18-
├─ capture/output_tensors.py — log each tensor operation
19-
├─ postprocess/ — 18-step pipeline (graph, loops, labels, modules)
20-
└─ Returns ModelLog with all logged data
16+
|- decoration/model_prep.py — prepare model (once + per-session)
17+
|- capture/trace.py — run forward pass with logging enabled
18+
|- capture/output_tensors.py — log each tensor operation
19+
|- postprocess/ — 18-step pipeline (graph, loops, labels, modules)
20+
+- Returns ModelLog with all logged data
2121
```
2222

2323
Two-pass strategy: when `layers_to_save` is a specific list, Pass 1 runs exhaustive
2424
(metadata only), Pass 2 runs fast (saves only requested layers).
2525

2626
## Files in This Directory
2727

28-
| File | Purpose |
29-
|------|---------|
30-
| `__init__.py` | Public API exports + import-time decoration trigger |
31-
| `_state.py` | Global toggle, session state, context managers. **No imports from other torchlens modules** (prevents circular deps) |
32-
| `constants.py` | Field-order lists (MODEL_LOG_FIELD_ORDER, LAYER_PASS_LOG_FIELD_ORDER), function discovery (ORIG_TORCH_FUNCS, IGNORED_FUNCS) |
33-
| `user_funcs.py` | User-facing API: `log_forward_pass`, `validate_forward_pass`, `show_model_graph`, `get_model_metadata` |
28+
| File | ~Lines | Purpose |
29+
|------|--------|---------|
30+
| `__init__.py` | 26 | Public API exports + import-time decoration trigger |
31+
| `_state.py` | 208 | Global toggle, session state, context managers, WeakSet, pre-computed mappings. **No imports from other torchlens modules** (prevents circular deps) |
32+
| `constants.py` | 645 | 7 FIELD_ORDER tuples, function discovery (~90 IGNORED_FUNCS, ORIG_TORCH_FUNCS) |
33+
| `user_funcs.py` | 664 | User-facing API: `log_forward_pass`, `validate_forward_pass`, `show_model_graph`, `get_model_metadata` |
3434

3535
## Key Concepts
3636

@@ -54,17 +54,18 @@ match but preserves all fields (no stripping). When adding new fields, update bo
5454
class definition and the corresponding FIELD_ORDER in constants.py.
5555

5656
## Subpackages
57-
- **[capture/](capture/CLAUDE.md)** — Real-time tensor operation logging during forward pass
58-
- **[data_classes/](data_classes/CLAUDE.md)** — ModelLog, LayerLog, LayerPassLog, ModuleLog, ParamLog, etc.
59-
- **[decoration/](decoration/CLAUDE.md)** — One-time torch function wrapping + model preparation
60-
- **[postprocess/](postprocess/CLAUDE.md)** — 18-step pipeline: graph cleanup, loop detection, labeling
61-
- **[utils/](utils/CLAUDE.md)** — Arg handling, tensor ops, RNG, hashing, display helpers
62-
- **[validation/](validation/CLAUDE.md)** — Forward replay, perturbation checks, metadata invariants
63-
- **[visualization/](visualization/CLAUDE.md)** — Graphviz-based computational graph rendering
57+
- **[capture/](capture/CLAUDE.md)** — Real-time tensor operation logging during forward pass (7 files: trace, output_tensors, source_tensors, tensor_tracking, arg_positions, salient_args, flops)
58+
- **[data_classes/](data_classes/CLAUDE.md)** — ModelLog, LayerLog, LayerPassLog, ModuleLog, ParamLog, etc. (10 files)
59+
- **[decoration/](decoration/CLAUDE.md)** — One-time torch function wrapping + model preparation (2 files)
60+
- **[postprocess/](postprocess/CLAUDE.md)** — 18-step pipeline: graph cleanup, loop detection, labeling (6 files)
61+
- **[utils/](utils/CLAUDE.md)** — Arg handling, tensor ops, RNG, hashing, display helpers (7 files)
62+
- **[validation/](validation/CLAUDE.md)** — Forward replay, perturbation checks, metadata invariants (3 files)
63+
- **[visualization/](visualization/CLAUDE.md)** — Graphviz + ELK-based computational graph rendering (2 files)
6464

6565
## Critical Invariants
6666
1. `_state.py` must never import other torchlens modules
6767
2. RNG state capture/restore must happen BEFORE `active_logging()` context
6868
3. `pause_logging()` must wrap any internal torch ops during logging (safe_copy, activation_postfunc, get_tensor_memory_amount)
6969
4. Decorated wrappers are permanent — never undecorated
7070
5. Field-order constants and class definitions must stay in sync
71+
6. Step 6 module suffix mutation makes `_rebuild_pass_assignments` (Step 8) NECESSARY — not just defensive

torchlens/capture/CLAUDE.md

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ saves only requested tensor data).
99

1010
| File | ~Lines | Purpose |
1111
|------|--------|---------|
12-
| `trace.py` | 347 | Forward-pass orchestration: input normalization, model execution, output extraction |
13-
| `output_tensors.py` | 737 | Core logging: builds LayerPassLog entries for each operation's output tensors |
14-
| `source_tensors.py` | 304 | Logs input and buffer tensors as source nodes in the graph |
15-
| `tensor_tracking.py` | 346 | Barcode system, parent-child links, backward hooks, operation equivalence fingerprinting |
16-
| `flops.py` | 1311 | Per-operation FLOPs computation (3-tier: zero/elementwise/specialty, ~290 ops) |
12+
| `trace.py` | 500 | Forward-pass orchestration: input normalization, model execution, session setup/cleanup |
13+
| `output_tensors.py` | 898 | Core logging: builds LayerPassLog entries, exhaustive/fast path split, identity detection |
14+
| `source_tensors.py` | 357 | Logs input and buffer tensors as source nodes in the graph |
15+
| `tensor_tracking.py` | 407 | Barcode system, parent-child links, backward hooks, arg hashing |
16+
| `arg_positions.py` | 961 | O(1) tensor extraction via 3-tier lookup: static table (639 entries), dynamic cache, BFS fallback |
17+
| `salient_args.py` | 444 | Extracts significant function args (hyperparameters) for metadata. 27 extractors for 50+ layer types |
18+
| `flops.py` | 1393 | Per-operation FLOPs computation (3-tier: zero/elementwise/specialty, ~290 ops) |
1719

1820
## Key Functions
1921

@@ -29,6 +31,7 @@ saves only requested tensor data).
2931
- `log_function_output_tensors_fast()` — Maps operation counter to existing raw label,
3032
validates function name matches, updates only tensor data + timing + RNG
3133
- `_output_should_be_logged()` — Tensor logged if unlabeled OR bottom-level function
34+
- `cond_branch_then_children` field added for conditional branch THEN detection
3235

3336
### source_tensors.py
3437
- `log_source_tensor_exhaustive()` / `log_source_tensor_fast()` — Marks input/buffer
@@ -41,6 +44,17 @@ saves only requested tensor data).
4144
(only 2 nesting levels tracked — deeper parents may get wrong arg positions)
4245
- `_add_backward_hook()` — Registers gradient capture (uses weakref to avoid GC leaks)
4346

47+
### arg_positions.py
48+
- 3-tier O(1) lookup: static `FUNC_ARG_SPECS` table → dynamic `_DYNAMIC_SPEC_CACHE` → BFS fallback
49+
- `ArgSpec` frozen dataclass: `tensor_args`, `tensor_kwargs`, `param_args`, `param_kwargs`
50+
- `extract_tensors_and_params()` — Main entry point, returns (tensors, params) from args/kwargs
51+
52+
### salient_args.py
53+
- `@_register()` pattern for extractors per layer type
54+
- `_build_arg_name_map()` maps positional args to named params
55+
- Failure-safe: try-except returns `{}` on any error
56+
- `_get()` helper returns `None` on missing keys (graceful degradation)
57+
4458
### flops.py
4559
- 3-tier system: ZERO_FLOPS_OPS (view, reshape = 0), ELEMENTWISE_FLOPS (relu, sigmoid),
4660
SPECIALTY_HANDLERS (conv2d, matmul — shape-aware computation)
@@ -51,6 +65,12 @@ saves only requested tensor data).
5165
- Function outputs: `{type}_{num}_{counter}_raw` (e.g., `"conv2d_1_5_raw"`)
5266
- Labels are raw during capture; renamed to final labels in postprocess/labeling.py
5367

68+
## Known Bugs
69+
- **ARG-KWARGS-MISSING**: `extract_tensors_and_params()` doesn't extract tensors passed as
70+
keyword args for many common functions (linear, cat, where). `tensor_kwargs=()` in static
71+
entries means `linear(x, weight=w, bias=b)` only finds `x`.
72+
- **salient_args silent drop**: `*args` silently dropped in `_build_arg_name_map` (lines 52-56)
73+
5474
## Gotchas
5575
- **In-place ops**: `safe_copy` strips `tl_tensor_label_raw` from clone, ensuring
5676
in-place ops are always logged as new operations. Label propagated back after logging.
@@ -62,6 +82,8 @@ saves only requested tensor data).
6282
Graph divergence between passes raises an error.
6383
- **pause_logging()**: Must wrap `activation_postfunc` calls and `get_tensor_memory_amount()`
6484
to prevent recursive logging of internal torch operations.
85+
- **arg_positions dynamic cache**: Never cleared on torch version upgrades — could serve
86+
stale specs if torch updates function signatures.
6587

6688
## Related
6789
- [decoration/](../decoration/CLAUDE.md) — Provides the decorated wrappers that call into this package

torchlens/data_classes/CLAUDE.md

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ All data structures for storing logged forward-pass information. Organized as a
55

66
```
77
ModelLog (top-level container)
8-
├─ LayerLog (aggregate per-layer, groups passes)
9-
└─ LayerPassLog (per-pass tensor operation entry, ~80+ fields)
10-
└─ BufferLog (extends LayerPassLog for buffer tensors)
11-
├─ ModuleLog (per module in model)
12-
└─ ModulePassLog (per invocation of a module)
13-
├─ ParamLog (per model parameter)
14-
└─ FuncCallLocation (structured call stack frame)
8+
|- LayerLog (aggregate per-layer, groups passes)
9+
| +- LayerPassLog (per-pass tensor operation entry, ~85+ fields)
10+
| +- BufferLog (extends LayerPassLog for buffer tensors)
11+
|- ModuleLog (per module in model)
12+
| +- ModulePassLog (per invocation of a module)
13+
|- ParamLog (per model parameter)
14+
+- FuncCallLocation (structured call stack frame)
1515
```
1616

1717
Each main class has a companion Accessor (LayerAccessor, ModuleAccessor, ParamAccessor,
@@ -21,34 +21,34 @@ BufferAccessor) providing dict-like indexing by name, index, or substring.
2121

2222
| File | ~Lines | Purpose |
2323
|------|--------|---------|
24-
| `model_log.py` | 316 | ModelLog: top-level container, FLOPs properties, `_init_module_build_data()` |
25-
| `layer_pass_log.py` | 477 | LayerPassLog: per-pass entry (~80+ fields), `save_tensor_data()`, `copy()` |
26-
| `layer_log.py` | 553 | LayerLog: aggregate class, `__getattr__` delegation, LayerAccessor |
27-
| `buffer_log.py` | 108 | BufferLog(LayerPassLog): `name`/`module_address` computed properties, BufferAccessor |
28-
| `module_log.py` | 353 | ModuleLog, ModulePassLog, ModuleAccessor |
29-
| `param_log.py` | 196 | ParamLog (lazy grad via `_param_ref`), ParamAccessor |
30-
| `func_call_location.py` | 230 | FuncCallLocation: lazy source loading via linecache |
31-
| `internal_types.py` | 41 | FuncExecutionContext + VisualizationOverrides dataclasses |
32-
| `interface.py` | 430 | ModelLog query methods: `__getitem__`, `__str__`, `to_pandas()` |
33-
| `cleanup.py` | 157 | Post-session teardown: destroy entries, free GPU memory |
24+
| `model_log.py` | 497 | ModelLog: top-level container, 70+ attrs, FLOPs properties, `conditional_then_edges` |
25+
| `layer_pass_log.py` | 677 | LayerPassLog: per-pass entry (~85+ fields, 18 @properties), `cond_branch_then_children`, `func_config` |
26+
| `layer_log.py` | 671 | LayerLog: aggregate class, 13 direct + 38 @properties, `__getattr__` delegation |
27+
| `buffer_log.py` | 132 | BufferLog(LayerPassLog): `name`/`module_address` computed properties, BufferAccessor |
28+
| `module_log.py` | 525 | ModuleLog, ModulePassLog, ModuleAccessor, shared alias support, FLOPs aggregation |
29+
| `param_log.py` | 248 | ParamLog (lazy grad via `_param_ref`), `release_param_ref()` for GC, ParamAccessor |
30+
| `func_call_location.py` | 265 | FuncCallLocation: lazy properties via linecache, dual construction paths, `_SENTINEL` |
31+
| `internal_types.py` | 61 | FuncExecutionContext + VisualizationOverrides (`@dataclass(slots=True)`) |
32+
| `interface.py` | 508 | ModelLog query methods: `__getitem__`, `__str__`, `to_pandas()`, 7-step lookup cascade |
33+
| `cleanup.py` | 237 | Post-session teardown: O(N+M) batch removal, `conditional_then_edges` filtering, `release_param_refs` |
3434

3535
## Key Access Patterns
3636

3737
```python
3838
# ModelLog access
39-
log["conv2d_1_5"] # LayerLog (aggregate)
40-
log["conv2d_1_5:2"] # LayerPassLog (specific pass)
41-
log[3] # LayerPassLog (by ordinal)
42-
log.layers # LayerAccessor (all LayerLogs)
43-
log.modules # ModuleAccessor
44-
log.params # ParamAccessor
45-
log.buffers # BufferAccessor
39+
log["conv2d_1_5"] # -> LayerLog (aggregate)
40+
log["conv2d_1_5:2"] # -> LayerPassLog (specific pass)
41+
log[3] # -> LayerPassLog (by ordinal)
42+
log.layers # -> LayerAccessor (all LayerLogs)
43+
log.modules # -> ModuleAccessor
44+
log.params # -> ParamAccessor
45+
log.buffers # -> BufferAccessor
4646

4747
# LayerLog delegation
4848
layer = log.layers["conv2d_1_1"]
49-
layer.tensor_contents # delegates to passes[1] for single-pass layers
50-
layer.child_layers # union of no-pass labels across all passes
51-
layer.passes # Dict[int, LayerPassLog]
49+
layer.tensor_contents # -> delegates to passes[1] for single-pass layers
50+
layer.child_layers # -> union of no-pass labels across all passes
51+
layer.passes # -> Dict[int, LayerPassLog]
5252
```
5353

5454
## Design Decisions
@@ -73,15 +73,24 @@ Subclass of LayerPassLog. `name` and `module_address` live only on BufferLog, no
7373
LayerLog (too generic for the aggregate). Single-pass buffer LayerLogs access them
7474
via `__getattr__` delegation.
7575

76+
### _build_layer_logs Multi-Pass Merge
77+
Only 3 fields merged across passes (has_input_ancestor OR, input_output_address char-merge,
78+
is_bottom_level_submodule_output OR). All other 78 fields use first-pass values.
79+
`cond_branch_start_children` and `cond_branch_then_children` use first pass only.
80+
7681
## Circular References (GC concern)
7782
```
78-
ModelLog LayerPassLog source_model_log ModelLog (CYCLE)
79-
ModelLog ModuleLog _source_model_log ModelLog (CYCLE)
80-
ParamLog _param_ref nn.Parameter (PINS MODEL)
83+
ModelLog -> LayerPassLog -> source_model_log -> ModelLog (CYCLE)
84+
ModelLog -> ModuleLog -> _source_model_log -> ModelLog (CYCLE)
85+
ParamLog -> _param_ref -> nn.Parameter (PINS MODEL)
8186
```
8287
All rely on Python's cyclic GC rather than ref-counting. `cleanup()` in cleanup.py
8388
can be called explicitly to break cycles.
8489

90+
## Known Bugs
91+
- **TO-PANDAS-NEW-FIELDS**: `to_pandas()` missing `func_config` and `cond_branch_then_children` columns
92+
- **COND-THEN-MULTIPASS**: `cond_branch_then_children` not merged for multi-pass LayerLog (first pass only)
93+
8594
## Gotchas
8695
- Adding new fields: update class definition AND `constants.py` FIELD_ORDER
8796
- `copy()` on LayerPassLog: shallow-copies 8 specific fields, deep-copies rest.
@@ -91,6 +100,8 @@ can be called explicitly to break cycles.
91100
- `equivalent_operations` per-LayerPassLog holds direct reference to ModelLog-level
92101
sets; becomes stale after rename step 11 (cosmetic, not read downstream)
93102
- `grad_contents` is a bare reference (no clone) — shared with parent tensor
103+
- `FuncCallLocation._frame_func_obj` set at construction but only released in
104+
`_load_source()` (lazy property trigger) — leaks if properties never accessed
94105

95106
## Related
96107
- [capture/](../capture/CLAUDE.md) — Creates LayerPassLog entries during logging

0 commit comments

Comments
 (0)