Skip to content

Commit 18f8ae9

Browse files
docs(maintenance): split CLAUDE.md/AGENTS.md into architect vs implementation roles
Break the symlink mirroring convention: CLAUDE.md now holds architect-level context (what, why, how it connects) while AGENTS.md holds implementation-level context (conventions, gotchas, known bugs, test commands). Pure-implementation subdirs (.github, scripts, tests, utils) get AGENTS.md only. Also populates .project-context/ templates (architecture, conventions, gotchas, decisions).
1 parent b0bafeb commit 18f8ae9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

30 files changed

+1131
-711
lines changed

.github/AGENTS.md

Lines changed: 0 additions & 1 deletion
This file was deleted.

.github/AGENTS.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# .github/ — CI/CD Configuration
2+
3+
## Workflows
4+
5+
| File | Trigger | What It Does |
6+
|------|---------|-------------|
7+
| `workflows/lint.yml` | Push/PR | Auto-linting with ruff (`ruff format` + `ruff check --fix`), auto-commits fixes via GitHub App |
8+
| `workflows/quality.yml` | Push/PR | Two jobs: (1) mypy type-checking on Python 3.11, (2) pip-audit dependency audit. Both use CPU torch. |
9+
| `workflows/release.yml` | Push to main | Semantic-release v9 (conventional commits). Publishes to PyPI via trusted OIDC + GitHub Releases. |
10+
11+
## Release Pipeline Details
12+
- Semantic-release v9 (pinned `>=9,<10`), `major_on_zero = true`
13+
- `fetch-tags: true` in checkout step for proper version calculation
14+
- PyPI trusted publishing via OIDC (no API tokens)
15+
- GitHub App (`torchlens-release`) for auth
16+
- Branch protection via rulesets
17+
18+
## Conventions
19+
- Conventional commits required: `fix(scope):`, `feat(scope):`, `chore(scope):`
20+
- `fix:` → patch bump, `feat:` → minor bump, `feat!:` → major bump
21+
- `chore:`, `docs:`, `ci:`, `test:` → no release

.github/CLAUDE.md

Lines changed: 0 additions & 21 deletions
This file was deleted.

.gitignore

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,16 @@ dmypy.json
128128
# Pyre type checker
129129
.pyre/
130130
/tests/test_outputs/
131+
132+
# Architect-Worker system: ephemeral task state
133+
.project-context/tasks/
134+
135+
# Python
136+
__pycache__/
137+
*.pyc
138+
.mypy_cache/
139+
.ruff_cache/
140+
.pytest_cache/
141+
*.egg-info/
142+
dist/
143+
build/

.project-context/architecture.md

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# TorchLens Architecture
2+
3+
## Module Map
4+
5+
### `torchlens/_state.py` (~208 lines)
6+
Global toggle, session state, context managers. Single source of truth for `_logging_enabled`
7+
bool checked by every decorated wrapper. Also stores pre-computed lookup tables, WeakSet of
8+
prepared models, active ModelLog reference. **Must never import other torchlens modules**
9+
(prevents circular deps).
10+
11+
### `torchlens/user_funcs.py` (~664 lines)
12+
Public API: `log_forward_pass()`, `show_model_graph()`, `validate_forward_pass()`,
13+
`get_model_metadata()`, `validate_batch_of_models_and_inputs()`. Orchestrates the two-pass
14+
strategy when selective layers requested.
15+
16+
### `torchlens/constants.py` (~645 lines)
17+
7 FIELD_ORDER tuples (canonical field sets for LayerPassLog, ModelLog, etc.), function
18+
discovery sets (~90 IGNORED_FUNCS, ORIG_TORCH_FUNCS listing ~2000 functions to decorate).
19+
20+
### `torchlens/decoration/` (2 files, ~1,710 lines)
21+
- `torch_funcs.py` — One-time decoration of ~2000 torch functions. Core interceptor with
22+
barcode nesting detection, in-place detection, DeviceContext bypass.
23+
- `model_prep.py` — Two-phase model preparation (permanent `_prepare_model_once` + per-session
24+
`_prepare_model_session`). Module forward decorator with exhaustive/fast-path split.
25+
26+
### `torchlens/capture/` (7 files, ~4,960 lines)
27+
Real-time tensor operation logging during forward pass.
28+
- `trace.py` — Forward-pass orchestration, session setup/cleanup
29+
- `output_tensors.py` — Core logging: builds LayerPassLog entries, exhaustive/fast dispatch
30+
- `source_tensors.py` — Logs input and buffer tensors as source nodes
31+
- `tensor_tracking.py` — Barcode system, parent-child links, backward hooks
32+
- `arg_positions.py` — O(1) tensor extraction via 3-tier lookup (639 static entries)
33+
- `salient_args.py` — Extracts significant function args for metadata
34+
- `flops.py` — Per-operation FLOPs computation (~290 ops)
35+
36+
### `torchlens/postprocess/` (6 files, ~3,179 lines)
37+
18-step pipeline. Order is critical — many steps depend on prior output.
38+
- `graph_traversal.py` — Steps 1-4: output layers, ancestor marking, orphan removal, distance flood
39+
- `control_flow.py` — Steps 5-7: conditional branches (backward-only flood + AST THEN detection),
40+
module fixing, buffer cleanup
41+
- `loop_detection.py` — Step 8: isomorphic subgraph expansion, layer assignment
42+
- `labeling.py` — Steps 9-12: label generation, rename, trim/reorder, lookup keys
43+
- `finalization.py` — Steps 13-18: undecorate, ParamLog, ModuleLog, LayerLog, mark complete
44+
45+
### `torchlens/data_classes/` (10 files, ~3,821 lines)
46+
- `model_log.py` — ModelLog: top-level container, 70+ attrs
47+
- `layer_pass_log.py` — LayerPassLog: per-pass entry (~85+ fields)
48+
- `layer_log.py` — LayerLog: aggregate class grouping passes
49+
- `buffer_log.py` — BufferLog(LayerPassLog): buffer-specific computed properties
50+
- `module_log.py` — ModuleLog, ModulePassLog, ModuleAccessor
51+
- `param_log.py` — ParamLog (lazy grad via `_param_ref`)
52+
- `func_call_location.py` — Structured call stack frame with lazy properties
53+
- `internal_types.py` — FuncExecutionContext, VisualizationOverrides
54+
- `interface.py` — ModelLog query methods: `__getitem__`, `to_pandas()`, 7-step lookup cascade
55+
- `cleanup.py` — Post-session teardown, cycle breaking
56+
57+
### `torchlens/validation/` (3 files, ~2,795 lines)
58+
- `core.py` — BFS orchestration, forward replay, perturbation checks
59+
- `exemptions.py` — 4 data-driven exemption registries + 16 posthoc checks
60+
- `invariants.py` — 18 metadata invariant categories (A-R): structural + semantic
61+
62+
### `torchlens/visualization/` (3 files, ~2,777+ lines)
63+
- `rendering.py` — Graphviz rendering: nodes, edges, module subgraphs, IF/THEN labels, override system
64+
- `elk_layout.py` — ELK-based layout for large graphs, Worker thread, sfdp fallback
65+
- `dagua_bridge.py` — ModelLog → DaguaGraph conversion for dagua renderer
66+
67+
### `torchlens/utils/` (7 files, ~950 lines)
68+
Stateless helpers: arg handling, tensor ops (safe_copy, tensor_nanequal), RNG capture/restore,
69+
barcode hashing, object introspection, display formatting, collection manipulation.
70+
71+
## Data Flow
72+
73+
```
74+
import torchlens
75+
→ decorate_all_once() # wraps ~2000 torch functions permanently
76+
→ patch_detached_references() # patches `from torch import cos` style refs
77+
78+
log_forward_pass(model, input)
79+
→ _prepare_model_once(model) # permanent: tl_module_address, forward wrappers
80+
→ _prepare_model_session(model) # per-call: requires_grad, buffers, session attrs
81+
→ active_logging(model_log) # enables _logging_enabled toggle
82+
→ model(input) # forward pass — each torch op hits decorated wrapper
83+
→ torch_func_decorator # barcode nesting → bottom-level ops logged
84+
→ log_function_output_tensors_exhaustive() # builds LayerPassLog entry
85+
→ OR log_function_output_tensors_fast() # reuses prior graph structure
86+
→ postprocess(model_log) # 18-step pipeline
87+
→ Steps 1-4: graph cleanup (outputs, ancestors, orphans, distances)
88+
→ Steps 5-7: control flow (conditionals, module fixing, buffer dedup)
89+
→ Step 8: loop detection (isomorphic subgraph expansion)
90+
→ Steps 9-12: labeling (raw→final labels, rename, reorder, lookup keys)
91+
→ Steps 13-18: finalization (undecorate, ParamLog, ModuleLog, LayerLog)
92+
→ return ModelLog
93+
```
94+
95+
Key types flowing between modules:
96+
- `Dict[str, Dict]` — raw tensor dict during capture (`_raw_tensor_dict` on ModelLog)
97+
- `LayerPassLog` — per-pass tensor operation entry (~85+ fields)
98+
- `LayerLog` — aggregate grouping passes of the same layer
99+
- `ModuleLog` / `ModulePassLog` — per-module metadata
100+
- `ParamLog` — per-parameter metadata with lazy gradient access
101+
102+
## Key Abstractions
103+
104+
### Toggle Architecture
105+
Single `_logging_enabled` bool in `_state.py`. Wrappers check it on every call — when False,
106+
one branch check, negligible overhead. No re-wrapping/un-wrapping per forward pass.
107+
108+
### Two-Pass Strategy
109+
When user requests specific layers (not "all"/"none"), Pass 1 runs exhaustive to discover full
110+
graph structure, Pass 2 runs fast saving only requested activations. Counter alignment between
111+
passes maintained via identical increment logic.
112+
113+
### Barcode Nesting Detection
114+
Random 8-char barcodes detect bottom-level vs wrapper functions. Barcode set on tensor before
115+
call; if unchanged after → no nested torch calls → log it. If changed → nested call already
116+
logged it.
117+
118+
### Operation Equivalence Types
119+
Structural fingerprint: `{func_name}_{arg_hash}[_outindex{i}][_module{origin}]`. Used by
120+
loop detection (Step 8) to group operations into layers.
121+
122+
### LayerLog Delegation
123+
Single-pass layers: `__getattr__` delegates to `passes[1]`. Multi-pass per-pass fields:
124+
raises **ValueError** (not AttributeError, to avoid Python's property/__getattr__ trap).
125+
126+
## Dependency Graph
127+
```
128+
_state.py ← imported by everything (no outgoing torchlens imports)
129+
constants.py ← imported by capture/, postprocess/, data_classes/
130+
utils/ ← imported by capture/, postprocess/, data_classes/, validation/
131+
decoration/ → calls capture/ (via decorated wrappers)
132+
→ reads _state.py
133+
capture/ → creates data_classes/ entries (LayerPassLog)
134+
→ reads _state.py, constants.py
135+
postprocess/ → mutates data_classes/ entries
136+
→ reads constants.py
137+
data_classes/ → references _state.py (TYPE_CHECKING only)
138+
validation/ → reads data_classes/, calls original torch funcs
139+
visualization/ → reads data_classes/ (LayerLog, ModelLog)
140+
user_funcs.py → orchestrates decoration/, capture/, postprocess/, validation/, visualization/
141+
```
142+
143+
## Known Complexity
144+
145+
### Loop Detection (postprocess/loop_detection.py)
146+
Most complex single module. BFS expansion of isomorphic subgraphs, iso group refinement with
147+
direction-aware neighbor connectivity, adjacency union-find for layer assignment. Step 6's
148+
module suffix mutation makes `_rebuild_pass_assignments` necessary (not defensive). ~826 lines.
149+
150+
### Exhaustive/Fast-Path Split (capture/output_tensors.py)
151+
Two parallel code paths that must maintain counter alignment. Fast path skips most metadata
152+
but must match exhaustive path's operation ordering exactly.
153+
154+
### ELK Layout (visualization/elk_layout.py)
155+
Node.js subprocess with V8 heap sizing, Worker thread to prevent stack overflow, stress
156+
algorithm with O(n^2) memory (NEVER use for >100k nodes), Kahn's topological sort for seeding.
157+
158+
### Circular References (data_classes/)
159+
ModelLog ↔ LayerPassLog ↔ ModelLog cycles. ModuleLog ↔ ModelLog cycles. ParamLog pins
160+
nn.Parameter. All rely on Python's cyclic GC. Explicit `cleanup()` available.
161+
162+
### DeviceContext Bypass (decoration/torch_funcs.py)
163+
Python wrappers bypass C-level TorchFunctionMode dispatch. Factory functions need manual
164+
device kwarg injection when `torch.device('meta')` context is active (HuggingFace use case).

.project-context/conventions.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# TorchLens Conventions
2+
3+
## Naming
4+
5+
### Files & Modules
6+
- Snake_case for all Python files
7+
- Subpackage CLAUDE.md files document each package
8+
9+
### Variables & Attributes
10+
- `tl_` prefix on tensor/module attributes during logging
11+
- Permanent attrs (survive sessions): `tl_module_address`, `tl_module_type`
12+
- Session attrs (cleaned per-call): `tl_source_model_log`, `tl_module_pass_num`, etc.
13+
- `_raw_` prefix for pre-postprocessing state (e.g., `tl_tensor_label_raw`)
14+
- `_final_` prefix for post-processed state
15+
- `_orig_` prefix for original (pre-decoration) references
16+
- `clean_` prefix for pre-decoration torch function imports (e.g., `clean_clone = torch.clone`)
17+
18+
### Labels
19+
- Source tensors: `{type}_{num}_raw` during capture (e.g., `input_0_raw`, `buffer_1_raw`)
20+
- Function outputs: `{type}_{num}_{counter}_raw` during capture
21+
- Final labels: human-readable after postprocess/labeling.py (e.g., `conv2d_1_5`)
22+
- Pass-qualified: `{label}:{pass_num}` (e.g., `conv2d_1_5:2`)
23+
24+
### Classes
25+
- PascalCase: `ModelLog`, `LayerPassLog`, `LayerLog`, `BufferLog`, `ModuleLog`, `ParamLog`
26+
- Accessors: `LayerAccessor`, `ModuleAccessor`, `ParamAccessor`, `BufferAccessor`
27+
- Internal: `FuncExecutionContext`, `VisualizationOverrides`, `FuncCallLocation`
28+
29+
### Constants
30+
- UPPER_SNAKE_CASE: `FIELD_ORDER`, `ORIG_TORCH_FUNCS`, `IGNORED_FUNCS`
31+
- `_DEVICE_CONSTRUCTOR_NAMES`, `_ATTR_SKIP_SET` for internal sets
32+
33+
## Error Handling
34+
- Validation errors: `MetadataInvariantError(check_name, message)` — named checks A through R
35+
- LayerLog multi-pass access: raises **ValueError** (not AttributeError) to avoid Python's
36+
property/__getattr__ trap
37+
- `salient_args.py` extractors: try-except returns `{}` on any error (failure-safe)
38+
- Validation replay: exceptions caught and returned as None (Bug #151 — known silent pass)
39+
- `FuncCallLocation`: lazy properties loaded via `linecache` on first access, not at construction
40+
41+
## Testing Patterns
42+
43+
### Fixtures (tests/conftest.py)
44+
- `default_input1` through `default_input4`: `(6,3,224,224)` standard image tensors
45+
- `zeros_input`, `ones_input`: edge-case inputs
46+
- `vector_input` `(5,)`, `input_2d` `(5,5)`, `input_complex` `(3,3)` complex
47+
- `small_input` `(2,3,32,32)`: fast metadata tests
48+
- Deterministic seeding: `torch.manual_seed(0)`, `torch.use_deterministic_algorithms(True)`
49+
50+
### Markers
51+
- `@pytest.mark.slow` — real-world model tests taking >5 min
52+
- `@pytest.mark.smoke` — 18 critical-path tests for fast validation (~6s total)
53+
- `@pytest.mark.rare` — always excluded unless `-m rare` specified
54+
55+
### Test Categories
56+
- **Toy models** (`test_toy_models.py`): `validate_saved_activations()` + `show_model_graph()` for every test
57+
- **Real-world** (`test_real_world_models.py`): `pytest.importorskip()` for optional deps
58+
- **Metadata** (`test_metadata.py`): `log_forward_pass()` directly, assert field properties
59+
- **Aesthetic** (`test_output_aesthetics.py`): generates PDFs for human visual inspection
60+
61+
### Model Definitions
62+
All test models live in `tests/example_models.py` (~5,400 lines). New models go here.
63+
64+
### Output
65+
All test outputs → `tests/test_outputs/` (gitignored):
66+
- `reports/` — coverage, aesthetic report, profiling
67+
- `visualizations/` — PDFs organized by model family subdirectories
68+
69+
## Import Order
70+
stdlib → third-party → local (enforced by ruff)
71+
72+
```python
73+
import os
74+
from typing import Dict, List, Optional
75+
76+
import torch
77+
from torch import nn
78+
79+
from .utils.tensor_utils import safe_copy
80+
from ._state import _logging_enabled
81+
```
82+
83+
## Documentation
84+
- Docstring format: NumPy style
85+
- Type hints on all functions (including internal)
86+
- Top-level file comments on `.py` files where purpose isn't obvious
87+
- Each subpackage has a `CLAUDE.md` with file table, key functions, gotchas, known bugs
88+
89+
## Git
90+
91+
### Commit Messages
92+
Conventional commits for semantic-release:
93+
```
94+
<type>(<scope>): <description> (#<issue>)
95+
```
96+
97+
Types: `fix`, `feat`, `chore`, `docs`, `ci`, `refactor`, `test`, `style`
98+
99+
Scopes (common): `logging`, `vis`, `postprocess`, `capture`, `validation`, `decoration`,
100+
`data`, `state`, `utils`, `ci`, `release`, `types`
101+
102+
### Branch Naming
103+
- Feature branches: `codex/<task-id>` (kebab-case task IDs)
104+
- One branch at a time besides main
105+
106+
### CI/CD
107+
- `lint.yml`: ruff format + check on push/PR, auto-commits fixes
108+
- `quality.yml`: mypy + pip-audit on push/PR
109+
- `release.yml`: semantic-release v9 on push to main, PyPI via OIDC
110+
111+
## Field Management
112+
FIELD_ORDER tuples in `constants.py` define complete field sets. When adding a new field:
113+
1. Add to class definition (LayerPassLog, ModelLog, etc.)
114+
2. Add to corresponding FIELD_ORDER in `constants.py`
115+
3. Add test in `test_metadata.py`
116+
4. Update `to_pandas()` if user-facing
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# TorchLens Architectural Decisions
2+
3+
## 2024 — Toggle Architecture (Permanent Decoration)
4+
Context: Originally TorchLens re-wrapped/un-wrapped torch functions on every `log_forward_pass` call.
5+
Decision: Wrap all ~2000 torch functions once at `import torchlens` time, gate with single `_logging_enabled` bool.
6+
Rationale: Eliminates per-call decoration overhead (~200ms), makes wrappers stateless. Single bool check when disabled is negligible.
7+
Alternatives considered: Context-manager-based decoration (too slow), monkey-patching per call (fragile).
8+
9+
## 2024 — Global State in _state.py
10+
Context: Decorated wrappers need access to session state (active ModelLog, toggle, etc.).
11+
Decision: Single `_state.py` module holds all mutable state. No imports from other torchlens modules.
12+
Rationale: Prevents circular imports. Wrappers only need to import `_state`, not heavy torchlens modules.
13+
Alternatives considered: Thread-local state (too complex), class-based state (no benefit over module globals).
14+
15+
## 2025 — LayerLog Class Hierarchy (PR #92)
16+
Context: TensorLog was both per-pass and aggregate. RolledTensorLog was a separate class for rolled views.
17+
Decision: Split into LayerPassLog (per-pass) and LayerLog (aggregate). Eliminate RolledTensorLog.
18+
Rationale: Clean separation of concerns. LayerLog delegates to single-pass LayerPassLog via `__getattr__`.
19+
Alternatives considered: Keep RolledTensorLog (too much duplication).
20+
21+
## 2025 — BufferLog Stays as Subclass
22+
Context: BufferLog has `name`/`module_address` fields that don't apply to generic LayerLog.
23+
Decision: BufferLog(LayerPassLog) keeps buffer-specific properties. Single-pass LayerLogs access them via delegation.
24+
Rationale: LayerLog is too generic for buffer metadata. Delegation handles the single-pass common case.
25+
26+
## 2025 — Backward-Only Conditional Flood (Bug #88, PR #127)
27+
Context: Bidirectional flood from terminal booleans falsely marked non-conditional children.
28+
Decision: `_mark_conditional_branches` floods backward-only (parent_layers). AST-based THEN detection when `save_source_context=True`.
29+
Rationale: Forward flood follows data flow, not control flow. Backward-only correctly marks ancestors of the branch decision.
30+
31+
## 2026 — ELK Stress Bypass for >100k Nodes (PR #132)
32+
Context: ELK stress allocates two n^2 × 8-byte distance matrices. 100k nodes = 160GB.
33+
Decision: >100k nodes bypass ELK entirely → Python topological layout (Kahn's algorithm, O(n+m)).
34+
Rationale: No size guard possible in elkjs. The old >150k stress switch was fundamentally broken.
35+
36+
## 2026 — Dagua Integration (Opt-In)
37+
Context: Graphviz rendering has limitations for large graphs and interactive exploration.
38+
Decision: Add dagua as optional renderer (`vis_renderer="dagua"`). Graphviz remains default.
39+
Rationale: Dagua provides GPU-accelerated layout and richer interaction model, but visual semantics still under iteration. Keep stable default.
40+
41+
## 2026 — Global Undecorate Override (PR latest)
42+
Context: Advanced users need clean PyTorch environment for benchmarking, profiling, or debugging decorator interactions.
43+
Decision: Expose `undecorate_all_globally()` / `redecorate_all_globally()` as explicit user API.
44+
Rationale: Permanent decoration is the right default, but escape hatch needed for power users.

0 commit comments

Comments
 (0)