|
35 | 35 | automatically. We detect active ``DeviceContext`` and inject the kwarg ourselves. |
36 | 36 | """ |
37 | 37 |
|
| 38 | +import ctypes |
38 | 39 | import inspect |
39 | 40 | import sys |
40 | 41 | import time |
|
61 | 62 | from ..data_classes.model_log import ModelLog |
62 | 63 |
|
63 | 64 |
|
| 65 | +# --------------------------------------------------------------------------- |
| 66 | +# CPython slot fixup for Tensor sequence protocol |
| 67 | +# --------------------------------------------------------------------------- |
| 68 | +# |
| 69 | +# When __getitem__ is replaced on a C extension type (like torch.Tensor) with |
| 70 | +# a Python function, CPython sets the sq_item slot in tp_as_sequence. This |
| 71 | +# makes PySequence_Check(tensor) return True, which causes torch.tensor() to |
| 72 | +# try iterating 0-d tensor elements as sequences -- calling len() which raises |
| 73 | +# TypeError. The sq_item slot is NEVER cleared by restoring the original |
| 74 | +# wrapper_descriptor or by delattr, because CPython's update_one_slot only |
| 75 | +# restores the exact slot the wrapper_descriptor wraps (mp_subscript), not |
| 76 | +# the collateral sq_item slot. |
| 77 | +# |
| 78 | +# We fix this by nulling sq_item directly via ctypes after any decoration or |
| 79 | +# undecoration cycle. This is safe because tensor indexing uses mp_subscript |
| 80 | +# (mapping protocol), not sq_item (sequence protocol). |
| 81 | + |
| 82 | + |
| 83 | +class _PySequenceMethods(ctypes.Structure): |
| 84 | + """Minimal ctypes mirror of CPython's PySequenceMethods struct.""" |
| 85 | + |
| 86 | + _fields_ = [ |
| 87 | + ("sq_length", ctypes.c_void_p), |
| 88 | + ("sq_concat", ctypes.c_void_p), |
| 89 | + ("sq_repeat", ctypes.c_void_p), |
| 90 | + ("sq_item", ctypes.c_void_p), |
| 91 | + ("was_sq_slice", ctypes.c_void_p), |
| 92 | + ("sq_ass_item", ctypes.c_void_p), |
| 93 | + ("was_sq_ass_slice", ctypes.c_void_p), |
| 94 | + ("sq_contains", ctypes.c_void_p), |
| 95 | + ("sq_inplace_concat", ctypes.c_void_p), |
| 96 | + ("sq_inplace_repeat", ctypes.c_void_p), |
| 97 | + ] |
| 98 | + |
| 99 | + |
| 100 | +class _PyTypeObject(ctypes.Structure): |
| 101 | + """Partial ctypes mirror of CPython's PyTypeObject up to tp_as_sequence. |
| 102 | +
|
| 103 | + Layout is stable across CPython 3.8+ (tp_vectorcall_offset replaced |
| 104 | + tp_print in 3.8; all earlier fields are pointer-sized regardless). |
| 105 | + """ |
| 106 | + |
| 107 | + _fields_ = [ |
| 108 | + ("ob_refcnt", ctypes.c_ssize_t), |
| 109 | + ("ob_type", ctypes.c_void_p), |
| 110 | + ("ob_size", ctypes.c_ssize_t), |
| 111 | + ("tp_name", ctypes.c_char_p), |
| 112 | + ("tp_basicsize", ctypes.c_ssize_t), |
| 113 | + ("tp_itemsize", ctypes.c_ssize_t), |
| 114 | + ("tp_dealloc", ctypes.c_void_p), |
| 115 | + ("tp_vectorcall_offset", ctypes.c_ssize_t), |
| 116 | + ("tp_getattr", ctypes.c_void_p), |
| 117 | + ("tp_setattr", ctypes.c_void_p), |
| 118 | + ("tp_as_async", ctypes.c_void_p), |
| 119 | + ("tp_repr", ctypes.c_void_p), |
| 120 | + ("tp_as_number", ctypes.c_void_p), |
| 121 | + ("tp_as_sequence", ctypes.POINTER(_PySequenceMethods)), |
| 122 | + ("tp_as_mapping", ctypes.c_void_p), |
| 123 | + ] |
| 124 | + |
| 125 | + |
| 126 | +def _fix_tensor_sequence_slot() -> None: |
| 127 | + """Clear the stale sq_item C slot on torch.Tensor after dunder changes. |
| 128 | +
|
| 129 | + Wrapping ``__getitem__`` on a C extension type pollutes the ``sq_item`` |
| 130 | + slot in ``tp_as_sequence``, making ``PySequence_Check(tensor)`` return |
| 131 | + ``True``. This breaks ``torch.tensor([0-d_tensor, ...])`` because the |
| 132 | + C code then calls ``len()`` on each element. Clearing ``sq_item`` to |
| 133 | + NULL restores the clean-state behavior where tensors are NOT treated as |
| 134 | + sequences. Tensor indexing is unaffected because it goes through |
| 135 | + ``mp_subscript`` (mapping protocol). |
| 136 | +
|
| 137 | + Safe to call multiple times. Fails silently on non-CPython or if the |
| 138 | + struct layout doesn't match (verified via ``tp_name``). |
| 139 | + """ |
| 140 | + if sys.implementation.name != "cpython": |
| 141 | + return |
| 142 | + try: |
| 143 | + type_obj = _PyTypeObject.from_address(id(torch.Tensor)) |
| 144 | + # Verify struct layout by checking tp_name |
| 145 | + if type_obj.tp_name != b"Tensor": |
| 146 | + return |
| 147 | + if type_obj.tp_as_sequence: |
| 148 | + type_obj.tp_as_sequence.contents.sq_item = None |
| 149 | + except Exception: |
| 150 | + pass # Best-effort; non-CPython or unexpected layout |
| 151 | + |
| 152 | + |
64 | 153 | def _is_inside_functorch_transform() -> bool: |
65 | 154 | """Return True if inside a vmap/grad/etc. functorch transform.""" |
66 | 155 | try: |
@@ -591,6 +680,11 @@ def decorate_all_once(): |
591 | 680 | _state._decorated_identity = torch_func_decorator(identity, "identity") |
592 | 681 | _state._is_decorated = True |
593 | 682 |
|
| 683 | + # Wrapping __getitem__ on torch.Tensor pollutes the C-level sq_item slot, |
| 684 | + # making PySequence_Check(tensor) return True. Clear it so torch.tensor() |
| 685 | + # doesn't try to iterate 0-d tensor elements as sequences. |
| 686 | + _fix_tensor_sequence_slot() |
| 687 | + |
594 | 688 |
|
595 | 689 | def _replace_detached_references(mapping: Dict[int, Callable]) -> None: |
596 | 690 | """Crawl ``sys.modules`` and replace callable references using ``mapping``. |
@@ -682,6 +776,9 @@ def unwrap_torch() -> None: |
682 | 776 | _replace_detached_references(_state._decorated_to_orig) |
683 | 777 | _state._is_decorated = False |
684 | 778 |
|
| 779 | + # Restoring Tensor.__getitem__ doesn't clear the stale sq_item slot. |
| 780 | + _fix_tensor_sequence_slot() |
| 781 | + |
685 | 782 |
|
686 | 783 | def wrap_torch() -> None: |
687 | 784 | """Install (or re-install) torchlens wrappers on all torch functions. |
@@ -729,6 +826,9 @@ def wrap_torch() -> None: |
729 | 826 | _state._is_decorated = True |
730 | 827 | patch_detached_references() |
731 | 828 |
|
| 829 | + # Re-wrapping __getitem__ pollutes sq_item again; clear it. |
| 830 | + _fix_tensor_sequence_slot() |
| 831 | + |
732 | 832 |
|
733 | 833 | @contextmanager |
734 | 834 | def wrapped(): |
|
0 commit comments