1 use crate::ast::{Enum, Field, Input, Struct};
2 use crate::attr::Trait;
3 use crate::generics::InferredBounds;
4 use proc_macro2::TokenStream;
5 use quote::{format_ident, quote, quote_spanned, ToTokens};
6 use std::collections::BTreeSet as Set;
7 use syn::spanned::Spanned;
8 use syn::{
9 Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
10 };
11
derive(node: &DeriveInput) -> Result<TokenStream>12 pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
13 let input = Input::from_syn(node)?;
14 input.validate()?;
15 Ok(match input {
16 Input::Struct(input) => impl_struct(input),
17 Input::Enum(input) => impl_enum(input),
18 })
19 }
20
impl_struct(input: Struct) -> TokenStream21 fn impl_struct(input: Struct) -> TokenStream {
22 let ty = &input.ident;
23 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24 let mut error_inferred_bounds = InferredBounds::new();
25
26 let source_body = if input.attrs.transparent.is_some() {
27 let only_field = &input.fields[0];
28 if only_field.contains_generic {
29 error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
30 }
31 let member = &only_field.member;
32 Some(quote! {
33 std::error::Error::source(self.#member.as_dyn_error())
34 })
35 } else if let Some(source_field) = input.source_field() {
36 let source = &source_field.member;
37 if source_field.contains_generic {
38 let ty = unoptional_type(source_field.ty);
39 error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
40 }
41 let asref = if type_is_option(source_field.ty) {
42 Some(quote_spanned!(source.span()=> .as_ref()?))
43 } else {
44 None
45 };
46 let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error());
47 Some(quote! {
48 ::core::option::Option::Some(#dyn_error)
49 })
50 } else {
51 None
52 };
53 let source_method = source_body.map(|body| {
54 quote! {
55 fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
56 use thiserror::__private::AsDynError;
57 #body
58 }
59 }
60 });
61
62 let provide_method = input.backtrace_field().map(|backtrace_field| {
63 let request = quote!(request);
64 let backtrace = &backtrace_field.member;
65 let body = if let Some(source_field) = input.source_field() {
66 let source = &source_field.member;
67 let source_provide = if type_is_option(source_field.ty) {
68 quote_spanned! {source.span()=>
69 if let ::core::option::Option::Some(source) = &self.#source {
70 source.thiserror_provide(#request);
71 }
72 }
73 } else {
74 quote_spanned! {source.span()=>
75 self.#source.thiserror_provide(#request);
76 }
77 };
78 let self_provide = if source == backtrace {
79 None
80 } else if type_is_option(backtrace_field.ty) {
81 Some(quote! {
82 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
83 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
84 }
85 })
86 } else {
87 Some(quote! {
88 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
89 })
90 };
91 quote! {
92 use thiserror::__private::ThiserrorProvide;
93 #source_provide
94 #self_provide
95 }
96 } else if type_is_option(backtrace_field.ty) {
97 quote! {
98 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
99 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
100 }
101 }
102 } else {
103 quote! {
104 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
105 }
106 };
107 quote! {
108 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
109 #body
110 }
111 }
112 });
113
114 let mut display_implied_bounds = Set::new();
115 let display_body = if input.attrs.transparent.is_some() {
116 let only_field = &input.fields[0].member;
117 display_implied_bounds.insert((0, Trait::Display));
118 Some(quote! {
119 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
120 })
121 } else if let Some(display) = &input.attrs.display {
122 display_implied_bounds = display.implied_bounds.clone();
123 let use_as_display = use_as_display(display.has_bonus_display);
124 let pat = fields_pat(&input.fields);
125 Some(quote! {
126 #use_as_display
127 #[allow(unused_variables, deprecated)]
128 let Self #pat = self;
129 #display
130 })
131 } else {
132 None
133 };
134 let display_impl = display_body.map(|body| {
135 let mut display_inferred_bounds = InferredBounds::new();
136 for (field, bound) in display_implied_bounds {
137 let field = &input.fields[field];
138 if field.contains_generic {
139 display_inferred_bounds.insert(field.ty, bound);
140 }
141 }
142 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
143 quote! {
144 #[allow(unused_qualifications)]
145 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
146 #[allow(clippy::used_underscore_binding)]
147 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
148 #body
149 }
150 }
151 }
152 });
153
154 let from_impl = input.from_field().map(|from_field| {
155 let backtrace_field = input.distinct_backtrace_field();
156 let from = unoptional_type(from_field.ty);
157 let body = from_initializer(from_field, backtrace_field);
158 quote! {
159 #[allow(unused_qualifications)]
160 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
161 #[allow(deprecated)]
162 fn from(source: #from) -> Self {
163 #ty #body
164 }
165 }
166 }
167 });
168
169 let error_trait = spanned_error_trait(input.original);
170 if input.generics.type_params().next().is_some() {
171 let self_token = <Token![Self]>::default();
172 error_inferred_bounds.insert(self_token, Trait::Debug);
173 error_inferred_bounds.insert(self_token, Trait::Display);
174 }
175 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
176
177 quote! {
178 #[allow(unused_qualifications)]
179 impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
180 #source_method
181 #provide_method
182 }
183 #display_impl
184 #from_impl
185 }
186 }
187
impl_enum(input: Enum) -> TokenStream188 fn impl_enum(input: Enum) -> TokenStream {
189 let ty = &input.ident;
190 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
191 let mut error_inferred_bounds = InferredBounds::new();
192
193 let source_method = if input.has_source() {
194 let arms = input.variants.iter().map(|variant| {
195 let ident = &variant.ident;
196 if variant.attrs.transparent.is_some() {
197 let only_field = &variant.fields[0];
198 if only_field.contains_generic {
199 error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
200 }
201 let member = &only_field.member;
202 let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
203 quote! {
204 #ty::#ident {#member: transparent} => #source,
205 }
206 } else if let Some(source_field) = variant.source_field() {
207 let source = &source_field.member;
208 if source_field.contains_generic {
209 let ty = unoptional_type(source_field.ty);
210 error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
211 }
212 let asref = if type_is_option(source_field.ty) {
213 Some(quote_spanned!(source.span()=> .as_ref()?))
214 } else {
215 None
216 };
217 let varsource = quote!(source);
218 let dyn_error = quote_spanned!(source.span()=> #varsource #asref.as_dyn_error());
219 quote! {
220 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
221 }
222 } else {
223 quote! {
224 #ty::#ident {..} => ::core::option::Option::None,
225 }
226 }
227 });
228 Some(quote! {
229 fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
230 use thiserror::__private::AsDynError;
231 #[allow(deprecated)]
232 match self {
233 #(#arms)*
234 }
235 }
236 })
237 } else {
238 None
239 };
240
241 let provide_method = if input.has_backtrace() {
242 let request = quote!(request);
243 let arms = input.variants.iter().map(|variant| {
244 let ident = &variant.ident;
245 match (variant.backtrace_field(), variant.source_field()) {
246 (Some(backtrace_field), Some(source_field))
247 if backtrace_field.attrs.backtrace.is_none() =>
248 {
249 let backtrace = &backtrace_field.member;
250 let source = &source_field.member;
251 let varsource = quote!(source);
252 let source_provide = if type_is_option(source_field.ty) {
253 quote_spanned! {source.span()=>
254 if let ::core::option::Option::Some(source) = #varsource {
255 source.thiserror_provide(#request);
256 }
257 }
258 } else {
259 quote_spanned! {source.span()=>
260 #varsource.thiserror_provide(#request);
261 }
262 };
263 let self_provide = if type_is_option(backtrace_field.ty) {
264 quote! {
265 if let ::core::option::Option::Some(backtrace) = backtrace {
266 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
267 }
268 }
269 } else {
270 quote! {
271 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
272 }
273 };
274 quote! {
275 #ty::#ident {
276 #backtrace: backtrace,
277 #source: #varsource,
278 ..
279 } => {
280 use thiserror::__private::ThiserrorProvide;
281 #source_provide
282 #self_provide
283 }
284 }
285 }
286 (Some(backtrace_field), Some(source_field))
287 if backtrace_field.member == source_field.member =>
288 {
289 let backtrace = &backtrace_field.member;
290 let varsource = quote!(source);
291 let source_provide = if type_is_option(source_field.ty) {
292 quote_spanned! {backtrace.span()=>
293 if let ::core::option::Option::Some(source) = #varsource {
294 source.thiserror_provide(#request);
295 }
296 }
297 } else {
298 quote_spanned! {backtrace.span()=>
299 #varsource.thiserror_provide(#request);
300 }
301 };
302 quote! {
303 #ty::#ident {#backtrace: #varsource, ..} => {
304 use thiserror::__private::ThiserrorProvide;
305 #source_provide
306 }
307 }
308 }
309 (Some(backtrace_field), _) => {
310 let backtrace = &backtrace_field.member;
311 let body = if type_is_option(backtrace_field.ty) {
312 quote! {
313 if let ::core::option::Option::Some(backtrace) = backtrace {
314 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
315 }
316 }
317 } else {
318 quote! {
319 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
320 }
321 };
322 quote! {
323 #ty::#ident {#backtrace: backtrace, ..} => {
324 #body
325 }
326 }
327 }
328 (None, _) => quote! {
329 #ty::#ident {..} => {}
330 },
331 }
332 });
333 Some(quote! {
334 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
335 #[allow(deprecated)]
336 match self {
337 #(#arms)*
338 }
339 }
340 })
341 } else {
342 None
343 };
344
345 let display_impl = if input.has_display() {
346 let mut display_inferred_bounds = InferredBounds::new();
347 let has_bonus_display = input.variants.iter().any(|v| {
348 v.attrs
349 .display
350 .as_ref()
351 .map_or(false, |display| display.has_bonus_display)
352 });
353 let use_as_display = use_as_display(has_bonus_display);
354 let void_deref = if input.variants.is_empty() {
355 Some(quote!(*))
356 } else {
357 None
358 };
359 let arms = input.variants.iter().map(|variant| {
360 let mut display_implied_bounds = Set::new();
361 let display = match &variant.attrs.display {
362 Some(display) => {
363 display_implied_bounds = display.implied_bounds.clone();
364 display.to_token_stream()
365 }
366 None => {
367 let only_field = match &variant.fields[0].member {
368 Member::Named(ident) => ident.clone(),
369 Member::Unnamed(index) => format_ident!("_{}", index),
370 };
371 display_implied_bounds.insert((0, Trait::Display));
372 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
373 }
374 };
375 for (field, bound) in display_implied_bounds {
376 let field = &variant.fields[field];
377 if field.contains_generic {
378 display_inferred_bounds.insert(field.ty, bound);
379 }
380 }
381 let ident = &variant.ident;
382 let pat = fields_pat(&variant.fields);
383 quote! {
384 #ty::#ident #pat => #display
385 }
386 });
387 let arms = arms.collect::<Vec<_>>();
388 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
389 Some(quote! {
390 #[allow(unused_qualifications)]
391 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
392 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
393 #use_as_display
394 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
395 match #void_deref self {
396 #(#arms,)*
397 }
398 }
399 }
400 })
401 } else {
402 None
403 };
404
405 let from_impls = input.variants.iter().filter_map(|variant| {
406 let from_field = variant.from_field()?;
407 let backtrace_field = variant.distinct_backtrace_field();
408 let variant = &variant.ident;
409 let from = unoptional_type(from_field.ty);
410 let body = from_initializer(from_field, backtrace_field);
411 Some(quote! {
412 #[allow(unused_qualifications)]
413 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
414 #[allow(deprecated)]
415 fn from(source: #from) -> Self {
416 #ty::#variant #body
417 }
418 }
419 })
420 });
421
422 let error_trait = spanned_error_trait(input.original);
423 if input.generics.type_params().next().is_some() {
424 let self_token = <Token![Self]>::default();
425 error_inferred_bounds.insert(self_token, Trait::Debug);
426 error_inferred_bounds.insert(self_token, Trait::Display);
427 }
428 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
429
430 quote! {
431 #[allow(unused_qualifications)]
432 impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
433 #source_method
434 #provide_method
435 }
436 #display_impl
437 #(#from_impls)*
438 }
439 }
440
fields_pat(fields: &[Field]) -> TokenStream441 fn fields_pat(fields: &[Field]) -> TokenStream {
442 let mut members = fields.iter().map(|field| &field.member).peekable();
443 match members.peek() {
444 Some(Member::Named(_)) => quote!({ #(#members),* }),
445 Some(Member::Unnamed(_)) => {
446 let vars = members.map(|member| match member {
447 Member::Unnamed(member) => format_ident!("_{}", member),
448 Member::Named(_) => unreachable!(),
449 });
450 quote!((#(#vars),*))
451 }
452 None => quote!({}),
453 }
454 }
455
use_as_display(needs_as_display: bool) -> Option<TokenStream>456 fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
457 if needs_as_display {
458 Some(quote! {
459 use thiserror::__private::AsDisplay as _;
460 })
461 } else {
462 None
463 }
464 }
465
from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream466 fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
467 let from_member = &from_field.member;
468 let some_source = if type_is_option(from_field.ty) {
469 quote!(::core::option::Option::Some(source))
470 } else {
471 quote!(source)
472 };
473 let backtrace = backtrace_field.map(|backtrace_field| {
474 let backtrace_member = &backtrace_field.member;
475 if type_is_option(backtrace_field.ty) {
476 quote! {
477 #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
478 }
479 } else {
480 quote! {
481 #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
482 }
483 }
484 });
485 quote!({
486 #from_member: #some_source,
487 #backtrace
488 })
489 }
490
type_is_option(ty: &Type) -> bool491 fn type_is_option(ty: &Type) -> bool {
492 type_parameter_of_option(ty).is_some()
493 }
494
unoptional_type(ty: &Type) -> TokenStream495 fn unoptional_type(ty: &Type) -> TokenStream {
496 let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
497 quote!(#unoptional)
498 }
499
type_parameter_of_option(ty: &Type) -> Option<&Type>500 fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
501 let path = match ty {
502 Type::Path(ty) => &ty.path,
503 _ => return None,
504 };
505
506 let last = path.segments.last().unwrap();
507 if last.ident != "Option" {
508 return None;
509 }
510
511 let bracketed = match &last.arguments {
512 PathArguments::AngleBracketed(bracketed) => bracketed,
513 _ => return None,
514 };
515
516 if bracketed.args.len() != 1 {
517 return None;
518 }
519
520 match &bracketed.args[0] {
521 GenericArgument::Type(arg) => Some(arg),
522 _ => None,
523 }
524 }
525
spanned_error_trait(input: &DeriveInput) -> TokenStream526 fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
527 let vis_span = match &input.vis {
528 Visibility::Public(vis) => Some(vis.span),
529 Visibility::Restricted(vis) => Some(vis.pub_token.span),
530 Visibility::Inherited => None,
531 };
532 let data_span = match &input.data {
533 Data::Struct(data) => data.struct_token.span,
534 Data::Enum(data) => data.enum_token.span,
535 Data::Union(data) => data.union_token.span,
536 };
537 let first_span = vis_span.unwrap_or(data_span);
538 let last_span = input.ident.span();
539 let path = quote_spanned!(first_span=> std::error::);
540 let error = quote_spanned!(last_span=> Error);
541 quote!(#path #error)
542 }
543