xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/detail/static.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/utils/variadic.h>
4 #include <torch/types.h>
5 
6 #include <cstdint>
7 #include <type_traits>
8 
9 namespace torch {
10 namespace nn {
11 class Module;
12 } // namespace nn
13 } // namespace torch
14 
15 namespace torch {
16 namespace detail {
17 /// Detects if a type T has a forward() method.
18 template <typename T>
19 struct has_forward {
20   // Declare two types with differing size.
21   using yes = int8_t;
22   using no = int16_t;
23 
24   // Here we declare two functions. The first is only enabled if `&U::forward`
25   // is well-formed and returns the `yes` type. In C++, the ellipsis parameter
26   // type (`...`) always puts the function at the bottom of overload resolution.
27   // This is specified in the standard as: 1) A standard conversion sequence is
28   // always better than a user-defined conversion sequence or an ellipsis
29   // conversion sequence. 2) A user-defined conversion sequence is always better
30   // than an ellipsis conversion sequence This means that if the first overload
31   // is viable, it will be preferred over the second as long as we pass any
32   // convertible type. The type of `&U::forward` is a pointer type, so we can
33   // pass e.g. 0.
34   template <typename U>
35   static yes test(decltype(&U::forward));
36   template <typename U>
37   static no test(...);
38 
39   // Finally we test statically whether the size of the type returned by the
40   // selected overload is the size of the `yes` type.
41   static constexpr bool value = (sizeof(test<T>(nullptr)) == sizeof(yes));
42 };
43 
44 template <typename Head = void, typename... Tail>
check_not_lvalue_references()45 constexpr bool check_not_lvalue_references() {
46   return (!std::is_lvalue_reference<Head>::value ||
47           std::is_const<typename std::remove_reference<Head>::type>::value) &&
48       check_not_lvalue_references<Tail...>();
49 }
50 
51 template <>
52 inline constexpr bool check_not_lvalue_references<void>() {
53   return true;
54 }
55 
56 /// A type trait whose `value` member is true if `M` derives from `Module`.
57 template <typename M>
58 using is_module =
59     std::is_base_of<torch::nn::Module, typename std::decay<M>::type>;
60 
61 template <typename M, typename T = void>
62 using enable_if_module_t =
63     typename std::enable_if<is_module<M>::value, T>::type;
64 } // namespace detail
65 } // namespace torch
66