1 #pragma once
2
3 #include <ATen/SequenceNumber.h>
4 #include <ATen/core/boxing/KernelFunction.h>
5 #include <ATen/core/boxing/impl/boxing.h>
6 #include <ATen/core/dispatch/OperatorEntry.h>
7 #include <ATen/core/dispatch/CppSignature.h>
8 #include <ATen/core/dispatch/RegistrationHandleRAII.h>
9 #include <ATen/record_function.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/LeftRight.h>
12 #include <list>
13 #include <mutex>
14 #include <condition_variable>
15 #include <type_traits>
16 #include <c10/core/SafePyObject.h>
17
18 #include <ATen/core/grad_mode.h>
19 #include <ATen/core/enum_tag.h>
20
21 #ifndef NDEBUG
22 #include <iostream>
23 #endif
24
25 namespace c10 {
26
27 TORCH_API bool show_dispatch_trace();
28 TORCH_API void dispatch_trace_nesting_incr();
29 TORCH_API void dispatch_trace_nesting_decr();
30 TORCH_API int64_t dispatch_trace_nesting_value();
31
32 struct DispatchTraceNestingGuard {
DispatchTraceNestingGuardDispatchTraceNestingGuard33 DispatchTraceNestingGuard() { dispatch_trace_nesting_incr(); }
~DispatchTraceNestingGuardDispatchTraceNestingGuard34 ~DispatchTraceNestingGuard() { dispatch_trace_nesting_decr(); }
35 };
36
37 class TORCH_API OperatorHandle;
38 template<class FuncType> class TypedOperatorHandle;
39
40 /**
41 * Implement this interface and register your instance with the dispatcher
42 * to get notified when operators are registered or deregistered with
43 * the dispatcher.
44 *
45 * NB: registration events only occur when a 'def' occurs; we don't trigger
46 * on 'impl' or 'fallback' calls.
47 */
48 class TORCH_API OpRegistrationListener {
49 public:
50 virtual ~OpRegistrationListener();
51
52 virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
53 virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
54 };
55
56 namespace detail {
57 class RegistrationListenerList;
58 }
59 class SchemaRegistrationHandleRAII;
60
61 /**
62 * Top-level dispatch interface for dispatching via the dynamic dispatcher.
63 * Most end users shouldn't use this directly; if you're trying to register
64 * ops look in op_registration
65 */
66 class TORCH_API Dispatcher final {
67 private:
68 // For direct access to backend fallback information
69 friend class impl::OperatorEntry;
70
71 struct OperatorDef final {
OperatorDeffinal72 explicit OperatorDef(OperatorName&& op_name)
73 : op(std::move(op_name)) {}
74
75 impl::OperatorEntry op;
76
77 // These refer to the number of outstanding RegistrationHandleRAII
78 // for this operator. def_count reflects only def() registrations
79 // (in the new world, this should only ever be 1, but old style
80 // registrations may register the schema multiple times, which
81 // will increase this count). def_and_impl_count reflects the number
82 // of combined def() and impl() registrations. When the last def() gets
83 // unregistered, we must immediately call the Deregistered listeners, but we
84 // must not actually delete the handle as there are other outstanding RAII
85 // destructors which will try to destruct and they had better still have a
86 // working operator handle in this case
87 size_t def_count = 0;
88 size_t def_and_impl_count = 0;
89 };
90 friend class OperatorHandle;
91 template<class> friend class TypedOperatorHandle;
92
93 struct Guard final {
Guardfinal94 Guard() : alive(true), mutex() {}
95 std::atomic<bool> alive;
96 std::mutex mutex;
97 };
98
99 public:
100 ~Dispatcher();
101
102 // Implementation note: this class abstracts over the fact that we have per-operator
103 // dispatch tables. This could be easily adjusted to have a single global hash
104 // table.
105 static Dispatcher& realSingleton();
106
singleton()107 C10_ALWAYS_INLINE static Dispatcher& singleton() {
108 #if !defined C10_MOBILE
109 // Implemented inline so that steady-state code needn't incur
110 // function-call overhead. We can't just inline `realSingleton`
111 // because the function-local static would get duplicated across
112 // all DSOs that include & use this header, leading to multiple
113 // singleton instances.
114 static Dispatcher& s = realSingleton();
115 return s;
116 #else
117 // For C10_MOBILE, we should never inline a static function that
118 // has a static member, since the generated code calls
119 // __cxa_guard_acquire and __cxa_guard_release which help
120 // implement exactly once semantics for the initialization of the
121 // static Dispatcher& s above (for the non-mobile case). That
122 // additional code when duplicated across all operator stubs
123 // for every backend results in a lot of additional code
124 // being generated by the compiler.
125 return realSingleton();
126 #endif
127 }
128
129 // ------------------------------------------------------------------------
130 //
131 // Accessing operators by schema
132 //
133 // ------------------------------------------------------------------------
134
135 /**
136 * Looks for an operator schema with the given name and overload name
137 * and returns it if it is registered WITH A SCHEMA.
138 * Returns nullopt otherwise.
139 */
140 std::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
141
142 /**
143 * Variant of findSchema that results in less code generated at the call site.
144 * It (1) takes const char* pointer rather than OperatorName (so we skip
145 * generating std::string constructor calls at the call site), and (2)
146 * it raises an exception if the operator is not found (so we skip
147 * generating exception raising code at the call site)
148 *
149 * Irritatingly, we still have to generate the handful of instructions
150 * for dealing with an exception being thrown during static initialization
151 * (e.g. __cxa_guard_abort). If we could annotate this method noexcept we
152 * could avoid this code too, but as the name of the function suggests,
153 * it does throw exceptions.
154 */
155 OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
156
157 // Like findSchema, but also returns OperatorHandle even if there is no schema
158 std::optional<OperatorHandle> findOp(const OperatorName& operator_name);
159
160 // Returns a list of all operator names present in the operatorLookupTable_
161 const std::vector<OperatorName> getAllOpNames();
162
163 // ------------------------------------------------------------------------
164 //
165 // Invoking operators
166 //
167 // ------------------------------------------------------------------------
168
169 template<class Return, class... Args>
170 Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
171
172
173 template<class Return, class... Args>
174 static Return callWithDispatchKeySlowPath(const TypedOperatorHandle<Return (Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args);
175
176 // Like call, but intended for use in a redispatch in kernels that have explicitly performed the DispatchKey update calculatulation.
177 // This will take the DispatchKeySet completely as is and dispatch to the kernel of the corresponding highest priority key in the set.
178 // Note that this version of redispatch treats the inputted DispatchKeySet *as is*, and does NOT mask out the highest priority key.
179 // See Note [Plumbing Keys Through The Dispatcher]
180 template<class Return, class... Args>
181 Return redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const;
182
183 // Invoke an operator via the boxed calling convention using an IValue stack
184 void callBoxed(const OperatorHandle& op, Stack* stack) const;
185 void callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const;
186
187 // TODO: This will only be useful if we write a backend fallback that plumbs dispatch keys (currently there are none)
188 // See Note [Plumbing Keys Through The Dispatcher]
189 void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
190
hasBackendFallbackForDispatchKey(DispatchKey dk)191 bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
192 auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
193 if (dispatch_ix < 0) return false;
194 return backendFallbackKernels_[dispatch_ix].kernel.isValid();
195 }
196
197 // Used by torchdeploy/multipy for multiple interpreters racing.
198 void waitForDef(const FunctionSchema& schema);
199 void waitForImpl(const OperatorName& op_name, std::optional<DispatchKey> dispatch_key);
200
201 // ------------------------------------------------------------------------
202 //
203 // Performing registrations (NON user public; use op_registration)
204 //
205 // ------------------------------------------------------------------------
206
207 /**
208 * Register a new operator schema.
209 *
210 * If a schema with the same operator name and overload name already exists,
211 * this function will check that both schemas are exactly identical.
212 */
213 RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags = {});
214
215 /**
216 * Register a kernel to the dispatch table for an operator.
217 * If dispatch_key is nullopt, then this registers a fallback kernel.
218 *
219 * @return A RAII object that manages the lifetime of the registration.
220 * Once that object is destructed, the kernel will be deregistered.
221 */
222 // NB: steals the inferred function schema, as we may need to hold on to
223 // it for a bit until the real schema turns up
224 RegistrationHandleRAII registerImpl(OperatorName op_name, std::optional<DispatchKey> dispatch_key, KernelFunction kernel, std::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
225
226 /**
227 * Given an operator, tells the Dispatcher that we have implemented a fake impl
228 * for this op in the given Python module. Call this a "pystub".
229 */
230 RegistrationHandleRAII registerPythonModule(const OperatorName& op_name, const char* pymodule, const char* context);
231
232 /**
233 * Given an operator, throws if we have a pystub.
234 */
235 void throwIfHasPythonModule(OperatorName op_name);
236
237 std::optional<std::pair<const char*, const char*>> getPyStub(OperatorName op_name);
238
239 /**
240 * Register a new operator by name.
241 */
242 RegistrationHandleRAII registerName(OperatorName op_name);
243
244 /**
245 * Register a fallback kernel for a backend.
246 * If an operator is called but there is no concrete kernel for the dispatch
247 * key of the given operator arguments, it will check if there is such a
248 * fallback kernel for the given dispatch key and, if yes, call that one.
249 */
250 RegistrationHandleRAII registerFallback(DispatchKey dispatch_key, KernelFunction kernel, std::string debug);
251
252 /**
253 * Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
254 * API. These invocations are only permitted once per program, so we raise
255 * an error if this is called again for the same namespace.
256 */
257 RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
258
259 // ------------------------------------------------------------------------
260 //
261 // Listeners on registrations
262 //
263 // ------------------------------------------------------------------------
264
265 /**
266 * Add a listener that gets called whenever a new op is registered or an existing
267 * op is deregistered. Immediately after registering, this listener gets called
268 * for all previously registered ops, so it can be used to keep track of ops
269 * registered with this dispatcher.
270 */
271 RegistrationHandleRAII addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener);
272
273 void checkInvariants() const;
274
275 //
276 // ------------------------------------------------------------------------
277 //
278 // Assertions
279 //
280 // ------------------------------------------------------------------------
281
282 /**
283 * For testing purposes.
284 * Returns a list of all operators that were created through calls to registerImpl(),
285 * without any corresponding calls to registerDef(). After static initialization
286 * is done this is almost certainly a bug, as the created OperatorHandle won't have
287 * any schema associated with it and users calling the op through the dispatcher
288 * won't be able to access it
289 *
290 * Note that we cannot enforce this invariant "as we go" during static initialization,
291 * due to undefined static initialization order- we have no guarantees over the order
292 * in which .def() and .impl() calls are registered in the dispatcher at static
293 * initialization time. So this function should only be called after static initialization.
294 */
295 std::vector<OperatorHandle> findDanglingImpls() const;
296
297 /**
298 * Useful for inspecting global Dispatcher registration state.
299 * Returns the names of all operators with a kernel registered for the specified DispatchKey.
300 * If no DispatchKey is specified, it returns all registered operators.
301 */
302 std::vector<OperatorName> getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const;
303
304 private:
305 Dispatcher();
306
307 static int64_t sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey, DispatchKeySet dispatchKeySet);
308 static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet);
309 static void runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet, c10::ArrayRef<const c10::IValue> args);
310
311 #ifdef FBCODE_CAFFE2
312 static bool profilingOperatorEvents();
313 static void fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref);
314 static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref);
315 #endif // FBCODE_CAFFE2
316
317 OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
318 OperatorHandle findOrRegisterName_(const OperatorName& op_name);
319
320 void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
321 void deregisterImpl_(
322 const OperatorHandle& op,
323 const OperatorName& op_name,
324 std::optional<DispatchKey> dispatch_key,
325 impl::OperatorEntry::AnnotatedKernelContainerIterator kernel_handle);
326 void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
327 void deregisterFallback_(DispatchKey dispatchKey);
328 void deregisterLibrary_(const std::string& ns);
329 void cleanup(const OperatorHandle& op, const OperatorName& op_name);
330 void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug);
331
332 std::list<OperatorDef> operators_;
333 #if !defined(C10_MOBILE)
334 LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
335 #else
336 RWSafeLeftRightWrapper<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
337 #endif
338 // Map from namespace to debug string (saying, e.g., where the library was defined)
339 ska::flat_hash_map<std::string, std::string> libraries_;
340
341 std::array<impl::AnnotatedKernel, num_runtime_entries> backendFallbackKernels_;
342
343 std::unique_ptr<detail::RegistrationListenerList> listeners_;
344
345 // This condition variable gets notified whenever we add a new def/impl to the
346 // dispatch table. This is primarily used by multipy/torchdeploy, when
347 // we have multiple interpreters trying to register to the dispatch table.
348 // In this situation, whenever the non-primary interpreter would have tried
349 // to register to the dispatch table, instead it will check to see if the
350 // expected registration has already been made, and if it hasn't, wait on
351 // this condition variable to see if it was just racing with the primary
352 // interpreter.
353 //
354 // We expect it to be rare for there to be any waiters on this condition
355 // variable. This is mostly just to help give better diagnostics if
356 // something goes horribly wrong
357 std::condition_variable cond_var_;
358
359 // Protect concurrent access to the dispatcher. We store this in a
360 // `shared_ptr` as we return callbacks that call back into dispatcher methods,
361 // and we need to be able to handle and guard against the event when the
362 // `Dispatcher` has been destroyed before the callbacks fire.
363 std::shared_ptr<Guard> guard_;
364 };
365
366 /**
367 * This is a handle to an operator schema registered with the dispatcher.
368 * This handle can be used to register kernels with the dispatcher or
369 * to lookup a kernel for a certain set of arguments.
370 */
371 class TORCH_API OperatorHandle {
372 template <typename T> friend struct std::hash;
373
374 public:
375 OperatorHandle(OperatorHandle&&) noexcept = default;
376 OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
377 OperatorHandle(const OperatorHandle&) = default;
378 OperatorHandle& operator=(const OperatorHandle&) = default;
379 // NOLINTNEXTLINE(performance-trivially-destructible)
380 ~OperatorHandle();
381
operator_name()382 const OperatorName& operator_name() const {
383 return operatorDef_->op.operator_name();
384 }
385
hasSchema()386 bool hasSchema() const {
387 return operatorDef_->op.hasSchema();
388 }
389
schema()390 const FunctionSchema& schema() const {
391 return operatorDef_->op.schema();
392 }
393
debug()394 const std::string& debug() const {
395 return operatorDef_->op.debug();
396 }
397
dumpState()398 std::string dumpState() const {
399 return operatorDef_->op.dumpState();
400 }
401
hasKernelForDispatchKey(DispatchKey k)402 bool hasKernelForDispatchKey(DispatchKey k) const {
403 return operatorDef_->op.hasKernelForDispatchKey(k);
404 }
405
isKernelFallthroughKernel(DispatchKey k)406 bool isKernelFallthroughKernel(DispatchKey k) const {
407 return operatorDef_->op.kernelForDispatchKey(k).isFallthrough();
408 }
409
hasKernelForAnyDispatchKey(DispatchKeySet k)410 bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
411 return operatorDef_->op.hasKernelForAnyDispatchKey(k);
412 }
413
hasComputedKernelForDispatchKey(DispatchKey k)414 bool hasComputedKernelForDispatchKey(DispatchKey k) const {
415 return operatorDef_->op.hasComputedKernelForDispatchKey(k);
416 }
417
dumpComputedTable()418 std::string dumpComputedTable() const {
419 return operatorDef_->op.dumpComputedTable();
420 }
421
checkInvariants()422 void checkInvariants() const {
423 return operatorDef_->op.checkInvariants();
424 }
425
getTags()426 c10::ArrayRef<at::Tag> getTags() const {
427 return operatorDef_->op.getTags();
428 }
429
setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback)430 void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
431 operatorDef_->op.setReportErrorCallback_(std::move(callback));
432 }
433
hasTag(const at::Tag & tag)434 bool hasTag(const at::Tag& tag) const {
435 for(const auto& tag_: getTags()) {
436 if (tag == tag_) {
437 return true;
438 }
439 }
440 return false;
441 }
442
443 template<class FuncType>
typed()444 TypedOperatorHandle<FuncType> typed() const {
445 // NB: This assert is not 100% sound: you can retrieve a typed() operator
446 // handle prior to ANY C++ signature being registered on the operator
447 // and the check will say everything is OK (at which point you can then
448 // smuggle in a kernel that is typed incorrectly). For everything
449 // in core library this won't happen, because all the static registrations
450 // will be done by the time a typed() handle is acquired.
451 #if !defined C10_MOBILE
452 operatorDef_->op.assertSignatureIsCorrect<FuncType>();
453 if (fn_has_symint<FuncType>::value) {
454 operatorDef_->op.assertSignatureIsCorrect<typename fn_remove_symint<FuncType>::type>();
455 }
456 #endif
457 return TypedOperatorHandle<FuncType>(operatorIterator_);
458 }
459
callBoxed(Stack * stack)460 void callBoxed(Stack* stack) const {
461 c10::Dispatcher::singleton().callBoxed(*this, stack);
462 }
463
callBoxed(Stack & stack)464 void callBoxed(Stack& stack) const {
465 callBoxed(&stack);
466 }
467
callBoxedForDispatchKey(DispatchKey dk,Stack & stack)468 void callBoxedForDispatchKey(DispatchKey dk, Stack& stack) const {
469 c10::Dispatcher::singleton().callBoxedForDispatchKey(*this, dk, &stack);
470 }
471
redispatchBoxed(DispatchKeySet ks,Stack * stack)472 void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
473 c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
474 }
475
476 template <typename F>
getPythonOp(c10::impl::PyInterpreter * self_interpreter,F slow_accessor)477 PyObject* getPythonOp(c10::impl::PyInterpreter* self_interpreter, F slow_accessor) const {
478 return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor);
479 }
480
481 bool operator==(const OperatorHandle& other) const {
482 return operatorDef_ == other.operatorDef_;
483 }
484
485 bool operator!=(const OperatorHandle& other) const {
486 return operatorDef_ != other.operatorDef_;
487 }
488
489 private:
OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)490 explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
491 : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
492 friend class Dispatcher;
493 template<class> friend class TypedOperatorHandle;
494
495 // Storing a direct pointer to the OperatorDef even though we
496 // already have the iterator saves an instruction in the critical
497 // dispatch path. The iterator is effectively a
498 // pointer-to-std::list-node, and (at least in libstdc++'s
499 // implementation) the element is at an offset 16 bytes from that,
500 // because the prev/next pointers come first in the list node
501 // struct. So, an add instruction would be necessary to convert from the
502 // iterator to an OperatorDef*.
503 Dispatcher::OperatorDef* operatorDef_;
504
505 // We need to store this iterator in order to make
506 // Dispatcher::cleanup() fast -- it runs a lot on program
507 // termination (and presuambly library unloading).
508 std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
509 };
510
511 /**
512 * This is a handle to an operator schema registered with the dispatcher.
513 * It holds the same information as an OperatorHandle, but it is templated
514 * on the operator arguments and allows calling the operator in an
515 * unboxed way.
516 */
517 template<class FuncType>
518 class TypedOperatorHandle final {
519 static_assert(guts::false_t<FuncType>(), "FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
520 };
521 template<class Return, class... Args>
522 class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
523 public:
524 TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
525 TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
526 TypedOperatorHandle(const TypedOperatorHandle&) = default;
527 TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
528
529 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
call(Args...args)530 C10_ALWAYS_INLINE Return call(Args... args) const {
531 return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
532 }
533
534 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
redispatch(DispatchKeySet currentDispatchKeySet,Args...args)535 C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
536 return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
537 }
538
539 private:
TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)540 explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
541 : OperatorHandle(operatorIterator) {}
542 friend class OperatorHandle;
543 };
544
545 namespace detail {
unused_arg_(const Args &...)546 template <class... Args> inline void unused_arg_(const Args&...) {}
547
548 // CaptureKernelCall is intended to capture return values from Dispatcher
549 // unboxed kernel calls. A record function may request to get outputs from the
550 // kernel calls. For boxed kernels, it's straightforward, the returned values
551 // are in the stack object. The stack can be passed to record functions. For
552 // unboxed kernels, we need to handle different kinds of return values, cache
553 // them temporarily, then release the values for the actual function call
554 // return.
555 template <typename ReturnType>
556 struct CaptureKernelCall {
557 template <typename F, typename... Args>
CaptureKernelCallCaptureKernelCall558 CaptureKernelCall(
559 const F& kernel,
560 const TypedOperatorHandle<ReturnType(Args...)>& op,
561 const DispatchKeySet& dispatchKeySet,
562 Args&&... args)
563 // Calls the kernel and capture the result in output_.
564 : output_{kernel.template call<ReturnType, Args...>(
565 op,
566 dispatchKeySet,
567 std::forward<Args>(args)...)} {}
568 // Wraps the return values in a Stack.
getOutputsCaptureKernelCall569 Stack getOutputs() {
570 Stack stack;
571 impl::push_outputs<ReturnType, false>::copy(output_, &stack);
572 return stack;
573 }
574 // Since we are returning the output_, we don't expect the output_ to be used
575 // afterward. Copy elision and RVO do not apply to class data members. Using
576 // move semantic to avoid copies when possible.
releaseCaptureKernelCall577 ReturnType release() && {
578 return std::move(output_);
579 }
580
581 private:
582 ReturnType output_;
583 };
584
585 // Handle the lvalue reference differently since it should not be moved.
586 template <>
release()587 inline at::Tensor& CaptureKernelCall<at::Tensor&>::release() && {
588 return output_;
589 }
590
591 // Handle case where the kernel returns void.
592 template <>
593 struct CaptureKernelCall<void> {
594 template <typename F, typename... Args>
595 CaptureKernelCall(
596 const F& kernel,
597 const TypedOperatorHandle<void(Args...)>& op,
598 const DispatchKeySet& dispatchKeySet,
599 Args&&... args) {
600 // Calling the kernel and no need to capture void.
601 kernel.template call<void, Args...>(
602 op, dispatchKeySet, std::forward<Args>(args)...);
603 }
604 Stack getOutputs() {
605 return Stack();
606 }
607 void release() && {}
608 };
609
610 TORCH_API void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet);
611
612 } // namespace detail
613
614 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
615 template<class Return, class... Args>
616 inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<Return(Args...)>& op, at::StepCallbacks& stepCallbacks, DispatchKeySet dispatchKeySet, const KernelFunction& kernel, Args... args) {
617 // If callbacks need inputs, we box the arguments and pass them to the guard.
618 // Note: For perf reasons we wouldn't want to prematurely box the arguments.
619 at::RecordFunction guard(std::move(stepCallbacks));
620 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.operatorDef_->op.isObserved());
621 auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
622 auto& schema = op.schema();
623 auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
624 constexpr auto num_boxed_args = impl::boxed_size<Args...>();
625 if constexpr (num_boxed_args != 0) {
626 if (guard.needsInputs()) {
627 // If we used std::array<IValue, num_boxed_args> here, we would
628 // have to spend time default constructing the IValues in
629 // boxedArgs. aligned_storage has no such requirement.
630 impl::IValueAlignedStorage boxedArgs[num_boxed_args];
631 // For debugging only; could be removed (but the compiler will do
632 // that for us and it's nice to have the extra assurance of
633 // correctness from our debug builds).
634 int lastArgIdx = 0;
635 impl::boxArgsToStack(boxedArgs, lastArgIdx, args...);
636 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(lastArgIdx == num_boxed_args);
637 // I don't *think* we need std::launder here, because IValue has
638 // no subclasses and no const or reference fields.
639 runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet, c10::ArrayRef<const c10::IValue>(reinterpret_cast<IValue *>(boxedArgs), num_boxed_args));
640 for (size_t ii = 0; ii < num_boxed_args; ++ii) {
641 reinterpret_cast<IValue *>(&boxedArgs[ii])->~IValue();
642 }
643 } else {
644 runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
645 }
646 } else {
647 runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
648 }
649
650 if (C10_UNLIKELY(guard.needsOutputs())) {
651 // Calls the kernel and capture the output temporarily to pass to
652 // RecordFunction.
653 detail::CaptureKernelCall<Return> captureKernelCall(
654 kernel, op, dispatchKeySet, std::forward<Args>(args)...);
655 guard.setOutputs(captureKernelCall.getOutputs());
656 // Releases the captured output to return to caller.
657 return std::move(captureKernelCall).release();
658 }
659
660 // keeping the guard alive while executing the kernel
661 return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
662 }
663
664 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
665 template<class Return, class... Args>
666 C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
667 detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
668 auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
669 .template getDispatchKeySetUnboxed<Args...>(args...);
670 #ifndef NDEBUG
671 DispatchTraceNestingGuard debug_guard;
672 if (show_dispatch_trace()) {
673 detail::_print_dispatch_trace("[call]", toString(op.operator_name()), dispatchKeySet);
674 }
675 #endif
676 const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
677 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
678 auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
679 if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
680 return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
681 }
682 #endif // PYTORCH_DISABLE_PER_OP_PROFILING
683
684 #ifdef FBCODE_CAFFE2
685 if(profilingOperatorEvents()) {
686 struct FireOpRAII {
687 FireOpRAII(at::RecordFunction::schema_ref_t schema_ref) : schema_ref_(schema_ref) {
688 fireOpStartUSDT(schema_ref);
689 }
690 ~FireOpRAII() { fireOpEndUSDT(schema_ref_); }
691 at::RecordFunction::schema_ref_t schema_ref_;
692 } event(op.schema());
693 return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
694 } else {
695 return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
696 }
697 #else
698 return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
699 #endif // FBCODE_CAFFE2
700 }
701
702 // See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
703 template<class Return, class... Args>
704 inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
705 detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
706 // do not use RecordFunction on redispatch
707 #ifndef NDEBUG
708 DispatchTraceNestingGuard debug_guard;
709 if (show_dispatch_trace()) {
710 detail::_print_dispatch_trace("[redispatch]", toString(op.operator_name()), currentDispatchKeySet);
711 }
712 #endif
713 const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
714 return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
715 }
716
717 inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
718 // note: this doesn't need the mutex because write operations on the list keep iterators intact.
719 const auto& entry = op.operatorDef_->op;
720 auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
721 #ifndef NDEBUG
722 DispatchTraceNestingGuard debug_guard;
723 if (show_dispatch_trace()) {
724 detail::_print_dispatch_trace("[callBoxed]", toString(op.operator_name()), dispatchKeySet);
725 }
726 #endif
727 const auto& kernel = entry.lookup(dispatchKeySet);
728 #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
729 auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
730 if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
731 at::RecordFunction guard(std::move(*step_callbacks));
732 auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
733 auto& schema = op.schema();
734 auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);
735 guard.needsInputs() ? runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet, c10::ArrayRef<const c10::IValue>(stack->data(), stack->size()))
736 : runRecordFunction(guard, schema_ref, dispatchKey, dispatchKeySet);
737
738 // keeping the guard alive while executing the kernel
739 kernel.callBoxed(op, dispatchKeySet, stack);
740
741 if (C10_UNLIKELY(guard.needsOutputs())) {
742 guard.setOutputs(*stack);
743 }
744 return;
745 }
746 #endif // PYTORCH_DISABLE_PER_OP_PROFILING
747 kernel.callBoxed(op, dispatchKeySet, stack);
748 }
749
750 // NB: this doesn't count as a "true" dispatcher jump, so no instrumentation
751 inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, DispatchKey dk, Stack* stack) const {
752 // note: this doesn't need the mutex because write operations on the list keep iterators intact.
753 const auto& entry = op.operatorDef_->op;
754 // We still compute this as we're obligated to pass it on to the internal
755 // kernel, if it is a boxed fallback
756 auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
757 const auto& kernel = ([&]() {
758 if (op.hasKernelForDispatchKey(dk)) {
759 return entry.kernelForDispatchKey(dk);
760 } else {
761 auto idx = getDispatchTableIndexForDispatchKey(dk);
762 TORCH_INTERNAL_ASSERT(idx >= 0);
763 return backendFallbackKernels_[idx].kernel;
764 }
765 })();
766 kernel.callBoxed(op, dispatchKeySet, stack);
767 }
768
769 inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
770 // note: this doesn't need the mutex because write operations on the list keep iterators intact.
771 const auto& entry = op.operatorDef_->op;
772 #ifndef NDEBUG
773 DispatchTraceNestingGuard debug_guard;
774 if (show_dispatch_trace()) {
775 detail::_print_dispatch_trace("[redispatchBoxed]", toString(op.operator_name()), dispatchKeySet);
776 }
777 #endif
778 const auto& kernel = entry.lookup(dispatchKeySet);
779 return kernel.callBoxed(op, dispatchKeySet, stack);
780 }
781
782 } // namespace c10
783
784 namespace std {
785
786 template <>
787 struct hash<c10::OperatorHandle> {
788 size_t operator()(const c10::OperatorHandle& op) const noexcept {
789 return std::hash<void*>{}(static_cast<void*>(op.operatorDef_));
790 }
791 };
792
793 } // namespace std
794