@@ -72,35 +72,80 @@ const SCOPE_FUNCTIONS: &[&str] = &["scope", "scope_fifo"];
7272/// SpanContext stays on the parent thread.
7373const FORK_TRIGGER_FUNCTIONS : & [ & str ] = & [ "scope" , "scope_fifo" , "join" ] ;
7474
75- /// Check whether a statement's expression tree contains an `.await`.
76- fn contains_await ( stmt : & syn:: Stmt ) -> bool {
75+ /// Check whether a statement contains `.await` at its own expression level,
76+ /// excluding awaits inside sub-blocks of control-flow expressions (if, match,
77+ /// loop, while, for). Those sub-blocks are handled by the recursive visitor.
78+ ///
79+ /// Also excludes closures, async blocks, and nested fn items since those are
80+ /// separate async contexts.
81+ fn stmt_has_direct_await ( stmt : & syn:: Stmt ) -> bool {
7782 struct AwaitFinder {
7883 found : bool ,
7984 }
8085 impl < ' ast > syn:: visit:: Visit < ' ast > for AwaitFinder {
8186 fn visit_expr_await ( & mut self , _: & ' ast syn:: ExprAwait ) {
8287 self . found = true ;
8388 }
89+ // Don't descend into closures -- separate async context.
90+ fn visit_expr_closure ( & mut self , _: & ' ast syn:: ExprClosure ) { }
91+ // Don't descend into async blocks -- separate async context.
92+ fn visit_expr_async ( & mut self , _: & ' ast syn:: ExprAsync ) { }
93+ // Don't descend into nested fn items -- they have their own guards.
94+ fn visit_item_fn ( & mut self , _: & ' ast syn:: ItemFn ) { }
95+ // Don't descend into sub-blocks of control-flow expressions.
96+ // The recursive VisitMut will process those blocks separately.
97+ fn visit_block ( & mut self , _: & ' ast syn:: Block ) { }
8498 }
8599 let mut finder = AwaitFinder { found : false } ;
86100 syn:: visit:: Visit :: visit_stmt ( & mut finder, stmt) ;
87101 finder. found
88102}
89103
90- /// Inject `_piano_guard.check();` after each statement containing `.await`.
104+ /// Recursively inject `_piano_guard.check();` after each statement that
105+ /// directly contains `.await`, at every nesting depth.
106+ ///
107+ /// Uses `VisitMut` to walk into if/match/loop/while/for/block bodies so
108+ /// that `check()` is placed immediately after the `.await`, not after the
109+ /// enclosing top-level statement.
91110fn inject_await_checks ( block : & mut syn:: Block ) {
92- let mut new_stmts = Vec :: with_capacity ( block. stmts . len ( ) * 2 ) ;
93- for stmt in block. stmts . drain ( ..) {
94- let has_await = contains_await ( & stmt) ;
95- new_stmts. push ( stmt) ;
96- if has_await {
97- let check_stmt: syn:: Stmt = syn:: parse_quote! {
98- _piano_guard. check( ) ;
99- } ;
100- new_stmts. push ( check_stmt) ;
111+ struct AwaitCheckInjector ;
112+
113+ impl AwaitCheckInjector {
114+ /// Process a block's statements: for each statement that directly
115+ /// contains `.await`, insert a `check()` call after it. Then recurse
116+ /// into nested blocks within each statement.
117+ fn process_block ( & mut self , block : & mut syn:: Block ) {
118+ let mut new_stmts = Vec :: with_capacity ( block. stmts . len ( ) * 2 ) ;
119+ for mut stmt in block. stmts . drain ( ..) {
120+ let has_await = stmt_has_direct_await ( & stmt) ;
121+ // Recurse into nested blocks within this statement first.
122+ self . visit_stmt_mut ( & mut stmt) ;
123+ new_stmts. push ( stmt) ;
124+ if has_await {
125+ let check_stmt: syn:: Stmt = syn:: parse_quote! {
126+ _piano_guard. check( ) ;
127+ } ;
128+ new_stmts. push ( check_stmt) ;
129+ }
130+ }
131+ block. stmts = new_stmts;
132+ }
133+ }
134+
135+ impl VisitMut for AwaitCheckInjector {
136+ fn visit_block_mut ( & mut self , block : & mut syn:: Block ) {
137+ self . process_block ( block) ;
101138 }
139+
140+ // Don't descend into closures -- separate async context.
141+ fn visit_expr_closure_mut ( & mut self , _: & mut syn:: ExprClosure ) { }
142+ // Don't descend into async blocks -- separate async context.
143+ fn visit_expr_async_mut ( & mut self , _: & mut syn:: ExprAsync ) { }
144+ // Don't descend into nested fn items -- they have their own guards.
145+ fn visit_item_fn_mut ( & mut self , _: & mut syn:: ItemFn ) { }
102146 }
103- block. stmts = new_stmts;
147+
148+ AwaitCheckInjector . process_block ( block) ;
104149}
105150
106151struct Instrumenter {
@@ -2089,4 +2134,258 @@ fn main() {}
20892134 "fn inside impl block in macro should be instrumented. Got:\n {result}"
20902135 ) ;
20912136 }
2137+
2138+ // -- Nested .await check() injection tests --
2139+ //
2140+ // These tests verify check() is injected INSIDE nested blocks, directly
2141+ // after the .await, not just after the enclosing top-level statement.
2142+
2143+ /// Assert that `needle_a` appears before `needle_b` in `haystack`.
2144+ fn assert_appears_before ( haystack : & str , needle_a : & str , needle_b : & str , context : & str ) {
2145+ let pos_a = haystack
2146+ . find ( needle_a)
2147+ . unwrap_or_else ( || panic ! ( "{needle_a} not found in output. {context}:\n {haystack}" ) ) ;
2148+ let pos_b = haystack
2149+ . find ( needle_b)
2150+ . unwrap_or_else ( || panic ! ( "{needle_b} not found in output. {context}:\n {haystack}" ) ) ;
2151+ assert ! (
2152+ pos_a < pos_b,
2153+ "{needle_a} should appear before {needle_b} ({context}). Got:\n {haystack}" ,
2154+ ) ;
2155+ }
2156+
2157+ #[ test]
2158+ fn injects_check_inside_if_block_not_after_it ( ) {
2159+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2160+ // The .await is inside the if block. check() must appear between
2161+ // fetch().await and process(), not after the closing brace of `if`.
2162+ let source = r#"
2163+ async fn example() {
2164+ if condition {
2165+ fetch().await;
2166+ process();
2167+ }
2168+ }
2169+ "# ;
2170+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2171+ // check() must appear before process() (i.e. inside the if block),
2172+ // not after the if statement.
2173+ assert_appears_before (
2174+ & result. source ,
2175+ "_piano_guard.check()" ,
2176+ "process()" ,
2177+ "inside if block" ,
2178+ ) ;
2179+ }
2180+
2181+ #[ test]
2182+ fn injects_check_inside_match_arm ( ) {
2183+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2184+ let source = r#"
2185+ async fn example(x: u32) {
2186+ match x {
2187+ 0 => {
2188+ fetch().await;
2189+ process();
2190+ }
2191+ _ => {}
2192+ }
2193+ }
2194+ "# ;
2195+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2196+ assert_appears_before (
2197+ & result. source ,
2198+ "_piano_guard.check()" ,
2199+ "process()" ,
2200+ "inside match arm" ,
2201+ ) ;
2202+ }
2203+
2204+ #[ test]
2205+ fn injects_check_inside_loop ( ) {
2206+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2207+ let source = r#"
2208+ async fn example() {
2209+ loop {
2210+ fetch().await;
2211+ process();
2212+ }
2213+ }
2214+ "# ;
2215+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2216+ assert_appears_before (
2217+ & result. source ,
2218+ "_piano_guard.check()" ,
2219+ "process()" ,
2220+ "inside loop" ,
2221+ ) ;
2222+ }
2223+
2224+ #[ test]
2225+ fn injects_check_at_every_nesting_depth ( ) {
2226+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2227+ let source = r#"
2228+ async fn example() {
2229+ if condition {
2230+ fetch().await;
2231+ if other {
2232+ save().await;
2233+ cleanup();
2234+ }
2235+ }
2236+ send().await;
2237+ }
2238+ "# ;
2239+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2240+ let check_count = result. source . matches ( "_piano_guard.check()" ) . count ( ) ;
2241+ assert_eq ! (
2242+ check_count, 3 ,
2243+ "should inject check() after every .await at any depth. Got:\n {}" ,
2244+ result. source
2245+ ) ;
2246+ // Verify the innermost check() is between save().await and cleanup()
2247+ let save_pos = result. source . find ( "save()" ) . unwrap ( ) ;
2248+ let cleanup_pos = result. source . find ( "cleanup()" ) . unwrap ( ) ;
2249+ let check_between = result. source [ save_pos..cleanup_pos] . contains ( "_piano_guard.check()" ) ;
2250+ assert ! (
2251+ check_between,
2252+ "check() should appear between save().await and cleanup(). Got:\n {}" ,
2253+ result. source,
2254+ ) ;
2255+ }
2256+
2257+ #[ test]
2258+ fn injects_check_inside_else_block ( ) {
2259+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2260+ let source = r#"
2261+ async fn example() {
2262+ if condition {
2263+ sync_work();
2264+ } else {
2265+ fetch().await;
2266+ process();
2267+ }
2268+ }
2269+ "# ;
2270+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2271+ assert_appears_before (
2272+ & result. source ,
2273+ "_piano_guard.check()" ,
2274+ "process()" ,
2275+ "inside else block" ,
2276+ ) ;
2277+ }
2278+
2279+ #[ test]
2280+ fn injects_check_inside_while_let_body ( ) {
2281+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2282+ let source = r#"
2283+ async fn example() {
2284+ while let Some(item) = iter.next() {
2285+ process(item).await;
2286+ log();
2287+ }
2288+ }
2289+ "# ;
2290+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2291+ assert_appears_before (
2292+ & result. source ,
2293+ "_piano_guard.check()" ,
2294+ "log()" ,
2295+ "inside while let body" ,
2296+ ) ;
2297+ }
2298+
2299+ #[ test]
2300+ fn injects_check_inside_for_loop_body ( ) {
2301+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2302+ let source = r#"
2303+ async fn example() {
2304+ for item in items {
2305+ process(item).await;
2306+ log();
2307+ }
2308+ }
2309+ "# ;
2310+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2311+ assert_appears_before (
2312+ & result. source ,
2313+ "_piano_guard.check()" ,
2314+ "log()" ,
2315+ "inside for loop body" ,
2316+ ) ;
2317+ }
2318+
2319+ #[ test]
2320+ fn injects_check_inside_bare_block ( ) {
2321+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2322+ let source = r#"
2323+ async fn example() {
2324+ {
2325+ fetch().await;
2326+ process();
2327+ }
2328+ }
2329+ "# ;
2330+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2331+ assert_appears_before (
2332+ & result. source ,
2333+ "_piano_guard.check()" ,
2334+ "process()" ,
2335+ "inside bare block" ,
2336+ ) ;
2337+ }
2338+
2339+ #[ test]
2340+ fn injects_check_after_await_in_condition_expression ( ) {
2341+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2342+ let source = r#"
2343+ async fn example() {
2344+ if stream.next().await.is_some() {
2345+ process();
2346+ }
2347+ cleanup();
2348+ }
2349+ "# ;
2350+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2351+ // The .await is in the condition, so check() goes after the whole if statement
2352+ assert_eq ! (
2353+ result. source. matches( "_piano_guard.check()" ) . count( ) ,
2354+ 1 ,
2355+ "expected exactly 1 check() call" ,
2356+ ) ;
2357+ assert_appears_before (
2358+ & result. source ,
2359+ "process()" ,
2360+ "_piano_guard.check()" ,
2361+ "check() should be after the if block containing process()" ,
2362+ ) ;
2363+ assert_appears_before (
2364+ & result. source ,
2365+ "_piano_guard.check()" ,
2366+ "cleanup()" ,
2367+ "after if with .await in condition" ,
2368+ ) ;
2369+ }
2370+
2371+ #[ test]
2372+ fn no_check_injection_inside_closure_or_async_block ( ) {
2373+ let targets: HashSet < String > = [ "example" . to_string ( ) ] . into_iter ( ) . collect ( ) ;
2374+ let source = r#"
2375+ async fn example() {
2376+ let f = || async { fetch().await; };
2377+ let g = async { save().await; };
2378+ send().await;
2379+ }
2380+ "# ;
2381+ let result = instrument_source ( source, & targets, false ) . unwrap ( ) ;
2382+ // Only one check() -- after send().await. The closure and async block
2383+ // have separate async contexts with their own guards.
2384+ let check_count = result. source . matches ( "_piano_guard.check()" ) . count ( ) ;
2385+ assert_eq ! (
2386+ check_count, 1 ,
2387+ "expected 1 check() (after send().await only), got {check_count}:\n {}" ,
2388+ result. source,
2389+ ) ;
2390+ }
20922391}
0 commit comments