xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/named_any.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/container/any.h>
6 #include <torch/nn/pimpl.h>
7 #include <torch/types.h>
8 
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/utils/variadic.h>
11 
12 #include <ATen/Device.h>
13 
14 #include <initializer_list>
15 #include <memory>
16 #include <type_traits>
17 #include <typeinfo>
18 #include <utility>
19 #include <vector>
20 
21 namespace torch {
22 namespace nn {
23 
24 /// Stores a type erased `Module` with name.
25 ///
26 /// The `NamedAnyModule` class enables the following API for constructing
27 /// `nn::Sequential` with named submodules:
28 /// \rst
29 /// .. code-block:: cpp
30 ///
31 ///   struct M : torch::nn::Module {
32 ///     explicit M(int value_) : value(value_) {}
33 ///     int value;
34 ///     int forward() {
35 ///       return value;
36 ///     }
37 ///   };
38 ///
39 ///   Sequential sequential({
40 ///     {"m1", std::make_shared<M>(1)},  // shared pointer to `Module` is
41 ///     supported {std::string("m2"), M(2)},  // `Module` is supported
42 ///     {"linear1", Linear(10, 3)}  // `ModuleHolder` is supported
43 ///   });
44 /// \endrst
45 class NamedAnyModule {
46  public:
47   /// Creates a `NamedAnyModule` from a (boxed) `Module`.
48   template <typename ModuleType>
NamedAnyModule(std::string name,std::shared_ptr<ModuleType> module_ptr)49   NamedAnyModule(std::string name, std::shared_ptr<ModuleType> module_ptr)
50       : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {}
51 
52   /// Creates a `NamedAnyModule` from a `Module`, moving or copying it
53   /// into a `shared_ptr` internally.
54   // NOTE: We need to use `std::remove_reference<M>::type` to get rid of
55   // any reference components for make_unique.
56   template <typename M, typename = torch::detail::enable_if_module_t<M>>
NamedAnyModule(std::string name,M && module)57   NamedAnyModule(std::string name, M&& module)
58       : NamedAnyModule(
59             std::move(name),
60             std::make_shared<typename std::remove_reference<M>::type>(
61                 std::forward<M>(module))) {}
62 
63   /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from
64   /// a `ModuleHolder`.
65   template <typename M>
NamedAnyModule(std::string name,const ModuleHolder<M> & module_holder)66   NamedAnyModule(std::string name, const ModuleHolder<M>& module_holder)
67       : NamedAnyModule(std::move(name), module_holder.ptr()) {}
68 
69   /// Creates a `NamedAnyModule` from a type-erased `AnyModule`.
NamedAnyModule(std::string name,AnyModule any_module)70   NamedAnyModule(std::string name, AnyModule any_module)
71       : name_(std::move(name)), module_(std::move(any_module)) {}
72 
73   /// Returns a reference to the name.
name()74   const std::string& name() const noexcept {
75     return name_;
76   }
77 
78   /// Returns a reference to the module.
module()79   AnyModule& module() noexcept {
80     return module_;
81   }
82 
83   /// Returns a const reference to the module.
module()84   const AnyModule& module() const noexcept {
85     return module_;
86   }
87 
88  private:
89   std::string name_;
90   AnyModule module_;
91 };
92 
93 } // namespace nn
94 } // namespace torch
95