1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <ATen/core/PythonOpRegistrationTrampoline.h>
3 #include <chrono>
4 #include <list>
5 #include <sstream>
6 #include <utility>
7
8 #ifdef FBCODE_CAFFE2
9 #include <c10/util/static_tracepoint.h>
10 #endif
11
12 namespace c10 {
13
14 #ifdef FBCODE_CAFFE2
15 TORCH_SDT_DEFINE_SEMAPHORE(operator_start)
TORCH_SDT_DEFINE_SEMAPHORE(operator_end)16 TORCH_SDT_DEFINE_SEMAPHORE(operator_end)
17 #endif
18
19 bool show_dispatch_trace() {
20 static char const* temp = getenv("TORCH_SHOW_DISPATCH_TRACE");
21 return temp != nullptr;
22 }
23
24 static thread_local int64_t dispatch_trace_nesting_value_;
25
dispatch_trace_nesting_incr()26 void dispatch_trace_nesting_incr() { ++dispatch_trace_nesting_value_; }
dispatch_trace_nesting_decr()27 void dispatch_trace_nesting_decr() { --dispatch_trace_nesting_value_; }
dispatch_trace_nesting_value()28 int64_t dispatch_trace_nesting_value() { return dispatch_trace_nesting_value_; }
29
30 namespace detail {
31
32 class RegistrationListenerList final {
33 public:
addListener(std::unique_ptr<OpRegistrationListener> listener)34 std::function<void()> addListener(std::unique_ptr<OpRegistrationListener> listener) {
35 listeners_.push_back(std::move(listener));
36 auto delete_it = --listeners_.end();
37 return [this, delete_it] {
38 listeners_.erase(delete_it);
39 };
40 }
41
callOnOperatorRegistered(const OperatorHandle & op)42 void callOnOperatorRegistered(const OperatorHandle& op) {
43 for (auto& listener : listeners_) {
44 listener->onOperatorRegistered(op);
45 }
46 }
47
callOnOperatorDeregistered(const OperatorHandle & op)48 void callOnOperatorDeregistered(const OperatorHandle& op) {
49 for (auto& listener : listeners_) {
50 listener->onOperatorDeregistered(op);
51 }
52 }
53 private:
54 std::list<std::unique_ptr<OpRegistrationListener>> listeners_;
55 };
56
_print_dispatch_trace(const std::string & label,const std::string & op_name,const DispatchKeySet & dispatchKeySet)57 void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
58 auto nesting_value = dispatch_trace_nesting_value();
59 for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
60 std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
61 }
62 } // namespace detail
63
64 OpRegistrationListener::~OpRegistrationListener()= default;
65
Dispatcher()66 Dispatcher::Dispatcher()
67 : operators_()
68 , operatorLookupTable_()
69 , backendFallbackKernels_()
70 , listeners_(std::make_unique<detail::RegistrationListenerList>())
71 , cond_var_()
72 , guard_(std::make_shared<Guard>())
73 {}
74
~Dispatcher()75 Dispatcher::~Dispatcher() {
76 std::lock_guard<std::mutex> lock(guard_->mutex);
77 guard_->alive.store(false);
78 }
79
realSingleton()80 C10_EXPORT Dispatcher& Dispatcher::realSingleton() {
81 static Dispatcher _singleton;
82 return _singleton;
83 }
84
findOp(const OperatorName & overload_name)85 std::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_name) {
86 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::optional<OperatorHandle> {
87 auto found = operatorLookupTable.find(overload_name);
88 if (found == operatorLookupTable.end()) {
89 return std::nullopt;
90 }
91 return found->second;
92 });
93 }
94
95 // NB: If you add more waitFor* implementations, you also have to add
96 // appropriate notify_all() calls to the relevant register calls
97
waitForDef(const FunctionSchema & schema)98 void Dispatcher::waitForDef(const FunctionSchema& schema) {
99 using namespace std::chrono_literals;
100 std::unique_lock<std::mutex> lock(guard_->mutex);
101 bool r = cond_var_.wait_for(lock, 2s, [&]{
102 return findOp(schema.operator_name()) != std::nullopt;
103 });
104 TORCH_INTERNAL_ASSERT(r,
105 "Expected main interpreter to define ", schema.operator_name(),
106 ", but this didn't happen within timeout. Are you trying to load "
107 "different models in the same torchdeploy/multipy instance? You "
108 "must warmup each interpreter identically, e.g., import all "
109 "the same dependencies.");
110 }
111
waitForImpl(const OperatorName & op_name,std::optional<c10::DispatchKey> maybe_dk)112 void Dispatcher::waitForImpl(const OperatorName& op_name, std::optional<c10::DispatchKey> maybe_dk) {
113 using namespace std::chrono_literals;
114 std::unique_lock<std::mutex> lock(guard_->mutex);
115 auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd);
116 auto op = findOrRegisterName_(op_name);
117 bool r = cond_var_.wait_for(lock, 2s, [&]{
118 // NB: this is slightly unsound for overrides, but overrides are
119 // funny business anyway
120 return op.hasKernelForDispatchKey(dk);
121 });
122 TORCH_INTERNAL_ASSERT(r,
123 "Expected main interpreter to implement ", dk, " for ", op_name,
124 ", but this didn't happen within timeout. Are you trying to load "
125 "different models in the same torchdeploy/multipy instance? You "
126 "must warmup each interpreter identically, e.g., import all "
127 "the same dependencies.");
128 }
129
findSchema(const OperatorName & overload_name)130 std::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
131 auto it = findOp(overload_name);
132 if (it.has_value()) {
133 if (it->hasSchema()) {
134 return it;
135 } else {
136 return std::nullopt;
137 }
138 } else {
139 return it;
140 }
141 }
142
findSchemaOrThrow(const char * name,const char * overload_name)143 OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overload_name) {
144 auto it = findSchema({name, overload_name});
145 if (!it.has_value()) {
146 // Check if we have ANYTHING; if that's the case, that means you're
147 // missing schema
148 auto it2 = findOp({name, overload_name});
149 if (!it2.has_value()) {
150 TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name);
151 } else {
152 TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name,
153 " but we found an implementation; did you forget to def() the operator?");
154 }
155 }
156 return it.value();
157 }
158
getAllOpNames()159 const std::vector<OperatorName> Dispatcher::getAllOpNames() {
160 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
161 std::vector<OperatorName> allOpNames;
162 for (const auto& op : operatorLookupTable) {
163 allOpNames.push_back(op.first);
164 }
165 return allOpNames;
166 });
167 }
168
169 // Postcondition: caller is responsible for disposing of registration when they
170 // are done
findOrRegisterName_(const OperatorName & op_name)171 OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) {
172 const auto found = findOp(op_name);
173 if (found != std::nullopt) {
174 return *found;
175 }
176
177 operators_.emplace_back(OperatorName(op_name));
178 OperatorHandle handle(--operators_.end());
179 operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
180 operatorLookupTable.emplace(op_name, handle);
181 });
182
183 return handle;
184 }
185
186
187 // Adding explicit destructor definition in the cpp to over linker error in Windows builds.
188 // Windows build doesn't produce the destructor symbol in PyTorch libs
189 // causing a linker failure in downstream projects.
190 // x-ref https://github.com/pytorch/pytorch/issues/70032
191 OperatorHandle::~OperatorHandle() = default;
192
registerLibrary(std::string ns,std::string debug)193 RegistrationHandleRAII Dispatcher::registerLibrary(std::string ns, std::string debug) {
194 std::lock_guard<std::mutex> lock(guard_->mutex);
195 auto found = libraries_.find(ns);
196 TORCH_CHECK(
197 found == libraries_.end(),
198 "Only a single TORCH_LIBRARY can be used to register the namespace ", ns,
199 "; please put all of your definitions in a single TORCH_LIBRARY block. "
200 "If you were trying to specify implementations, consider using TORCH_LIBRARY_IMPL "
201 "(which can be duplicated). If you really intended to define operators for a "
202 "single namespace in a distributed way, you can use TORCH_LIBRARY_FRAGMENT to "
203 "explicitly indicate this. "
204 "Previous registration of TORCH_LIBRARY was ",
205 found->second, "; latest registration was ", debug
206 );
207 libraries_.emplace(ns, std::move(debug));
208 return RegistrationHandleRAII([guard = this->guard_, this, ns] {
209 std::lock_guard<std::mutex> lock(guard->mutex);
210 if (!guard->alive.load()) {
211 return;
212 }
213 deregisterLibrary_(ns);
214 });
215 }
216
deregisterLibrary_(const std::string & ns)217 void Dispatcher::deregisterLibrary_(const std::string& ns) {
218 // we need a lock to avoid concurrent writes
219 libraries_.erase(ns);
220 }
221
registerDef(FunctionSchema schema,std::string debug,std::vector<at::Tag> tags)222 RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
223 // we need a lock to avoid concurrent writes
224 std::lock_guard<std::mutex> lock(guard_->mutex);
225
226 OperatorName op_name = schema.operator_name();
227 auto op = findOrRegisterName_(op_name);
228
229 TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
230 " Each overload's schema should only be registered with a single call to def().",
231 " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
232 op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), std::move(tags));
233 listeners_->callOnOperatorRegistered(op);
234
235 // NB: do not increment the counts until AFTER error checking
236 ++op.operatorDef_->def_count;
237 ++op.operatorDef_->def_and_impl_count;
238
239 cond_var_.notify_all();
240
241 return RegistrationHandleRAII([guard = this->guard_, this, op, op_name] {
242 // we need a lock to avoid concurrent writes
243 std::lock_guard<std::mutex> lock(guard->mutex);
244 if (!guard->alive.load()) {
245 return;
246 }
247 deregisterDef_(op, op_name);
248 });
249 }
250
deregisterDef_(const OperatorHandle & op,const OperatorName & op_name)251 void Dispatcher::deregisterDef_(
252 const OperatorHandle& op,
253 const OperatorName& op_name) {
254 TORCH_INTERNAL_ASSERT(op.schema().operator_name() == op_name);
255
256 // reduce def_count and actually deregister if no references left
257 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_count > 0);
258 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
259
260 --op.operatorDef_->def_count;
261 --op.operatorDef_->def_and_impl_count;
262 if (0 == op.operatorDef_->def_count) {
263 // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
264 // TODO: check that listeners are not relying on prepareForDeregistration()
265 // invariant
266 listeners_->callOnOperatorDeregistered(op);
267 op.operatorDef_->op.deregisterSchema();
268 }
269
270 cleanup(op, op_name);
271 }
272
273 namespace {
274
275 // Maps OperatorName to (python module name, description) tuple.
276 using PythonModuleMapType = std::unordered_map<at::OperatorName, std::pair<const char*, const char*>>;
pythonModulesSingleton()277 PythonModuleMapType& pythonModulesSingleton() {
278 static PythonModuleMapType _data;
279 return _data;
280 }
281
282 }
283
getPyStub(OperatorName op_name)284 std::optional<std::pair<const char*, const char*>> Dispatcher::getPyStub(OperatorName op_name) {
285 std::lock_guard<std::mutex> lock(guard_->mutex);
286 auto found = pythonModulesSingleton().find(op_name);
287 if (found == pythonModulesSingleton().end()) {
288 return std::nullopt;
289 }
290 return found->second;
291 }
292
registerPythonModule(const OperatorName & op_name,const char * pymodule,const char * context)293 RegistrationHandleRAII Dispatcher::registerPythonModule(
294 const OperatorName& op_name,
295 const char* pymodule,
296 const char* context
297 ) {
298 std::lock_guard<std::mutex> lock(guard_->mutex);
299 // If there are duplicates, we just let it through and warn about it.
300 // Throwing an error during static initialization causes a crash that
301 // doesn't give any sign of what happened.
302 auto found = pythonModulesSingleton().find(op_name);
303 if (found != pythonModulesSingleton().end()) {
304 TORCH_WARN(
305 "Tried to register an python registration stub (pystub) for ", op_name, " ",
306 "that specifies the Python module ", pymodule, " "
307 "but there already was a pystub that specifies the Python module ",
308 found->second.first, ". We will override the existing pystub.");
309 }
310 pythonModulesSingleton()[op_name] = std::make_pair(pymodule, context);
311 return RegistrationHandleRAII([guard = this->guard_, op_name] {
312 std::lock_guard<std::mutex> lock(guard->mutex);
313 if (!guard->alive.load()) {
314 return;
315 }
316 pythonModulesSingleton().erase(op_name);
317 });
318 }
319
throwIfHasPythonModule(OperatorName op_name)320 void Dispatcher::throwIfHasPythonModule(OperatorName op_name) {
321 std::lock_guard<std::mutex> lock(guard_->mutex);
322 auto elt = pythonModulesSingleton().find(op_name);
323 if (elt == pythonModulesSingleton().end()) {
324 return;
325 }
326 const char* pymodule = elt->second.first;
327 const char* context = elt->second.second;
328 auto* interpreter = at::impl::PythonOpRegistrationTrampoline::getInterpreter();
329 TORCH_CHECK(
330 interpreter != nullptr,
331 op_name,
332 ": while attempting to run this operator with Meta Tensors: "
333 "Either there is no meta kernel for this operator, or it is located "
334 "in the python module ", pymodule, " which is not available "
335 "because Python isn't available.")
336 (*interpreter)->throw_abstract_impl_not_imported_error(toString(op_name), pymodule, context);
337 }
338
registerImpl(OperatorName op_name,std::optional<DispatchKey> dispatch_key,KernelFunction kernel,std::optional<impl::CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> inferred_function_schema,std::string debug)339 RegistrationHandleRAII Dispatcher::registerImpl(
340 OperatorName op_name,
341 std::optional<DispatchKey> dispatch_key,
342 KernelFunction kernel,
343 std::optional<impl::CppSignature> cpp_signature,
344 std::unique_ptr<FunctionSchema> inferred_function_schema,
345 std::string debug
346 ) {
347 std::lock_guard<std::mutex> lock(guard_->mutex);
348
349 auto op = findOrRegisterName_(op_name);
350
351 auto handle = op.operatorDef_->op.registerKernel(
352 *this,
353 dispatch_key,
354 std::move(kernel),
355 std::move(cpp_signature),
356 std::move(inferred_function_schema),
357 std::move(debug)
358 );
359
360 ++op.operatorDef_->def_and_impl_count;
361
362 cond_var_.notify_all();
363
364 return RegistrationHandleRAII([guard = this->guard_, this, op, op_name, dispatch_key, handle] {
365 std::lock_guard<std::mutex> lock(guard->mutex);
366 if (!guard->alive.load()) {
367 return;
368 }
369 deregisterImpl_(op, op_name, dispatch_key, handle);
370 });
371 }
372
deregisterImpl_(const OperatorHandle & op,const OperatorName & op_name,std::optional<DispatchKey> dispatch_key,impl::OperatorEntry::AnnotatedKernelContainerIterator handle)373 void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, std::optional<DispatchKey> dispatch_key, impl::OperatorEntry::AnnotatedKernelContainerIterator handle) {
374 op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle);
375
376 TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
377
378 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
379 --op.operatorDef_->def_and_impl_count;
380
381 cleanup(op, op_name);
382 }
383
registerName(OperatorName op_name)384 RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) {
385 std::lock_guard<std::mutex> lock(guard_->mutex);
386 auto op = findOrRegisterName_(op_name);
387 ++op.operatorDef_->def_and_impl_count;
388
389 return RegistrationHandleRAII(
390 [guard = this->guard_, this, op, op_name] {
391 std::lock_guard<std::mutex> lock(guard->mutex);
392 if (!guard->alive.load()) {
393 return;
394 }
395 deregisterName_(op, op_name);
396 }
397 );
398 }
399
deregisterName_(const OperatorHandle & op,const OperatorName & op_name)400 void Dispatcher::deregisterName_(
401 const OperatorHandle& op,
402 const OperatorName& op_name) {
403 TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
404 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
405 --op.operatorDef_->def_and_impl_count;
406 cleanup(op, op_name);
407 }
408
409 // Test if the operator entry is completely dead, and if so remove it completely
cleanup(const OperatorHandle & op,const OperatorName & op_name)410 void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) {
411 if (0 == op.operatorDef_->def_and_impl_count) {
412 // NOTE: Making this call fast is the only reason OperatorHandle
413 // stores operatorIterator_!
414 operators_.erase(op.operatorIterator_);
415 operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
416 operatorLookupTable.erase(op_name);
417 });
418 }
419 }
420
registerFallback(DispatchKey dispatchKey,KernelFunction kernel,std::string debug)421 RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
422 std::lock_guard<std::mutex> lock(guard_->mutex);
423
424 auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
425 TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
426 TORCH_CHECK(
427 !backendFallbackKernels_[idx].kernel.isValid(),
428 "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
429 backendFallbackKernels_[idx].debug, ", new registration ", debug
430 );
431 // NB: inferred function schema is always nullptr for fallbacks, as fallbacks
432 // cannot be unboxed
433 backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
434
435 for (auto& op : operators_) {
436 op.op.updateFallback(*this, dispatchKey);
437 }
438
439 return RegistrationHandleRAII([guard = this->guard_, this, dispatchKey] {
440 std::lock_guard<std::mutex> lock(guard->mutex);
441 if (!guard->alive.load()) {
442 return;
443 }
444 deregisterFallback_(dispatchKey);
445 });
446 }
447
deregisterFallback_(DispatchKey dispatchKey)448 void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
449 auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
450 backendFallbackKernels_[idx] = {};
451
452 for (auto& op : operators_) {
453 op.op.updateFallback(*this, dispatchKey);
454 }
455 }
456
457
addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener)458 RegistrationHandleRAII Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
459 std::lock_guard<std::mutex> lock(guard_->mutex);
460
461 for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
462 if (iter->def_count > 0) {
463 listener->onOperatorRegistered(OperatorHandle(iter));
464 }
465 }
466
467 auto removeListener = listeners_->addListener(std::move(listener));
468 return RegistrationHandleRAII([guard = this->guard_, this, removeListener] {
469 std::lock_guard<std::mutex> lock(guard_->mutex);
470 if (!guard->alive.load()) {
471 return;
472 }
473 removeListener();
474 });
475 }
476
checkInvariants() const477 void Dispatcher::checkInvariants() const {
478 for (const auto& op : operators_) {
479 op.op.checkInvariants();
480 }
481 }
482
findDanglingImpls() const483 std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const {
484 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorHandle> {
485 std::vector<OperatorHandle> opsWithDanglingImpls;
486 for (const auto& op : operatorLookupTable) {
487 if (!op.second.hasSchema()) {
488 opsWithDanglingImpls.push_back(op.second);
489 }
490 }
491 return opsWithDanglingImpls;
492 });
493 }
494
getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const495 std::vector<OperatorName> Dispatcher::getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const {
496 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
497 std::vector<OperatorName> op_names;
498 for (const auto& op : operatorLookupTable) {
499 // If no DispatchKey is specified, print all of the operators.
500 if (!k || op.second.hasKernelForDispatchKey(*k)) {
501 op_names.push_back(op.first);
502 }
503 }
504 return op_names;
505 });
506 }
507
sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)508 int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) {
509 int64_t seq_num = -1;
510 // Setting sequence number in the Autograd case to associate
511 // the forward range with the corresponding Autograd's node
512
513 // Note: this records a sequence number for both Autograd keys, and for
514 // non-Autograd keys where the dispatchKeySet still contains an autograd key.
515 // This means that we might collect the same sequence nubmer two different
516 // events if they all occurred above Autograd and still had the Autograd
517 // dispatch key in the dispatch key set.
518 // However, this usually doesn't happen: normally the first call will
519 // go through the call() or callBoxed() path in the dispatcher, while
520 // subsequent redispatches go through redispatch() or redispatchBoxed().
521 // `call` has profiler instrumentation, whereas `redispatch` doesn't.
522 // So usually, we'll collect a sequence number on the first call() if the
523 // dispatch keys contain autograd, and not on subsequent redispatches.
524 bool dispatchHasAutograd = !(dispatchKeySet & autograd_dispatch_keyset).empty();
525
526 if (dispatchHasAutograd && at::GradMode::is_enabled()) {
527 seq_num = at::sequence_number::peek();
528 }
529 return seq_num;
530 }
531
runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet,c10::ArrayRef<const c10::IValue> args)532 void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet, c10::ArrayRef<const c10::IValue> args) {
533 guard.before(schema_ref, args, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet));
534 }
535
runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)536 void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) {
537 // Setting sequence number in the Autograd case to associate
538 // the forward range with the corresponding Autograd's node
539 guard.before(schema_ref, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet));
540 }
541 #ifdef FBCODE_CAFFE2
profilingOperatorEvents()542 bool Dispatcher::profilingOperatorEvents() {
543 return TORCH_SDT_IS_ENABLED(operator_start) || TORCH_SDT_IS_ENABLED(operator_end);
544 }
545
fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref)546 C10_NOINLINE void Dispatcher::fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref) {
547 if (TORCH_SDT_IS_ENABLED(operator_start)) {
548 TORCH_SDT_WITH_SEMAPHORE(operator_start, schema_ref.get().name().c_str());
549 }
550 }
551
fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref)552 C10_NOINLINE void Dispatcher::fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref) {
553 if (TORCH_SDT_IS_ENABLED(operator_end)) {
554 TORCH_SDT_WITH_SEMAPHORE(operator_end, schema_ref.get().name().c_str());
555 }
556 }
557 #endif
558
559 }
560