1 use crate::ast::{Enum, Field, Input, Struct};
2 use crate::attr::Trait;
3 use crate::generics::InferredBounds;
4 use proc_macro2::TokenStream;
5 use quote::{format_ident, quote, quote_spanned, ToTokens};
6 use std::collections::BTreeSet as Set;
7 use syn::spanned::Spanned;
8 use syn::{
9     Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
10 };
11 
derive(node: &DeriveInput) -> Result<TokenStream>12 pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
13     let input = Input::from_syn(node)?;
14     input.validate()?;
15     Ok(match input {
16         Input::Struct(input) => impl_struct(input),
17         Input::Enum(input) => impl_enum(input),
18     })
19 }
20 
impl_struct(input: Struct) -> TokenStream21 fn impl_struct(input: Struct) -> TokenStream {
22     let ty = &input.ident;
23     let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24     let mut error_inferred_bounds = InferredBounds::new();
25 
26     let source_body = if input.attrs.transparent.is_some() {
27         let only_field = &input.fields[0];
28         if only_field.contains_generic {
29             error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
30         }
31         let member = &only_field.member;
32         Some(quote! {
33             std::error::Error::source(self.#member.as_dyn_error())
34         })
35     } else if let Some(source_field) = input.source_field() {
36         let source = &source_field.member;
37         if source_field.contains_generic {
38             let ty = unoptional_type(source_field.ty);
39             error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
40         }
41         let asref = if type_is_option(source_field.ty) {
42             Some(quote_spanned!(source.span()=> .as_ref()?))
43         } else {
44             None
45         };
46         let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error());
47         Some(quote! {
48             ::core::option::Option::Some(#dyn_error)
49         })
50     } else {
51         None
52     };
53     let source_method = source_body.map(|body| {
54         quote! {
55             fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
56                 use thiserror::__private::AsDynError;
57                 #body
58             }
59         }
60     });
61 
62     let provide_method = input.backtrace_field().map(|backtrace_field| {
63         let request = quote!(request);
64         let backtrace = &backtrace_field.member;
65         let body = if let Some(source_field) = input.source_field() {
66             let source = &source_field.member;
67             let source_provide = if type_is_option(source_field.ty) {
68                 quote_spanned! {source.span()=>
69                     if let ::core::option::Option::Some(source) = &self.#source {
70                         source.thiserror_provide(#request);
71                     }
72                 }
73             } else {
74                 quote_spanned! {source.span()=>
75                     self.#source.thiserror_provide(#request);
76                 }
77             };
78             let self_provide = if source == backtrace {
79                 None
80             } else if type_is_option(backtrace_field.ty) {
81                 Some(quote! {
82                     if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
83                         #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
84                     }
85                 })
86             } else {
87                 Some(quote! {
88                     #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
89                 })
90             };
91             quote! {
92                 use thiserror::__private::ThiserrorProvide;
93                 #source_provide
94                 #self_provide
95             }
96         } else if type_is_option(backtrace_field.ty) {
97             quote! {
98                 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
99                     #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
100                 }
101             }
102         } else {
103             quote! {
104                 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
105             }
106         };
107         quote! {
108             fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
109                 #body
110             }
111         }
112     });
113 
114     let mut display_implied_bounds = Set::new();
115     let display_body = if input.attrs.transparent.is_some() {
116         let only_field = &input.fields[0].member;
117         display_implied_bounds.insert((0, Trait::Display));
118         Some(quote! {
119             ::core::fmt::Display::fmt(&self.#only_field, __formatter)
120         })
121     } else if let Some(display) = &input.attrs.display {
122         display_implied_bounds = display.implied_bounds.clone();
123         let use_as_display = use_as_display(display.has_bonus_display);
124         let pat = fields_pat(&input.fields);
125         Some(quote! {
126             #use_as_display
127             #[allow(unused_variables, deprecated)]
128             let Self #pat = self;
129             #display
130         })
131     } else {
132         None
133     };
134     let display_impl = display_body.map(|body| {
135         let mut display_inferred_bounds = InferredBounds::new();
136         for (field, bound) in display_implied_bounds {
137             let field = &input.fields[field];
138             if field.contains_generic {
139                 display_inferred_bounds.insert(field.ty, bound);
140             }
141         }
142         let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
143         quote! {
144             #[allow(unused_qualifications)]
145             impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
146                 #[allow(clippy::used_underscore_binding)]
147                 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
148                     #body
149                 }
150             }
151         }
152     });
153 
154     let from_impl = input.from_field().map(|from_field| {
155         let backtrace_field = input.distinct_backtrace_field();
156         let from = unoptional_type(from_field.ty);
157         let body = from_initializer(from_field, backtrace_field);
158         quote! {
159             #[allow(unused_qualifications)]
160             impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
161                 #[allow(deprecated)]
162                 fn from(source: #from) -> Self {
163                     #ty #body
164                 }
165             }
166         }
167     });
168 
169     let error_trait = spanned_error_trait(input.original);
170     if input.generics.type_params().next().is_some() {
171         let self_token = <Token![Self]>::default();
172         error_inferred_bounds.insert(self_token, Trait::Debug);
173         error_inferred_bounds.insert(self_token, Trait::Display);
174     }
175     let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
176 
177     quote! {
178         #[allow(unused_qualifications)]
179         impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
180             #source_method
181             #provide_method
182         }
183         #display_impl
184         #from_impl
185     }
186 }
187 
impl_enum(input: Enum) -> TokenStream188 fn impl_enum(input: Enum) -> TokenStream {
189     let ty = &input.ident;
190     let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
191     let mut error_inferred_bounds = InferredBounds::new();
192 
193     let source_method = if input.has_source() {
194         let arms = input.variants.iter().map(|variant| {
195             let ident = &variant.ident;
196             if variant.attrs.transparent.is_some() {
197                 let only_field = &variant.fields[0];
198                 if only_field.contains_generic {
199                     error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
200                 }
201                 let member = &only_field.member;
202                 let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
203                 quote! {
204                     #ty::#ident {#member: transparent} => #source,
205                 }
206             } else if let Some(source_field) = variant.source_field() {
207                 let source = &source_field.member;
208                 if source_field.contains_generic {
209                     let ty = unoptional_type(source_field.ty);
210                     error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
211                 }
212                 let asref = if type_is_option(source_field.ty) {
213                     Some(quote_spanned!(source.span()=> .as_ref()?))
214                 } else {
215                     None
216                 };
217                 let varsource = quote!(source);
218                 let dyn_error = quote_spanned!(source.span()=> #varsource #asref.as_dyn_error());
219                 quote! {
220                     #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
221                 }
222             } else {
223                 quote! {
224                     #ty::#ident {..} => ::core::option::Option::None,
225                 }
226             }
227         });
228         Some(quote! {
229             fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
230                 use thiserror::__private::AsDynError;
231                 #[allow(deprecated)]
232                 match self {
233                     #(#arms)*
234                 }
235             }
236         })
237     } else {
238         None
239     };
240 
241     let provide_method = if input.has_backtrace() {
242         let request = quote!(request);
243         let arms = input.variants.iter().map(|variant| {
244             let ident = &variant.ident;
245             match (variant.backtrace_field(), variant.source_field()) {
246                 (Some(backtrace_field), Some(source_field))
247                     if backtrace_field.attrs.backtrace.is_none() =>
248                 {
249                     let backtrace = &backtrace_field.member;
250                     let source = &source_field.member;
251                     let varsource = quote!(source);
252                     let source_provide = if type_is_option(source_field.ty) {
253                         quote_spanned! {source.span()=>
254                             if let ::core::option::Option::Some(source) = #varsource {
255                                 source.thiserror_provide(#request);
256                             }
257                         }
258                     } else {
259                         quote_spanned! {source.span()=>
260                             #varsource.thiserror_provide(#request);
261                         }
262                     };
263                     let self_provide = if type_is_option(backtrace_field.ty) {
264                         quote! {
265                             if let ::core::option::Option::Some(backtrace) = backtrace {
266                                 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
267                             }
268                         }
269                     } else {
270                         quote! {
271                             #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
272                         }
273                     };
274                     quote! {
275                         #ty::#ident {
276                             #backtrace: backtrace,
277                             #source: #varsource,
278                             ..
279                         } => {
280                             use thiserror::__private::ThiserrorProvide;
281                             #source_provide
282                             #self_provide
283                         }
284                     }
285                 }
286                 (Some(backtrace_field), Some(source_field))
287                     if backtrace_field.member == source_field.member =>
288                 {
289                     let backtrace = &backtrace_field.member;
290                     let varsource = quote!(source);
291                     let source_provide = if type_is_option(source_field.ty) {
292                         quote_spanned! {backtrace.span()=>
293                             if let ::core::option::Option::Some(source) = #varsource {
294                                 source.thiserror_provide(#request);
295                             }
296                         }
297                     } else {
298                         quote_spanned! {backtrace.span()=>
299                             #varsource.thiserror_provide(#request);
300                         }
301                     };
302                     quote! {
303                         #ty::#ident {#backtrace: #varsource, ..} => {
304                             use thiserror::__private::ThiserrorProvide;
305                             #source_provide
306                         }
307                     }
308                 }
309                 (Some(backtrace_field), _) => {
310                     let backtrace = &backtrace_field.member;
311                     let body = if type_is_option(backtrace_field.ty) {
312                         quote! {
313                             if let ::core::option::Option::Some(backtrace) = backtrace {
314                                 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
315                             }
316                         }
317                     } else {
318                         quote! {
319                             #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
320                         }
321                     };
322                     quote! {
323                         #ty::#ident {#backtrace: backtrace, ..} => {
324                             #body
325                         }
326                     }
327                 }
328                 (None, _) => quote! {
329                     #ty::#ident {..} => {}
330                 },
331             }
332         });
333         Some(quote! {
334             fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
335                 #[allow(deprecated)]
336                 match self {
337                     #(#arms)*
338                 }
339             }
340         })
341     } else {
342         None
343     };
344 
345     let display_impl = if input.has_display() {
346         let mut display_inferred_bounds = InferredBounds::new();
347         let has_bonus_display = input.variants.iter().any(|v| {
348             v.attrs
349                 .display
350                 .as_ref()
351                 .map_or(false, |display| display.has_bonus_display)
352         });
353         let use_as_display = use_as_display(has_bonus_display);
354         let void_deref = if input.variants.is_empty() {
355             Some(quote!(*))
356         } else {
357             None
358         };
359         let arms = input.variants.iter().map(|variant| {
360             let mut display_implied_bounds = Set::new();
361             let display = match &variant.attrs.display {
362                 Some(display) => {
363                     display_implied_bounds = display.implied_bounds.clone();
364                     display.to_token_stream()
365                 }
366                 None => {
367                     let only_field = match &variant.fields[0].member {
368                         Member::Named(ident) => ident.clone(),
369                         Member::Unnamed(index) => format_ident!("_{}", index),
370                     };
371                     display_implied_bounds.insert((0, Trait::Display));
372                     quote!(::core::fmt::Display::fmt(#only_field, __formatter))
373                 }
374             };
375             for (field, bound) in display_implied_bounds {
376                 let field = &variant.fields[field];
377                 if field.contains_generic {
378                     display_inferred_bounds.insert(field.ty, bound);
379                 }
380             }
381             let ident = &variant.ident;
382             let pat = fields_pat(&variant.fields);
383             quote! {
384                 #ty::#ident #pat => #display
385             }
386         });
387         let arms = arms.collect::<Vec<_>>();
388         let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
389         Some(quote! {
390             #[allow(unused_qualifications)]
391             impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
392                 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
393                     #use_as_display
394                     #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
395                     match #void_deref self {
396                         #(#arms,)*
397                     }
398                 }
399             }
400         })
401     } else {
402         None
403     };
404 
405     let from_impls = input.variants.iter().filter_map(|variant| {
406         let from_field = variant.from_field()?;
407         let backtrace_field = variant.distinct_backtrace_field();
408         let variant = &variant.ident;
409         let from = unoptional_type(from_field.ty);
410         let body = from_initializer(from_field, backtrace_field);
411         Some(quote! {
412             #[allow(unused_qualifications)]
413             impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
414                 #[allow(deprecated)]
415                 fn from(source: #from) -> Self {
416                     #ty::#variant #body
417                 }
418             }
419         })
420     });
421 
422     let error_trait = spanned_error_trait(input.original);
423     if input.generics.type_params().next().is_some() {
424         let self_token = <Token![Self]>::default();
425         error_inferred_bounds.insert(self_token, Trait::Debug);
426         error_inferred_bounds.insert(self_token, Trait::Display);
427     }
428     let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
429 
430     quote! {
431         #[allow(unused_qualifications)]
432         impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
433             #source_method
434             #provide_method
435         }
436         #display_impl
437         #(#from_impls)*
438     }
439 }
440 
fields_pat(fields: &[Field]) -> TokenStream441 fn fields_pat(fields: &[Field]) -> TokenStream {
442     let mut members = fields.iter().map(|field| &field.member).peekable();
443     match members.peek() {
444         Some(Member::Named(_)) => quote!({ #(#members),* }),
445         Some(Member::Unnamed(_)) => {
446             let vars = members.map(|member| match member {
447                 Member::Unnamed(member) => format_ident!("_{}", member),
448                 Member::Named(_) => unreachable!(),
449             });
450             quote!((#(#vars),*))
451         }
452         None => quote!({}),
453     }
454 }
455 
use_as_display(needs_as_display: bool) -> Option<TokenStream>456 fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
457     if needs_as_display {
458         Some(quote! {
459             use thiserror::__private::AsDisplay as _;
460         })
461     } else {
462         None
463     }
464 }
465 
from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream466 fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
467     let from_member = &from_field.member;
468     let some_source = if type_is_option(from_field.ty) {
469         quote!(::core::option::Option::Some(source))
470     } else {
471         quote!(source)
472     };
473     let backtrace = backtrace_field.map(|backtrace_field| {
474         let backtrace_member = &backtrace_field.member;
475         if type_is_option(backtrace_field.ty) {
476             quote! {
477                 #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
478             }
479         } else {
480             quote! {
481                 #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
482             }
483         }
484     });
485     quote!({
486         #from_member: #some_source,
487         #backtrace
488     })
489 }
490 
type_is_option(ty: &Type) -> bool491 fn type_is_option(ty: &Type) -> bool {
492     type_parameter_of_option(ty).is_some()
493 }
494 
unoptional_type(ty: &Type) -> TokenStream495 fn unoptional_type(ty: &Type) -> TokenStream {
496     let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
497     quote!(#unoptional)
498 }
499 
type_parameter_of_option(ty: &Type) -> Option<&Type>500 fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
501     let path = match ty {
502         Type::Path(ty) => &ty.path,
503         _ => return None,
504     };
505 
506     let last = path.segments.last().unwrap();
507     if last.ident != "Option" {
508         return None;
509     }
510 
511     let bracketed = match &last.arguments {
512         PathArguments::AngleBracketed(bracketed) => bracketed,
513         _ => return None,
514     };
515 
516     if bracketed.args.len() != 1 {
517         return None;
518     }
519 
520     match &bracketed.args[0] {
521         GenericArgument::Type(arg) => Some(arg),
522         _ => None,
523     }
524 }
525 
spanned_error_trait(input: &DeriveInput) -> TokenStream526 fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
527     let vis_span = match &input.vis {
528         Visibility::Public(vis) => Some(vis.span),
529         Visibility::Restricted(vis) => Some(vis.pub_token.span),
530         Visibility::Inherited => None,
531     };
532     let data_span = match &input.data {
533         Data::Struct(data) => data.struct_token.span,
534         Data::Enum(data) => data.enum_token.span,
535         Data::Union(data) => data.union_token.span,
536     };
537     let first_span = vis_span.unwrap_or(data_span);
538     let last_span = input.ident.span();
539     let path = quote_spanned!(first_span=> std::error::);
540     let error = quote_spanned!(last_span=> Error);
541     quote!(#path #error)
542 }
543