1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <torch/nn/cloneable.h> 5 #include <torch/nn/module.h> 6 7 #include <utility> 8 #include <vector> 9 10 namespace torch { 11 namespace nn { 12 13 /// A list of `Module`s that registers its elements. 14 /// 15 /// \rst 16 /// .. code-block:: cpp 17 /// 18 /// torch::nn::ModuleList mlist( 19 /// torch::nn::Linear(3, 4), 20 /// torch::nn::BatchNorm1d(4), 21 /// torch::nn::Dropout(0.5) 22 /// ); 23 /// 24 /// for (const auto &module : *mlist) { 25 /// module->pretty_print(std::cout); 26 /// } 27 /// 28 /// \endrst 29 /// 30 /// Why should you use `ModuleList` instead of a simple `std::vector`? The value 31 /// a `ModuleList` provides over manually calling a sequence of modules is that 32 /// it allows treating the whole container *as a single module*, such that 33 /// performing a transformation on the `ModuleList` applies to each of the 34 /// modules it stores (which are each a registered submodule of the 35 /// `ModuleList`). For example, calling 36 /// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to 37 /// CUDA memory. For example: 38 /// 39 /// \rst 40 /// .. code-block:: cpp 41 /// 42 /// torch::nn::ModuleList mlist( 43 /// torch::nn::Linear(3, 4), 44 /// torch::nn::BatchNorm1d(4), 45 /// torch::nn::Dropout(0.5) 46 /// ); 47 /// 48 /// // Convert all modules to CUDA. 49 /// mlist->to(torch::kCUDA); 50 /// 51 /// \endrst 52 /// 53 /// Finally, `ModuleList` provides a lightweight container API, such as allowing 54 /// iteration over submodules, positional access, adding a new module after 55 /// construction via `push_back`, as well as joining two `ModuleList`s via 56 /// `extend`. 57 class ModuleListImpl : public Cloneable<ModuleListImpl> { 58 public: 59 using Iterator = std::vector<std::shared_ptr<Module>>::iterator; 60 using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator; 61 62 ModuleListImpl() = default; 63 64 /// Constructs the `ModuleList` from a variadic list of modules. 65 template <typename... Modules> ModuleListImpl(Modules &&...modules)66 explicit ModuleListImpl(Modules&&... modules) { 67 modules_.reserve(sizeof...(Modules)); 68 push_back_var(std::forward<Modules>(modules)...); 69 } 70 71 /// Special cloning function for `ModuleList` because it does not use 72 /// `reset()`. 73 std::shared_ptr<Module> clone( 74 const std::optional<Device>& device = std::nullopt) const override { 75 auto clone = std::make_shared<ModuleListImpl>(); 76 for (const auto& module : modules_) { 77 clone->push_back(module->clone(device)); 78 } 79 return clone; 80 } 81 82 /// `reset()` is empty for `ModuleList`, since it does not have parameters of 83 /// its own. reset()84 void reset() override {} 85 86 /// Pretty prints the `ModuleList` module into the given `stream`. pretty_print(std::ostream & stream)87 void pretty_print(std::ostream& stream) const override { 88 stream << "torch::nn::ModuleList"; 89 } 90 push_back(std::shared_ptr<Module> module)91 void push_back(std::shared_ptr<Module> module) { 92 modules_.push_back(std::move(module)); 93 const auto index = modules_.size() - 1; 94 register_module(std::to_string(index), modules_[index]); 95 } 96 97 /// Adds a new `Module` to the `ModuleList` container, moving or copying 98 /// it into a `shared_ptr` internally. This method allows passing value types, 99 /// and letting the container deal with the boxing. 100 template <typename M, typename = torch::detail::enable_if_module_t<M>> push_back(M && module)101 void push_back(M&& module) { 102 using Type = typename std::remove_reference<M>::type; 103 push_back(std::make_shared<Type>(std::forward<M>(module))); 104 } 105 106 /// Unwraps the contained module of a `ModuleHolder` and adds it to the 107 /// `ModuleList`. 108 template <typename M> push_back(const ModuleHolder<M> & module_holder)109 void push_back(const ModuleHolder<M>& module_holder) { 110 push_back(module_holder.ptr()); 111 } 112 113 /// Iterates over the container and calls `push_back()` on each value. 114 template <typename Container> extend(const Container & container)115 void extend(const Container& container) { 116 for (const auto& module : container) { 117 push_back(module); 118 } 119 } 120 121 /// Returns an iterator to the start of the `ModuleList`. begin()122 Iterator begin() { 123 return modules_.begin(); 124 } 125 126 /// Returns a const iterator to the start of the `ModuleList`. begin()127 ConstIterator begin() const { 128 return modules_.begin(); 129 } 130 131 /// Returns an iterator to the end of the `ModuleList`. end()132 Iterator end() { 133 return modules_.end(); 134 } 135 136 /// Returns a const iterator to the end of the `ModuleList`. end()137 ConstIterator end() const { 138 return modules_.end(); 139 } 140 141 /// Attempts to return the module at the given index as the requested type. 142 /// Throws an exception if the index is out of bounds or the types do not 143 /// match. 144 template <typename T> at(size_t index)145 T& at(size_t index) { 146 static_assert( 147 torch::detail::is_module<T>::value, 148 "Can only call ModuleList::at with an nn::Module type"); 149 TORCH_CHECK(index < size(), "Index out of range"); 150 auto module = modules_[index]->as<T>(); 151 TORCH_CHECK( 152 module, 153 "Unable to cast module[", 154 index, 155 "] to ", 156 c10::demangle(typeid(T).name())); 157 return *module; 158 } 159 160 /// Attempts to return the module at the given index as the requested type. 161 /// Throws an exception if the index is out of bounds or the types do not 162 /// match. 163 template <typename T> at(size_t index)164 const T& at(size_t index) const { 165 static_assert( 166 torch::detail::is_module<T>::value, 167 "Can only call ModuleList::at with an nn::Module type"); 168 TORCH_CHECK(index < size(), "Index out of range"); 169 const auto module = modules_[index]->as<T>(); 170 TORCH_CHECK( 171 module, 172 "Unable to cast module[", 173 index, 174 "] to ", 175 c10::demangle(typeid(T).name())); 176 return *module; 177 } 178 179 /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the 180 /// underlying module at the given index. Throws an exception if the index is 181 /// out of bounds. ptr(size_t index)182 std::shared_ptr<Module> ptr(size_t index) const { 183 TORCH_CHECK(index < size(), "Index out of range"); 184 return modules_[index]; 185 } 186 187 /// Attempts to return a `std::shared_ptr` whose type is the one provided. 188 /// Throws an exception if the index is out of bounds or the types do not 189 /// match. 190 template <typename T> ptr(size_t index)191 std::shared_ptr<T> ptr(size_t index) const { 192 static_assert( 193 torch::detail::is_module<T>::value, 194 "Can only call ModuleList::ptr with an nn::Module type"); 195 TORCH_CHECK(index < size(), "Index out of range"); 196 return std::dynamic_pointer_cast<T>(modules_[index]); 197 } 198 199 /// Like `ptr(index)`. 200 std::shared_ptr<Module> operator[](size_t index) const { 201 // This is the only method we can call without a type. 202 return ptr(index); 203 } 204 205 /// The current size of the `ModuleList` container. size()206 size_t size() const noexcept { 207 return modules_.size(); 208 } 209 210 /// True if there are no modules in the `ModuleList`. is_empty()211 bool is_empty() const noexcept { 212 return size() == 0; 213 } 214 insert(size_t index,std::shared_ptr<Module> module)215 void insert(size_t index, std::shared_ptr<Module> module) { 216 TORCH_CHECK(index <= size(), "Index out of range"); 217 218 if (index == size()) 219 push_back(std::move(module)); 220 else { 221 modules_.insert( 222 modules_.begin() + Iterator::difference_type(index), 223 std::move(module)); 224 225 for (const auto i : c10::irange(index, size() - 1)) { 226 (void)i; // Suppress unused variable warning 227 replace_module(std::to_string(index), modules_[index]); 228 } 229 register_module(std::to_string(size() - 1), modules_.back()); 230 } 231 } 232 233 /// Unwraps the contained module of a `ModuleHolder` and inserts it in the 234 /// `ModuleList`. 235 template <typename M> insert(size_t index,const ModuleHolder<M> & module_holder)236 void insert(size_t index, const ModuleHolder<M>& module_holder) { 237 insert(index, module_holder.ptr()); 238 } 239 240 /// inserts a new `Module` to the `ModuleList` container, moving or copying 241 /// it into a `shared_ptr` internally. This method allows passing value types, 242 /// and letting the container deal with the boxing. 243 template <typename M, typename = torch::detail::enable_if_module_t<M>> insert(size_t index,M && module)244 void insert(size_t index, M&& module) { 245 using Type = typename std::remove_reference<M>::type; 246 insert(index, std::make_shared<Type>(std::forward<M>(module))); 247 } 248 249 private: 250 template <typename Head, typename... Tail> push_back_var(Head && head,Tail &&...tail)251 void push_back_var(Head&& head, Tail&&... tail) { 252 push_back(std::forward<Head>(head)); 253 // Recursively calls this method, until the parameter pack only thas this 254 // entry left. Then calls `push_back()` a final time (above). 255 push_back_var(std::forward<Tail>(tail)...); 256 } 257 258 /// The base case, when the list of modules is empty. push_back_var()259 void push_back_var() {} 260 261 // Box the AnyModules to give ModuleList reference semantics, like the rest of 262 // the API. Note that this is not required otherwise, this could just be a 263 // `vector<AnyModule>`. 264 std::vector<std::shared_ptr<Module>> modules_; 265 }; 266 267 /// A `ModuleHolder` subclass for `ModuleListImpl`. 268 /// See the documentation for `ModuleListImpl` class to learn what methods it 269 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 270 /// module storage semantics. 271 TORCH_MODULE(ModuleList); 272 273 } // namespace nn 274 } // namespace torch 275