xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/pimpl-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This class exists  only to do SFINAE on abstract types `T` that are really
2 // `ModuleHolder<ModuleType>`, because there's no good way to say that `T` is a
3 // `ModuleHolder` over some unknown type `ModuleType`. With this, you can do
4 // `enable_if_t<is_base_of_v<ModuleHolderIndicator, T>>`.
5 struct ModuleHolderIndicator {};
6 
7 // A type trait that is true for types that are `ModuleHolder`s.
8 template <typename T>
9 using is_module_holder =
10     std::is_base_of<ModuleHolderIndicator, std::decay_t<T>>;
11 
12 template <typename T>
13 using disable_if_module_holder_t =
14     std::enable_if_t<!is_module_holder<T>::value>;
15 
16 // A collection of templates that answer the question whether a type `T` is a
17 // `ModuleHolder`, and if so whether its contained type is of type `C`. This is
18 // tricky because it is hard to short circuit in template metaprogramming. A
19 // naive and incorrect solution to this problem would be something like
20 // `disable_if<is_module_holder<T>::value && typename T::ContainedType == C>`.
21 // This would disable all types that are not `ModuleHolder`s, because even
22 // though the `is_module_holder<T>::value` may be `false` for such types the
23 // `T::ContainedType` access would be ill-formed and thus fail the whole
24 // expression by the rules of SFINAE. Instead we have to use template
25 // specialization to statically branch on the first condition
26 // (`is_module_holder<T>`) and are only then allowed to query
27 // `T::ContainedType` in the branch for which the condition was true.
28 
29 // Base template.
30 template <bool is_module_holder_value, typename T, typename C>
31 struct is_module_holder_of_impl;
32 
33 // False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with
34 // contained type `C`.
35 template <typename T, typename C>
36 struct is_module_holder_of_impl<false, T, C> : std::false_type {};
37 
38 // True branch. `T` is a `ModuleHolder` and thus we can legit access its
39 // `ContainedType` and compare it against `C`.
40 template <typename T, typename C>
41 struct is_module_holder_of_impl<true, T, C>
42     : std::is_same<typename T::ContainedType, C> {};
43 
44 // Helper template.
45 template <typename T, typename C>
46 struct is_module_holder_of : is_module_holder_of_impl<
47                                  is_module_holder<T>::value,
48                                  std::decay_t<T>,
49                                  std::decay_t<C>> {};
50 
51 // A collection of templates that allow deducing the return type of the
52 // `forward()` method, but only if a module actually has a `forward()` method,
53 // and otherwise deduces to the type `void`.
54 
55 template <bool has_forward_value, typename C, typename... Args>
56 struct return_type_of_forward_impl;
57 
58 template <typename C, typename... Args>
59 struct return_type_of_forward_impl<true, C, Args...> {
60   using type = decltype(::std::declval<C>().forward(::std::declval<Args>()...));
61 };
62 
63 template <typename C, typename... Args>
64 struct return_type_of_forward_impl<false, C, Args...> {
65   using type = void;
66 };
67 
68 template <typename C, typename... Args>
69 using return_type_of_forward = return_type_of_forward_impl<
70     torch::detail::has_forward<C>::value,
71     C,
72     Args...>;
73 
74 template <typename C, typename... Args>
75 using return_type_of_forward_t =
76     typename return_type_of_forward<C, Args...>::type;
77