xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/pimpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/detail/static.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/utils/variadic.h>
9 
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 
14 namespace torch {
15 namespace detail {
16 // Dump all the template metaprogramming in this file.
17 #include <torch/csrc/api/include/torch/nn/pimpl-inl.h>
18 } // namespace detail
19 
20 namespace nn {
21 
22 /// A `ModuleHolder` is essentially a wrapper around `std::shared_ptr<M>` where
23 /// `M` is an `nn::Module` subclass, with convenient constructors defined for
24 /// the kind of constructions we want to allow for our modules.
25 template <typename Contained>
26 class ModuleHolder : torch::detail::ModuleHolderIndicator {
27  protected:
28   /// The module pointer this class wraps.
29   /// NOTE: Must be placed at the top of the class so that we can use it with
30   /// trailing return types below.
31   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
32   std::shared_ptr<Contained> impl_;
33 
34  public:
35   using ContainedType = Contained;
36 
37   /// Default constructs the contained module if if has a default constructor,
38   /// else produces a static error.
39   ///
40   /// NOTE: This uses the behavior of template
41   /// classes in C++ that constructors (or any methods) are only compiled when
42   /// actually used.
ModuleHolder()43   ModuleHolder() : impl_(default_construct()) {
44     static_assert(
45         std::is_default_constructible<Contained>::value,
46         "You are trying to default construct a module which has "
47         "no default constructor. Use = nullptr to give it the empty state "
48         "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`).");
49   }
50 
51   /// Constructs the `ModuleHolder` with an empty contained value. Access to
52   /// the underlying module is not permitted and will throw an exception, until
53   /// a value is assigned.
ModuleHolder(std::nullptr_t)54   /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {}
55 
56   /// Constructs the `ModuleHolder` with a contained module, forwarding all
57   /// arguments to its constructor.
58   template <
59       typename Head,
60       typename... Tail,
61       typename = typename std::enable_if<
62           !(torch::detail::is_module_holder_of<Head, ContainedType>::value &&
63             (sizeof...(Tail) == 0))>::type>
ModuleHolder(Head && head,Tail &&...tail)64   explicit ModuleHolder(Head&& head, Tail&&... tail)
65       : impl_(new Contained(
66             std::forward<Head>(head),
67             std::forward<Tail>(tail)...)) {}
68 
69   /// Constructs the `ModuleHolder` from a pointer to the contained type.
70   /// Example: `Linear(std::make_shared<LinearImpl>(...))`.
ModuleHolder(std::shared_ptr<Contained> module)71   /* implicit */ ModuleHolder(std::shared_ptr<Contained> module)
72       : impl_(std::move(module)) {}
73 
74   /// Returns true if the `ModuleHolder` contains a module, or false if it is
75   /// `nullptr`.
76   explicit operator bool() const noexcept {
77     return !is_empty();
78   }
79 
80   /// Forwards to the contained module.
81   Contained* operator->() {
82     return get();
83   }
84 
85   /// Forwards to the contained module.
86   const Contained* operator->() const {
87     return get();
88   }
89 
90   /// Returns a reference to the contained module.
91   Contained& operator*() {
92     return *get();
93   }
94 
95   /// Returns a const reference to the contained module.
96   const Contained& operator*() const {
97     return *get();
98   }
99 
100   /// Returns a shared pointer to the underlying module.
ptr()101   const std::shared_ptr<Contained>& ptr() const {
102     TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
103     return impl_;
104   }
105 
106   /// Returns a pointer to the underlying module.
get()107   Contained* get() {
108     TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
109     return impl_.get();
110   }
111 
112   /// Returns a const pointer to the underlying module.
get()113   const Contained* get() const {
114     TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder");
115     return impl_.get();
116   }
117 
118   /// Calls the `forward()` method of the contained module.
119   template <typename... Args>
120   auto operator()(Args&&... args)
121       -> torch::detail::return_type_of_forward_t<Contained, Args...> {
122     // This will not compile if the module does not have a `forward()` method
123     // (as expected).
124     // NOTE: `std::forward` is qualified to prevent VS2017 emitting
125     // error C2872: 'std': ambiguous symbol
126     return impl_->forward(::std::forward<Args>(args)...);
127   }
128 
129   /// Forwards to the subscript operator of the contained module.
130   /// NOTE: std::forward is qualified to prevent VS2017 emitting
131   ///       error C2872: 'std': ambiguous symbol
132   template <typename Arg>
decltype(auto)133   decltype(auto) operator[](Arg&& arg) {
134     return (*impl_)[::std::forward<Arg>(arg)];
135   }
136 
137   /// Returns true if the `ModuleHolder` does not contain a module.
is_empty()138   bool is_empty() const noexcept {
139     return impl_ == nullptr;
140   }
141 
142  private:
143   template <typename T = Contained>
default_construct()144   std::shared_ptr<Contained> default_construct() {
145     if constexpr (std::is_default_constructible_v<T>) {
146       return std::make_shared<Contained>();
147     } else {
148       return nullptr;
149     }
150   }
151 };
152 
153 /// Pretty prints the given `Module` into the `ostream`.
154 template <typename ModuleType>
155 std::ostream& operator<<(
156     std::ostream& stream,
157     const nn::ModuleHolder<ModuleType>& module) {
158   return stream << *module;
159 }
160 
161 /// Serializes a `ModuleHolder` into an `OutputArchive`.
162 template <typename ModuleType>
163 serialize::OutputArchive& operator<<(
164     serialize::OutputArchive& archive,
165     const nn::ModuleHolder<ModuleType>& module) {
166   return archive << module.ptr();
167 }
168 
169 /// Deserializes a `ModuleHolder` from an `InputArchive`.
170 template <typename ModuleType>
171 serialize::InputArchive& operator>>(
172     serialize::InputArchive& archive,
173     nn::ModuleHolder<ModuleType>& module) {
174   return archive >> module.ptr();
175 }
176 
177 } // namespace nn
178 } // namespace torch
179 
180 // Workaround for CUDA 10.2 and below not allowing attribute unused on
181 // using declarations.
182 #ifdef __CUDACC__
183 #define TORCH_UNUSED_EXCEPT_CUDA
184 #else
185 #define TORCH_UNUSED_EXCEPT_CUDA C10_UNUSED
186 #endif
187 
188 /// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a
189 /// wrapper over a `std::shared_ptr<ImplType>`.
190 /// `Impl` is a type alias for `ImplType` which provides a way to call static
191 /// method of `ImplType`.
192 #define TORCH_MODULE_IMPL(Name, ImplType)                              \
193   class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \
194    public:                                                             \
195     using torch::nn::ModuleHolder<ImplType>::ModuleHolder;             \
196     using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType;                    \
197   }
198 
199 /// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `<Name>Impl`.
200 #define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)
201