diff --git a/libcst/codemod/commands/convert_type_comments.py b/libcst/codemod/commands/convert_type_comments.py index a786f21a3..ef91687b6 100644 --- a/libcst/codemod/commands/convert_type_comments.py +++ b/libcst/codemod/commands/convert_type_comments.py @@ -7,7 +7,7 @@ import builtins import functools import sys -from typing import List, Optional, Set, Tuple, Union +from typing import List, Optional, Sequence, Set, Tuple, Union from typing_extensions import TypeAlias @@ -30,7 +30,7 @@ def _ast_for_node(node: cst.CSTNode) -> ast.Module: def _statement_type_comment( - node: Union[cst.SimpleStatementLine, cst.For], + node: Union[cst.SimpleStatementLine, cst.For, cst.With], ) -> Optional[str]: return _ast_for_node(node).body[-1].type_comment @@ -156,6 +156,30 @@ def type_declaration( value=None, ) + @staticmethod + def type_declaration_statements( + bindings: UnpackedBindings, + annotations: UnpackedAnnotations, + leading_lines: Sequence[cst.EmptyLine], + ) -> List[cst.SimpleStatementLine]: + return [ + cst.SimpleStatementLine( + body=[ + AnnotationSpreader.type_declaration( + binding=binding, + raw_annotation=raw_annotation, + ) + ], + leading_lines=leading_lines if i == 0 else [], + ) + for i, (binding, raw_annotation) in enumerate( + AnnotationSpreader.annotated_bindings( + bindings=bindings, + annotations=annotations, + ) + ) + ] + def convert_Assign( node: cst.Assign, @@ -315,3 +339,94 @@ def leave_SimpleStatementLine( ) else: raise RuntimeError(f"Unhandled value {converted}") + + def leave_For( + self, + original_node: cst.For, + updated_node: cst.For, + ) -> Union[cst.For, cst.FlattenSentinel]: + """ + Convert a For with a type hint on the bound variable(s) to + use type declarations. + """ + # Type comments are only possible when the body is an indented + # block, and we need this refinement to work with the header, + # so we check and only then extract the type comment. + body = updated_node.body + if not isinstance(body, cst.IndentedBlock): + return updated_node + type_comment = _statement_type_comment(original_node) + if type_comment is None: + return updated_node + # Zip up the type hint and the bindings. If we hit an arity + # error, abort. + try: + type_declarations = AnnotationSpreader.type_declaration_statements( + bindings=AnnotationSpreader.unpack_target(updated_node.target), + annotations=AnnotationSpreader.unpack_type_comment(type_comment), + leading_lines=updated_node.leading_lines, + ) + except _ArityError: + return updated_node + # There is no arity error, so we can add the type delaration(s) + return cst.FlattenSentinel( + [ + *type_declarations, + updated_node.with_changes( + body=body.with_changes( + header=self._strip_TrailingWhitespace(body.header) + ), + leading_lines=[], + ), + ] + ) + + def leave_With( + self, + original_node: cst.With, + updated_node: cst.With, + ) -> Union[cst.With, cst.FlattenSentinel]: + """ + Convert a With with a type hint on the bound variable(s) to + use type declarations. + """ + # Type comments are only possible when the body is an indented + # block, and we need this refinement to work with the header, + # so we check and only then extract the type comment. + body = updated_node.body + if not isinstance(body, cst.IndentedBlock): + return updated_node + type_comment = _statement_type_comment(original_node) + if type_comment is None: + return updated_node + # PEP 484 does not attempt to specify type comment semantics for + # multiple with bindings (there's more than one sensible way to + # do it), so we make no attempt to handle this + targets = [ + item.asname.name for item in updated_node.items if item.asname is not None + ] + if len(targets) != 1: + return updated_node + target = targets[0] + # Zip up the type hint and the bindings. If we hit an arity + # error, abort. + try: + type_declarations = AnnotationSpreader.type_declaration_statements( + bindings=AnnotationSpreader.unpack_target(target), + annotations=AnnotationSpreader.unpack_type_comment(type_comment), + leading_lines=updated_node.leading_lines, + ) + except _ArityError: + return updated_node + # There is no arity error, so we can add the type delaration(s) + return cst.FlattenSentinel( + [ + *type_declarations, + updated_node.with_changes( + body=body.with_changes( + header=self._strip_TrailingWhitespace(body.header) + ), + leading_lines=[], + ), + ] + ) diff --git a/libcst/codemod/commands/tests/test_convert_type_comments.py b/libcst/codemod/commands/tests/test_convert_type_comments.py index e1e3da8bf..8b0b44de5 100644 --- a/libcst/codemod/commands/tests/test_convert_type_comments.py +++ b/libcst/codemod/commands/tests/test_convert_type_comments.py @@ -132,6 +132,65 @@ def test_semicolons_with_assignment(self) -> None: """ self.assertCodemod39Plus(before, after) + def test_converting_for_statements(self) -> None: + before = """ + # simple binding + for x in foo(): # type: int + pass + + # nested binding + for (a, (b, c)) in bar(): # type: int, (str, float) + pass + """ + after = """ + # simple binding + x: int + for x in foo(): + pass + + # nested binding + a: int + b: str + c: float + for (a, (b, c)) in bar(): + pass + """ + self.assertCodemod39Plus(before, after) + + def test_converting_with_statements(self) -> None: + before = """ + # simple binding + with open('file') as f: # type: File + pass + + # simple binding, with extra items + with foo(), open('file') as f, bar(): # type: File + pass + + # nested binding + with bar() as (a, (b, c)): # type: int, (str, float) + pass + """ + after = """ + # simple binding + f: "File" + with open('file') as f: + pass + + # simple binding, with extra items + f: "File" + with foo(), open('file') as f, bar(): + pass + + # nested binding + a: int + b: str + c: float + with bar() as (a, (b, c)): + pass + """ + self.assertCodemod39Plus(before, after) + def test_no_change_when_type_comment_unused(self) -> None: before = """ # type-ignores are not type comments @@ -150,6 +209,18 @@ def test_no_change_when_type_comment_unused(self) -> None: # Multiple assigns with mismatched LHS arities always result in arity # errors, and we only codemod if each target is error-free v = v0, v1 = (3, 5) # type: int, int + + # Ignore for statements with arity mismatches + for x in []: # type: int, int + pass + + # Ignore with statements with arity mismatches + with open('file') as (f0, f1): # type: File + pass + + # Ignore with statements that have multiple item bindings + with open('file') as f0, open('file') as f1: # type: File + pass """ after = before self.assertCodemod39Plus(before, after)