Skip to content

Commit b9a6129

Browse files
authored
[ty] Improve support for kwarg splats in dictionary literals (#22781)
Resolves astral-sh/ty#1332.
1 parent f516d47 commit b9a6129

4 files changed

Lines changed: 106 additions & 67 deletions

File tree

crates/ty_python_semantic/resources/mdtest/literal/collections/dictionary.md

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,32 @@ reveal_type({1: (1, 2), 2: (3, 4)}) # revealed: dict[Unknown | int, Unknown | t
2121
## Unpacked dict
2222

2323
```py
24+
from typing import Mapping, KeysView
25+
2426
a = {"a": 1, "b": 2}
2527
b = {"c": 3, "d": 4}
28+
c = {**a, **b}
29+
reveal_type(c) # revealed: dict[Unknown | str, Unknown | int]
30+
31+
# revealed: list[int | str]
32+
# revealed: list[int | str]
33+
d: dict[str, list[int | str]] = {"a": reveal_type([1, 2]), **{"b": reveal_type([3, 4])}}
34+
reveal_type(d) # revealed: dict[str, list[int | str]]
35+
36+
class HasKeysAndGetItem:
37+
def keys(self) -> KeysView[str]:
38+
return {}.keys()
39+
40+
def __getitem__(self, arg: str) -> int:
41+
return 42
42+
43+
def _(a: dict[str, int], b: Mapping[str, int], c: HasKeysAndGetItem, d: object):
44+
reveal_type({**a}) # revealed: dict[Unknown | str, Unknown | int]
45+
reveal_type({**b}) # revealed: dict[Unknown | str, Unknown | int]
46+
reveal_type({**c}) # revealed: dict[Unknown | str, Unknown | int]
2647

27-
d = {**a, **b}
28-
reveal_type(d) # revealed: dict[Unknown | str, Unknown | int]
48+
# error: [invalid-argument-type] "Argument expression after ** must be a mapping type: Found `object`"
49+
reveal_type({**d}) # revealed: dict[Unknown, Unknown]
2950
```
3051

3152
## Dict of functions

crates/ty_python_semantic/src/types.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,6 +3801,60 @@ impl<'db> Type<'db> {
38013801
non_negative_int_literal(db, return_ty)
38023802
}
38033803

3804+
/// Returns the key and value types of this object if it was unpacked using `**`,
3805+
/// or `None` if the object does not support unpacking.
3806+
fn unpack_keys_and_items(self, db: &'db dyn Db) -> Option<(Type<'db>, Type<'db>)> {
3807+
let key_ty = match self
3808+
.member_lookup_with_policy(
3809+
db,
3810+
Name::new_static("keys"),
3811+
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
3812+
)
3813+
.place
3814+
{
3815+
Place::Defined(DefinedPlace {
3816+
ty: keys_method,
3817+
definedness: Definedness::AlwaysDefined,
3818+
..
3819+
}) => keys_method
3820+
.try_call(db, &CallArguments::none())
3821+
.ok()
3822+
.and_then(|bindings| {
3823+
Some(
3824+
bindings
3825+
.return_type(db)
3826+
.try_iterate(db)
3827+
.ok()?
3828+
.homogeneous_element_type(db),
3829+
)
3830+
})?,
3831+
3832+
_ => return None,
3833+
};
3834+
3835+
let value_ty = match self
3836+
.member_lookup_with_policy(
3837+
db,
3838+
Name::new_static("__getitem__"),
3839+
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
3840+
)
3841+
.place
3842+
{
3843+
Place::Defined(DefinedPlace {
3844+
ty: getitem_method,
3845+
definedness: Definedness::AlwaysDefined,
3846+
..
3847+
}) => getitem_method
3848+
.try_call(db, &CallArguments::positional([Type::unknown()]))
3849+
.ok()
3850+
.map_or_else(Type::unknown, |bindings| bindings.return_type(db)),
3851+
3852+
_ => Type::unknown(),
3853+
};
3854+
3855+
Some((key_ty, value_ty))
3856+
}
3857+
38043858
/// Returns a [`Bindings`] that can be used to analyze a call to this type. You must call
38053859
/// [`match_parameters`][Bindings::match_parameters] and [`check_types`][Bindings::check_types]
38063860
/// to fully analyze a particular call site.

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3588,41 +3588,13 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
35883588
}
35893589
} else {
35903590
let mut value_type_fallback = |argument_type: Type<'db>| {
3591-
// TODO: Instead of calling the `keys` and `__getitem__` methods, we should
3592-
// instead get the constraints which satisfies the `SupportsKeysAndGetItem`
3593-
// protocol i.e., the key and value type.
3594-
let key_type = match argument_type
3595-
.member_lookup_with_policy(
3596-
self.db,
3597-
Name::new_static("keys"),
3598-
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
3599-
)
3600-
.place
3601-
{
3602-
Place::Defined(DefinedPlace {
3603-
ty: keys_method,
3604-
definedness: Definedness::AlwaysDefined,
3605-
..
3606-
}) => keys_method
3607-
.try_call(self.db, &CallArguments::none())
3608-
.ok()
3609-
.and_then(|bindings| {
3610-
Some(
3611-
bindings
3612-
.return_type(self.db)
3613-
.try_iterate(self.db)
3614-
.ok()?
3615-
.homogeneous_element_type(self.db),
3616-
)
3617-
}),
3618-
_ => None,
3619-
};
3620-
3621-
let Some(key_type) = key_type else {
3591+
let Some((key_type, value_type)) = argument_type.unpack_keys_and_items(self.db)
3592+
else {
36223593
self.errors.push(BindingError::KeywordsNotAMapping {
36233594
argument_index: adjusted_argument_index,
36243595
provided_ty: argument_type,
36253596
});
3597+
36263598
return None;
36273599
};
36283600

@@ -3640,26 +3612,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
36403612
});
36413613
}
36423614

3643-
Some(
3644-
match argument_type
3645-
.member_lookup_with_policy(
3646-
self.db,
3647-
Name::new_static("__getitem__"),
3648-
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
3649-
)
3650-
.place
3651-
{
3652-
Place::Defined(DefinedPlace {
3653-
ty: getitem_method,
3654-
definedness: Definedness::AlwaysDefined,
3655-
..
3656-
}) => getitem_method
3657-
.try_call(self.db, &CallArguments::positional([Type::unknown()]))
3658-
.ok()
3659-
.map_or_else(Type::unknown, |bindings| bindings.return_type(self.db)),
3660-
_ => Type::unknown(),
3661-
},
3662-
)
3615+
Some(value_type)
36633616
};
36643617

36653618
let value_type = match argument_type {

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::iter;
2-
31
use itertools::{Either, EitherOrBoth, Itertools};
42
use ruff_db::diagnostic::{Annotation, Diagnostic, DiagnosticId, Severity, Span};
53
use ruff_db::files::File;
@@ -10126,21 +10124,34 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
1012610124

1012710125
for elts in elts {
1012810126
// An unpacking expression for a dictionary.
10129-
if let &[None, Some(value)] = elts.as_slice() {
10130-
let inferred_value_ty =
10131-
infer_elt_expression(self, (1, value, TypeContext::default()));
10127+
if let &[None, Some(value_expr)] = elts.as_slice() {
10128+
let unpack_ty = infer_elt_expression(self, (1, value_expr, tcx));
1013210129

10133-
// Merge the inferred type of the nested dictionary.
10134-
if let Some(specialization) =
10135-
inferred_value_ty.known_specialization(self.db(), KnownClass::Dict)
10136-
{
10137-
for (elt_ty, inferred_elt_ty) in
10138-
iter::zip(elt_tys.clone(), specialization.types(self.db()))
10130+
let Some((unpacked_key_ty, unpacked_value_ty)) =
10131+
unpack_ty.unpack_keys_and_items(self.db())
10132+
else {
10133+
if let Some(builder) =
10134+
self.context.report_lint(&INVALID_ARGUMENT_TYPE, value_expr)
1013910135
{
10140-
builder
10141-
.infer(Type::TypeVar(elt_ty), *inferred_elt_ty)
10142-
.ok()?;
10136+
let mut diag = builder
10137+
.into_diagnostic("Argument expression after ** must be a mapping type");
10138+
10139+
diag.set_primary_message(format_args!(
10140+
"Found `{}`",
10141+
unpack_ty.display(self.db())
10142+
));
1014310143
}
10144+
10145+
continue;
10146+
};
10147+
10148+
let mut elt_tys = elt_tys.clone();
10149+
if let Some((key_ty, value_ty)) = elt_tys.next_tuple() {
10150+
builder.infer(Type::TypeVar(key_ty), unpacked_key_ty).ok()?;
10151+
10152+
builder
10153+
.infer(Type::TypeVar(value_ty), unpacked_value_ty)
10154+
.ok()?;
1014410155
}
1014510156

1014610157
continue;

0 commit comments

Comments
 (0)