1 use proc_macro2::TokenStream;
2 use quote::quote;
3 use syn::{
4 visit_mut::{self, visit_item_mut, visit_path_segment_mut, VisitMut},
5 Expr, ExprBlock, File, GenericArgument, GenericParam, Item, PathArguments, PathSegment, Stmt,
6 Type, TypeParamBound, WherePredicate,
7 };
8
9 pub struct ReplaceGenericType<'a> {
10 generic_type: &'a str,
11 arg_type: &'a PathSegment,
12 }
13
14 impl<'a> ReplaceGenericType<'a> {
new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self15 pub fn new(generic_type: &'a str, arg_type: &'a PathSegment) -> Self {
16 Self {
17 generic_type,
18 arg_type,
19 }
20 }
21
replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment)22 pub fn replace_generic_type(item: &mut Item, generic_type: &'a str, arg_type: &'a PathSegment) {
23 let mut s = Self::new(generic_type, arg_type);
24 s.visit_item_mut(item);
25 }
26 }
27
28 impl<'a> VisitMut for ReplaceGenericType<'a> {
visit_item_mut(&mut self, i: &mut Item)29 fn visit_item_mut(&mut self, i: &mut Item) {
30 if let Item::Fn(item_fn) = i {
31 // remove generic type from generics <T, F>
32 let args = item_fn
33 .sig
34 .generics
35 .params
36 .iter()
37 .filter_map(|param| {
38 if let GenericParam::Type(type_param) = ¶m {
39 if type_param.ident.to_string().eq(self.generic_type) {
40 None
41 } else {
42 Some(param)
43 }
44 } else {
45 Some(param)
46 }
47 })
48 .collect::<Vec<_>>();
49 item_fn.sig.generics.params = args.into_iter().cloned().collect();
50
51 // remove generic type from where clause
52 if let Some(where_clause) = &mut item_fn.sig.generics.where_clause {
53 let new_where_clause = where_clause
54 .predicates
55 .iter()
56 .filter_map(|predicate| {
57 if let WherePredicate::Type(predicate_type) = predicate {
58 if let Type::Path(p) = &predicate_type.bounded_ty {
59 if p.path.segments[0].ident.to_string().eq(self.generic_type) {
60 None
61 } else {
62 Some(predicate)
63 }
64 } else {
65 Some(predicate)
66 }
67 } else {
68 Some(predicate)
69 }
70 })
71 .collect::<Vec<_>>();
72
73 where_clause.predicates = new_where_clause.into_iter().cloned().collect();
74 };
75 }
76 visit_item_mut(self, i)
77 }
visit_path_segment_mut(&mut self, i: &mut PathSegment)78 fn visit_path_segment_mut(&mut self, i: &mut PathSegment) {
79 // replace generic type with target type
80 if i.ident.to_string().eq(&self.generic_type) {
81 *i = self.arg_type.clone();
82 }
83 visit_path_segment_mut(self, i);
84 }
85 }
86
87 pub struct AsyncAwaitRemoval;
88
89 impl AsyncAwaitRemoval {
remove_async_await(&mut self, item: TokenStream) -> TokenStream90 pub fn remove_async_await(&mut self, item: TokenStream) -> TokenStream {
91 let mut syntax_tree: File = syn::parse(item.into()).unwrap();
92 self.visit_file_mut(&mut syntax_tree);
93 quote!(#syntax_tree)
94 }
95 }
96
97 impl VisitMut for AsyncAwaitRemoval {
visit_expr_mut(&mut self, node: &mut Expr)98 fn visit_expr_mut(&mut self, node: &mut Expr) {
99 // Delegate to the default impl to visit nested expressions.
100 visit_mut::visit_expr_mut(self, node);
101
102 match node {
103 Expr::Await(expr) => *node = (*expr.base).clone(),
104
105 Expr::Async(expr) => {
106 let inner = &expr.block;
107 let sync_expr = if let [Stmt::Expr(expr, None)] = inner.stmts.as_slice() {
108 // remove useless braces when there is only one statement
109 expr.clone()
110 } else {
111 Expr::Block(ExprBlock {
112 attrs: expr.attrs.clone(),
113 block: inner.clone(),
114 label: None,
115 })
116 };
117 *node = sync_expr;
118 }
119 _ => {}
120 }
121 }
122
visit_item_mut(&mut self, i: &mut Item)123 fn visit_item_mut(&mut self, i: &mut Item) {
124 // find generic parameter of Future and replace it with its Output type
125 if let Item::Fn(item_fn) = i {
126 let mut inputs: Vec<(String, PathSegment)> = vec![];
127
128 // generic params: <T:Future<Output=()>, F>
129 for param in &item_fn.sig.generics.params {
130 // generic param: T:Future<Output=()>
131 if let GenericParam::Type(type_param) = param {
132 let generic_type_name = type_param.ident.to_string();
133
134 // bound: Future<Output=()>
135 for bound in &type_param.bounds {
136 inputs.extend(search_trait_bound(&generic_type_name, bound));
137 }
138 }
139 }
140
141 if let Some(where_clause) = &item_fn.sig.generics.where_clause {
142 for predicate in &where_clause.predicates {
143 if let WherePredicate::Type(predicate_type) = predicate {
144 let generic_type_name = if let Type::Path(p) = &predicate_type.bounded_ty {
145 p.path.segments[0].ident.to_string()
146 } else {
147 panic!("Please submit an issue");
148 };
149
150 for bound in &predicate_type.bounds {
151 inputs.extend(search_trait_bound(&generic_type_name, bound));
152 }
153 }
154 }
155 }
156
157 for (generic_type_name, path_seg) in &inputs {
158 ReplaceGenericType::replace_generic_type(i, generic_type_name, path_seg);
159 }
160 }
161 visit_item_mut(self, i);
162 }
163 }
164
search_trait_bound( generic_type_name: &str, bound: &TypeParamBound, ) -> Vec<(String, PathSegment)>165 fn search_trait_bound(
166 generic_type_name: &str,
167 bound: &TypeParamBound,
168 ) -> Vec<(String, PathSegment)> {
169 let mut inputs = vec![];
170
171 if let TypeParamBound::Trait(trait_bound) = bound {
172 let segment = &trait_bound.path.segments[trait_bound.path.segments.len() - 1];
173 let name = segment.ident.to_string();
174 if name.eq("Future") {
175 // match Future<Output=Type>
176 if let PathArguments::AngleBracketed(args) = &segment.arguments {
177 // binding: Output=Type
178 if let GenericArgument::AssocType(binding) = &args.args[0] {
179 if let Type::Path(p) = &binding.ty {
180 inputs.push((generic_type_name.to_owned(), p.path.segments[0].clone()));
181 }
182 }
183 }
184 }
185 }
186 inputs
187 }
188