Skip to content

Commit cd5b323

Browse files
committed
mGCA: Validate const literal against expected type
1 parent 56aaf58 commit cd5b323

File tree

18 files changed

+245
-75
lines changed

18 files changed

+245
-75
lines changed

compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,8 +2798,17 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
27982798
span: Span,
27992799
) -> Const<'tcx> {
28002800
let tcx = self.tcx();
2801-
let input = LitToConstInput { lit: *kind, ty, neg };
2802-
tcx.at(span).lit_to_const(input)
2801+
if let LitKind::Err(guar) = *kind {
2802+
return ty::Const::new_error(tcx, guar);
2803+
}
2804+
let input = LitToConstInput { lit: *kind, ty, neg: false };
2805+
match tcx.at(span).lit_to_const(input) {
2806+
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
2807+
None => {
2808+
let e = tcx.dcx().span_err(span, "type annotations needed for the literal");
2809+
ty::Const::new_error(tcx, e)
2810+
}
2811+
}
28032812
}
28042813

28052814
#[instrument(skip(self), level = "debug")]
@@ -2832,7 +2841,11 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
28322841
// Allow the `ty` to be an alias type, though we cannot handle it here, we just go through
28332842
// the more expensive anon const code path.
28342843
.filter(|l| !l.ty.has_aliases())
2835-
.map(|l| tcx.at(expr.span).lit_to_const(l))
2844+
.and_then(|l| {
2845+
tcx.at(expr.span)
2846+
.lit_to_const(l)
2847+
.map(|value| ty::Const::new_value(tcx, value.valtree, value.ty))
2848+
})
28362849
}
28372850

28382851
fn require_type_const_attribute(

compiler/rustc_middle/src/queries.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ rustc_queries! {
14121412
// FIXME get rid of this with valtrees
14131413
query lit_to_const(
14141414
key: LitToConstInput<'tcx>
1415-
) -> ty::Const<'tcx> {
1415+
) -> Option<ty::Value<'tcx>> {
14161416
desc { "converting literal to const" }
14171417
}
14181418

compiler/rustc_middle/src/query/erase.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ impl Erasable for Option<ty::EarlyBinder<'_, Ty<'_>>> {
256256
type Storage = [u8; size_of::<Option<ty::EarlyBinder<'static, Ty<'static>>>>()];
257257
}
258258

259+
impl Erasable for Option<ty::Value<'_>> {
260+
type Storage = [u8; size_of::<Option<ty::Value<'static>>>()];
261+
}
262+
259263
impl Erasable for rustc_hir::MaybeOwner<'_> {
260264
type Storage = [u8; size_of::<rustc_hir::MaybeOwner<'static>>()];
261265
}

compiler/rustc_mir_build/src/thir/constant.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use rustc_abi::Size;
22
use rustc_ast::{self as ast, UintTy};
33
use rustc_hir::LangItem;
4-
use rustc_middle::bug;
54
use rustc_middle::mir::interpret::LitToConstInput;
65
use rustc_middle::ty::{self, ScalarInt, TyCtxt, TypeVisitableExt as _};
76
use tracing::trace;
@@ -11,11 +10,11 @@ use crate::builder::parse_float_into_scalar;
1110
pub(crate) fn lit_to_const<'tcx>(
1211
tcx: TyCtxt<'tcx>,
1312
lit_input: LitToConstInput<'tcx>,
14-
) -> ty::Const<'tcx> {
13+
) -> Option<ty::Value<'tcx>> {
1514
let LitToConstInput { lit, ty, neg } = lit_input;
1615

17-
if let Err(guar) = ty.error_reported() {
18-
return ty::Const::new_error(tcx, guar);
16+
if ty.error_reported().is_err() {
17+
return None;
1918
}
2019

2120
let trunc = |n, width: ty::UintTy| {
@@ -29,7 +28,6 @@ pub(crate) fn lit_to_const<'tcx>(
2928
trace!("trunc result: {}", result);
3029

3130
ScalarInt::try_from_uint(result, width)
32-
.unwrap_or_else(|| bug!("expected to create ScalarInt from uint {:?}", result))
3331
};
3432

3533
let valtree = match (lit, ty.kind()) {
@@ -68,27 +66,25 @@ pub(crate) fn lit_to_const<'tcx>(
6866
ty::ValTree::from_branches(tcx, [ty::Const::new_value(tcx, bytes, *inner_ty)])
6967
}
7068
(ast::LitKind::Int(n, _), ty::Uint(ui)) if !neg => {
71-
let scalar_int = trunc(n.get(), *ui);
69+
let scalar_int = trunc(n.get(), *ui)?;
7270
ty::ValTree::from_scalar_int(tcx, scalar_int)
7371
}
7472
(ast::LitKind::Int(n, _), ty::Int(i)) => {
7573
// Unsigned "negation" has the same bitwise effect as signed negation,
7674
// which gets the result we want without additional casts.
7775
let scalar_int =
78-
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned());
76+
trunc(if neg { u128::wrapping_neg(n.get()) } else { n.get() }, i.to_unsigned())?;
7977
ty::ValTree::from_scalar_int(tcx, scalar_int)
8078
}
8179
(ast::LitKind::Bool(b), ty::Bool) => ty::ValTree::from_scalar_int(tcx, b.into()),
8280
(ast::LitKind::Float(n, _), ty::Float(fty)) => {
83-
let bits = parse_float_into_scalar(n, *fty, neg).unwrap_or_else(|| {
84-
tcx.dcx().bug(format!("couldn't parse float literal: {:?}", lit_input.lit))
85-
});
81+
let bits = parse_float_into_scalar(n, *fty, neg)?;
8682
ty::ValTree::from_scalar_int(tcx, bits)
8783
}
8884
(ast::LitKind::Char(c), ty::Char) => ty::ValTree::from_scalar_int(tcx, c.into()),
89-
(ast::LitKind::Err(guar), _) => return ty::Const::new_error(tcx, guar),
90-
_ => return ty::Const::new_misc_error(tcx),
85+
(ast::LitKind::Err(_), _) => return None,
86+
_ => return None,
9187
};
9288

93-
ty::Const::new_value(tcx, valtree, ty)
89+
Some(ty::Value { ty, valtree })
9490
}

compiler/rustc_mir_build/src/thir/pattern/mod.rs

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ use std::cmp::Ordering;
88
use std::sync::Arc;
99

1010
use rustc_abi::{FieldIdx, Integer};
11+
use rustc_ast::LitKind;
1112
use rustc_data_structures::assert_matches;
1213
use rustc_errors::codes::*;
1314
use rustc_hir::def::{CtorOf, DefKind, Res};
1415
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
15-
use rustc_hir::{self as hir, RangeEnd};
16+
use rustc_hir::{self as hir, LangItem, RangeEnd};
1617
use rustc_index::Idx;
1718
use rustc_middle::mir::interpret::LitToConstInput;
1819
use rustc_middle::thir::{
@@ -197,8 +198,6 @@ impl<'tcx> PatCtxt<'tcx> {
197198
expr: Option<&'tcx hir::PatExpr<'tcx>>,
198199
ty: Ty<'tcx>,
199200
) -> Result<(), ErrorGuaranteed> {
200-
use rustc_ast::ast::LitKind;
201-
202201
let Some(expr) = expr else {
203202
return Ok(());
204203
};
@@ -696,9 +695,58 @@ impl<'tcx> PatCtxt<'tcx> {
696695

697696
let pat_ty = self.typeck_results.node_type(pat.hir_id);
698697
let lit_input = LitToConstInput { lit: lit.node, ty: pat_ty, neg: *negated };
699-
let constant = self.tcx.at(expr.span).lit_to_const(lit_input);
698+
let error_const = || {
699+
if let Some(guar) = self.typeck_results.tainted_by_errors {
700+
ty::Const::new_error(self.tcx, guar)
701+
} else {
702+
ty::Const::new_error_with_message(
703+
self.tcx,
704+
expr.span,
705+
"literal does not match expected type",
706+
)
707+
}
708+
};
709+
let constant = if self.const_lit_matches_ty(&lit.node, pat_ty, *negated) {
710+
match self.tcx.at(expr.span).lit_to_const(lit_input) {
711+
Some(value) => ty::Const::new_value(self.tcx, value.valtree, value.ty),
712+
None => error_const(),
713+
}
714+
} else {
715+
error_const()
716+
};
700717
self.const_to_pat(constant, pat_ty, expr.hir_id, lit.span)
701718
}
702719
}
703720
}
721+
722+
fn const_lit_matches_ty(&self, kind: &LitKind, ty: Ty<'tcx>, neg: bool) -> bool {
723+
let tcx = self.tcx;
724+
match (*kind, ty.kind()) {
725+
(LitKind::Str(..), ty::Ref(_, inner_ty, _)) if inner_ty.is_str() => true,
726+
(LitKind::Str(..), ty::Str) if tcx.features().deref_patterns() => true,
727+
(LitKind::ByteStr(..), ty::Ref(_, inner_ty, _))
728+
if let ty::Slice(ty) | ty::Array(ty, _) = inner_ty.kind()
729+
&& matches!(ty.kind(), ty::Uint(ty::UintTy::U8)) =>
730+
{
731+
true
732+
}
733+
(LitKind::ByteStr(..), ty::Slice(inner_ty) | ty::Array(inner_ty, _))
734+
if tcx.features().deref_patterns()
735+
&& matches!(inner_ty.kind(), ty::Uint(ty::UintTy::U8)) =>
736+
{
737+
true
738+
}
739+
(LitKind::Byte(..), ty::Uint(ty::UintTy::U8)) => true,
740+
(LitKind::CStr(..), ty::Ref(_, inner_ty, _)) if matches!(inner_ty.kind(), ty::Adt(def, _) if tcx.is_lang_item(def.did(), LangItem::CStr)) => {
741+
true
742+
}
743+
(LitKind::Int(..), ty::Uint(_)) if !neg => true,
744+
(LitKind::Int(..), ty::Int(_)) => true,
745+
(LitKind::Bool(..), ty::Bool) => true,
746+
(LitKind::Float(..), ty::Float(_)) => true,
747+
(LitKind::Char(..), ty::Char) => true,
748+
(LitKind::Err(..), _) => true,
749+
_ => false,
750+
}
751+
}
704752
}

compiler/rustc_ty_utils/src/consts.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ fn recurse_build<'tcx>(
5959
}
6060
&ExprKind::Literal { lit, neg } => {
6161
let sp = node.span;
62-
tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg })
62+
match tcx.at(sp).lit_to_const(LitToConstInput { lit: lit.node, ty: node.ty, neg }) {
63+
Some(value) => ty::Const::new_value(tcx, value.valtree, value.ty),
64+
None => ty::Const::new_misc_error(tcx),
65+
}
6366
}
6467
&ExprKind::NonHirLiteral { lit, user_ty: _ } => {
6568
let val = ty::ValTree::from_scalar_int(tcx, lit);

tests/ui/const-generics/adt_const_params/byte-string-u8-validation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
struct ConstBytes<const T: &'static [*mut u8; 3]>
99
//~^ ERROR rustc_dump_predicates
1010
//~| NOTE Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
11-
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
11+
//~| NOTE Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
1212
where
1313
ConstBytes<b"AAA">: Sized;
1414
//~^ ERROR mismatched types
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
error: rustc_dump_predicates
2-
--> $DIR/byte-string-u8-validation.rs:8:1
3-
|
4-
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
5-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6-
|
7-
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
8-
= note: Binder { value: TraitPredicate(<ConstBytes<{const error}> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
9-
101
error[E0308]: mismatched types
112
--> $DIR/byte-string-u8-validation.rs:13:16
123
|
@@ -16,6 +7,15 @@ LL | ConstBytes<b"AAA">: Sized;
167
= note: expected reference `&'static [*mut u8; 3]`
178
found reference `&'static [u8; 3]`
189

10+
error: rustc_dump_predicates
11+
--> $DIR/byte-string-u8-validation.rs:8:1
12+
|
13+
LL | struct ConstBytes<const T: &'static [*mut u8; 3]>
14+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15+
|
16+
= note: Binder { value: ConstArgHasType(T/#0, &'static [*mut u8; 3_usize]), bound_vars: [] }
17+
= note: Binder { value: TraitPredicate(<ConstBytes<b"AAA"> as std::marker::Sized>, polarity:Positive), bound_vars: [] }
18+
1919
error: aborting due to 2 previous errors
2020

2121
For more information about this error, try `rustc --explain E0308`.

tests/ui/const-generics/adt_const_params/mismatch-raw-ptr-in-adt.stderr

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ LL | struct ConstBytes<const T: &'static [*mut u8; 3]>;
88
= note: `[*mut u8; 3]` must implement `ConstParamTy_`, but it does not
99

1010
error[E0308]: mismatched types
11-
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:46
11+
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:23
1212
|
1313
LL | let _: ConstBytes<b"AAA"> = ConstBytes::<b"BBB">;
14-
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
14+
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
1515
|
1616
= note: expected reference `&'static [*mut u8; 3]`
1717
found reference `&'static [u8; 3]`
1818

1919
error[E0308]: mismatched types
20-
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:23
20+
--> $DIR/mismatch-raw-ptr-in-adt.rs:9:46
2121
|
2222
LL | let _: ConstBytes<b"AAA"> = ConstBytes::<b"BBB">;
23-
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
23+
| ^^^^^^ expected `&[*mut u8; 3]`, found `&[u8; 3]`
2424
|
2525
= note: expected reference `&'static [*mut u8; 3]`
2626
found reference `&'static [u8; 3]`
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//! Regression test for <https://github.com/rust-lang/rust/issues/150983>
2+
#![expect(incomplete_features)]
3+
#![feature(
4+
generic_const_items,
5+
generic_const_parameter_types,
6+
min_generic_const_args,
7+
unsized_const_params
8+
)]
9+
use std::marker::ConstParamTy_;
10+
11+
struct Foo<T> {
12+
field: T,
13+
}
14+
15+
#[type_const]
16+
const WRAP<T : ConstParamTy_> : T = {
17+
Foo::<T>{field : 1} //~ ERROR: type annotations needed for the literal
18+
};
19+
20+
fn main() {}

0 commit comments

Comments
 (0)