diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 9619c80904426..1737335410054 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -57,6 +57,9 @@ impl TypeTree { } Self(ints) } + pub fn add_indirection(self) -> Self { + Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }]) + } } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)] diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index c3bf566b35880..53104fe3a0ecc 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -34,6 +34,8 @@ trait ArgAttributesExt { callsite: &Value, ); } + use crate::abi::ty::print::with_no_trimmed_paths; +use rustc_codegen_ssa::mir::operand::scalar_pair_component_field_ty; const ABI_AFFECTING_ATTRIBUTES: [(ArgAttribute, llvm::AttributeKind); 1] = [(ArgAttribute::InReg, llvm::AttributeKind::InReg)]; @@ -262,6 +264,16 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { let llscratch = bx.alloca(scratch_size, scratch_align); bx.lifetime_start(llscratch, scratch_size); // ...store the value... + + let f0 = scalar_pair_component_field_ty(bx, dst.layout, 0); + let f1 = scalar_pair_component_field_ty(bx, dst.layout, 1); + + if f1.is_some() && f0.is_some() { + with_no_trimmed_paths!({ + eprintln!("Cast of extractvalue 0 field = {:?}", f0.map(|f| f0.unwrap())); + eprintln!("Cast of extractvalue 1 field = {:?}", f1.map(|f| f1.unwrap())); + }); + } rustc_codegen_ssa::mir::store_cast(bx, cast, val, llscratch, scratch_align); // ... and then memcpy it to the intended destination. bx.memcpy( diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs index c5ab9fc2336eb..1f053300e8963 100644 --- a/compiler/rustc_codegen_llvm/src/asm.rs +++ b/compiler/rustc_codegen_llvm/src/asm.rs @@ -363,7 +363,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let value = if output_types.len() == 1 { result } else { - self.extract_value(result, op_idx[&idx] as u64) + self.extract_value(result, op_idx[&idx] as u64, None) }; let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 134bc5006dd00..9291cdb12ca51 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,7 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::iter; use std::ops::Deref; -use rustc_ast::expand::typetree::FncTree; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -38,6 +38,7 @@ use crate::llvm::{ ToLlvmBool, Type, Value, }; use crate::type_of::LayoutLlvmExt; +use rustc_middle::ty::type_tree::typetree_from_ty; #[must_use] pub(crate) struct GenericBuilder<'a, 'll, CX: Borrow>> { @@ -181,11 +182,12 @@ impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { } pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { - unsafe { + let load = unsafe { let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); llvm::LLVMSetAlignment(load, align.bytes() as c_uint); load - } + }; + load } } @@ -585,7 +587,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let name = format!("llvm.{}{oop_str}.with.overflow", if signed { 's' } else { 'u' }); let res = self.call_intrinsic(name, &[self.type_ix(width)], &[lhs, rhs]); - (self.extract_value(res, 0), self.extract_value(res, 1)) + (self.extract_value(res, 0, None), self.extract_value(res, 1, None)) } fn from_immediate(&mut self, val: Self::Value) -> Self::Value { @@ -627,13 +629,17 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { - unsafe { + fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align, tt: Option) -> &'ll Value { + let load = unsafe { let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); let align = align.min(self.cx().tcx.sess.target.max_reliable_alignment()); llvm::LLVMSetAlignment(load, align.bytes() as c_uint); load + }; + if let Some(tt) = tt { + //crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, tt); } + load } fn volatile_load(&mut self, ty: &'ll Type, ptr: &'ll Value) -> &'ll Value { @@ -734,6 +740,70 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let llval = const_llval.unwrap_or_else(|| { let load = self.load(llty, place.val.llval, place.val.align); + //let layout = place.layout.ty_and_layout_pointee_info_at(self.cx(), Size::ZERO).unwrap(); + let ty = place.layout.ty; + let tt = typetree_from_ty(self.tcx, ty); + if tt != rustc_ast::expand::typetree::TypeTree::new() { + use rustc_middle::ty::print::with_no_trimmed_paths; + //dbg!("add_tt start!"); + //dbg!(&load); + //dbg!(&tt); + //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); + let fnc_tree = FncTree { + args: vec![TypeTree::new(), TypeTree::new()], + ret: tt, + }; + // TODO: re-enable? + //crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); + //dbg!("add_tt done!"); + } + //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); + // 25 general load of place = PlaceRef { + // 24 val: PlaceValue { + // 23 llval: (ptr: %3 = alloca [8 x i8], align 8), + // 22 llextra: None, + // 21 align: Align(8 bytes), + // 20 }, + // 19 layout: TyAndLayout { + // 18 ty: &([f64; 3], [f64; 3]), + // 17 layout: Layout { + // 16 size: Size(8 bytes), + // 15 align: AbiAlign { + // 14 abi: Align(8 bytes), + // 13 }, + // 12 backend_repr: Scalar( + // 11 Initialized { + // 10 value: Pointer( + // 9 AddressSpace( + // 8 0, + // 7 ), + // 6 ), + // 5 valid_range: 1..=18446744073709551615, + // 4 }, + // 3 ), + // 2 fields: Primitive, + // 1 largest_niche: Some( + // 259 Niche { + // 1 offset: Size(0 bytes), + // 2 value: Pointer( + // 3 AddressSpace( + // 4 0, + // 5 ), + // 6 ), + // 7 valid_range: 1..=18446744073709551615, + // 8 }, + // 9 ), + // 10 uninhabited: false, + // 11 variants: Single { + // 12 index: 0, + // 13 }, + // 14 max_repr_align: None, + // 15 unadjusted_abi_align: Align(8 bytes), + // 16 randomization_seed: 281492156579847, + // 17 }, + // 18 }, + // 19 } + if let abi::BackendRepr::Scalar(scalar) = place.layout.backend_repr { scalar_load_metadata(self, load, scalar, place.layout, Size::ZERO); self.to_immediate_scalar(load, scalar) @@ -1113,7 +1183,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { // vs. copying a struct with mixed types requires different derivative handling. // The TypeTree tells Enzyme exactly what memory layout to expect. if let Some(tt) = tt { - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, self.tcx, tt); } } @@ -1125,11 +1195,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memmove = unsafe { llvm::LLVMRustBuildMemMove( self.llbuilder, dst, @@ -1138,7 +1209,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memmove, self.tcx, tt); } } @@ -1149,10 +1223,11 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { size: &'ll Value, align: Align, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memset not supported"); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memset = unsafe { llvm::LLVMRustBuildMemSet( self.llbuilder, ptr, @@ -1160,7 +1235,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fill_byte, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memset, self.tcx, tt); } } @@ -1191,9 +1269,15 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn extract_value(&mut self, agg_val: &'ll Value, idx: u64) -> &'ll Value { + fn extract_value(&mut self, agg_val: &'ll Value, idx: u64, tt: Option) -> &'ll Value { assert_eq!(idx as c_uint as u64, idx); - unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) } + let ev = unsafe { + llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, ev, self.tcx, tt); + } + ev } fn insert_value(&mut self, agg_val: &'ll Value, elt: &'ll Value, idx: u64) -> &'ll Value { @@ -1213,7 +1297,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { unsafe { llvm::LLVMSetCleanup(landing_pad, llvm::TRUE); } - (self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1)) + (self.extract_value(landing_pad, 0, None), self.extract_value(landing_pad, 1, None)) } fn filter_landing_pad(&mut self, pers_fn: &'ll Value) { @@ -1314,8 +1398,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { llvm::FALSE, // SingleThreaded ); llvm::LLVMSetWeak(value, weak.to_llvm_bool()); - let val = self.extract_value(value, 0); - let success = self.extract_value(value, 1); + let val = self.extract_value(value, 0, None); + let success = self.extract_value(value, 1, None); (val, success) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 1cefdaae5ebde..ed07bc8af823e 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -290,6 +290,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_name: &str, @@ -375,7 +376,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( ); if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() { - crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree); + crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, tcx, fnc_tree); } let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 0b009321802cf..39e3ac8eb484f 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -162,9 +162,9 @@ impl<'ll> OffloadKernelDims<'ll> { builder: &mut Builder<'_, 'll, 'tcx>, arr: &'ll Value, ) -> &'ll Value { - let x = builder.extract_value(arr, 0); - let y = builder.extract_value(arr, 1); - let z = builder.extract_value(arr, 2); + let x = builder.extract_value(arr, 0, None); + let y = builder.extract_value(arr, 1, None); + let z = builder.extract_value(arr, 2, None); let xy = builder.mul(x, y); builder.mul(xy, z) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 1c7b415fd04c7..3952f2a143e1b 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -6,6 +6,7 @@ use rustc_abi::{ AddressSpace, Align, BackendRepr, CVariadicStatus, Float, HasDataLayout, Integer, NumScalableVectors, Primitive, Size, WrappingRange, }; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_codegen_ssa::RetagInfo; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; @@ -46,6 +47,8 @@ use crate::errors::{ use crate::llvm::{self, Type, Value}; use crate::type_of::LayoutLlvmExt; use crate::va_arg::emit_va_arg; +use rustc_middle::ty::type_tree::typetree_from_ty; +use rustc_middle::ty::type_tree::fnc_typetrees; fn call_simple_intrinsic<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, @@ -780,6 +783,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { self.extract_value( args[0].immediate(), fn_args.const_at(2).to_leaf().to_i32() as u64, + None, ) } @@ -1063,7 +1067,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[], &[llvtable, vtable_byte_offset, typeid], ); - self.extract_value(type_checked_load, 0) + self.extract_value(type_checked_load, 0, None) } fn va_start(&mut self, va_list: &'ll Value) { @@ -1179,7 +1183,7 @@ fn autocast<'ll>( iter::zip(bx.struct_element_types(src_ty), bx.struct_element_types(dest_ty)) .enumerate() { - let elt = bx.extract_value(val, idx as u64); + let elt = bx.extract_value(val, idx as u64, None); let casted_elt = autocast(bx, elt, src_element_ty, dest_element_ty); ret = bx.insert_value(ret, casted_elt, idx as u64); } @@ -1642,7 +1646,7 @@ fn codegen_gnu_try<'ll, 'tcx>( let vals = bx.landing_pad(lpad_ty, bx.eh_personality(), 1); let tydesc = bx.const_null(bx.type_ptr()); bx.add_clause(vals, tydesc); - let ptr = bx.extract_value(vals, 0); + let ptr = bx.extract_value(vals, 0, None); let catch_ty = bx.type_func(&[bx.type_ptr(), bx.type_ptr()], bx.type_void()); bx.call(catch_ty, None, None, catch_func, &[data, ptr], None, None); bx.ret(bx.const_bool(true)); @@ -1704,8 +1708,8 @@ fn codegen_emcc_try<'ll, 'tcx>( let vals = bx.landing_pad(lpad_ty, bx.eh_personality(), 2); bx.add_clause(vals, tydesc); bx.add_clause(vals, bx.const_null(bx.type_ptr())); - let ptr = bx.extract_value(vals, 0); - let selector = bx.extract_value(vals, 1); + let ptr = bx.extract_value(vals, 0, None); + let selector = bx.extract_value(vals, 1, None); // Check if the typeid we got is the one for a Rust panic. let rust_typeid = bx.call_intrinsic("llvm.eh.typeid.for", &[bx.val_ty(tydesc)], &[tydesc]); @@ -1878,11 +1882,12 @@ fn codegen_autodiff<'ll, 'tcx>( &mut diff_attrs.input_activity, ); - let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty); + let fnc_tree = fnc_typetrees(tcx, source_fn_ptr_ty); // Build body generate_enzyme_call( bx, + tcx, bx.cx, fn_to_diff, &diff_symbol, @@ -2005,8 +2010,18 @@ fn get_args_from_tuple<'ll, 'tcx>( let field = tuple_place.project_field(bx, tuple_index); let llvm_ty = field.layout.llvm_type(bx.cx); let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); - result.push(bx.extract_value(pair_val, 0)); - result.push(bx.extract_value(pair_val, 1)); + let extract_ty = field.layout.ty; + let tt = typetree_from_ty(bx.tcx(), extract_ty); + dbg!("intrinsic pair"); + dbg!(&tt); + let fnc = FncTree { + args: vec![],//TypeTree::new() + ret: tt, + }; + //let tt0 = enzyme_type_from_ty /* TypeTree for extracted element 0 */; + //let tt1 = /* TypeTree for extracted element 1 */; + result.push(bx.extract_value(pair_val, 0, Some(fnc.clone()))); + result.push(bx.extract_value(pair_val, 1, Some(fnc))); tuple_index += 1; } PassMode::Indirect { .. } => { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 195e050a9b651..3430bdbf37b22 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -65,7 +65,12 @@ unsafe extern "C" { name: *const c_char, NameLen: libc::size_t, ) -> Option<&Value>; - + //pub(crate) fn LLVMRustIsPtrLoad(v: &Value) -> bool; + pub(crate) fn LLVMRustIsLoadOrExtractValue(v: &Value) -> bool; + pub(crate) fn LLVMRustSetEnzymeTypeMetadata( + v: &Value, + md: &Value, + ); } unsafe extern "C" { @@ -90,6 +95,7 @@ pub(crate) use self::Enzyme_AD::*; pub(crate) mod Enzyme_AD { use std::ffi::{c_char, c_void}; use std::sync::{Mutex, MutexGuard, OnceLock}; + use super::Value; use rustc_middle::bug; use rustc_session::config::{Sysroot, host_tuple}; @@ -114,6 +120,7 @@ pub(crate) mod Enzyme_AD { unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context); type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char; type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char); + type EnzymeTypeTreeToMDFn = unsafe extern "C" fn(CTypeTreeRef, &Context) -> Option<&Value>; #[allow(non_snake_case)] pub(crate) struct EnzymeWrapper { @@ -127,6 +134,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn, EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, + EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn, EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, EnzymePrintPerf: *mut c_void, @@ -292,6 +300,10 @@ pub(crate) mod Enzyme_AD { unsafe { (self.EnzymeTypeTreeToString)(tree) } } + pub(crate) fn tree_to_md<'a>(&'a self, tree: *mut EnzymeTypeTree, ctx: &'a Context) -> Option<&'a Value> { + unsafe { (self.EnzymeTypeTreeToMD)(tree, ctx) } + } + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } @@ -381,6 +393,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, EnzymeSetCLBool: EnzymeSetCLBoolFn, EnzymeSetCLString: EnzymeSetCLStringFn, + EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn, ); load_ptrs_by_symbols_mut_void!( @@ -422,6 +435,7 @@ pub(crate) mod Enzyme_AD { looseTypeAnalysis, EnzymeSetCLBool, EnzymeSetCLString, + EnzymeTypeTreeToMD, registerEnzymeAndPassPipeline, lib, }) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 1fde5866f5dca..50c13d09a7a90 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2453,6 +2453,7 @@ unsafe extern "C" { FileType: FileType, VerifyIR: bool, ) -> LLVMRustResult; + pub(crate) fn LLVMRustIsIntrinsicCall(val: &Value) -> bool; pub(crate) fn LLVMRustOptimize<'a>( M: &'a Module, TM: &'a TargetMachine, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 4f433f273c8cc..a89d5d18c3ba9 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,9 +1,10 @@ use std::ffi::{CString, c_char, c_uint}; use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; - +use crate::llvm::LLVMRustSetEnzymeTypeMetadata; +use crate::llvm::LLVMRustIsLoadOrExtractValue; use crate::attributes; -use crate::llvm::{self, EnzymeWrapper, Value}; +use crate::llvm::{self, EnzymeWrapper, TypeTree, Value}; fn to_enzyme_typetree( rust_typetree: RustTypeTree, @@ -56,8 +57,15 @@ pub(crate) fn add_tt<'ll>( llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, + tcx: rustc_middle::ty::TyCtxt<'_>, tt: FncTree, ) { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return; + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return; + } // TypeTree processing uses functions from Enzyme, which we might not have available if we did // not build this compiler with `llvm_enzyme`. This feature is not strictly necessary, but // skipping this function increases the chance that Enzyme fails to compile some code. @@ -69,6 +77,7 @@ pub(crate) fn add_tt<'ll>( let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; + //dbg!("getting DataLayout"); let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; let llvm_data_layout = std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) @@ -76,9 +85,14 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); + //dbg!("going to iter over inputs"); for (i, input) in inputs.iter().enumerate() { unsafe { + if *input == rustc_ast::expand::typetree::TypeTree::new() { + //dbg!("skipping empty input tt"); + continue; + } let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); let enzyme_wrapper = EnzymeWrapper::get_instance(); let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); @@ -91,16 +105,51 @@ pub(crate) fn add_tt<'ll>( c_str.as_ptr(), c_str.to_bytes().len() as c_uint, ); - - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + //dbg!("adding attribute for argument {}", i); + //dbg!("attribute string: {:?}", c_str); + //dbg!(&fn_def); + + if llvm::LLVMRustIsIntrinsicCall(fn_def) { + //dbg!("intrinsic"); + attributes::apply_to_callsite(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + //} else if LLVMRustIsPtrLoad(fn_def) { + } else if LLVMRustIsLoadOrExtractValue(fn_def) { + //dbg!("skipping input args for instr"); + } else { + //dbg!("fn call"); + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } + //dbg!("finished to iter over inputs"); unsafe { + if ret_tt == rustc_ast::expand::typetree::TypeTree::new() { + //dbg!("skipping empty return tt"); + return; + } let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); let enzyme_wrapper = EnzymeWrapper::get_instance(); let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); + // just printing + //let ptr = wrapper.tree_to_string(self.inner); + let cstr = unsafe { std::ffi::CStr::from_ptr(c_str) }; + use std::io::Write as _; + let mut stderr = std::io::stderr().lock(); + + match cstr.to_str() { + Ok(x) => { + writeln!(stderr, "parsed: {:?}", x).ok(); + } + Err(err) => { + writeln!(stderr, "could not parse: {}", err).ok(); + } + } + + // delete C string pointer + //wrapper.tree_to_string_free(ptr); + // done printing let c_str = std::ffi::CStr::from_ptr(c_str); let ret_attr = llvm::LLVMCreateStringAttribute( @@ -111,7 +160,20 @@ pub(crate) fn add_tt<'ll>( c_str.to_bytes().len() as c_uint, ); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + dbg!(&fn_def); + + if llvm::LLVMRustIsIntrinsicCall(fn_def) { + dbg!("intrinsic call"); + attributes::apply_to_callsite(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + //} else if LLVMRustIsPtrLoad(fn_def) { + } else if LLVMRustIsLoadOrExtractValue(fn_def) { + let val = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); + LLVMRustSetEnzymeTypeMetadata(fn_def, val.unwrap()); + } else { + dbg!("fn call"); + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } + dbg!("finished to add return attribute"); } diff --git a/compiler/rustc_codegen_ssa/src/meth.rs b/compiler/rustc_codegen_ssa/src/meth.rs index b87034f9b33b7..08cd29289b7ca 100644 --- a/compiler/rustc_codegen_ssa/src/meth.rs +++ b/compiler/rustc_codegen_ssa/src/meth.rs @@ -146,7 +146,7 @@ pub(crate) fn load_vtable<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( } let gep = bx.inbounds_ptradd(llvtable, bx.const_usize(vtable_byte_offset)); - let ptr = bx.load(llty, gep, ptr_align); + let ptr = bx.load(llty, gep, ptr_align, None); // VTable loads are invariant. bx.set_invariant_load(ptr); if nonnull { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index b6b95c5f12aae..c6a14e2e1ae08 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1815,7 +1815,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // used for this call is passing it by-value. In that case, // the load would just produce `OperandValue::Ref` instead // of the `OperandValue::Immediate` we need for the call. - llval = bx.load(bx.backend_type(arg.layout), llval, align); + llval = bx.load(bx.backend_type(arg.layout), llval, align, None); if let BackendRepr::Scalar(scalar) = arg.layout.backend_repr { if scalar.is_bool() { bx.range_metadata(llval, WrappingRange { start: 0, end: 1 }); @@ -2237,14 +2237,14 @@ fn load_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( assert_eq!(cast.rest.unit.size, cast.rest.total); let first_ty = bx.reg_backend_type(&cast.prefix[0].unwrap()); let second_ty = bx.reg_backend_type(&cast.rest.unit); - let first = bx.load(first_ty, ptr, align); + let first = bx.load(first_ty, ptr, align, None); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); - let second = bx.load(second_ty, second_ptr, align.restrict_for_offset(offset_from_start)); + let second = bx.load(second_ty, second_ptr, align.restrict_for_offset(offset_from_start), None); let res = bx.cx().const_poison(cast_ty); let res = bx.insert_value(res, first, 0); bx.insert_value(res, second, 1) } else { - bx.load(cast_ty, ptr, align) + bx.load(cast_ty, ptr, align, None) } } @@ -2256,11 +2256,19 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( align: Align, ) { if let Some(offset_from_start) = cast.rest_offset { + dbg!("here we are!!"); + //let f0 = field_for_scalar_pair_component(bx, layout, 0); + //let f1 = field_for_scalar_pair_component(bx, layout, 1); + + //with_no_trimmed_paths!({ + // eprintln!("extractvalue 0 field = {:?}", f0.map(|f| f.ty)); + // eprintln!("extractvalue 1 field = {:?}", f1.map(|f| f.ty)); + //}); assert!(cast.prefix[1..].iter().all(|p| p.is_none())); assert_eq!(cast.rest.unit.size, cast.rest.total); assert!(cast.prefix[0].is_some()); - let first = bx.extract_value(value, 0); - let second = bx.extract_value(value, 1); + let first = bx.extract_value(value, 0, None); + let second = bx.extract_value(value, 1, None); bx.store(first, ptr, align); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); bx.store(second, second_ptr, align.restrict_for_offset(offset_from_start)); diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index ac6c6a0a52efa..1d7cb0694d06d 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -1,4 +1,5 @@ use rustc_abi::{Align, FieldIdx, WrappingRange}; +use rustc_ast::expand::typetree::FncTree; use rustc_middle::mir::SourceInfo; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; @@ -14,6 +15,7 @@ use crate::errors::InvalidMonomorphization; use crate::mir::operand::OperandRefBuilder; use crate::traits::*; use crate::{MemFlags, meth, size_of_val}; +use rustc_middle::ty::type_tree::typetree_from_ty; fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, @@ -29,10 +31,16 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; + let tcx = bx.tcx(); + let tt = typetree_from_ty(tcx, ty); + let fnc_tree = FncTree { + args: vec![tt.clone()], + ret: tt, + }; if allow_overlap { - bx.memmove(dst, align, src, align, size, flags); + bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree)); } else { - bx.memcpy(dst, align, src, align, size, flags, None); + bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree)); } } @@ -49,7 +57,7 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; - bx.memset(dst, val, size, align, flags); + bx.memset(dst, val, size, align, flags, None); } impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index c0c71edd4d905..4a563ed3ec528 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -21,6 +21,91 @@ use crate::MemFlags; use crate::common::IntPredicate; use crate::traits::*; +use rustc_ast::expand::typetree::TypeTree; +use rustc_middle::ty::type_tree::typetree_from_ty; +use crate::TyCtxt; +use rustc_span::sym; +use rustc_ast::expand::typetree::FncTree; + +fn option_ptr_like_scalar_pair_tts<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, +) -> Option { + let ty::Adt(def, args) = ty.kind() else { + return None; + }; + + if !tcx.is_lang_item(def.did(), LangItem::Option) { + return None; + } + + let inner = args.type_at(0); + if !(inner.is_ref() || inner.is_box() || nonnull_inner_ty(tcx, inner).is_some()) { + return None; + } + + let tt = typetree_from_ty(tcx, inner); + //let some_layout = layout.for_variant(bx.cx(), VariantIdx::from_u32(1)); + //let payload_layout = some_layout.field(bx.cx(), 0); + // this will be a slice + //let payload_ty = payload_layout.ty; + //let tt = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); + if tt == TypeTree::new() { + return None; + } + let fnc_tree = FncTree { + args: vec![], + ret: tt, + }; + Some(fnc_tree) +} + +fn nonnull_inner_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option> { + if let ty::Adt(def, args) = ty.kind() + && tcx.is_diagnostic_item(sym::NonNull, def.did()) + { + return Some(args.type_at(0)); + } + + None +} + pub fn scalar_pair_component_field_ty<'a, 'tcx, Bx, V>( + bx: &Bx, + layout: TyAndLayout<'tcx>, + idx: u64, + ) -> Option> + where + Bx: BuilderMethods<'a, 'tcx, Value = V>, + { + let BackendRepr::ScalarPair(a, b) = layout.backend_repr else { + return None; + }; + + let (want_offset, want_size) = match idx { + 0 => (Size::ZERO, a.size(bx.cx())), + 1 => { + let off = a.size(bx.cx()).align_to(b.align(bx.cx()).abi); + (off, b.size(bx.cx())) + } + _ => bug!("bad scalar-pair index {idx}"), + }; + + for i in 0..layout.fields.count() { + let field_layout = layout.field(bx.cx(), i); + let field_offset = layout.fields.offset(i); + + if field_layout.is_zst() { + continue; + } + + if field_offset == want_offset && field_layout.size == want_size { + return Some(field_layout.ty); + } + } + + None + } + /// The representation of a Rust value. The enum variant is in fact /// uniquely determined by the value's type, but is kept as a /// safety check. @@ -330,11 +415,78 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { layout: TyAndLayout<'tcx>, ) -> Self { let val = if let BackendRepr::ScalarPair(..) = layout.backend_repr { - debug!("Operand::from_immediate_or_packed_pair: unpacking {:?} @ {:?}", llval, layout); + use rustc_middle::ty::print::with_no_trimmed_paths; + let f1 = option_ptr_like_scalar_pair_tts(bx.tcx(), layout.ty); + let f2 = if f1.is_none() { + None + } else { + Some(FncTree { args: vec![], ret: TypeTree::int(8) }) + }; + //{ + // dbg!("new option ptr-like scalar pair"); + // tt + //} else { + // //let field0_ty = scalar_pair_component_field_ty(bx, layout, 0); + // //let field1_ty = scalar_pair_component_field_ty(bx, layout, 1); + + // //if field0_ty.is_none() || field1_ty.is_none() { + // //dbg!("from_immediate_or_packed_pair: missing field for layout"); + // //with_no_trimmed_paths!({ + // // eprintln!( + // // "from_immediate_or_packed_pair layout {:?}", + // // layout + // // ); + // //}); + // None + //}; + + //let field0_ty = scalar_pair_component_field_ty(bx, layout, 0); + //let field1_ty = scalar_pair_component_field_ty(bx, layout, 1); + //let (f1, f2) = if field0_ty.is_none() || field1_ty.is_none() { + // dbg!("from_immediate_or_packed_pair: missing field for layout"); + // with_no_trimmed_paths!({ + // eprintln!( + // "from_immediate_or_packed_pair layout {:?}", + // layout + // ); + // }); + // (None, None) + //} else { + // let tt1 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); + // let tt2 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field1_ty.unwrap()); + // with_no_trimmed_paths!({ + // eprintln!( + // "from_immediate_or_packed_pair layout {:?}", + // layout + // ); + // eprintln!( + // "from_immediate_or_packed_pair layout0 {:?}", + // field0_ty + // ); + // eprintln!( + // "from_immediate_or_packed_pair layout1 {:?}", + // field1_ty + // ); + // }); + // dbg!(&tt1); + // dbg!(&tt2); + // //dbg!(&tt); + // use rustc_ast::expand::typetree::FncTree; + // let fnc1 = FncTree { + // args: vec![],//TypeTree::new() + // ret: tt1, + // }; + // let fnc2 = FncTree { + // args: vec![],//TypeTree::new() + // ret: tt2, + // }; + // (Some(fnc1), Some(fnc2)) + //}; // Deconstruct the immediate aggregate. - let a_llval = bx.extract_value(llval, 0); - let b_llval = bx.extract_value(llval, 1); + let a_llval = bx.extract_value(llval, 0, f1); + // TODO: above/below should take the actual subtree, not full tree like now + let b_llval = bx.extract_value(llval, 1, f2); OperandValue::Pair(a_llval, b_llval) } else { OperandValue::Immediate(llval) @@ -342,12 +494,21 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { OperandRef { val, layout, move_annotation: None } } + pub(crate) fn extract_field>( &self, fx: &mut FunctionCx<'a, 'tcx, Bx>, bx: &mut Bx, i: usize, ) -> Self { + use rustc_middle::ty::print::with_no_trimmed_paths; + + with_no_trimmed_paths!({ + eprintln!( + "from_immediate_or_packed_pair single extract_field {:?}", + self.layout + ); + }); let field = self.layout.field(bx.cx(), i); let offset = self.layout.fields.offset(i); @@ -925,7 +1086,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { dest: PlaceRef<'tcx, V>, flags: MemFlags, ) { - debug!("OperandRef::store: operand={:?}, dest={:?}", self, dest); + //dbg!("store"); + //debug!("OperandRef::store: operand={:?}, dest={:?}", self, dest); match self { OperandValue::ZeroSized => { // Avoid generating stores of zero-sized values, because the only way to have a diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 90ac8c89ba9ad..cdc6f2d57d697 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -119,6 +119,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { size, dest.val.align, MemFlags::empty(), + None, ); return; } @@ -136,14 +137,14 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { && let Ok(&byte) = bytes.iter().all_equal_value() { let fill = bx.cx().const_u8(byte); - bx.memset(start, fill, size, dest.val.align, MemFlags::empty()); + bx.memset(start, fill, size, dest.val.align, MemFlags::empty(), None); return true; } // Use llvm.memset.p0i8.* to initialize byte arrays let v = bx.from_immediate(v); if bx.cx().val_ty(v) == bx.cx().type_i8() { - bx.memset(start, v, size, dest.val.align, MemFlags::empty()); + bx.memset(start, v, size, dest.val.align, MemFlags::empty(), None); return true; } false diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 5092f28a33f7b..e933d3826692e 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,9 +2,11 @@ use std::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; +use rustc_middle::ty::type_tree::typetree_from_ty; use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; use rustc_session::config::OptLevel; use rustc_span::Span; @@ -237,7 +239,7 @@ pub trait BuilderMethods<'a, 'tcx>: fn alloca(&mut self, size: Size, align: Align) -> Self::Value; fn alloca_with_ty(&mut self, layout: TyAndLayout<'tcx>) -> Self::Value; - fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value; + fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align, tt: Option) -> Self::Value; fn volatile_load(&mut self, ty: Self::Type, ptr: Self::Value) -> Self::Value; fn atomic_load( &mut self, @@ -248,7 +250,7 @@ pub trait BuilderMethods<'a, 'tcx>: ) -> Self::Value; fn load_from_place(&mut self, ty: Self::Type, place: PlaceValue) -> Self::Value { assert_eq!(place.llextra, None); - self.load(ty, place.llval, place.align) + self.load(ty, place.llval, place.align, None) } fn load_operand(&mut self, place: PlaceRef<'tcx, Self::Value>) -> OperandRef<'tcx, Self::Value>; @@ -462,6 +464,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memset( &mut self, @@ -470,6 +473,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + tt: Option, ); /// *Typed* copy for non-overlapping places. @@ -502,14 +506,28 @@ pub trait BuilderMethods<'a, 'tcx>: let ty = self.backend_type(layout); let val = self.load_from_place(ty, src); self.store_to_place_with_flags(val, dst, flags); + dbg!("typed copy, branch1, nontemporal"); } else if self.sess().opts.optimize == OptLevel::No && self.is_backend_immediate(layout) { // If we're not optimizing, the aliasing information from `memcpy` // isn't useful, so just load-store the value for smaller code. let temp = self.load_operand(src.with_type(layout)); + dbg!("typed copy, branch2, immediate"); temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); + //dbg!("typed copy, branch3"); + //let ty = self.backend_type(layout); + let ty = layout.ty; + //dbg!(&ty); + let tt: TypeTree = typetree_from_ty(self.tcx(), ty); + let tt = tt.add_indirection(); + let fnc_tree = FncTree { + args: vec![tt.clone(), tt], + ret: TypeTree::new(), + }; + //dbg!(&fnc_tree); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree)); + //dbg!("done"); } } @@ -547,7 +565,7 @@ pub trait BuilderMethods<'a, 'tcx>: fn va_arg(&mut self, list: Self::Value, ty: Self::Type) -> Self::Value; fn extract_element(&mut self, vec: Self::Value, idx: Self::Value) -> Self::Value; fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value; - fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value; + fn extract_value(&mut self, agg_val: Self::Value, idx: u64, tt: Option) -> Self::Value; fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value; fn set_personality_fn(&mut self, personality: Self::Function); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index ce38ba8338338..7211cba82786a 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -759,6 +759,28 @@ extern "C" bool LLVMRustInlineAsmVerify(LLVMTypeRef Ty, char *Constraints, unwrap(Ty), StringRef(Constraints, ConstraintsLen))); } +//extern "C" bool LLVMRustIsPtrLoad(LLVMValueRef V) { +// auto *LI = llvm::dyn_cast(llvm::unwrap(V)); +// return LI; +//} +extern "C" bool LLVMRustIsLoadOrExtractValue(LLVMValueRef V) { + auto *I = llvm::dyn_cast(llvm::unwrap(V)); + return I && (llvm::isa(I) || llvm::isa(I)); +} + +extern "C" void LLVMRustSetEnzymeTypeMetadata(LLVMValueRef V, LLVMValueRef MDV) { + auto *I = llvm::dyn_cast(llvm::unwrap(V)); + assert(I && "expected instruction for !enzyme_type metadata"); + + auto *MAV = llvm::dyn_cast(llvm::unwrap(MDV)); + assert(MAV && "expected MetadataAsValue"); + + auto *MD = llvm::dyn_cast(MAV->getMetadata()); + assert(MD && "expected MDNode"); + + I->setMetadata("enzyme_type", MD); +} + template DIT *unwrapDIPtr(LLVMMetadataRef Ref) { return (DIT *)(Ref ? unwrap(Ref) : nullptr); } @@ -907,6 +929,14 @@ enum class LLVMRustDebugEmissionKind { DebugDirectivesOnly, }; + +extern "C" bool LLVMRustIsIntrinsicCall(LLVMValueRef V) { + if (auto *CB = llvm::dyn_cast(llvm::unwrap(V))) { + return CB->getIntrinsicID() != llvm::Intrinsic::not_intrinsic; + } + return false; +} + static DICompileUnit::DebugEmissionKind fromRust(LLVMRustDebugEmissionKind Kind) { switch (Kind) { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 6df1ed82d260a..5a71175835827 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -70,6 +70,8 @@ use rustc_type_ir::{InferCtxtLike, Interner}; use tracing::{debug, instrument, trace}; pub use vtable::*; +pub mod type_tree; + pub use self::closure::{ BorrowKind, CAPTURE_STRUCT_LOCAL, CaptureInfo, CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, MinCaptureList, RootVariableMinCaptureList, UpvarCapture, UpvarId, @@ -2341,228 +2343,3 @@ pub struct DestructuredAdtConst<'tcx> { pub variant: VariantIdx, pub fields: &'tcx [ty::Const<'tcx>], } - -/// Generate TypeTree information for autodiff. -/// This function creates TypeTree metadata that describes the memory layout -/// of function parameters and return types for Enzyme autodiff. -pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { - // Check if TypeTrees are disabled via NoTT flag - if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { - return FncTree { args: vec![], ret: TypeTree::new() }; - } - - // Check if this is actually a function type - if !fn_ty.is_fn() { - return FncTree { args: vec![], ret: TypeTree::new() }; - } - - // Get the function signature - let fn_sig = fn_ty.fn_sig(tcx); - let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); - - // Create TypeTrees for each input parameter - let mut args = vec![]; - for ty in sig.inputs().iter() { - let type_tree = typetree_from_ty(tcx, *ty); - args.push(type_tree); - } - - // Create TypeTree for return type - let ret = typetree_from_ty(tcx, sig.output()); - - FncTree { args, ret } -} - -/// Generate TypeTree for a specific type. -/// This function analyzes a Rust type and creates appropriate TypeTree metadata. -pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { - let mut visited = Vec::new(); - typetree_from_ty_inner(tcx, ty, 0, &mut visited) -} - -/// Maximum recursion depth for TypeTree generation to prevent stack overflow -/// from pathological deeply nested types. Combined with cycle detection. -const MAX_TYPETREE_DEPTH: usize = 6; - -/// Internal recursive function for TypeTree generation with cycle detection and depth limiting. -fn typetree_from_ty_inner<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - depth: usize, - visited: &mut Vec>, -) -> TypeTree { - if depth >= MAX_TYPETREE_DEPTH { - trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty); - return TypeTree::new(); - } - - if visited.contains(&ty) { - return TypeTree::new(); - } - - visited.push(ty); - let result = typetree_from_ty_impl(tcx, ty, depth, visited); - visited.pop(); - result -} - -/// Implementation of TypeTree generation logic. -fn typetree_from_ty_impl<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - depth: usize, - visited: &mut Vec>, -) -> TypeTree { - typetree_from_ty_impl_inner(tcx, ty, depth, visited, false) -} - -/// Internal implementation with context about whether this is for a reference target. -fn typetree_from_ty_impl_inner<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - depth: usize, - visited: &mut Vec>, - is_reference_target: bool, -) -> TypeTree { - if ty.is_scalar() { - let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { - (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) - } else if ty.is_floating_point() { - match ty { - x if x == tcx.types.f16 => (Kind::Half, 2), - x if x == tcx.types.f32 => (Kind::Float, 4), - x if x == tcx.types.f64 => (Kind::Double, 8), - x if x == tcx.types.f128 => (Kind::F128, 16), - _ => (Kind::Integer, 0), - } - } else { - (Kind::Integer, 0) - }; - - // Use offset 0 for scalars that are direct targets of references (like &f64) - // Use offset -1 for scalars used directly (like function return types) - let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; - return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); - } - - if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { - let Some(inner_ty) = ty.builtin_deref(true) else { - return TypeTree::new(); - }; - - let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); - return TypeTree(vec![Type { - offset: -1, - size: tcx.data_layout.pointer_size().bytes_usize(), - kind: Kind::Pointer, - child, - }]); - } - - if ty.is_array() { - if let ty::Array(element_ty, len_const) = ty.kind() { - let len = len_const.try_to_target_usize(tcx).unwrap_or(0); - if len == 0 { - return TypeTree::new(); - } - let element_tree = - typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); - let mut types = Vec::new(); - for elem_type in &element_tree.0 { - types.push(Type { - offset: -1, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } - - return TypeTree(types); - } - } - - if ty.is_slice() { - if let ty::Slice(element_ty) = ty.kind() { - let element_tree = - typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); - return element_tree; - } - } - - if let ty::Tuple(tuple_types) = ty.kind() { - if tuple_types.is_empty() { - return TypeTree::new(); - } - - let mut types = Vec::new(); - let mut current_offset = 0; - - for tuple_ty in tuple_types.iter() { - let element_tree = - typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false); - - let element_layout = tcx - .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) - .ok() - .map(|layout| layout.size.bytes_usize()) - .unwrap_or(0); - - for elem_type in &element_tree.0 { - types.push(Type { - offset: if elem_type.offset == -1 { - current_offset as isize - } else { - current_offset as isize + elem_type.offset - }, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } - - current_offset += element_layout; - } - - return TypeTree(types); - } - - if let ty::Adt(adt_def, args) = ty.kind() { - if adt_def.is_struct() { - let struct_layout = - tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); - if let Ok(layout) = struct_layout { - let mut types = Vec::new(); - - for (field_idx, field_def) in adt_def.all_fields().enumerate() { - let field_ty = field_def.ty(tcx, args); - let field_tree = typetree_from_ty_impl_inner( - tcx, - field_ty.skip_norm_wip(), - depth + 1, - visited, - false, - ); - - let field_offset = layout.fields.offset(field_idx).bytes_usize(); - - for elem_type in &field_tree.0 { - types.push(Type { - offset: if elem_type.offset == -1 { - field_offset as isize - } else { - field_offset as isize + elem_type.offset - }, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } - } - - return TypeTree(types); - } - } - } - - TypeTree::new() -} diff --git a/compiler/rustc_middle/src/ty/type_tree.rs b/compiler/rustc_middle/src/ty/type_tree.rs new file mode 100644 index 0000000000000..89cd1872e7f44 --- /dev/null +++ b/compiler/rustc_middle/src/ty/type_tree.rs @@ -0,0 +1,302 @@ + +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_hir::LangItem; +use tracing::trace; +use rustc_ast::expand::typetree::*; + + +/// Generate TypeTree information for autodiff. +/// This function creates TypeTree metadata that describes the memory layout +/// of function parameters and return types for Enzyme autodiff. +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { + // Check if TypeTrees are disabled via NoTT flag + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Check if this is actually a function type + if !fn_ty.is_fn() { + dbg!("not a function type: {}", fn_ty); + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Get the function signature + let fn_sig = fn_ty.fn_sig(tcx); + let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); + + // Create TypeTrees for each input parameter + let mut args = vec![]; + for ty in sig.inputs().iter() { + let type_tree = typetree_from_ty(tcx, *ty); + args.push(type_tree); + } + + // Create TypeTree for return type + let ret = typetree_from_ty(tcx, sig.output()); + + FncTree { args, ret } +} + +/// Generate TypeTree for a specific type. +/// This function analyzes a Rust type and creates appropriate TypeTree metadata. +pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return TypeTree::new(); + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return TypeTree::new(); + } + let mut visited = Vec::new(); + typetree_from_ty_inner(tcx, ty, 0, &mut visited) +} + +/// Maximum recursion depth for TypeTree generation to prevent stack overflow +/// from pathological deeply nested types. Combined with cycle detection. +const MAX_TYPETREE_DEPTH: usize = 6; + +/// Internal recursive function for TypeTree generation with cycle detection and depth limiting. +fn typetree_from_ty_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + if depth >= MAX_TYPETREE_DEPTH { + trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty); + return TypeTree::new(); + } + + if visited.contains(&ty) { + return TypeTree::new(); + } + + visited.push(ty); + let result = typetree_from_ty_impl(tcx, ty, depth, visited); + visited.pop(); + result +} + +/// Implementation of TypeTree generation logic. +fn typetree_from_ty_impl<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + typetree_from_ty_impl_inner(tcx, ty, depth, visited, false) +} + +/// Internal implementation with context about whether this is for a reference target. +fn typetree_from_ty_impl_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, + is_reference_target: bool, +) -> TypeTree { + + if ty.is_slice() { + if let ty::Slice(element_ty) = ty.kind() { + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); + return element_tree; + } + } + + if let Some(inner) = unwrap_option_nonnull_ptr_like(tcx, ty) { + dbg!("option of niche!"); + dbg!(&ty); + // Option> and similar types use a niche to encode the Option/None variant, + // so we can ignore it and return the inner tt directly. + return typetree_from_ty_impl_inner(tcx, inner, depth + 1, visited, true); + }; + + + if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { + let Some(inner_ty) = ty.builtin_deref(true) else { + bug!("expected reference or pointer type to have a dereferenceable inner type: {}", ty); + return TypeTree::new(); + }; + + let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::Pointer, + child, + }]); + } + if let Some(inner) = is_nonnull(tcx, ty) { + let child = typetree_from_ty_impl_inner(tcx, inner, depth + 1, visited, true); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::Pointer, + child, + }]); + } + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f16 => (Kind::Half, 2), + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + x if x == tcx.types.f128 => (Kind::F128, 16), + _ => (Kind::Integer, 0), + } + } else { + (Kind::Integer, 0) + }; + + // Use offset 0 for scalars that are direct targets of references (like &f64) + // Use offset -1 for scalars used directly (like function return types) + let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; + return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); + } + + if ty.is_array() { + if let ty::Array(element_ty, len_const) = ty.kind() { + let len = len_const.try_to_target_usize(tcx).unwrap_or(0); + if len == 0 { + return TypeTree::new(); + } + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); + let mut types = Vec::new(); + for elem_type in &element_tree.0 { + types.push(Type { + offset: -1, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + + return TypeTree(types); + } + } + + if let ty::Tuple(tuple_types) = ty.kind() { + if tuple_types.is_empty() { + return TypeTree::new(); + } + + let mut types = Vec::new(); + let mut current_offset = 0; + + for tuple_ty in tuple_types.iter() { + let element_tree = + typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + current_offset as isize + } else { + current_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + + current_offset += element_layout; + } + + return TypeTree(types); + } + + if let ty::Adt(adt_def, args) = ty.kind() { + if adt_def.is_struct() { + let struct_layout = + tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); + if let Ok(layout) = struct_layout { + let mut types = Vec::new(); + + for (field_idx, field_def) in adt_def.all_fields().enumerate() { + let field_ty = field_def.ty(tcx, args); + let field_tree = typetree_from_ty_impl_inner( + tcx, + field_ty.skip_norm_wip(), + depth + 1, + visited, + false, + ); + + let field_offset = layout.fields.offset(field_idx).bytes_usize(); + + for elem_type in &field_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + field_offset as isize + } else { + field_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + } + + TypeTree::new() +} +use rustc_span::sym; +fn is_nonnull<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + ) -> Option> { + if let ty::Adt(def, args) = ty.kind() { + if tcx.is_diagnostic_item(sym::NonNull, def.did()) { + let inner_ty = args.type_at(0); + return Some(inner_ty); + } + } + None +} + +fn unwrap_option_nonnull_ptr_like<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, +) -> Option> { + let ty::Adt(def, args) = ty.kind() else { + return None; + }; + + if !tcx.is_lang_item(def.did(), LangItem::Option) { + dbg!("not an Option: {}", ty); + return None; + } + + let inner = args.type_at(0); + + // Accepted: Option<&T>, Option<&mut T>, Option> + // Rejected: Option<*const T>, Option<*mut T> + if inner.is_ref() || inner.is_box() { + return inner.builtin_deref(true); + } + + // Accepted: Option> + if let Some(inner_ty) = is_nonnull(tcx, inner) { + return Some(inner_ty); + } + dbg!("the wrong Option: {}", ty); + + None +} + +