xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/sequential.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/module.h>
6 #include <torch/nn/modules/container/any.h>
7 #include <torch/nn/modules/container/named_any.h>
8 #include <torch/nn/pimpl.h>
9 #include <torch/types.h>
10 
11 #include <c10/util/Exception.h>
12 
13 #include <cstdint>
14 #include <memory>
15 #include <ostream>
16 #include <string>
17 #include <type_traits>
18 #include <utility>
19 #include <vector>
20 
21 namespace torch {
22 namespace nn {
23 
24 /// A list of `Module`s that acts as a `Module` itself.
25 ///
26 /// A `Sequential` is fundamentally a list of `Module`s, each with a `forward()`
27 /// method. `Sequential` provides a `forward()` method of its own, which accepts
28 /// any input and forwards it to the first module it stores. It then "chains"
29 /// outputs to inputs sequentially for each subsequent module, finally returning
30 /// the output of the last module. For example:
31 ///
32 /// \rst
33 /// .. code-block:: cpp
34 ///
35 ///   torch::nn::Sequential seq(
36 ///     torch::nn::Linear(3, 4),
37 ///     torch::nn::BatchNorm1d(4),
38 ///     torch::nn::Dropout(0.5)
39 ///   );
40 ///
41 ///   auto output = seq->forward(torch::ones(3));
42 ///
43 /// \endrst
44 ///
45 /// This can conceptually be thought of as the following loop (using Python as
46 /// pseudocode):
47 ///
48 /// \rst
49 /// .. code-block:: python
50 ///
51 ///   def forward(sequential, input):
52 ///     for module in sequential:
53 ///       input = module(input)
54 ///     return input
55 ///
56 /// \endrst
57 ///
58 /// Why should you use `Sequential` instead of a simple `std::vector`? The value
59 /// a `Sequential` provides over manually calling a sequence of modules is that
60 /// it allows treating the whole container *as a single module*, such that
61 /// performing a transformation on the `Sequential` applies to each of the
62 /// modules it stores (which are each a registered submodule of the
63 /// `Sequential`). For example, calling
64 /// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to
65 /// CUDA memory. For example:
66 ///
67 /// \rst
68 /// .. code-block:: cpp
69 ///
70 ///   torch::nn::Sequential seq(
71 ///     torch::nn::Linear(3, 4),
72 ///     torch::nn::BatchNorm1d(4),
73 ///     torch::nn::Dropout(0.5)
74 ///   );
75 ///
76 ///   // Convert all modules to CUDA.
77 ///   seq->to(torch::kCUDA);
78 ///
79 /// \endrst
80 ///
81 /// Finally, `Sequential` provides a lightweight container API, such as allowing
82 /// iteration over submodules, positional access, adding a new module after
83 /// construction via `push_back`, as well as joining two `Sequential`s via
84 /// `extend`.
85 ///
86 /// \rst
87 /// .. attention::
88 ///   One current limitation of `Sequential` is that all except the first module
89 ///   must accept a single argument. If your modules need to take multiple
90 ///   arguments, you should define them to take and return tuples.
91 /// \endrst
92 class SequentialImpl : public Cloneable<SequentialImpl> {
93  public:
94   using Iterator = std::vector<AnyModule>::iterator;
95   using ConstIterator = std::vector<AnyModule>::const_iterator;
96 
97   SequentialImpl() = default;
98 
99   /// Constructs the `Sequential` from a variadic list of modules.
100   template <typename... Modules>
SequentialImpl(Modules &&...modules)101   explicit SequentialImpl(Modules&&... modules) {
102     modules_.reserve(sizeof...(Modules));
103     push_back(std::forward<Modules>(modules)...);
104   }
105 
106   /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s.
SequentialImpl(torch::OrderedDict<std::string,AnyModule> && ordered_dict)107   explicit SequentialImpl(
108       torch::OrderedDict<std::string, AnyModule>&& ordered_dict) {
109     modules_.reserve(ordered_dict.size());
110     for (auto& item : ordered_dict) {
111       push_back(item.key(), std::move(item.value()));
112     }
113   }
114 
115   /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s.
116   /// It enables the following use case:
117   /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})`
SequentialImpl(std::initializer_list<NamedAnyModule> named_modules)118   explicit SequentialImpl(std::initializer_list<NamedAnyModule> named_modules) {
119     modules_.reserve(named_modules.size());
120     for (const auto& named_module : named_modules) {
121       push_back(named_module.name(), named_module.module());
122     }
123   }
124 
125   /// Special cloning function for `Sequential` because it does not use
126   /// `reset()`.
127   std::shared_ptr<Module> clone(
128       const std::optional<Device>& device = std::nullopt) const override {
129     auto clone = std::make_shared<SequentialImpl>();
130     for (const auto& module : modules_) {
131       clone->push_back(module.clone(device));
132     }
133     return clone;
134   }
135 
136   /// `reset()` is empty for `Sequential`, since it does not have parameters of
137   /// its own.
reset()138   void reset() override {}
139 
140   /// Pretty prints the `Sequential` module into the given `stream`.
pretty_print(std::ostream & stream)141   void pretty_print(std::ostream& stream) const override {
142     stream << "torch::nn::Sequential";
143   }
144 
145   /// Feeds `inputs` to the first module and then chains outputs to inputs,
146   /// returning the last output.
147   ///
148   /// Conceptually the following loop in Python:
149   ///
150   /// \rst
151   /// .. code-block:: python
152   ///
153   ///   def forward(sequential, input):
154   ///     for module in sequential:
155   ///       input = module(input)
156   ///     return input
157   ///
158   /// \endrst
159   ///
160   /// The return type is taken as the first template parameter. It defaults to
161   /// `Tensor`. If the last module in the `Sequential` returns another type `T`,
162   /// you should call `forward<T>(inputs)` instead of just `forward(inputs)`:
163   ///
164   /// \rst
165   /// .. code-block:: cpp
166   ///
167   ///   torch::Tensor tensor = sequential1->forward(inputs);
168   ///   int integer = sequential2->forward<int>(inputs);
169   ///   float value = sequential3->forward<float>(inputs);
170   ///
171   /// \endrst
172   template <typename ReturnType = Tensor, typename... InputTypes>
forward(InputTypes &&...inputs)173   ReturnType forward(InputTypes&&... inputs) {
174     TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
175 
176     auto iterator = modules_.begin();
177     auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
178 
179     for (++iterator; iterator != modules_.end(); ++iterator) {
180       input = iterator->any_forward(std::move(input));
181     }
182 
183     // Check the return value and give a nice error message if the requested
184     // return type was incorrect.
185     if (auto* return_value = input.template try_get<ReturnType>()) {
186       return std::move(*return_value);
187     }
188     AT_ERROR(
189         "The type of the return value is ",
190         c10::demangle(input.type_info().name()),
191         ", but you asked for type ",
192         c10::demangle(typeid(ReturnType).name()));
193   }
194 
195   /// Adds a new (boxed) `Module` to the `Sequential` container.
196   template <typename ModuleType>
push_back(std::shared_ptr<ModuleType> module_ptr)197   void push_back(std::shared_ptr<ModuleType> module_ptr) {
198     push_back(std::to_string(modules_.size()), std::move(module_ptr));
199   }
200 
201   /// Adds a new named (boxed) `Module` to the `Sequential` container.
202   template <typename ModuleType>
push_back(std::string name,std::shared_ptr<ModuleType> module_ptr)203   void push_back(std::string name, std::shared_ptr<ModuleType> module_ptr) {
204     push_back(std::move(name), AnyModule(std::move(module_ptr)));
205   }
206 
207   /// Adds a new `Module` to the `Sequential` container, moving or copying it
208   /// into a `shared_ptr` internally. This method allows passing value types,
209   /// and letting the container deal with the boxing. This means you can write
210   /// `Sequential(Module(3, 4))` instead of
211   /// `Sequential(std::make_shared<Module>(3, 4))`.
212   template <typename M, typename = torch::detail::enable_if_module_t<M>>
push_back(M && module)213   void push_back(M&& module) {
214     push_back(std::to_string(modules_.size()), std::forward<M>(module));
215   }
216 
217   /// Adds a new named `Module` to the `Sequential` container, moving or copying
218   /// it into a `shared_ptr` internally. This method allows passing value types,
219   /// and letting the container deal with the boxing.
220   template <typename M, typename = torch::detail::enable_if_module_t<M>>
push_back(std::string name,M && module)221   void push_back(std::string name, M&& module) {
222     using Type = typename std::remove_reference_t<M>;
223     push_back(std::move(name), std::make_shared<Type>(std::forward<M>(module)));
224   }
225 
226   /// Unwraps the contained module of a `ModuleHolder` and adds it to the
227   /// `Sequential`.
228   template <typename M>
push_back(const ModuleHolder<M> & module_holder)229   void push_back(const ModuleHolder<M>& module_holder) {
230     push_back(std::to_string(modules_.size()), module_holder);
231   }
232 
233   /// Unwraps the contained named module of a `ModuleHolder` and adds it to the
234   /// `Sequential`.
235   template <typename M>
push_back(std::string name,const ModuleHolder<M> & module_holder)236   void push_back(std::string name, const ModuleHolder<M>& module_holder) {
237     push_back(std::move(name), module_holder.ptr());
238   }
239 
240   /// Iterates over the container and calls `push_back()` on each value.
241   template <typename Container>
extend(const Container & container)242   void extend(const Container& container) {
243     for (const auto& module : container) {
244       push_back(module);
245     }
246   }
247 
248   /// Adds a type-erased `AnyModule` to the `Sequential`.
push_back(AnyModule any_module)249   void push_back(AnyModule any_module) {
250     push_back(std::to_string(modules_.size()), std::move(any_module));
251   }
252 
push_back(std::string name,AnyModule any_module)253   void push_back(std::string name, AnyModule any_module) {
254     modules_.push_back(std::move(any_module));
255     const auto index = modules_.size() - 1;
256     register_module(std::move(name), modules_[index].ptr());
257   }
258 
259   /// Returns an iterator to the start of the `Sequential`.
begin()260   Iterator begin() {
261     return modules_.begin();
262   }
263 
264   /// Returns a const iterator to the start of the `Sequential`.
begin()265   ConstIterator begin() const {
266     return modules_.begin();
267   }
268 
269   /// Returns an iterator to the end of the `Sequential`.
end()270   Iterator end() {
271     return modules_.end();
272   }
273 
274   /// Returns a const iterator to the end of the `Sequential`.
end()275   ConstIterator end() const {
276     return modules_.end();
277   }
278 
279   /// Attempts to return the module at the given index as the requested type.
280   /// Throws an exception if the index is out of bounds or the types do not
281   /// match.
282   template <typename T>
at(size_t index)283   T& at(size_t index) {
284     static_assert(
285         torch::detail::is_module<T>::value,
286         "Can only call Sequential::at with an nn::Module type");
287     TORCH_CHECK(index < size(), "Index out of range");
288     return modules_[index].get<T>();
289   }
290 
291   /// Attempts to return the module at the given index as the requested type.
292   /// Throws an exception if the index is out of bounds or the types do not
293   /// match.
294   template <typename T>
at(size_t index)295   const T& at(size_t index) const {
296     static_assert(
297         torch::detail::is_module<T>::value,
298         "Can only call Sequential::at with an nn::Module type");
299     TORCH_CHECK(index < size(), "Index out of range");
300     return modules_[index].get<T>();
301   }
302 
303   /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the
304   /// underlying module at the given index. Throws an exception if the index is
305   /// out of bounds.
ptr(size_t index)306   std::shared_ptr<Module> ptr(size_t index) const {
307     TORCH_CHECK(index < size(), "Index out of range");
308     return modules_[index].ptr();
309   }
310 
311   /// Attempts to return a `std::shared_ptr` whose type is the one provided.
312   /// Throws an exception if the index is out of bounds or the types do not
313   /// match.
314   template <typename T>
ptr(size_t index)315   std::shared_ptr<T> ptr(size_t index) const {
316     static_assert(
317         torch::detail::is_module<T>::value,
318         "Can only call Sequential::ptr with an nn::Module type");
319     TORCH_CHECK(index < size(), "Index out of range");
320     return modules_[index].ptr<T>();
321   }
322 
323   /// Like `ptr(index)`.
324   std::shared_ptr<Module> operator[](size_t index) const {
325     // This is the only method we can call without a type.
326     return ptr(index);
327   }
328 
329   /// The current size of the `Sequential` container.
size()330   size_t size() const noexcept {
331     return modules_.size();
332   }
333 
334   /// True if there are no modules in the `Sequential`.
is_empty()335   bool is_empty() const noexcept {
336     return size() == 0;
337   }
338 
339  private:
340   /// Takes a First *and* Second parameter, to avoid ambiguity when a parameter
341   /// pack has only one type, in which case the template would be preferred,
342   /// even if the other `push_back` functions are better fits (e.g. `unique_ptr`
343   /// -> `shared_ptr` overload).
344   /// NOTE: We explicitly avoid matching this template with
345   /// `push_back(std::string("name"), module)` or `push_back("name", module)`,
346   /// since they should be handled by their respective `push_back` functions.
347   template <
348       typename First,
349       typename Second,
350       typename... Rest,
351       typename = std::enable_if_t<
352           !std::is_same_v<First, std::string> &&
353           // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
354           !std::is_same_v<std::decay_t<First>, std::decay_t<const char (&)[]>>>>
push_back(First && first,Second && second,Rest &&...rest)355   void push_back(First&& first, Second&& second, Rest&&... rest) {
356     push_back(std::forward<First>(first));
357     // Recursively calls this method, until the parameter pack only thas this
358     // entry left. Then calls `push_back()` a final time (above).
359     push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
360   }
361 
362   /// The base case, when the list of modules is empty.
push_back()363   void push_back() {}
364 
365   // Box the AnyModules to give Sequential reference semantics, like the rest of
366   // the API. Note that this is not required otherwise, this could just be a
367   // `vector<AnyModule>`.
368   std::vector<AnyModule> modules_;
369 };
370 
371 /// A `ModuleHolder` subclass for `SequentialImpl`.
372 /// See the documentation for `SequentialImpl` class to learn what methods it
373 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
374 /// module storage semantics.
375 class Sequential : public torch::nn::ModuleHolder<SequentialImpl> {
376  public:
377   using torch::nn::ModuleHolder<SequentialImpl>::ModuleHolder;
378 
Sequential()379   Sequential() : ModuleHolder() {}
380 
381   /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s.
382   /// It enables the following use case:
383   /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})`
Sequential(std::initializer_list<NamedAnyModule> named_modules)384   Sequential(std::initializer_list<NamedAnyModule> named_modules)
385       : ModuleHolder(std::make_shared<SequentialImpl>(named_modules)) {}
386 };
387 } // namespace nn
388 } // namespace torch
389