Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions sdks/python/apache_beam/typehints/opcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from apache_beam.typehints import typehints
from apache_beam.typehints.trivial_inference import BoundMethod
from apache_beam.typehints.trivial_inference import Const
from apache_beam.typehints.trivial_inference import _TypeInCell
from apache_beam.typehints.trivial_inference import element_type
from apache_beam.typehints.trivial_inference import key_value_types
from apache_beam.typehints.trivial_inference import resolve_dataclass_field_type
Expand Down Expand Up @@ -570,18 +571,17 @@ def gen_start(state, arg):


def load_closure(state, arg):
# The arg is no longer offset by len(covar_names) as of 3.11
# closure_type performs version-aware resolution of the index: as of 3.11
# it refers to the frame's localsplus storage (in which the cell of a
# captured parameter shares the parameter's slot) rather than
# co_cellvars + co_freevars.
# See https://docs.python.org/3/library/dis.html#opcode-LOAD_CLOSURE
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
arg -= len(state.co.co_varnames)
state.stack.append(state.closure_type(arg))


def load_deref(state, arg):
# The arg is no longer offset by len(covar_names) as of 3.11
# See https://docs.python.org/3/library/dis.html#opcode-LOAD_DEREF
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
arg -= len(state.co.co_varnames)
# See load_closure above for how the index is resolved.
# https://docs.python.org/3/library/dis.html#opcode-LOAD_DEREF
state.stack.append(state.closure_type(arg))


Expand Down Expand Up @@ -615,7 +615,7 @@ def make_function(state, arg):
closureTuplePos = -2
else:
closureTuplePos = -3
closure = tuple((lambda _: lambda: _)(t).__closure__[0]
closure = tuple((lambda _: lambda: _)(_TypeInCell(t)).__closure__[0]
for t in state.stack[closureTuplePos].tuple_types)

func = types.FunctionType(func_code, globals, name=func_name, closure=closure)
Expand All @@ -629,7 +629,7 @@ def set_function_attribute(state, arg):
attr = state.stack.pop().value
closure = None
if arg & 0x08:
closure = tuple((lambda _: lambda: _)(t).__closure__[0]
closure = tuple((lambda _: lambda: _)(_TypeInCell(t)).__closure__[0]
for t in state.stack[attr].tuple_types)
new_func = types.FunctionType(
func.code, func.globals, name=func.name, closure=closure)
Expand Down
73 changes: 55 additions & 18 deletions sdks/python/apache_beam/typehints/trivial_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ def unwrap_all(xs):
return [Const.unwrap(x) for x in xs]


class _TypeInCell(object):
"""Marker for an inferred type stored in a synthetic closure cell.

When MAKE_FUNCTION is emulated, the closure cells of the function it
creates are populated with the inferred types of the captured variables
rather than with actual runtime values. This wrapper marks such cells so
that FrameState.closure_type can distinguish them from the cells of a real
closure, which hold runtime values.
"""
def __init__(self, value):
self.value = value


class FrameState(object):
"""Stores the state of the frame at a particular point of execution.
"""
Expand All @@ -145,20 +158,42 @@ def copy(self):
def const_type(self, i):
return Const(self.co.co_consts[i])

def get_closure(self, i):
num_cellvars = len(self.co.co_cellvars)
if i < num_cellvars:
return self.vars[i]
else:
return self.f.__closure__[i - num_cellvars].cell_contents

def closure_type(self, i):
"""Returns a TypeConstraint or Const."""
val = self.get_closure(i)
if isinstance(val, typehints.TypeConstraint):
return val
"""Returns the type of the cell or free variable with the given index.

The index is the raw oparg of a LOAD_CLOSURE or LOAD_DEREF instruction.
For Python < 3.11 it indexes co_cellvars + co_freevars. From Python 3.11
on it indexes the frame's "fast locals" (localsplus) storage, in which
the cell of a captured parameter shares the slot of the parameter
itself, so the slot layout is co_varnames, followed by the cell
variables that are not parameters, followed by the free variables.
"""
if sys.version_info >= (3, 11):
names = self.co.co_varnames + tuple(
c for c in self.co.co_cellvars
if c not in self.co.co_varnames) + self.co.co_freevars
else:
names = self.co.co_cellvars + self.co.co_freevars
name = names[i]
if name in self.co.co_freevars:
# A free variable: its cell belongs to the function's closure.
val = self.f.__closure__[self.co.co_freevars.index(name)].cell_contents
if isinstance(val, _TypeInCell):
# A synthetic cell produced while emulating MAKE_FUNCTION: it holds
# the inferred type of the captured variable, not an actual runtime
# value.
return val.value
return Const(val)
try:
# A cell variable of the current frame. The frame state tracks the
# inferred *type* of each local variable rather than its value, so the
# tracked type can be returned directly.
return self.vars[self.co.co_varnames.index(name)]
except ValueError:
# The cell variable does not correspond to a tracked local variable
# (e.g. it is only ever assigned via STORE_DEREF, which is not
# modeled), so its type is unknown.
return typehints.Any

def get_global(self, i):
name = self.get_name(i)
Expand Down Expand Up @@ -468,13 +503,15 @@ def infer_return_type_func(f, input_types, debug=False, depth=0):
print('(' + dis.cmp_op[arg] + ')', end=' ')
elif op in dis.hasfree:
if free is None:
free = co.co_cellvars + co.co_freevars
# From 3.11 on the arg is no longer offset by len(co_varnames)
# so we adjust it back
print_arg = arg
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
print_arg = arg - len(co.co_varnames)
print('(' + free[print_arg] + ')', end=' ')
# From 3.11 on the arg indexes the localsplus storage, in which
# the cell of a captured parameter shares the parameter's slot.
if (sys.version_info.major, sys.version_info.minor) >= (3, 11):
free = co.co_varnames + tuple(
c for c in co.co_cellvars
if c not in co.co_varnames) + co.co_freevars
else:
free = co.co_cellvars + co.co_freevars
print('(' + free[arg] + ')', end=' ')

# Actually emulate the op.
if state is None and states[start] is None:
Expand Down
64 changes: 62 additions & 2 deletions sdks/python/apache_beam/typehints/trivial_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# pytype: skip-file

import dataclasses
import sys
import types
import unittest

Expand Down Expand Up @@ -276,7 +277,11 @@ def testClosure(self):
y = 1.0
self.assertReturnType(typehints.Tuple[int, float], lambda: (x, y))

@unittest.skip("https://github.com/apache/beam/issues/28420")
@unittest.skipIf(
sys.version_info >= (3, 13),
'MAKE_FUNCTION closure emulation is not yet supported from Python '
'3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE '
'after function creation: https://github.com/apache/beam/issues/28420')
def testLocalClosure(self):
self.assertReturnType(
typehints.Tuple[int, int], lambda x: (x, (lambda: x)()), [int])
Expand Down Expand Up @@ -515,5 +520,60 @@ class MyDataClass:
"lambda x: (x.id, x.name, x.tags, x.custom)"), [MyDataClass])


class ClosureTypeInferenceTest(unittest.TestCase):
def assertReturnType(self, expected, f, inputs=(), depth=5):
self.assertEqual(
expected,
trivial_inference.infer_return_type(
f, inputs, debug=False, depth=depth))

@unittest.skipIf(
sys.version_info >= (3, 13),
'MAKE_FUNCTION closure emulation is not yet supported from Python '
'3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE '
'after function creation: https://github.com/apache/beam/issues/28420')
def testClosureCallingCapturedArgument(self):
# https://github.com/apache/beam/issues/28420
self.assertReturnType(
typehints.Tuple[int, int], lambda x: (x, (lambda: x)()), [int])

@unittest.skipIf(
sys.version_info >= (3, 13),
'MAKE_FUNCTION closure emulation is not yet supported from Python '
'3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE '
'after function creation: https://github.com/apache/beam/issues/28420')
def testClosureCallingCapturedArgumentOnly(self):
self.assertReturnType(int, lambda x: (lambda: x)(), [int])

@unittest.skipIf(
sys.version_info >= (3, 13),
'MAKE_FUNCTION closure emulation is not yet supported from Python '
'3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE '
'after function creation: https://github.com/apache/beam/issues/28420')
def testNestedClosureCallingCapturedArgument(self):
self.assertReturnType(int, lambda x: (lambda: (lambda: x)())(), [int])

@unittest.skipIf(
sys.version_info >= (3, 13),
'MAKE_FUNCTION closure emulation is not yet supported from Python '
'3.13, in which closures are attached via SET_FUNCTION_ATTRIBUTE '
'after function creation: https://github.com/apache/beam/issues/28420')
def testClosureCapturedArgumentMixedWithParameter(self):
self.assertReturnType(
typehints.Tuple[int, str], lambda x, y: (lambda z: (x, z))(y),
[int, str])

def testRealClosureCellsHoldValues(self):
# Cells of an actual (non-emulated) closure contain runtime values and
# must still be inferred as constants of the value's type.
v = 123
self.assertReturnType(int, lambda: v)

def testRealClosureCellHoldingClass(self):
# A captured class is a constant and calling it constructs an instance.
cls = int
self.assertReturnType(int, lambda: cls('7'))


if __name__ == '__main__':
unittest.main()
unittest.main()
Loading