|
| 1 | +# TorchLens Architecture |
| 2 | + |
| 3 | +## Module Map |
| 4 | + |
| 5 | +### `torchlens/_state.py` (~208 lines) |
| 6 | +Global toggle, session state, context managers. Single source of truth for `_logging_enabled` |
| 7 | +bool checked by every decorated wrapper. Also stores pre-computed lookup tables, WeakSet of |
| 8 | +prepared models, active ModelLog reference. **Must never import other torchlens modules** |
| 9 | +(prevents circular deps). |
| 10 | + |
| 11 | +### `torchlens/user_funcs.py` (~664 lines) |
| 12 | +Public API: `log_forward_pass()`, `show_model_graph()`, `validate_forward_pass()`, |
| 13 | +`get_model_metadata()`, `validate_batch_of_models_and_inputs()`. Orchestrates the two-pass |
| 14 | +strategy when selective layers requested. |
| 15 | + |
| 16 | +### `torchlens/constants.py` (~645 lines) |
| 17 | +7 FIELD_ORDER tuples (canonical field sets for LayerPassLog, ModelLog, etc.), function |
| 18 | +discovery sets (~90 IGNORED_FUNCS, ORIG_TORCH_FUNCS listing ~2000 functions to decorate). |
| 19 | + |
| 20 | +### `torchlens/decoration/` (2 files, ~1,710 lines) |
| 21 | +- `torch_funcs.py` — One-time decoration of ~2000 torch functions. Core interceptor with |
| 22 | + barcode nesting detection, in-place detection, DeviceContext bypass. |
| 23 | +- `model_prep.py` — Two-phase model preparation (permanent `_prepare_model_once` + per-session |
| 24 | + `_prepare_model_session`). Module forward decorator with exhaustive/fast-path split. |
| 25 | + |
| 26 | +### `torchlens/capture/` (7 files, ~4,960 lines) |
| 27 | +Real-time tensor operation logging during forward pass. |
| 28 | +- `trace.py` — Forward-pass orchestration, session setup/cleanup |
| 29 | +- `output_tensors.py` — Core logging: builds LayerPassLog entries, exhaustive/fast dispatch |
| 30 | +- `source_tensors.py` — Logs input and buffer tensors as source nodes |
| 31 | +- `tensor_tracking.py` — Barcode system, parent-child links, backward hooks |
| 32 | +- `arg_positions.py` — O(1) tensor extraction via 3-tier lookup (639 static entries) |
| 33 | +- `salient_args.py` — Extracts significant function args for metadata |
| 34 | +- `flops.py` — Per-operation FLOPs computation (~290 ops) |
| 35 | + |
| 36 | +### `torchlens/postprocess/` (6 files, ~3,179 lines) |
| 37 | +18-step pipeline. Order is critical — many steps depend on prior output. |
| 38 | +- `graph_traversal.py` — Steps 1-4: output layers, ancestor marking, orphan removal, distance flood |
| 39 | +- `control_flow.py` — Steps 5-7: conditional branches (backward-only flood + AST THEN detection), |
| 40 | + module fixing, buffer cleanup |
| 41 | +- `loop_detection.py` — Step 8: isomorphic subgraph expansion, layer assignment |
| 42 | +- `labeling.py` — Steps 9-12: label generation, rename, trim/reorder, lookup keys |
| 43 | +- `finalization.py` — Steps 13-18: undecorate, ParamLog, ModuleLog, LayerLog, mark complete |
| 44 | + |
| 45 | +### `torchlens/data_classes/` (10 files, ~3,821 lines) |
| 46 | +- `model_log.py` — ModelLog: top-level container, 70+ attrs |
| 47 | +- `layer_pass_log.py` — LayerPassLog: per-pass entry (~85+ fields) |
| 48 | +- `layer_log.py` — LayerLog: aggregate class grouping passes |
| 49 | +- `buffer_log.py` — BufferLog(LayerPassLog): buffer-specific computed properties |
| 50 | +- `module_log.py` — ModuleLog, ModulePassLog, ModuleAccessor |
| 51 | +- `param_log.py` — ParamLog (lazy grad via `_param_ref`) |
| 52 | +- `func_call_location.py` — Structured call stack frame with lazy properties |
| 53 | +- `internal_types.py` — FuncExecutionContext, VisualizationOverrides |
| 54 | +- `interface.py` — ModelLog query methods: `__getitem__`, `to_pandas()`, 7-step lookup cascade |
| 55 | +- `cleanup.py` — Post-session teardown, cycle breaking |
| 56 | + |
| 57 | +### `torchlens/validation/` (3 files, ~2,795 lines) |
| 58 | +- `core.py` — BFS orchestration, forward replay, perturbation checks |
| 59 | +- `exemptions.py` — 4 data-driven exemption registries + 16 posthoc checks |
| 60 | +- `invariants.py` — 18 metadata invariant categories (A-R): structural + semantic |
| 61 | + |
| 62 | +### `torchlens/visualization/` (3 files, ~2,777+ lines) |
| 63 | +- `rendering.py` — Graphviz rendering: nodes, edges, module subgraphs, IF/THEN labels, override system |
| 64 | +- `elk_layout.py` — ELK-based layout for large graphs, Worker thread, sfdp fallback |
| 65 | +- `dagua_bridge.py` — ModelLog → DaguaGraph conversion for dagua renderer |
| 66 | + |
| 67 | +### `torchlens/utils/` (7 files, ~950 lines) |
| 68 | +Stateless helpers: arg handling, tensor ops (safe_copy, tensor_nanequal), RNG capture/restore, |
| 69 | +barcode hashing, object introspection, display formatting, collection manipulation. |
| 70 | + |
| 71 | +## Data Flow |
| 72 | + |
| 73 | +``` |
| 74 | +import torchlens |
| 75 | + → decorate_all_once() # wraps ~2000 torch functions permanently |
| 76 | + → patch_detached_references() # patches `from torch import cos` style refs |
| 77 | +
|
| 78 | +log_forward_pass(model, input) |
| 79 | + → _prepare_model_once(model) # permanent: tl_module_address, forward wrappers |
| 80 | + → _prepare_model_session(model) # per-call: requires_grad, buffers, session attrs |
| 81 | + → active_logging(model_log) # enables _logging_enabled toggle |
| 82 | + → model(input) # forward pass — each torch op hits decorated wrapper |
| 83 | + → torch_func_decorator # barcode nesting → bottom-level ops logged |
| 84 | + → log_function_output_tensors_exhaustive() # builds LayerPassLog entry |
| 85 | + → OR log_function_output_tensors_fast() # reuses prior graph structure |
| 86 | + → postprocess(model_log) # 18-step pipeline |
| 87 | + → Steps 1-4: graph cleanup (outputs, ancestors, orphans, distances) |
| 88 | + → Steps 5-7: control flow (conditionals, module fixing, buffer dedup) |
| 89 | + → Step 8: loop detection (isomorphic subgraph expansion) |
| 90 | + → Steps 9-12: labeling (raw→final labels, rename, reorder, lookup keys) |
| 91 | + → Steps 13-18: finalization (undecorate, ParamLog, ModuleLog, LayerLog) |
| 92 | + → return ModelLog |
| 93 | +``` |
| 94 | + |
| 95 | +Key types flowing between modules: |
| 96 | +- `Dict[str, Dict]` — raw tensor dict during capture (`_raw_tensor_dict` on ModelLog) |
| 97 | +- `LayerPassLog` — per-pass tensor operation entry (~85+ fields) |
| 98 | +- `LayerLog` — aggregate grouping passes of the same layer |
| 99 | +- `ModuleLog` / `ModulePassLog` — per-module metadata |
| 100 | +- `ParamLog` — per-parameter metadata with lazy gradient access |
| 101 | + |
| 102 | +## Key Abstractions |
| 103 | + |
| 104 | +### Toggle Architecture |
| 105 | +Single `_logging_enabled` bool in `_state.py`. Wrappers check it on every call — when False, |
| 106 | +one branch check, negligible overhead. No re-wrapping/un-wrapping per forward pass. |
| 107 | + |
| 108 | +### Two-Pass Strategy |
| 109 | +When user requests specific layers (not "all"/"none"), Pass 1 runs exhaustive to discover full |
| 110 | +graph structure, Pass 2 runs fast saving only requested activations. Counter alignment between |
| 111 | +passes maintained via identical increment logic. |
| 112 | + |
| 113 | +### Barcode Nesting Detection |
| 114 | +Random 8-char barcodes detect bottom-level vs wrapper functions. Barcode set on tensor before |
| 115 | +call; if unchanged after → no nested torch calls → log it. If changed → nested call already |
| 116 | +logged it. |
| 117 | + |
| 118 | +### Operation Equivalence Types |
| 119 | +Structural fingerprint: `{func_name}_{arg_hash}[_outindex{i}][_module{origin}]`. Used by |
| 120 | +loop detection (Step 8) to group operations into layers. |
| 121 | + |
| 122 | +### LayerLog Delegation |
| 123 | +Single-pass layers: `__getattr__` delegates to `passes[1]`. Multi-pass per-pass fields: |
| 124 | +raises **ValueError** (not AttributeError, to avoid Python's property/__getattr__ trap). |
| 125 | + |
| 126 | +## Dependency Graph |
| 127 | +``` |
| 128 | +_state.py ← imported by everything (no outgoing torchlens imports) |
| 129 | +constants.py ← imported by capture/, postprocess/, data_classes/ |
| 130 | +utils/ ← imported by capture/, postprocess/, data_classes/, validation/ |
| 131 | +decoration/ → calls capture/ (via decorated wrappers) |
| 132 | + → reads _state.py |
| 133 | +capture/ → creates data_classes/ entries (LayerPassLog) |
| 134 | + → reads _state.py, constants.py |
| 135 | +postprocess/ → mutates data_classes/ entries |
| 136 | + → reads constants.py |
| 137 | +data_classes/ → references _state.py (TYPE_CHECKING only) |
| 138 | +validation/ → reads data_classes/, calls original torch funcs |
| 139 | +visualization/ → reads data_classes/ (LayerLog, ModelLog) |
| 140 | +user_funcs.py → orchestrates decoration/, capture/, postprocess/, validation/, visualization/ |
| 141 | +``` |
| 142 | + |
| 143 | +## Known Complexity |
| 144 | + |
| 145 | +### Loop Detection (postprocess/loop_detection.py) |
| 146 | +Most complex single module. BFS expansion of isomorphic subgraphs, iso group refinement with |
| 147 | +direction-aware neighbor connectivity, adjacency union-find for layer assignment. Step 6's |
| 148 | +module suffix mutation makes `_rebuild_pass_assignments` necessary (not defensive). ~826 lines. |
| 149 | + |
| 150 | +### Exhaustive/Fast-Path Split (capture/output_tensors.py) |
| 151 | +Two parallel code paths that must maintain counter alignment. Fast path skips most metadata |
| 152 | +but must match exhaustive path's operation ordering exactly. |
| 153 | + |
| 154 | +### ELK Layout (visualization/elk_layout.py) |
| 155 | +Node.js subprocess with V8 heap sizing, Worker thread to prevent stack overflow, stress |
| 156 | +algorithm with O(n^2) memory (NEVER use for >100k nodes), Kahn's topological sort for seeding. |
| 157 | + |
| 158 | +### Circular References (data_classes/) |
| 159 | +ModelLog ↔ LayerPassLog ↔ ModelLog cycles. ModuleLog ↔ ModelLog cycles. ParamLog pins |
| 160 | +nn.Parameter. All rely on Python's cyclic GC. Explicit `cleanup()` available. |
| 161 | + |
| 162 | +### DeviceContext Bypass (decoration/torch_funcs.py) |
| 163 | +Python wrappers bypass C-level TorchFunctionMode dispatch. Factory functions need manual |
| 164 | +device kwarg injection when `torch.device('meta')` context is active (HuggingFace use case). |
0 commit comments