diff --git a/pytype/mixin.py b/pytype/mixin.py index c34ca554f..510942adb 100644 --- a/pytype/mixin.py +++ b/pytype/mixin.py @@ -259,7 +259,8 @@ def is_abstract(self): @property def is_test_class(self): - return any(base.full_name == "unittest.TestCase" for base in self.mro) + return any(base.full_name in ("unittest.TestCase", "unittest.case.TestCase") + for base in self.mro) @property def is_protocol(self): diff --git a/pytype/tests/py2/test_classes.py b/pytype/tests/py2/test_classes.py index 38e8e8d33..df08cedf3 100644 --- a/pytype/tests/py2/test_classes.py +++ b/pytype/tests/py2/test_classes.py @@ -23,5 +23,42 @@ class A(object): x = ... # type: Any """) + def testInitTestClassInSetup(self): + ty = self.Infer("""\ + import unittest + class A(unittest.TestCase): + def setUp(self): + self.x = 10 + def fooTest(self): + return self.x + """) + self.assertTypesMatchPytd(ty, """ + import unittest + unittest = ... # type: module + class A(unittest.TestCase): + x = ... # type: int + def fooTest(self) -> int: ... + """) + + def testInitInheritedTestClassInSetup(self): + ty = self.Infer("""\ + import unittest + class A(unittest.TestCase): + def setUp(self): + self.x = 10 + class B(A): + def fooTest(self): + return self.x + """) + self.assertTypesMatchPytd(ty, """ + import unittest + unittest = ... # type: module + class A(unittest.TestCase): + x = ... # type: int + class B(A): + x = ... # type: int + def fooTest(self) -> int: ... + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/py3/test_classes.py b/pytype/tests/py3/test_classes.py index 693336929..97dbbab14 100644 --- a/pytype/tests/py3/test_classes.py +++ b/pytype/tests/py3/test_classes.py @@ -194,27 +194,6 @@ class Foo(object): bar: int """) - def testInitTestClassInInitAndSetup(self): - ty = self.Infer("""\ - import unittest - class A(unittest.TestCase): - def __init__(self, foo: str): - self.foo = foo - def setUp(self): - self.x = 10 - def fooTest(self): - return self.x - """) - self.assertTypesMatchPytd(ty, """ - import unittest - unittest = ... # type: module - class A(unittest.TestCase): - x = ... # type: int - foo = ... # type: str - def __init__(self, foo: str) -> NoneType - def fooTest(self) -> int: ... - """) - class ClassesTestPython3Feature(test_base.TargetPython3FeatureTest): """Tests for classes.""" @@ -287,5 +266,63 @@ def f(x: Foo): f(Bar()) """) + def testInitTestClassInSetup(self): + ty = self.Infer("""\ + import unittest + class A(unittest.TestCase): + def setUp(self): + self.x = 10 + def fooTest(self): + return self.x + """) + self.assertTypesMatchPytd(ty, """ + import unittest + unittest = ... # type: module + class A(unittest.case.TestCase): + x = ... # type: int + def fooTest(self) -> int: ... + """) + + def testInitInheritedTestClassInSetup(self): + ty = self.Infer("""\ + import unittest + class A(unittest.TestCase): + def setUp(self): + self.x = 10 + class B(A): + def fooTest(self): + return self.x + """) + self.assertTypesMatchPytd(ty, """ + import unittest + unittest = ... # type: module + class A(unittest.case.TestCase): + x = ... # type: int + class B(A): + x = ... # type: int + def fooTest(self) -> int: ... + """) + + def testInitTestClassInInitAndSetup(self): + ty = self.Infer("""\ + import unittest + class A(unittest.TestCase): + def __init__(self, foo: str): + self.foo = foo + def setUp(self): + self.x = 10 + def fooTest(self): + return self.x + """) + self.assertTypesMatchPytd(ty, """ + import unittest + unittest = ... # type: module + class A(unittest.case.TestCase): + x = ... # type: int + foo = ... # type: str + def __init__(self, foo: str) -> NoneType + def fooTest(self) -> int: ... + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/test_classes.py b/pytype/tests/test_classes.py index 1ee5cdde4..066075c31 100644 --- a/pytype/tests/test_classes.py +++ b/pytype/tests/test_classes.py @@ -1367,43 +1367,6 @@ class C(object): x = ... # type: int """) - def testInitTestClassInSetup(self): - ty = self.Infer("""\ - import unittest - class A(unittest.TestCase): - def setUp(self): - self.x = 10 - def fooTest(self): - return self.x - """) - self.assertTypesMatchPytd(ty, """ - import unittest - unittest = ... # type: module - class A(unittest.TestCase): - x = ... # type: int - def fooTest(self) -> int: ... - """) - - def testInitInheritedTestClassInSetup(self): - ty = self.Infer("""\ - import unittest - class A(unittest.TestCase): - def setUp(self): - self.x = 10 - class B(A): - def fooTest(self): - return self.x - """) - self.assertTypesMatchPytd(ty, """ - import unittest - unittest = ... # type: module - class A(unittest.TestCase): - x = ... # type: int - class B(A): - x = ... # type: int - def fooTest(self) -> int: ... - """) - def testPyiNestedClass(self): # Test that pytype can look up a pyi nested class in a py file and reconsume # the inferred pyi. diff --git a/typeshed b/typeshed index 4e572ae6a..50d98acc7 160000 --- a/typeshed +++ b/typeshed @@ -1 +1 @@ -Subproject commit 4e572ae6a398c82fa9f1b0e7ba7c819274de8c3a +Subproject commit 50d98acc766e9425cb099c17b711f1bccb697584