Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 70 additions & 112 deletions python-stdlib/unittest/unittest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import io
import sys

try:
import traceback
except ImportError:
traceback = None


class SkipTest(Exception):
pass
Expand All @@ -19,9 +14,12 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, tb):
if self.expected is None:
# Used by assertWarns, do nothing.
return
self.exception = exc_value
if exc_type is None:
assert False, "%r not raised" % self.expected
raise AssertionError("%r not raised" % self.expected)
if issubclass(exc_type, self.expected):
# store exception for later retrieval
self.exception = exc_value
Expand All @@ -42,8 +40,8 @@ def __init__(self, msg=None, params=None):
def __enter__(self):
pass

def __exit__(self, *exc_info):
if exc_info[0] is not None:
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
# Exception raised
global __test_result__, __current_test__
test_details = __current_test__
Expand All @@ -53,19 +51,11 @@ def __exit__(self, *exc_info):
detail = ", ".join("%s=%s" % k_v for k_v in self.params.items())
test_details += (" (%s)" % detail,)

_handle_test_exception(test_details, __test_result__, exc_info, False)
_handle_test_exception(test_details, __test_result__, exc_value, False)
# Suppress the exception as we've captured it above
return True


class NullContext:
def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
pass


class TestCase:
def __init__(self):
pass
Expand All @@ -88,29 +78,25 @@ def skipTest(self, reason):
raise SkipTest(reason)

def fail(self, msg=""):
assert False, msg
raise AssertionError(msg)

def assertEqual(self, x, y, msg=""):
if not msg:
msg = "%r vs (expected) %r" % (x, y)
assert x == y, msg
def assertEqual(self, x, y, msg=None):
if not x == y:
raise AssertionError(msg or "%r vs (expected) %r" % (x, y))

def assertNotEqual(self, x, y, msg=""):
if not msg:
msg = "%r not expected to be equal %r" % (x, y)
assert x != y, msg
def assertNotEqual(self, x, y, msg=None):
if not x != y:
raise AssertionError(msg or "%r not expected to be equal %r" % (x, y))

def assertLessEqual(self, x, y, msg=None):
if msg is None:
msg = "%r is expected to be <= %r" % (x, y)
assert x <= y, msg
if not x <= y:
raise AssertionError(msg or "%r is expected to be <= %r" % (x, y))

def assertGreaterEqual(self, x, y, msg=None):
if msg is None:
msg = "%r is expected to be >= %r" % (x, y)
assert x >= y, msg
if not x >= y:
raise AssertionError(msg or "%r is expected to be >= %r" % (x, y))

def assertAlmostEqual(self, x, y, places=None, msg="", delta=None):
def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None):
if x == y:
return
if delta is not None and places is not None:
Expand All @@ -119,74 +105,60 @@ def assertAlmostEqual(self, x, y, places=None, msg="", delta=None):
if delta is not None:
if abs(x - y) <= delta:
return
if not msg:
msg = "%r != %r within %r delta" % (x, y, delta)
raise AssertionError(msg or "%r != %r within %r delta" % (x, y, delta))
else:
if places is None:
places = 7
if round(abs(y - x), places) == 0:
return
if not msg:
msg = "%r != %r within %r places" % (x, y, places)
raise AssertionError(msg or "%r != %r within %r places" % (x, y, places))

assert False, msg

def assertNotAlmostEqual(self, x, y, places=None, msg="", delta=None):
def assertNotAlmostEqual(self, x, y, places=None, msg=None, delta=None):
if delta is not None and places is not None:
raise TypeError("specify delta or places not both")

if delta is not None:
if not (x == y) and abs(x - y) > delta:
return
if not msg:
msg = "%r == %r within %r delta" % (x, y, delta)
raise AssertionError(msg or "%r == %r within %r delta" % (x, y, delta))
else:
if places is None:
places = 7
if not (x == y) and round(abs(y - x), places) != 0:
return
if not msg:
msg = "%r == %r within %r places" % (x, y, places)

assert False, msg

def assertIs(self, x, y, msg=""):
if not msg:
msg = "%r is not %r" % (x, y)
assert x is y, msg

def assertIsNot(self, x, y, msg=""):
if not msg:
msg = "%r is %r" % (x, y)
assert x is not y, msg

def assertIsNone(self, x, msg=""):
if not msg:
msg = "%r is not None" % x
assert x is None, msg

def assertIsNotNone(self, x, msg=""):
if not msg:
msg = "%r is None" % x
assert x is not None, msg

def assertTrue(self, x, msg=""):
if not msg:
msg = "Expected %r to be True" % x
assert x, msg

def assertFalse(self, x, msg=""):
if not msg:
msg = "Expected %r to be False" % x
assert not x, msg

def assertIn(self, x, y, msg=""):
if not msg:
msg = "Expected %r to be in %r" % (x, y)
assert x in y, msg
raise AssertionError(msg or "%r == %r within %r places" % (x, y, places))

def assertIs(self, x, y, msg=None):
if not x is y:
raise AssertionError(msg or "%r is not %r" % (x, y))

def assertIsNot(self, x, y, msg=None):
if not x is not y:
raise AssertionError(msg or "%r is %r" % (x, y))

def assertIsNone(self, x, msg=None):
if not x is None:
raise AssertionError(msg or "%r is not None" % x)

def assertIsNotNone(self, x, msg=None):
if not x is not None:
raise AssertionError(msg or "%r is None" % x)

def assertTrue(self, x, msg=None):
if not x:
raise AssertionError(msg or "Expected %r to be True" % x)

def assertFalse(self, x, msg=None):
if x:
raise AssertionError(msg or "Expected %r to be False" % x)

def assertIn(self, x, y, msg=None):
if not x in y:
raise AssertionError(msg or "Expected %r to be in %r" % (x, y))

def assertIsInstance(self, x, y, msg=""):
assert isinstance(x, y), msg
if not isinstance(x, y):
raise AssertionError(msg)

def assertRaises(self, exc, func=None, *args, **kwargs):
if func is None:
Expand All @@ -199,10 +171,10 @@ def assertRaises(self, exc, func=None, *args, **kwargs):
return
raise e

assert False, "%r not raised" % exc
raise AssertionError("%r not raised" % exc)

def assertWarns(self, warn):
return NullContext()
return AssertRaisesContext(None)


def skip(msg):
Expand All @@ -217,15 +189,11 @@ def _inner(self):


def skipIf(cond, msg):
if not cond:
return lambda x: x
return skip(msg)
return skip(msg) if cond else lambda x: x


def skipUnless(cond, msg):
if cond:
return lambda x: x
return skip(msg)
return skipIf(not cond, msg)


def expectedFailure(test):
Expand All @@ -235,7 +203,7 @@ def test_exp_fail(*args, **kwargs):
except:
pass
else:
assert False, "unexpected success"
raise AssertionError("unexpected success")

return test_exp_fail

Expand Down Expand Up @@ -332,28 +300,19 @@ def __add__(self, other):
return self


def _capture_exc(exc, exc_traceback):
buf = io.StringIO()
if hasattr(sys, "print_exception"):
sys.print_exception(exc, buf)
elif traceback is not None:
traceback.print_exception(None, exc, exc_traceback, file=buf)
return buf.getvalue()


def _handle_test_exception(
current_test: tuple, test_result: TestResult, exc_info: tuple, verbose=True
current_test: tuple, test_result: TestResult, exc: Exception, verbose=True
):
exc = exc_info[1]
traceback = exc_info[2]
ex_str = _capture_exc(exc, traceback)
if isinstance(exc, SkipTest):
reason = exc.args[0]
test_result.skippedNum += 1
test_result.skipped.append((current_test, reason))
print(" skipped:", reason)
return
elif isinstance(exc, AssertionError):
buf = io.StringIO()
sys.print_exception(exc, buf)
ex_str = buf.getvalue()
if isinstance(exc, AssertionError):
test_result.failuresNum += 1
test_result.failures.append((current_test, ex_str))
if verbose:
Expand All @@ -375,10 +334,11 @@ def _run_suite(c, test_result: TestResult, suite_name=""):
o = c()
else:
o = c
set_up_class = getattr(o, "setUpClass", lambda: None)
tear_down_class = getattr(o, "tearDownClass", lambda: None)
set_up = getattr(o, "setUp", lambda: None)
tear_down = getattr(o, "tearDown", lambda: None)
nop = lambda: None
set_up_class = getattr(o, "setUpClass", nop)
tear_down_class = getattr(o, "tearDownClass", nop)
set_up = getattr(o, "setUp", nop)
tear_down = getattr(o, "tearDown", nop)
exceptions = []
try:
suite_name += "." + c.__qualname__
Expand All @@ -402,9 +362,7 @@ def run_one(test_function):
else:
print(" ok")
except Exception as ex:
_handle_test_exception(
current_test=(name, c), test_result=test_result, exc_info=(type(ex), ex, None)
)
_handle_test_exception((name, c), test_result, ex)
# Uncomment to investigate failure in detail
# raise ex
finally:
Expand Down
Loading