Skip to content
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Changelog
*[CalVer, YY.month.patch](https://calver.org/)*

## Future
- add TRIO112, nursery body with only a call to `nursery.start[_soon]` and not passing itself as a parameter can be replaced with a regular function call.
## 22.8.5
- Add TRIO111: Variable, from context manager opened inside nursery, passed to `start[_soon]` might be invalidly accesed while in use, due to context manager closing before the nursery. This is usually a bug, and nurseries should generally be the inner-most context manager.
- Add TRIO112: this single-task nursery could be replaced by awaiting the function call directly.

## 22.8.4
- Fix TRIO108 raising errors on yields in some sync code.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ pip install flake8-trio
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
- **TRIO109**: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead
- **TRIO110**: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.
- **TRIO111**: Variable, from context manager opened inside nursery, passed to `start[_soon]` might be invalidly accesed while in use, due to context manager closing before the nursery. This is usually a bug, and nurseries should generally be the inner-most context manager.
- **TRIO112**: nursery body with only a call to `nursery.start[_soon]` and not passing itself as a parameter can be replaced with a regular function call.
146 changes: 124 additions & 22 deletions flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.8.4"
__version__ = "22.8.5"


Error_codes = {
Expand Down Expand Up @@ -55,7 +55,12 @@
"`trio.[fail/move_on]_[after/at]` instead"
),
"TRIO110": "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.",
"TRIO112": "Redundant nursery {}, consider replacing with a regular function call",
"TRIO111": (
"variable {2} is usable within the context manager on line {0}, but that "
"will close before nursery opened on line {1} - this is usually a bug. "
"Nurseries should generally be the inner-most context manager."
),
"TRIO112": "Redundant nursery {}, consider replacing with directly awaiting the function call",
}


Expand Down Expand Up @@ -162,10 +167,18 @@ def error(self, error: str, node: HasLineCol, *args: object):
if not self.suppress_errors:
self._problems.append(Error(error, node.lineno, node.col_offset, *args))

def get_state(self, *attrs: str) -> Dict[str, Any]:
def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
if not attrs:
attrs = tuple(self.__dict__.keys())
return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
res: Dict[str, Any] = {}
for attr in attrs:
if attr == "_problems":
continue
value = getattr(self, attr)
if copy and hasattr(value, "copy"):
value = value.copy()
res[attr] = value
return res

def set_state(self, attrs: Dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
Expand All @@ -187,37 +200,68 @@ def has_decorator(decorator_list: List[ast.expr], *names: str):
return False


# handles 100, 101, 106, 109, 110
# handles 100, 101, 106, 109, 110, 111, 112
class VisitorMiscChecks(Flake8TrioVisitor):
class NurseryCall(NamedTuple):
stack_index: int
name: str

class TrioContextManager(NamedTuple):
lineno: int
name: str
is_nursery: bool

def __init__(self):
super().__init__()

# variables only used for 101
# 101
self._yield_is_error = False
self._safe_decorator = False

# ---- 100, 101 ----
# 111
self._context_managers: List[VisitorMiscChecks.TrioContextManager] = []
self._nursery_call: Optional[VisitorMiscChecks.NurseryCall] = None

self.defaults = self.get_state(copy=True)

# ---- 100, 101, 111, 112 ----
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
# 100
self.check_for_trio100(node)
self.check_for_trio112(node)

# 101 for rest of function
outer = self.get_state("_yield_is_error")
outer = self.get_state("_yield_is_error", "_context_managers", copy=True)

# Check for a `with trio.<scope_creater>`
if not self._safe_decorator:
for item in (i.context_expr for i in node.items):
if (
get_matching_call(item, "open_nursery", *cancel_scope_names)
is not None
):
self._yield_is_error = True
break
for item in node.items:
# 101
# if there's no safe decorator,
# and it's not yet been determined that yield is error
# and this withitem opens a cancelscope:
# then yielding is unsafe
if (
not self._safe_decorator
and not self._yield_is_error
and get_matching_call(
item.context_expr, "open_nursery", *cancel_scope_names
)
is not None
):
self._yield_is_error = True

self.generic_visit(node)
# 111
# if a withitem is saved in a variable,
# push its line, variable, and whether it's a trio nursery
# to the _context_managers stack,
if isinstance(item.optional_vars, ast.Name):
self._context_managers.append(
self.TrioContextManager(
item.context_expr.lineno,
item.optional_vars.id,
get_matching_call(item.context_expr, "open_nursery")
is not None,
)
)

# reset yield_is_error
self.generic_visit(node)
self.set_state(outer)

visit_AsyncWith = visit_With
Expand All @@ -236,7 +280,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
# ---- 101 ----
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
outer = self.get_state()
self._yield_is_error = False
self.set_state(self.defaults, copy=True)

# check for @<context_manager_name> and @<library>.<context_manager_name>
if has_decorator(node.decorator_list, *context_manager_names):
Expand All @@ -251,6 +295,12 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.check_for_trio109(node)
self.visit_FunctionDef(node)

def visit_Lambda(self, node: ast.Lambda):
outer = self.get_state()
self.set_state(self.defaults, copy=True)
self.generic_visit(node)
self.set_state(outer)

# ---- 101 ----
def visit_Yield(self, node: ast.Yield):
if self._yield_is_error:
Expand All @@ -260,8 +310,11 @@ def visit_Yield(self, node: ast.Yield):

# ---- 109 ----
def check_for_trio109(self, node: ast.AsyncFunctionDef):
# pending configuration or a more sophisticated check, ignore
# all functions with a decorator
if node.decorator_list:
return

args = node.args
for arg in (*args.posonlyargs, *args.args, *args.kwonlyargs):
if arg.arg == "timeout":
Expand All @@ -277,6 +330,7 @@ def visit_Import(self, node: ast.Import):
for name in node.names:
if name.name == "trio" and name.asname is not None:
self.error("TRIO106", node)
self.generic_visit(node)

# ---- 110 ----
def visit_While(self, node: ast.While):
Expand All @@ -292,6 +346,53 @@ def check_for_trio110(self, node: ast.While):
):
self.error("TRIO110", node)

# ---- 111 ----
# if it's a <X>.start[_soon] call
# and <X> is a nursery listed in self._context_managers:
# Save <X>'s index in self._context_managers to guard against cm's higher in the
# stack being passed as parameters to it. (and save <X> for the error message)
def visit_Call(self, node: ast.Call):
outer = self.get_state("_nursery_call")

if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.attr in ("start", "start_soon")
):
self._nursery_call = None
for i, cm in enumerate(self._context_managers):
if node.func.value.id == cm.name:
# don't break upon finding a nursery in case there's multiple cm's
# on the stack with the same name
if cm.is_nursery:
self._nursery_call = self.NurseryCall(i, node.func.attr)
else:
self._nursery_call = None

self.generic_visit(node)
self.set_state(outer)

# If we're inside a <X>.start[_soon] call (where <X> is a nursery),
# and we're accessing a variable cm that's on the self._context_managers stack,
# with a higher index than <X>:
# Raise error since the scope of cm may close before the function passed to the
# nursery finishes.
def visit_Name(self, node: ast.Name):
self.generic_visit(node)
if self._nursery_call is None:
return

for i, cm in enumerate(self._context_managers):
if cm.name == node.id and i > self._nursery_call.stack_index:
self.error(
"TRIO111",
node,
cm.lineno,
self._context_managers[self._nursery_call.stack_index].lineno,
node.id,
self._nursery_call.name,
)

# if with has a withitem `trio.open_nursery() as <X>`,
# and the body is only a single expression <X>.start[_soon](),
# and does not pass <X> as a parameter to the expression
Expand Down Expand Up @@ -323,6 +424,7 @@ def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):
self.error("TRIO112", item.context_expr, var_name)


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Optional[Statement]:
def has_exception(node: Optional[ast.expr]) -> str:
if isinstance(node, ast.Name) and node.id == "BaseException":
Expand Down
41 changes: 25 additions & 16 deletions tests/test_flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,27 @@ def test_eval(test: str, path: str):
try:
# Append a bunch of empty strings so string formatting gives garbage
# instead of throwing an exception
args = eval(
f"[{reg_match}]",
{
"lineno": lineno,
"line": lineno,
"Statement": Statement,
"Stmt": Statement,
},
)
try:
args = eval(
f"[{reg_match}]",
{
"lineno": lineno,
"line": lineno,
"Statement": Statement,
"Stmt": Statement,
},
)
except NameError:
print(f"failed to eval on line {lineno}", file=sys.stderr)
raise

except Exception as e:
print(f"lineno: {lineno}, line: {line}", file=sys.stderr)
raise e
col, *args = args
if args:
col, *args = args
else:
col = 0
assert isinstance(
col, int
), f'invalid column "{col}" @L{lineno}, in "{line}"'
Expand Down Expand Up @@ -163,13 +170,15 @@ def assert_expected_errors(plugin: Plugin, include: Iterable[str], *expected: Er

def print_first_diff(errors: Sequence[Error], expected: Sequence[Error]):
first_error_line: List[Error] = []
for e in errors:
if e.line == errors[0].line:
first_error_line.append(e)
first_expected_line: List[Error] = []
for e in expected:
if e.line == expected[0].line:
first_expected_line.append(e)
for err, exp in zip(errors, expected):
if err == exp:
continue
if not first_error_line or err.line == first_error_line[0]:
first_error_line.append(err)
if not first_expected_line or exp.line == first_expected_line[0]:
first_expected_line.append(exp)

if first_expected_line != first_error_line:
print(
"First lines with different errors",
Expand Down
Loading