Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6583,6 +6583,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa

partial_type_maps = []
for operator, expr_indices in simplified_operator_list:
if_map: TypeMap
else_map: TypeMap

if operator in {"is", "is not", "==", "!="}:
if_map, else_map = self.equality_type_narrowing_helper(
node,
Expand All @@ -6598,14 +6601,24 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
item_type = operand_types[left_index]
iterable_type = operand_types[right_index]

if_map, else_map = {}, {}
if_map = {}
else_map = {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
if collection_item_type is not None:
if_map, else_map = self.narrow_type_by_equality(
"==",
operands=[operands[left_index], operands[right_index]],
operand_types=[item_type, collection_item_type],
expr_indices=[left_index, right_index],
narrowable_indices={0},
)

# We only try and narrow away 'None' for now
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually not really needed anymore. I will clean it up, there is one test case affected, a later PR in my stack changes it

if (
collection_item_type is not None
if_map is not None
and is_overlapping_none(item_type)
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
Expand All @@ -6622,11 +6635,11 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
expr = operands[right_index]
if if_type is None:
if_map = None
else:
elif if_map is not None:
if_map[expr] = if_type
if else_type is None:
else_map = None
else:
elif else_map is not None:
else_map[expr] = else_type

else:
Expand Down
2 changes: 1 addition & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def infer_constraints_for_callable(
param_spec = callee.param_spec()
param_spec_arg_types = []
param_spec_arg_names = []
param_spec_arg_kinds = []
param_spec_arg_kinds: list[ArgKind] = []

incomplete_star_mapping = False
for i, actuals in enumerate(formal_to_actual): # TODO: isn't this `enumerate(arg_types)`?
Expand Down
3 changes: 1 addition & 2 deletions mypyc/irbuild/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Final, Literal, TypedDict, cast
from typing import Any, Final, Literal, TypedDict
from typing_extensions import NotRequired

from mypy.nodes import (
Expand Down Expand Up @@ -138,7 +138,6 @@ def get_mypyc_attrs(

def set_mypyc_attr(key: str, value: Any, line: int) -> None:
if key in MYPYC_ATTRS:
key = cast(MypycAttr, key)
attrs[key] = value
lines[key] = line
else:
Expand Down
136 changes: 129 additions & 7 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1373,13 +1373,13 @@ else:
reveal_type(val) # N: Revealed type is "None"

if val in (None,):
reveal_type(val) # N: Revealed type is "__main__.A | None"
reveal_type(val) # N: Revealed type is "None"
else:
reveal_type(val) # N: Revealed type is "__main__.A | None"
reveal_type(val) # N: Revealed type is "__main__.A"
if val not in (None,):
reveal_type(val) # N: Revealed type is "__main__.A | None"
reveal_type(val) # N: Revealed type is "__main__.A"
else:
reveal_type(val) # N: Revealed type is "__main__.A | None"
reveal_type(val) # N: Revealed type is "None"

class Hmm:
def __eq__(self, other) -> bool: ...
Expand Down Expand Up @@ -2294,9 +2294,8 @@ def f(x: str | int) -> None:
y = x

if x in ["x"]:
# TODO: we should fix this reveal https://github.com/python/mypy/issues/3229
reveal_type(x) # N: Revealed type is "builtins.str | builtins.int"
y = x # E: Incompatible types in assignment (expression has type "str | int", variable has type "str")
reveal_type(x) # N: Revealed type is "builtins.str"
y = x
z = x
z = y
[builtins fixtures/primitives.pyi]
Expand Down Expand Up @@ -2806,3 +2805,126 @@ class X:
reveal_type(self.y) # N: Revealed type is "builtins.list[builtins.str]"
self.y[0].does_not_exist # E: "str" has no attribute "does_not_exist"
[builtins fixtures/dict.pyi]


[case testTypeNarrowingStringInLiteralUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInLiteralUnionSubset]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
strIn: str = "b"
strOut: str = "c"
if strIn in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
else:
reveal_type(strIn) # N: Revealed type is "builtins.str"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringNotInLiteralUnion]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
strIn: str = "c"
strOut: str = "d"
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "builtins.str"
else:
reveal_type(strIn) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
if strOut in typeAlpha:
reveal_type(strOut) # N: Revealed type is "Literal['a'] | Literal['b'] | Literal['c']"
else:
reveal_type(strOut) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowingStringInLiteralUnionDontExpand]
from typing import Literal, Tuple
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
strIn: Literal['c'] = "c"
reveal_type(strIn) # N: Revealed type is "Literal['c']"
#Check we don't expand a Literal into the Union type
if strIn not in typeAlpha:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
else:
reveal_type(strIn) # N: Revealed type is "Literal['c']"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInMixedUnion]
from typing import Literal, Tuple
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInSet]
from typing import Literal, Set
typ: Set[Literal['a', 'b']] = {'a', 'b'}
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingStringInList]
from typing import Literal, List
typ: List[Literal['a', 'b']] = ['a', 'b']
x: str = "hi!"
if x in typ:
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
else:
reveal_type(x) # N: Revealed type is "builtins.str"
if x not in typ:
reveal_type(x) # N: Revealed type is "builtins.str"
else:
reveal_type(x) # N: Revealed type is "Literal['a'] | Literal['b']"
[builtins fixtures/narrowing.pyi]
[typing fixtures/typing-medium.pyi]

[case testTypeNarrowingUnionStringFloat]
from typing import Union
def foobar(foo: Union[str, float]):
if foo in ['a', 'b']:
reveal_type(foo) # N: Revealed type is "builtins.str"
else:
reveal_type(foo) # N: Revealed type is "builtins.str | builtins.float"
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testNarrowAnyWithEqualityOrContainment]
# https://github.com/python/mypy/issues/17841
from typing import Any

def f1(x: Any) -> None:
if x is not None and x not in ["x"]:
return
reveal_type(x) # N: Revealed type is "Any"

def f2(x: Any) -> None:
if x is not None and x != "x":
return
reveal_type(x) # N: Revealed type is "Any"
[builtins fixtures/tuple.pyi]
9 changes: 8 additions & 1 deletion test-data/unit/fixtures/narrowing.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Builtins stub used in check-narrowing test cases.
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable


Tco = TypeVar('Tco', covariant=True)
Expand All @@ -15,6 +15,13 @@ class function: pass
class ellipsis: pass
class int: pass
class str: pass
class float: pass
class dict(Generic[KT, VT]): pass

def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass

class list(Sequence[Tco]):
def __contains__(self, other: object) -> bool: pass
class set(Iterable[Tco], Generic[Tco]):
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
def __contains__(self, item: object) -> bool: pass