xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/Dispatcher.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <ATen/core/PythonOpRegistrationTrampoline.h>
3 #include <chrono>
4 #include <list>
5 #include <sstream>
6 #include <utility>
7 
8 #ifdef FBCODE_CAFFE2
9 #include <c10/util/static_tracepoint.h>
10 #endif
11 
12 namespace c10 {
13 
14 #ifdef FBCODE_CAFFE2
15 TORCH_SDT_DEFINE_SEMAPHORE(operator_start)
TORCH_SDT_DEFINE_SEMAPHORE(operator_end)16 TORCH_SDT_DEFINE_SEMAPHORE(operator_end)
17 #endif
18 
19 bool show_dispatch_trace() {
20     static char const* temp = getenv("TORCH_SHOW_DISPATCH_TRACE");
21     return temp != nullptr;
22 }
23 
24 static thread_local int64_t dispatch_trace_nesting_value_;
25 
dispatch_trace_nesting_incr()26 void dispatch_trace_nesting_incr() { ++dispatch_trace_nesting_value_; }
dispatch_trace_nesting_decr()27 void dispatch_trace_nesting_decr() { --dispatch_trace_nesting_value_; }
dispatch_trace_nesting_value()28 int64_t dispatch_trace_nesting_value() { return dispatch_trace_nesting_value_; }
29 
30 namespace detail {
31 
32 class RegistrationListenerList final {
33 public:
addListener(std::unique_ptr<OpRegistrationListener> listener)34   std::function<void()> addListener(std::unique_ptr<OpRegistrationListener> listener) {
35     listeners_.push_back(std::move(listener));
36     auto delete_it = --listeners_.end();
37     return [this, delete_it] {
38         listeners_.erase(delete_it);
39     };
40   }
41 
callOnOperatorRegistered(const OperatorHandle & op)42   void callOnOperatorRegistered(const OperatorHandle& op) {
43     for (auto& listener : listeners_) {
44       listener->onOperatorRegistered(op);
45     }
46   }
47 
callOnOperatorDeregistered(const OperatorHandle & op)48   void callOnOperatorDeregistered(const OperatorHandle& op) {
49     for (auto& listener : listeners_) {
50       listener->onOperatorDeregistered(op);
51     }
52   }
53 private:
54   std::list<std::unique_ptr<OpRegistrationListener>> listeners_;
55 };
56 
_print_dispatch_trace(const std::string & label,const std::string & op_name,const DispatchKeySet & dispatchKeySet)57 void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
58   auto nesting_value = dispatch_trace_nesting_value();
59   for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
60   std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
61 }
62 } // namespace detail
63 
64 OpRegistrationListener::~OpRegistrationListener()= default;
65 
Dispatcher()66 Dispatcher::Dispatcher()
67 : operators_()
68 , operatorLookupTable_()
69 , backendFallbackKernels_()
70 , listeners_(std::make_unique<detail::RegistrationListenerList>())
71 , cond_var_()
72 , guard_(std::make_shared<Guard>())
73 {}
74 
~Dispatcher()75 Dispatcher::~Dispatcher() {
76   std::lock_guard<std::mutex> lock(guard_->mutex);
77   guard_->alive.store(false);
78 }
79 
realSingleton()80 C10_EXPORT Dispatcher& Dispatcher::realSingleton() {
81   static Dispatcher _singleton;
82   return _singleton;
83 }
84 
findOp(const OperatorName & overload_name)85 std::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_name) {
86   return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::optional<OperatorHandle> {
87     auto found = operatorLookupTable.find(overload_name);
88     if (found == operatorLookupTable.end()) {
89       return std::nullopt;
90     }
91     return found->second;
92   });
93 }
94 
95 // NB: If you add more waitFor* implementations, you also have to add
96 // appropriate notify_all() calls to the relevant register calls
97 
waitForDef(const FunctionSchema & schema)98 void Dispatcher::waitForDef(const FunctionSchema& schema) {
99   using namespace std::chrono_literals;
100   std::unique_lock<std::mutex> lock(guard_->mutex);
101   bool r = cond_var_.wait_for(lock, 2s, [&]{
102     return findOp(schema.operator_name()) != std::nullopt;
103   });
104   TORCH_INTERNAL_ASSERT(r,
105     "Expected main interpreter to define ", schema.operator_name(),
106     ", but this didn't happen within timeout.  Are you trying to load "
107     "different models in the same torchdeploy/multipy instance?  You "
108     "must warmup each interpreter identically, e.g., import all "
109     "the same dependencies.");
110 }
111 
waitForImpl(const OperatorName & op_name,std::optional<c10::DispatchKey> maybe_dk)112 void Dispatcher::waitForImpl(const OperatorName& op_name, std::optional<c10::DispatchKey> maybe_dk) {
113   using namespace std::chrono_literals;
114   std::unique_lock<std::mutex> lock(guard_->mutex);
115   auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd);
116   auto op = findOrRegisterName_(op_name);
117   bool r = cond_var_.wait_for(lock, 2s, [&]{
118     // NB: this is slightly unsound for overrides, but overrides are
119     // funny business anyway
120     return op.hasKernelForDispatchKey(dk);
121   });
122   TORCH_INTERNAL_ASSERT(r,
123     "Expected main interpreter to implement ", dk, " for ", op_name,
124     ", but this didn't happen within timeout.  Are you trying to load "
125     "different models in the same torchdeploy/multipy instance?  You "
126     "must warmup each interpreter identically, e.g., import all "
127     "the same dependencies.");
128 }
129 
findSchema(const OperatorName & overload_name)130 std::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
131   auto it = findOp(overload_name);
132   if (it.has_value()) {
133     if (it->hasSchema()) {
134       return it;
135     } else {
136       return std::nullopt;
137     }
138   } else {
139     return it;
140   }
141 }
142 
findSchemaOrThrow(const char * name,const char * overload_name)143 OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overload_name) {
144   auto it = findSchema({name, overload_name});
145   if (!it.has_value()) {
146     // Check if we have ANYTHING; if that's the case, that means you're
147     // missing schema
148     auto it2 = findOp({name, overload_name});
149     if (!it2.has_value()) {
150       TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name);
151     } else {
152       TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name,
153         " but we found an implementation; did you forget to def() the operator?");
154     }
155   }
156   return it.value();
157 }
158 
getAllOpNames()159 const std::vector<OperatorName> Dispatcher::getAllOpNames() {
160   return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
161     std::vector<OperatorName> allOpNames;
162     for (const auto& op : operatorLookupTable) {
163         allOpNames.push_back(op.first);
164     }
165     return allOpNames;
166   });
167 }
168 
169 // Postcondition: caller is responsible for disposing of registration when they
170 // are done
findOrRegisterName_(const OperatorName & op_name)171 OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) {
172   const auto found = findOp(op_name);
173   if (found != std::nullopt) {
174     return *found;
175   }
176 
177   operators_.emplace_back(OperatorName(op_name));
178   OperatorHandle handle(--operators_.end());
179   operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
180     operatorLookupTable.emplace(op_name, handle);
181   });
182 
183   return handle;
184 }
185 
186 
187 // Adding explicit destructor definition in the cpp to over linker error in Windows builds.
188 // Windows build doesn't produce the destructor symbol in PyTorch libs
189 // causing a linker failure in downstream projects.
190 // x-ref https://github.com/pytorch/pytorch/issues/70032
191 OperatorHandle::~OperatorHandle() = default;
192 
registerLibrary(std::string ns,std::string debug)193 RegistrationHandleRAII Dispatcher::registerLibrary(std::string ns, std::string debug) {
194   std::lock_guard<std::mutex> lock(guard_->mutex);
195   auto found = libraries_.find(ns);
196   TORCH_CHECK(
197     found == libraries_.end(),
198     "Only a single TORCH_LIBRARY can be used to register the namespace ", ns,
199     "; please put all of your definitions in a single TORCH_LIBRARY block.  "
200     "If you were trying to specify implementations, consider using TORCH_LIBRARY_IMPL "
201     "(which can be duplicated).  If you really intended to define operators for a "
202     "single namespace in a distributed way, you can use TORCH_LIBRARY_FRAGMENT to "
203     "explicitly indicate this.  "
204     "Previous registration of TORCH_LIBRARY was ",
205     found->second, "; latest registration was ", debug
206   );
207   libraries_.emplace(ns, std::move(debug));
208   return RegistrationHandleRAII([guard = this->guard_, this, ns] {
209     std::lock_guard<std::mutex> lock(guard->mutex);
210     if (!guard->alive.load()) {
211       return;
212     }
213     deregisterLibrary_(ns);
214   });
215 }
216 
deregisterLibrary_(const std::string & ns)217 void Dispatcher::deregisterLibrary_(const std::string& ns) {
218   // we need a lock to avoid concurrent writes
219   libraries_.erase(ns);
220 }
221 
registerDef(FunctionSchema schema,std::string debug,std::vector<at::Tag> tags)222 RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
223   // we need a lock to avoid concurrent writes
224   std::lock_guard<std::mutex> lock(guard_->mutex);
225 
226   OperatorName op_name = schema.operator_name();
227   auto op = findOrRegisterName_(op_name);
228 
229   TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
230                                                     " Each overload's schema should only be registered with a single call to def().",
231                                                     " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
232   op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), std::move(tags));
233   listeners_->callOnOperatorRegistered(op);
234 
235   // NB: do not increment the counts until AFTER error checking
236   ++op.operatorDef_->def_count;
237   ++op.operatorDef_->def_and_impl_count;
238 
239   cond_var_.notify_all();
240 
241   return RegistrationHandleRAII([guard = this->guard_, this, op, op_name] {
242     // we need a lock to avoid concurrent writes
243     std::lock_guard<std::mutex> lock(guard->mutex);
244     if (!guard->alive.load()) {
245       return;
246     }
247     deregisterDef_(op, op_name);
248   });
249 }
250 
deregisterDef_(const OperatorHandle & op,const OperatorName & op_name)251 void Dispatcher::deregisterDef_(
252     const OperatorHandle& op,
253     const OperatorName& op_name) {
254   TORCH_INTERNAL_ASSERT(op.schema().operator_name() == op_name);
255 
256   // reduce def_count and actually deregister if no references left
257   TORCH_INTERNAL_ASSERT(op.operatorDef_->def_count > 0);
258   TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
259 
260   --op.operatorDef_->def_count;
261   --op.operatorDef_->def_and_impl_count;
262   if (0 == op.operatorDef_->def_count) {
263     // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
264     // TODO: check that listeners are not relying on prepareForDeregistration()
265     // invariant
266     listeners_->callOnOperatorDeregistered(op);
267     op.operatorDef_->op.deregisterSchema();
268   }
269 
270   cleanup(op, op_name);
271 }
272 
273 namespace {
274 
275 // Maps OperatorName to (python module name, description) tuple.
276 using PythonModuleMapType = std::unordered_map<at::OperatorName, std::pair<const char*, const char*>>;
pythonModulesSingleton()277 PythonModuleMapType& pythonModulesSingleton() {
278   static PythonModuleMapType _data;
279   return _data;
280 }
281 
282 }
283 
getPyStub(OperatorName op_name)284 std::optional<std::pair<const char*, const char*>> Dispatcher::getPyStub(OperatorName op_name) {
285   std::lock_guard<std::mutex> lock(guard_->mutex);
286   auto found = pythonModulesSingleton().find(op_name);
287   if (found == pythonModulesSingleton().end()) {
288     return std::nullopt;
289   }
290   return found->second;
291 }
292 
registerPythonModule(const OperatorName & op_name,const char * pymodule,const char * context)293 RegistrationHandleRAII Dispatcher::registerPythonModule(
294   const OperatorName& op_name,
295   const char* pymodule,
296   const char* context
297 ) {
298   std::lock_guard<std::mutex> lock(guard_->mutex);
299   // If there are duplicates, we just let it through and warn about it.
300   // Throwing an error during static initialization causes a crash that
301   // doesn't give any sign of what happened.
302   auto found = pythonModulesSingleton().find(op_name);
303   if (found != pythonModulesSingleton().end()) {
304     TORCH_WARN(
305         "Tried to register an python registration stub (pystub) for ", op_name, " ",
306         "that specifies the Python module ", pymodule, " "
307         "but there already was a pystub that specifies the Python module ",
308         found->second.first, ". We will override the existing pystub.");
309   }
310   pythonModulesSingleton()[op_name] = std::make_pair(pymodule, context);
311   return RegistrationHandleRAII([guard = this->guard_, op_name] {
312     std::lock_guard<std::mutex> lock(guard->mutex);
313     if (!guard->alive.load()) {
314       return;
315     }
316     pythonModulesSingleton().erase(op_name);
317   });
318 }
319 
throwIfHasPythonModule(OperatorName op_name)320 void Dispatcher::throwIfHasPythonModule(OperatorName op_name) {
321   std::lock_guard<std::mutex> lock(guard_->mutex);
322   auto elt = pythonModulesSingleton().find(op_name);
323   if (elt == pythonModulesSingleton().end()) {
324     return;
325   }
326   const char* pymodule = elt->second.first;
327   const char* context = elt->second.second;
328   auto* interpreter = at::impl::PythonOpRegistrationTrampoline::getInterpreter();
329   TORCH_CHECK(
330       interpreter != nullptr,
331       op_name,
332       ": while attempting to run this operator with Meta Tensors: "
333       "Either there is no meta kernel for this operator, or it is located "
334       "in the python module ", pymodule, " which is not available "
335       "because Python isn't available.")
336   (*interpreter)->throw_abstract_impl_not_imported_error(toString(op_name), pymodule, context);
337 }
338 
registerImpl(OperatorName op_name,std::optional<DispatchKey> dispatch_key,KernelFunction kernel,std::optional<impl::CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> inferred_function_schema,std::string debug)339 RegistrationHandleRAII Dispatcher::registerImpl(
340   OperatorName op_name,
341   std::optional<DispatchKey> dispatch_key,
342   KernelFunction kernel,
343   std::optional<impl::CppSignature> cpp_signature,
344   std::unique_ptr<FunctionSchema> inferred_function_schema,
345   std::string debug
346 ) {
347   std::lock_guard<std::mutex> lock(guard_->mutex);
348 
349   auto op = findOrRegisterName_(op_name);
350 
351   auto handle = op.operatorDef_->op.registerKernel(
352     *this,
353     dispatch_key,
354     std::move(kernel),
355     std::move(cpp_signature),
356     std::move(inferred_function_schema),
357     std::move(debug)
358   );
359 
360   ++op.operatorDef_->def_and_impl_count;
361 
362   cond_var_.notify_all();
363 
364   return RegistrationHandleRAII([guard = this->guard_, this, op, op_name, dispatch_key, handle] {
365     std::lock_guard<std::mutex> lock(guard->mutex);
366     if (!guard->alive.load()) {
367       return;
368     }
369     deregisterImpl_(op, op_name, dispatch_key, handle);
370   });
371 }
372 
deregisterImpl_(const OperatorHandle & op,const OperatorName & op_name,std::optional<DispatchKey> dispatch_key,impl::OperatorEntry::AnnotatedKernelContainerIterator handle)373 void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, std::optional<DispatchKey> dispatch_key, impl::OperatorEntry::AnnotatedKernelContainerIterator handle) {
374   op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle);
375 
376   TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
377 
378   TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
379   --op.operatorDef_->def_and_impl_count;
380 
381   cleanup(op, op_name);
382 }
383 
registerName(OperatorName op_name)384 RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) {
385   std::lock_guard<std::mutex> lock(guard_->mutex);
386   auto op = findOrRegisterName_(op_name);
387   ++op.operatorDef_->def_and_impl_count;
388 
389   return RegistrationHandleRAII(
390       [guard = this->guard_, this, op, op_name] {
391         std::lock_guard<std::mutex> lock(guard->mutex);
392         if (!guard->alive.load()) {
393           return;
394         }
395         deregisterName_(op, op_name);
396       }
397   );
398 }
399 
deregisterName_(const OperatorHandle & op,const OperatorName & op_name)400 void Dispatcher::deregisterName_(
401     const OperatorHandle& op,
402     const OperatorName& op_name) {
403   TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
404   TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
405   --op.operatorDef_->def_and_impl_count;
406   cleanup(op, op_name);
407 }
408 
409 // Test if the operator entry is completely dead, and if so remove it completely
cleanup(const OperatorHandle & op,const OperatorName & op_name)410 void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) {
411   if (0 == op.operatorDef_->def_and_impl_count) {
412     // NOTE: Making this call fast is the only reason OperatorHandle
413     // stores operatorIterator_!
414     operators_.erase(op.operatorIterator_);
415     operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
416       operatorLookupTable.erase(op_name);
417     });
418   }
419 }
420 
registerFallback(DispatchKey dispatchKey,KernelFunction kernel,std::string debug)421 RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
422   std::lock_guard<std::mutex> lock(guard_->mutex);
423 
424   auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
425   TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
426   TORCH_CHECK(
427     !backendFallbackKernels_[idx].kernel.isValid(),
428     "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
429     backendFallbackKernels_[idx].debug, ", new registration ", debug
430   );
431   // NB: inferred function schema is always nullptr for fallbacks, as fallbacks
432   // cannot be unboxed
433   backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
434 
435   for (auto& op : operators_) {
436     op.op.updateFallback(*this, dispatchKey);
437   }
438 
439   return RegistrationHandleRAII([guard = this->guard_, this, dispatchKey] {
440     std::lock_guard<std::mutex> lock(guard->mutex);
441     if (!guard->alive.load()) {
442       return;
443     }
444     deregisterFallback_(dispatchKey);
445   });
446 }
447 
deregisterFallback_(DispatchKey dispatchKey)448 void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
449   auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
450   backendFallbackKernels_[idx] = {};
451 
452   for (auto& op : operators_) {
453     op.op.updateFallback(*this, dispatchKey);
454   }
455 }
456 
457 
addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener)458 RegistrationHandleRAII Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
459   std::lock_guard<std::mutex> lock(guard_->mutex);
460 
461   for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
462     if (iter->def_count > 0) {
463       listener->onOperatorRegistered(OperatorHandle(iter));
464     }
465   }
466 
467   auto removeListener = listeners_->addListener(std::move(listener));
468   return RegistrationHandleRAII([guard = this->guard_, this, removeListener] {
469       std::lock_guard<std::mutex> lock(guard_->mutex);
470       if (!guard->alive.load()) {
471         return;
472       }
473       removeListener();
474   });
475 }
476 
checkInvariants() const477 void Dispatcher::checkInvariants() const {
478   for (const auto& op : operators_) {
479     op.op.checkInvariants();
480   }
481 }
482 
findDanglingImpls() const483 std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const {
484   return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorHandle> {
485     std::vector<OperatorHandle> opsWithDanglingImpls;
486     for (const auto& op : operatorLookupTable) {
487       if (!op.second.hasSchema()) {
488         opsWithDanglingImpls.push_back(op.second);
489       }
490     }
491     return opsWithDanglingImpls;
492   });
493 }
494 
getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const495 std::vector<OperatorName> Dispatcher::getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const {
496   return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
497     std::vector<OperatorName> op_names;
498     for (const auto& op : operatorLookupTable) {
499       // If no DispatchKey is specified, print all of the operators.
500       if (!k || op.second.hasKernelForDispatchKey(*k)) {
501           op_names.push_back(op.first);
502       }
503     }
504     return op_names;
505   });
506 }
507 
sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)508 int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) {
509   int64_t seq_num = -1;
510   // Setting sequence number in the Autograd case to associate
511   // the forward range with the corresponding Autograd's node
512 
513   // Note: this records a sequence number for both Autograd keys, and for
514   // non-Autograd keys where the dispatchKeySet still contains an autograd key.
515   // This means that we might collect the same sequence nubmer two different
516   // events if they all occurred above Autograd and still had the Autograd
517   // dispatch key in the dispatch key set.
518   // However, this usually doesn't happen: normally the first call will
519   // go through the call() or callBoxed() path in the dispatcher, while
520   // subsequent redispatches go through redispatch() or redispatchBoxed().
521   // `call` has profiler instrumentation, whereas `redispatch` doesn't.
522   // So usually, we'll collect a sequence number on the first call() if the
523   // dispatch keys contain autograd, and not on subsequent redispatches.
524   bool dispatchHasAutograd = !(dispatchKeySet & autograd_dispatch_keyset).empty();
525 
526   if (dispatchHasAutograd && at::GradMode::is_enabled()) {
527     seq_num = at::sequence_number::peek();
528   }
529   return seq_num;
530 }
531 
runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet,c10::ArrayRef<const c10::IValue> args)532 void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet, c10::ArrayRef<const c10::IValue> args) {
533   guard.before(schema_ref, args, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet));
534 }
535 
runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)536 void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) {
537   // Setting sequence number in the Autograd case to associate
538   // the forward range with the corresponding Autograd's node
539   guard.before(schema_ref, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet));
540 }
541 #ifdef FBCODE_CAFFE2
profilingOperatorEvents()542 bool Dispatcher::profilingOperatorEvents() {
543   return TORCH_SDT_IS_ENABLED(operator_start) || TORCH_SDT_IS_ENABLED(operator_end);
544 }
545 
fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref)546 C10_NOINLINE void Dispatcher::fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref) {
547   if (TORCH_SDT_IS_ENABLED(operator_start)) {
548     TORCH_SDT_WITH_SEMAPHORE(operator_start, schema_ref.get().name().c_str());
549   }
550 }
551 
fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref)552 C10_NOINLINE void Dispatcher::fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref) {
553   if (TORCH_SDT_IS_ENABLED(operator_end)) {
554     TORCH_SDT_WITH_SEMAPHORE(operator_end, schema_ref.get().name().c_str());
555   }
556 }
557 #endif
558 
559 }
560