1 use std::iter;
2 
3 use proc_macro2::TokenStream;
4 use quote::{quote, quote_spanned, ToTokens};
5 use syn::visit_mut::VisitMut;
6 use syn::{
7     punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg,
8     Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType,
9     Path, ReturnType, Signature, Stmt, Token, Type, TypePath,
10 };
11 
12 use crate::{
13     attr::{Field, Fields, FormatMode, InstrumentArgs, Level},
14     MaybeItemFn, MaybeItemFnRef,
15 };
16 
17 /// Given an existing function, generate an instrumented version of that function
gen_function<'a, B: ToTokens + 'a>( input: MaybeItemFnRef<'a, B>, args: InstrumentArgs, instrumented_function_name: &str, self_type: Option<&TypePath>, ) -> proc_macro2::TokenStream18 pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
19     input: MaybeItemFnRef<'a, B>,
20     args: InstrumentArgs,
21     instrumented_function_name: &str,
22     self_type: Option<&TypePath>,
23 ) -> proc_macro2::TokenStream {
24     // these are needed ahead of time, as ItemFn contains the function body _and_
25     // isn't representable inside a quote!/quote_spanned! macro
26     // (Syn's ToTokens isn't implemented for ItemFn)
27     let MaybeItemFnRef {
28         outer_attrs,
29         inner_attrs,
30         vis,
31         sig,
32         block,
33     } = input;
34 
35     let Signature {
36         output,
37         inputs: params,
38         unsafety,
39         asyncness,
40         constness,
41         abi,
42         ident,
43         generics:
44             syn::Generics {
45                 params: gen_params,
46                 where_clause,
47                 ..
48             },
49         ..
50     } = sig;
51 
52     let warnings = args.warnings();
53 
54     let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output {
55         (erase_impl_trait(return_type), return_type.span())
56     } else {
57         // Point at function name if we don't have an explicit return type
58         (syn::parse_quote! { () }, ident.span())
59     };
60     // Install a fake return statement as the first thing in the function
61     // body, so that we eagerly infer that the return type is what we
62     // declared in the async fn signature.
63     // The `#[allow(..)]` is given because the return statement is
64     // unreachable, but does affect inference, so it needs to be written
65     // exactly that way for it to do its magic.
66     let fake_return_edge = quote_spanned! {return_span=>
67         #[allow(
68             unknown_lints, unreachable_code, clippy::diverging_sub_expression,
69             clippy::let_unit_value, clippy::unreachable, clippy::let_with_type_underscore,
70             clippy::empty_loop
71         )]
72         if false {
73             let __tracing_attr_fake_return: #return_type = loop {};
74             return __tracing_attr_fake_return;
75         }
76     };
77     let block = quote! {
78         {
79             #fake_return_edge
80             #block
81         }
82     };
83 
84     let body = gen_block(
85         &block,
86         params,
87         asyncness.is_some(),
88         args,
89         instrumented_function_name,
90         self_type,
91     );
92 
93     quote!(
94         #(#outer_attrs) *
95         #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #output
96         #where_clause
97         {
98             #(#inner_attrs) *
99             #warnings
100             #body
101         }
102     )
103 }
104 
105 /// Instrument a block
gen_block<B: ToTokens>( block: &B, params: &Punctuated<FnArg, Token![,]>, async_context: bool, mut args: InstrumentArgs, instrumented_function_name: &str, self_type: Option<&TypePath>, ) -> proc_macro2::TokenStream106 fn gen_block<B: ToTokens>(
107     block: &B,
108     params: &Punctuated<FnArg, Token![,]>,
109     async_context: bool,
110     mut args: InstrumentArgs,
111     instrumented_function_name: &str,
112     self_type: Option<&TypePath>,
113 ) -> proc_macro2::TokenStream {
114     // generate the span's name
115     let span_name = args
116         // did the user override the span's name?
117         .name
118         .as_ref()
119         .map(|name| quote!(#name))
120         .unwrap_or_else(|| quote!(#instrumented_function_name));
121 
122     let args_level = args.level();
123     let level = args_level.clone();
124 
125     let follows_from = args.follows_from.iter();
126     let follows_from = quote! {
127         #(for cause in #follows_from {
128             __tracing_attr_span.follows_from(cause);
129         })*
130     };
131 
132     // generate this inside a closure, so we can return early on errors.
133     let span = (|| {
134         // Pull out the arguments-to-be-skipped first, so we can filter results
135         // below.
136         let param_names: Vec<(Ident, (Ident, RecordType))> = params
137             .clone()
138             .into_iter()
139             .flat_map(|param| match param {
140                 FnArg::Typed(PatType { pat, ty, .. }) => {
141                     param_names(*pat, RecordType::parse_from_ty(&ty))
142                 }
143                 FnArg::Receiver(_) => Box::new(iter::once((
144                     Ident::new("self", param.span()),
145                     RecordType::Debug,
146                 ))),
147             })
148             // Little dance with new (user-exposed) names and old (internal)
149             // names of identifiers. That way, we could do the following
150             // even though async_trait (<=0.1.43) rewrites "self" as "_self":
151             // ```
152             // #[async_trait]
153             // impl Foo for FooImpl {
154             //     #[instrument(skip(self))]
155             //     async fn foo(&self, v: usize) {}
156             // }
157             // ```
158             .map(|(x, record_type)| {
159                 // if we are inside a function generated by async-trait <=0.1.43, we need to
160                 // take care to rewrite "_self" as "self" for 'user convenience'
161                 if self_type.is_some() && x == "_self" {
162                     (Ident::new("self", x.span()), (x, record_type))
163                 } else {
164                     (x.clone(), (x, record_type))
165                 }
166             })
167             .collect();
168 
169         for skip in &args.skips {
170             if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) {
171                 return quote_spanned! {skip.span()=>
172                     compile_error!("attempting to skip non-existent parameter")
173                 };
174             }
175         }
176 
177         let target = args.target();
178 
179         let parent = args.parent.iter();
180 
181         // filter out skipped fields
182         let quoted_fields: Vec<_> = param_names
183             .iter()
184             .filter(|(param, _)| {
185                 if args.skip_all || args.skips.contains(param) {
186                     return false;
187                 }
188 
189                 // If any parameters have the same name as a custom field, skip
190                 // and allow them to be formatted by the custom field.
191                 if let Some(ref fields) = args.fields {
192                     fields.0.iter().all(|Field { ref name, .. }| {
193                         let first = name.first();
194                         first != name.last() || !first.iter().any(|name| name == &param)
195                     })
196                 } else {
197                     true
198                 }
199             })
200             .map(|(user_name, (real_name, record_type))| match record_type {
201                 RecordType::Value => quote!(#user_name = #real_name),
202                 RecordType::Debug => quote!(#user_name = tracing::field::debug(&#real_name)),
203             })
204             .collect();
205 
206         // replace every use of a variable with its original name
207         if let Some(Fields(ref mut fields)) = args.fields {
208             let mut replacer = IdentAndTypesRenamer {
209                 idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(),
210                 types: Vec::new(),
211             };
212 
213             // when async-trait <=0.1.43 is in use, replace instances
214             // of the "Self" type inside the fields values
215             if let Some(self_type) = self_type {
216                 replacer.types.push(("Self", self_type.clone()));
217             }
218 
219             for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) {
220                 syn::visit_mut::visit_expr_mut(&mut replacer, e);
221             }
222         }
223 
224         let custom_fields = &args.fields;
225 
226         quote!(tracing::span!(
227             target: #target,
228             #(parent: #parent,)*
229             #level,
230             #span_name,
231             #(#quoted_fields,)*
232             #custom_fields
233 
234         ))
235     })();
236 
237     let target = args.target();
238 
239     let err_event = match args.err_args {
240         Some(event_args) => {
241             let level_tokens = event_args.level(Level::Error);
242             match event_args.mode {
243                 FormatMode::Default | FormatMode::Display => Some(quote!(
244                     tracing::event!(target: #target, #level_tokens, error = %e)
245                 )),
246                 FormatMode::Debug => Some(quote!(
247                     tracing::event!(target: #target, #level_tokens, error = ?e)
248                 )),
249             }
250         }
251         _ => None,
252     };
253 
254     let ret_event = match args.ret_args {
255         Some(event_args) => {
256             let level_tokens = event_args.level(args_level);
257             match event_args.mode {
258                 FormatMode::Display => Some(quote!(
259                     tracing::event!(target: #target, #level_tokens, return = %x)
260                 )),
261                 FormatMode::Default | FormatMode::Debug => Some(quote!(
262                     tracing::event!(target: #target, #level_tokens, return = ?x)
263                 )),
264             }
265         }
266         _ => None,
267     };
268 
269     // Generate the instrumented function body.
270     // If the function is an `async fn`, this will wrap it in an async block,
271     // which is `instrument`ed using `tracing-futures`. Otherwise, this will
272     // enter the span and then perform the rest of the body.
273     // If `err` is in args, instrument any resulting `Err`s.
274     // If `ret` is in args, instrument any resulting `Ok`s when the function
275     // returns `Result`s, otherwise instrument any resulting values.
276     if async_context {
277         let mk_fut = match (err_event, ret_event) {
278             (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=>
279                 async move {
280                     match async move #block.await {
281                         #[allow(clippy::unit_arg)]
282                         Ok(x) => {
283                             #ret_event;
284                             Ok(x)
285                         },
286                         Err(e) => {
287                             #err_event;
288                             Err(e)
289                         }
290                     }
291                 }
292             ),
293             (Some(err_event), None) => quote_spanned!(block.span()=>
294                 async move {
295                     match async move #block.await {
296                         #[allow(clippy::unit_arg)]
297                         Ok(x) => Ok(x),
298                         Err(e) => {
299                             #err_event;
300                             Err(e)
301                         }
302                     }
303                 }
304             ),
305             (None, Some(ret_event)) => quote_spanned!(block.span()=>
306                 async move {
307                     let x = async move #block.await;
308                     #ret_event;
309                     x
310                 }
311             ),
312             (None, None) => quote_spanned!(block.span()=>
313                 async move #block
314             ),
315         };
316 
317         return quote!(
318             let __tracing_attr_span = #span;
319             let __tracing_instrument_future = #mk_fut;
320             if !__tracing_attr_span.is_disabled() {
321                 #follows_from
322                 tracing::Instrument::instrument(
323                     __tracing_instrument_future,
324                     __tracing_attr_span
325                 )
326                 .await
327             } else {
328                 __tracing_instrument_future.await
329             }
330         );
331     }
332 
333     let span = quote!(
334         // These variables are left uninitialized and initialized only
335         // if the tracing level is statically enabled at this point.
336         // While the tracing level is also checked at span creation
337         // time, that will still create a dummy span, and a dummy guard
338         // and drop the dummy guard later. By lazily initializing these
339         // variables, Rust will generate a drop flag for them and thus
340         // only drop the guard if it was created. This creates code that
341         // is very straightforward for LLVM to optimize out if the tracing
342         // level is statically disabled, while not causing any performance
343         // regression in case the level is enabled.
344         let __tracing_attr_span;
345         let __tracing_attr_guard;
346         if tracing::level_enabled!(#level) || tracing::if_log_enabled!(#level, {true} else {false}) {
347             __tracing_attr_span = #span;
348             #follows_from
349             __tracing_attr_guard = __tracing_attr_span.enter();
350         }
351     );
352 
353     match (err_event, ret_event) {
354         (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=>
355             #span
356             #[allow(clippy::redundant_closure_call)]
357             match (move || #block)() {
358                 #[allow(clippy::unit_arg)]
359                 Ok(x) => {
360                     #ret_event;
361                     Ok(x)
362                 },
363                 Err(e) => {
364                     #err_event;
365                     Err(e)
366                 }
367             }
368         },
369         (Some(err_event), None) => quote_spanned!(block.span()=>
370             #span
371             #[allow(clippy::redundant_closure_call)]
372             match (move || #block)() {
373                 #[allow(clippy::unit_arg)]
374                 Ok(x) => Ok(x),
375                 Err(e) => {
376                     #err_event;
377                     Err(e)
378                 }
379             }
380         ),
381         (None, Some(ret_event)) => quote_spanned!(block.span()=>
382             #span
383             #[allow(clippy::redundant_closure_call)]
384             let x = (move || #block)();
385             #ret_event;
386             x
387         ),
388         (None, None) => quote_spanned!(block.span() =>
389             // Because `quote` produces a stream of tokens _without_ whitespace, the
390             // `if` and the block will appear directly next to each other. This
391             // generates a clippy lint about suspicious `if/else` formatting.
392             // Therefore, suppress the lint inside the generated code...
393             #[allow(clippy::suspicious_else_formatting)]
394             {
395                 #span
396                 // ...but turn the lint back on inside the function body.
397                 #[warn(clippy::suspicious_else_formatting)]
398                 #block
399             }
400         ),
401     }
402 }
403 
404 /// Indicates whether a field should be recorded as `Value` or `Debug`.
405 enum RecordType {
406     /// The field should be recorded using its `Value` implementation.
407     Value,
408     /// The field should be recorded using `tracing::field::debug()`.
409     Debug,
410 }
411 
412 impl RecordType {
413     /// Array of primitive types which should be recorded as [RecordType::Value].
414     const TYPES_FOR_VALUE: &'static [&'static str] = &[
415         "bool",
416         "str",
417         "u8",
418         "i8",
419         "u16",
420         "i16",
421         "u32",
422         "i32",
423         "u64",
424         "i64",
425         "f32",
426         "f64",
427         "usize",
428         "isize",
429         "NonZeroU8",
430         "NonZeroI8",
431         "NonZeroU16",
432         "NonZeroI16",
433         "NonZeroU32",
434         "NonZeroI32",
435         "NonZeroU64",
436         "NonZeroI64",
437         "NonZeroUsize",
438         "NonZeroIsize",
439         "Wrapping",
440     ];
441 
442     /// Parse `RecordType` from [Type] by looking up
443     /// the [RecordType::TYPES_FOR_VALUE] array.
parse_from_ty(ty: &Type) -> Self444     fn parse_from_ty(ty: &Type) -> Self {
445         match ty {
446             Type::Path(TypePath { path, .. })
447                 if path
448                     .segments
449                     .iter()
450                     .last()
451                     .map(|path_segment| {
452                         let ident = path_segment.ident.to_string();
453                         Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident)
454                     })
455                     .unwrap_or(false) =>
456             {
457                 RecordType::Value
458             }
459             Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem),
460             _ => RecordType::Debug,
461         }
462     }
463 }
464 
param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>>465 fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> {
466     match pat {
467         Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))),
468         Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type),
469         // We can't get the concrete type of fields in the struct/tuple
470         // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`.
471         // Therefore, the struct/tuple patterns in the arguments will just
472         // always be recorded as `RecordType::Debug`.
473         Pat::Struct(PatStruct { fields, .. }) => Box::new(
474             fields
475                 .into_iter()
476                 .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)),
477         ),
478         Pat::Tuple(PatTuple { elems, .. }) => Box::new(
479             elems
480                 .into_iter()
481                 .flat_map(|p| param_names(p, RecordType::Debug)),
482         ),
483         Pat::TupleStruct(PatTupleStruct { elems, .. }) => Box::new(
484             elems
485                 .into_iter()
486                 .flat_map(|p| param_names(p, RecordType::Debug)),
487         ),
488 
489         // The above *should* cover all cases of irrefutable patterns,
490         // but we purposefully don't do any funny business here
491         // (such as panicking) because that would obscure rustc's
492         // much more informative error message.
493         _ => Box::new(iter::empty()),
494     }
495 }
496 
497 /// The specific async code pattern that was detected
498 enum AsyncKind<'a> {
499     /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
500     /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
501     Function(&'a ItemFn),
502     /// A function returning an async (move) block, optionally `Box::pin`-ed,
503     /// as generated by `async-trait >= 0.1.44`:
504     /// `Box::pin(async move { ... })`
505     Async {
506         async_expr: &'a ExprAsync,
507         pinned_box: bool,
508     },
509 }
510 
511 pub(crate) struct AsyncInfo<'block> {
512     // statement that must be patched
513     source_stmt: &'block Stmt,
514     kind: AsyncKind<'block>,
515     self_type: Option<TypePath>,
516     input: &'block ItemFn,
517 }
518 
519 impl<'block> AsyncInfo<'block> {
520     /// Get the AST of the inner function we need to hook, if it looks like a
521     /// manual future implementation.
522     ///
523     /// When we are given a function that returns a (pinned) future containing the
524     /// user logic, it is that (pinned) future that needs to be instrumented.
525     /// Were we to instrument its parent, we would only collect information
526     /// regarding the allocation of that future, and not its own span of execution.
527     ///
528     /// We inspect the block of the function to find if it matches any of the
529     /// following patterns:
530     ///
531     /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
532     ///   `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
533     ///
534     /// - A function returning an async (move) block, optionally `Box::pin`-ed,
535     ///   as generated by `async-trait >= 0.1.44`:
536     ///   `Box::pin(async move { ... })`
537     ///
538     /// We the return the statement that must be instrumented, along with some
539     /// other information.
540     /// 'gen_body' will then be able to use that information to instrument the
541     /// proper function/future.
542     ///
543     /// (this follows the approach suggested in
544     /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
from_fn(input: &'block ItemFn) -> Option<Self>545     pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> {
546         // are we in an async context? If yes, this isn't a manual async-like pattern
547         if input.sig.asyncness.is_some() {
548             return None;
549         }
550 
551         let block = &input.block;
552 
553         // list of async functions declared inside the block
554         let inside_funs = block.stmts.iter().filter_map(|stmt| {
555             if let Stmt::Item(Item::Fn(fun)) = &stmt {
556                 // If the function is async, this is a candidate
557                 if fun.sig.asyncness.is_some() {
558                     return Some((stmt, fun));
559                 }
560             }
561             None
562         });
563 
564         // last expression of the block: it determines the return value of the
565         // block, this is quite likely a `Box::pin` statement or an async block
566         let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| {
567             if let Stmt::Expr(expr, _semi) = stmt {
568                 Some((stmt, expr))
569             } else {
570                 None
571             }
572         })?;
573 
574         // is the last expression an async block?
575         if let Expr::Async(async_expr) = last_expr {
576             return Some(AsyncInfo {
577                 source_stmt: last_expr_stmt,
578                 kind: AsyncKind::Async {
579                     async_expr,
580                     pinned_box: false,
581                 },
582                 self_type: None,
583                 input,
584             });
585         }
586 
587         // is the last expression a function call?
588         let (outside_func, outside_args) = match last_expr {
589             Expr::Call(ExprCall { func, args, .. }) => (func, args),
590             _ => return None,
591         };
592 
593         // is it a call to `Box::pin()`?
594         let path = match outside_func.as_ref() {
595             Expr::Path(path) => &path.path,
596             _ => return None,
597         };
598         if !path_to_string(path).ends_with("Box::pin") {
599             return None;
600         }
601 
602         // Does the call take an argument? If it doesn't,
603         // it's not gonna compile anyway, but that's no reason
604         // to (try to) perform an out of bounds access
605         if outside_args.is_empty() {
606             return None;
607         }
608 
609         // Is the argument to Box::pin an async block that
610         // captures its arguments?
611         if let Expr::Async(async_expr) = &outside_args[0] {
612             return Some(AsyncInfo {
613                 source_stmt: last_expr_stmt,
614                 kind: AsyncKind::Async {
615                     async_expr,
616                     pinned_box: true,
617                 },
618                 self_type: None,
619                 input,
620             });
621         }
622 
623         // Is the argument to Box::pin a function call itself?
624         let func = match &outside_args[0] {
625             Expr::Call(ExprCall { func, .. }) => func,
626             _ => return None,
627         };
628 
629         // "stringify" the path of the function called
630         let func_name = match **func {
631             Expr::Path(ref func_path) => path_to_string(&func_path.path),
632             _ => return None,
633         };
634 
635         // Was that function defined inside of the current block?
636         // If so, retrieve the statement where it was declared and the function itself
637         let (stmt_func_declaration, func) = inside_funs
638             .into_iter()
639             .find(|(_, fun)| fun.sig.ident == func_name)?;
640 
641         // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
642         // parameter type) with the type of "_self"
643         let mut self_type = None;
644         for arg in &func.sig.inputs {
645             if let FnArg::Typed(ty) = arg {
646                 if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat {
647                     if ident == "_self" {
648                         let mut ty = *ty.ty.clone();
649                         // extract the inner type if the argument is "&self" or "&mut self"
650                         if let Type::Reference(syn::TypeReference { elem, .. }) = ty {
651                             ty = *elem;
652                         }
653 
654                         if let Type::Path(tp) = ty {
655                             self_type = Some(tp);
656                             break;
657                         }
658                     }
659                 }
660             }
661         }
662 
663         Some(AsyncInfo {
664             source_stmt: stmt_func_declaration,
665             kind: AsyncKind::Function(func),
666             self_type,
667             input,
668         })
669     }
670 
gen_async( self, args: InstrumentArgs, instrumented_function_name: &str, ) -> Result<proc_macro::TokenStream, syn::Error>671     pub(crate) fn gen_async(
672         self,
673         args: InstrumentArgs,
674         instrumented_function_name: &str,
675     ) -> Result<proc_macro::TokenStream, syn::Error> {
676         // let's rewrite some statements!
677         let mut out_stmts: Vec<TokenStream> = self
678             .input
679             .block
680             .stmts
681             .iter()
682             .map(|stmt| stmt.to_token_stream())
683             .collect();
684 
685         if let Some((iter, _stmt)) = self
686             .input
687             .block
688             .stmts
689             .iter()
690             .enumerate()
691             .find(|(_iter, stmt)| *stmt == self.source_stmt)
692         {
693             // instrument the future by rewriting the corresponding statement
694             out_stmts[iter] = match self.kind {
695                 // `Box::pin(immediately_invoked_async_fn())`
696                 AsyncKind::Function(fun) => {
697                     let fun = MaybeItemFn::from(fun.clone());
698                     gen_function(
699                         fun.as_ref(),
700                         args,
701                         instrumented_function_name,
702                         self.self_type.as_ref(),
703                     )
704                 }
705                 // `async move { ... }`, optionally pinned
706                 AsyncKind::Async {
707                     async_expr,
708                     pinned_box,
709                 } => {
710                     let instrumented_block = gen_block(
711                         &async_expr.block,
712                         &self.input.sig.inputs,
713                         true,
714                         args,
715                         instrumented_function_name,
716                         None,
717                     );
718                     let async_attrs = &async_expr.attrs;
719                     if pinned_box {
720                         quote! {
721                             Box::pin(#(#async_attrs) * async move { #instrumented_block })
722                         }
723                     } else {
724                         quote! {
725                             #(#async_attrs) * async move { #instrumented_block }
726                         }
727                     }
728                 }
729             };
730         }
731 
732         let vis = &self.input.vis;
733         let sig = &self.input.sig;
734         let attrs = &self.input.attrs;
735         Ok(quote!(
736             #(#attrs) *
737             #vis #sig {
738                 #(#out_stmts) *
739             }
740         )
741         .into())
742     }
743 }
744 
745 // Return a path as a String
path_to_string(path: &Path) -> String746 fn path_to_string(path: &Path) -> String {
747     use std::fmt::Write;
748     // some heuristic to prevent too many allocations
749     let mut res = String::with_capacity(path.segments.len() * 5);
750     for i in 0..path.segments.len() {
751         write!(&mut res, "{}", path.segments[i].ident)
752             .expect("writing to a String should never fail");
753         if i < path.segments.len() - 1 {
754             res.push_str("::");
755         }
756     }
757     res
758 }
759 
760 /// A visitor struct to replace idents and types in some piece
761 /// of code (e.g. the "self" and "Self" tokens in user-supplied
762 /// fields expressions when the function is generated by an old
763 /// version of async-trait).
764 struct IdentAndTypesRenamer<'a> {
765     types: Vec<(&'a str, TypePath)>,
766     idents: Vec<(Ident, Ident)>,
767 }
768 
769 impl<'a> VisitMut for IdentAndTypesRenamer<'a> {
770     // we deliberately compare strings because we want to ignore the spans
771     // If we apply clippy's lint, the behavior changes
772     #[allow(clippy::cmp_owned)]
visit_ident_mut(&mut self, id: &mut Ident)773     fn visit_ident_mut(&mut self, id: &mut Ident) {
774         for (old_ident, new_ident) in &self.idents {
775             if id.to_string() == old_ident.to_string() {
776                 *id = new_ident.clone();
777             }
778         }
779     }
780 
visit_type_mut(&mut self, ty: &mut Type)781     fn visit_type_mut(&mut self, ty: &mut Type) {
782         for (type_name, new_type) in &self.types {
783             if let Type::Path(TypePath { path, .. }) = ty {
784                 if path_to_string(path) == *type_name {
785                     *ty = Type::Path(new_type.clone());
786                 }
787             }
788         }
789     }
790 }
791 
792 // A visitor struct that replace an async block by its patched version
793 struct AsyncTraitBlockReplacer<'a> {
794     block: &'a Block,
795     patched_block: Block,
796 }
797 
798 impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> {
visit_block_mut(&mut self, i: &mut Block)799     fn visit_block_mut(&mut self, i: &mut Block) {
800         if i == self.block {
801             *i = self.patched_block.clone();
802         }
803     }
804 }
805 
806 // Replaces any `impl Trait` with `_` so it can be used as the type in
807 // a `let` statement's LHS.
808 struct ImplTraitEraser;
809 
810 impl VisitMut for ImplTraitEraser {
visit_type_mut(&mut self, t: &mut Type)811     fn visit_type_mut(&mut self, t: &mut Type) {
812         if let Type::ImplTrait(..) = t {
813             *t = syn::TypeInfer {
814                 underscore_token: Token![_](t.span()),
815             }
816             .into();
817         } else {
818             syn::visit_mut::visit_type_mut(self, t);
819         }
820     }
821 }
822 
erase_impl_trait(ty: &Type) -> Type823 fn erase_impl_trait(ty: &Type) -> Type {
824     let mut ty = ty.clone();
825     ImplTraitEraser.visit_type_mut(&mut ty);
826     ty
827 }
828