diff --git a/Lib/copy.py b/Lib/copy.py index f86040a33c55478..38d0207254df4ad 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -50,7 +50,7 @@ class instances). import types import weakref -from copyreg import dispatch_table +from copyreg import dispatch_table, getcallable class Error(Exception): pass @@ -63,6 +63,7 @@ class Error(Exception): __all__ = ["Error", "copy", "deepcopy"] + def copy(x): """Shallow copy operation on arbitrary Python objects. @@ -83,7 +84,7 @@ def copy(x): # treat it as a regular class: return _copy_immutable(x) - copier = getattr(cls, "__copy__", None) + copier = getcallable(cls, "__copy__", None) if copier: return copier(x) @@ -91,11 +92,11 @@ def copy(x): if reductor: rv = reductor(x) else: - reductor = getattr(x, "__reduce_ex__", None) + reductor = getcallable(x, "__reduce_ex__", None) if reductor: rv = reductor(4) else: - reductor = getattr(x, "__reduce__", None) + reductor = getcallable(x, "__reduce__", None) if reductor: rv = reductor() else: @@ -113,10 +114,7 @@ def _copy_immutable(x): for t in (type(None), int, float, bool, complex, str, tuple, bytes, frozenset, type, range, slice, types.BuiltinFunctionType, type(Ellipsis), type(NotImplemented), - types.FunctionType, weakref.ref): - d[t] = _copy_immutable -t = getattr(types, "CodeType", None) -if t is not None: + types.FunctionType, weakref.ref, types.CodeType): d[t] = _copy_immutable d[list] = list.copy @@ -156,7 +154,7 @@ def deepcopy(x, memo=None, _nil=[]): if issc: y = _deepcopy_atomic(x, memo) else: - copier = getattr(x, "__deepcopy__", None) + copier = getcallable(x, "__deepcopy__", None) if copier: y = copier(memo) else: @@ -164,11 +162,11 @@ def deepcopy(x, memo=None, _nil=[]): if reductor: rv = reductor(x) else: - reductor = getattr(x, "__reduce_ex__", None) + reductor = getcallable(x, "__reduce_ex__", None) if reductor: rv = reductor(4) else: - reductor = getattr(x, "__reduce__", None) + reductor = getcallable(x, "__reduce__", None) if reductor: rv = reductor() else: @@ -278,8 +276,9 @@ def _reconstruct(x, memo, func, args, if state is not None: if deep: state = deepcopy(state, memo) - if hasattr(y, '__setstate__'): - y.__setstate__(state) + setstate = getcallable(y, '__setstate__', None) + if setstate: + setstate(state) else: if isinstance(state, tuple) and len(state) == 2: state, slotstate = state @@ -310,4 +309,5 @@ def _reconstruct(x, memo, func, args, y[key] = value return y + del types, weakref, PyStringMap diff --git a/Lib/copyreg.py b/Lib/copyreg.py index bbe1af4e2e7e717..a876357fbc26c16 100644 --- a/Lib/copyreg.py +++ b/Lib/copyreg.py @@ -65,9 +65,8 @@ def _reduce_ex(self, proto): raise TypeError("can't pickle %s objects" % base.__name__) state = base(self) args = (self.__class__, base, state) - try: - getstate = self.__getstate__ - except AttributeError: + getstate = getcallable(self, '__getstate__', None) + if not getstate: if getattr(self, "__slots__", None): raise TypeError("a class that defines __slots__ without " "defining __getstate__ cannot be pickled") from None @@ -144,6 +143,36 @@ class found there. (This assumes classes don't modify their return names +def getcallable(obj, name, fallback=None): + """Specialized getattr(): swallows exceptions, checks callability.""" + try: + arg = getattr(obj, name) + except Exception: + # Custom __getattr__ implementations can raise anything on an empty + # cls.__new__() instance, most commonly RecursionError. + return fallback + + # Custom __getattr__ implementations might return anything on + # missing attributes. We could do even more detailed checks here but + # that might break custom proxies. + if not callable(arg): + return fallback + + if name == '__reduce_ex__': + # reduce_newobj in typeobject.c and _reduce_ex in copyreg.py are + # using the __getstate__ method if available. Need to make sure + # that attribue is callable, too. + try: + getstate = getattr(obj, '__getstate__') + if not callable(getstate): + return fallback + except Exception: + # Custom __getattr__ implementations can raise anything on an empty + # cls.__new__() instance, most commonly RecursionError. + pass + + return arg + # A registry of extension codes. This is an ad-hoc compression # mechanism. Whenever a global reference to , is about # to be pickled, the (, ) tuple is looked up here to see diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index 45a692022f29bc0..e46135b801b0c6f 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -881,6 +881,27 @@ def m(self): self.assertIs(g.b.__self__, g) g.b() + def test_deepcopy_custom_getattr_recursion_limit_exceeded(self): + class Proxy(object): + def __init__(self, proxied_object): + self.proxied_object = proxied_object + + def __getattr__(self, name): + return getattr(self.proxied_object, name) + one = Proxy(1) + two = copy.deepcopy(one) + self.assertEqual(one.proxied_object, two.proxied_object) + + def test_deepcopy_custom_gettattr_non_callable(self): + class AttrDict(dict): + def __getattr__(self, name): + return self.get(name) + + one = AttrDict() + one.update({'a': 1, 'b': 2}) + two = copy.deepcopy(one) + self.assertEqual(one, two) + def global_foo(x, y): return x+y