Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Lazly override __init__ method of a Protocol subclasses
  • Loading branch information
uriyyo committed Sep 2, 2021
commit 084840a29bf70abbea70ef4c11da3ef14fc042f7
5 changes: 0 additions & 5 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,11 +1015,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME)

# typing.Protocol can override __init__ method to object.__init__
inherits_from_protocol = any(getattr(c, '_is_protocol', False) for c in cls.__bases__)
if inherits_from_protocol and cls.__dict__.get('__init__') is object.__init__:
del cls.__init__

_set_new_attribute(cls, '__init__',
_init_fn(all_init_fields,
std_init_fields,
Expand Down
2 changes: 2 additions & 0 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,8 @@ def __init__(self, x):
self.assertEqual(C(5).x, 10)

def test_inherit_from_protocol(self):
# See bpo-45081.

class P(Protocol):
a: int

Expand Down
29 changes: 19 additions & 10 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,25 @@ def _is_callable_members_only(cls):


def _no_init(self, *args, **kwargs):
raise TypeError('Protocols cannot be instantiated')
cls = type(self)

if cls._is_protocol:
raise TypeError('Protocols cannot be instantiated')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: this is the previous behavior.


# set correct __init__ method on a first initialization
# so all further initialization will call it directly
# see bpo-45081
for base in cls.__mro__:
init = base.__dict__.get('__init__', _no_init)
if init is not _no_init:
cls.__init__ = init
break
else:
# should not happen
cls.__init__ = object.__init__

cls.__init__(self, *args, **kwargs)


def _caller(depth=1, default='__main__'):
try:
Expand Down Expand Up @@ -1541,15 +1559,6 @@ def _proto_hook(other):

# We have nothing more to do for non-protocols...
if not cls._is_protocol:
if cls.__init__ == _no_init:
for base in cls.__mro__:
init = base.__dict__.get('__init__', _no_init)
if init != _no_init:
cls.__init__ = init
break
else:
# should not happen
cls.__init__ = object.__init__
return

# ... otherwise check consistency of bases, and prohibit instantiation.
Expand Down