diff --git a/python-stdlib/unittest/unittest/__init__.py b/python-stdlib/unittest/unittest/__init__.py index 8014e2828..5f4c21a29 100644 --- a/python-stdlib/unittest/unittest/__init__.py +++ b/python-stdlib/unittest/unittest/__init__.py @@ -1,11 +1,6 @@ import io import sys -try: - import traceback -except ImportError: - traceback = None - class SkipTest(Exception): pass @@ -19,9 +14,12 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, tb): + if self.expected is None: + # Used by assertWarns, do nothing. + return self.exception = exc_value if exc_type is None: - assert False, "%r not raised" % self.expected + raise AssertionError("%r not raised" % self.expected) if issubclass(exc_type, self.expected): # store exception for later retrieval self.exception = exc_value @@ -42,8 +40,8 @@ def __init__(self, msg=None, params=None): def __enter__(self): pass - def __exit__(self, *exc_info): - if exc_info[0] is not None: + def __exit__(self, exc_type, exc_value, tb): + if exc_type is not None: # Exception raised global __test_result__, __current_test__ test_details = __current_test__ @@ -53,19 +51,11 @@ def __exit__(self, *exc_info): detail = ", ".join("%s=%s" % k_v for k_v in self.params.items()) test_details += (" (%s)" % detail,) - _handle_test_exception(test_details, __test_result__, exc_info, False) + _handle_test_exception(test_details, __test_result__, exc_value, False) # Suppress the exception as we've captured it above return True -class NullContext: - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_value, traceback): - pass - - class TestCase: def __init__(self): pass @@ -88,29 +78,25 @@ def skipTest(self, reason): raise SkipTest(reason) def fail(self, msg=""): - assert False, msg + raise AssertionError(msg) - def assertEqual(self, x, y, msg=""): - if not msg: - msg = "%r vs (expected) %r" % (x, y) - assert x == y, msg + def assertEqual(self, x, y, msg=None): + if not x == y: + raise AssertionError(msg or "%r vs (expected) %r" % (x, y)) - def assertNotEqual(self, x, y, msg=""): - if not msg: - msg = "%r not expected to be equal %r" % (x, y) - assert x != y, msg + def assertNotEqual(self, x, y, msg=None): + if not x != y: + raise AssertionError(msg or "%r not expected to be equal %r" % (x, y)) def assertLessEqual(self, x, y, msg=None): - if msg is None: - msg = "%r is expected to be <= %r" % (x, y) - assert x <= y, msg + if not x <= y: + raise AssertionError(msg or "%r is expected to be <= %r" % (x, y)) def assertGreaterEqual(self, x, y, msg=None): - if msg is None: - msg = "%r is expected to be >= %r" % (x, y) - assert x >= y, msg + if not x >= y: + raise AssertionError(msg or "%r is expected to be >= %r" % (x, y)) - def assertAlmostEqual(self, x, y, places=None, msg="", delta=None): + def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None): if x == y: return if delta is not None and places is not None: @@ -119,74 +105,60 @@ def assertAlmostEqual(self, x, y, places=None, msg="", delta=None): if delta is not None: if abs(x - y) <= delta: return - if not msg: - msg = "%r != %r within %r delta" % (x, y, delta) + raise AssertionError(msg or "%r != %r within %r delta" % (x, y, delta)) else: if places is None: places = 7 if round(abs(y - x), places) == 0: return - if not msg: - msg = "%r != %r within %r places" % (x, y, places) + raise AssertionError(msg or "%r != %r within %r places" % (x, y, places)) - assert False, msg - - def assertNotAlmostEqual(self, x, y, places=None, msg="", delta=None): + def assertNotAlmostEqual(self, x, y, places=None, msg=None, delta=None): if delta is not None and places is not None: raise TypeError("specify delta or places not both") if delta is not None: if not (x == y) and abs(x - y) > delta: return - if not msg: - msg = "%r == %r within %r delta" % (x, y, delta) + raise AssertionError(msg or "%r == %r within %r delta" % (x, y, delta)) else: if places is None: places = 7 if not (x == y) and round(abs(y - x), places) != 0: return - if not msg: - msg = "%r == %r within %r places" % (x, y, places) - - assert False, msg - - def assertIs(self, x, y, msg=""): - if not msg: - msg = "%r is not %r" % (x, y) - assert x is y, msg - - def assertIsNot(self, x, y, msg=""): - if not msg: - msg = "%r is %r" % (x, y) - assert x is not y, msg - - def assertIsNone(self, x, msg=""): - if not msg: - msg = "%r is not None" % x - assert x is None, msg - - def assertIsNotNone(self, x, msg=""): - if not msg: - msg = "%r is None" % x - assert x is not None, msg - - def assertTrue(self, x, msg=""): - if not msg: - msg = "Expected %r to be True" % x - assert x, msg - - def assertFalse(self, x, msg=""): - if not msg: - msg = "Expected %r to be False" % x - assert not x, msg - - def assertIn(self, x, y, msg=""): - if not msg: - msg = "Expected %r to be in %r" % (x, y) - assert x in y, msg + raise AssertionError(msg or "%r == %r within %r places" % (x, y, places)) + + def assertIs(self, x, y, msg=None): + if not x is y: + raise AssertionError(msg or "%r is not %r" % (x, y)) + + def assertIsNot(self, x, y, msg=None): + if not x is not y: + raise AssertionError(msg or "%r is %r" % (x, y)) + + def assertIsNone(self, x, msg=None): + if not x is None: + raise AssertionError(msg or "%r is not None" % x) + + def assertIsNotNone(self, x, msg=None): + if not x is not None: + raise AssertionError(msg or "%r is None" % x) + + def assertTrue(self, x, msg=None): + if not x: + raise AssertionError(msg or "Expected %r to be True" % x) + + def assertFalse(self, x, msg=None): + if x: + raise AssertionError(msg or "Expected %r to be False" % x) + + def assertIn(self, x, y, msg=None): + if not x in y: + raise AssertionError(msg or "Expected %r to be in %r" % (x, y)) def assertIsInstance(self, x, y, msg=""): - assert isinstance(x, y), msg + if not isinstance(x, y): + raise AssertionError(msg) def assertRaises(self, exc, func=None, *args, **kwargs): if func is None: @@ -199,10 +171,10 @@ def assertRaises(self, exc, func=None, *args, **kwargs): return raise e - assert False, "%r not raised" % exc + raise AssertionError("%r not raised" % exc) def assertWarns(self, warn): - return NullContext() + return AssertRaisesContext(None) def skip(msg): @@ -217,15 +189,11 @@ def _inner(self): def skipIf(cond, msg): - if not cond: - return lambda x: x - return skip(msg) + return skip(msg) if cond else lambda x: x def skipUnless(cond, msg): - if cond: - return lambda x: x - return skip(msg) + return skipIf(not cond, msg) def expectedFailure(test): @@ -235,7 +203,7 @@ def test_exp_fail(*args, **kwargs): except: pass else: - assert False, "unexpected success" + raise AssertionError("unexpected success") return test_exp_fail @@ -332,28 +300,19 @@ def __add__(self, other): return self -def _capture_exc(exc, exc_traceback): - buf = io.StringIO() - if hasattr(sys, "print_exception"): - sys.print_exception(exc, buf) - elif traceback is not None: - traceback.print_exception(None, exc, exc_traceback, file=buf) - return buf.getvalue() - - def _handle_test_exception( - current_test: tuple, test_result: TestResult, exc_info: tuple, verbose=True + current_test: tuple, test_result: TestResult, exc: Exception, verbose=True ): - exc = exc_info[1] - traceback = exc_info[2] - ex_str = _capture_exc(exc, traceback) if isinstance(exc, SkipTest): reason = exc.args[0] test_result.skippedNum += 1 test_result.skipped.append((current_test, reason)) print(" skipped:", reason) return - elif isinstance(exc, AssertionError): + buf = io.StringIO() + sys.print_exception(exc, buf) + ex_str = buf.getvalue() + if isinstance(exc, AssertionError): test_result.failuresNum += 1 test_result.failures.append((current_test, ex_str)) if verbose: @@ -375,10 +334,11 @@ def _run_suite(c, test_result: TestResult, suite_name=""): o = c() else: o = c - set_up_class = getattr(o, "setUpClass", lambda: None) - tear_down_class = getattr(o, "tearDownClass", lambda: None) - set_up = getattr(o, "setUp", lambda: None) - tear_down = getattr(o, "tearDown", lambda: None) + nop = lambda: None + set_up_class = getattr(o, "setUpClass", nop) + tear_down_class = getattr(o, "tearDownClass", nop) + set_up = getattr(o, "setUp", nop) + tear_down = getattr(o, "tearDown", nop) exceptions = [] try: suite_name += "." + c.__qualname__ @@ -402,9 +362,7 @@ def run_one(test_function): else: print(" ok") except Exception as ex: - _handle_test_exception( - current_test=(name, c), test_result=test_result, exc_info=(type(ex), ex, None) - ) + _handle_test_exception((name, c), test_result, ex) # Uncomment to investigate failure in detail # raise ex finally: