xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/modules/container/any_value.h>
4 
5 namespace torch {
6 namespace nn {
7 
8 class Module;
9 
10 // ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~
11 
12 /// The static type of the object we store in the `AnyModule`, which erases
13 /// the actual type, but allows us to call `forward()` on the underlying
14 /// module.
15 struct AnyModulePlaceholder : public AnyValue::Placeholder {
16   using AnyValue::Placeholder::Placeholder;
17 
18   /// The "erased" `forward()` method.
19   virtual AnyValue forward(std::vector<AnyValue>&& arguments) = 0;
20 
21   /// Returns std::shared_ptr<Module> pointing to the erased module.
22   virtual std::shared_ptr<Module> ptr() = 0;
23 
24   /// Returns a `AnyModulePlaceholder` with a shallow copy of this `AnyModule`.
25   virtual std::unique_ptr<AnyModulePlaceholder> copy() const = 0;
26 
27   /// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`.
28   virtual std::unique_ptr<AnyModulePlaceholder> clone_module(
29       std::optional<Device> device) const = 0;
30 };
31 
32 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33 
34 /// The dynamic type of the object stored in the `AnyModule`. It contains the
35 /// concrete instance to which all calls are forwarded. It is parameterized
36 /// over the concrete type of the module, and the types of the arguments the
37 /// module takes in its `forward()` method.
38 template <typename ModuleType, typename... ArgumentTypes>
39 struct AnyModuleHolder : public AnyModulePlaceholder {
40   /// \internal
41   struct CheckedGetter {
42     template <typename T>
operatorAnyModuleHolder::CheckedGetter43     std::decay_t<T>&& operator()(size_t index) {
44       AT_ASSERT(index < arguments_.size());
45       auto& value = arguments_[index];
46       if (auto* maybe_value = value.template try_get<std::decay_t<T>>()) {
47         return std::move(*maybe_value);
48       }
49       AT_ERROR(
50           "Expected argument #",
51           index,
52           " to be of type ",
53           c10::demangle(typeid(T).name()),
54           ", but received value of type ",
55           c10::demangle(value.type_info().name()));
56     }
57     std::vector<AnyValue>& arguments_;
58   };
59 
60   /// \internal
61   struct InvokeForward {
62     template <typename... Ts>
operatorAnyModuleHolder::InvokeForward63     AnyValue operator()(Ts&&... ts) {
64       return AnyValue(module_->forward(std::forward<Ts>(ts)...));
65     }
66     std::shared_ptr<ModuleType>& module_;
67   };
68 
69   /// Constructs the `AnyModuleHolder` from a concrete module.
AnyModuleHolderAnyModuleHolder70   explicit AnyModuleHolder(std::shared_ptr<ModuleType>&& module_)
71       : AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {}
72 
73   /// Calls `forward()` on the underlying module, casting each `AnyValue` in the
74   /// argument vector to a concrete value.
forwardAnyModuleHolder75   AnyValue forward(std::vector<AnyValue>&& arguments) override {
76     if (module->_forward_has_default_args()) {
77       TORCH_CHECK(
78           arguments.size() >= module->_forward_num_required_args() &&
79               arguments.size() <= sizeof...(ArgumentTypes),
80           c10::demangle(type_info.name()),
81           "'s forward() method expects at least ",
82           module->_forward_num_required_args(),
83           " argument(s) and at most ",
84           sizeof...(ArgumentTypes),
85           " argument(s), but received ",
86           arguments.size(),
87           ".");
88       arguments = std::move(
89           module->_forward_populate_default_args(std::move(arguments)));
90     } else {
91       std::string use_default_args_macro_prompt = " If " +
92           c10::demangle(type_info.name()) +
93           "'s forward() method has default arguments, " +
94           "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.";
95       TORCH_CHECK(
96           arguments.size() == sizeof...(ArgumentTypes),
97           c10::demangle(type_info.name()),
98           "'s forward() method expects ",
99           sizeof...(ArgumentTypes),
100           " argument(s), but received ",
101           arguments.size(),
102           ".",
103           (arguments.size() < sizeof...(ArgumentTypes))
104               ? use_default_args_macro_prompt
105               : "");
106     }
107 
108     // FYI: During invocation of a module's `forward()` method, the values live
109     // in the `arguments` vector inside this function.
110     return torch::unpack<AnyValue, ArgumentTypes...>(
111         InvokeForward{module}, CheckedGetter{arguments});
112   }
113 
ptrAnyModuleHolder114   std::shared_ptr<Module> ptr() override {
115     return module;
116   }
117 
copyAnyModuleHolder118   std::unique_ptr<AnyModulePlaceholder> copy() const override {
119     return std::make_unique<AnyModuleHolder>(*this);
120   }
121 
clone_moduleAnyModuleHolder122   std::unique_ptr<AnyModulePlaceholder> clone_module(
123       std::optional<Device> device) const override {
124     return std::make_unique<AnyModuleHolder>(
125         std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
126   }
127 
128   /// The actual concrete module instance.
129   std::shared_ptr<ModuleType> module;
130 };
131 
132 } // namespace nn
133 } // namespace torch
134