Skip to content

Commit 0c6736d

Browse files
fix(rewrite): inject Guard::check() after nested .await expressions
Replace flat top-level iteration in inject_await_checks() with recursive VisitMut-based AwaitCheckInjector that processes blocks at every nesting depth. check() is now injected immediately after each .await inside if, match, loop, while, for, and bare blocks rather than after the enclosing top-level statement. Closes #142
1 parent 5768d20 commit 0c6736d

File tree

1 file changed

+312
-13
lines changed

1 file changed

+312
-13
lines changed

src/rewrite.rs

Lines changed: 312 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,35 +72,80 @@ const SCOPE_FUNCTIONS: &[&str] = &["scope", "scope_fifo"];
7272
/// SpanContext stays on the parent thread.
7373
const FORK_TRIGGER_FUNCTIONS: &[&str] = &["scope", "scope_fifo", "join"];
7474

75-
/// Check whether a statement's expression tree contains an `.await`.
76-
fn contains_await(stmt: &syn::Stmt) -> bool {
75+
/// Check whether a statement contains `.await` at its own expression level,
76+
/// excluding awaits inside sub-blocks of control-flow expressions (if, match,
77+
/// loop, while, for). Those sub-blocks are handled by the recursive visitor.
78+
///
79+
/// Also excludes closures, async blocks, and nested fn items since those are
80+
/// separate async contexts.
81+
fn stmt_has_direct_await(stmt: &syn::Stmt) -> bool {
7782
struct AwaitFinder {
7883
found: bool,
7984
}
8085
impl<'ast> syn::visit::Visit<'ast> for AwaitFinder {
8186
fn visit_expr_await(&mut self, _: &'ast syn::ExprAwait) {
8287
self.found = true;
8388
}
89+
// Don't descend into closures -- separate async context.
90+
fn visit_expr_closure(&mut self, _: &'ast syn::ExprClosure) {}
91+
// Don't descend into async blocks -- separate async context.
92+
fn visit_expr_async(&mut self, _: &'ast syn::ExprAsync) {}
93+
// Don't descend into nested fn items -- they have their own guards.
94+
fn visit_item_fn(&mut self, _: &'ast syn::ItemFn) {}
95+
// Don't descend into sub-blocks of control-flow expressions.
96+
// The recursive VisitMut will process those blocks separately.
97+
fn visit_block(&mut self, _: &'ast syn::Block) {}
8498
}
8599
let mut finder = AwaitFinder { found: false };
86100
syn::visit::Visit::visit_stmt(&mut finder, stmt);
87101
finder.found
88102
}
89103

90-
/// Inject `_piano_guard.check();` after each statement containing `.await`.
104+
/// Recursively inject `_piano_guard.check();` after each statement that
105+
/// directly contains `.await`, at every nesting depth.
106+
///
107+
/// Uses `VisitMut` to walk into if/match/loop/while/for/block bodies so
108+
/// that `check()` is placed immediately after the `.await`, not after the
109+
/// enclosing top-level statement.
91110
fn inject_await_checks(block: &mut syn::Block) {
92-
let mut new_stmts = Vec::with_capacity(block.stmts.len() * 2);
93-
for stmt in block.stmts.drain(..) {
94-
let has_await = contains_await(&stmt);
95-
new_stmts.push(stmt);
96-
if has_await {
97-
let check_stmt: syn::Stmt = syn::parse_quote! {
98-
_piano_guard.check();
99-
};
100-
new_stmts.push(check_stmt);
111+
struct AwaitCheckInjector;
112+
113+
impl AwaitCheckInjector {
114+
/// Process a block's statements: for each statement that directly
115+
/// contains `.await`, insert a `check()` call after it. Then recurse
116+
/// into nested blocks within each statement.
117+
fn process_block(&mut self, block: &mut syn::Block) {
118+
let mut new_stmts = Vec::with_capacity(block.stmts.len() * 2);
119+
for mut stmt in block.stmts.drain(..) {
120+
let has_await = stmt_has_direct_await(&stmt);
121+
// Recurse into nested blocks within this statement first.
122+
self.visit_stmt_mut(&mut stmt);
123+
new_stmts.push(stmt);
124+
if has_await {
125+
let check_stmt: syn::Stmt = syn::parse_quote! {
126+
_piano_guard.check();
127+
};
128+
new_stmts.push(check_stmt);
129+
}
130+
}
131+
block.stmts = new_stmts;
132+
}
133+
}
134+
135+
impl VisitMut for AwaitCheckInjector {
136+
fn visit_block_mut(&mut self, block: &mut syn::Block) {
137+
self.process_block(block);
101138
}
139+
140+
// Don't descend into closures -- separate async context.
141+
fn visit_expr_closure_mut(&mut self, _: &mut syn::ExprClosure) {}
142+
// Don't descend into async blocks -- separate async context.
143+
fn visit_expr_async_mut(&mut self, _: &mut syn::ExprAsync) {}
144+
// Don't descend into nested fn items -- they have their own guards.
145+
fn visit_item_fn_mut(&mut self, _: &mut syn::ItemFn) {}
102146
}
103-
block.stmts = new_stmts;
147+
148+
AwaitCheckInjector.process_block(block);
104149
}
105150

106151
struct Instrumenter {
@@ -2089,4 +2134,258 @@ fn main() {}
20892134
"fn inside impl block in macro should be instrumented. Got:\n{result}"
20902135
);
20912136
}
2137+
2138+
// -- Nested .await check() injection tests --
2139+
//
2140+
// These tests verify check() is injected INSIDE nested blocks, directly
2141+
// after the .await, not just after the enclosing top-level statement.
2142+
2143+
/// Assert that `needle_a` appears before `needle_b` in `haystack`.
2144+
fn assert_appears_before(haystack: &str, needle_a: &str, needle_b: &str, context: &str) {
2145+
let pos_a = haystack
2146+
.find(needle_a)
2147+
.unwrap_or_else(|| panic!("{needle_a} not found in output. {context}:\n{haystack}"));
2148+
let pos_b = haystack
2149+
.find(needle_b)
2150+
.unwrap_or_else(|| panic!("{needle_b} not found in output. {context}:\n{haystack}"));
2151+
assert!(
2152+
pos_a < pos_b,
2153+
"{needle_a} should appear before {needle_b} ({context}). Got:\n{haystack}",
2154+
);
2155+
}
2156+
2157+
#[test]
2158+
fn injects_check_inside_if_block_not_after_it() {
2159+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2160+
// The .await is inside the if block. check() must appear between
2161+
// fetch().await and process(), not after the closing brace of `if`.
2162+
let source = r#"
2163+
async fn example() {
2164+
if condition {
2165+
fetch().await;
2166+
process();
2167+
}
2168+
}
2169+
"#;
2170+
let result = instrument_source(source, &targets, false).unwrap();
2171+
// check() must appear before process() (i.e. inside the if block),
2172+
// not after the if statement.
2173+
assert_appears_before(
2174+
&result.source,
2175+
"_piano_guard.check()",
2176+
"process()",
2177+
"inside if block",
2178+
);
2179+
}
2180+
2181+
#[test]
2182+
fn injects_check_inside_match_arm() {
2183+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2184+
let source = r#"
2185+
async fn example(x: u32) {
2186+
match x {
2187+
0 => {
2188+
fetch().await;
2189+
process();
2190+
}
2191+
_ => {}
2192+
}
2193+
}
2194+
"#;
2195+
let result = instrument_source(source, &targets, false).unwrap();
2196+
assert_appears_before(
2197+
&result.source,
2198+
"_piano_guard.check()",
2199+
"process()",
2200+
"inside match arm",
2201+
);
2202+
}
2203+
2204+
#[test]
2205+
fn injects_check_inside_loop() {
2206+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2207+
let source = r#"
2208+
async fn example() {
2209+
loop {
2210+
fetch().await;
2211+
process();
2212+
}
2213+
}
2214+
"#;
2215+
let result = instrument_source(source, &targets, false).unwrap();
2216+
assert_appears_before(
2217+
&result.source,
2218+
"_piano_guard.check()",
2219+
"process()",
2220+
"inside loop",
2221+
);
2222+
}
2223+
2224+
#[test]
2225+
fn injects_check_at_every_nesting_depth() {
2226+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2227+
let source = r#"
2228+
async fn example() {
2229+
if condition {
2230+
fetch().await;
2231+
if other {
2232+
save().await;
2233+
cleanup();
2234+
}
2235+
}
2236+
send().await;
2237+
}
2238+
"#;
2239+
let result = instrument_source(source, &targets, false).unwrap();
2240+
let check_count = result.source.matches("_piano_guard.check()").count();
2241+
assert_eq!(
2242+
check_count, 3,
2243+
"should inject check() after every .await at any depth. Got:\n{}",
2244+
result.source
2245+
);
2246+
// Verify the innermost check() is between save().await and cleanup()
2247+
let save_pos = result.source.find("save()").unwrap();
2248+
let cleanup_pos = result.source.find("cleanup()").unwrap();
2249+
let check_between = result.source[save_pos..cleanup_pos].contains("_piano_guard.check()");
2250+
assert!(
2251+
check_between,
2252+
"check() should appear between save().await and cleanup(). Got:\n{}",
2253+
result.source,
2254+
);
2255+
}
2256+
2257+
#[test]
2258+
fn injects_check_inside_else_block() {
2259+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2260+
let source = r#"
2261+
async fn example() {
2262+
if condition {
2263+
sync_work();
2264+
} else {
2265+
fetch().await;
2266+
process();
2267+
}
2268+
}
2269+
"#;
2270+
let result = instrument_source(source, &targets, false).unwrap();
2271+
assert_appears_before(
2272+
&result.source,
2273+
"_piano_guard.check()",
2274+
"process()",
2275+
"inside else block",
2276+
);
2277+
}
2278+
2279+
#[test]
2280+
fn injects_check_inside_while_let_body() {
2281+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2282+
let source = r#"
2283+
async fn example() {
2284+
while let Some(item) = iter.next() {
2285+
process(item).await;
2286+
log();
2287+
}
2288+
}
2289+
"#;
2290+
let result = instrument_source(source, &targets, false).unwrap();
2291+
assert_appears_before(
2292+
&result.source,
2293+
"_piano_guard.check()",
2294+
"log()",
2295+
"inside while let body",
2296+
);
2297+
}
2298+
2299+
#[test]
2300+
fn injects_check_inside_for_loop_body() {
2301+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2302+
let source = r#"
2303+
async fn example() {
2304+
for item in items {
2305+
process(item).await;
2306+
log();
2307+
}
2308+
}
2309+
"#;
2310+
let result = instrument_source(source, &targets, false).unwrap();
2311+
assert_appears_before(
2312+
&result.source,
2313+
"_piano_guard.check()",
2314+
"log()",
2315+
"inside for loop body",
2316+
);
2317+
}
2318+
2319+
#[test]
2320+
fn injects_check_inside_bare_block() {
2321+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2322+
let source = r#"
2323+
async fn example() {
2324+
{
2325+
fetch().await;
2326+
process();
2327+
}
2328+
}
2329+
"#;
2330+
let result = instrument_source(source, &targets, false).unwrap();
2331+
assert_appears_before(
2332+
&result.source,
2333+
"_piano_guard.check()",
2334+
"process()",
2335+
"inside bare block",
2336+
);
2337+
}
2338+
2339+
#[test]
2340+
fn injects_check_after_await_in_condition_expression() {
2341+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2342+
let source = r#"
2343+
async fn example() {
2344+
if stream.next().await.is_some() {
2345+
process();
2346+
}
2347+
cleanup();
2348+
}
2349+
"#;
2350+
let result = instrument_source(source, &targets, false).unwrap();
2351+
// The .await is in the condition, so check() goes after the whole if statement
2352+
assert_eq!(
2353+
result.source.matches("_piano_guard.check()").count(),
2354+
1,
2355+
"expected exactly 1 check() call",
2356+
);
2357+
assert_appears_before(
2358+
&result.source,
2359+
"process()",
2360+
"_piano_guard.check()",
2361+
"check() should be after the if block containing process()",
2362+
);
2363+
assert_appears_before(
2364+
&result.source,
2365+
"_piano_guard.check()",
2366+
"cleanup()",
2367+
"after if with .await in condition",
2368+
);
2369+
}
2370+
2371+
#[test]
2372+
fn no_check_injection_inside_closure_or_async_block() {
2373+
let targets: HashSet<String> = ["example".to_string()].into_iter().collect();
2374+
let source = r#"
2375+
async fn example() {
2376+
let f = || async { fetch().await; };
2377+
let g = async { save().await; };
2378+
send().await;
2379+
}
2380+
"#;
2381+
let result = instrument_source(source, &targets, false).unwrap();
2382+
// Only one check() -- after send().await. The closure and async block
2383+
// have separate async contexts with their own guards.
2384+
let check_count = result.source.matches("_piano_guard.check()").count();
2385+
assert_eq!(
2386+
check_count, 1,
2387+
"expected 1 check() (after send().await only), got {check_count}:\n{}",
2388+
result.source,
2389+
);
2390+
}
20922391
}

0 commit comments

Comments
 (0)