1 // This class exists only to do SFINAE on abstract types `T` that are really 2 // `ModuleHolder<ModuleType>`, because there's no good way to say that `T` is a 3 // `ModuleHolder` over some unknown type `ModuleType`. With this, you can do 4 // `enable_if_t<is_base_of_v<ModuleHolderIndicator, T>>`. 5 struct ModuleHolderIndicator {}; 6 7 // A type trait that is true for types that are `ModuleHolder`s. 8 template <typename T> 9 using is_module_holder = 10 std::is_base_of<ModuleHolderIndicator, std::decay_t<T>>; 11 12 template <typename T> 13 using disable_if_module_holder_t = 14 std::enable_if_t<!is_module_holder<T>::value>; 15 16 // A collection of templates that answer the question whether a type `T` is a 17 // `ModuleHolder`, and if so whether its contained type is of type `C`. This is 18 // tricky because it is hard to short circuit in template metaprogramming. A 19 // naive and incorrect solution to this problem would be something like 20 // `disable_if<is_module_holder<T>::value && typename T::ContainedType == C>`. 21 // This would disable all types that are not `ModuleHolder`s, because even 22 // though the `is_module_holder<T>::value` may be `false` for such types the 23 // `T::ContainedType` access would be ill-formed and thus fail the whole 24 // expression by the rules of SFINAE. Instead we have to use template 25 // specialization to statically branch on the first condition 26 // (`is_module_holder<T>`) and are only then allowed to query 27 // `T::ContainedType` in the branch for which the condition was true. 28 29 // Base template. 30 template <bool is_module_holder_value, typename T, typename C> 31 struct is_module_holder_of_impl; 32 33 // False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with 34 // contained type `C`. 35 template <typename T, typename C> 36 struct is_module_holder_of_impl<false, T, C> : std::false_type {}; 37 38 // True branch. `T` is a `ModuleHolder` and thus we can legit access its 39 // `ContainedType` and compare it against `C`. 40 template <typename T, typename C> 41 struct is_module_holder_of_impl<true, T, C> 42 : std::is_same<typename T::ContainedType, C> {}; 43 44 // Helper template. 45 template <typename T, typename C> 46 struct is_module_holder_of : is_module_holder_of_impl< 47 is_module_holder<T>::value, 48 std::decay_t<T>, 49 std::decay_t<C>> {}; 50 51 // A collection of templates that allow deducing the return type of the 52 // `forward()` method, but only if a module actually has a `forward()` method, 53 // and otherwise deduces to the type `void`. 54 55 template <bool has_forward_value, typename C, typename... Args> 56 struct return_type_of_forward_impl; 57 58 template <typename C, typename... Args> 59 struct return_type_of_forward_impl<true, C, Args...> { 60 using type = decltype(::std::declval<C>().forward(::std::declval<Args>()...)); 61 }; 62 63 template <typename C, typename... Args> 64 struct return_type_of_forward_impl<false, C, Args...> { 65 using type = void; 66 }; 67 68 template <typename C, typename... Args> 69 using return_type_of_forward = return_type_of_forward_impl< 70 torch::detail::has_forward<C>::value, 71 C, 72 Args...>; 73 74 template <typename C, typename... Args> 75 using return_type_of_forward_t = 76 typename return_type_of_forward<C, Args...>::type; 77