From 1de39b92b6c759fd877b6f57cb07d6122630fe16 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jun 2026 14:45:09 -0400 Subject: [PATCH 1/9] compiles --- compiler/rustc_codegen_llvm/src/builder.rs | 31 +++++++++++++------ .../src/builder/autodiff.rs | 3 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 4 +++ compiler/rustc_codegen_ssa/src/meth.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/block.rs | 8 ++--- .../rustc_codegen_ssa/src/mir/intrinsic.rs | 13 ++++++-- compiler/rustc_codegen_ssa/src/mir/rvalue.rs | 5 +-- .../rustc_codegen_ssa/src/traits/builder.rs | 6 ++-- compiler/rustc_middle/src/ty/mod.rs | 7 +++++ 10 files changed, 58 insertions(+), 22 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 134bc5006dd00..acbbe189e910e 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -181,11 +181,12 @@ impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { } pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { - unsafe { + let load = unsafe { let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); llvm::LLVMSetAlignment(load, align.bytes() as c_uint); load - } + }; + load } } @@ -627,13 +628,17 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { - unsafe { + fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align, tt: Option) -> &'ll Value { + let load = unsafe { let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); let align = align.min(self.cx().tcx.sess.target.max_reliable_alignment()); llvm::LLVMSetAlignment(load, align.bytes() as c_uint); load + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, tt); } + load } fn volatile_load(&mut self, ty: &'ll Type, ptr: &'ll Value) -> &'ll Value { @@ -1113,7 +1118,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { // vs. copying a struct with mixed types requires different derivative handling. // The TypeTree tells Enzyme exactly what memory layout to expect. if let Some(tt) = tt { - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, self.tcx, tt); } } @@ -1125,11 +1130,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memmove not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memmove = unsafe { llvm::LLVMRustBuildMemMove( self.llbuilder, dst, @@ -1138,7 +1144,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memmove, self.tcx, tt); } } @@ -1149,10 +1158,11 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { size: &'ll Value, align: Align, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memset not supported"); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memset = unsafe { llvm::LLVMRustBuildMemSet( self.llbuilder, ptr, @@ -1160,7 +1170,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fill_byte, size, is_volatile, - ); + ) + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memset, self.tcx, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 1cefdaae5ebde..ed07bc8af823e 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -290,6 +290,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_name: &str, @@ -375,7 +376,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( ); if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() { - crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree); + crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, tcx, fnc_tree); } let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 1c7b415fd04c7..59971cf8e15e5 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1883,6 +1883,7 @@ fn codegen_autodiff<'ll, 'tcx>( // Build body generate_enzyme_call( bx, + tcx, bx.cx, fn_to_diff, &diff_symbol, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 4f433f273c8cc..ddf1db492e1fd 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -56,8 +56,12 @@ pub(crate) fn add_tt<'ll>( llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, + tcx: rustc_middle::ty::TyCtxt<'_>, tt: FncTree, ) { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return; + } // TypeTree processing uses functions from Enzyme, which we might not have available if we did // not build this compiler with `llvm_enzyme`. This feature is not strictly necessary, but // skipping this function increases the chance that Enzyme fails to compile some code. diff --git a/compiler/rustc_codegen_ssa/src/meth.rs b/compiler/rustc_codegen_ssa/src/meth.rs index b87034f9b33b7..08cd29289b7ca 100644 --- a/compiler/rustc_codegen_ssa/src/meth.rs +++ b/compiler/rustc_codegen_ssa/src/meth.rs @@ -146,7 +146,7 @@ pub(crate) fn load_vtable<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( } let gep = bx.inbounds_ptradd(llvtable, bx.const_usize(vtable_byte_offset)); - let ptr = bx.load(llty, gep, ptr_align); + let ptr = bx.load(llty, gep, ptr_align, None); // VTable loads are invariant. bx.set_invariant_load(ptr); if nonnull { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index b6b95c5f12aae..6d41c31b36a5e 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1815,7 +1815,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // used for this call is passing it by-value. In that case, // the load would just produce `OperandValue::Ref` instead // of the `OperandValue::Immediate` we need for the call. - llval = bx.load(bx.backend_type(arg.layout), llval, align); + llval = bx.load(bx.backend_type(arg.layout), llval, align, None); if let BackendRepr::Scalar(scalar) = arg.layout.backend_repr { if scalar.is_bool() { bx.range_metadata(llval, WrappingRange { start: 0, end: 1 }); @@ -2237,14 +2237,14 @@ fn load_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( assert_eq!(cast.rest.unit.size, cast.rest.total); let first_ty = bx.reg_backend_type(&cast.prefix[0].unwrap()); let second_ty = bx.reg_backend_type(&cast.rest.unit); - let first = bx.load(first_ty, ptr, align); + let first = bx.load(first_ty, ptr, align, None); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); - let second = bx.load(second_ty, second_ptr, align.restrict_for_offset(offset_from_start)); + let second = bx.load(second_ty, second_ptr, align.restrict_for_offset(offset_from_start), None); let res = bx.cx().const_poison(cast_ty); let res = bx.insert_value(res, first, 0); bx.insert_value(res, second, 1) } else { - bx.load(cast_ty, ptr, align) + bx.load(cast_ty, ptr, align, None) } } diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index ac6c6a0a52efa..21107898e2962 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -1,4 +1,5 @@ use rustc_abi::{Align, FieldIdx, WrappingRange}; +use rustc_ast::expand::typetree::FncTree; use rustc_middle::mir::SourceInfo; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; @@ -29,10 +30,16 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; + let tcx = bx.tcx(); + let tt = rustc_middle::ty::typetree_from_ty(tcx, ty); + let fnc_tree = FncTree { + args: vec![tt.clone()], + ret: tt, + }; if allow_overlap { - bx.memmove(dst, align, src, align, size, flags); + bx.memmove(dst, align, src, align, size, flags, Some(fnc_tree)); } else { - bx.memcpy(dst, align, src, align, size, flags, None); + bx.memcpy(dst, align, src, align, size, flags, Some(fnc_tree)); } } @@ -49,7 +56,7 @@ fn memset_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let align = layout.align.abi; let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; - bx.memset(dst, val, size, align, flags); + bx.memset(dst, val, size, align, flags, None); } impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 90ac8c89ba9ad..cdc6f2d57d697 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -119,6 +119,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { size, dest.val.align, MemFlags::empty(), + None, ); return; } @@ -136,14 +137,14 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { && let Ok(&byte) = bytes.iter().all_equal_value() { let fill = bx.cx().const_u8(byte); - bx.memset(start, fill, size, dest.val.align, MemFlags::empty()); + bx.memset(start, fill, size, dest.val.align, MemFlags::empty(), None); return true; } // Use llvm.memset.p0i8.* to initialize byte arrays let v = bx.from_immediate(v); if bx.cx().val_ty(v) == bx.cx().type_i8() { - bx.memset(start, v, size, dest.val.align, MemFlags::empty()); + bx.memset(start, v, size, dest.val.align, MemFlags::empty(), None); return true; } false diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 5092f28a33f7b..60334ab20833e 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -237,7 +237,7 @@ pub trait BuilderMethods<'a, 'tcx>: fn alloca(&mut self, size: Size, align: Align) -> Self::Value; fn alloca_with_ty(&mut self, layout: TyAndLayout<'tcx>) -> Self::Value; - fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value; + fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align, tt: Option) -> Self::Value; fn volatile_load(&mut self, ty: Self::Type, ptr: Self::Value) -> Self::Value; fn atomic_load( &mut self, @@ -248,7 +248,7 @@ pub trait BuilderMethods<'a, 'tcx>: ) -> Self::Value; fn load_from_place(&mut self, ty: Self::Type, place: PlaceValue) -> Self::Value { assert_eq!(place.llextra, None); - self.load(ty, place.llval, place.align) + self.load(ty, place.llval, place.align, None) } fn load_operand(&mut self, place: PlaceRef<'tcx, Self::Value>) -> OperandRef<'tcx, Self::Value>; @@ -462,6 +462,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memset( &mut self, @@ -470,6 +471,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + tt: Option, ); /// *Typed* copy for non-overlapping places. diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 6df1ed82d260a..deeb26c7e5120 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2353,6 +2353,7 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { // Check if this is actually a function type if !fn_ty.is_fn() { + dbg!("not a function type: {}", fn_ty); return FncTree { args: vec![], ret: TypeTree::new() }; } @@ -2376,6 +2377,12 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { /// Generate TypeTree for a specific type. /// This function analyzes a Rust type and creates appropriate TypeTree metadata. pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return TypeTree::new(); + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return TypeTree::new(); + } let mut visited = Vec::new(); typetree_from_ty_inner(tcx, ty, 0, &mut visited) } From e721b4b2bbc96afefdd69ba0e062aadc6939ab06 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jun 2026 15:43:02 -0400 Subject: [PATCH 2/9] more tt and dbg. todo: add tt to intrinsic, not just fn calls --- compiler/rustc_codegen_llvm/src/typetree.rs | 12 ++++++++++++ .../rustc_codegen_ssa/src/traits/builder.rs | 19 +++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index ddf1db492e1fd..5d972773a9099 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -62,6 +62,9 @@ pub(crate) fn add_tt<'ll>( if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { return; } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return; + } // TypeTree processing uses functions from Enzyme, which we might not have available if we did // not build this compiler with `llvm_enzyme`. This feature is not strictly necessary, but // skipping this function increases the chance that Enzyme fails to compile some code. @@ -73,6 +76,7 @@ pub(crate) fn add_tt<'ll>( let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; + dbg!("getting DataLayout"); let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; let llvm_data_layout = std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) @@ -80,6 +84,7 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); + dbg!("going to iter over inputs"); for (i, input) in inputs.iter().enumerate() { unsafe { @@ -95,11 +100,15 @@ pub(crate) fn add_tt<'ll>( c_str.as_ptr(), c_str.to_bytes().len() as c_uint, ); + dbg!("adding attribute for argument {}", i); + dbg!("attribute string: {:?}", c_str); + dbg!(&fn_def); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } + dbg!("finished to iter over inputs"); unsafe { let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); @@ -115,7 +124,10 @@ pub(crate) fn add_tt<'ll>( c_str.to_bytes().len() as c_uint, ); + dbg!(&fn_def); + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } + dbg!("finished to add return attribute"); } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 60334ab20833e..473d0fdc970cb 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,10 +2,11 @@ use std::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; +use rustc_ast::expand::typetree::FncTree; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; -use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; +use rustc_middle::ty::{typetree_from_ty, AtomicOrdering, Instance, Ty}; use rustc_session::config::OptLevel; use rustc_span::Span; use rustc_target::callconv::FnAbi; @@ -504,14 +505,28 @@ pub trait BuilderMethods<'a, 'tcx>: let ty = self.backend_type(layout); let val = self.load_from_place(ty, src); self.store_to_place_with_flags(val, dst, flags); + dbg!("typed copy, branch1, nontemporal"); } else if self.sess().opts.optimize == OptLevel::No && self.is_backend_immediate(layout) { // If we're not optimizing, the aliasing information from `memcpy` // isn't useful, so just load-store the value for smaller code. let temp = self.load_operand(src.with_type(layout)); + dbg!("typed copy, branch2, immediate"); temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); + dbg!("typed copy, branch3"); + //let ty = self.backend_type(layout); + let ty = layout.ty; + dbg!(&ty); + let tt = typetree_from_ty(self.tcx(), ty); + dbg!("got tt"); + let fnc_tree = FncTree { + args: vec![tt.clone()], + ret: tt, + }; + dbg!(&fnc_tree); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree)); + dbg!("done"); } } From 922ddfacfa2908c048cafec53ded8f37c234f4d8 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jun 2026 16:26:47 -0400 Subject: [PATCH 3/9] update tt to be added to callsites of intrinsics --- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 18 +++++++++++++++--- .../rustc_codegen_ssa/src/traits/builder.rs | 6 +++--- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 8 ++++++++ 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 1fde5866f5dca..50c13d09a7a90 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2453,6 +2453,7 @@ unsafe extern "C" { FileType: FileType, VerifyIR: bool, ) -> LLVMRustResult; + pub(crate) fn LLVMRustIsIntrinsicCall(val: &Value) -> bool; pub(crate) fn LLVMRustOptimize<'a>( M: &'a Module, TM: &'a TargetMachine, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 5d972773a9099..b6f686b07d6bc 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -3,7 +3,7 @@ use std::ffi::{CString, c_char, c_uint}; use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; use crate::attributes; -use crate::llvm::{self, EnzymeWrapper, Value}; +use crate::llvm::{self, EnzymeWrapper, TypeTree, Value}; fn to_enzyme_typetree( rust_typetree: RustTypeTree, @@ -104,13 +104,21 @@ pub(crate) fn add_tt<'ll>( dbg!("attribute string: {:?}", c_str); dbg!(&fn_def); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + if llvm::LLVMRustIsIntrinsicCall(fn_def) { + attributes::apply_to_callsite(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + } else { + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } dbg!("finished to iter over inputs"); unsafe { + if ret_tt == rustc_ast::expand::typetree::TypeTree::new() { + dbg!("skipping empty return tt"); + return; + } let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); let enzyme_wrapper = EnzymeWrapper::get_instance(); let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); @@ -126,7 +134,11 @@ pub(crate) fn add_tt<'ll>( dbg!(&fn_def); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + if llvm::LLVMRustIsIntrinsicCall(fn_def) { + attributes::apply_to_callsite(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + } else { + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } dbg!("finished to add return attribute"); diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 473d0fdc970cb..d02d8c5afb372 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,7 +2,7 @@ use std::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; -use rustc_ast::expand::typetree::FncTree; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; @@ -521,8 +521,8 @@ pub trait BuilderMethods<'a, 'tcx>: let tt = typetree_from_ty(self.tcx(), ty); dbg!("got tt"); let fnc_tree = FncTree { - args: vec![tt.clone()], - ret: tt, + args: vec![tt.clone(), tt], + ret: TypeTree::new(), }; dbg!(&fnc_tree); self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree)); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index ce38ba8338338..67a0cd9acb9c6 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -907,6 +907,14 @@ enum class LLVMRustDebugEmissionKind { DebugDirectivesOnly, }; + +extern "C" bool LLVMRustIsIntrinsicCall(LLVMValueRef V) { + if (auto *CB = llvm::dyn_cast(llvm::unwrap(V))) { + return CB->getIntrinsicID() != llvm::Intrinsic::not_intrinsic; + } + return false; +} + static DICompileUnit::DebugEmissionKind fromRust(LLVMRustDebugEmissionKind Kind) { switch (Kind) { From 6130cd073429b4e74dfd1601eda373f0385b6f32 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jun 2026 17:32:45 -0400 Subject: [PATCH 4/9] local enzyme fix, add indirection for the memcpy, not sure if correct? --- compiler/rustc_ast/src/expand/typetree.rs | 3 +++ compiler/rustc_codegen_ssa/src/traits/builder.rs | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 9619c80904426..1737335410054 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -57,6 +57,9 @@ impl TypeTree { } Self(ints) } + pub fn add_indirection(self) -> Self { + Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }]) + } } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)] diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index d02d8c5afb372..7e88ab451f404 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -518,7 +518,8 @@ pub trait BuilderMethods<'a, 'tcx>: //let ty = self.backend_type(layout); let ty = layout.ty; dbg!(&ty); - let tt = typetree_from_ty(self.tcx(), ty); + let tt: TypeTree = typetree_from_ty(self.tcx(), ty); + let tt = tt.add_indirection(); dbg!("got tt"); let fnc_tree = FncTree { args: vec![tt.clone(), tt], From 69a8e282b25057697c39a401c7ced5817480ed9e Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jun 2026 21:07:54 -0400 Subject: [PATCH 5/9] compiles, but md on laod instr seems wrong --- compiler/rustc_codegen_llvm/src/asm.rs | 2 +- compiler/rustc_codegen_llvm/src/builder.rs | 72 +++++++++++++++++-- .../src/builder/gpu_offload.rs | 6 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 15 ++-- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 15 +++- compiler/rustc_codegen_llvm/src/typetree.rs | 30 ++++++-- compiler/rustc_codegen_ssa/src/mir/block.rs | 4 +- compiler/rustc_codegen_ssa/src/mir/operand.rs | 4 +- .../rustc_codegen_ssa/src/traits/builder.rs | 2 +- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 20 ++++++ 10 files changed, 140 insertions(+), 30 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs index c5ab9fc2336eb..1f053300e8963 100644 --- a/compiler/rustc_codegen_llvm/src/asm.rs +++ b/compiler/rustc_codegen_llvm/src/asm.rs @@ -363,7 +363,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let value = if output_types.len() == 1 { result } else { - self.extract_value(result, op_idx[&idx] as u64) + self.extract_value(result, op_idx[&idx] as u64, None) }; let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index acbbe189e910e..1cfbea4964e5c 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,7 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::iter; use std::ops::Deref; -use rustc_ast::expand::typetree::FncTree; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -586,7 +586,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let name = format!("llvm.{}{oop_str}.with.overflow", if signed { 's' } else { 'u' }); let res = self.call_intrinsic(name, &[self.type_ix(width)], &[lhs, rhs]); - (self.extract_value(res, 0), self.extract_value(res, 1)) + (self.extract_value(res, 0, None), self.extract_value(res, 1, None)) } fn from_immediate(&mut self, val: Self::Value) -> Self::Value { @@ -668,6 +668,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { #[instrument(level = "trace", skip(self))] fn load_operand(&mut self, place: PlaceRef<'tcx, &'ll Value>) -> OperandRef<'tcx, &'ll Value> { + //dbg!("load_operand"); + //use rustc_middle::ty::print::with_no_trimmed_paths; + //eprintln!("place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); + if place.layout.is_unsized() { let tail = self.tcx.struct_tail_for_codegen(place.layout.ty, self.typing_env()); if matches!(tail.kind(), ty::Foreign(..)) { @@ -739,6 +743,62 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let llval = const_llval.unwrap_or_else(|| { let load = self.load(llty, place.val.llval, place.val.align); + let ty = place.layout.ty; + let tt = rustc_middle::ty::typetree_from_ty(self.tcx, ty); + dbg!(&tt); + let fnc_tree = FncTree { + args: vec![TypeTree::new(), TypeTree::new()], + ret: tt, + }; + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); + //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); + // 25 general load of place = PlaceRef { + // 24 val: PlaceValue { + // 23 llval: (ptr: %3 = alloca [8 x i8], align 8), + // 22 llextra: None, + // 21 align: Align(8 bytes), + // 20 }, + // 19 layout: TyAndLayout { + // 18 ty: &([f64; 3], [f64; 3]), + // 17 layout: Layout { + // 16 size: Size(8 bytes), + // 15 align: AbiAlign { + // 14 abi: Align(8 bytes), + // 13 }, + // 12 backend_repr: Scalar( + // 11 Initialized { + // 10 value: Pointer( + // 9 AddressSpace( + // 8 0, + // 7 ), + // 6 ), + // 5 valid_range: 1..=18446744073709551615, + // 4 }, + // 3 ), + // 2 fields: Primitive, + // 1 largest_niche: Some( + // 259 Niche { + // 1 offset: Size(0 bytes), + // 2 value: Pointer( + // 3 AddressSpace( + // 4 0, + // 5 ), + // 6 ), + // 7 valid_range: 1..=18446744073709551615, + // 8 }, + // 9 ), + // 10 uninhabited: false, + // 11 variants: Single { + // 12 index: 0, + // 13 }, + // 14 max_repr_align: None, + // 15 unadjusted_abi_align: Align(8 bytes), + // 16 randomization_seed: 281492156579847, + // 17 }, + // 18 }, + // 19 } + + dbg!(&load); if let abi::BackendRepr::Scalar(scalar) = place.layout.backend_repr { scalar_load_metadata(self, load, scalar, place.layout, Size::ZERO); self.to_immediate_scalar(load, scalar) @@ -1204,7 +1264,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } - fn extract_value(&mut self, agg_val: &'ll Value, idx: u64) -> &'ll Value { + fn extract_value(&mut self, agg_val: &'ll Value, idx: u64, tt: Option) -> &'ll Value { assert_eq!(idx as c_uint as u64, idx); unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) } } @@ -1226,7 +1286,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { unsafe { llvm::LLVMSetCleanup(landing_pad, llvm::TRUE); } - (self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1)) + (self.extract_value(landing_pad, 0, None), self.extract_value(landing_pad, 1, None)) } fn filter_landing_pad(&mut self, pers_fn: &'ll Value) { @@ -1327,8 +1387,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { llvm::FALSE, // SingleThreaded ); llvm::LLVMSetWeak(value, weak.to_llvm_bool()); - let val = self.extract_value(value, 0); - let success = self.extract_value(value, 1); + let val = self.extract_value(value, 0, None); + let success = self.extract_value(value, 1, None); (val, success) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 0b009321802cf..39e3ac8eb484f 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -162,9 +162,9 @@ impl<'ll> OffloadKernelDims<'ll> { builder: &mut Builder<'_, 'll, 'tcx>, arr: &'ll Value, ) -> &'ll Value { - let x = builder.extract_value(arr, 0); - let y = builder.extract_value(arr, 1); - let z = builder.extract_value(arr, 2); + let x = builder.extract_value(arr, 0, None); + let y = builder.extract_value(arr, 1, None); + let z = builder.extract_value(arr, 2, None); let xy = builder.mul(x, y); builder.mul(xy, z) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 59971cf8e15e5..a91515952aa75 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -780,6 +780,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { self.extract_value( args[0].immediate(), fn_args.const_at(2).to_leaf().to_i32() as u64, + None, ) } @@ -1063,7 +1064,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[], &[llvtable, vtable_byte_offset, typeid], ); - self.extract_value(type_checked_load, 0) + self.extract_value(type_checked_load, 0, None) } fn va_start(&mut self, va_list: &'ll Value) { @@ -1179,7 +1180,7 @@ fn autocast<'ll>( iter::zip(bx.struct_element_types(src_ty), bx.struct_element_types(dest_ty)) .enumerate() { - let elt = bx.extract_value(val, idx as u64); + let elt = bx.extract_value(val, idx as u64, None); let casted_elt = autocast(bx, elt, src_element_ty, dest_element_ty); ret = bx.insert_value(ret, casted_elt, idx as u64); } @@ -1642,7 +1643,7 @@ fn codegen_gnu_try<'ll, 'tcx>( let vals = bx.landing_pad(lpad_ty, bx.eh_personality(), 1); let tydesc = bx.const_null(bx.type_ptr()); bx.add_clause(vals, tydesc); - let ptr = bx.extract_value(vals, 0); + let ptr = bx.extract_value(vals, 0, None); let catch_ty = bx.type_func(&[bx.type_ptr(), bx.type_ptr()], bx.type_void()); bx.call(catch_ty, None, None, catch_func, &[data, ptr], None, None); bx.ret(bx.const_bool(true)); @@ -1704,8 +1705,8 @@ fn codegen_emcc_try<'ll, 'tcx>( let vals = bx.landing_pad(lpad_ty, bx.eh_personality(), 2); bx.add_clause(vals, tydesc); bx.add_clause(vals, bx.const_null(bx.type_ptr())); - let ptr = bx.extract_value(vals, 0); - let selector = bx.extract_value(vals, 1); + let ptr = bx.extract_value(vals, 0, None); + let selector = bx.extract_value(vals, 1, None); // Check if the typeid we got is the one for a Rust panic. let rust_typeid = bx.call_intrinsic("llvm.eh.typeid.for", &[bx.val_ty(tydesc)], &[tydesc]); @@ -2006,8 +2007,8 @@ fn get_args_from_tuple<'ll, 'tcx>( let field = tuple_place.project_field(bx, tuple_index); let llvm_ty = field.layout.llvm_type(bx.cx); let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); - result.push(bx.extract_value(pair_val, 0)); - result.push(bx.extract_value(pair_val, 1)); + result.push(bx.extract_value(pair_val, 0, None)); + result.push(bx.extract_value(pair_val, 1, None)); tuple_index += 1; } PassMode::Indirect { .. } => { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 195e050a9b651..fd37c70d9d6f3 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -65,7 +65,11 @@ unsafe extern "C" { name: *const c_char, NameLen: libc::size_t, ) -> Option<&Value>; - + pub(crate) fn LLVMRustIsPtrLoad(v: &Value) -> bool; + pub(crate) fn LLVMRustSetEnzymeTypeMetadata( + v: &Value, + md: &Value, + ); } unsafe extern "C" { @@ -90,6 +94,7 @@ pub(crate) use self::Enzyme_AD::*; pub(crate) mod Enzyme_AD { use std::ffi::{c_char, c_void}; use std::sync::{Mutex, MutexGuard, OnceLock}; + use super::Value; use rustc_middle::bug; use rustc_session::config::{Sysroot, host_tuple}; @@ -114,6 +119,7 @@ pub(crate) mod Enzyme_AD { unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context); type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char; type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char); + type EnzymeTypeTreeToMDFn = unsafe extern "C" fn(CTypeTreeRef, &Context) -> Option<&Value>; #[allow(non_snake_case)] pub(crate) struct EnzymeWrapper { @@ -127,6 +133,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn, EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, + EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn, EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, EnzymePrintPerf: *mut c_void, @@ -292,6 +299,10 @@ pub(crate) mod Enzyme_AD { unsafe { (self.EnzymeTypeTreeToString)(tree) } } + pub(crate) fn tree_to_md<'a>(&'a self, tree: *mut EnzymeTypeTree, ctx: &'a Context) -> Option<&'a Value> { + unsafe { (self.EnzymeTypeTreeToMD)(tree, ctx) } + } + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } @@ -381,6 +392,7 @@ pub(crate) mod Enzyme_AD { EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, EnzymeSetCLBool: EnzymeSetCLBoolFn, EnzymeSetCLString: EnzymeSetCLStringFn, + EnzymeTypeTreeToMD: EnzymeTypeTreeToMDFn, ); load_ptrs_by_symbols_mut_void!( @@ -422,6 +434,7 @@ pub(crate) mod Enzyme_AD { looseTypeAnalysis, EnzymeSetCLBool, EnzymeSetCLString, + EnzymeTypeTreeToMD, registerEnzymeAndPassPipeline, lib, }) diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index b6f686b07d6bc..9a0be0191b388 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,7 +1,8 @@ use std::ffi::{CString, c_char, c_uint}; use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; - +use crate::llvm::LLVMRustSetEnzymeTypeMetadata; +use crate::llvm::LLVMRustIsPtrLoad; use crate::attributes; use crate::llvm::{self, EnzymeWrapper, TypeTree, Value}; @@ -76,7 +77,7 @@ pub(crate) fn add_tt<'ll>( let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; - dbg!("getting DataLayout"); + //dbg!("getting DataLayout"); let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; let llvm_data_layout = std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) @@ -84,10 +85,14 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); - dbg!("going to iter over inputs"); + //dbg!("going to iter over inputs"); for (i, input) in inputs.iter().enumerate() { unsafe { + if *input == rustc_ast::expand::typetree::TypeTree::new() { + dbg!("skipping empty input tt"); + continue; + } let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); let enzyme_wrapper = EnzymeWrapper::get_instance(); let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); @@ -100,19 +105,23 @@ pub(crate) fn add_tt<'ll>( c_str.as_ptr(), c_str.to_bytes().len() as c_uint, ); - dbg!("adding attribute for argument {}", i); - dbg!("attribute string: {:?}", c_str); - dbg!(&fn_def); + //dbg!("adding attribute for argument {}", i); + //dbg!("attribute string: {:?}", c_str); + //dbg!(&fn_def); if llvm::LLVMRustIsIntrinsicCall(fn_def) { + //dbg!("intrinsic"); attributes::apply_to_callsite(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + } else if LLVMRustIsPtrLoad(fn_def) { + //dbg!("skipping input args for instr"); } else { + //dbg!("fn call"); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } - dbg!("finished to iter over inputs"); + //dbg!("finished to iter over inputs"); unsafe { if ret_tt == rustc_ast::expand::typetree::TypeTree::new() { @@ -135,8 +144,15 @@ pub(crate) fn add_tt<'ll>( dbg!(&fn_def); if llvm::LLVMRustIsIntrinsicCall(fn_def) { + dbg!("intrinsic call"); attributes::apply_to_callsite(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + } else if LLVMRustIsPtrLoad(fn_def) { + dbg!("hiii"); + //dbg!(&enzyme_tt); + let val = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); + LLVMRustSetEnzymeTypeMetadata(fn_def, val.unwrap()); } else { + dbg!("fn call"); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 6d41c31b36a5e..c45e31200f4a0 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -2259,8 +2259,8 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( assert!(cast.prefix[1..].iter().all(|p| p.is_none())); assert_eq!(cast.rest.unit.size, cast.rest.total); assert!(cast.prefix[0].is_some()); - let first = bx.extract_value(value, 0); - let second = bx.extract_value(value, 1); + let first = bx.extract_value(value, 0, None); + let second = bx.extract_value(value, 1, None); bx.store(first, ptr, align); let second_ptr = bx.inbounds_ptradd(ptr, bx.const_usize(offset_from_start.bytes())); bx.store(second, second_ptr, align.restrict_for_offset(offset_from_start)); diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index c0c71edd4d905..ef75293194f15 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -333,8 +333,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { debug!("Operand::from_immediate_or_packed_pair: unpacking {:?} @ {:?}", llval, layout); // Deconstruct the immediate aggregate. - let a_llval = bx.extract_value(llval, 0); - let b_llval = bx.extract_value(llval, 1); + let a_llval = bx.extract_value(llval, 0, None); + let b_llval = bx.extract_value(llval, 1, None); OperandValue::Pair(a_llval, b_llval) } else { OperandValue::Immediate(llval) diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 7e88ab451f404..8a7d6df5c390d 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -565,7 +565,7 @@ pub trait BuilderMethods<'a, 'tcx>: fn va_arg(&mut self, list: Self::Value, ty: Self::Type) -> Self::Value; fn extract_element(&mut self, vec: Self::Value, idx: Self::Value) -> Self::Value; fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value; - fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value; + fn extract_value(&mut self, agg_val: Self::Value, idx: u64, tt: Option) -> Self::Value; fn insert_value(&mut self, agg_val: Self::Value, elt: Self::Value, idx: u64) -> Self::Value; fn set_personality_fn(&mut self, personality: Self::Function); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 67a0cd9acb9c6..40f7cb114c946 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -759,6 +759,26 @@ extern "C" bool LLVMRustInlineAsmVerify(LLVMTypeRef Ty, char *Constraints, unwrap(Ty), StringRef(Constraints, ConstraintsLen))); } + +extern "C" bool LLVMRustIsPtrLoad(LLVMValueRef V) { + auto *LI = llvm::dyn_cast(llvm::unwrap(V)); + return LI; + //return LI && LI->getType()->isPointerTy(); +} + +extern "C" void LLVMRustSetEnzymeTypeMetadata(LLVMValueRef V, LLVMValueRef MDV) { + auto *I = llvm::dyn_cast(llvm::unwrap(V)); + assert(I && "expected instruction for !enzyme_type metadata"); + + auto *MAV = llvm::dyn_cast(llvm::unwrap(MDV)); + assert(MAV && "expected MetadataAsValue"); + + auto *MD = llvm::dyn_cast(MAV->getMetadata()); + assert(MD && "expected MDNode"); + + I->setMetadata("enzyme_type", MD); +} + template DIT *unwrapDIPtr(LLVMMetadataRef Ref) { return (DIT *)(Ref ? unwrap(Ref) : nullptr); } From 90d4d3f8804f2bb1700ccd4433251666cf723378 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 16 Jun 2026 22:44:59 -0400 Subject: [PATCH 6/9] no runtime error iff we update all PtrIntSame in the enzyme submodule to true. Still to weak to handle iterators --- compiler/rustc_codegen_llvm/src/builder.rs | 27 ++++++++++++--------- compiler/rustc_codegen_llvm/src/typetree.rs | 23 +++++++++++++++--- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 1cfbea4964e5c..090b428a89b43 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -636,7 +636,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { load }; if let Some(tt) = tt { - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, tt); + //crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, tt); } load } @@ -668,10 +668,6 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { #[instrument(level = "trace", skip(self))] fn load_operand(&mut self, place: PlaceRef<'tcx, &'ll Value>) -> OperandRef<'tcx, &'ll Value> { - //dbg!("load_operand"); - //use rustc_middle::ty::print::with_no_trimmed_paths; - //eprintln!("place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); - if place.layout.is_unsized() { let tail = self.tcx.struct_tail_for_codegen(place.layout.ty, self.typing_env()); if matches!(tail.kind(), ty::Foreign(..)) { @@ -743,14 +739,22 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let llval = const_llval.unwrap_or_else(|| { let load = self.load(llty, place.val.llval, place.val.align); + //let layout = place.layout.ty_and_layout_pointee_info_at(self.cx(), Size::ZERO).unwrap(); let ty = place.layout.ty; let tt = rustc_middle::ty::typetree_from_ty(self.tcx, ty); - dbg!(&tt); - let fnc_tree = FncTree { - args: vec![TypeTree::new(), TypeTree::new()], - ret: tt, - }; - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); + if tt != rustc_ast::expand::typetree::TypeTree::new() { + use rustc_middle::ty::print::with_no_trimmed_paths; + dbg!("add_tt start!"); + dbg!(&load); + dbg!(&tt); + eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); + let fnc_tree = FncTree { + args: vec![TypeTree::new(), TypeTree::new()], + ret: tt, + }; + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); + dbg!("add_tt done!"); + } //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); // 25 general load of place = PlaceRef { // 24 val: PlaceValue { @@ -798,7 +802,6 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { // 18 }, // 19 } - dbg!(&load); if let abi::BackendRepr::Scalar(scalar) = place.layout.backend_repr { scalar_load_metadata(self, load, scalar, place.layout, Size::ZERO); self.to_immediate_scalar(load, scalar) diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 9a0be0191b388..c3f96a3c9e889 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -90,7 +90,7 @@ pub(crate) fn add_tt<'ll>( for (i, input) in inputs.iter().enumerate() { unsafe { if *input == rustc_ast::expand::typetree::TypeTree::new() { - dbg!("skipping empty input tt"); + //dbg!("skipping empty input tt"); continue; } let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); @@ -125,12 +125,30 @@ pub(crate) fn add_tt<'ll>( unsafe { if ret_tt == rustc_ast::expand::typetree::TypeTree::new() { - dbg!("skipping empty return tt"); + //dbg!("skipping empty return tt"); return; } let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); let enzyme_wrapper = EnzymeWrapper::get_instance(); let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); + // just printing + //let ptr = wrapper.tree_to_string(self.inner); + let cstr = unsafe { std::ffi::CStr::from_ptr(c_str) }; + use std::io::Write as _; + let mut stderr = std::io::stderr().lock(); + + match cstr.to_str() { + Ok(x) => { + writeln!(stderr, "parsed: {:?}", x).ok(); + } + Err(err) => { + writeln!(stderr, "could not parse: {}", err).ok(); + } + } + + // delete C string pointer + //wrapper.tree_to_string_free(ptr); + // done printing let c_str = std::ffi::CStr::from_ptr(c_str); let ret_attr = llvm::LLVMCreateStringAttribute( @@ -148,7 +166,6 @@ pub(crate) fn add_tt<'ll>( attributes::apply_to_callsite(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); } else if LLVMRustIsPtrLoad(fn_def) { dbg!("hiii"); - //dbg!(&enzyme_tt); let val = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); LLVMRustSetEnzymeTypeMetadata(fn_def, val.unwrap()); } else { From 67e4944ef2220fc6ec149077580a2910cbe63b11 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 17 Jun 2026 17:48:39 -0400 Subject: [PATCH 7/9] Be more conservative about integer values in TT --- compiler/rustc_codegen_llvm/src/typetree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index c3f96a3c9e889..119981712bb6c 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -24,7 +24,7 @@ fn process_typetree_recursive( for rust_type in &rust_typetree.0 { let concrete_type = match rust_type.kind { rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, - rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, + rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Unknown, rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, From 64a76852b126444aec4f82d9d5fbd2129f8ae0e7 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 18 Jun 2026 16:55:55 -0400 Subject: [PATCH 8/9] fix tt creation, handle one path towards extractvalue, big comptime increase --- compiler/rustc_codegen_llvm/src/builder.rs | 18 ++-- compiler/rustc_codegen_llvm/src/intrinsic.rs | 15 ++- .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 3 +- compiler/rustc_codegen_llvm/src/typetree.rs | 9 +- compiler/rustc_codegen_ssa/src/mir/operand.rs | 92 ++++++++++++++++++- .../rustc_codegen_ssa/src/traits/builder.rs | 9 +- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 12 ++- compiler/rustc_middle/src/ty/mod.rs | 54 ++++++++--- 8 files changed, 171 insertions(+), 41 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 090b428a89b43..1b1b04d3ea4a5 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -744,16 +744,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let tt = rustc_middle::ty::typetree_from_ty(self.tcx, ty); if tt != rustc_ast::expand::typetree::TypeTree::new() { use rustc_middle::ty::print::with_no_trimmed_paths; - dbg!("add_tt start!"); - dbg!(&load); - dbg!(&tt); - eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); + //dbg!("add_tt start!"); + //dbg!(&load); + //dbg!(&tt); + //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); let fnc_tree = FncTree { args: vec![TypeTree::new(), TypeTree::new()], ret: tt, }; crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); - dbg!("add_tt done!"); + //dbg!("add_tt done!"); } //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); // 25 general load of place = PlaceRef { @@ -1269,7 +1269,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { fn extract_value(&mut self, agg_val: &'ll Value, idx: u64, tt: Option) -> &'ll Value { assert_eq!(idx as c_uint as u64, idx); - unsafe { llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) } + let ev = unsafe { + llvm::LLVMBuildExtractValue(self.llbuilder, agg_val, idx as c_uint, UNNAMED) + }; + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, ev, self.tcx, tt); + } + ev } fn insert_value(&mut self, agg_val: &'ll Value, elt: &'ll Value, idx: u64) -> &'ll Value { diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index a91515952aa75..a2dbdd5acd867 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -6,6 +6,7 @@ use rustc_abi::{ AddressSpace, Align, BackendRepr, CVariadicStatus, Float, HasDataLayout, Integer, NumScalableVectors, Primitive, Size, WrappingRange, }; +use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_codegen_ssa::RetagInfo; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; @@ -2007,8 +2008,18 @@ fn get_args_from_tuple<'ll, 'tcx>( let field = tuple_place.project_field(bx, tuple_index); let llvm_ty = field.layout.llvm_type(bx.cx); let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); - result.push(bx.extract_value(pair_val, 0, None)); - result.push(bx.extract_value(pair_val, 1, None)); + let extract_ty = field.layout.ty; + let tt = rustc_middle::ty::typetree_from_ty(bx.tcx(), extract_ty); + dbg!("intrinsic pair"); + dbg!(&tt); + let fnc = FncTree { + args: vec![],//TypeTree::new() + ret: tt, + }; + //let tt0 = enzyme_type_from_ty /* TypeTree for extracted element 0 */; + //let tt1 = /* TypeTree for extracted element 1 */; + result.push(bx.extract_value(pair_val, 0, Some(fnc.clone()))); + result.push(bx.extract_value(pair_val, 1, Some(fnc))); tuple_index += 1; } PassMode::Indirect { .. } => { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index fd37c70d9d6f3..3430bdbf37b22 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -65,7 +65,8 @@ unsafe extern "C" { name: *const c_char, NameLen: libc::size_t, ) -> Option<&Value>; - pub(crate) fn LLVMRustIsPtrLoad(v: &Value) -> bool; + //pub(crate) fn LLVMRustIsPtrLoad(v: &Value) -> bool; + pub(crate) fn LLVMRustIsLoadOrExtractValue(v: &Value) -> bool; pub(crate) fn LLVMRustSetEnzymeTypeMetadata( v: &Value, md: &Value, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 119981712bb6c..ee38820801a99 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -2,7 +2,7 @@ use std::ffi::{CString, c_char, c_uint}; use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; use crate::llvm::LLVMRustSetEnzymeTypeMetadata; -use crate::llvm::LLVMRustIsPtrLoad; +use crate::llvm::LLVMRustIsLoadOrExtractValue; use crate::attributes; use crate::llvm::{self, EnzymeWrapper, TypeTree, Value}; @@ -112,7 +112,8 @@ pub(crate) fn add_tt<'ll>( if llvm::LLVMRustIsIntrinsicCall(fn_def) { //dbg!("intrinsic"); attributes::apply_to_callsite(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); - } else if LLVMRustIsPtrLoad(fn_def) { + //} else if LLVMRustIsPtrLoad(fn_def) { + } else if LLVMRustIsLoadOrExtractValue(fn_def) { //dbg!("skipping input args for instr"); } else { //dbg!("fn call"); @@ -164,8 +165,8 @@ pub(crate) fn add_tt<'ll>( if llvm::LLVMRustIsIntrinsicCall(fn_def) { dbg!("intrinsic call"); attributes::apply_to_callsite(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); - } else if LLVMRustIsPtrLoad(fn_def) { - dbg!("hiii"); + //} else if LLVMRustIsPtrLoad(fn_def) { + } else if LLVMRustIsLoadOrExtractValue(fn_def) { let val = enzyme_wrapper.tree_to_md(enzyme_tt.inner, llcx); LLVMRustSetEnzymeTypeMetadata(fn_def, val.unwrap()); } else { diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index ef75293194f15..03695e73afe62 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -21,6 +21,44 @@ use crate::MemFlags; use crate::common::IntPredicate; use crate::traits::*; + fn scalar_pair_component_field_ty<'a, 'tcx, Bx, V>( + bx: &Bx, + layout: TyAndLayout<'tcx>, + idx: u64, + ) -> Option> + where + Bx: BuilderMethods<'a, 'tcx, Value = V>, + //Bx: BuilderMethods<'_, 'tcx, Value = V>, + { + let BackendRepr::ScalarPair(a, b) = layout.backend_repr else { + return None; + }; + + let (want_offset, want_size) = match idx { + 0 => (Size::ZERO, a.size(bx.cx())), + 1 => { + let off = a.size(bx.cx()).align_to(b.align(bx.cx()).abi); + (off, b.size(bx.cx())) + } + _ => bug!("bad scalar-pair index {idx}"), + }; + + for i in 0..layout.fields.count() { + let field_layout = layout.field(bx.cx(), i); + let field_offset = layout.fields.offset(i); + + if field_layout.is_zst() { + continue; + } + + if field_offset == want_offset && field_layout.size == want_size { + return Some(field_layout.ty); + } + } + + None + } + /// The representation of a Rust value. The enum variant is in fact /// uniquely determined by the value's type, but is kept as a /// safety check. @@ -330,11 +368,47 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { layout: TyAndLayout<'tcx>, ) -> Self { let val = if let BackendRepr::ScalarPair(..) = layout.backend_repr { - debug!("Operand::from_immediate_or_packed_pair: unpacking {:?} @ {:?}", llval, layout); + use rustc_middle::ty::print::with_no_trimmed_paths; + let field0_ty = scalar_pair_component_field_ty(bx, layout, 0); + let field1_ty = scalar_pair_component_field_ty(bx, layout, 1); + let (f1, f2) = if field0_ty.is_none() || field1_ty.is_none() { + (None, None) + } else { + let tt1 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); + let tt2 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field1_ty.unwrap()); + with_no_trimmed_paths!({ + eprintln!( + "from_immediate_or_packed_pair layout {:?}", + layout + ); + eprintln!( + "from_immediate_or_packed_pair layout0 {:?}", + field0_ty + ); + eprintln!( + "from_immediate_or_packed_pair layout1 {:?}", + field1_ty + ); + }); + dbg!(&tt1); + dbg!(&tt2); + //dbg!(&tt); + use rustc_ast::expand::typetree::FncTree; + let fnc1 = FncTree { + args: vec![],//TypeTree::new() + ret: tt1, + }; + let fnc2 = FncTree { + args: vec![],//TypeTree::new() + ret: tt2, + }; + (Some(fnc1), Some(fnc2)) + }; // Deconstruct the immediate aggregate. - let a_llval = bx.extract_value(llval, 0, None); - let b_llval = bx.extract_value(llval, 1, None); + let a_llval = bx.extract_value(llval, 0, f1); + // TODO: above/below should take the actual subtree, not full tree like now + let b_llval = bx.extract_value(llval, 1, f2); OperandValue::Pair(a_llval, b_llval) } else { OperandValue::Immediate(llval) @@ -342,12 +416,21 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { OperandRef { val, layout, move_annotation: None } } + pub(crate) fn extract_field>( &self, fx: &mut FunctionCx<'a, 'tcx, Bx>, bx: &mut Bx, i: usize, ) -> Self { + use rustc_middle::ty::print::with_no_trimmed_paths; + + with_no_trimmed_paths!({ + eprintln!( + "from_immediate_or_packed_pair layout {:?}", + self.layout + ); + }); let field = self.layout.field(bx.cx(), i); let offset = self.layout.fields.offset(i); @@ -925,7 +1008,8 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue { dest: PlaceRef<'tcx, V>, flags: MemFlags, ) { - debug!("OperandRef::store: operand={:?}, dest={:?}", self, dest); + //dbg!("store"); + //debug!("OperandRef::store: operand={:?}, dest={:?}", self, dest); match self { OperandValue::ZeroSized => { // Avoid generating stores of zero-sized values, because the only way to have a diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 8a7d6df5c390d..afc791c396bea 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -514,20 +514,19 @@ pub trait BuilderMethods<'a, 'tcx>: temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - dbg!("typed copy, branch3"); + //dbg!("typed copy, branch3"); //let ty = self.backend_type(layout); let ty = layout.ty; - dbg!(&ty); + //dbg!(&ty); let tt: TypeTree = typetree_from_ty(self.tcx(), ty); let tt = tt.add_indirection(); - dbg!("got tt"); let fnc_tree = FncTree { args: vec![tt.clone(), tt], ret: TypeTree::new(), }; - dbg!(&fnc_tree); + //dbg!(&fnc_tree); self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree)); - dbg!("done"); + //dbg!("done"); } } diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 40f7cb114c946..7211cba82786a 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -759,11 +759,13 @@ extern "C" bool LLVMRustInlineAsmVerify(LLVMTypeRef Ty, char *Constraints, unwrap(Ty), StringRef(Constraints, ConstraintsLen))); } - -extern "C" bool LLVMRustIsPtrLoad(LLVMValueRef V) { - auto *LI = llvm::dyn_cast(llvm::unwrap(V)); - return LI; - //return LI && LI->getType()->isPointerTy(); +//extern "C" bool LLVMRustIsPtrLoad(LLVMValueRef V) { +// auto *LI = llvm::dyn_cast(llvm::unwrap(V)); +// return LI; +//} +extern "C" bool LLVMRustIsLoadOrExtractValue(LLVMValueRef V) { + auto *I = llvm::dyn_cast(llvm::unwrap(V)); + return I && (llvm::isa(I) || llvm::isa(I)); } extern "C" void LLVMRustSetEnzymeTypeMetadata(LLVMValueRef V, LLVMValueRef MDV) { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index deeb26c7e5120..13660f8997860 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2431,6 +2431,31 @@ fn typetree_from_ty_impl_inner<'tcx>( visited: &mut Vec>, is_reference_target: bool, ) -> TypeTree { + + if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { + let Some(inner_ty) = ty.builtin_deref(true) else { + bug!("expected reference or pointer type to have a dereferenceable inner type: {}", ty); + return TypeTree::new(); + }; + + let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::Pointer, + child, + }]); + } + if let Some(inner) = is_ptr_like(tcx, ty) { + let child = typetree_from_ty_impl_inner(tcx, inner, depth + 1, visited, true); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::Pointer, + child, + }]); + } + if ty.is_scalar() { let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) @@ -2452,20 +2477,6 @@ fn typetree_from_ty_impl_inner<'tcx>( return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); } - if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { - let Some(inner_ty) = ty.builtin_deref(true) else { - return TypeTree::new(); - }; - - let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); - return TypeTree(vec![Type { - offset: -1, - size: tcx.data_layout.pointer_size().bytes_usize(), - kind: Kind::Pointer, - child, - }]); - } - if ty.is_array() { if let ty::Array(element_ty, len_const) = ty.kind() { let len = len_const.try_to_target_usize(tcx).unwrap_or(0); @@ -2573,3 +2584,18 @@ fn typetree_from_ty_impl_inner<'tcx>( TypeTree::new() } +use rustc_span::sym; +fn is_ptr_like<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + ) -> Option> { + if let ty::Adt(def, args) = ty.kind() { + if tcx.is_diagnostic_item(sym::NonNull, def.did()) { + let inner_ty = args.type_at(0); + return Some(inner_ty); + } + } + None +} + + From f5f9741c554e0c580f3fa2267e4a33815e8deb5b Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 19 Jun 2026 03:05:55 -0400 Subject: [PATCH 9/9] fixes the map/iter case in debug mode, with no enzyme changes --- compiler/rustc_codegen_llvm/src/abi.rs | 12 + compiler/rustc_codegen_llvm/src/builder.rs | 6 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 6 +- compiler/rustc_codegen_llvm/src/typetree.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/block.rs | 8 + .../rustc_codegen_ssa/src/mir/intrinsic.rs | 3 +- compiler/rustc_codegen_ssa/src/mir/operand.rs | 150 ++++++--- .../rustc_codegen_ssa/src/traits/builder.rs | 3 +- compiler/rustc_middle/src/ty/mod.rs | 260 +-------------- compiler/rustc_middle/src/ty/type_tree.rs | 302 ++++++++++++++++++ 10 files changed, 451 insertions(+), 301 deletions(-) create mode 100644 compiler/rustc_middle/src/ty/type_tree.rs diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index c3bf566b35880..53104fe3a0ecc 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -34,6 +34,8 @@ trait ArgAttributesExt { callsite: &Value, ); } + use crate::abi::ty::print::with_no_trimmed_paths; +use rustc_codegen_ssa::mir::operand::scalar_pair_component_field_ty; const ABI_AFFECTING_ATTRIBUTES: [(ArgAttribute, llvm::AttributeKind); 1] = [(ArgAttribute::InReg, llvm::AttributeKind::InReg)]; @@ -262,6 +264,16 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { let llscratch = bx.alloca(scratch_size, scratch_align); bx.lifetime_start(llscratch, scratch_size); // ...store the value... + + let f0 = scalar_pair_component_field_ty(bx, dst.layout, 0); + let f1 = scalar_pair_component_field_ty(bx, dst.layout, 1); + + if f1.is_some() && f0.is_some() { + with_no_trimmed_paths!({ + eprintln!("Cast of extractvalue 0 field = {:?}", f0.map(|f| f0.unwrap())); + eprintln!("Cast of extractvalue 1 field = {:?}", f1.map(|f| f1.unwrap())); + }); + } rustc_codegen_ssa::mir::store_cast(bx, cast, val, llscratch, scratch_align); // ... and then memcpy it to the intended destination. bx.memcpy( diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 1b1b04d3ea4a5..9291cdb12ca51 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -38,6 +38,7 @@ use crate::llvm::{ ToLlvmBool, Type, Value, }; use crate::type_of::LayoutLlvmExt; +use rustc_middle::ty::type_tree::typetree_from_ty; #[must_use] pub(crate) struct GenericBuilder<'a, 'll, CX: Borrow>> { @@ -741,7 +742,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { let load = self.load(llty, place.val.llval, place.val.align); //let layout = place.layout.ty_and_layout_pointee_info_at(self.cx(), Size::ZERO).unwrap(); let ty = place.layout.ty; - let tt = rustc_middle::ty::typetree_from_ty(self.tcx, ty); + let tt = typetree_from_ty(self.tcx, ty); if tt != rustc_ast::expand::typetree::TypeTree::new() { use rustc_middle::ty::print::with_no_trimmed_paths; //dbg!("add_tt start!"); @@ -752,7 +753,8 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { args: vec![TypeTree::new(), TypeTree::new()], ret: tt, }; - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); + // TODO: re-enable? + //crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, load, self.tcx, fnc_tree); //dbg!("add_tt done!"); } //eprintln!("general load of place = {}", with_no_trimmed_paths!(format!("{place:#?}"))); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index a2dbdd5acd867..3952f2a143e1b 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -47,6 +47,8 @@ use crate::errors::{ use crate::llvm::{self, Type, Value}; use crate::type_of::LayoutLlvmExt; use crate::va_arg::emit_va_arg; +use rustc_middle::ty::type_tree::typetree_from_ty; +use rustc_middle::ty::type_tree::fnc_typetrees; fn call_simple_intrinsic<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, @@ -1880,7 +1882,7 @@ fn codegen_autodiff<'ll, 'tcx>( &mut diff_attrs.input_activity, ); - let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty); + let fnc_tree = fnc_typetrees(tcx, source_fn_ptr_ty); // Build body generate_enzyme_call( @@ -2009,7 +2011,7 @@ fn get_args_from_tuple<'ll, 'tcx>( let llvm_ty = field.layout.llvm_type(bx.cx); let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); let extract_ty = field.layout.ty; - let tt = rustc_middle::ty::typetree_from_ty(bx.tcx(), extract_ty); + let tt = typetree_from_ty(bx.tcx(), extract_ty); dbg!("intrinsic pair"); dbg!(&tt); let fnc = FncTree { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index ee38820801a99..a89d5d18c3ba9 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -24,7 +24,7 @@ fn process_typetree_recursive( for rust_type in &rust_typetree.0 { let concrete_type = match rust_type.kind { rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, - rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Unknown, + rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index c45e31200f4a0..c6a14e2e1ae08 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -2256,6 +2256,14 @@ pub fn store_cast<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( align: Align, ) { if let Some(offset_from_start) = cast.rest_offset { + dbg!("here we are!!"); + //let f0 = field_for_scalar_pair_component(bx, layout, 0); + //let f1 = field_for_scalar_pair_component(bx, layout, 1); + + //with_no_trimmed_paths!({ + // eprintln!("extractvalue 0 field = {:?}", f0.map(|f| f.ty)); + // eprintln!("extractvalue 1 field = {:?}", f1.map(|f| f.ty)); + //}); assert!(cast.prefix[1..].iter().all(|p| p.is_none())); assert_eq!(cast.rest.unit.size, cast.rest.total); assert!(cast.prefix[0].is_some()); diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index 21107898e2962..1d7cb0694d06d 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -15,6 +15,7 @@ use crate::errors::InvalidMonomorphization; use crate::mir::operand::OperandRefBuilder; use crate::traits::*; use crate::{MemFlags, meth, size_of_val}; +use rustc_middle::ty::type_tree::typetree_from_ty; fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( bx: &mut Bx, @@ -31,7 +32,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( let size = bx.mul(bx.const_usize(size.bytes()), count); let flags = if volatile { MemFlags::VOLATILE } else { MemFlags::empty() }; let tcx = bx.tcx(); - let tt = rustc_middle::ty::typetree_from_ty(tcx, ty); + let tt = typetree_from_ty(tcx, ty); let fnc_tree = FncTree { args: vec![tt.clone()], ret: tt, diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index 03695e73afe62..4a563ed3ec528 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -21,14 +21,61 @@ use crate::MemFlags; use crate::common::IntPredicate; use crate::traits::*; - fn scalar_pair_component_field_ty<'a, 'tcx, Bx, V>( +use rustc_ast::expand::typetree::TypeTree; +use rustc_middle::ty::type_tree::typetree_from_ty; +use crate::TyCtxt; +use rustc_span::sym; +use rustc_ast::expand::typetree::FncTree; + +fn option_ptr_like_scalar_pair_tts<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, +) -> Option { + let ty::Adt(def, args) = ty.kind() else { + return None; + }; + + if !tcx.is_lang_item(def.did(), LangItem::Option) { + return None; + } + + let inner = args.type_at(0); + if !(inner.is_ref() || inner.is_box() || nonnull_inner_ty(tcx, inner).is_some()) { + return None; + } + + let tt = typetree_from_ty(tcx, inner); + //let some_layout = layout.for_variant(bx.cx(), VariantIdx::from_u32(1)); + //let payload_layout = some_layout.field(bx.cx(), 0); + // this will be a slice + //let payload_ty = payload_layout.ty; + //let tt = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); + if tt == TypeTree::new() { + return None; + } + let fnc_tree = FncTree { + args: vec![], + ret: tt, + }; + Some(fnc_tree) +} + +fn nonnull_inner_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option> { + if let ty::Adt(def, args) = ty.kind() + && tcx.is_diagnostic_item(sym::NonNull, def.did()) + { + return Some(args.type_at(0)); + } + + None +} + pub fn scalar_pair_component_field_ty<'a, 'tcx, Bx, V>( bx: &Bx, layout: TyAndLayout<'tcx>, idx: u64, ) -> Option> where Bx: BuilderMethods<'a, 'tcx, Value = V>, - //Bx: BuilderMethods<'_, 'tcx, Value = V>, { let BackendRepr::ScalarPair(a, b) = layout.backend_repr else { return None; @@ -369,41 +416,72 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { ) -> Self { let val = if let BackendRepr::ScalarPair(..) = layout.backend_repr { use rustc_middle::ty::print::with_no_trimmed_paths; - let field0_ty = scalar_pair_component_field_ty(bx, layout, 0); - let field1_ty = scalar_pair_component_field_ty(bx, layout, 1); - let (f1, f2) = if field0_ty.is_none() || field1_ty.is_none() { - (None, None) + let f1 = option_ptr_like_scalar_pair_tts(bx.tcx(), layout.ty); + let f2 = if f1.is_none() { + None } else { - let tt1 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); - let tt2 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field1_ty.unwrap()); - with_no_trimmed_paths!({ - eprintln!( - "from_immediate_or_packed_pair layout {:?}", - layout - ); - eprintln!( - "from_immediate_or_packed_pair layout0 {:?}", - field0_ty - ); - eprintln!( - "from_immediate_or_packed_pair layout1 {:?}", - field1_ty - ); - }); - dbg!(&tt1); - dbg!(&tt2); - //dbg!(&tt); - use rustc_ast::expand::typetree::FncTree; - let fnc1 = FncTree { - args: vec![],//TypeTree::new() - ret: tt1, - }; - let fnc2 = FncTree { - args: vec![],//TypeTree::new() - ret: tt2, - }; - (Some(fnc1), Some(fnc2)) + Some(FncTree { args: vec![], ret: TypeTree::int(8) }) }; + //{ + // dbg!("new option ptr-like scalar pair"); + // tt + //} else { + // //let field0_ty = scalar_pair_component_field_ty(bx, layout, 0); + // //let field1_ty = scalar_pair_component_field_ty(bx, layout, 1); + + // //if field0_ty.is_none() || field1_ty.is_none() { + // //dbg!("from_immediate_or_packed_pair: missing field for layout"); + // //with_no_trimmed_paths!({ + // // eprintln!( + // // "from_immediate_or_packed_pair layout {:?}", + // // layout + // // ); + // //}); + // None + //}; + + //let field0_ty = scalar_pair_component_field_ty(bx, layout, 0); + //let field1_ty = scalar_pair_component_field_ty(bx, layout, 1); + //let (f1, f2) = if field0_ty.is_none() || field1_ty.is_none() { + // dbg!("from_immediate_or_packed_pair: missing field for layout"); + // with_no_trimmed_paths!({ + // eprintln!( + // "from_immediate_or_packed_pair layout {:?}", + // layout + // ); + // }); + // (None, None) + //} else { + // let tt1 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field0_ty.unwrap()); + // let tt2 = rustc_middle::ty::typetree_from_ty(bx.tcx(), field1_ty.unwrap()); + // with_no_trimmed_paths!({ + // eprintln!( + // "from_immediate_or_packed_pair layout {:?}", + // layout + // ); + // eprintln!( + // "from_immediate_or_packed_pair layout0 {:?}", + // field0_ty + // ); + // eprintln!( + // "from_immediate_or_packed_pair layout1 {:?}", + // field1_ty + // ); + // }); + // dbg!(&tt1); + // dbg!(&tt2); + // //dbg!(&tt); + // use rustc_ast::expand::typetree::FncTree; + // let fnc1 = FncTree { + // args: vec![],//TypeTree::new() + // ret: tt1, + // }; + // let fnc2 = FncTree { + // args: vec![],//TypeTree::new() + // ret: tt2, + // }; + // (Some(fnc1), Some(fnc2)) + //}; // Deconstruct the immediate aggregate. let a_llval = bx.extract_value(llval, 0, f1); @@ -427,7 +505,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { with_no_trimmed_paths!({ eprintln!( - "from_immediate_or_packed_pair layout {:?}", + "from_immediate_or_packed_pair single extract_field {:?}", self.layout ); }); diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index afc791c396bea..e933d3826692e 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -6,7 +6,8 @@ use rustc_ast::expand::typetree::{FncTree, TypeTree}; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; -use rustc_middle::ty::{typetree_from_ty, AtomicOrdering, Instance, Ty}; +use rustc_middle::ty::type_tree::typetree_from_ty; +use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; use rustc_session::config::OptLevel; use rustc_span::Span; use rustc_target::callconv::FnAbi; diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 13660f8997860..5a71175835827 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -70,6 +70,8 @@ use rustc_type_ir::{InferCtxtLike, Interner}; use tracing::{debug, instrument, trace}; pub use vtable::*; +pub mod type_tree; + pub use self::closure::{ BorrowKind, CAPTURE_STRUCT_LOCAL, CaptureInfo, CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, MinCaptureList, RootVariableMinCaptureList, UpvarCapture, UpvarId, @@ -2341,261 +2343,3 @@ pub struct DestructuredAdtConst<'tcx> { pub variant: VariantIdx, pub fields: &'tcx [ty::Const<'tcx>], } - -/// Generate TypeTree information for autodiff. -/// This function creates TypeTree metadata that describes the memory layout -/// of function parameters and return types for Enzyme autodiff. -pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { - // Check if TypeTrees are disabled via NoTT flag - if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { - return FncTree { args: vec![], ret: TypeTree::new() }; - } - - // Check if this is actually a function type - if !fn_ty.is_fn() { - dbg!("not a function type: {}", fn_ty); - return FncTree { args: vec![], ret: TypeTree::new() }; - } - - // Get the function signature - let fn_sig = fn_ty.fn_sig(tcx); - let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); - - // Create TypeTrees for each input parameter - let mut args = vec![]; - for ty in sig.inputs().iter() { - let type_tree = typetree_from_ty(tcx, *ty); - args.push(type_tree); - } - - // Create TypeTree for return type - let ret = typetree_from_ty(tcx, sig.output()); - - FncTree { args, ret } -} - -/// Generate TypeTree for a specific type. -/// This function analyzes a Rust type and creates appropriate TypeTree metadata. -pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { - if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { - return TypeTree::new(); - } - if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { - return TypeTree::new(); - } - let mut visited = Vec::new(); - typetree_from_ty_inner(tcx, ty, 0, &mut visited) -} - -/// Maximum recursion depth for TypeTree generation to prevent stack overflow -/// from pathological deeply nested types. Combined with cycle detection. -const MAX_TYPETREE_DEPTH: usize = 6; - -/// Internal recursive function for TypeTree generation with cycle detection and depth limiting. -fn typetree_from_ty_inner<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - depth: usize, - visited: &mut Vec>, -) -> TypeTree { - if depth >= MAX_TYPETREE_DEPTH { - trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty); - return TypeTree::new(); - } - - if visited.contains(&ty) { - return TypeTree::new(); - } - - visited.push(ty); - let result = typetree_from_ty_impl(tcx, ty, depth, visited); - visited.pop(); - result -} - -/// Implementation of TypeTree generation logic. -fn typetree_from_ty_impl<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - depth: usize, - visited: &mut Vec>, -) -> TypeTree { - typetree_from_ty_impl_inner(tcx, ty, depth, visited, false) -} - -/// Internal implementation with context about whether this is for a reference target. -fn typetree_from_ty_impl_inner<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - depth: usize, - visited: &mut Vec>, - is_reference_target: bool, -) -> TypeTree { - - if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { - let Some(inner_ty) = ty.builtin_deref(true) else { - bug!("expected reference or pointer type to have a dereferenceable inner type: {}", ty); - return TypeTree::new(); - }; - - let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); - return TypeTree(vec![Type { - offset: -1, - size: tcx.data_layout.pointer_size().bytes_usize(), - kind: Kind::Pointer, - child, - }]); - } - if let Some(inner) = is_ptr_like(tcx, ty) { - let child = typetree_from_ty_impl_inner(tcx, inner, depth + 1, visited, true); - return TypeTree(vec![Type { - offset: -1, - size: tcx.data_layout.pointer_size().bytes_usize(), - kind: Kind::Pointer, - child, - }]); - } - - if ty.is_scalar() { - let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { - (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) - } else if ty.is_floating_point() { - match ty { - x if x == tcx.types.f16 => (Kind::Half, 2), - x if x == tcx.types.f32 => (Kind::Float, 4), - x if x == tcx.types.f64 => (Kind::Double, 8), - x if x == tcx.types.f128 => (Kind::F128, 16), - _ => (Kind::Integer, 0), - } - } else { - (Kind::Integer, 0) - }; - - // Use offset 0 for scalars that are direct targets of references (like &f64) - // Use offset -1 for scalars used directly (like function return types) - let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; - return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); - } - - if ty.is_array() { - if let ty::Array(element_ty, len_const) = ty.kind() { - let len = len_const.try_to_target_usize(tcx).unwrap_or(0); - if len == 0 { - return TypeTree::new(); - } - let element_tree = - typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); - let mut types = Vec::new(); - for elem_type in &element_tree.0 { - types.push(Type { - offset: -1, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } - - return TypeTree(types); - } - } - - if ty.is_slice() { - if let ty::Slice(element_ty) = ty.kind() { - let element_tree = - typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); - return element_tree; - } - } - - if let ty::Tuple(tuple_types) = ty.kind() { - if tuple_types.is_empty() { - return TypeTree::new(); - } - - let mut types = Vec::new(); - let mut current_offset = 0; - - for tuple_ty in tuple_types.iter() { - let element_tree = - typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false); - - let element_layout = tcx - .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) - .ok() - .map(|layout| layout.size.bytes_usize()) - .unwrap_or(0); - - for elem_type in &element_tree.0 { - types.push(Type { - offset: if elem_type.offset == -1 { - current_offset as isize - } else { - current_offset as isize + elem_type.offset - }, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } - - current_offset += element_layout; - } - - return TypeTree(types); - } - - if let ty::Adt(adt_def, args) = ty.kind() { - if adt_def.is_struct() { - let struct_layout = - tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); - if let Ok(layout) = struct_layout { - let mut types = Vec::new(); - - for (field_idx, field_def) in adt_def.all_fields().enumerate() { - let field_ty = field_def.ty(tcx, args); - let field_tree = typetree_from_ty_impl_inner( - tcx, - field_ty.skip_norm_wip(), - depth + 1, - visited, - false, - ); - - let field_offset = layout.fields.offset(field_idx).bytes_usize(); - - for elem_type in &field_tree.0 { - types.push(Type { - offset: if elem_type.offset == -1 { - field_offset as isize - } else { - field_offset as isize + elem_type.offset - }, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } - } - - return TypeTree(types); - } - } - } - - TypeTree::new() -} -use rustc_span::sym; -fn is_ptr_like<'tcx>( - tcx: TyCtxt<'tcx>, - ty: Ty<'tcx>, - ) -> Option> { - if let ty::Adt(def, args) = ty.kind() { - if tcx.is_diagnostic_item(sym::NonNull, def.did()) { - let inner_ty = args.type_at(0); - return Some(inner_ty); - } - } - None -} - - diff --git a/compiler/rustc_middle/src/ty/type_tree.rs b/compiler/rustc_middle/src/ty/type_tree.rs new file mode 100644 index 0000000000000..89cd1872e7f44 --- /dev/null +++ b/compiler/rustc_middle/src/ty/type_tree.rs @@ -0,0 +1,302 @@ + +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_hir::LangItem; +use tracing::trace; +use rustc_ast::expand::typetree::*; + + +/// Generate TypeTree information for autodiff. +/// This function creates TypeTree metadata that describes the memory layout +/// of function parameters and return types for Enzyme autodiff. +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { + // Check if TypeTrees are disabled via NoTT flag + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Check if this is actually a function type + if !fn_ty.is_fn() { + dbg!("not a function type: {}", fn_ty); + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Get the function signature + let fn_sig = fn_ty.fn_sig(tcx); + let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); + + // Create TypeTrees for each input parameter + let mut args = vec![]; + for ty in sig.inputs().iter() { + let type_tree = typetree_from_ty(tcx, *ty); + args.push(type_tree); + } + + // Create TypeTree for return type + let ret = typetree_from_ty(tcx, sig.output()); + + FncTree { args, ret } +} + +/// Generate TypeTree for a specific type. +/// This function analyzes a Rust type and creates appropriate TypeTree metadata. +pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return TypeTree::new(); + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return TypeTree::new(); + } + let mut visited = Vec::new(); + typetree_from_ty_inner(tcx, ty, 0, &mut visited) +} + +/// Maximum recursion depth for TypeTree generation to prevent stack overflow +/// from pathological deeply nested types. Combined with cycle detection. +const MAX_TYPETREE_DEPTH: usize = 6; + +/// Internal recursive function for TypeTree generation with cycle detection and depth limiting. +fn typetree_from_ty_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + if depth >= MAX_TYPETREE_DEPTH { + trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty); + return TypeTree::new(); + } + + if visited.contains(&ty) { + return TypeTree::new(); + } + + visited.push(ty); + let result = typetree_from_ty_impl(tcx, ty, depth, visited); + visited.pop(); + result +} + +/// Implementation of TypeTree generation logic. +fn typetree_from_ty_impl<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + typetree_from_ty_impl_inner(tcx, ty, depth, visited, false) +} + +/// Internal implementation with context about whether this is for a reference target. +fn typetree_from_ty_impl_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, + is_reference_target: bool, +) -> TypeTree { + + if ty.is_slice() { + if let ty::Slice(element_ty) = ty.kind() { + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); + return element_tree; + } + } + + if let Some(inner) = unwrap_option_nonnull_ptr_like(tcx, ty) { + dbg!("option of niche!"); + dbg!(&ty); + // Option> and similar types use a niche to encode the Option/None variant, + // so we can ignore it and return the inner tt directly. + return typetree_from_ty_impl_inner(tcx, inner, depth + 1, visited, true); + }; + + + if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { + let Some(inner_ty) = ty.builtin_deref(true) else { + bug!("expected reference or pointer type to have a dereferenceable inner type: {}", ty); + return TypeTree::new(); + }; + + let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::Pointer, + child, + }]); + } + if let Some(inner) = is_nonnull(tcx, ty) { + let child = typetree_from_ty_impl_inner(tcx, inner, depth + 1, visited, true); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::Pointer, + child, + }]); + } + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f16 => (Kind::Half, 2), + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + x if x == tcx.types.f128 => (Kind::F128, 16), + _ => (Kind::Integer, 0), + } + } else { + (Kind::Integer, 0) + }; + + // Use offset 0 for scalars that are direct targets of references (like &f64) + // Use offset -1 for scalars used directly (like function return types) + let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; + return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); + } + + if ty.is_array() { + if let ty::Array(element_ty, len_const) = ty.kind() { + let len = len_const.try_to_target_usize(tcx).unwrap_or(0); + if len == 0 { + return TypeTree::new(); + } + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); + let mut types = Vec::new(); + for elem_type in &element_tree.0 { + types.push(Type { + offset: -1, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + + return TypeTree(types); + } + } + + if let ty::Tuple(tuple_types) = ty.kind() { + if tuple_types.is_empty() { + return TypeTree::new(); + } + + let mut types = Vec::new(); + let mut current_offset = 0; + + for tuple_ty in tuple_types.iter() { + let element_tree = + typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + current_offset as isize + } else { + current_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + + current_offset += element_layout; + } + + return TypeTree(types); + } + + if let ty::Adt(adt_def, args) = ty.kind() { + if adt_def.is_struct() { + let struct_layout = + tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); + if let Ok(layout) = struct_layout { + let mut types = Vec::new(); + + for (field_idx, field_def) in adt_def.all_fields().enumerate() { + let field_ty = field_def.ty(tcx, args); + let field_tree = typetree_from_ty_impl_inner( + tcx, + field_ty.skip_norm_wip(), + depth + 1, + visited, + false, + ); + + let field_offset = layout.fields.offset(field_idx).bytes_usize(); + + for elem_type in &field_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + field_offset as isize + } else { + field_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + } + + TypeTree::new() +} +use rustc_span::sym; +fn is_nonnull<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + ) -> Option> { + if let ty::Adt(def, args) = ty.kind() { + if tcx.is_diagnostic_item(sym::NonNull, def.did()) { + let inner_ty = args.type_at(0); + return Some(inner_ty); + } + } + None +} + +fn unwrap_option_nonnull_ptr_like<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, +) -> Option> { + let ty::Adt(def, args) = ty.kind() else { + return None; + }; + + if !tcx.is_lang_item(def.did(), LangItem::Option) { + dbg!("not an Option: {}", ty); + return None; + } + + let inner = args.type_at(0); + + // Accepted: Option<&T>, Option<&mut T>, Option> + // Rejected: Option<*const T>, Option<*mut T> + if inner.is_ref() || inner.is_box() { + return inner.builtin_deref(true); + } + + // Accepted: Option> + if let Some(inner_ty) = is_nonnull(tcx, inner) { + return Some(inner_ty); + } + dbg!("the wrong Option: {}", ty); + + None +} + +