xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/function_schema.h>
4 #include <c10/util/Metaprogramming.h>
5 #include <c10/util/flat_hash_map.h>
6 #include <c10/core/DispatchKey.h>
7 #include <c10/core/PyHandleCache.h>
8 #include <c10/core/SafePyObject.h>
9 #include <ATen/core/ivalue.h>
10 #include <ATen/core/boxing/KernelFunction.h>
11 #include <ATen/core/dispatch/DispatchKeyExtractor.h>
12 
13 #include <ATen/core/dispatch/OperatorOptions.h>
14 #include <ATen/core/dispatch/CppSignature.h>
15 #include <ATen/core/dispatch/RegistrationHandleRAII.h>
16 #include <ATen/core/enum_tag.h>
17 
18 #include <optional>
19 #include <array>
20 #include <list>
21 
22 #ifdef C10_MOBILE
23 #define C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
24 #endif
25 
26 namespace c10 {
27 
28 class Dispatcher;
29 
30 namespace impl {
31 
32 // This data structure represents a kernel that was registered to us from a
33 // user.  Unlike KernelFunction, AnnotatedKernel contains some extra metadata
34 // about the kernel that isn't necessary for actual dispatching (this is why
35 // we don't put AnnotatedKernel in the actual DispatchTable), but is useful for
36 // giving good error messages.
37 struct AnnotatedKernel final {
AnnotatedKernelfinal38   AnnotatedKernel(KernelFunction k, std::unique_ptr<FunctionSchema> s, std::string d)
39     : kernel(std::move(k))
40     , inferred_function_schema(std::move(s))
41     , debug(std::move(d))
42     {}
43   AnnotatedKernel() = default;
44   KernelFunction kernel;
45   std::unique_ptr<FunctionSchema> inferred_function_schema;
46   // A little debug string to help us identify the kernel in question.
47   // Most importantly it records the TORCH_LIBRARY block that did the
48   // registration.
49   std::string debug;
50 };
51 
52 // This data structure represents operator schema, with metadata specifying
53 // where the registration of this schema occurred
54 struct AnnotatedSchema final {
AnnotatedSchemafinal55   AnnotatedSchema(FunctionSchema s, std::string d)
56     : schema(std::move(s))
57     , debug(std::move(d))
58     {}
59   FunctionSchema schema;
60   std::string debug;
61 };
62 
63 // Internal data structure that records information about a specific operator.
64 // It's not part of the public API; typically, users will interact with
65 // OperatorHandle instead.
66 //
67 // Concurrent writes to OperatorEntry are protected by the GLOBAL Dispatcher
68 // lock (this is important because some methods in OperatorEntry access
69 // dispatcher state)
70 class TORCH_API OperatorEntry final {
71 public:
72   explicit OperatorEntry(OperatorName&& operator_name);
73 
74   OperatorEntry(const OperatorEntry&) = delete;
75   OperatorEntry(OperatorEntry&&) noexcept = delete;
76   OperatorEntry& operator=(const OperatorEntry&) = delete;
77   OperatorEntry& operator=(OperatorEntry&&) noexcept = delete;
78 
schema()79   const FunctionSchema& schema() const {
80     TORCH_INTERNAL_ASSERT(schema_.has_value(), "Tried to access the schema for ", name_, " which doesn't have a schema registered yet");
81     return schema_->schema;
82   }
debug()83   const std::string& debug() const {
84     TORCH_INTERNAL_ASSERT(schema_.has_value());
85     return schema_->debug;
86   }
hasSchema()87   bool hasSchema() const {
88     return schema_.has_value();
89   }
90 
isObserved()91   bool isObserved() const {
92     return is_observed_;
93   }
94 
95   // We may allocate an OperatorEntry for an operator even when we don't
96   // have a schema.  When we receive the schema registration, we post
97   // facto register a schema.
98   //
99   // NB: registerSchema/deregisterSchema are not idempotent; if you
100   // attempt to register a schema when one is already present or vice
101   // versa that is an error.  (Refcounting for the registrations is
102   // handled in the OperatorHandle in Dispatcher)
103   void registerSchema(FunctionSchema&&, std::string&& debug, std::vector<at::Tag> tags = {});
104   void deregisterSchema();
105 
operator_name()106   const OperatorName& operator_name() const {
107     return name_;
108   }
109 
110 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
111   using AnnotatedKernelContainer = std::array<AnnotatedKernel, 1>;
112 #else
113   using AnnotatedKernelContainer = std::list<AnnotatedKernel>;
114 #endif
115   using AnnotatedKernelContainerIterator = AnnotatedKernelContainer::iterator;
116 
117   // Why are kernels and fallback asymmetric?  It has to do with ownership.
118   // Kernels and the computed dispatch tables for them are canonically
119   // owned by OperatorEntry, but backend fallbacks are specified once
120   // and apply for all operators, so they should be owned by Dispatcher.
121   // However, the registration of a backend fallback affects the
122   // state of the computed dispatch table, so when a backend fallback
123   // is updated, we need to update the operator tables too.  Thus,
124   // registerKernel is the mechanism by which we give kernels to
125   // operator entry to own (and update dispatch table), but we only
126   // need a non-owning mechanism to update fallback.
127 
128   // Precondition: Dispatcher::mutex_ is held
129   // Postcondition: caller is responsible for disposing of the kernel
130   AnnotatedKernelContainerIterator registerKernel(
131     const Dispatcher& dispatcher,
132     std::optional<DispatchKey> dispatch_key,
133     KernelFunction kernel,
134     std::optional<CppSignature> cpp_signature,
135     std::unique_ptr<FunctionSchema> inferred_function_schema,
136     std::string debug
137   );
138 
139   // Precondition: Dispatcher::mutex_ is held
140   void deregisterKernel_(
141     const Dispatcher& dispatcher,
142     std::optional<DispatchKey> dispatch_key,
143     AnnotatedKernelContainerIterator kernel
144   );
145 
146   // Precondition: Dispatcher::mutex_ is held
147   void updateFallback(
148     const Dispatcher& dispatcher,
149     DispatchKey dispatch_key
150   );
151 
152   // Precondition: Dispatcher::mutex_ is held
updateSchemaAliasAnalysis(AliasAnalysisKind a)153   void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
154     TORCH_INTERNAL_ASSERT(schema_.has_value());
155     schema_->schema.setAliasAnalysis(a);
156   }
157 
158   std::string dumpComputedTable() const;
159   std::string dumpState() const;
160   void checkInvariants() const;
161 
dispatchKeyExtractor()162   const DispatchKeyExtractor& dispatchKeyExtractor() const { return dispatchKeyExtractor_; }
163 
164   // Asserts that the given FuncType is correct for calling this operator in an unboxed way.
165   template<class FuncType>
assertSignatureIsCorrect()166   inline void assertSignatureIsCorrect() {
167     assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
168   }
169 
170   void assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const;
171 
172   [[noreturn]] void reportError(DispatchKey dispatchKey) const;
173 
lookup(DispatchKeySet ks)174   const KernelFunction& lookup(DispatchKeySet ks) const {
175     const auto idx = ks.getDispatchTableIndexForDispatchKeySet();
176     if (C10_UNLIKELY(idx == -1)) {
177       reportError(ks.highestPriorityTypeId());
178     }
179     const auto& kernel = dispatchTable_[idx];
180     // A valid kernel *always* has a boxed kernel and *may* have an
181     // unboxed kernel. However, we typically do unboxed calls in at::
182     // APIs, where the kernel 1) will very likely be valid and 2)
183     // should have an unboxed kernel. Checking the unboxed kernel
184     // first will allow us to avoid touching the boxed kernel at all
185     // in the common case.
186     if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
187       if (!kernel.isValid()) {
188         reportError(ks.highestPriorityTypeId());
189       }
190     }
191     return kernel;
192   }
193 
194   std::string listAllDispatchKeys() const;
195 
196   // Returns true if kernel_ has entry for any key in ks.
197   //
198   // Invariant: There are no alias keys in the passed-in dispatch key set.
199   // Note [No Alias Keys in DispatchKeySet]
200   // Alias keys should be checked using `hasKernelForDispatchKey`
201   // Alias keys shouldn't go inside of a DispatchKeySet, since they can technically
202   // have a value > 63 (causing overflow).
203   bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
204   // Returns true if kernel_ has entry for a particular key.
205   bool hasKernelForDispatchKey(DispatchKey k) const;
206   // Retrieves the kernel entry at a particular key.  Symmetric with
207   // hasKernelForDispatchKey.  To get the AnnotatedKernel, see
208   // getKernelForDispatchKey (private)
209   const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
210   // Returns true if the "computed table" has an entry for a particular key.
211   bool hasComputedKernelForDispatchKey(DispatchKey k) const;
212   // Returns all the operator tags added at the time of registration
213   const std::vector<at::Tag>& getTags() const;
214   void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
215 
216   template <typename F>
getPythonOp(PyInterpreter * self_interpreter,F slow_accessor)217   PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
218     return py_cache_.ptr_or(self_interpreter, slow_accessor);
219   }
220 
221 private:
222 
223   OperatorName name_;
224   std::optional<AnnotatedSchema> schema_;
225   #ifndef C10_MOBILE
226     std::vector<at::Tag> tags_;
227   #endif
228   std::array<KernelFunction, c10::num_runtime_entries> dispatchTable_;
229   DispatchKeyExtractor dispatchKeyExtractor_;
230   // Pointer to the torch.ops.ns.op.overload object for speed
231   c10::PyHandleCache py_cache_;
232 
233   // kernels_ stores all registered kernels for the corresponding dispatch key
234   // and catchAllKernels_ stores the catch-all kernels.
235   // If an operator library gets loaded that overwrites an already existing kernel,
236   // both kernels will be in that list but only the newer one will be in
237   // dispatchTable. If any of the kernels go away (say the library gets
238   // unloaded), we remove the kernel from this list and update the
239   // dispatchTable if necessary.
240   // Kernels in the list are ordered by registration time descendingly,
241   // newer registrations are before older registrations.
242   // We do not combine dispatchTable and kernels into one hash map because
243   // kernels is a larger data structure and accessed quite infrequently
244   // while dispatchTable is accessed often and should be kept small to fit
245   // into CPU caches.
246   // Invariants:
247   //  - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
248   //  - dispatchTable[dispatch_key] does not exist if and only if
249   //    kernels_[dispatch_key] does not exist
250   //  - If kernels_[dispatch_key] exists, then it has elements.
251   //    It is never an empty list.
252   //
253   // Why do we do that?
254   // -----
255   // We mostly do this to enable Jupyter notebooks where a cell registering
256   // a kernel could be executed multiple times and the later execution
257   // should overwrite the earlier one. Note that this still fails when the
258   // function schema changed between the executions, but it works as long
259   // as the function schema didn't change. A better solution would be to
260   // unload the old extension library from the Jupyter cell when the cell is
261   // re-executed and then only allow one kernel here, i.e. error if a kernel
262   // is already registered, but that's a lot of effort to implement and
263   // currently not high-pri.
264   ska::flat_hash_map<DispatchKey,
265 #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
266                      // On mobile, we needn't worry about Jupyter notebooks.
267                      std::array<AnnotatedKernel, 1>
268 #else
269                      std::list<AnnotatedKernel>
270 #endif
271                      > kernels_;
272 
273   const AnnotatedKernel& missingKernel() const;
274   const AnnotatedKernel& ambiguousAutogradOtherKernel() const;
275 
276   // cpp_signature_ stores function signature if any of
277   // the kernels was created in a way that allowed us to know the function
278   // signature (i.e. by supplying an unboxed C++ kernel function).
279   // If this is set, it will be used to check that future kernel
280   // registrations match and it will be used in unboxed function calls
281   // to verify their arguments against the known function signature.
282   struct CppSignatureWithDebug {
283     CppSignature signature;
284     std::string debug;
285     std::optional<DispatchKey> dispatch_key;
286   };
287   std::optional<CppSignatureWithDebug> cpp_signature_;
288   std::optional<CppSignatureWithDebug> sym_cpp_signature_;
289 
290   // A Python custom error handler for OperatorEntry::reportError
291   std::unique_ptr<c10::SafePyObject> report_error_callback_;
292 
293   // Whether this operator needs to be observed with RecordFunction
294   const bool is_observed_;
295 
296   [[noreturn]] void reportSignatureError(const CppSignature& call_signature, const CppSignatureWithDebug& saved_signature) const;
297   const KernelFunction& computeDispatchTableEntry(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const;
298   std::pair<const AnnotatedKernel&, const char*> computeDispatchTableEntryWithDebug(
299     const c10::Dispatcher& dispatcher, DispatchKey dispatch_key
300   ) const;
301   // This function re-establishes the invariant that dispatchTable
302   // contains the front element from the kernels list for a given runtime dispatch key.
303   void updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
304   // Like above, but also handles alias dispatch keys.
305   void updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key);
306   // Like above, but for ALL entries in the dispatch table.
307   void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);
308   // Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front().
309   const AnnotatedKernel* getKernelForDispatchKey(DispatchKey dispatch_key) const;
310 };
311 
312 } // namespace impl
313 } // namespace c10
314