1 #pragma once 2 3 #include <torch/detail/static.h> 4 #include <torch/nn/cloneable.h> 5 #include <torch/nn/module.h> 6 #include <torch/nn/modules/container/any.h> 7 #include <torch/nn/modules/container/named_any.h> 8 #include <torch/nn/pimpl.h> 9 #include <torch/types.h> 10 11 #include <c10/util/Exception.h> 12 13 #include <cstdint> 14 #include <memory> 15 #include <ostream> 16 #include <string> 17 #include <type_traits> 18 #include <utility> 19 #include <vector> 20 21 namespace torch { 22 namespace nn { 23 24 /// A list of `Module`s that acts as a `Module` itself. 25 /// 26 /// A `Sequential` is fundamentally a list of `Module`s, each with a `forward()` 27 /// method. `Sequential` provides a `forward()` method of its own, which accepts 28 /// any input and forwards it to the first module it stores. It then "chains" 29 /// outputs to inputs sequentially for each subsequent module, finally returning 30 /// the output of the last module. For example: 31 /// 32 /// \rst 33 /// .. code-block:: cpp 34 /// 35 /// torch::nn::Sequential seq( 36 /// torch::nn::Linear(3, 4), 37 /// torch::nn::BatchNorm1d(4), 38 /// torch::nn::Dropout(0.5) 39 /// ); 40 /// 41 /// auto output = seq->forward(torch::ones(3)); 42 /// 43 /// \endrst 44 /// 45 /// This can conceptually be thought of as the following loop (using Python as 46 /// pseudocode): 47 /// 48 /// \rst 49 /// .. code-block:: python 50 /// 51 /// def forward(sequential, input): 52 /// for module in sequential: 53 /// input = module(input) 54 /// return input 55 /// 56 /// \endrst 57 /// 58 /// Why should you use `Sequential` instead of a simple `std::vector`? The value 59 /// a `Sequential` provides over manually calling a sequence of modules is that 60 /// it allows treating the whole container *as a single module*, such that 61 /// performing a transformation on the `Sequential` applies to each of the 62 /// modules it stores (which are each a registered submodule of the 63 /// `Sequential`). For example, calling 64 /// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to 65 /// CUDA memory. For example: 66 /// 67 /// \rst 68 /// .. code-block:: cpp 69 /// 70 /// torch::nn::Sequential seq( 71 /// torch::nn::Linear(3, 4), 72 /// torch::nn::BatchNorm1d(4), 73 /// torch::nn::Dropout(0.5) 74 /// ); 75 /// 76 /// // Convert all modules to CUDA. 77 /// seq->to(torch::kCUDA); 78 /// 79 /// \endrst 80 /// 81 /// Finally, `Sequential` provides a lightweight container API, such as allowing 82 /// iteration over submodules, positional access, adding a new module after 83 /// construction via `push_back`, as well as joining two `Sequential`s via 84 /// `extend`. 85 /// 86 /// \rst 87 /// .. attention:: 88 /// One current limitation of `Sequential` is that all except the first module 89 /// must accept a single argument. If your modules need to take multiple 90 /// arguments, you should define them to take and return tuples. 91 /// \endrst 92 class SequentialImpl : public Cloneable<SequentialImpl> { 93 public: 94 using Iterator = std::vector<AnyModule>::iterator; 95 using ConstIterator = std::vector<AnyModule>::const_iterator; 96 97 SequentialImpl() = default; 98 99 /// Constructs the `Sequential` from a variadic list of modules. 100 template <typename... Modules> SequentialImpl(Modules &&...modules)101 explicit SequentialImpl(Modules&&... modules) { 102 modules_.reserve(sizeof...(Modules)); 103 push_back(std::forward<Modules>(modules)...); 104 } 105 106 /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s. SequentialImpl(torch::OrderedDict<std::string,AnyModule> && ordered_dict)107 explicit SequentialImpl( 108 torch::OrderedDict<std::string, AnyModule>&& ordered_dict) { 109 modules_.reserve(ordered_dict.size()); 110 for (auto& item : ordered_dict) { 111 push_back(item.key(), std::move(item.value())); 112 } 113 } 114 115 /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s. 116 /// It enables the following use case: 117 /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` SequentialImpl(std::initializer_list<NamedAnyModule> named_modules)118 explicit SequentialImpl(std::initializer_list<NamedAnyModule> named_modules) { 119 modules_.reserve(named_modules.size()); 120 for (const auto& named_module : named_modules) { 121 push_back(named_module.name(), named_module.module()); 122 } 123 } 124 125 /// Special cloning function for `Sequential` because it does not use 126 /// `reset()`. 127 std::shared_ptr<Module> clone( 128 const std::optional<Device>& device = std::nullopt) const override { 129 auto clone = std::make_shared<SequentialImpl>(); 130 for (const auto& module : modules_) { 131 clone->push_back(module.clone(device)); 132 } 133 return clone; 134 } 135 136 /// `reset()` is empty for `Sequential`, since it does not have parameters of 137 /// its own. reset()138 void reset() override {} 139 140 /// Pretty prints the `Sequential` module into the given `stream`. pretty_print(std::ostream & stream)141 void pretty_print(std::ostream& stream) const override { 142 stream << "torch::nn::Sequential"; 143 } 144 145 /// Feeds `inputs` to the first module and then chains outputs to inputs, 146 /// returning the last output. 147 /// 148 /// Conceptually the following loop in Python: 149 /// 150 /// \rst 151 /// .. code-block:: python 152 /// 153 /// def forward(sequential, input): 154 /// for module in sequential: 155 /// input = module(input) 156 /// return input 157 /// 158 /// \endrst 159 /// 160 /// The return type is taken as the first template parameter. It defaults to 161 /// `Tensor`. If the last module in the `Sequential` returns another type `T`, 162 /// you should call `forward<T>(inputs)` instead of just `forward(inputs)`: 163 /// 164 /// \rst 165 /// .. code-block:: cpp 166 /// 167 /// torch::Tensor tensor = sequential1->forward(inputs); 168 /// int integer = sequential2->forward<int>(inputs); 169 /// float value = sequential3->forward<float>(inputs); 170 /// 171 /// \endrst 172 template <typename ReturnType = Tensor, typename... InputTypes> forward(InputTypes &&...inputs)173 ReturnType forward(InputTypes&&... inputs) { 174 TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential"); 175 176 auto iterator = modules_.begin(); 177 auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...); 178 179 for (++iterator; iterator != modules_.end(); ++iterator) { 180 input = iterator->any_forward(std::move(input)); 181 } 182 183 // Check the return value and give a nice error message if the requested 184 // return type was incorrect. 185 if (auto* return_value = input.template try_get<ReturnType>()) { 186 return std::move(*return_value); 187 } 188 AT_ERROR( 189 "The type of the return value is ", 190 c10::demangle(input.type_info().name()), 191 ", but you asked for type ", 192 c10::demangle(typeid(ReturnType).name())); 193 } 194 195 /// Adds a new (boxed) `Module` to the `Sequential` container. 196 template <typename ModuleType> push_back(std::shared_ptr<ModuleType> module_ptr)197 void push_back(std::shared_ptr<ModuleType> module_ptr) { 198 push_back(std::to_string(modules_.size()), std::move(module_ptr)); 199 } 200 201 /// Adds a new named (boxed) `Module` to the `Sequential` container. 202 template <typename ModuleType> push_back(std::string name,std::shared_ptr<ModuleType> module_ptr)203 void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) { 204 push_back(std::move(name), AnyModule(std::move(module_ptr))); 205 } 206 207 /// Adds a new `Module` to the `Sequential` container, moving or copying it 208 /// into a `shared_ptr` internally. This method allows passing value types, 209 /// and letting the container deal with the boxing. This means you can write 210 /// `Sequential(Module(3, 4))` instead of 211 /// `Sequential(std::make_shared<Module>(3, 4))`. 212 template <typename M, typename = torch::detail::enable_if_module_t<M>> push_back(M && module)213 void push_back(M&& module) { 214 push_back(std::to_string(modules_.size()), std::forward<M>(module)); 215 } 216 217 /// Adds a new named `Module` to the `Sequential` container, moving or copying 218 /// it into a `shared_ptr` internally. This method allows passing value types, 219 /// and letting the container deal with the boxing. 220 template <typename M, typename = torch::detail::enable_if_module_t<M>> push_back(std::string name,M && module)221 void push_back(std::string name, M&& module) { 222 using Type = typename std::remove_reference_t<M>; 223 push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module))); 224 } 225 226 /// Unwraps the contained module of a `ModuleHolder` and adds it to the 227 /// `Sequential`. 228 template <typename M> push_back(const ModuleHolder<M> & module_holder)229 void push_back(const ModuleHolder<M>& module_holder) { 230 push_back(std::to_string(modules_.size()), module_holder); 231 } 232 233 /// Unwraps the contained named module of a `ModuleHolder` and adds it to the 234 /// `Sequential`. 235 template <typename M> push_back(std::string name,const ModuleHolder<M> & module_holder)236 void push_back(std::string name, const ModuleHolder<M>& module_holder) { 237 push_back(std::move(name), module_holder.ptr()); 238 } 239 240 /// Iterates over the container and calls `push_back()` on each value. 241 template <typename Container> extend(const Container & container)242 void extend(const Container& container) { 243 for (const auto& module : container) { 244 push_back(module); 245 } 246 } 247 248 /// Adds a type-erased `AnyModule` to the `Sequential`. push_back(AnyModule any_module)249 void push_back(AnyModule any_module) { 250 push_back(std::to_string(modules_.size()), std::move(any_module)); 251 } 252 push_back(std::string name,AnyModule any_module)253 void push_back(std::string name, AnyModule any_module) { 254 modules_.push_back(std::move(any_module)); 255 const auto index = modules_.size() - 1; 256 register_module(std::move(name), modules_[index].ptr()); 257 } 258 259 /// Returns an iterator to the start of the `Sequential`. begin()260 Iterator begin() { 261 return modules_.begin(); 262 } 263 264 /// Returns a const iterator to the start of the `Sequential`. begin()265 ConstIterator begin() const { 266 return modules_.begin(); 267 } 268 269 /// Returns an iterator to the end of the `Sequential`. end()270 Iterator end() { 271 return modules_.end(); 272 } 273 274 /// Returns a const iterator to the end of the `Sequential`. end()275 ConstIterator end() const { 276 return modules_.end(); 277 } 278 279 /// Attempts to return the module at the given index as the requested type. 280 /// Throws an exception if the index is out of bounds or the types do not 281 /// match. 282 template <typename T> at(size_t index)283 T& at(size_t index) { 284 static_assert( 285 torch::detail::is_module<T>::value, 286 "Can only call Sequential::at with an nn::Module type"); 287 TORCH_CHECK(index < size(), "Index out of range"); 288 return modules_[index].get<T>(); 289 } 290 291 /// Attempts to return the module at the given index as the requested type. 292 /// Throws an exception if the index is out of bounds or the types do not 293 /// match. 294 template <typename T> at(size_t index)295 const T& at(size_t index) const { 296 static_assert( 297 torch::detail::is_module<T>::value, 298 "Can only call Sequential::at with an nn::Module type"); 299 TORCH_CHECK(index < size(), "Index out of range"); 300 return modules_[index].get<T>(); 301 } 302 303 /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the 304 /// underlying module at the given index. Throws an exception if the index is 305 /// out of bounds. ptr(size_t index)306 std::shared_ptr<Module> ptr(size_t index) const { 307 TORCH_CHECK(index < size(), "Index out of range"); 308 return modules_[index].ptr(); 309 } 310 311 /// Attempts to return a `std::shared_ptr` whose type is the one provided. 312 /// Throws an exception if the index is out of bounds or the types do not 313 /// match. 314 template <typename T> ptr(size_t index)315 std::shared_ptr<T> ptr(size_t index) const { 316 static_assert( 317 torch::detail::is_module<T>::value, 318 "Can only call Sequential::ptr with an nn::Module type"); 319 TORCH_CHECK(index < size(), "Index out of range"); 320 return modules_[index].ptr<T>(); 321 } 322 323 /// Like `ptr(index)`. 324 std::shared_ptr<Module> operator[](size_t index) const { 325 // This is the only method we can call without a type. 326 return ptr(index); 327 } 328 329 /// The current size of the `Sequential` container. size()330 size_t size() const noexcept { 331 return modules_.size(); 332 } 333 334 /// True if there are no modules in the `Sequential`. is_empty()335 bool is_empty() const noexcept { 336 return size() == 0; 337 } 338 339 private: 340 /// Takes a First *and* Second parameter, to avoid ambiguity when a parameter 341 /// pack has only one type, in which case the template would be preferred, 342 /// even if the other `push_back` functions are better fits (e.g. `unique_ptr` 343 /// -> `shared_ptr` overload). 344 /// NOTE: We explicitly avoid matching this template with 345 /// `push_back(std::string("name"), module)` or `push_back("name", module)`, 346 /// since they should be handled by their respective `push_back` functions. 347 template < 348 typename First, 349 typename Second, 350 typename... Rest, 351 typename = std::enable_if_t< 352 !std::is_same_v<First, std::string> && 353 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) 354 !std::is_same_v<std::decay_t<First>, std::decay_t<const char (&)[]>>>> push_back(First && first,Second && second,Rest &&...rest)355 void push_back(First&& first, Second&& second, Rest&&... rest) { 356 push_back(std::forward<First>(first)); 357 // Recursively calls this method, until the parameter pack only thas this 358 // entry left. Then calls `push_back()` a final time (above). 359 push_back(std::forward<Second>(second), std::forward<Rest>(rest)...); 360 } 361 362 /// The base case, when the list of modules is empty. push_back()363 void push_back() {} 364 365 // Box the AnyModules to give Sequential reference semantics, like the rest of 366 // the API. Note that this is not required otherwise, this could just be a 367 // `vector<AnyModule>`. 368 std::vector<AnyModule> modules_; 369 }; 370 371 /// A `ModuleHolder` subclass for `SequentialImpl`. 372 /// See the documentation for `SequentialImpl` class to learn what methods it 373 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 374 /// module storage semantics. 375 class Sequential : public torch::nn::ModuleHolder<SequentialImpl> { 376 public: 377 using torch::nn::ModuleHolder<SequentialImpl>::ModuleHolder; 378 Sequential()379 Sequential() : ModuleHolder() {} 380 381 /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s. 382 /// It enables the following use case: 383 /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` Sequential(std::initializer_list<NamedAnyModule> named_modules)384 Sequential(std::initializer_list<NamedAnyModule> named_modules) 385 : ModuleHolder(std::make_shared<SequentialImpl>(named_modules)) {} 386 }; 387 } // namespace nn 388 } // namespace torch 389