From 7c0c61cf539ca312fcb9c0302837758a6a78e75f Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 5 Aug 2022 21:20:45 +0200 Subject: [PATCH] various small fixes and improved test output --- flake8_trio.py | 83 +++++++++++++++--------------- tests/test_flake8_trio.py | 104 ++++++++++++++++++++++++++------------ tests/trio100_py39.py | 4 +- tests/trio101.py | 1 - tests/trio102.py | 8 +-- tox.ini | 3 ++ 6 files changed, 125 insertions(+), 78 deletions(-) diff --git a/flake8_trio.py b/flake8_trio.py index 49deb493..7bd27fca 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -11,13 +11,14 @@ import ast import tokenize -from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" __version__ = "22.7.6" - Error = Tuple[int, int, str, Type[Any]] + + checkpoint_node_types = (ast.Await, ast.AsyncFor, ast.AsyncWith) cancel_scope_names = ( "fail_after", @@ -39,13 +40,13 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) -> class Flake8TrioVisitor(ast.NodeVisitor): def __init__(self): super().__init__() - self.problems: List[Error] = [] + self._problems: List[Error] = [] @classmethod - def run(cls, tree: ast.AST) -> Generator[Error, None, None]: + def run(cls, tree: ast.AST) -> Iterable[Error]: visitor = cls() visitor.visit(tree) - yield from visitor.problems + yield from visitor._problems def visit_nodes( self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False @@ -62,7 +63,16 @@ def visit_nodes( visit(node) def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any): - self.problems.append(make_error(error, lineno, col, *args, **kwargs)) + self._problems.append(make_error(error, lineno, col, *args, **kwargs)) + + def get_state(self, *attrs: str) -> Dict[str, Any]: + if not attrs: + attrs = tuple(self.__dict__.keys()) + return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"} + + def set_state(self, attrs: Dict[str, Any]): + for attr, value in attrs.items(): + setattr(self, attr, value) class TrioScope: @@ -87,8 +97,6 @@ def __init__(self, node: ast.Call, funcname: str, packagename: str): def __str__(self): # Not supporting other ways of importing trio - # if self.packagename is None: - # return self.funcname return f"{self.packagename}.{self.funcname}" @@ -100,7 +108,6 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]: and node.func.value.id == "trio" and node.func.attr in names ): - # return "trio." + node.func.attr return TrioScope(node, node.func.attr, node.func.value.id) return None @@ -124,7 +131,7 @@ def __init__(self): def visit_With(self, node: Union[ast.With, ast.AsyncWith]): self.check_for_trio100(node) - outer_yie = self._yield_is_error + outer = self.get_state("_yield_is_error") # Check for a `with trio.` if not self._safe_decorator: @@ -139,13 +146,13 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]): self.generic_visit(node) # reset yield_is_error - self._yield_is_error = outer_yie + self.set_state(outer) def visit_AsyncWith(self, node: ast.AsyncWith): self.visit_With(node) def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): - outer = self._safe_decorator, self._yield_is_error + outer = self.get_state() self._yield_is_error = False # check for @ and @. @@ -154,14 +161,14 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): self.generic_visit(node) - self._safe_decorator, self._yield_is_error = outer + self.set_state(outer) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): self.visit_FunctionDef(node) def visit_Yield(self, node: ast.Yield): if self._yield_is_error: - self.problems.append(make_error(TRIO101, node.lineno, node.col_offset)) + self.error(TRIO101, node.lineno, node.col_offset) self.generic_visit(node) @@ -173,19 +180,17 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]): isinstance(x, checkpoint_node_types) and x != node for x in ast.walk(node) ): - self.problems.append( - make_error(TRIO100, item.lineno, item.col_offset, call) - ) + self.error(TRIO100, item.lineno, item.col_offset, call) def visit_ImportFrom(self, node: ast.ImportFrom): if node.module == "trio": - self.problems.append(make_error(TRIO106, node.lineno, node.col_offset)) + self.error(TRIO106, node.lineno, node.col_offset) self.generic_visit(node) def visit_Import(self, node: ast.Import): for name in node.names: if name.name == "trio" and name.asname is not None: - self.problems.append(make_error(TRIO106, node.lineno, node.col_offset)) + self.error(TRIO106, node.lineno, node.col_offset) def critical_except(node: ast.ExceptHandler) -> Optional[Tuple[int, int, str]]: @@ -239,9 +244,7 @@ def visit_Await( cm.has_timeout and cm.shielded for cm in self._trio_context_managers ) ): - self.problems.append( - make_error(TRIO102, node.lineno, node.col_offset, *self._critical_scope) - ) + self.error(TRIO102, node.lineno, node.col_offset, *self._critical_scope) if visit_children: self.generic_visit(node) @@ -275,14 +278,15 @@ def visit_AsyncWith(self, node: ast.AsyncWith): self.visit_With(node) def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): - outer_cm = self._safe_decorator + outer = self.get_state("_safe_decorator") # check for @ and @. if has_decorator(node.decorator_list, *context_manager_names): self._safe_decorator = True self.generic_visit(node) - self._safe_decorator = outer_cm + + self.set_state(outer) visit_AsyncFunctionDef = visit_FunctionDef @@ -292,13 +296,13 @@ def critical_visit( block: Tuple[int, int, str], generic: bool = False, ): - outer = self._critical_scope, self._trio_context_managers + outer = self.get_state("_critical_scope", "_trio_context_managers") self._trio_context_managers = [] self._critical_scope = block self.visit_nodes(node, generic=generic) - self._critical_scope, self._trio_context_managers = outer + self.set_state(outer) def visit_Try(self, node: ast.Try): # There's no visit_Finally, so we need to manually visit the Try fields. @@ -345,7 +349,7 @@ def __init__(self): # then there might be a code path that doesn't re-raise. def visit_ExceptHandler(self, node: ast.ExceptHandler): - outer = (self.unraised, self.except_name, self.loop_depth) + outer = self.get_state() marker = critical_except(node) # we need to *not* unset self.unraised if this is non-critical, to still @@ -362,10 +366,9 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): self.generic_visit(node) if self.unraised and marker is not None: - # print(marker) - self.problems.append(make_error(TRIO103, *marker)) + self.error(TRIO103, *marker) - (self.unraised, self.except_name, self.loop_depth) = outer + self.set_state(outer) def visit_Raise(self, node: ast.Raise): # if there's an unraised critical exception, the raise isn't bare, @@ -375,7 +378,7 @@ def visit_Raise(self, node: ast.Raise): and node.exc is not None and not (isinstance(node.exc, ast.Name) and node.exc.id == self.except_name) ): - self.problems.append(make_error(TRIO104, node.lineno, node.col_offset)) + self.error(TRIO104, node.lineno, node.col_offset) # treat it as safe regardless, to avoid unnecessary error messages. self.unraised = False @@ -385,7 +388,7 @@ def visit_Raise(self, node: ast.Raise): def visit_Return(self, node: Union[ast.Return, ast.Yield]): if self.unraised: # Error: must re-raise - self.problems.append(make_error(TRIO104, node.lineno, node.col_offset)) + self.error(TRIO104, node.lineno, node.col_offset) self.generic_visit(node) visit_Yield = visit_Return @@ -434,20 +437,22 @@ def visit_If(self, node: ast.If): # we completely disregard them when checking coverage by resetting the # effects of them afterwards def visit_For(self, node: Union[ast.For, ast.While]): - outer_unraised = self.unraised + outer = self.get_state("unraised") + self.loop_depth += 1 for n in node.body: self.visit(n) self.loop_depth -= 1 for n in node.orelse: self.visit(n) - self.unraised = outer_unraised + + self.set_state(outer) visit_While = visit_For def visit_Break(self, node: Union[ast.Break, ast.Continue]): if self.unraised and self.loop_depth == 0: - self.problems.append(make_error(TRIO104, node.lineno, node.col_offset)) + self.error(TRIO104, node.lineno, node.col_offset) self.generic_visit(node) visit_Continue = visit_Break @@ -492,9 +497,7 @@ def visit_Call(self, node: ast.Call): or not isinstance(self.node_stack[-2], ast.Await) ) ): - self.problems.append( - make_error(TRIO105, node.lineno, node.col_offset, node.func.attr) - ) + self.error(TRIO105, node.lineno, node.col_offset, node.func.attr) self.generic_visit(node) @@ -615,7 +618,7 @@ def from_filename(cls, filename: str) -> "Plugin": source = f.read() return cls(ast.parse(source)) - def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: + def run(self) -> Iterable[Error]: for v in Flake8TrioVisitor.__subclasses__(): yield from v.run(self._tree) @@ -625,7 +628,7 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: TRIO102 = "TRIO102: await inside {2} on line {0} must have shielded cancel scope with a timeout" TRIO103 = "TRIO103: {} block with a code path that doesn't re-raise the error" TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised" -TRIO105 = "TRIO105: Trio async function {} must be immediately awaited" +TRIO105 = "TRIO105: trio async function {} must be immediately awaited" TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work" TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised" TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it." diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 3a43b96b..78d86cf2 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -39,55 +39,97 @@ def test_eval(test: str, path: str): major, minor = version_str[0], version_str[1:] v_i = sys.version_info if (v_i.major, v_i.minor) < (int(major), int(minor)): - return + raise unittest.SkipTest("v_i, major, minor") test = test.split("_")[0] error_msg = getattr(flake8_trio, test) expected: List[Error] = [] with open(os.path.join("tests", path)) as file: - for lineno, line in enumerate(file): - # get text between `error: ` and newline - k = re.search(r"(?<=error: ).*(?=\n)", line) - if not k: + for lineno, line in enumerate(file, start=1): + # get text between `error:` and end of line + k = re.search(r"(?<=error:).*$", line) + if not k or line.strip()[0] == "#": continue - # Append a bunch of 0's so string formatting gives garbage instead + # Append a bunch of empty strings so string formatting gives garbage instead # of throwing an exception - args = [m.strip() for m in k.group().split(",")] + ["0"] * 5 + args = [m.strip() for m in k.group().split(",")] + [""] * 5 col, *args = args - expected.append(make_error(error_msg, lineno + 1, int(col), *args)) + for i, arg in enumerate(args): + if "$lineno" in arg: + args[i] = eval(arg.replace("$", ""), {"lineno": lineno}) + assert col.isdigit(), f'invalid column "{col}" @L{lineno}, in "{line}"' + expected.append(make_error(error_msg, lineno, int(col), *args)) + assert expected, "failed to parse any errors in file" assert_expected_errors(path, test, *expected) -# This function is also a mess now, but I keep slowly iterating on getting it to -# print actually helpful error messages in all cases - which is a struggle. -# It'll likely continue to be a mess for the foreseeable future -def assert_expected_errors(test_file: str, include: str, *expected: Error) -> None: - def trim_messages(messages: Iterable[Error]): - return tuple(((line, col, int(msg[4:7])) for line, col, msg, _ in messages)) - +def assert_expected_errors(test_file: str, include: str, *expected: Error): filename = Path(__file__).absolute().parent / test_file plugin = Plugin.from_filename(str(filename)) - errors = tuple(e for e in plugin.run() if include in e[2]) - - # start with a check with trimmed errors that will make for smaller diff messages - trim_errors = trim_messages(errors) - trim_expected = trim_messages(expected) - - cls = unittest.TestCase() - unexpected = sorted(set(trim_errors) - set(trim_expected)) - missing = sorted(set(trim_expected) - set(trim_errors)) - cls.assertEqual((unexpected, missing), ([], []), msg="(unexpected, missing)") + errors = tuple(sorted(e for e in plugin.run() if include in e[2])) + expected = tuple(sorted(expected)) - unexpected = sorted(set(errors) - set(expected)) - missing = sorted(set(expected) - set(errors)) - if unexpected and missing: - cls.assertEqual(unexpected[0], missing[0]) - cls.assertEqual((unexpected, missing), ([], []), msg="(unexpected, missing)") + assert_correct_lines(errors, expected) + assert_correct_columns(errors, expected) + assert_correct_messages(errors, expected) # full check - cls.assertSequenceEqual(sorted(errors), sorted(expected)) + unittest.TestCase().assertSequenceEqual(sorted(errors), sorted(expected)) + + +def assert_correct_lines(errors: Iterable[Error], expected: Iterable[Error]): + # Check that errors are on correct lines + error_lines = {line for line, *_ in errors} + expected_lines = {line for line, *_ in expected} + unexpected_lines = sorted(error_lines - expected_lines) + missing_lines = sorted(expected_lines - error_lines) + unittest.TestCase().assertEqual( + unexpected_lines, + missing_lines, + msg="Lines with unexpected errors; missing errors", + ) + + +def assert_correct_columns(errors: Iterable[Error], expected: Iterable[Error]): + # check errors have correct columns + col_error = False + for (line, error_col, *_), (_, expected_col, *_) in zip(errors, expected): + if error_col != expected_col: + if not col_error: + print("Errors with same line but different columns:", file=sys.stderr) + print("| line | actual | expected |", file=sys.stderr) + col_error = True + print( + f"| {line:4} | {error_col:6} | {expected_col:8} |", + file=sys.stderr, + ) + assert not col_error + + +def assert_correct_messages(errors: Iterable[Error], expected: Iterable[Error]): + # check errors have correct messages + msg_error = False + for (line, _, error_msg, *_), (_, _, expected_msg, *_) in zip(errors, expected): + if error_msg != expected_msg: + if not msg_error: + print( + "Errors with different messages:", + "-" * 20, + sep="\n", + file=sys.stderr, + ) + msg_error = True + print( + f"* line: {line:3}", + f" actual: {error_msg}", + f"expected: {expected_msg}", + "-" * 20, + sep="\n", + file=sys.stderr, + ) + assert not msg_error @pytest.mark.fuzz diff --git a/tests/trio100_py39.py b/tests/trio100_py39.py index 29ea1930..b8cccb6f 100644 --- a/tests/trio100_py39.py +++ b/tests/trio100_py39.py @@ -3,14 +3,14 @@ async def function_name(): with ( - open("veryverylongfilenamesoshedsplitsthisintotwolines") as _, + open("") as _, trio.fail_after(10), # error: 8, trio.fail_after ): pass with ( trio.fail_after(5), # error: 8, trio.fail_after - open("veryverylongfilenamesoshedsplitsthisintotwolines") as _, + open("") as _, trio.move_on_after(5), # error: 8, trio.move_on_after ): pass diff --git a/tests/trio101.py b/tests/trio101.py index bc6760e1..0815d93b 100644 --- a/tests/trio101.py +++ b/tests/trio101.py @@ -23,7 +23,6 @@ def foo2(): async def foo3(): async with trio.CancelScope() as _: - await trio.sleep(1) # so trio100 doesn't complain yield 1 # error: 8 diff --git a/tests/trio102.py b/tests/trio102.py index e8dafbbc..f1a51581 100644 --- a/tests/trio102.py +++ b/tests/trio102.py @@ -122,7 +122,7 @@ async def foo3(): await foo() # safe with trio.fail_after(5), trio.move_on_after(30) as s: s.shield = True - await foo() # safe in theory, error: 12, 115, 4, try/finally + await foo() # safe in theory, error: 12, $lineno-10, 4, try/finally # New: except cancelled/baseexception are also critical @@ -132,11 +132,11 @@ async def foo4(): except ValueError: await foo() # safe except trio.Cancelled: - await foo() # error: 8, 134, 11, trio.Cancelled + await foo() # error: 8, $lineno-1, 11, trio.Cancelled except BaseException: - await foo() # error: 8, 136, 11, BaseException + await foo() # error: 8, $lineno-1, 11, BaseException except: - await foo() # error: 8, 138, 4, bare except + await foo() # error: 8, $lineno-1, 4, bare except async def foo5(): diff --git a/tox.ini b/tox.ini index 92161f76..62a8ee4d 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,9 @@ deps = hypothesmith pytest trio +setenv = + # Make sure pyright is always up to date + PYRIGHT_PYTHON_FORCE_VERSION = latest skip_install = # don't install the plugin, which would register it with flake8 # and potentially stop the linter from functioning.