1 #pragma once 2 3 #include <torch/detail/static.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/modules/container/any.h> 6 #include <torch/nn/pimpl.h> 7 #include <torch/types.h> 8 9 #include <torch/csrc/autograd/variable.h> 10 #include <torch/csrc/utils/variadic.h> 11 12 #include <ATen/Device.h> 13 14 #include <initializer_list> 15 #include <memory> 16 #include <type_traits> 17 #include <typeinfo> 18 #include <utility> 19 #include <vector> 20 21 namespace torch { 22 namespace nn { 23 24 /// Stores a type erased `Module` with name. 25 /// 26 /// The `NamedAnyModule` class enables the following API for constructing 27 /// `nn::Sequential` with named submodules: 28 /// \rst 29 /// .. code-block:: cpp 30 /// 31 /// struct M : torch::nn::Module { 32 /// explicit M(int value_) : value(value_) {} 33 /// int value; 34 /// int forward() { 35 /// return value; 36 /// } 37 /// }; 38 /// 39 /// Sequential sequential({ 40 /// {"m1", std::make_shared<M>(1)}, // shared pointer to `Module` is 41 /// supported {std::string("m2"), M(2)}, // `Module` is supported 42 /// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported 43 /// }); 44 /// \endrst 45 class NamedAnyModule { 46 public: 47 /// Creates a `NamedAnyModule` from a (boxed) `Module`. 48 template <typename ModuleType> NamedAnyModule(std::string name,std::shared_ptr<ModuleType> module_ptr)49 NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr) 50 : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {} 51 52 /// Creates a `NamedAnyModule` from a `Module`, moving or copying it 53 /// into a `shared_ptr` internally. 54 // NOTE: We need to use `std::remove_reference<M>::type` to get rid of 55 // any reference components for make_unique. 56 template <typename M, typename = torch::detail::enable_if_module_t<M>> NamedAnyModule(std::string name,M && module)57 NamedAnyModule(std::string name, M&& module) 58 : NamedAnyModule( 59 std::move(name), 60 std::make_shared<typename std::remove_reference<M>::type>( 61 std::forward<M>(module))) {} 62 63 /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from 64 /// a `ModuleHolder`. 65 template <typename M> NamedAnyModule(std::string name,const ModuleHolder<M> & module_holder)66 NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder) 67 : NamedAnyModule(std::move(name), module_holder.ptr()) {} 68 69 /// Creates a `NamedAnyModule` from a type-erased `AnyModule`. NamedAnyModule(std::string name,AnyModule any_module)70 NamedAnyModule(std::string name, AnyModule any_module) 71 : name_(std::move(name)), module_(std::move(any_module)) {} 72 73 /// Returns a reference to the name. name()74 const std::string& name() const noexcept { 75 return name_; 76 } 77 78 /// Returns a reference to the module. module()79 AnyModule& module() noexcept { 80 return module_; 81 } 82 83 /// Returns a const reference to the module. module()84 const AnyModule& module() const noexcept { 85 return module_; 86 } 87 88 private: 89 std::string name_; 90 AnyModule module_; 91 }; 92 93 } // namespace nn 94 } // namespace torch 95