Skip to content

Commit b2c6085

Browse files
fix(decoration): clear stale sq_item C slot after wrapping Tensor.__getitem__
When __getitem__ is replaced on a C extension type with a Python function, CPython sets the sq_item slot in tp_as_sequence. This makes PySequence_Check(tensor) return True (was False in clean PyTorch), causing torch.tensor([0-d_tensor, ...]) to iterate elements as sequences and call len() -- which raises TypeError for 0-d tensors. The slot is never cleared by restoring the original wrapper_descriptor or by delattr. Fix: null sq_item via ctypes after every decoration/undecoration cycle (decorate_all_once, unwrap_torch, wrap_torch). Safe because tensor indexing uses mp_subscript (mapping protocol), not sq_item (sequence protocol). Verified via tp_name guard; fails silently on non-CPython. Adds 9 regression tests covering all lifecycle paths.
1 parent 84d4644 commit b2c6085

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

tests/test_decoration.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,99 @@ def test_wrap_restores_logging_after_unwrap(self):
317317
assert relu_layers
318318

319319

320+
# =========================================================================
321+
# 1b. Sequence Slot Fix (sq_item pollution from __getitem__ wrapping)
322+
# =========================================================================
323+
324+
325+
class TestSequenceSlotFix:
326+
"""Wrapping Tensor.__getitem__ pollutes CPython's sq_item C slot, making
327+
PySequence_Check(tensor) return True. This breaks torch.tensor() on lists
328+
of 0-d tensors. These tests verify the fix holds across all lifecycle paths.
329+
"""
330+
331+
@pytest.fixture(autouse=True)
332+
def _ensure_wrapped(self):
333+
"""Start wrapped, ensure wrapped on exit."""
334+
wrap_torch()
335+
yield
336+
wrap_torch()
337+
338+
@staticmethod
339+
def _check_tensor_from_0d(msg: str = "") -> None:
340+
"""Assert torch.tensor([0-d, 0-d]) works without TypeError."""
341+
D = torch.randn(3, 3)
342+
result = torch.tensor([D[0, 0], D[1, 1]])
343+
assert result.shape == (2,), f"wrong shape {msg}"
344+
assert torch.allclose(result, torch.stack([D[0, 0], D[1, 1]])), msg
345+
346+
@pytest.mark.smoke
347+
def test_tensor_from_0d_while_wrapped(self):
348+
"""torch.tensor([0-d, 0-d]) must work while decoration is active."""
349+
self._check_tensor_from_0d("while wrapped")
350+
351+
@pytest.mark.smoke
352+
def test_tensor_from_0d_after_unwrap(self):
353+
"""torch.tensor([0-d, 0-d]) must work after unwrap_torch()."""
354+
unwrap_torch()
355+
self._check_tensor_from_0d("after unwrap")
356+
357+
def test_tensor_from_0d_after_wrap_unwrap_cycle(self):
358+
"""torch.tensor([0-d, 0-d]) survives multiple wrap/unwrap cycles."""
359+
for i in range(3):
360+
unwrap_torch()
361+
self._check_tensor_from_0d(f"unwrap cycle {i}")
362+
wrap_torch()
363+
self._check_tensor_from_0d(f"wrap cycle {i}")
364+
365+
def test_tensor_from_0d_wrapped_context_manager(self):
366+
"""torch.tensor([0-d, 0-d]) works inside and after wrapped()."""
367+
unwrap_torch()
368+
with wrapped():
369+
self._check_tensor_from_0d("inside wrapped()")
370+
self._check_tensor_from_0d("after wrapped() exit")
371+
372+
def test_tensor_from_0d_after_forward_pass(self):
373+
"""torch.tensor([0-d, 0-d]) works after a real forward pass."""
374+
model = SimpleModel()
375+
log_forward_pass(model, torch.randn(5))
376+
self._check_tensor_from_0d("after forward pass")
377+
378+
def test_tensor_from_0d_after_unwrap_when_done(self):
379+
"""torch.tensor([0-d, 0-d]) works after log_forward_pass(unwrap_when_done=True)."""
380+
model = SimpleModel()
381+
log_forward_pass(model, torch.randn(5), unwrap_when_done=True)
382+
self._check_tensor_from_0d("after unwrap_when_done")
383+
384+
def test_tensor_from_0d_nested_list(self):
385+
"""torch.tensor with nested lists of 0-d tensors also works."""
386+
D = torch.randn(2, 2)
387+
result = torch.tensor([[D[0, 0], D[0, 1]], [D[1, 0], D[1, 1]]])
388+
assert result.shape == (2, 2)
389+
assert torch.allclose(result, D)
390+
391+
def test_tensor_indexing_still_works(self):
392+
"""Clearing sq_item must not break normal tensor indexing."""
393+
x = torch.randn(3, 4, 5)
394+
assert x[0].shape == (4, 5)
395+
assert x[0, 1].shape == (5,)
396+
assert x[0, 1, 2].shape == ()
397+
assert x[:, 1:3].shape == (3, 2, 5)
398+
assert x[torch.tensor([0, 2])].shape == (2, 4, 5)
399+
400+
def test_sequence_check_false_for_tensors(self):
401+
"""PySequence_Check must return False for tensors after decoration."""
402+
import ctypes
403+
404+
if sys.implementation.name != "cpython":
405+
pytest.skip("ctypes slot check only works on CPython")
406+
check = ctypes.pythonapi.PySequence_Check
407+
check.argtypes = [ctypes.py_object]
408+
check.restype = ctypes.c_int
409+
t = torch.tensor(0.5)
410+
assert check(t) == 0, "PySequence_Check(tensor) should be False"
411+
412+
320413
# =========================================================================
321414
# 2. Torch Functions Normal When Toggle Off
322415
# =========================================================================

torchlens/decoration/torch_funcs.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
automatically. We detect active ``DeviceContext`` and inject the kwarg ourselves.
3636
"""
3737

38+
import ctypes
3839
import inspect
3940
import sys
4041
import time
@@ -61,6 +62,94 @@
6162
from ..data_classes.model_log import ModelLog
6263

6364

65+
# ---------------------------------------------------------------------------
66+
# CPython slot fixup for Tensor sequence protocol
67+
# ---------------------------------------------------------------------------
68+
#
69+
# When __getitem__ is replaced on a C extension type (like torch.Tensor) with
70+
# a Python function, CPython sets the sq_item slot in tp_as_sequence. This
71+
# makes PySequence_Check(tensor) return True, which causes torch.tensor() to
72+
# try iterating 0-d tensor elements as sequences -- calling len() which raises
73+
# TypeError. The sq_item slot is NEVER cleared by restoring the original
74+
# wrapper_descriptor or by delattr, because CPython's update_one_slot only
75+
# restores the exact slot the wrapper_descriptor wraps (mp_subscript), not
76+
# the collateral sq_item slot.
77+
#
78+
# We fix this by nulling sq_item directly via ctypes after any decoration or
79+
# undecoration cycle. This is safe because tensor indexing uses mp_subscript
80+
# (mapping protocol), not sq_item (sequence protocol).
81+
82+
83+
class _PySequenceMethods(ctypes.Structure):
84+
"""Minimal ctypes mirror of CPython's PySequenceMethods struct."""
85+
86+
_fields_ = [
87+
("sq_length", ctypes.c_void_p),
88+
("sq_concat", ctypes.c_void_p),
89+
("sq_repeat", ctypes.c_void_p),
90+
("sq_item", ctypes.c_void_p),
91+
("was_sq_slice", ctypes.c_void_p),
92+
("sq_ass_item", ctypes.c_void_p),
93+
("was_sq_ass_slice", ctypes.c_void_p),
94+
("sq_contains", ctypes.c_void_p),
95+
("sq_inplace_concat", ctypes.c_void_p),
96+
("sq_inplace_repeat", ctypes.c_void_p),
97+
]
98+
99+
100+
class _PyTypeObject(ctypes.Structure):
101+
"""Partial ctypes mirror of CPython's PyTypeObject up to tp_as_sequence.
102+
103+
Layout is stable across CPython 3.8+ (tp_vectorcall_offset replaced
104+
tp_print in 3.8; all earlier fields are pointer-sized regardless).
105+
"""
106+
107+
_fields_ = [
108+
("ob_refcnt", ctypes.c_ssize_t),
109+
("ob_type", ctypes.c_void_p),
110+
("ob_size", ctypes.c_ssize_t),
111+
("tp_name", ctypes.c_char_p),
112+
("tp_basicsize", ctypes.c_ssize_t),
113+
("tp_itemsize", ctypes.c_ssize_t),
114+
("tp_dealloc", ctypes.c_void_p),
115+
("tp_vectorcall_offset", ctypes.c_ssize_t),
116+
("tp_getattr", ctypes.c_void_p),
117+
("tp_setattr", ctypes.c_void_p),
118+
("tp_as_async", ctypes.c_void_p),
119+
("tp_repr", ctypes.c_void_p),
120+
("tp_as_number", ctypes.c_void_p),
121+
("tp_as_sequence", ctypes.POINTER(_PySequenceMethods)),
122+
("tp_as_mapping", ctypes.c_void_p),
123+
]
124+
125+
126+
def _fix_tensor_sequence_slot() -> None:
127+
"""Clear the stale sq_item C slot on torch.Tensor after dunder changes.
128+
129+
Wrapping ``__getitem__`` on a C extension type pollutes the ``sq_item``
130+
slot in ``tp_as_sequence``, making ``PySequence_Check(tensor)`` return
131+
``True``. This breaks ``torch.tensor([0-d_tensor, ...])`` because the
132+
C code then calls ``len()`` on each element. Clearing ``sq_item`` to
133+
NULL restores the clean-state behavior where tensors are NOT treated as
134+
sequences. Tensor indexing is unaffected because it goes through
135+
``mp_subscript`` (mapping protocol).
136+
137+
Safe to call multiple times. Fails silently on non-CPython or if the
138+
struct layout doesn't match (verified via ``tp_name``).
139+
"""
140+
if sys.implementation.name != "cpython":
141+
return
142+
try:
143+
type_obj = _PyTypeObject.from_address(id(torch.Tensor))
144+
# Verify struct layout by checking tp_name
145+
if type_obj.tp_name != b"Tensor":
146+
return
147+
if type_obj.tp_as_sequence:
148+
type_obj.tp_as_sequence.contents.sq_item = None
149+
except Exception:
150+
pass # Best-effort; non-CPython or unexpected layout
151+
152+
64153
def _is_inside_functorch_transform() -> bool:
65154
"""Return True if inside a vmap/grad/etc. functorch transform."""
66155
try:
@@ -591,6 +680,11 @@ def decorate_all_once():
591680
_state._decorated_identity = torch_func_decorator(identity, "identity")
592681
_state._is_decorated = True
593682

683+
# Wrapping __getitem__ on torch.Tensor pollutes the C-level sq_item slot,
684+
# making PySequence_Check(tensor) return True. Clear it so torch.tensor()
685+
# doesn't try to iterate 0-d tensor elements as sequences.
686+
_fix_tensor_sequence_slot()
687+
594688

595689
def _replace_detached_references(mapping: Dict[int, Callable]) -> None:
596690
"""Crawl ``sys.modules`` and replace callable references using ``mapping``.
@@ -682,6 +776,9 @@ def unwrap_torch() -> None:
682776
_replace_detached_references(_state._decorated_to_orig)
683777
_state._is_decorated = False
684778

779+
# Restoring Tensor.__getitem__ doesn't clear the stale sq_item slot.
780+
_fix_tensor_sequence_slot()
781+
685782

686783
def wrap_torch() -> None:
687784
"""Install (or re-install) torchlens wrappers on all torch functions.
@@ -729,6 +826,9 @@ def wrap_torch() -> None:
729826
_state._is_decorated = True
730827
patch_detached_references()
731828

829+
# Re-wrapping __getitem__ pollutes sq_item again; clear it.
830+
_fix_tensor_sequence_slot()
831+
732832

733833
@contextmanager
734834
def wrapped():

0 commit comments

Comments
 (0)