1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{
4     visit_mut::{self, visit_item_mut, visit_path_segment_mut, VisitMut},
5     Expr, ExprBlock, File, GenericArgument, GenericParam, Item, PathArguments, PathSegment, Stmt,
6     Type, TypeParamBound, WherePredicate,
7 };
8 
9 pub struct ReplaceGenericType<'a> {
10     generic_type: &'a str,
11     arg_type: &'a PathSegment,
12 }
13 
14 impl<'a> ReplaceGenericType<'a> {
new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self15     pub fn new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self {
16         Self {
17             generic_type,
18             arg_type,
19         }
20     }
21 
replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment)22     pub fn replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment) {
23         let mut s = Self::new(generic_type, arg_type);
24         s.visit_item_mut(item);
25     }
26 }
27 
28 impl<'a> VisitMut for ReplaceGenericType<'a> {
visit_item_mut(&mut self, i: &mut Item)29     fn visit_item_mut(&mut self, i: &mut Item) {
30         if let Item::Fn(item_fn) = i {
31             // remove generic type from generics <T, F>
32             let args = item_fn
33                 .sig
34                 .generics
35                 .params
36                 .iter()
37                 .filter_map(|param| {
38                     if let GenericParam::Type(type_param) = &param {
39                         if type_param.ident.to_string().eq(self.generic_type) {
40                             None
41                         } else {
42                             Some(param)
43                         }
44                     } else {
45                         Some(param)
46                     }
47                 })
48                 .collect::<Vec<_>>();
49             item_fn.sig.generics.params = args.into_iter().cloned().collect();
50 
51             // remove generic type from where clause
52             if let Some(where_clause) = &mut item_fn.sig.generics.where_clause {
53                 let new_where_clause = where_clause
54                     .predicates
55                     .iter()
56                     .filter_map(|predicate| {
57                         if let WherePredicate::Type(predicate_type) = predicate {
58                             if let Type::Path(p) = &predicate_type.bounded_ty {
59                                 if p.path.segments[0].ident.to_string().eq(self.generic_type) {
60                                     None
61                                 } else {
62                                     Some(predicate)
63                                 }
64                             } else {
65                                 Some(predicate)
66                             }
67                         } else {
68                             Some(predicate)
69                         }
70                     })
71                     .collect::<Vec<_>>();
72 
73                 where_clause.predicates = new_where_clause.into_iter().cloned().collect();
74             };
75         }
76         visit_item_mut(self, i)
77     }
visit_path_segment_mut(&mut self, i: &mut PathSegment)78     fn visit_path_segment_mut(&mut self, i: &mut PathSegment) {
79         // replace generic type with target type
80         if i.ident.to_string().eq(&self.generic_type) {
81             *i = self.arg_type.clone();
82         }
83         visit_path_segment_mut(self, i);
84     }
85 }
86 
87 pub struct AsyncAwaitRemoval;
88 
89 impl AsyncAwaitRemoval {
remove_async_await(&mut self, item: TokenStream) -> TokenStream90     pub fn remove_async_await(&mut self, item: TokenStream) -> TokenStream {
91         let mut syntax_tree: File = syn::parse(item.into()).unwrap();
92         self.visit_file_mut(&mut syntax_tree);
93         quote!(#syntax_tree)
94     }
95 }
96 
97 impl VisitMut for AsyncAwaitRemoval {
visit_expr_mut(&mut self, node: &mut Expr)98     fn visit_expr_mut(&mut self, node: &mut Expr) {
99         // Delegate to the default impl to visit nested expressions.
100         visit_mut::visit_expr_mut(self, node);
101 
102         match node {
103             Expr::Await(expr) => *node = (*expr.base).clone(),
104 
105             Expr::Async(expr) => {
106                 let inner = &expr.block;
107                 let sync_expr = if let [Stmt::Expr(expr, None)] = inner.stmts.as_slice() {
108                     // remove useless braces when there is only one statement
109                     expr.clone()
110                 } else {
111                     Expr::Block(ExprBlock {
112                         attrs: expr.attrs.clone(),
113                         block: inner.clone(),
114                         label: None,
115                     })
116                 };
117                 *node = sync_expr;
118             }
119             _ => {}
120         }
121     }
122 
visit_item_mut(&mut self, i: &mut Item)123     fn visit_item_mut(&mut self, i: &mut Item) {
124         // find generic parameter of Future and replace it with its Output type
125         if let Item::Fn(item_fn) = i {
126             let mut inputs: Vec<(String, PathSegment)> = vec![];
127 
128             // generic params: <T:Future<Output=()>, F>
129             for param in &item_fn.sig.generics.params {
130                 // generic param: T:Future<Output=()>
131                 if let GenericParam::Type(type_param) = param {
132                     let generic_type_name = type_param.ident.to_string();
133 
134                     // bound: Future<Output=()>
135                     for bound in &type_param.bounds {
136                         inputs.extend(search_trait_bound(&generic_type_name, bound));
137                     }
138                 }
139             }
140 
141             if let Some(where_clause) = &item_fn.sig.generics.where_clause {
142                 for predicate in &where_clause.predicates {
143                     if let WherePredicate::Type(predicate_type) = predicate {
144                         let generic_type_name = if let Type::Path(p) = &predicate_type.bounded_ty {
145                             p.path.segments[0].ident.to_string()
146                         } else {
147                             panic!("Please submit an issue");
148                         };
149 
150                         for bound in &predicate_type.bounds {
151                             inputs.extend(search_trait_bound(&generic_type_name, bound));
152                         }
153                     }
154                 }
155             }
156 
157             for (generic_type_name, path_seg) in &inputs {
158                 ReplaceGenericType::replace_generic_type(i, generic_type_name, path_seg);
159             }
160         }
161         visit_item_mut(self, i);
162     }
163 }
164 
search_trait_bound( generic_type_name: &str, bound: &TypeParamBound, ) -> Vec<(String, PathSegment)>165 fn search_trait_bound(
166     generic_type_name: &str,
167     bound: &TypeParamBound,
168 ) -> Vec<(String, PathSegment)> {
169     let mut inputs = vec![];
170 
171     if let TypeParamBound::Trait(trait_bound) = bound {
172         let segment = &trait_bound.path.segments[trait_bound.path.segments.len() - 1];
173         let name = segment.ident.to_string();
174         if name.eq("Future") {
175             // match Future<Output=Type>
176             if let PathArguments::AngleBracketed(args) = &segment.arguments {
177                 // binding: Output=Type
178                 if let GenericArgument::AssocType(binding) = &args.args[0] {
179                     if let Type::Path(p) = &binding.ty {
180                         inputs.push((generic_type_name.to_owned(), p.path.segments[0].clone()));
181                     }
182                 }
183             }
184         }
185     }
186     inputs
187 }
188