Skip to content

Commit b0bafeb

Browse files
feat(decoration): add global undecorate override
1 parent 6f9a3fe commit b0bafeb

File tree

6 files changed

+186
-4
lines changed

6 files changed

+186
-4
lines changed

TODO.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# TorchLens TODO
2+
3+
- Add a first-class "undecorate / suspend decoration" escape hatch so users can
4+
restore a clean PyTorch environment when tracing is not needed.
5+
- Explore decorating on the first `log_forward_pass(...)` instead of at import
6+
time, while still leaving torch decorated afterward once the user opts in.

tests/test_decoration.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
decorate_all_once,
2323
patch_detached_references,
2424
patch_model_instance,
25+
redecorate_all_globally,
26+
undecorate_all_globally,
2527
)
2628

2729

@@ -179,6 +181,33 @@ def test_requires_grad_restored_after_exception(self):
179181
assert param.requires_grad == orig_grads[name], f"{name} requires_grad changed"
180182

181183

184+
class TestGlobalUndecorate:
185+
def test_undecorate_restores_original_torch_function(self):
186+
"""Users can globally strip TorchLens wrappers from torch callables."""
187+
assert getattr(torch.cos, "tl_is_decorated_function", False)
188+
undecorate_all_globally()
189+
try:
190+
assert not getattr(torch.cos, "tl_is_decorated_function", False)
191+
x = torch.randn(4)
192+
y = torch.cos(x)
193+
assert y.shape == x.shape
194+
assert not hasattr(y, "tl_tensor_label_raw")
195+
finally:
196+
redecorate_all_globally()
197+
198+
def test_redecorate_restores_logging_after_global_undecorate(self):
199+
"""Global undecoration is reversible for later TorchLens use."""
200+
undecorate_all_globally()
201+
try:
202+
assert not getattr(torch.nn.functional.relu, "tl_is_decorated_function", False)
203+
finally:
204+
redecorate_all_globally()
205+
assert getattr(torch.nn.functional.relu, "tl_is_decorated_function", False)
206+
result = log_forward_pass(SimpleModel(), torch.randn(5))
207+
relu_layers = [label for label in result.layer_labels if "relu" in label.lower()]
208+
assert relu_layers
209+
210+
182211
# =========================================================================
183212
# 2. Torch Functions Normal When Toggle Off
184213
# =========================================================================

torchlens/CLAUDE.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class definition and the corresponding FIELD_ORDER in constants.py.
6666
1. `_state.py` must never import other torchlens modules
6767
2. RNG state capture/restore must happen BEFORE `active_logging()` context
6868
3. `pause_logging()` must wrap any internal torch ops during logging (safe_copy, activation_postfunc, get_tensor_memory_amount)
69-
4. Decorated wrappers are permanent — never undecorated
69+
4. Decorated wrappers are permanent by default, but advanced users may call
70+
`undecorate_all_globally()` / `redecorate_all_globally()` as an explicit
71+
override when they need a clean PyTorch environment.
7072
5. Field-order constants and class definitions must stay in sync
7173
6. Step 6 module suffix mutation makes `_rebuild_pass_assignments` (Step 8) NECESSARY — not just defensive

torchlens/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .data_classes.layer_pass_log import LayerPassLog, TensorLog
3131
from .data_classes import FuncCallLocation, ModuleAccessor, ModuleLog, ModulePassLog, ParamLog
3232
from .visualization import build_render_audit, model_log_to_dagua_graph, render_model_log_with_dagua
33+
from .decoration import redecorate_all_globally, undecorate_all_globally
3334

3435
# ---- Import-time decoration (side effects) --------------------------------
3536
#

torchlens/decoration/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,17 @@
11
"""Torch function wrapping and model preparation for logging."""
2+
3+
from .torch_funcs import (
4+
decorate_all_once,
5+
patch_detached_references,
6+
patch_model_instance,
7+
redecorate_all_globally,
8+
undecorate_all_globally,
9+
)
10+
11+
__all__ = [
12+
"decorate_all_once",
13+
"patch_detached_references",
14+
"patch_model_instance",
15+
"redecorate_all_globally",
16+
"undecorate_all_globally",
17+
]

torchlens/decoration/torch_funcs.py

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Permanent torch function wrapping: decorates all torch ops at import time with toggle-gated wrappers.
1+
"""Permanent-by-default torch function wrapping with optional global override.
22
33
This module implements the core interception mechanism for TorchLens. At ``import torchlens``
44
time, every function listed in ``ORIG_TORCH_FUNCS`` is replaced with a thin wrapper that
@@ -9,8 +9,11 @@
99
1010
Key design decisions:
1111
12-
1. **Permanent decoration** avoids the fragility of repeatedly patching/unpatching torch
13-
internals. The toggle makes this safe for production use.
12+
1. **Permanent decoration by default** avoids the fragility of repeatedly
13+
patching/unpatching torch internals. The toggle makes this safe for production
14+
use. Advanced users can still call ``undecorate_all_globally()`` to restore
15+
the original torch callables and ``redecorate_all_globally()`` to re-enable
16+
TorchLens interception.
1417
1518
2. **Shared originals reuse wrappers**: If ``torch.cos`` and ``torch._VF.cos`` point to the
1619
same C builtin, only one wrapper is created and both namespaces point to it. This keeps
@@ -584,6 +587,131 @@ def decorate_all_once():
584587
torch.identity = new_identity
585588

586589

590+
def _replace_detached_references(mapping: Dict[int, Callable]) -> None:
591+
"""Crawl ``sys.modules`` and replace callable references using ``mapping``.
592+
593+
``mapping`` may be either original->decorated or decorated->original. This
594+
keeps the sys.modules crawl logic symmetric so ``undecorate`` can reverse
595+
the same module/class/default-arg patching done during normal decoration.
596+
"""
597+
if not mapping:
598+
return
599+
600+
for mod in list(sys.modules.values()):
601+
if mod is None:
602+
continue
603+
if hasattr(mod, "__name__") and getattr(mod, "__name__", "").startswith("torchlens"):
604+
continue
605+
606+
try:
607+
mod_dict = vars(mod)
608+
except TypeError:
609+
continue
610+
611+
for attr_name, attr_val in list(mod_dict.items()):
612+
if id(attr_val) in mapping:
613+
try:
614+
mod_dict[attr_name] = mapping[id(attr_val)]
615+
except (TypeError, KeyError):
616+
pass
617+
continue
618+
619+
try:
620+
with warnings.catch_warnings():
621+
warnings.simplefilter("ignore")
622+
is_type = isinstance(attr_val, type)
623+
except Exception:
624+
is_type = False
625+
if is_type:
626+
try:
627+
cls_dict = vars(attr_val)
628+
except TypeError:
629+
continue
630+
for cls_attr_name, cls_attr_val in list(cls_dict.items()):
631+
if id(cls_attr_val) in mapping:
632+
try:
633+
setattr(attr_val, cls_attr_name, mapping[id(cls_attr_val)])
634+
except (AttributeError, TypeError):
635+
pass
636+
637+
try:
638+
is_callable = callable(attr_val) and not is_type
639+
except Exception:
640+
is_callable = False
641+
if is_callable:
642+
_patch_function_defaults(attr_val, mapping)
643+
644+
645+
def undecorate_all_globally() -> None:
646+
"""Restore original torch callables globally.
647+
648+
This is an explicit override for advanced users who want a clean PyTorch
649+
environment after importing TorchLens. It restores torch namespace
650+
attributes and reverses the detached-reference crawl so module globals,
651+
class attributes, and function defaults point back at original callables.
652+
653+
TorchLens logging will not function until ``redecorate_all_globally()`` is
654+
called again.
655+
"""
656+
_state._logging_enabled = False
657+
_state._active_model_log = None
658+
659+
if not _state._decorated_to_orig:
660+
return
661+
662+
for namespace_name, func_name in ORIG_TORCH_FUNCS:
663+
namespace_key = namespace_name.replace("torch.", "")
664+
local_func_namespace = nested_getattr(torch, namespace_key)
665+
if not hasattr(local_func_namespace, func_name):
666+
continue
667+
current = getattr(local_func_namespace, func_name)
668+
orig = _state._decorated_to_orig.get(id(current))
669+
if orig is None:
670+
continue
671+
try:
672+
with warnings.catch_warnings():
673+
warnings.simplefilter("ignore")
674+
setattr(local_func_namespace, func_name, orig)
675+
except (AttributeError, TypeError):
676+
pass
677+
678+
_replace_detached_references(_state._decorated_to_orig)
679+
torch.identity = identity
680+
681+
682+
def redecorate_all_globally() -> None:
683+
"""Reinstall TorchLens wrappers after a prior global undecoration."""
684+
if not _state._orig_to_decorated:
685+
decorate_all_once()
686+
patch_detached_references()
687+
return
688+
689+
for namespace_name, func_name in ORIG_TORCH_FUNCS:
690+
namespace_key = namespace_name.replace("torch.", "")
691+
local_func_namespace = nested_getattr(torch, namespace_key)
692+
if not hasattr(local_func_namespace, func_name):
693+
continue
694+
current = getattr(local_func_namespace, func_name)
695+
decorated = None
696+
if id(current) in _state._orig_to_decorated:
697+
decorated = _state._orig_to_decorated[id(current)]
698+
elif id(current) in _state._decorated_to_orig:
699+
decorated = current
700+
if decorated is None:
701+
continue
702+
try:
703+
with warnings.catch_warnings():
704+
warnings.simplefilter("ignore")
705+
setattr(local_func_namespace, func_name, decorated)
706+
except (AttributeError, TypeError):
707+
pass
708+
709+
_replace_detached_references(_state._orig_to_decorated)
710+
new_identity = torch_func_decorator(identity, "identity")
711+
new_identity.tl_is_decorated_function = True
712+
torch.identity = new_identity
713+
714+
587715
# ---------------------------------------------------------------------------
588716
# sys.modules deep crawl
589717
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)