From 7c0c61cf539ca312fcb9c0302837758a6a78e75f Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 5 Aug 2022 21:20:45 +0200
Subject: [PATCH] various small fixes and improved test output
---
flake8_trio.py | 83 +++++++++++++++---------------
tests/test_flake8_trio.py | 104 ++++++++++++++++++++++++++------------
tests/trio100_py39.py | 4 +-
tests/trio101.py | 1 -
tests/trio102.py | 8 +--
tox.ini | 3 ++
6 files changed, 125 insertions(+), 78 deletions(-)
diff --git a/flake8_trio.py b/flake8_trio.py
index 49deb493..7bd27fca 100644
--- a/flake8_trio.py
+++ b/flake8_trio.py
@@ -11,13 +11,14 @@
import ast
import tokenize
-from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.7.6"
-
Error = Tuple[int, int, str, Type[Any]]
+
+
checkpoint_node_types = (ast.Await, ast.AsyncFor, ast.AsyncWith)
cancel_scope_names = (
"fail_after",
@@ -39,13 +40,13 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) ->
class Flake8TrioVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
- self.problems: List[Error] = []
+ self._problems: List[Error] = []
@classmethod
- def run(cls, tree: ast.AST) -> Generator[Error, None, None]:
+ def run(cls, tree: ast.AST) -> Iterable[Error]:
visitor = cls()
visitor.visit(tree)
- yield from visitor.problems
+ yield from visitor._problems
def visit_nodes(
self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False
@@ -62,7 +63,16 @@ def visit_nodes(
visit(node)
def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any):
- self.problems.append(make_error(error, lineno, col, *args, **kwargs))
+ self._problems.append(make_error(error, lineno, col, *args, **kwargs))
+
+ def get_state(self, *attrs: str) -> Dict[str, Any]:
+ if not attrs:
+ attrs = tuple(self.__dict__.keys())
+ return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
+
+ def set_state(self, attrs: Dict[str, Any]):
+ for attr, value in attrs.items():
+ setattr(self, attr, value)
class TrioScope:
@@ -87,8 +97,6 @@ def __init__(self, node: ast.Call, funcname: str, packagename: str):
def __str__(self):
# Not supporting other ways of importing trio
- # if self.packagename is None:
- # return self.funcname
return f"{self.packagename}.{self.funcname}"
@@ -100,7 +108,6 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
and node.func.value.id == "trio"
and node.func.attr in names
):
- # return "trio." + node.func.attr
return TrioScope(node, node.func.attr, node.func.value.id)
return None
@@ -124,7 +131,7 @@ def __init__(self):
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
self.check_for_trio100(node)
- outer_yie = self._yield_is_error
+ outer = self.get_state("_yield_is_error")
# Check for a `with trio.`
if not self._safe_decorator:
@@ -139,13 +146,13 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
self.generic_visit(node)
# reset yield_is_error
- self._yield_is_error = outer_yie
+ self.set_state(outer)
def visit_AsyncWith(self, node: ast.AsyncWith):
self.visit_With(node)
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
- outer = self._safe_decorator, self._yield_is_error
+ outer = self.get_state()
self._yield_is_error = False
# check for @ and @.
@@ -154,14 +161,14 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
self.generic_visit(node)
- self._safe_decorator, self._yield_is_error = outer
+ self.set_state(outer)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.visit_FunctionDef(node)
def visit_Yield(self, node: ast.Yield):
if self._yield_is_error:
- self.problems.append(make_error(TRIO101, node.lineno, node.col_offset))
+ self.error(TRIO101, node.lineno, node.col_offset)
self.generic_visit(node)
@@ -173,19 +180,17 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
isinstance(x, checkpoint_node_types) and x != node
for x in ast.walk(node)
):
- self.problems.append(
- make_error(TRIO100, item.lineno, item.col_offset, call)
- )
+ self.error(TRIO100, item.lineno, item.col_offset, call)
def visit_ImportFrom(self, node: ast.ImportFrom):
if node.module == "trio":
- self.problems.append(make_error(TRIO106, node.lineno, node.col_offset))
+ self.error(TRIO106, node.lineno, node.col_offset)
self.generic_visit(node)
def visit_Import(self, node: ast.Import):
for name in node.names:
if name.name == "trio" and name.asname is not None:
- self.problems.append(make_error(TRIO106, node.lineno, node.col_offset))
+ self.error(TRIO106, node.lineno, node.col_offset)
def critical_except(node: ast.ExceptHandler) -> Optional[Tuple[int, int, str]]:
@@ -239,9 +244,7 @@ def visit_Await(
cm.has_timeout and cm.shielded for cm in self._trio_context_managers
)
):
- self.problems.append(
- make_error(TRIO102, node.lineno, node.col_offset, *self._critical_scope)
- )
+ self.error(TRIO102, node.lineno, node.col_offset, *self._critical_scope)
if visit_children:
self.generic_visit(node)
@@ -275,14 +278,15 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
self.visit_With(node)
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
- outer_cm = self._safe_decorator
+ outer = self.get_state("_safe_decorator")
# check for @ and @.
if has_decorator(node.decorator_list, *context_manager_names):
self._safe_decorator = True
self.generic_visit(node)
- self._safe_decorator = outer_cm
+
+ self.set_state(outer)
visit_AsyncFunctionDef = visit_FunctionDef
@@ -292,13 +296,13 @@ def critical_visit(
block: Tuple[int, int, str],
generic: bool = False,
):
- outer = self._critical_scope, self._trio_context_managers
+ outer = self.get_state("_critical_scope", "_trio_context_managers")
self._trio_context_managers = []
self._critical_scope = block
self.visit_nodes(node, generic=generic)
- self._critical_scope, self._trio_context_managers = outer
+ self.set_state(outer)
def visit_Try(self, node: ast.Try):
# There's no visit_Finally, so we need to manually visit the Try fields.
@@ -345,7 +349,7 @@ def __init__(self):
# then there might be a code path that doesn't re-raise.
def visit_ExceptHandler(self, node: ast.ExceptHandler):
- outer = (self.unraised, self.except_name, self.loop_depth)
+ outer = self.get_state()
marker = critical_except(node)
# we need to *not* unset self.unraised if this is non-critical, to still
@@ -362,10 +366,9 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler):
self.generic_visit(node)
if self.unraised and marker is not None:
- # print(marker)
- self.problems.append(make_error(TRIO103, *marker))
+ self.error(TRIO103, *marker)
- (self.unraised, self.except_name, self.loop_depth) = outer
+ self.set_state(outer)
def visit_Raise(self, node: ast.Raise):
# if there's an unraised critical exception, the raise isn't bare,
@@ -375,7 +378,7 @@ def visit_Raise(self, node: ast.Raise):
and node.exc is not None
and not (isinstance(node.exc, ast.Name) and node.exc.id == self.except_name)
):
- self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
+ self.error(TRIO104, node.lineno, node.col_offset)
# treat it as safe regardless, to avoid unnecessary error messages.
self.unraised = False
@@ -385,7 +388,7 @@ def visit_Raise(self, node: ast.Raise):
def visit_Return(self, node: Union[ast.Return, ast.Yield]):
if self.unraised:
# Error: must re-raise
- self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
+ self.error(TRIO104, node.lineno, node.col_offset)
self.generic_visit(node)
visit_Yield = visit_Return
@@ -434,20 +437,22 @@ def visit_If(self, node: ast.If):
# we completely disregard them when checking coverage by resetting the
# effects of them afterwards
def visit_For(self, node: Union[ast.For, ast.While]):
- outer_unraised = self.unraised
+ outer = self.get_state("unraised")
+
self.loop_depth += 1
for n in node.body:
self.visit(n)
self.loop_depth -= 1
for n in node.orelse:
self.visit(n)
- self.unraised = outer_unraised
+
+ self.set_state(outer)
visit_While = visit_For
def visit_Break(self, node: Union[ast.Break, ast.Continue]):
if self.unraised and self.loop_depth == 0:
- self.problems.append(make_error(TRIO104, node.lineno, node.col_offset))
+ self.error(TRIO104, node.lineno, node.col_offset)
self.generic_visit(node)
visit_Continue = visit_Break
@@ -492,9 +497,7 @@ def visit_Call(self, node: ast.Call):
or not isinstance(self.node_stack[-2], ast.Await)
)
):
- self.problems.append(
- make_error(TRIO105, node.lineno, node.col_offset, node.func.attr)
- )
+ self.error(TRIO105, node.lineno, node.col_offset, node.func.attr)
self.generic_visit(node)
@@ -615,7 +618,7 @@ def from_filename(cls, filename: str) -> "Plugin":
source = f.read()
return cls(ast.parse(source))
- def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
+ def run(self) -> Iterable[Error]:
for v in Flake8TrioVisitor.__subclasses__():
yield from v.run(self._tree)
@@ -625,7 +628,7 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
TRIO102 = "TRIO102: await inside {2} on line {0} must have shielded cancel scope with a timeout"
TRIO103 = "TRIO103: {} block with a code path that doesn't re-raise the error"
TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised"
-TRIO105 = "TRIO105: Trio async function {} must be immediately awaited"
+TRIO105 = "TRIO105: trio async function {} must be immediately awaited"
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised"
TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it."
diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py
index 3a43b96b..78d86cf2 100644
--- a/tests/test_flake8_trio.py
+++ b/tests/test_flake8_trio.py
@@ -39,55 +39,97 @@ def test_eval(test: str, path: str):
major, minor = version_str[0], version_str[1:]
v_i = sys.version_info
if (v_i.major, v_i.minor) < (int(major), int(minor)):
- return
+ raise unittest.SkipTest("v_i, major, minor")
test = test.split("_")[0]
error_msg = getattr(flake8_trio, test)
expected: List[Error] = []
with open(os.path.join("tests", path)) as file:
- for lineno, line in enumerate(file):
- # get text between `error: ` and newline
- k = re.search(r"(?<=error: ).*(?=\n)", line)
- if not k:
+ for lineno, line in enumerate(file, start=1):
+ # get text between `error:` and end of line
+ k = re.search(r"(?<=error:).*$", line)
+ if not k or line.strip()[0] == "#":
continue
- # Append a bunch of 0's so string formatting gives garbage instead
+ # Append a bunch of empty strings so string formatting gives garbage instead
# of throwing an exception
- args = [m.strip() for m in k.group().split(",")] + ["0"] * 5
+ args = [m.strip() for m in k.group().split(",")] + [""] * 5
col, *args = args
- expected.append(make_error(error_msg, lineno + 1, int(col), *args))
+ for i, arg in enumerate(args):
+ if "$lineno" in arg:
+ args[i] = eval(arg.replace("$", ""), {"lineno": lineno})
+ assert col.isdigit(), f'invalid column "{col}" @L{lineno}, in "{line}"'
+ expected.append(make_error(error_msg, lineno, int(col), *args))
+ assert expected, "failed to parse any errors in file"
assert_expected_errors(path, test, *expected)
-# This function is also a mess now, but I keep slowly iterating on getting it to
-# print actually helpful error messages in all cases - which is a struggle.
-# It'll likely continue to be a mess for the foreseeable future
-def assert_expected_errors(test_file: str, include: str, *expected: Error) -> None:
- def trim_messages(messages: Iterable[Error]):
- return tuple(((line, col, int(msg[4:7])) for line, col, msg, _ in messages))
-
+def assert_expected_errors(test_file: str, include: str, *expected: Error):
filename = Path(__file__).absolute().parent / test_file
plugin = Plugin.from_filename(str(filename))
- errors = tuple(e for e in plugin.run() if include in e[2])
-
- # start with a check with trimmed errors that will make for smaller diff messages
- trim_errors = trim_messages(errors)
- trim_expected = trim_messages(expected)
-
- cls = unittest.TestCase()
- unexpected = sorted(set(trim_errors) - set(trim_expected))
- missing = sorted(set(trim_expected) - set(trim_errors))
- cls.assertEqual((unexpected, missing), ([], []), msg="(unexpected, missing)")
+ errors = tuple(sorted(e for e in plugin.run() if include in e[2]))
+ expected = tuple(sorted(expected))
- unexpected = sorted(set(errors) - set(expected))
- missing = sorted(set(expected) - set(errors))
- if unexpected and missing:
- cls.assertEqual(unexpected[0], missing[0])
- cls.assertEqual((unexpected, missing), ([], []), msg="(unexpected, missing)")
+ assert_correct_lines(errors, expected)
+ assert_correct_columns(errors, expected)
+ assert_correct_messages(errors, expected)
# full check
- cls.assertSequenceEqual(sorted(errors), sorted(expected))
+ unittest.TestCase().assertSequenceEqual(sorted(errors), sorted(expected))
+
+
+def assert_correct_lines(errors: Iterable[Error], expected: Iterable[Error]):
+ # Check that errors are on correct lines
+ error_lines = {line for line, *_ in errors}
+ expected_lines = {line for line, *_ in expected}
+ unexpected_lines = sorted(error_lines - expected_lines)
+ missing_lines = sorted(expected_lines - error_lines)
+ unittest.TestCase().assertEqual(
+ unexpected_lines,
+ missing_lines,
+ msg="Lines with unexpected errors; missing errors",
+ )
+
+
+def assert_correct_columns(errors: Iterable[Error], expected: Iterable[Error]):
+ # check errors have correct columns
+ col_error = False
+ for (line, error_col, *_), (_, expected_col, *_) in zip(errors, expected):
+ if error_col != expected_col:
+ if not col_error:
+ print("Errors with same line but different columns:", file=sys.stderr)
+ print("| line | actual | expected |", file=sys.stderr)
+ col_error = True
+ print(
+ f"| {line:4} | {error_col:6} | {expected_col:8} |",
+ file=sys.stderr,
+ )
+ assert not col_error
+
+
+def assert_correct_messages(errors: Iterable[Error], expected: Iterable[Error]):
+ # check errors have correct messages
+ msg_error = False
+ for (line, _, error_msg, *_), (_, _, expected_msg, *_) in zip(errors, expected):
+ if error_msg != expected_msg:
+ if not msg_error:
+ print(
+ "Errors with different messages:",
+ "-" * 20,
+ sep="\n",
+ file=sys.stderr,
+ )
+ msg_error = True
+ print(
+ f"* line: {line:3}",
+ f" actual: {error_msg}",
+ f"expected: {expected_msg}",
+ "-" * 20,
+ sep="\n",
+ file=sys.stderr,
+ )
+ assert not msg_error
@pytest.mark.fuzz
diff --git a/tests/trio100_py39.py b/tests/trio100_py39.py
index 29ea1930..b8cccb6f 100644
--- a/tests/trio100_py39.py
+++ b/tests/trio100_py39.py
@@ -3,14 +3,14 @@
async def function_name():
with (
- open("veryverylongfilenamesoshedsplitsthisintotwolines") as _,
+ open("") as _,
trio.fail_after(10), # error: 8, trio.fail_after
):
pass
with (
trio.fail_after(5), # error: 8, trio.fail_after
- open("veryverylongfilenamesoshedsplitsthisintotwolines") as _,
+ open("") as _,
trio.move_on_after(5), # error: 8, trio.move_on_after
):
pass
diff --git a/tests/trio101.py b/tests/trio101.py
index bc6760e1..0815d93b 100644
--- a/tests/trio101.py
+++ b/tests/trio101.py
@@ -23,7 +23,6 @@ def foo2():
async def foo3():
async with trio.CancelScope() as _:
- await trio.sleep(1) # so trio100 doesn't complain
yield 1 # error: 8
diff --git a/tests/trio102.py b/tests/trio102.py
index e8dafbbc..f1a51581 100644
--- a/tests/trio102.py
+++ b/tests/trio102.py
@@ -122,7 +122,7 @@ async def foo3():
await foo() # safe
with trio.fail_after(5), trio.move_on_after(30) as s:
s.shield = True
- await foo() # safe in theory, error: 12, 115, 4, try/finally
+ await foo() # safe in theory, error: 12, $lineno-10, 4, try/finally
# New: except cancelled/baseexception are also critical
@@ -132,11 +132,11 @@ async def foo4():
except ValueError:
await foo() # safe
except trio.Cancelled:
- await foo() # error: 8, 134, 11, trio.Cancelled
+ await foo() # error: 8, $lineno-1, 11, trio.Cancelled
except BaseException:
- await foo() # error: 8, 136, 11, BaseException
+ await foo() # error: 8, $lineno-1, 11, BaseException
except:
- await foo() # error: 8, 138, 4, bare except
+ await foo() # error: 8, $lineno-1, 4, bare except
async def foo5():
diff --git a/tox.ini b/tox.ini
index 92161f76..62a8ee4d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -18,6 +18,9 @@ deps =
hypothesmith
pytest
trio
+setenv =
+ # Make sure pyright is always up to date
+ PYRIGHT_PYTHON_FORCE_VERSION = latest
skip_install =
# don't install the plugin, which would register it with flake8
# and potentially stop the linter from functioning.