1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/detail/static.h> 5 #include <torch/serialize/archive.h> 6 #include <torch/types.h> 7 8 #include <torch/csrc/utils/variadic.h> 9 10 #include <memory> 11 #include <type_traits> 12 #include <utility> 13 14 namespace torch { 15 namespace detail { 16 // Dump all the template metaprogramming in this file. 17 #include <torch/csrc/api/include/torch/nn/pimpl-inl.h> 18 } // namespace detail 19 20 namespace nn { 21 22 /// A `ModuleHolder` is essentially a wrapper around `std::shared_ptr<M>` where 23 /// `M` is an `nn::Module` subclass, with convenient constructors defined for 24 /// the kind of constructions we want to allow for our modules. 25 template <typename Contained> 26 class ModuleHolder : torch::detail::ModuleHolderIndicator { 27 protected: 28 /// The module pointer this class wraps. 29 /// NOTE: Must be placed at the top of the class so that we can use it with 30 /// trailing return types below. 31 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 32 std::shared_ptr<Contained> impl_; 33 34 public: 35 using ContainedType = Contained; 36 37 /// Default constructs the contained module if if has a default constructor, 38 /// else produces a static error. 39 /// 40 /// NOTE: This uses the behavior of template 41 /// classes in C++ that constructors (or any methods) are only compiled when 42 /// actually used. ModuleHolder()43 ModuleHolder() : impl_(default_construct()) { 44 static_assert( 45 std::is_default_constructible<Contained>::value, 46 "You are trying to default construct a module which has " 47 "no default constructor. Use = nullptr to give it the empty state " 48 "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`)."); 49 } 50 51 /// Constructs the `ModuleHolder` with an empty contained value. Access to 52 /// the underlying module is not permitted and will throw an exception, until 53 /// a value is assigned. ModuleHolder(std::nullptr_t)54 /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {} 55 56 /// Constructs the `ModuleHolder` with a contained module, forwarding all 57 /// arguments to its constructor. 58 template < 59 typename Head, 60 typename... Tail, 61 typename = typename std::enable_if< 62 !(torch::detail::is_module_holder_of<Head, ContainedType>::value && 63 (sizeof...(Tail) == 0))>::type> ModuleHolder(Head && head,Tail &&...tail)64 explicit ModuleHolder(Head&& head, Tail&&... tail) 65 : impl_(new Contained( 66 std::forward<Head>(head), 67 std::forward<Tail>(tail)...)) {} 68 69 /// Constructs the `ModuleHolder` from a pointer to the contained type. 70 /// Example: `Linear(std::make_shared<LinearImpl>(...))`. ModuleHolder(std::shared_ptr<Contained> module)71 /* implicit */ ModuleHolder(std::shared_ptr<Contained> module) 72 : impl_(std::move(module)) {} 73 74 /// Returns true if the `ModuleHolder` contains a module, or false if it is 75 /// `nullptr`. 76 explicit operator bool() const noexcept { 77 return !is_empty(); 78 } 79 80 /// Forwards to the contained module. 81 Contained* operator->() { 82 return get(); 83 } 84 85 /// Forwards to the contained module. 86 const Contained* operator->() const { 87 return get(); 88 } 89 90 /// Returns a reference to the contained module. 91 Contained& operator*() { 92 return *get(); 93 } 94 95 /// Returns a const reference to the contained module. 96 const Contained& operator*() const { 97 return *get(); 98 } 99 100 /// Returns a shared pointer to the underlying module. ptr()101 const std::shared_ptr<Contained>& ptr() const { 102 TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); 103 return impl_; 104 } 105 106 /// Returns a pointer to the underlying module. get()107 Contained* get() { 108 TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); 109 return impl_.get(); 110 } 111 112 /// Returns a const pointer to the underlying module. get()113 const Contained* get() const { 114 TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); 115 return impl_.get(); 116 } 117 118 /// Calls the `forward()` method of the contained module. 119 template <typename... Args> 120 auto operator()(Args&&... args) 121 -> torch::detail::return_type_of_forward_t<Contained, Args...> { 122 // This will not compile if the module does not have a `forward()` method 123 // (as expected). 124 // NOTE: `std::forward` is qualified to prevent VS2017 emitting 125 // error C2872: 'std': ambiguous symbol 126 return impl_->forward(::std::forward<Args>(args)...); 127 } 128 129 /// Forwards to the subscript operator of the contained module. 130 /// NOTE: std::forward is qualified to prevent VS2017 emitting 131 /// error C2872: 'std': ambiguous symbol 132 template <typename Arg> decltype(auto)133 decltype(auto) operator[](Arg&& arg) { 134 return (*impl_)[::std::forward<Arg>(arg)]; 135 } 136 137 /// Returns true if the `ModuleHolder` does not contain a module. is_empty()138 bool is_empty() const noexcept { 139 return impl_ == nullptr; 140 } 141 142 private: 143 template <typename T = Contained> default_construct()144 std::shared_ptr<Contained> default_construct() { 145 if constexpr (std::is_default_constructible_v<T>) { 146 return std::make_shared<Contained>(); 147 } else { 148 return nullptr; 149 } 150 } 151 }; 152 153 /// Pretty prints the given `Module` into the `ostream`. 154 template <typename ModuleType> 155 std::ostream& operator<<( 156 std::ostream& stream, 157 const nn::ModuleHolder<ModuleType>& module) { 158 return stream << *module; 159 } 160 161 /// Serializes a `ModuleHolder` into an `OutputArchive`. 162 template <typename ModuleType> 163 serialize::OutputArchive& operator<<( 164 serialize::OutputArchive& archive, 165 const nn::ModuleHolder<ModuleType>& module) { 166 return archive << module.ptr(); 167 } 168 169 /// Deserializes a `ModuleHolder` from an `InputArchive`. 170 template <typename ModuleType> 171 serialize::InputArchive& operator>>( 172 serialize::InputArchive& archive, 173 nn::ModuleHolder<ModuleType>& module) { 174 return archive >> module.ptr(); 175 } 176 177 } // namespace nn 178 } // namespace torch 179 180 // Workaround for CUDA 10.2 and below not allowing attribute unused on 181 // using declarations. 182 #ifdef __CUDACC__ 183 #define TORCH_UNUSED_EXCEPT_CUDA 184 #else 185 #define TORCH_UNUSED_EXCEPT_CUDA C10_UNUSED 186 #endif 187 188 /// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a 189 /// wrapper over a `std::shared_ptr<ImplType>`. 190 /// `Impl` is a type alias for `ImplType` which provides a way to call static 191 /// method of `ImplType`. 192 #define TORCH_MODULE_IMPL(Name, ImplType) \ 193 class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \ 194 public: \ 195 using torch::nn::ModuleHolder<ImplType>::ModuleHolder; \ 196 using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \ 197 } 198 199 /// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `<Name>Impl`. 200 #define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl) 201