1 #pragma once 2 3 #include <torch/csrc/utils/variadic.h> 4 #include <torch/types.h> 5 6 #include <cstdint> 7 #include <type_traits> 8 9 namespace torch { 10 namespace nn { 11 class Module; 12 } // namespace nn 13 } // namespace torch 14 15 namespace torch { 16 namespace detail { 17 /// Detects if a type T has a forward() method. 18 template <typename T> 19 struct has_forward { 20 // Declare two types with differing size. 21 using yes = int8_t; 22 using no = int16_t; 23 24 // Here we declare two functions. The first is only enabled if `&U::forward` 25 // is well-formed and returns the `yes` type. In C++, the ellipsis parameter 26 // type (`...`) always puts the function at the bottom of overload resolution. 27 // This is specified in the standard as: 1) A standard conversion sequence is 28 // always better than a user-defined conversion sequence or an ellipsis 29 // conversion sequence. 2) A user-defined conversion sequence is always better 30 // than an ellipsis conversion sequence This means that if the first overload 31 // is viable, it will be preferred over the second as long as we pass any 32 // convertible type. The type of `&U::forward` is a pointer type, so we can 33 // pass e.g. 0. 34 template <typename U> 35 static yes test(decltype(&U::forward)); 36 template <typename U> 37 static no test(...); 38 39 // Finally we test statically whether the size of the type returned by the 40 // selected overload is the size of the `yes` type. 41 static constexpr bool value = (sizeof(test<T>(nullptr)) == sizeof(yes)); 42 }; 43 44 template <typename Head = void, typename... Tail> check_not_lvalue_references()45constexpr bool check_not_lvalue_references() { 46 return (!std::is_lvalue_reference<Head>::value || 47 std::is_const<typename std::remove_reference<Head>::type>::value) && 48 check_not_lvalue_references<Tail...>(); 49 } 50 51 template <> 52 inline constexpr bool check_not_lvalue_references<void>() { 53 return true; 54 } 55 56 /// A type trait whose `value` member is true if `M` derives from `Module`. 57 template <typename M> 58 using is_module = 59 std::is_base_of<torch::nn::Module, typename std::decay<M>::type>; 60 61 template <typename M, typename T = void> 62 using enable_if_module_t = 63 typename std::enable_if<is_module<M>::value, T>::type; 64 } // namespace detail 65 } // namespace torch 66