xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/SequenceNumber.h>
4 #include <ATen/core/boxing/KernelFunction.h>
5 #include <ATen/core/boxing/impl/boxing.h>
6 #include <ATen/core/dispatch/OperatorEntry.h>
7 #include <ATen/core/dispatch/CppSignature.h>
8 #include <ATen/core/dispatch/RegistrationHandleRAII.h>
9 #include <ATen/record_function.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/LeftRight.h>
12 #include <list>
13 #include <mutex>
14 #include <condition_variable>
15 #include <type_traits>
16 #include <c10/core/SafePyObject.h>
17 
18 #include <ATen/core/grad_mode.h>
19 #include <ATen/core/enum_tag.h>
20 
21 #ifndef NDEBUG
22 #include <iostream>
23 #endif
24 
25 namespace c10 {
26 
27 TORCH_API bool show_dispatch_trace();
28 TORCH_API void dispatch_trace_nesting_incr();
29 TORCH_API void dispatch_trace_nesting_decr();
30 TORCH_API int64_t dispatch_trace_nesting_value();
31 
32 struct DispatchTraceNestingGuard {
DispatchTraceNestingGuardDispatchTraceNestingGuard33   DispatchTraceNestingGuard() { dispatch_trace_nesting_incr(); }
~DispatchTraceNestingGuardDispatchTraceNestingGuard34   ~DispatchTraceNestingGuard() { dispatch_trace_nesting_decr(); }
35 };
36 
37 class TORCH_API OperatorHandle;
38 template<class FuncType> class TypedOperatorHandle;
39 
40 /**
41  * Implement this interface and register your instance with the dispatcher
42  * to get notified when operators are registered or deregistered with
43  * the dispatcher.
44  *
45  * NB: registration events only occur when a 'def' occurs; we don't trigger
46  * on 'impl' or 'fallback' calls.
47  */
48 class TORCH_API OpRegistrationListener {
49 public:
50   virtual ~OpRegistrationListener();
51 
52   virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
53   virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
54 };
55 
56 namespace detail {
57 class RegistrationListenerList;
58 }
59 class SchemaRegistrationHandleRAII;
60 
61 /**
62  * Top-level dispatch interface for dispatching via the dynamic dispatcher.
63  * Most end users shouldn't use this directly; if you're trying to register
64  * ops look in op_registration
65  */
66 class TORCH_API Dispatcher final {
67 private:
68   // For direct access to backend fallback information
69   friend class impl::OperatorEntry;
70 
71   struct OperatorDef final {
OperatorDeffinal72     explicit OperatorDef(OperatorName&& op_name)
73     : op(std::move(op_name)) {}
74 
75     impl::OperatorEntry op;
76 
77     // These refer to the number of outstanding RegistrationHandleRAII
78     // for this operator.  def_count reflects only def() registrations
79     // (in the new world, this should only ever be 1, but old style
80     // registrations may register the schema multiple times, which
81     // will increase this count).  def_and_impl_count reflects the number
82     // of combined def() and impl() registrations.  When the last def() gets
83     // unregistered, we must immediately call the Deregistered listeners, but we
84     // must not actually delete the handle as there are other outstanding RAII
85     // destructors which will try to destruct and they had better still have a
86     // working operator handle in this case
87     size_t def_count = 0;
88     size_t def_and_impl_count = 0;
89   };
90   friend class OperatorHandle;
91   template<class> friend class TypedOperatorHandle;
92 
93   struct Guard final {
Guardfinal94     Guard() : alive(true), mutex() {}
95     std::atomic<bool> alive;
96     std::mutex mutex;
97   };
98 
99 public:
100   ~Dispatcher();
101 
102   // Implementation note: this class abstracts over the fact that we have per-operator
103   // dispatch tables.  This could be easily adjusted to have a single global hash
104   // table.
105   static Dispatcher& realSingleton();
106 
singleton()107   C10_ALWAYS_INLINE static Dispatcher& singleton() {
108 #if !defined C10_MOBILE
109     // Implemented inline so that steady-state code needn't incur
110     // function-call overhead. We can't just inline `realSingleton`
111     // because the function-local static would get duplicated across
112     // all DSOs that include & use this header, leading to multiple
113     // singleton instances.
114     static Dispatcher& s = realSingleton();
115     return s;
116 #else
117     // For C10_MOBILE, we should never inline a static function that
118     // has a static member, since the generated code calls
119     // __cxa_guard_acquire and __cxa_guard_release which help
120     // implement exactly once semantics for the initialization of the
121     // static Dispatcher& s above (for the non-mobile case). That
122     // additional code when duplicated across all operator stubs
123     // for every backend results in a lot of additional code
124     // being generated by the compiler.
125     return realSingleton();
126 #endif
127   }
128 
129   // ------------------------------------------------------------------------
130   //
131   // Accessing operators by schema
132   //
133   // ------------------------------------------------------------------------
134 
135   /**
136    * Looks for an operator schema with the given name and overload name
137    * and returns it if it is registered WITH A SCHEMA.
138    * Returns nullopt otherwise.
139    */
140   std::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
141 
142   /**
143    * Variant of findSchema that results in less code generated at the call site.
144    * It (1) takes const char* pointer rather than OperatorName (so we skip
145    * generating std::string constructor calls at the call site), and (2)
146    * it raises an exception if the operator is not found (so we skip
147    * generating exception raising code at the call site)
148    *
149    * Irritatingly, we still have to generate the handful of instructions
150    * for dealing with an exception being thrown during static initialization
151    * (e.g. __cxa_guard_abort).  If we could annotate this method noexcept we
152    * could avoid this code too, but as the name of the function suggests,
153    * it does throw exceptions.
154    */
155   OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
156 
157   // Like findSchema, but also returns OperatorHandle even if there is no schema
158   std::optional<OperatorHandle> findOp(const OperatorName& operator_name);
159 
160   // Returns a list of all operator names present in the operatorLookupTable_
161   const std::vector<OperatorName> getAllOpNames();
162 
163   // ------------------------------------------------------------------------
164   //
165   // Invoking operators
166   //
167   // ------------------------------------------------------------------------
168 
169   template<class Return, class... Args>
170   Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
171 
172 
173   template<class Return, class... Args>
174   static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
175 
176   // Like call, but intended for use in a redispatch in kernels that have explicitly performed the DispatchKey update calculatulation.
177   // This will take the DispatchKeySet completely as is and dispatch to the kernel of the corresponding highest priority key in the set.
178   // Note that this version of redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask out the highest priority key.
179   // See Note [Plumbing Keys Through The Dispatcher]
180   template<class Return, class... Args>
181   Return redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const;
182 
183   // Invoke an operator via the boxed calling convention using an IValue stack
184   void callBoxed(const OperatorHandle& op, Stack* stack) const;
185   void callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const;
186 
187   // TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
188   // See Note [Plumbing Keys Through The Dispatcher]
189   void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
190 
hasBackendFallbackForDispatchKey(DispatchKey dk)191   bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
192     auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
193     if (dispatch_ix < 0) return false;
194     return backendFallbackKernels_[dispatch_ix].kernel.isValid();
195   }
196 
197   // Used by torchdeploy/multipy for multiple interpreters racing.
198   void waitForDef(const FunctionSchema& schema);
199   void waitForImpl(const OperatorName& op_name, std::optional<DispatchKey> dispatch_key);
200 
201   // ------------------------------------------------------------------------
202   //
203   // Performing registrations (NON user public; use op_registration)
204   //
205   // ------------------------------------------------------------------------
206 
207   /**
208    * Register a new operator schema.
209    *
210    * If a schema with the same operator name and overload name already exists,
211    * this function will check that both schemas are exactly identical.
212    */
213   RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {});
214 
215   /**
216    * Register a kernel to the dispatch table for an operator.
217    * If dispatch_key is nullopt, then this registers a fallback kernel.
218    *
219    * @return A RAII object that manages the lifetime of the registration.
220    *         Once that object is destructed, the kernel will be deregistered.
221    */
222   // NB: steals the inferred function schema, as we may need to hold on to
223   // it for a bit until the real schema turns up
224   RegistrationHandleRAII registerImpl(OperatorName op_name, std::optional<DispatchKey> dispatch_key, KernelFunction kernel, std::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
225 
226   /**
227    * Given an operator, tells the Dispatcher that we have implemented a fake impl
228    * for this op in the given Python module. Call this a "pystub".
229    */
230   RegistrationHandleRAII registerPythonModule(const OperatorName& op_name, const char* pymodule, const char* context);
231 
232   /**
233    * Given an operator, throws if we have a pystub.
234    */
235   void throwIfHasPythonModule(OperatorName op_name);
236 
237   std::optional<std::pair<const char*, const char*>> getPyStub(OperatorName op_name);
238 
239   /**
240    * Register a new operator by name.
241    */
242   RegistrationHandleRAII registerName(OperatorName op_name);
243 
244   /**
245    * Register a fallback kernel for a backend.
246    * If an operator is called but there is no concrete kernel for the dispatch
247    * key of the given operator arguments, it will check if there is such a
248    * fallback kernel for the given dispatch key and, if yes, call that one.
249    */
250   RegistrationHandleRAII registerFallback(DispatchKey dispatch_key, KernelFunction kernel, std::string debug);
251 
252   /**
253    * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
254    * API.  These invocations are only permitted once per program, so we raise
255    * an error if this is called again for the same namespace.
256    */
257   RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
258 
259   // ------------------------------------------------------------------------
260   //
261   // Listeners on registrations
262   //
263   // ------------------------------------------------------------------------
264 
265   /**
266    * Add a listener that gets called whenever a new op is registered or an existing
267    * op is deregistered. Immediately after registering, this listener gets called
268    * for all previously registered ops, so it can be used to keep track of ops
269    * registered with this dispatcher.
270    */
271   RegistrationHandleRAII addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener);
272 
273   void checkInvariants() const;
274 
275   //
276   // ------------------------------------------------------------------------
277   //
278   // Assertions
279   //
280   // ------------------------------------------------------------------------
281 
282   /**
283    * For testing purposes.
284    * Returns a list of all operators that were created through calls to registerImpl(),
285    * without any corresponding calls to registerDef(). After static initialization
286    * is done this is almost certainly a bug, as the created OperatorHandle won't have
287    * any schema associated with it and users calling the op through the dispatcher
288    * won't be able to access it
289    *
290    * Note that we cannot enforce this invariant "as we go" during static initialization,
291    * due to undefined static initialization order- we have no guarantees over the order
292    * in which .def() and .impl() calls are registered in the dispatcher at static
293    * initialization time. So this function should only be called after static initialization.
294    */
295   std::vector<OperatorHandle> findDanglingImpls() const;
296 
297   /**
298    * Useful for inspecting global Dispatcher registration state.
299    * Returns the names of all operators with a kernel registered for the specified DispatchKey.
300    * If no DispatchKey is specified, it returns all registered operators.
301    */
302   std::vector<OperatorName> getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const;
303 
304 private:
305   Dispatcher();
306 
307   static int64_t sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey, DispatchKeySet dispatchKeySet);
308   static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet);
309   static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet, c10::ArrayRef<const c10::IValue> args);
310 
311   #ifdef FBCODE_CAFFE2
312   static bool profilingOperatorEvents();
313   static void fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref);
314   static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref);
315   #endif // FBCODE_CAFFE2
316 
317   OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
318   OperatorHandle findOrRegisterName_(const OperatorName& op_name);
319 
320   void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
321   void deregisterImpl_(
322     const OperatorHandle& op,
323     const OperatorName& op_name,
324     std::optional<DispatchKey> dispatch_key,
325     impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
326   void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
327   void deregisterFallback_(DispatchKey dispatchKey);
328   void deregisterLibrary_(const std::string& ns);
329   void cleanup(const OperatorHandle& op, const OperatorName& op_name);
330   void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug);
331 
332   std::list<OperatorDef> operators_;
333 #if !defined(C10_MOBILE)
334   LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
335 #else
336   RWSafeLeftRightWrapper<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
337 #endif
338   // Map from namespace to debug string (saying, e.g., where the library was defined)
339   ska::flat_hash_map<std::string, std::string> libraries_;
340 
341   std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
342 
343   std::unique_ptr<detail::RegistrationListenerList> listeners_;
344 
345   // This condition variable gets notified whenever we add a new def/impl to the
346   // dispatch table.  This is primarily used by multipy/torchdeploy, when
347   // we have multiple interpreters trying to register to the dispatch table.
348   // In this situation, whenever the non-primary interpreter would have tried
349   // to register to the dispatch table, instead it will check to see if the
350   // expected registration has already been made, and if it hasn't, wait on
351   // this condition variable to see if it was just racing with the primary
352   // interpreter.
353   //
354   // We expect it to be rare for there to be any waiters on this condition
355   // variable.  This is mostly just to help give better diagnostics if
356   // something goes horribly wrong
357   std::condition_variable cond_var_;
358 
359   // Protect concurrent access to the dispatcher.  We store this in a
360   // `shared_ptr` as we return callbacks that call back into dispatcher methods,
361   // and we need to be able to handle and guard against the event when the
362   // `Dispatcher` has been destroyed before the callbacks fire.
363   std::shared_ptr<Guard> guard_;
364 };
365 
366 /**
367  * This is a handle to an operator schema registered with the dispatcher.
368  * This handle can be used to register kernels with the dispatcher or
369  * to lookup a kernel for a certain set of arguments.
370  */
371 class TORCH_API OperatorHandle {
372   template <typename T> friend struct std::hash;
373 
374 public:
375   OperatorHandle(OperatorHandle&&) noexcept = default;
376   OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
377   OperatorHandle(const OperatorHandle&) = default;
378   OperatorHandle& operator=(const OperatorHandle&) = default;
379   // NOLINTNEXTLINE(performance-trivially-destructible)
380   ~OperatorHandle();
381 
operator_name()382   const OperatorName& operator_name() const {
383     return operatorDef_->op.operator_name();
384   }
385 
hasSchema()386   bool hasSchema() const {
387     return operatorDef_->op.hasSchema();
388   }
389 
schema()390   const FunctionSchema& schema() const {
391     return operatorDef_->op.schema();
392   }
393 
debug()394   const std::string& debug() const {
395     return operatorDef_->op.debug();
396   }
397 
dumpState()398   std::string dumpState() const {
399     return operatorDef_->op.dumpState();
400   }
401 
hasKernelForDispatchKey(DispatchKey k)402   bool hasKernelForDispatchKey(DispatchKey k) const {
403     return operatorDef_->op.hasKernelForDispatchKey(k);
404   }
405 
isKernelFallthroughKernel(DispatchKey k)406   bool isKernelFallthroughKernel(DispatchKey k) const {
407     return operatorDef_->op.kernelForDispatchKey(k).isFallthrough();
408   }
409 
hasKernelForAnyDispatchKey(DispatchKeySet k)410   bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
411     return operatorDef_->op.hasKernelForAnyDispatchKey(k);
412   }
413 
hasComputedKernelForDispatchKey(DispatchKey k)414   bool hasComputedKernelForDispatchKey(DispatchKey k) const {
415     return operatorDef_->op.hasComputedKernelForDispatchKey(k);
416   }
417 
dumpComputedTable()418   std::string dumpComputedTable() const {
419     return operatorDef_->op.dumpComputedTable();
420   }
421 
checkInvariants()422   void checkInvariants() const {
423     return operatorDef_->op.checkInvariants();
424   }
425 
getTags()426   c10::ArrayRef<at::Tag> getTags() const {
427     return operatorDef_->op.getTags();
428   }
429 
setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback)430   void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
431     operatorDef_->op.setReportErrorCallback_(std::move(callback));
432   }
433 
hasTag(const at::Tag & tag)434   bool hasTag(const at::Tag& tag) const {
435     for(const auto& tag_: getTags()) {
436       if (tag == tag_) {
437         return true;
438       }
439     }
440     return false;
441   }
442 
443   template<class FuncType>
typed()444   TypedOperatorHandle<FuncType> typed() const {
445     // NB: This assert is not 100% sound: you can retrieve a typed() operator
446     // handle prior to ANY C++ signature being registered on the operator
447     // and the check will say everything is OK (at which point you can then
448     // smuggle in a kernel that is typed incorrectly).  For everything
449     // in core library this won't happen, because all the static registrations
450     // will be done by the time a typed() handle is acquired.
451 #if !defined C10_MOBILE
452     operatorDef_->op.assertSignatureIsCorrect<FuncType>();
453     if (fn_has_symint<FuncType>::value) {
454       operatorDef_->op.assertSignatureIsCorrect<typename fn_remove_symint<FuncType>::type>();
455     }
456 #endif
457     return TypedOperatorHandle<FuncType>(operatorIterator_);
458   }
459 
callBoxed(Stack * stack)460   void callBoxed(Stack* stack) const {
461     c10::Dispatcher::singleton().callBoxed(*this, stack);
462   }
463 
callBoxed(Stack & stack)464   void callBoxed(Stack& stack) const {
465     callBoxed(&stack);
466   }
467 
callBoxedForDispatchKey(DispatchKey dk,Stack & stack)468   void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const {
469     c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack);
470   }
471 
redispatchBoxed(DispatchKeySet ks,Stack * stack)472   void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
473     c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
474   }
475 
476   template <typename F>
getPythonOp(c10::impl::PyInterpreter * self_interpreter,F slow_accessor)477   PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const {
478     return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
479   }
480 
481   bool operator==(const OperatorHandle& other) const {
482     return operatorDef_ == other.operatorDef_;
483   }
484 
485   bool operator!=(const OperatorHandle& other) const {
486     return operatorDef_ != other.operatorDef_;
487   }
488 
489 private:
OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)490   explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
491   : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator)  {}
492   friend class Dispatcher;
493   template<class> friend class TypedOperatorHandle;
494 
495   // Storing a direct pointer to the OperatorDef even though we
496   // already have the iterator saves an instruction in the critical
497   // dispatch path. The iterator is effectively a
498   // pointer-to-std::list-node, and (at least in libstdc++'s
499   // implementation) the element is at an offset 16 bytes from that,
500   // because the prev/next pointers come first in the list node
501   // struct. So, an add instruction would be necessary to convert from the
502   // iterator to an OperatorDef*.
503   Dispatcher::OperatorDef* operatorDef_;
504 
505   // We need to store this iterator in order to make
506   // Dispatcher::cleanup() fast -- it runs a lot on program
507   // termination (and presuambly library unloading).
508   std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
509 };
510 
511 /**
512  * This is a handle to an operator schema registered with the dispatcher.
513  * It holds the same information as an OperatorHandle, but it is templated
514  * on the operator arguments and allows calling the operator in an
515  * unboxed way.
516  */
517 template<class FuncType>
518 class TypedOperatorHandle final {
519   static_assert(guts::false_t<FuncType>(), "FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
520 };
521 template<class Return, class... Args>
522 class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
523 public:
524   TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
525   TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
526   TypedOperatorHandle(const TypedOperatorHandle&) = default;
527   TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
528 
529   // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
call(Args...args)530   C10_ALWAYS_INLINE Return call(Args... args) const {
531     return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
532   }
533 
534   // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
redispatch(DispatchKeySet currentDispatchKeySet,Args...args)535   C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
536     return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
537   }
538 
539 private:
TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)540   explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
541   : OperatorHandle(operatorIterator) {}
542   friend class OperatorHandle;
543 };
544 
545 namespace detail {
unused_arg_(const Args &...)546 template <class... Args> inline void unused_arg_(const Args&...) {}
547 
548 // CaptureKernelCall is intended to capture return values from Dispatcher
549 // unboxed kernel calls. A record function may request to get outputs from the
550 // kernel calls. For boxed kernels, it's straightforward, the returned values
551 // are in the stack object. The stack can be passed to record functions. For
552 // unboxed kernels, we need to handle different kinds of return values, cache
553 // them temporarily, then release the values for the actual function call
554 // return.
555 template <typename ReturnType>
556 struct CaptureKernelCall {
557   template <typename F, typename... Args>
CaptureKernelCallCaptureKernelCall558   CaptureKernelCall(
559       const F& kernel,
560       const TypedOperatorHandle<ReturnType(Args...)>& op,
561       const DispatchKeySet& dispatchKeySet,
562       Args&&... args)
563       // Calls the kernel and capture the result in output_.
564       : output_{kernel.template call<ReturnType, Args...>(
565             op,
566             dispatchKeySet,
567             std::forward<Args>(args)...)} {}
568   // Wraps the return values in a Stack.
getOutputsCaptureKernelCall569   Stack getOutputs() {
570     Stack stack;
571     impl::push_outputs<ReturnType, false>::copy(output_, &stack);
572     return stack;
573   }
574   // Since we are returning the output_, we don't expect the output_ to be used
575   // afterward. Copy elision and RVO do not apply to class data members. Using
576   // move semantic to avoid copies when possible.
releaseCaptureKernelCall577   ReturnType release() && {
578     return std::move(output_);
579   }
580 
581  private:
582   ReturnType output_;
583 };
584 
585 // Handle the lvalue reference differently since it should not be moved.
586 template <>
release()587 inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
588   return output_;
589 }
590 
591 // Handle case where the kernel returns void.
592 template <>
593 struct CaptureKernelCall<void> {
594   template <typename F, typename... Args>
595   CaptureKernelCall(
596       const F& kernel,
597       const TypedOperatorHandle<void(Args...)>& op,
598       const DispatchKeySet& dispatchKeySet,
599       Args&&... args) {
600     // Calling the kernel and no need to capture void.
601     kernel.template call<void, Args...>(
602         op, dispatchKeySet, std::forward<Args>(args)...);
603   }
604   Stack getOutputs() {
605     return Stack();
606   }
607   void release() && {}
608 };
609 
610 TORCH_API void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet);
611 
612 } // namespace detail
613 
614 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
615 template<class Return, class... Args>
616 inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
617   // If callbacks need inputs, we box the arguments and pass them to the guard.
618   // Note: For perf reasons we wouldn't want to prematurely box the arguments.
619   at::RecordFunction guard(std::move(stepCallbacks));
620   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
621   auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
622   auto& schema = op.schema();
623   auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
624   constexpr auto num_boxed_args = impl::boxed_size<Args...>();
625   if constexpr (num_boxed_args != 0) {
626     if (guard.needsInputs()) {
627       // If we used std::array<IValue, num_boxed_args> here, we would
628       // have to spend time default constructing the IValues in
629       // boxedArgs. aligned_storage has no such requirement.
630       impl::IValueAlignedStorage boxedArgs[num_boxed_args];
631       // For debugging only; could be removed (but the compiler will do
632       // that for us and it's nice to have the extra assurance of
633       // correctness from our debug builds).
634       int lastArgIdx = 0;
635       impl::boxArgsToStack(boxedArgs, lastArgIdx, args...);
636       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(lastArgIdx == num_boxed_args);
637       // I don't *think* we need std::launder here, because IValue has
638       // no subclasses and no const or reference fields.
639       runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet, c10::ArrayRef<const c10::IValue>(reinterpret_cast<IValue *>(boxedArgs), num_boxed_args));
640       for (size_t ii = 0; ii < num_boxed_args; ++ii) {
641         reinterpret_cast<IValue *>(&boxedArgs[ii])->~IValue();
642       }
643     } else {
644       runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
645     }
646   } else {
647     runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
648   }
649 
650   if (C10_UNLIKELY(guard.needsOutputs())) {
651     // Calls the kernel and capture the output temporarily to pass to
652     // RecordFunction.
653     detail::CaptureKernelCall<Return> captureKernelCall(
654         kernel, op, dispatchKeySet, std::forward<Args>(args)...);
655     guard.setOutputs(captureKernelCall.getOutputs());
656     // Releases the captured output to return to caller.
657     return std::move(captureKernelCall).release();
658   }
659 
660   // keeping the guard alive while executing the kernel
661   return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
662 }
663 
664 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
665 template<class Return, class... Args>
666 C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
667   detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
668   auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
669     .template getDispatchKeySetUnboxed<Args...>(args...);
670 #ifndef NDEBUG
671   DispatchTraceNestingGuard debug_guard;
672   if (show_dispatch_trace()) {
673     detail::_print_dispatch_trace("[call]", toString(op.operator_name()), dispatchKeySet);
674   }
675 #endif
676   const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
677 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
678   auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
679   if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
680     return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
681   }
682 #endif  // PYTORCH_DISABLE_PER_OP_PROFILING
683 
684 #ifdef FBCODE_CAFFE2
685   if(profilingOperatorEvents()) {
686     struct FireOpRAII {
687        FireOpRAII(at::RecordFunction::schema_ref_t schema_ref) : schema_ref_(schema_ref) {
688            fireOpStartUSDT(schema_ref);
689         }
690        ~FireOpRAII() { fireOpEndUSDT(schema_ref_); }
691        at::RecordFunction::schema_ref_t schema_ref_;
692     } event(op.schema());
693     return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
694   } else {
695     return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
696   }
697 #else
698     return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
699 #endif // FBCODE_CAFFE2
700 }
701 
702 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
703 template<class Return, class... Args>
704 inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
705   detail::unused_arg_(args...);  // workaround for a false-positive warning about unused parameters in gcc 5
706   // do not use RecordFunction on redispatch
707 #ifndef NDEBUG
708   DispatchTraceNestingGuard debug_guard;
709   if (show_dispatch_trace()) {
710     detail::_print_dispatch_trace("[redispatch]", toString(op.operator_name()), currentDispatchKeySet);
711   }
712 #endif
713   const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
714   return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
715 }
716 
717 inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
718   // note: this doesn't need the mutex because write operations on the list keep iterators intact.
719   const auto& entry = op.operatorDef_->op;
720   auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
721 #ifndef NDEBUG
722   DispatchTraceNestingGuard debug_guard;
723   if (show_dispatch_trace()) {
724     detail::_print_dispatch_trace("[callBoxed]", toString(op.operator_name()), dispatchKeySet);
725   }
726 #endif
727   const auto& kernel = entry.lookup(dispatchKeySet);
728 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
729   auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
730   if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
731     at::RecordFunction guard(std::move(*step_callbacks));
732     auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
733     auto& schema = op.schema();
734     auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
735     guard.needsInputs() ? runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet, c10::ArrayRef<const c10::IValue>(stack->data(), stack->size()))
736                         : runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
737 
738     // keeping the guard alive while executing the kernel
739     kernel.callBoxed(op, dispatchKeySet, stack);
740 
741     if (C10_UNLIKELY(guard.needsOutputs())) {
742       guard.setOutputs(*stack);
743     }
744     return;
745   }
746 #endif  // PYTORCH_DISABLE_PER_OP_PROFILING
747   kernel.callBoxed(op, dispatchKeySet, stack);
748 }
749 
750 // NB: this doesn't count as a "true" dispatcher jump, so no instrumentation
751 inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const {
752   // note: this doesn't need the mutex because write operations on the list keep iterators intact.
753   const auto& entry = op.operatorDef_->op;
754   // We still compute this as we're obligated to pass it on to the internal
755   // kernel, if it is a boxed fallback
756   auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
757   const auto& kernel = ([&]() {
758     if (op.hasKernelForDispatchKey(dk)) {
759       return entry.kernelForDispatchKey(dk);
760     } else {
761       auto idx = getDispatchTableIndexForDispatchKey(dk);
762       TORCH_INTERNAL_ASSERT(idx >= 0);
763       return backendFallbackKernels_[idx].kernel;
764     }
765   })();
766   kernel.callBoxed(op, dispatchKeySet, stack);
767 }
768 
769 inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
770   // note: this doesn't need the mutex because write operations on the list keep iterators intact.
771   const auto& entry = op.operatorDef_->op;
772 #ifndef NDEBUG
773   DispatchTraceNestingGuard debug_guard;
774   if (show_dispatch_trace()) {
775     detail::_print_dispatch_trace("[redispatchBoxed]", toString(op.operator_name()), dispatchKeySet);
776   }
777 #endif
778   const auto& kernel = entry.lookup(dispatchKeySet);
779   return kernel.callBoxed(op, dispatchKeySet, stack);
780 }
781 
782 } // namespace c10
783 
784 namespace std {
785 
786 template <>
787 struct hash<c10::OperatorHandle> {
788   size_t operator()(const c10::OperatorHandle& op) const noexcept {
789     return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
790   }
791 };
792 
793 } // namespace std
794