Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 199a453

Browse files
committed
Refactor Image to accept const generic parameters
1 parent b8cea94 commit 199a453

File tree

11 files changed

+357
-226
lines changed

11 files changed

+357
-226
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/rustc_codegen_spirv/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use-compiled-tools = ["spirv-tools/use-compiled-tools"]
2929
[dependencies]
3030
bimap = "0.5"
3131
indexmap = "1.6.0"
32+
num-traits = "0.2.14"
3233
rspirv = { git = "https://github.com/gfx-rs/rspirv.git", rev = "01ca0d2e5b667a0e4ff1bc1804511e38f9a08759" }
3334
rustc-demangle = "0.1.18"
3435
serde = { version = "1.0", features = ["derive"] }

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rspirv::spirv::{Capability, StorageClass, Word};
88
use rustc_middle::bug;
99
use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout};
1010
use rustc_middle::ty::subst::SubstsRef;
11-
use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind, TypeAndMut};
11+
use rustc_middle::ty::{GeneratorSubsts, ParamEnv, PolyFnSig, Ty, TyKind, TypeAndMut};
1212
use rustc_span::Span;
1313
use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind};
1414
use rustc_target::abi::{
@@ -20,6 +20,9 @@ use std::collections::HashMap;
2020
use std::fmt;
2121
use std::fmt::Write;
2222

23+
use num_traits::FromPrimitive;
24+
use rspirv::spirv;
25+
2326
/// If a struct contains a pointer to itself, even indirectly, then doing a naiive recursive walk
2427
/// of the fields will result in an infinite loop. Because pointers are the only thing that are
2528
/// allowed to be recursive, keep track of what pointers we've translated, or are currently in the
@@ -797,28 +800,77 @@ fn trans_image<'tcx>(
797800
attr: SpirvAttribute,
798801
) -> Option<Word> {
799802
match attr {
800-
SpirvAttribute::Image {
801-
dim,
802-
depth,
803-
arrayed,
804-
multisampled,
805-
sampled,
806-
image_format,
807-
access_qualifier,
808-
} => {
803+
SpirvAttribute::Image => {
809804
// see SpirvType::sizeof
810805
if ty.size != Size::from_bytes(4) {
811806
cx.tcx.sess.err("#[spirv(image)] type must have size 4");
812807
return None;
813808
}
809+
814810
// Hardcode to float for now
815811
let sampled_type = SpirvType::Float(32).def(span, cx);
812+
813+
macro_rules! type_from_variant_index {
814+
($(let $name:ident : $ty:ty = $exp:expr);+ ;) => {
815+
$(
816+
let $name = <$ty>::from_u64(
817+
cx.tcx
818+
.destructure_const(ParamEnv::reveal_all().and($exp))
819+
.variant
820+
.unwrap()
821+
.as_u32() as u64,
822+
)
823+
.unwrap();
824+
)+
825+
}
826+
}
827+
828+
type_from_variant_index! {
829+
let dim: spirv::Dim = substs.const_at(0);
830+
let depth: u32 = substs.const_at(1);
831+
let sampled: u32 = substs.const_at(4);
832+
let image_format: spirv::ImageFormat = substs.const_at(5);
833+
}
834+
835+
let arrayed: bool = substs
836+
.const_at(2)
837+
.val
838+
.try_to_value()
839+
.unwrap()
840+
.try_to_bool()
841+
.unwrap();
842+
let multisampled: bool = substs
843+
.const_at(3)
844+
.val
845+
.try_to_value()
846+
.unwrap()
847+
.try_to_bool()
848+
.unwrap();
849+
let access_qualifier = {
850+
let option = cx
851+
.tcx
852+
.destructure_const(ParamEnv::reveal_all().and(substs.const_at(6)));
853+
854+
match option.variant.map(|i| i.as_u32()).unwrap_or(0) {
855+
0 => None,
856+
1 => spirv::AccessQualifier::from_u64(
857+
option.fields[0]
858+
.val
859+
.try_to_scalar()
860+
.unwrap()
861+
.to_u64()
862+
.unwrap(),
863+
),
864+
_ => unreachable!(),
865+
}
866+
};
867+
816868
let ty = SpirvType::Image {
817869
sampled_type,
818870
dim,
819871
depth,
820-
arrayed,
821-
multisampled,
872+
arrayed: arrayed as u32,
873+
multisampled: multisampled as u32,
822874
sampled,
823875
image_format,
824876
access_qualifier,

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 3 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use crate::builder::libm_intrinsics;
22
use crate::codegen_cx::CodegenCx;
3-
use rspirv::spirv::{
4-
AccessQualifier, BuiltIn, Dim, ExecutionMode, ExecutionModel, ImageFormat, StorageClass,
5-
};
3+
use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass};
64
use rustc_ast::ast::{AttrKind, Attribute, Lit, LitIntType, LitKind, NestedMetaItem};
75
use rustc_span::symbol::{Ident, Symbol};
86
use std::collections::HashMap;
@@ -33,13 +31,6 @@ pub struct Symbols {
3331
descriptor_set: Symbol,
3432
binding: Symbol,
3533
image: Symbol,
36-
dim: Symbol,
37-
depth: Symbol,
38-
arrayed: Symbol,
39-
multisampled: Symbol,
40-
sampled: Symbol,
41-
image_format: Symbol,
42-
access_qualifier: Symbol,
4334
attributes: HashMap<Symbol, SpirvAttribute>,
4435
execution_modes: HashMap<Symbol, (ExecutionMode, ExecutionModeExtraDim)>,
4536
pub libm_intrinsics: HashMap<Symbol, libm_intrinsics::LibmIntrinsic>,
@@ -386,13 +377,6 @@ impl Symbols {
386377
descriptor_set: Symbol::intern("descriptor_set"),
387378
binding: Symbol::intern("binding"),
388379
image: Symbol::intern("image"),
389-
dim: Symbol::intern("dim"),
390-
depth: Symbol::intern("depth"),
391-
arrayed: Symbol::intern("arrayed"),
392-
multisampled: Symbol::intern("multisampled"),
393-
sampled: Symbol::intern("sampled"),
394-
image_format: Symbol::intern("image_format"),
395-
access_qualifier: Symbol::intern("access_qualifier"),
396380
attributes,
397381
execution_modes,
398382
libm_intrinsics,
@@ -445,15 +429,7 @@ pub enum SpirvAttribute {
445429
DescriptorSet(u32),
446430
Binding(u32),
447431
ReallyUnsafeIgnoreBitcasts,
448-
Image {
449-
dim: Dim,
450-
depth: u32,
451-
arrayed: u32,
452-
multisampled: u32,
453-
sampled: u32,
454-
image_format: ImageFormat,
455-
access_qualifier: Option<AccessQualifier>,
456-
},
432+
Image,
457433
Sampler,
458434
SampledImage,
459435
Block,
@@ -493,7 +469,7 @@ pub fn parse_attrs(
493469
};
494470
args.into_iter().filter_map(move |ref arg| {
495471
if arg.has_name(cx.sym.image) {
496-
parse_image(cx, arg)
472+
Some(SpirvAttribute::Image)
497473
} else if arg.has_name(cx.sym.descriptor_set) {
498474
match parse_attr_int_value(cx, arg) {
499475
Some(x) => Some(SpirvAttribute::DescriptorSet(x)),
@@ -537,136 +513,6 @@ pub fn parse_attrs(
537513
result.collect::<Vec<_>>().into_iter()
538514
}
539515

540-
fn parse_image(cx: &CodegenCx<'_>, attr: &NestedMetaItem) -> Option<SpirvAttribute> {
541-
let args = match attr.meta_item_list() {
542-
Some(args) => args,
543-
None => {
544-
cx.tcx
545-
.sess
546-
.span_err(attr.span(), "image attribute must have arguments");
547-
return None;
548-
}
549-
};
550-
if args.len() != 6 && args.len() != 7 {
551-
cx.tcx
552-
.sess
553-
.span_err(attr.span(), "image attribute must have 6 or 7 arguments");
554-
return None;
555-
}
556-
let check = |idx: usize, sym: Symbol| -> bool {
557-
if args[idx].has_name(sym) {
558-
false
559-
} else {
560-
cx.tcx.sess.span_err(
561-
args[idx].span(),
562-
&format!("image attribute argument {} must be {}=...", idx + 1, sym),
563-
);
564-
true
565-
}
566-
};
567-
if check(0, cx.sym.dim)
568-
| check(1, cx.sym.depth)
569-
| check(2, cx.sym.arrayed)
570-
| check(3, cx.sym.multisampled)
571-
| check(4, cx.sym.sampled)
572-
| check(5, cx.sym.image_format)
573-
| (args.len() == 7 && check(6, cx.sym.access_qualifier))
574-
{
575-
return None;
576-
}
577-
let arg_values = args
578-
.iter()
579-
.map(
580-
|arg| match arg.meta_item().and_then(|arg| arg.name_value_literal()) {
581-
Some(arg) => Some(arg),
582-
None => {
583-
cx.tcx
584-
.sess
585-
.span_err(arg.span(), "image attribute must be name=value");
586-
None
587-
}
588-
},
589-
)
590-
.collect::<Option<Vec<_>>>()?;
591-
let dim = match arg_values[0].kind {
592-
LitKind::Str(dim, _) => match dim.with(|s| s.parse()) {
593-
Ok(dim) => dim,
594-
Err(()) => {
595-
cx.tcx.sess.span_err(args[0].span(), "invalid dim value");
596-
return None;
597-
}
598-
},
599-
_ => {
600-
cx.tcx
601-
.sess
602-
.span_err(args[0].span(), "dim value must be str");
603-
return None;
604-
}
605-
};
606-
let parse_lit = |idx: usize, name: &str| -> Option<u32> {
607-
match arg_values[idx].kind {
608-
LitKind::Int(v, _) => Some(v as u32),
609-
_ => {
610-
cx.tcx
611-
.sess
612-
.span_err(args[idx].span(), &format!("{} value must be int", name));
613-
None
614-
}
615-
}
616-
};
617-
let depth = parse_lit(1, "depth")?;
618-
let arrayed = parse_lit(2, "arrayed")?;
619-
let multisampled = parse_lit(3, "multisampled")?;
620-
let sampled = parse_lit(4, "sampled")?;
621-
let image_format = match arg_values[5].kind {
622-
LitKind::Str(dim, _) => match dim.with(|s| s.parse()) {
623-
Ok(dim) => dim,
624-
Err(()) => {
625-
cx.tcx
626-
.sess
627-
.span_err(args[5].span(), "invalid image_format value");
628-
return None;
629-
}
630-
},
631-
_ => {
632-
cx.tcx
633-
.sess
634-
.span_err(args[5].span(), "image_format value must be str");
635-
return None;
636-
}
637-
};
638-
let access_qualifier = if args.len() == 7 {
639-
Some(match arg_values[6].kind {
640-
LitKind::Str(dim, _) => match dim.with(|s| s.parse()) {
641-
Ok(dim) => dim,
642-
Err(()) => {
643-
cx.tcx
644-
.sess
645-
.span_err(args[6].span(), "invalid access_qualifier value");
646-
return None;
647-
}
648-
},
649-
_ => {
650-
cx.tcx
651-
.sess
652-
.span_err(args[6].span(), "access_qualifier value must be str");
653-
return None;
654-
}
655-
})
656-
} else {
657-
None
658-
};
659-
Some(SpirvAttribute::Image {
660-
dim,
661-
depth,
662-
arrayed,
663-
multisampled,
664-
sampled,
665-
image_format,
666-
access_qualifier,
667-
})
668-
}
669-
670516
fn parse_attr_int_value(cx: &CodegenCx<'_>, arg: &NestedMetaItem) -> Option<u32> {
671517
let arg = match arg.meta_item() {
672518
Some(arg) => arg,

crates/spirv-builder/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result<PathBuf, SpirvBuilderError> {
143143
format!(" -C target-feature={}", target_features.join(","))
144144
};
145145
let rustflags = format!(
146-
"-Z codegen-backend={} -Z symbol-mangling-version=v0{}",
146+
"-Z codegen-backend={} {}",
147147
rustc_codegen_spirv.display(),
148148
feature_flag,
149149
);

crates/spirv-std/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ license = "MIT OR Apache-2.0"
77
repository = "https://github.com/EmbarkStudios/rust-gpu"
88
description = "Standard functions and types for SPIR-V"
99

10+
[profile.dev]
11+
incremental = false
12+
13+
[profile.release]
14+
incremental = false
15+
1016
[dependencies]
1117
glam = { version = "0.11.3", default-features = false, features = ["libm", "scalar-math"] }
1218
num-traits = { version = "0.2.14", default-features = false }

crates/spirv-std/src/lib.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,31 @@
3838
future_incompatible,
3939
nonstandard_style
4040
)]
41+
#![feature(const_generics)]
42+
#![allow(incomplete_features)]
4143

4244
#[cfg(not(target_arch = "spirv"))]
4345
#[macro_use]
4446
pub extern crate spirv_std_macros;
4547

48+
pub(crate) mod sealed;
4649
pub mod storage_class;
4750
mod textures;
4851

4952
pub use glam;
5053
pub use num_traits;
5154
pub use textures::*;
5255

56+
/// Marker trait for arguments that accept single scalar values or vectors
57+
/// of scalars.
58+
pub trait ScalarOrVector: sealed::Sealed {}
59+
60+
impl ScalarOrVector for f32 {}
61+
impl ScalarOrVector for glam::Vec2 {}
62+
impl ScalarOrVector for glam::Vec3 {}
63+
impl ScalarOrVector for glam::Vec3A {}
64+
impl ScalarOrVector for glam::Vec4 {}
65+
5366
#[cfg(all(not(test), target_arch = "spirv"))]
5467
#[panic_handler]
5568
fn panic(_: &core::panic::PanicInfo<'_>) -> ! {

crates/spirv-std/src/sealed.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/// Sealed trait to ensure certain traits can't be implemented outside
2+
/// of `spirv-std`.
3+
pub trait Sealed {}
4+
5+
impl Sealed for f32 {}
6+
impl Sealed for glam::Vec2 {}
7+
impl Sealed for glam::Vec3 {}
8+
impl Sealed for glam::Vec3A {}
9+
impl Sealed for glam::Vec4 {}

0 commit comments

Comments
 (0)