Skip to content

Commit 51492d0

Browse files
Rollup merge of #152978 - JonathanBrouwer:autodiff_attrs, r=jdonszelmann
Port `#[rustc_autodiff]` to the attribute parser infrastructure For #131229 r? @jdonszelmann cc @ZuseZ4 `autodiff_forward` and `autodiff_reverse` can be ported in a seperate PR, but these are expanded in the AST and don't exist anymore in the HIR so this is a bit more of a challenge.
2 parents 81aa532 + d5b6474 commit 51492d0

File tree

16 files changed

+282
-263
lines changed

16 files changed

+282
-263
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 22 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2-
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
2+
//! we create an `RustcAutodiff` which contains the source and target function names. The source
33
//! is the function to which the autodiff attribute is applied, and the target is the function
44
//! getting generated by us (with a name given by the user as the first autodiff arg).
55
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
9+
use rustc_span::{Symbol, sym};
10+
1011
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1112
use crate::{Ty, TyKind};
1213

@@ -31,6 +32,12 @@ pub enum DiffMode {
3132
Reverse,
3233
}
3334

35+
impl DiffMode {
36+
pub fn all_modes() -> &'static [Symbol] {
37+
&[sym::Source, sym::Forward, sym::Reverse]
38+
}
39+
}
40+
3441
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
3542
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
3643
/// we add to the previous shadow value. To not surprise users, we picked different names.
@@ -76,43 +83,20 @@ impl DiffActivity {
7683
use DiffActivity::*;
7784
matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
7885
}
79-
}
80-
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
81-
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
82-
pub struct AutoDiffItem {
83-
/// The name of the function getting differentiated
84-
pub source: String,
85-
/// The name of the function being generated
86-
pub target: String,
87-
pub attrs: AutoDiffAttrs,
88-
pub inputs: Vec<TypeTree>,
89-
pub output: TypeTree,
90-
}
9186

92-
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
93-
pub struct AutoDiffAttrs {
94-
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
95-
/// e.g. in the [JAX
96-
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
97-
pub mode: DiffMode,
98-
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
99-
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
100-
/// - Calling the function 50 times with a batch size of 2
101-
/// - Calling the function 25 times with a batch size of 4,
102-
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
103-
/// cache locality, better re-usal of primal values, and other optimizations.
104-
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
105-
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
106-
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
107-
/// experiments for now and focus on documenting the implications of a large width.
108-
pub width: u32,
109-
pub ret_activity: DiffActivity,
110-
pub input_activity: Vec<DiffActivity>,
111-
}
112-
113-
impl AutoDiffAttrs {
114-
pub fn has_primal_ret(&self) -> bool {
115-
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
87+
pub fn all_activities() -> &'static [Symbol] {
88+
&[
89+
sym::None,
90+
sym::Active,
91+
sym::ActiveOnly,
92+
sym::Const,
93+
sym::Dual,
94+
sym::Dualv,
95+
sym::DualOnly,
96+
sym::DualvOnly,
97+
sym::Duplicated,
98+
sym::DuplicatedOnly,
99+
]
116100
}
117101
}
118102

@@ -241,59 +225,3 @@ impl FromStr for DiffActivity {
241225
}
242226
}
243227
}
244-
245-
impl AutoDiffAttrs {
246-
pub fn has_ret_activity(&self) -> bool {
247-
self.ret_activity != DiffActivity::None
248-
}
249-
pub fn has_active_only_ret(&self) -> bool {
250-
self.ret_activity == DiffActivity::ActiveOnly
251-
}
252-
253-
pub const fn error() -> Self {
254-
AutoDiffAttrs {
255-
mode: DiffMode::Error,
256-
width: 0,
257-
ret_activity: DiffActivity::None,
258-
input_activity: Vec::new(),
259-
}
260-
}
261-
pub fn source() -> Self {
262-
AutoDiffAttrs {
263-
mode: DiffMode::Source,
264-
width: 0,
265-
ret_activity: DiffActivity::None,
266-
input_activity: Vec::new(),
267-
}
268-
}
269-
270-
pub fn is_active(&self) -> bool {
271-
self.mode != DiffMode::Error
272-
}
273-
274-
pub fn is_source(&self) -> bool {
275-
self.mode == DiffMode::Source
276-
}
277-
pub fn apply_autodiff(&self) -> bool {
278-
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
279-
}
280-
281-
pub fn into_item(
282-
self,
283-
source: String,
284-
target: String,
285-
inputs: Vec<TypeTree>,
286-
output: TypeTree,
287-
) -> AutoDiffItem {
288-
AutoDiffItem { source, target, inputs, output, attrs: self }
289-
}
290-
}
291-
292-
impl fmt::Display for AutoDiffItem {
293-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294-
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
295-
write!(f, " with attributes: {:?}", self.attrs)?;
296-
write!(f, " with inputs: {:?}", self.inputs)?;
297-
write!(f, " with output: {:?}", self.output)
298-
}
299-
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
use std::str::FromStr;
2+
3+
use rustc_ast::LitKind;
4+
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
5+
use rustc_feature::{AttributeTemplate, template};
6+
use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
7+
use rustc_hir::{MethodKind, Target};
8+
use rustc_span::{Symbol, sym};
9+
use thin_vec::ThinVec;
10+
11+
use crate::attributes::prelude::Allow;
12+
use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser};
13+
use crate::context::{AcceptContext, Stage};
14+
use crate::parser::{ArgParser, MetaItemOrLitParser};
15+
use crate::target_checking::AllowedTargets;
16+
17+
pub(crate) struct RustcAutodiffParser;
18+
19+
impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
20+
const PATH: &[Symbol] = &[sym::rustc_autodiff];
21+
const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost;
22+
const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
23+
const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
24+
Allow(Target::Fn),
25+
Allow(Target::Method(MethodKind::Inherent)),
26+
Allow(Target::Method(MethodKind::Trait { body: true })),
27+
Allow(Target::Method(MethodKind::TraitImpl)),
28+
]);
29+
const TEMPLATE: AttributeTemplate = template!(
30+
List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
31+
"https://doc.rust-lang.org/std/autodiff/index.html"
32+
);
33+
34+
fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
35+
let list = match args {
36+
ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
37+
ArgParser::List(list) => list,
38+
ArgParser::NameValue(_) => {
39+
cx.expected_list_or_no_args(cx.attr_span);
40+
return None;
41+
}
42+
};
43+
44+
let mut items = list.mixed().peekable();
45+
46+
// Parse name
47+
let Some(mode) = items.next() else {
48+
cx.expected_at_least_one_argument(list.span);
49+
return None;
50+
};
51+
let Some(mode) = mode.meta_item() else {
52+
cx.expected_identifier(mode.span());
53+
return None;
54+
};
55+
let Ok(()) = mode.args().no_args() else {
56+
cx.expected_identifier(mode.span());
57+
return None;
58+
};
59+
let Some(mode) = mode.path().word() else {
60+
cx.expected_identifier(mode.span());
61+
return None;
62+
};
63+
let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
64+
cx.expected_specific_argument(mode.span, DiffMode::all_modes());
65+
return None;
66+
};
67+
68+
// Parse width
69+
let width = if let Some(width) = items.peek()
70+
&& let MetaItemOrLitParser::Lit(width) = width
71+
&& let LitKind::Int(width, _) = width.kind
72+
&& let Ok(width) = width.0.try_into()
73+
{
74+
_ = items.next();
75+
width
76+
} else {
77+
1
78+
};
79+
80+
// Parse activities
81+
let mut activities = ThinVec::new();
82+
for activity in items {
83+
let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
84+
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
85+
return None;
86+
};
87+
let Ok(()) = activity.args().no_args() else {
88+
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
89+
return None;
90+
};
91+
let Some(activity) = activity.path().word() else {
92+
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
93+
return None;
94+
};
95+
let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
96+
cx.expected_specific_argument(activity.span, DiffActivity::all_activities());
97+
return None;
98+
};
99+
100+
activities.push(activity);
101+
}
102+
let Some(ret_activity) = activities.pop() else {
103+
cx.expected_specific_argument(
104+
list.span.with_lo(list.span.hi()),
105+
DiffActivity::all_activities(),
106+
);
107+
return None;
108+
};
109+
110+
Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
111+
mode,
112+
width,
113+
input_activity: activities,
114+
ret_activity,
115+
}))))
116+
}
117+
}

compiler/rustc_attr_parsing/src/attributes/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::target_checking::AllowedTargets;
3030
mod prelude;
3131

3232
pub(crate) mod allow_unstable;
33+
pub(crate) mod autodiff;
3334
pub(crate) mod body;
3435
pub(crate) mod cfg;
3536
pub(crate) mod cfg_select;

compiler/rustc_attr_parsing/src/context.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use rustc_span::{ErrorGuaranteed, Span, Symbol};
1919
use crate::AttributeParser;
2020
// Glob imports to avoid big, bitrotty import lists
2121
use crate::attributes::allow_unstable::*;
22+
use crate::attributes::autodiff::*;
2223
use crate::attributes::body::*;
2324
use crate::attributes::cfi_encoding::*;
2425
use crate::attributes::codegen_attrs::*;
@@ -204,6 +205,7 @@ attribute_parsers!(
204205
Single<ReexportTestHarnessMainParser>,
205206
Single<RustcAbiParser>,
206207
Single<RustcAllocatorZeroedVariantParser>,
208+
Single<RustcAutodiffParser>,
207209
Single<RustcBuiltinMacroParser>,
208210
Single<RustcDefPathParser>,
209211
Single<RustcDeprecatedSafe2024Parser>,

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ mod llvm_enzyme {
88
use std::string::String;
99

1010
use rustc_ast::expand::autodiff_attrs::{
11-
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12-
valid_ty_for_activity,
11+
DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity,
1312
};
1413
use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
1514
use rustc_ast::tokenstream::*;
@@ -20,6 +19,7 @@ mod llvm_enzyme {
2019
MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
2120
};
2221
use rustc_expand::base::{Annotatable, ExtCtxt};
22+
use rustc_hir::attrs::RustcAutodiff;
2323
use rustc_span::{Ident, Span, Symbol, sym};
2424
use thin_vec::{ThinVec, thin_vec};
2525
use tracing::{debug, trace};
@@ -87,7 +87,7 @@ mod llvm_enzyme {
8787
meta_item: &ThinVec<MetaItemInner>,
8888
has_ret: bool,
8989
mode: DiffMode,
90-
) -> AutoDiffAttrs {
90+
) -> RustcAutodiff {
9191
let dcx = ecx.sess.dcx();
9292

9393
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
@@ -105,7 +105,7 @@ mod llvm_enzyme {
105105
span: meta_item[1].span(),
106106
width: x,
107107
});
108-
return AutoDiffAttrs::error();
108+
return RustcAutodiff::error();
109109
}
110110
}
111111
} else {
@@ -129,7 +129,7 @@ mod llvm_enzyme {
129129
};
130130
}
131131
if errors {
132-
return AutoDiffAttrs::error();
132+
return RustcAutodiff::error();
133133
}
134134

135135
// If a return type exist, we need to split the last activity,
@@ -145,11 +145,11 @@ mod llvm_enzyme {
145145
(&DiffActivity::None, activities.as_slice())
146146
};
147147

148-
AutoDiffAttrs {
148+
RustcAutodiff {
149149
mode,
150150
width,
151151
ret_activity: *ret_activity,
152-
input_activity: input_activity.to_vec(),
152+
input_activity: input_activity.iter().cloned().collect(),
153153
}
154154
}
155155

@@ -309,7 +309,7 @@ mod llvm_enzyme {
309309
ts.pop();
310310
let ts: TokenStream = TokenStream::from_iter(ts);
311311

312-
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
312+
let x: RustcAutodiff = from_ast(ecx, &meta_item_vec, has_ret, mode);
313313
if !x.is_active() {
314314
// We encountered an error, so we return the original item.
315315
// This allows us to potentially parse other attributes.
@@ -603,7 +603,7 @@ mod llvm_enzyme {
603603
fn gen_enzyme_decl(
604604
ecx: &ExtCtxt<'_>,
605605
sig: &ast::FnSig,
606-
x: &AutoDiffAttrs,
606+
x: &RustcAutodiff,
607607
span: Span,
608608
) -> ast::FnSig {
609609
let dcx = ecx.sess.dcx();

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use std::ptr;
22

3-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
3+
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
44
use rustc_ast::expand::typetree::FncTree;
55
use rustc_codegen_ssa::common::TypeKind;
66
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7+
use rustc_data_structures::thin_vec::ThinVec;
8+
use rustc_hir::attrs::RustcAutodiff;
79
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
810
use rustc_middle::{bug, ty};
911
use rustc_target::callconv::PassMode;
@@ -18,7 +20,7 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
1820
tcx: TyCtxt<'tcx>,
1921
instance: Instance<'tcx>,
2022
typing_env: TypingEnv<'tcx>,
21-
da: &mut Vec<DiffActivity>,
23+
da: &mut ThinVec<DiffActivity>,
2224
) {
2325
let fn_ty = instance.ty(tcx, typing_env);
2426

@@ -295,7 +297,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
295297
outer_name: &str,
296298
ret_ty: &'ll Type,
297299
fn_args: &[&'ll Value],
298-
attrs: AutoDiffAttrs,
300+
attrs: &RustcAutodiff,
299301
dest: PlaceRef<'tcx, &'ll Value>,
300302
fnc_tree: FncTree,
301303
) {

0 commit comments

Comments
 (0)