xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/modulelist.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/module.h>
6 
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 namespace nn {
12 
13 /// A list of `Module`s that registers its elements.
14 ///
15 /// \rst
16 /// .. code-block:: cpp
17 ///
18 ///   torch::nn::ModuleList mlist(
19 ///     torch::nn::Linear(3, 4),
20 ///     torch::nn::BatchNorm1d(4),
21 ///     torch::nn::Dropout(0.5)
22 ///   );
23 ///
24 ///   for (const auto &module : *mlist) {
25 ///     module->pretty_print(std::cout);
26 ///   }
27 ///
28 /// \endrst
29 ///
30 /// Why should you use `ModuleList` instead of a simple `std::vector`? The value
31 /// a `ModuleList` provides over manually calling a sequence of modules is that
32 /// it allows treating the whole container *as a single module*, such that
33 /// performing a transformation on the `ModuleList` applies to each of the
34 /// modules it stores (which are each a registered submodule of the
35 /// `ModuleList`). For example, calling
36 /// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to
37 /// CUDA memory. For example:
38 ///
39 /// \rst
40 /// .. code-block:: cpp
41 ///
42 ///   torch::nn::ModuleList mlist(
43 ///     torch::nn::Linear(3, 4),
44 ///     torch::nn::BatchNorm1d(4),
45 ///     torch::nn::Dropout(0.5)
46 ///   );
47 ///
48 ///   // Convert all modules to CUDA.
49 ///   mlist->to(torch::kCUDA);
50 ///
51 /// \endrst
52 ///
53 /// Finally, `ModuleList` provides a lightweight container API, such as allowing
54 /// iteration over submodules, positional access, adding a new module after
55 /// construction via `push_back`, as well as joining two `ModuleList`s via
56 /// `extend`.
57 class ModuleListImpl : public Cloneable<ModuleListImpl> {
58  public:
59   using Iterator = std::vector<std::shared_ptr<Module>>::iterator;
60   using ConstIterator = std::vector<std::shared_ptr<Module>>::const_iterator;
61 
62   ModuleListImpl() = default;
63 
64   /// Constructs the `ModuleList` from a variadic list of modules.
65   template <typename... Modules>
ModuleListImpl(Modules &&...modules)66   explicit ModuleListImpl(Modules&&... modules) {
67     modules_.reserve(sizeof...(Modules));
68     push_back_var(std::forward<Modules>(modules)...);
69   }
70 
71   /// Special cloning function for `ModuleList` because it does not use
72   /// `reset()`.
73   std::shared_ptr<Module> clone(
74       const std::optional<Device>& device = std::nullopt) const override {
75     auto clone = std::make_shared<ModuleListImpl>();
76     for (const auto& module : modules_) {
77       clone->push_back(module->clone(device));
78     }
79     return clone;
80   }
81 
82   /// `reset()` is empty for `ModuleList`, since it does not have parameters of
83   /// its own.
reset()84   void reset() override {}
85 
86   /// Pretty prints the `ModuleList` module into the given `stream`.
pretty_print(std::ostream & stream)87   void pretty_print(std::ostream& stream) const override {
88     stream << "torch::nn::ModuleList";
89   }
90 
push_back(std::shared_ptr<Module> module)91   void push_back(std::shared_ptr<Module> module) {
92     modules_.push_back(std::move(module));
93     const auto index = modules_.size() - 1;
94     register_module(std::to_string(index), modules_[index]);
95   }
96 
97   /// Adds a new `Module` to the `ModuleList` container, moving or copying
98   /// it into a `shared_ptr` internally. This method allows passing value types,
99   /// and letting the container deal with the boxing.
100   template <typename M, typename = torch::detail::enable_if_module_t<M>>
push_back(M && module)101   void push_back(M&& module) {
102     using Type = typename std::remove_reference<M>::type;
103     push_back(std::make_shared<Type>(std::forward<M>(module)));
104   }
105 
106   /// Unwraps the contained module of a `ModuleHolder` and adds it to the
107   /// `ModuleList`.
108   template <typename M>
push_back(const ModuleHolder<M> & module_holder)109   void push_back(const ModuleHolder<M>& module_holder) {
110     push_back(module_holder.ptr());
111   }
112 
113   /// Iterates over the container and calls `push_back()` on each value.
114   template <typename Container>
extend(const Container & container)115   void extend(const Container& container) {
116     for (const auto& module : container) {
117       push_back(module);
118     }
119   }
120 
121   /// Returns an iterator to the start of the `ModuleList`.
begin()122   Iterator begin() {
123     return modules_.begin();
124   }
125 
126   /// Returns a const iterator to the start of the `ModuleList`.
begin()127   ConstIterator begin() const {
128     return modules_.begin();
129   }
130 
131   /// Returns an iterator to the end of the `ModuleList`.
end()132   Iterator end() {
133     return modules_.end();
134   }
135 
136   /// Returns a const iterator to the end of the `ModuleList`.
end()137   ConstIterator end() const {
138     return modules_.end();
139   }
140 
141   /// Attempts to return the module at the given index as the requested type.
142   /// Throws an exception if the index is out of bounds or the types do not
143   /// match.
144   template <typename T>
at(size_t index)145   T& at(size_t index) {
146     static_assert(
147         torch::detail::is_module<T>::value,
148         "Can only call ModuleList::at with an nn::Module type");
149     TORCH_CHECK(index < size(), "Index out of range");
150     auto module = modules_[index]->as<T>();
151     TORCH_CHECK(
152         module,
153         "Unable to cast module[",
154         index,
155         "] to ",
156         c10::demangle(typeid(T).name()));
157     return *module;
158   }
159 
160   /// Attempts to return the module at the given index as the requested type.
161   /// Throws an exception if the index is out of bounds or the types do not
162   /// match.
163   template <typename T>
at(size_t index)164   const T& at(size_t index) const {
165     static_assert(
166         torch::detail::is_module<T>::value,
167         "Can only call ModuleList::at with an nn::Module type");
168     TORCH_CHECK(index < size(), "Index out of range");
169     const auto module = modules_[index]->as<T>();
170     TORCH_CHECK(
171         module,
172         "Unable to cast module[",
173         index,
174         "] to ",
175         c10::demangle(typeid(T).name()));
176     return *module;
177   }
178 
179   /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
180   /// underlying module at the given index. Throws an exception if the index is
181   /// out of bounds.
ptr(size_t index)182   std::shared_ptr<Module> ptr(size_t index) const {
183     TORCH_CHECK(index < size(), "Index out of range");
184     return modules_[index];
185   }
186 
187   /// Attempts to return a `std::shared_ptr` whose type is the one provided.
188   /// Throws an exception if the index is out of bounds or the types do not
189   /// match.
190   template <typename T>
ptr(size_t index)191   std::shared_ptr<T> ptr(size_t index) const {
192     static_assert(
193         torch::detail::is_module<T>::value,
194         "Can only call ModuleList::ptr with an nn::Module type");
195     TORCH_CHECK(index < size(), "Index out of range");
196     return std::dynamic_pointer_cast<T>(modules_[index]);
197   }
198 
199   /// Like `ptr(index)`.
200   std::shared_ptr<Module> operator[](size_t index) const {
201     // This is the only method we can call without a type.
202     return ptr(index);
203   }
204 
205   /// The current size of the `ModuleList` container.
size()206   size_t size() const noexcept {
207     return modules_.size();
208   }
209 
210   /// True if there are no modules in the `ModuleList`.
is_empty()211   bool is_empty() const noexcept {
212     return size() == 0;
213   }
214 
insert(size_t index,std::shared_ptr<Module> module)215   void insert(size_t index, std::shared_ptr<Module> module) {
216     TORCH_CHECK(index <= size(), "Index out of range");
217 
218     if (index == size())
219       push_back(std::move(module));
220     else {
221       modules_.insert(
222           modules_.begin() + Iterator::difference_type(index),
223           std::move(module));
224 
225       for (const auto i : c10::irange(index, size() - 1)) {
226         (void)i; // Suppress unused variable warning
227         replace_module(std::to_string(index), modules_[index]);
228       }
229       register_module(std::to_string(size() - 1), modules_.back());
230     }
231   }
232 
233   /// Unwraps the contained module of a `ModuleHolder` and inserts it in the
234   /// `ModuleList`.
235   template <typename M>
insert(size_t index,const ModuleHolder<M> & module_holder)236   void insert(size_t index, const ModuleHolder<M>& module_holder) {
237     insert(index, module_holder.ptr());
238   }
239 
240   /// inserts a new `Module` to the `ModuleList` container, moving or copying
241   /// it into a `shared_ptr` internally. This method allows passing value types,
242   /// and letting the container deal with the boxing.
243   template <typename M, typename = torch::detail::enable_if_module_t<M>>
insert(size_t index,M && module)244   void insert(size_t index, M&& module) {
245     using Type = typename std::remove_reference<M>::type;
246     insert(index, std::make_shared<Type>(std::forward<M>(module)));
247   }
248 
249  private:
250   template <typename Head, typename... Tail>
push_back_var(Head && head,Tail &&...tail)251   void push_back_var(Head&& head, Tail&&... tail) {
252     push_back(std::forward<Head>(head));
253     // Recursively calls this method, until the parameter pack only thas this
254     // entry left. Then calls `push_back()` a final time (above).
255     push_back_var(std::forward<Tail>(tail)...);
256   }
257 
258   /// The base case, when the list of modules is empty.
push_back_var()259   void push_back_var() {}
260 
261   // Box the AnyModules to give ModuleList reference semantics, like the rest of
262   // the API. Note that this is not required otherwise, this could just be a
263   // `vector<AnyModule>`.
264   std::vector<std::shared_ptr<Module>> modules_;
265 };
266 
267 /// A `ModuleHolder` subclass for `ModuleListImpl`.
268 /// See the documentation for `ModuleListImpl` class to learn what methods it
269 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
270 /// module storage semantics.
271 TORCH_MODULE(ModuleList);
272 
273 } // namespace nn
274 } // namespace torch
275