Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Next Next commit
Add basic support for struct DSTs
  • Loading branch information
Hentropy committed Mar 19, 2021
commit 194777738742f69673d72b7cc3815e0748ea054d
114 changes: 84 additions & 30 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::CodegenCx;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::SpirvValue;
use crate::spirv_type::SpirvType;
use crate::symbols::{parse_attrs, Entry, SpirvAttribute};
Expand All @@ -9,7 +8,10 @@ use rustc_hir as hir;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{Instance, Ty};
use rustc_span::Span;
use rustc_target::abi::call::{FnAbi, PassMode};
use rustc_target::abi::{
call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode},
Size,
};
use std::collections::HashMap;

impl<'tcx> CodegenCx<'tcx> {
Expand All @@ -36,8 +38,27 @@ impl<'tcx> CodegenCx<'tcx> {
};
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
const EMPTY: ArgAttribute = ArgAttribute::empty();
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
if let PassMode::Direct(_) = abi.mode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this should probably be turned into a match, the reason it's an if let is because there's only one case

} else if let PassMode::Pair(
// plain DST/RTA/VLA
ArgAttributes {
pointee_size: Size::ZERO,
..
},
ArgAttributes { regular: EMPTY, .. },
) = abi.mode
{
} else if let PassMode::Pair(
// DST struct with fields before the DST member
ArgAttributes { .. },
ArgAttributes {
pointee_size: Size::ZERO,
..
},
) = abi.mode
{
} else {
self.tcx.sess.span_err(
arg.span,
Expand All @@ -62,7 +83,7 @@ impl<'tcx> CodegenCx<'tcx> {
self.shader_entry_stub(
self.tcx.def_span(instance.def_id()),
entry_func,
fn_abi,
&fn_abi.args,
body.params,
name,
execution_model,
Expand All @@ -81,7 +102,7 @@ impl<'tcx> CodegenCx<'tcx> {
&self,
span: Span,
entry_func: SpirvValue,
entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
arg_abis: &[ArgAbi<'tcx, Ty<'tcx>>],
Comment on lines -85 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is a good idea, especially if we want to do something with the return type at some point.

hir_params: &[hir::Param<'tcx>],
name: String,
execution_model: ExecutionModel,
Expand All @@ -92,52 +113,85 @@ impl<'tcx> CodegenCx<'tcx> {
arguments: vec![],
}
.def(span, self);
let entry_func_return_type = match self.lookup_type(entry_func.ty) {
let (entry_func_return_type, entry_func_arg_types) = match self.lookup_type(entry_func.ty) {
SpirvType::Function {
return_type,
arguments: _,
} => return_type,
arguments,
} => (return_type, arguments),
other => self.tcx.sess.fatal(&format!(
"Invalid entry_stub type: {}",
other.debug(entry_func.ty, self)
)),
};
let mut decoration_locations = HashMap::new();
// Create OpVariables before OpFunction so they're global instead of local vars.
let arguments = entry_fn_abi
.args
.iter()
.zip(hir_params)
.map(|(entry_fn_arg, hir_param)| {
self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations)
})
.collect::<Vec<_>>();
let new_spirv = self.emit_global().version().unwrap() > (1, 3);
let arg_len = arg_abis.len();
let mut arguments = Vec::with_capacity(arg_len);
let mut interface = Vec::with_capacity(arg_len);
let mut rta_lens = Vec::with_capacity(arg_len / 2);
let mut arg_types = entry_func_arg_types.iter();
for (hir_param, arg_abi) in hir_params.iter().zip(arg_abis) {
// explicit next because there are two args for scalar pairs, but only one param & abi
let arg_t = *arg_types.next().unwrap_or_else(|| {
self.tcx.sess.span_fatal(
hir_param.span,
&format!(
"Invalid function arguments: Param {:?} Abi {:?} missing type",
hir_param, arg_abi.layout.ty
),
)
});
let (argument, storage_class) =
self.declare_parameter(arg_abi.layout, hir_param, arg_t, &mut decoration_locations);
// SPIR-V <= v1.3 only includes Input and Output in the interface.
if new_spirv
|| storage_class == StorageClass::Input
|| storage_class == StorageClass::Output
{
interface.push(argument);
}
arguments.push(argument);
if let SpirvType::Pointer { pointee } = self.lookup_type(arg_t) {
if let SpirvType::Adt {
size: None,
field_types,
..
} = self.lookup_type(pointee)
{
let len_t = *arg_types.next().unwrap_or_else(|| {
self.tcx.sess.span_fatal(
hir_param.span,
&format!(
"Invalid function arguments: Param {:?} Abi {:?} fat pointer missing length",
hir_param, arg_abi.layout.ty
),
)
});
rta_lens.push((arguments.len() as u32, len_t, field_types.len() as u32 - 1));
arguments.push(u32::MAX);
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pretty ad-hoc and confusing to follow (splitting it up into functions would probably help), and has some issues with it. For example, this compiletest passes, but it shouldn't (no way to get the length of the slice):

// build-pass

use spirv_std::storage_class::{Output, Uniform};

#[spirv(fragment)]
pub fn shader(
    #[spirv(descriptor_set = 0, binding = 0)] slice: Uniform<&[f32]>,
    mut out: Output<f32>,
) {
    let x = slice[5];
    *out = x;
}

let mut emit = self.emit_global();
let fn_id = emit
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
.unwrap();
emit.begin_block(None).unwrap();
rta_lens.iter().for_each(|&(len_idx, len_t, member_idx)| {
arguments[len_idx as usize] = emit
.array_length(len_t, None, arguments[len_idx as usize - 1], member_idx)
.unwrap()
});
emit.function_call(
entry_func_return_type,
None,
entry_func.def_cx(self),
arguments.iter().map(|&(a, _)| a),
arguments,
)
.unwrap();
emit.ret().unwrap();
emit.end_function().unwrap();

let interface: Vec<_> = if emit.version().unwrap() > (1, 3) {
// SPIR-V >= v1.4 includes all OpVariables in the interface.
arguments.into_iter().map(|(a, _)| a).collect()
} else {
// SPIR-V <= v1.3 only includes Input and Output in the interface.
arguments
.into_iter()
.filter(|&(_, s)| s == StorageClass::Input || s == StorageClass::Output)
.map(|(a, _)| a)
.collect()
};
emit.entry_point(execution_model, fn_id, name, interface);
fn_id
}
Expand All @@ -146,6 +200,7 @@ impl<'tcx> CodegenCx<'tcx> {
&self,
layout: TyAndLayout<'tcx>,
hir_param: &hir::Param<'tcx>,
arg_t: Word,
decoration_locations: &mut HashMap<StorageClass, u32>,
) -> (Word, StorageClass) {
let storage_class = crate::abi::get_storage_class(self, layout).unwrap_or_else(|| {
Expand All @@ -159,10 +214,9 @@ impl<'tcx> CodegenCx<'tcx> {
StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant
);
// Note: this *declares* the variable too.
let spirv_type = layout.spirv_type(hir_param.span, self);
let variable = self
.emit_global()
.variable(spirv_type, None, storage_class, None);
.variable(arg_t, None, storage_class, None);
if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
self.emit_global().name(variable, ident.to_string());
}
Expand Down