1 #pragma once 2 3 #include <torch/nn/modules/container/any_value.h> 4 5 namespace torch { 6 namespace nn { 7 8 class Module; 9 10 // ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~ 11 12 /// The static type of the object we store in the `AnyModule`, which erases 13 /// the actual type, but allows us to call `forward()` on the underlying 14 /// module. 15 struct AnyModulePlaceholder : public AnyValue::Placeholder { 16 using AnyValue::Placeholder::Placeholder; 17 18 /// The "erased" `forward()` method. 19 virtual AnyValue forward(std::vector<AnyValue>&& arguments) = 0; 20 21 /// Returns std::shared_ptr<Module> pointing to the erased module. 22 virtual std::shared_ptr<Module> ptr() = 0; 23 24 /// Returns a `AnyModulePlaceholder` with a shallow copy of this `AnyModule`. 25 virtual std::unique_ptr<AnyModulePlaceholder> copy() const = 0; 26 27 /// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`. 28 virtual std::unique_ptr<AnyModulePlaceholder> clone_module( 29 std::optional<Device> device) const = 0; 30 }; 31 32 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 33 34 /// The dynamic type of the object stored in the `AnyModule`. It contains the 35 /// concrete instance to which all calls are forwarded. It is parameterized 36 /// over the concrete type of the module, and the types of the arguments the 37 /// module takes in its `forward()` method. 38 template <typename ModuleType, typename... ArgumentTypes> 39 struct AnyModuleHolder : public AnyModulePlaceholder { 40 /// \internal 41 struct CheckedGetter { 42 template <typename T> operatorAnyModuleHolder::CheckedGetter43 std::decay_t<T>&& operator()(size_t index) { 44 AT_ASSERT(index < arguments_.size()); 45 auto& value = arguments_[index]; 46 if (auto* maybe_value = value.template try_get<std::decay_t<T>>()) { 47 return std::move(*maybe_value); 48 } 49 AT_ERROR( 50 "Expected argument #", 51 index, 52 " to be of type ", 53 c10::demangle(typeid(T).name()), 54 ", but received value of type ", 55 c10::demangle(value.type_info().name())); 56 } 57 std::vector<AnyValue>& arguments_; 58 }; 59 60 /// \internal 61 struct InvokeForward { 62 template <typename... Ts> operatorAnyModuleHolder::InvokeForward63 AnyValue operator()(Ts&&... ts) { 64 return AnyValue(module_->forward(std::forward<Ts>(ts)...)); 65 } 66 std::shared_ptr<ModuleType>& module_; 67 }; 68 69 /// Constructs the `AnyModuleHolder` from a concrete module. AnyModuleHolderAnyModuleHolder70 explicit AnyModuleHolder(std::shared_ptr<ModuleType>&& module_) 71 : AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {} 72 73 /// Calls `forward()` on the underlying module, casting each `AnyValue` in the 74 /// argument vector to a concrete value. forwardAnyModuleHolder75 AnyValue forward(std::vector<AnyValue>&& arguments) override { 76 if (module->_forward_has_default_args()) { 77 TORCH_CHECK( 78 arguments.size() >= module->_forward_num_required_args() && 79 arguments.size() <= sizeof...(ArgumentTypes), 80 c10::demangle(type_info.name()), 81 "'s forward() method expects at least ", 82 module->_forward_num_required_args(), 83 " argument(s) and at most ", 84 sizeof...(ArgumentTypes), 85 " argument(s), but received ", 86 arguments.size(), 87 "."); 88 arguments = std::move( 89 module->_forward_populate_default_args(std::move(arguments))); 90 } else { 91 std::string use_default_args_macro_prompt = " If " + 92 c10::demangle(type_info.name()) + 93 "'s forward() method has default arguments, " + 94 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."; 95 TORCH_CHECK( 96 arguments.size() == sizeof...(ArgumentTypes), 97 c10::demangle(type_info.name()), 98 "'s forward() method expects ", 99 sizeof...(ArgumentTypes), 100 " argument(s), but received ", 101 arguments.size(), 102 ".", 103 (arguments.size() < sizeof...(ArgumentTypes)) 104 ? use_default_args_macro_prompt 105 : ""); 106 } 107 108 // FYI: During invocation of a module's `forward()` method, the values live 109 // in the `arguments` vector inside this function. 110 return torch::unpack<AnyValue, ArgumentTypes...>( 111 InvokeForward{module}, CheckedGetter{arguments}); 112 } 113 ptrAnyModuleHolder114 std::shared_ptr<Module> ptr() override { 115 return module; 116 } 117 copyAnyModuleHolder118 std::unique_ptr<AnyModulePlaceholder> copy() const override { 119 return std::make_unique<AnyModuleHolder>(*this); 120 } 121 clone_moduleAnyModuleHolder122 std::unique_ptr<AnyModulePlaceholder> clone_module( 123 std::optional<Device> device) const override { 124 return std::make_unique<AnyModuleHolder>( 125 std::dynamic_pointer_cast<ModuleType>(module->clone(device))); 126 } 127 128 /// The actual concrete module instance. 129 std::shared_ptr<ModuleType> module; 130 }; 131 132 } // namespace nn 133 } // namespace torch 134