|
1 | | -"""Permanent torch function wrapping: decorates all torch ops at import time with toggle-gated wrappers. |
| 1 | +"""Permanent-by-default torch function wrapping with optional global override. |
2 | 2 |
|
3 | 3 | This module implements the core interception mechanism for TorchLens. At ``import torchlens`` |
4 | 4 | time, every function listed in ``ORIG_TORCH_FUNCS`` is replaced with a thin wrapper that |
|
9 | 9 |
|
10 | 10 | Key design decisions: |
11 | 11 |
|
12 | | -1. **Permanent decoration** avoids the fragility of repeatedly patching/unpatching torch |
13 | | - internals. The toggle makes this safe for production use. |
| 12 | +1. **Permanent decoration by default** avoids the fragility of repeatedly |
| 13 | + patching/unpatching torch internals. The toggle makes this safe for production |
| 14 | + use. Advanced users can still call ``undecorate_all_globally()`` to restore |
| 15 | + the original torch callables and ``redecorate_all_globally()`` to re-enable |
| 16 | + TorchLens interception. |
14 | 17 |
|
15 | 18 | 2. **Shared originals reuse wrappers**: If ``torch.cos`` and ``torch._VF.cos`` point to the |
16 | 19 | same C builtin, only one wrapper is created and both namespaces point to it. This keeps |
@@ -584,6 +587,131 @@ def decorate_all_once(): |
584 | 587 | torch.identity = new_identity |
585 | 588 |
|
586 | 589 |
|
| 590 | +def _replace_detached_references(mapping: Dict[int, Callable]) -> None: |
| 591 | + """Crawl ``sys.modules`` and replace callable references using ``mapping``. |
| 592 | +
|
| 593 | + ``mapping`` may be either original->decorated or decorated->original. This |
| 594 | + keeps the sys.modules crawl logic symmetric so ``undecorate`` can reverse |
| 595 | + the same module/class/default-arg patching done during normal decoration. |
| 596 | + """ |
| 597 | + if not mapping: |
| 598 | + return |
| 599 | + |
| 600 | + for mod in list(sys.modules.values()): |
| 601 | + if mod is None: |
| 602 | + continue |
| 603 | + if hasattr(mod, "__name__") and getattr(mod, "__name__", "").startswith("torchlens"): |
| 604 | + continue |
| 605 | + |
| 606 | + try: |
| 607 | + mod_dict = vars(mod) |
| 608 | + except TypeError: |
| 609 | + continue |
| 610 | + |
| 611 | + for attr_name, attr_val in list(mod_dict.items()): |
| 612 | + if id(attr_val) in mapping: |
| 613 | + try: |
| 614 | + mod_dict[attr_name] = mapping[id(attr_val)] |
| 615 | + except (TypeError, KeyError): |
| 616 | + pass |
| 617 | + continue |
| 618 | + |
| 619 | + try: |
| 620 | + with warnings.catch_warnings(): |
| 621 | + warnings.simplefilter("ignore") |
| 622 | + is_type = isinstance(attr_val, type) |
| 623 | + except Exception: |
| 624 | + is_type = False |
| 625 | + if is_type: |
| 626 | + try: |
| 627 | + cls_dict = vars(attr_val) |
| 628 | + except TypeError: |
| 629 | + continue |
| 630 | + for cls_attr_name, cls_attr_val in list(cls_dict.items()): |
| 631 | + if id(cls_attr_val) in mapping: |
| 632 | + try: |
| 633 | + setattr(attr_val, cls_attr_name, mapping[id(cls_attr_val)]) |
| 634 | + except (AttributeError, TypeError): |
| 635 | + pass |
| 636 | + |
| 637 | + try: |
| 638 | + is_callable = callable(attr_val) and not is_type |
| 639 | + except Exception: |
| 640 | + is_callable = False |
| 641 | + if is_callable: |
| 642 | + _patch_function_defaults(attr_val, mapping) |
| 643 | + |
| 644 | + |
| 645 | +def undecorate_all_globally() -> None: |
| 646 | + """Restore original torch callables globally. |
| 647 | +
|
| 648 | + This is an explicit override for advanced users who want a clean PyTorch |
| 649 | + environment after importing TorchLens. It restores torch namespace |
| 650 | + attributes and reverses the detached-reference crawl so module globals, |
| 651 | + class attributes, and function defaults point back at original callables. |
| 652 | +
|
| 653 | + TorchLens logging will not function until ``redecorate_all_globally()`` is |
| 654 | + called again. |
| 655 | + """ |
| 656 | + _state._logging_enabled = False |
| 657 | + _state._active_model_log = None |
| 658 | + |
| 659 | + if not _state._decorated_to_orig: |
| 660 | + return |
| 661 | + |
| 662 | + for namespace_name, func_name in ORIG_TORCH_FUNCS: |
| 663 | + namespace_key = namespace_name.replace("torch.", "") |
| 664 | + local_func_namespace = nested_getattr(torch, namespace_key) |
| 665 | + if not hasattr(local_func_namespace, func_name): |
| 666 | + continue |
| 667 | + current = getattr(local_func_namespace, func_name) |
| 668 | + orig = _state._decorated_to_orig.get(id(current)) |
| 669 | + if orig is None: |
| 670 | + continue |
| 671 | + try: |
| 672 | + with warnings.catch_warnings(): |
| 673 | + warnings.simplefilter("ignore") |
| 674 | + setattr(local_func_namespace, func_name, orig) |
| 675 | + except (AttributeError, TypeError): |
| 676 | + pass |
| 677 | + |
| 678 | + _replace_detached_references(_state._decorated_to_orig) |
| 679 | + torch.identity = identity |
| 680 | + |
| 681 | + |
| 682 | +def redecorate_all_globally() -> None: |
| 683 | + """Reinstall TorchLens wrappers after a prior global undecoration.""" |
| 684 | + if not _state._orig_to_decorated: |
| 685 | + decorate_all_once() |
| 686 | + patch_detached_references() |
| 687 | + return |
| 688 | + |
| 689 | + for namespace_name, func_name in ORIG_TORCH_FUNCS: |
| 690 | + namespace_key = namespace_name.replace("torch.", "") |
| 691 | + local_func_namespace = nested_getattr(torch, namespace_key) |
| 692 | + if not hasattr(local_func_namespace, func_name): |
| 693 | + continue |
| 694 | + current = getattr(local_func_namespace, func_name) |
| 695 | + decorated = None |
| 696 | + if id(current) in _state._orig_to_decorated: |
| 697 | + decorated = _state._orig_to_decorated[id(current)] |
| 698 | + elif id(current) in _state._decorated_to_orig: |
| 699 | + decorated = current |
| 700 | + if decorated is None: |
| 701 | + continue |
| 702 | + try: |
| 703 | + with warnings.catch_warnings(): |
| 704 | + warnings.simplefilter("ignore") |
| 705 | + setattr(local_func_namespace, func_name, decorated) |
| 706 | + except (AttributeError, TypeError): |
| 707 | + pass |
| 708 | + |
| 709 | + _replace_detached_references(_state._orig_to_decorated) |
| 710 | + new_identity = torch_func_decorator(identity, "identity") |
| 711 | + new_identity.tl_is_decorated_function = True |
| 712 | + torch.identity = new_identity |
| 713 | + |
| 714 | + |
587 | 715 | # --------------------------------------------------------------------------- |
588 | 716 | # sys.modules deep crawl |
589 | 717 | # --------------------------------------------------------------------------- |
|
0 commit comments