Skip to content

Commit 8da69a9

Browse files
authored
Refactor proc-macro attribute parsing (#1369)
* Refactor proc-macro attribute parsing * Remove `#[allow(warnings)]` which was accidentally committed * Change span for "cannot use `rejection` without `via`" error for enums * fix test
1 parent 54d8439 commit 8da69a9

File tree

5 files changed

+150
-171
lines changed

5 files changed

+150
-171
lines changed

axum-macros/src/attr_parsing.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use quote::ToTokens;
2+
use syn::parse::{Parse, ParseStream};
3+
4+
pub(crate) fn parse_parenthesized_attribute<K, T>(
5+
input: ParseStream,
6+
out: &mut Option<(K, T)>,
7+
) -> syn::Result<()>
8+
where
9+
K: Parse + ToTokens,
10+
T: Parse,
11+
{
12+
let kw = input.parse()?;
13+
14+
let content;
15+
syn::parenthesized!(content in input);
16+
let inner = content.parse()?;
17+
18+
if out.is_some() {
19+
let kw_name = std::any::type_name::<K>().split("::").last().unwrap();
20+
let msg = format!("`{}` specified more than once", kw_name);
21+
return Err(syn::Error::new_spanned(kw, msg));
22+
}
23+
24+
*out = Some((kw, inner));
25+
26+
Ok(())
27+
}
28+
29+
pub(crate) trait Combine: Sized {
30+
fn combine(self, other: Self) -> syn::Result<Self>;
31+
}
32+
33+
pub(crate) fn parse_attrs<T>(ident: &str, attrs: &[syn::Attribute]) -> syn::Result<T>
34+
where
35+
T: Combine + Default + Parse,
36+
{
37+
attrs
38+
.iter()
39+
.filter(|attr| attr.path.is_ident(ident))
40+
.map(|attr| attr.parse_args::<T>())
41+
.try_fold(T::default(), |out, next| out.combine(next?))
42+
}
43+
44+
pub(crate) fn combine_attribute<K, T>(a: &mut Option<(K, T)>, b: Option<(K, T)>) -> syn::Result<()>
45+
where
46+
K: ToTokens,
47+
{
48+
if let Some((kw, inner)) = b {
49+
if a.is_some() {
50+
let kw_name = std::any::type_name::<K>().split("::").last().unwrap();
51+
let msg = format!("`{}` specified more than once", kw_name);
52+
return Err(syn::Error::new_spanned(kw, msg));
53+
}
54+
*a = Some((kw, inner));
55+
}
56+
Ok(())
57+
}
58+
59+
pub(crate) fn second<T, K>(tuple: (T, K)) -> K {
60+
tuple.1
61+
}

axum-macros/src/from_request.rs

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
use self::attr::{
2-
parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr,
1+
use self::attr::FromRequestContainerAttrs;
2+
use crate::{
3+
attr_parsing::{parse_attrs, second},
4+
from_request::attr::FromRequestFieldAttrs,
35
};
46
use proc_macro2::{Span, TokenStream};
57
use quote::{quote, quote_spanned};
@@ -38,26 +40,20 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
3840

3941
let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
4042

41-
match parse_container_attrs(&attrs)? {
42-
FromRequestContainerAttr::Via { path, rejection } => {
43-
impl_struct_by_extracting_all_at_once(
44-
ident,
45-
fields,
46-
path,
47-
rejection,
48-
generic_ident,
49-
tr,
50-
)
51-
}
52-
FromRequestContainerAttr::Rejection(rejection) => {
53-
error_on_generic_ident(generic_ident, tr)?;
43+
let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?;
5444

55-
impl_struct_by_extracting_each_field(ident, fields, Some(rejection), tr)
56-
}
57-
FromRequestContainerAttr::None => {
45+
match (via.map(second), rejection.map(second)) {
46+
(Some(via), rejection) => impl_struct_by_extracting_all_at_once(
47+
ident,
48+
fields,
49+
via,
50+
rejection,
51+
generic_ident,
52+
tr,
53+
),
54+
(None, rejection) => {
5855
error_on_generic_ident(generic_ident, tr)?;
59-
60-
impl_struct_by_extracting_each_field(ident, fields, None, tr)
56+
impl_struct_by_extracting_each_field(ident, fields, rejection, tr)
6157
}
6258
}
6359
}
@@ -82,15 +78,21 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
8278
return Err(syn::Error::new_spanned(where_clause, generics_error));
8379
}
8480

85-
match parse_container_attrs(&attrs)? {
86-
FromRequestContainerAttr::Via { path, rejection } => {
87-
impl_enum_by_extracting_all_at_once(ident, variants, path, rejection, tr)
88-
}
89-
FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned(
90-
rejection,
81+
let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?;
82+
83+
match (via.map(second), rejection) {
84+
(Some(via), rejection) => impl_enum_by_extracting_all_at_once(
85+
ident,
86+
variants,
87+
via,
88+
rejection.map(second),
89+
tr,
90+
),
91+
(None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned(
92+
rejection_kw,
9193
"cannot use `rejection` without `via`",
9294
)),
93-
FromRequestContainerAttr::None => Err(syn::Error::new(
95+
(None, _) => Err(syn::Error::new(
9496
Span::call_site(),
9597
"missing `#[from_request(via(...))]`",
9698
)),
@@ -316,7 +318,7 @@ fn extract_fields(
316318
let mut res: Vec<_> = fields_iter
317319
.enumerate()
318320
.map(|(index, field)| {
319-
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
321+
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
320322

321323
let member = member(field, index);
322324
let ty_span = field.ty.span();
@@ -434,7 +436,7 @@ fn extract_fields(
434436

435437
// Handle the last element, if deriving FromRequest
436438
if let Some(field) = last {
437-
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
439+
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
438440

439441
let member = member(field, fields.len() - 1);
440442
let ty_span = field.ty.span();
@@ -557,7 +559,8 @@ fn impl_struct_by_extracting_all_at_once(
557559
};
558560

559561
for field in fields {
560-
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
562+
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
563+
561564
if let Some((via, _)) = via {
562565
return Err(syn::Error::new_spanned(
563566
via,
@@ -695,7 +698,8 @@ fn impl_enum_by_extracting_all_at_once(
695698
tr: Trait,
696699
) -> syn::Result<TokenStream> {
697700
for variant in variants {
698-
let FromRequestFieldAttr { via } = parse_field_attrs(&variant.attrs)?;
701+
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &variant.attrs)?;
702+
699703
if let Some((via, _)) = via {
700704
return Err(syn::Error::new_spanned(
701705
via,
@@ -710,7 +714,7 @@ fn impl_enum_by_extracting_all_at_once(
710714
};
711715

712716
for field in fields {
713-
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
717+
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
714718
if let Some((via, _)) = via {
715719
return Err(syn::Error::new_spanned(
716720
via,

0 commit comments

Comments
 (0)