From a83184352c68051d21de03f057abd102d1b2f87f Mon Sep 17 00:00:00 2001 From: Lakshit Bahl Date: Fri, 3 Jul 2026 02:39:46 +0200 Subject: [PATCH] [#28420] Fix type inference for closures over local variables --- sdks/python/apache_beam/typehints/opcodes.py | 18 ++--- .../typehints/trivial_inference.py | 73 ++++++++++++++----- .../typehints/trivial_inference_test.py | 64 +++++++++++++++- 3 files changed, 126 insertions(+), 29 deletions(-) diff --git a/sdks/python/apache_beam/typehints/opcodes.py b/sdks/python/apache_beam/typehints/opcodes.py index 53eabdadc4af..2e74c9156650 100644 --- a/sdks/python/apache_beam/typehints/opcodes.py +++ b/sdks/python/apache_beam/typehints/opcodes.py @@ -40,6 +40,7 @@ from apache_beam.typehints import typehints from apache_beam.typehints.trivial_inference import BoundMethod from apache_beam.typehints.trivial_inference import Const +from apache_beam.typehints.trivial_inference import _TypeInCell from apache_beam.typehints.trivial_inference import element_type from apache_beam.typehints.trivial_inference import key_value_types from apache_beam.typehints.trivial_inference import resolve_dataclass_field_type @@ -570,18 +571,17 @@ def gen_start(state, arg): def load_closure(state, arg): - # The arg is no longer offset by len(covar_names) as of 3.11 + # closure_type performs version-aware resolution of the index: as of 3.11 + # it refers to the frame's localsplus storage (in which the cell of a + # captured parameter shares the parameter's slot) rather than + # co_cellvars + co_freevars. # See https://docs.python.org/3/library/dis.html#opcode-LOAD_CLOSURE - if (sys.version_info.major, sys.version_info.minor) >= (3, 11): - arg -= len(state.co.co_varnames) state.stack.append(state.closure_type(arg)) def load_deref(state, arg): - # The arg is no longer offset by len(covar_names) as of 3.11 - # See https://docs.python.org/3/library/dis.html#opcode-LOAD_DEREF - if (sys.version_info.major, sys.version_info.minor) >= (3, 11): - arg -= len(state.co.co_varnames) + # See load_closure above for how the index is resolved. + # https://docs.python.org/3/library/dis.html#opcode-LOAD_DEREF state.stack.append(state.closure_type(arg)) @@ -615,7 +615,7 @@ def make_function(state, arg): closureTuplePos = -2 else: closureTuplePos = -3 - closure = tuple((lambda _: lambda: _)(t).__closure__[0] + closure = tuple((lambda _: lambda: _)(_TypeInCell(t)).__closure__[0] for t in state.stack[closureTuplePos].tuple_types) func = types.FunctionType(func_code, globals, name=func_name, closure=closure) @@ -629,7 +629,7 @@ def set_function_attribute(state, arg): attr = state.stack.pop().value closure = None if arg & 0x08: - closure = tuple((lambda _: lambda: _)(t).__closure__[0] + closure = tuple((lambda _: lambda: _)(_TypeInCell(t)).__closure__[0] for t in state.stack[attr].tuple_types) new_func = types.FunctionType( func.code, func.globals, name=func.name, closure=closure) diff --git a/sdks/python/apache_beam/typehints/trivial_inference.py b/sdks/python/apache_beam/typehints/trivial_inference.py index 69edfc309281..f8740ab976ec 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference.py +++ b/sdks/python/apache_beam/typehints/trivial_inference.py @@ -123,6 +123,19 @@ def unwrap_all(xs): return [Const.unwrap(x) for x in xs] +class _TypeInCell(object): + """Marker for an inferred type stored in a synthetic closure cell. + + When MAKE_FUNCTION is emulated, the closure cells of the function it + creates are populated with the inferred types of the captured variables + rather than with actual runtime values. This wrapper marks such cells so + that FrameState.closure_type can distinguish them from the cells of a real + closure, which hold runtime values. + """ + def __init__(self, value): + self.value = value + + class FrameState(object): """Stores the state of the frame at a particular point of execution. """ @@ -145,20 +158,42 @@ def copy(self): def const_type(self, i): return Const(self.co.co_consts[i]) - def get_closure(self, i): - num_cellvars = len(self.co.co_cellvars) - if i < num_cellvars: - return self.vars[i] - else: - return self.f.__closure__[i - num_cellvars].cell_contents - def closure_type(self, i): - """Returns a TypeConstraint or Const.""" - val = self.get_closure(i) - if isinstance(val, typehints.TypeConstraint): - return val + """Returns the type of the cell or free variable with the given index. + + The index is the raw oparg of a LOAD_CLOSURE or LOAD_DEREF instruction. + For Python < 3.11 it indexes co_cellvars + co_freevars. From Python 3.11 + on it indexes the frame's "fast locals" (localsplus) storage, in which + the cell of a captured parameter shares the slot of the parameter + itself, so the slot layout is co_varnames, followed by the cell + variables that are not parameters, followed by the free variables. + """ + if sys.version_info >= (3, 11): + names = self.co.co_varnames + tuple( + c for c in self.co.co_cellvars + if c not in self.co.co_varnames) + self.co.co_freevars else: + names = self.co.co_cellvars + self.co.co_freevars + name = names[i] + if name in self.co.co_freevars: + # A free variable: its cell belongs to the function's closure. + val = self.f.__closure__[self.co.co_freevars.index(name)].cell_contents + if isinstance(val, _TypeInCell): + # A synthetic cell produced while emulating MAKE_FUNCTION: it holds + # the inferred type of the captured variable, not an actual runtime + # value. + return val.value return Const(val) + try: + # A cell variable of the current frame. The frame state tracks the + # inferred *type* of each local variable rather than its value, so the + # tracked type can be returned directly. + return self.vars[self.co.co_varnames.index(name)] + except ValueError: + # The cell variable does not correspond to a tracked local variable + # (e.g. it is only ever assigned via STORE_DEREF, which is not + # modeled), so its type is unknown. + return typehints.Any def get_global(self, i): name = self.get_name(i) @@ -468,13 +503,15 @@ def infer_return_type_func(f, input_types, debug=False, depth=0): print('(' + dis.cmp_op[arg] + ')', end=' ') elif op in dis.hasfree: if free is None: - free = co.co_cellvars + co.co_freevars - # From 3.11 on the arg is no longer offset by len(co_varnames) - # so we adjust it back - print_arg = arg - if (sys.version_info.major, sys.version_info.minor) >= (3, 11): - print_arg = arg - len(co.co_varnames) - print('(' + free[print_arg] + ')', end=' ') + # From 3.11 on the arg indexes the localsplus storage, in which + # the cell of a captured parameter shares the parameter's slot. + if (sys.version_info.major, sys.version_info.minor) >= (3, 11): + free = co.co_varnames + tuple( + c for c in co.co_cellvars + if c not in co.co_varnames) + co.co_freevars + else: + free = co.co_cellvars + co.co_freevars + print('(' + free[arg] + ')', end=' ') # Actually emulate the op. if state is None and states[start] is None: diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index dcb0bac97e80..763a08bb89b8 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -20,6 +20,7 @@ # pytype: skip-file import dataclasses +import sys import types import unittest @@ -276,7 +277,11 @@ def testClosure(self): y = 1.0 self.assertReturnType(typehints.Tuple[int, float], lambda: (x, y)) - @unittest.skip("https://github.com/apache/beam/issues/28420") + @unittest.skipIf( + sys.version_info >= (3, 13), + 'MAKE_FUNCTION closure emulation is not yet supported from Python ' + '3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE ' + 'after function creation: https://github.com/apache/beam/issues/28420') def testLocalClosure(self): self.assertReturnType( typehints.Tuple[int, int], lambda x: (x, (lambda: x)()), [int]) @@ -515,5 +520,60 @@ class MyDataClass: "lambda x: (x.id, x.name, x.tags, x.custom)"), [MyDataClass]) +class ClosureTypeInferenceTest(unittest.TestCase): + def assertReturnType(self, expected, f, inputs=(), depth=5): + self.assertEqual( + expected, + trivial_inference.infer_return_type( + f, inputs, debug=False, depth=depth)) + + @unittest.skipIf( + sys.version_info >= (3, 13), + 'MAKE_FUNCTION closure emulation is not yet supported from Python ' + '3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE ' + 'after function creation: https://github.com/apache/beam/issues/28420') + def testClosureCallingCapturedArgument(self): + # https://github.com/apache/beam/issues/28420 + self.assertReturnType( + typehints.Tuple[int, int], lambda x: (x, (lambda: x)()), [int]) + + @unittest.skipIf( + sys.version_info >= (3, 13), + 'MAKE_FUNCTION closure emulation is not yet supported from Python ' + '3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE ' + 'after function creation: https://github.com/apache/beam/issues/28420') + def testClosureCallingCapturedArgumentOnly(self): + self.assertReturnType(int, lambda x: (lambda: x)(), [int]) + + @unittest.skipIf( + sys.version_info >= (3, 13), + 'MAKE_FUNCTION closure emulation is not yet supported from Python ' + '3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE ' + 'after function creation: https://github.com/apache/beam/issues/28420') + def testNestedClosureCallingCapturedArgument(self): + self.assertReturnType(int, lambda x: (lambda: (lambda: x)())(), [int]) + + @unittest.skipIf( + sys.version_info >= (3, 13), + 'MAKE_FUNCTION closure emulation is not yet supported from Python ' + '3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE ' + 'after function creation: https://github.com/apache/beam/issues/28420') + def testClosureCapturedArgumentMixedWithParameter(self): + self.assertReturnType( + typehints.Tuple[int, str], lambda x, y: (lambda z: (x, z))(y), + [int, str]) + + def testRealClosureCellsHoldValues(self): + # Cells of an actual (non-emulated) closure contain runtime values and + # must still be inferred as constants of the value's type. + v = 123 + self.assertReturnType(int, lambda: v) + + def testRealClosureCellHoldingClass(self): + # A captured class is a constant and calling it constructs an instance. + cls = int + self.assertReturnType(int, lambda: cls('7')) + + if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file