xref: /aosp_15_r20/external/pytorch/aten/src/ATen/record_function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/record_function.h>
2 #include <ATen/core/dispatch/Dispatcher.h>
3 #include <c10/macros/Macros.h>
4 #include <c10/util/ThreadLocal.h>
5 #include <c10/util/overloaded.h>
6 
7 #include <algorithm>
8 #include <cstdlib>
9 #include <random>
10 
11 namespace at {
12 
13 extern const std::string kParamCommsCallName = "record_param_comms";
14 
15 namespace {
16 
17 // Used to generate unique callback handles
next_unique_callback_handle()18 CallbackHandle next_unique_callback_handle() {
19   static std::atomic<uint64_t> unique_cb_id {1};
20   return CallbackHandle(unique_cb_id++);
21 }
22 
next_unique_record_function_handle()23 RecordFunctionHandle next_unique_record_function_handle() {
24   static std::atomic<uint64_t> unique_rf_id {1};
25   return RecordFunctionHandle(unique_rf_id++);
26 }
27 
28 std::atomic<int64_t> defaultNodeId(-1);
29 
30 // Enumerates thread ids logically;
31 // note: std::this_thread::get_id may return potentially
32 // reused thread id
33 std::atomic<uint64_t> next_thread_id_ {0};
34 thread_local uint64_t current_thread_id_ = 0;
35 
36 static constexpr size_t NumRecordScopes =
37     static_cast<size_t>(RecordScope::NUM_SCOPES);
38 
findCallback(RecordFunctionCallbacks & entries,CallbackHandle handle)39 RecordFunctionCallbacks::iterator findCallback(
40     RecordFunctionCallbacks& entries,
41     CallbackHandle handle) {
42   auto match_handle = [handle](const auto& el) { return el.handle_ == handle; };
43   return std::find_if(entries.begin(), entries.end(), match_handle);
44 }
45 
extractCallback(RecordFunctionCallbacks & entries,CallbackHandle handle)46 std::optional<RecordFunctionCallback> extractCallback(
47     RecordFunctionCallbacks& entries,
48     CallbackHandle handle) {
49   auto it = findCallback(entries, handle);
50   if (it == entries.end()) {
51     return std::nullopt;
52   }
53   auto out = it->callback_;
54   entries.erase(it);
55   return out;
56 }
57 
58 // ============================================================================
59 // == Callback manager ========================================================
60 // ============================================================================
61 // The high level idea of the RecordFunction callback machinery is based on the
62 // observation that the set of callbacks to be run changes infrequently.
63 // However, in order to reuse the active set we have to be able to invalidate
64 // when the active set changes. There are three events that can change which
65 // callbacks should be run:
66 //  1) The set of global callbacks changes
67 //  2) The set of local callbacks changes
68 //  3) A sampling callback is present, and should run on this iteration
69 //
70 // Global callbacks rely on thread local replication and an atomic version
71 // counter to maintain consistency. Whenever we change the set of active global
72 // callbacks (add / remove / enable / disable) the `GlobalCallbackManager`
73 // increments the version number and updates the global state while holding
74 // a mutex. The local callback manager snapshots the global callbacks and
75 // lazily rebuilds by comparing`GlobalCallbackManager::version()` (which is
76 // a simple atomic read) to the version of the last rebuild. In the
77 // overwhelmingly common case that they match it can reuse the existing
78 // snapshot. Otherwise it must call the much more expensive (and locked)
79 // `GlobalCallbackManager::getSnapshot()`.
80 //
81 // Handling changes to the thread local callbacks is trivial; functions that
82 // change them can simply force a cache rebuild for that thread after the
83 // changes are made.
84 //
85 // Sampling is by far the most challenging to handle efficiently. In general
86 // sampling callbacks are expected to have very low frequency. (e.g. 1 per
87 // million) Random number generation is rather expensive, so flipping a coin on
88 // every call for every sampling callback is wasteful. We can significantly
89 // reduce this cost by noting that the number of failures of a Bernoulli random
90 // variable is a geometric distribution, and thus we can sample the geometric
91 // distribution to determine the next time a callback should run. This reduces
92 // the cost from a random sample to a simple integer decrement.
93 //
94 // We can further note that Bernoulli samples are independent. (In contrast to,
95 // say, sampling without replacement.) This means that we can generate a
96 // counter for each scope that a given callback supports and then decrement the
97 // counter corresponding to the RecordScope being called. Conceptually, this is
98 // analogous to flipping different coins with the same probability. By sharding
99 // on RecordScope, we can consolidate the decrement to a single shared counter
100 // and update individual counters during rebuild.
101 
102 class GlobalCallbackManager {
103  public:
104   static GlobalCallbackManager& get(); // Singleton
105 
106  private:
107   GlobalCallbackManager() = default;
108 
109  public:
110   static constexpr size_t NoVersion = 0;
111   using snapshot_t = std::pair<size_t, RecordFunctionCallbacks>;
112 
113   //                                                                Locking?
114   size_t version() const; //                                     No
115   snapshot_t getSnapshot() const; //                                Yes
116   CallbackHandle addCallback(RecordFunctionCallback cb); //         Yes
117   void setCallbackEnabled(CallbackHandle handle, bool enabled); //  Yes
118   void removeCallback(CallbackHandle handle); //                    Yes
119   void clearCallbacks(); //                                         Yes
120 
121  private:
122   std::atomic<size_t> version_{NoVersion + 1};
123   RecordFunctionCallbacks global_callbacks_; // Source of truth.
124   mutable std::mutex update_mutex_;
125 };
126 
127 class CacheEntry {
128  public:
129   CacheEntry() = default;
130   CacheEntry(std::mt19937* generator, RecordScope scope);
131 
132   // The caller is expected to check `GlobalCallbackManager::get().version()'
133   // and call CacheEntry::update() if necessary.
134   StepCallbacks getActiveCallbacks();
135   std::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();
136 
137   // Full rebuild. (E.g. during registration)
138   void update(const std::vector<RecordFunctionCallback>& callbacks);
139 
140  private:
141   struct CallbackAndCounter {
142     RecordFunctionCallback callback_;
143 
144     // `-1` indicates that a callback is not sampled.
145     int tries_left_{-1};
146   };
147 
148   C10_ALWAYS_INLINE void getActiveCallbacksImpl();
149 
150   void rebuildActiveCallbacks();
151   int sampleTries(double p) const;
152 
153   // std::mt19937 is quite large, so all scopes share the same generator.
154   std::mt19937* generator_{nullptr};
155 
156   // Includes sampling callbacks which are waiting to run.
157   c10::SmallVector<CallbackAndCounter, kSoftLimitCallbacks> callbacks_;
158   RecordScope scope_{RecordScope::FUNCTION};
159 
160   StepCallbacks active_callbacks_;
161 
162   // For managing sampling callbacks
163   int sampling_countdown_{0};
164   int steps_for_this_update_{0};
165 };
166 
167 class LocalCallbackManager {
168  public:
169   static LocalCallbackManager& get(); // Singleton
170 
171  private:
172   LocalCallbackManager();
173 
174  public:
175   const RecordFunctionTLS& getTLS() const;
176   StepCallbacks getActiveCallbacks(const RecordScope scope);
177   std::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);
178 
179   void setTLS(const RecordFunctionTLS& tls);
180   void seed(uint32_t seed);
181   CallbackHandle addCallback(RecordFunctionCallback callback);
182   bool setCallbackEnabled(CallbackHandle handle, bool enabled);
183   bool removeCallback(CallbackHandle handle);
184   void clearCallbacks();
185 
186  private:
187   void rebuildActiveCallbacksIfNeeded();
188 
189   void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);
190 
191   void rebuild_callback_scopes(
192       const GlobalCallbackManager::snapshot_t& global_snapshot,
193       const RecordFunctionCallback& callback);
194 
195   void rebuild_scope(
196       const GlobalCallbackManager::snapshot_t& global_snapshot,
197       const RecordScope scope);
198 
199   // Source of truth.
200   RecordFunctionTLS registered_callbacks_;
201 
202   // Runtime cache.
203   size_t global_version_{GlobalCallbackManager::NoVersion};
204   std::array<CacheEntry, NumRecordScopes> active_callbacks_;
205   std::mt19937 generator_{};
206 };
207 
208 // ============================================================================
209 // == GlobalCallbackManager: Implementation ===================================
210 // ============================================================================
get()211 GlobalCallbackManager& GlobalCallbackManager::get() {
212   static GlobalCallbackManager manager;
213   return manager;
214 }
215 
version() const216 size_t GlobalCallbackManager::version() const {
217   return version_.load(std::memory_order_relaxed);
218 }
219 
getSnapshot() const220 std::pair<size_t, RecordFunctionCallbacks> GlobalCallbackManager::getSnapshot() const {
221   std::lock_guard<std::mutex> guard(update_mutex_);
222   return {version_.load(std::memory_order_seq_cst), global_callbacks_};
223 }
224 
addCallback(RecordFunctionCallback cb)225 CallbackHandle GlobalCallbackManager::addCallback(RecordFunctionCallback cb) {
226   std::lock_guard<std::mutex> guard(update_mutex_);
227   ++version_;
228   auto handle = next_unique_callback_handle();
229   global_callbacks_.emplace_back(cb, handle);
230   return handle;
231 }
232 
setCallbackEnabled(CallbackHandle handle,bool enabled)233 void GlobalCallbackManager::setCallbackEnabled(
234     CallbackHandle handle,
235     bool enabled) {
236   std::lock_guard<std::mutex> guard(update_mutex_);
237   auto it = findCallback(global_callbacks_, handle);
238   if (it != global_callbacks_.end()) {
239     if (it->enabled_ != enabled) {
240       ++version_;
241       it->enabled_ = enabled;
242     }
243   } else {
244     LOG(WARNING) << "Requested callback is not found";
245   }
246 }
247 
removeCallback(CallbackHandle handle)248 void GlobalCallbackManager::removeCallback(CallbackHandle handle) {
249   std::lock_guard<std::mutex> guard(update_mutex_);
250   if (extractCallback(global_callbacks_, handle).has_value()) {
251     ++version_;
252   } else {
253     LOG(WARNING) << "Requested callback is not found";
254   }
255 }
256 
clearCallbacks()257 void GlobalCallbackManager::clearCallbacks() {
258   std::lock_guard<std::mutex> guard(update_mutex_);
259   ++version_;
260   global_callbacks_.clear();
261 }
262 
263 // ============================================================================
264 // == CacheEntry: Implementation ==============================================
265 // ============================================================================
CacheEntry(std::mt19937 * generator,RecordScope scope)266 CacheEntry::CacheEntry(std::mt19937* generator, RecordScope scope)
267     : generator_{generator}, scope_{scope} {
268   rebuildActiveCallbacks();
269 }
270 
update(const std::vector<RecordFunctionCallback> & callbacks)271 void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
272   callbacks_.clear();
273   callbacks_.reserve(callbacks.size());
274   for (const auto& callback : callbacks) {
275     const auto p = callback.samplingProb();
276     callbacks_.push_back({callback, p < 1.0 ? sampleTries(p) : -1});
277   }
278 
279   rebuildActiveCallbacks();
280 }
281 
getActiveCallbacksImpl()282 void CacheEntry::getActiveCallbacksImpl() {
283   // We rebuild the active set when `sampling_countdown_` reaches zero, so if it
284   // reaches zero at the start of this function something has gone wrong.
285   TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
286 
287   if (C10_UNLIKELY(!(--sampling_countdown_))) {
288     // Use inferred steps to update sampled callbacks.
289     for (auto& i : callbacks_) {
290       if (i.tries_left_ > 0) {
291         TORCH_INTERNAL_ASSERT(i.tries_left_ >= steps_for_this_update_);
292         i.tries_left_ -= steps_for_this_update_;
293       }
294     }
295 
296     // Determine which callbacks to run and for how long.
297     rebuildActiveCallbacks();
298 
299     // Resample any sampled callbacks that ran this call.
300     for (auto& i : callbacks_) {
301       if (!i.tries_left_) {
302         i.tries_left_ = sampleTries(i.callback_.samplingProb());
303       }
304     }
305   }
306 }
307 
getActiveCallbacks()308 StepCallbacks CacheEntry::getActiveCallbacks() {
309   getActiveCallbacksImpl();
310   return active_callbacks_;
311 }
312 
getActiveCallbacksUnlessEmpty()313 std::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
314   getActiveCallbacksImpl();
315   if (C10_LIKELY(active_callbacks_.empty())) {
316     return std::nullopt;
317   }
318   return active_callbacks_;
319 }
320 
rebuildActiveCallbacks()321 void CacheEntry::rebuildActiveCallbacks() {
322   // We could store thread ID in CacheEntry, but rebuilds are infrequent and
323   // this saves us from having to plumb it through.
324   const auto thread_id = RecordFunction::currentThreadId();
325   active_callbacks_ = StepCallbacks(thread_id, scope_);
326 
327   sampling_countdown_ = std::numeric_limits<int>::max();
328   for (const auto& i : callbacks_) {
329     if (i.tries_left_ < 0) {
330       // Callback is not sampled. Unconditionally push.
331       active_callbacks_.callbacks_.push_back(
332           {i.callback_.start(), i.callback_.end()});
333 
334     } else if (i.tries_left_ == 0) {
335       // Callback is sampled and we have reached a sampling event. Push and
336       // set `sampling_countdown_` to one so we trigger a rebuild after one call.
337       active_callbacks_.callbacks_.push_back(
338           {i.callback_.start(), i.callback_.end()});
339       sampling_countdown_ = 1;
340 
341     } else {
342       // Callback is sampled and we have not reached sampling event. Set
343       // `sampling_countdown_` to rebuild when it is time for this callback to
344       // execute.
345       sampling_countdown_ = std::min(sampling_countdown_, i.tries_left_);
346     }
347     active_callbacks_.needs_inputs_ |= i.callback_.needsInputs();
348     active_callbacks_.needs_outputs_ |= i.callback_.needsOutputs();
349     active_callbacks_.needs_ids_ |= i.callback_.needsIds();
350   }
351   steps_for_this_update_ = sampling_countdown_;
352 }
353 
sampleTries(double p) const354 int CacheEntry::sampleTries(double p) const {
355   TORCH_INTERNAL_ASSERT(generator_ != nullptr);
356   TORCH_INTERNAL_ASSERT(p > 0.0 && p <= 1.0);
357 
358   // The geometric distribution returns the number of failures. We add one to
359   // also account for the call where we succeed.
360   return std::geometric_distribution<int>(p)(*generator_) + 1;
361 }
362 
363 // ============================================================================
364 // == LocalCallbackManager: Implementation ====================================
365 // ============================================================================
get()366 LocalCallbackManager& LocalCallbackManager::get() {
367 #if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
368   static c10::ThreadLocal<LocalCallbackManager> manager;
369   return manager.get();
370 #else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
371   static thread_local LocalCallbackManager manager;
372   return manager;
373 #endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
374 }
375 
LocalCallbackManager()376 LocalCallbackManager::LocalCallbackManager() {
377   for (auto i : c10::irange(NumRecordScopes)) {
378     active_callbacks_[i] = CacheEntry(&generator_, static_cast<RecordScope>(i));
379   }
380   rebuild_all(GlobalCallbackManager::get().getSnapshot());
381 }
382 
getTLS() const383 const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
384   return registered_callbacks_;
385 }
386 
rebuildActiveCallbacksIfNeeded()387 void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
388   const auto global_version = GlobalCallbackManager::get().version();
389   if (C10_UNLIKELY(global_version != global_version_)) {
390     rebuild_all(GlobalCallbackManager::get().getSnapshot());
391   }
392 }
393 
getActiveCallbacks(const RecordScope scope)394 StepCallbacks LocalCallbackManager::getActiveCallbacks(
395     const RecordScope scope) {
396   rebuildActiveCallbacksIfNeeded();
397   return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
398 }
399 
getActiveCallbacksUnlessEmpty(const RecordScope scope)400 std::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
401     const RecordScope scope) {
402   rebuildActiveCallbacksIfNeeded();
403   return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
404 }
405 
setTLS(const RecordFunctionTLS & tls)406 void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
407   registered_callbacks_ = tls;
408   rebuild_all(GlobalCallbackManager::get().getSnapshot());
409 }
410 
seed(uint32_t seed)411 void LocalCallbackManager::seed(uint32_t seed) {
412   generator_.seed(seed);
413 }
414 
addCallback(RecordFunctionCallback callback)415 CallbackHandle LocalCallbackManager::addCallback(
416     RecordFunctionCallback callback) {
417   auto handle = next_unique_callback_handle();
418   auto& callbacks = registered_callbacks_.sorted_tls_callbacks_;
419   callbacks.emplace_back(callback, handle);
420   rebuild_callback_scopes(
421       GlobalCallbackManager::get().getSnapshot(), callbacks.back().callback_);
422   return handle;
423 }
424 
setCallbackEnabled(CallbackHandle handle,bool enabled)425 bool LocalCallbackManager::setCallbackEnabled(
426     CallbackHandle handle,
427     bool enabled) {
428   auto it = findCallback(registered_callbacks_.sorted_tls_callbacks_, handle);
429   auto found = (it != registered_callbacks_.sorted_tls_callbacks_.end());
430   if (found && it->enabled_ != enabled) {
431     it->enabled_ = enabled;
432     rebuild_callback_scopes(
433         GlobalCallbackManager::get().getSnapshot(), it->callback_);
434   }
435   return found;
436 }
437 
removeCallback(CallbackHandle handle)438 bool LocalCallbackManager::removeCallback(CallbackHandle handle) {
439   auto& callbacks = registered_callbacks_.sorted_tls_callbacks_;
440   auto callback = extractCallback(callbacks, handle);
441   if (callback.has_value()) {
442     rebuild_callback_scopes(
443         GlobalCallbackManager::get().getSnapshot(), *callback);
444   }
445   return callback.has_value();
446 }
447 
clearCallbacks()448 void LocalCallbackManager::clearCallbacks() {
449   registered_callbacks_.sorted_tls_callbacks_.clear();
450   rebuild_all(GlobalCallbackManager::get().getSnapshot());
451 }
452 
rebuild_all(const GlobalCallbackManager::snapshot_t & global_snapshot)453 void LocalCallbackManager::rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot) {
454   global_version_ = global_snapshot.first;
455   for (auto i : c10::irange(NumRecordScopes)) {
456     rebuild_scope(global_snapshot, static_cast<RecordScope>(i));
457   }
458 }
459 
rebuild_callback_scopes(const GlobalCallbackManager::snapshot_t & global_snapshot,const RecordFunctionCallback & callback)460 void LocalCallbackManager::rebuild_callback_scopes(
461     const GlobalCallbackManager::snapshot_t& global_snapshot,
462     const RecordFunctionCallback& callback) {
463   if (global_snapshot.first == global_version_) {
464     // Only rebuild scopes associated with `callback`
465     for (auto i : c10::irange(NumRecordScopes)) {
466       if (callback.checkScope(static_cast<RecordScope>(i))) {
467         rebuild_scope(global_snapshot, static_cast<RecordScope>(i));
468       }
469     }
470   } else {
471     rebuild_all(global_snapshot);
472   }
473 }
474 
rebuild_scope(const GlobalCallbackManager::snapshot_t & global_snapshot,const RecordScope scope)475 void LocalCallbackManager::rebuild_scope(
476     const GlobalCallbackManager::snapshot_t& global_snapshot,
477     const RecordScope scope) {
478   std::vector<RecordFunctionCallback> callbacks;
479   if (registered_callbacks_.tls_record_function_enabled_) {
480     auto populate_callbacks =
481         [&](const RecordFunctionCallbacks& raw_callbacks) {
482           for (const auto& i : raw_callbacks) {
483             if (i.enabled_ && i.callback_.checkScope(scope) &&
484                 i.callback_.samplingProb() > 0) {
485               callbacks.push_back(i.callback_);
486             }
487           }
488         };
489     populate_callbacks(global_snapshot.second);
490     populate_callbacks(registered_callbacks_.sorted_tls_callbacks_);
491   }
492   active_callbacks_[static_cast<size_t>(scope)].update(callbacks);
493 }
494 
495 // ============================================================================
496 // == Callback execution ======================================================
497 // ============================================================================
logTryRunCallbackError(const char * what,const char * name)498 void logTryRunCallbackError(const char* what, const char* name) {
499   LOG(WARNING) << "Exception in RecordFunction callback: " << what
500                << " , for the range " << name;
501 }
502 
503 template <bool is_start>
tryRunCallback(const StepCallbacks::StartEndPair callback_ptrs,const RecordFunction & rf,std::unique_ptr<ObserverContext> & ctx)504 C10_ALWAYS_INLINE bool tryRunCallback(
505     const StepCallbacks::StartEndPair callback_ptrs,
506     const RecordFunction& rf,
507     std::unique_ptr<ObserverContext>& ctx) {
508   try {
509     if (is_start && callback_ptrs.start_) {
510       ctx = callback_ptrs.start_(rf);
511     }
512 
513     if (!is_start && callback_ptrs.end_) {
514       callback_ptrs.end_(rf, ctx.get());
515     }
516 
517     return true;
518   } catch (const std::exception& e) {
519     logTryRunCallbackError(e.what(), rf.name());
520     return false;
521   } catch (...) {
522     logTryRunCallbackError("unknown", rf.name());
523     return false;
524   }
525 }
526 
527 } // namespace
528 
RecordFunction(RecordScope scope)529 RecordFunction::RecordFunction(RecordScope scope)
530     : RecordFunction(getStepCallbacks(scope)) {}
531 
RecordFunction(StepCallbacks && step_callbacks)532 RecordFunction::RecordFunction(StepCallbacks&& step_callbacks)
533     : step_callbacks_{std::move(step_callbacks)} {
534   ctx_.resize(step_callbacks_.callbacks_.size());
535   if (step_callbacks_.needs_ids_) {
536     setHandle(next_unique_record_function_handle());
537   }
538 }
539 
runStartCallbacks()540 void RecordFunction::runStartCallbacks() {
541   for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) {
542     tryRunCallback</*is_start=*/true>(
543         step_callbacks_.callbacks_[i], *this, ctx_[i]);
544   }
545   called_start_callbacks_ = true;
546 }
547 
end()548 void RecordFunction::end() {
549   if (called_start_callbacks_) {
550     for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) {
551       tryRunCallback</*is_start=*/false>(
552         step_callbacks_.callbacks_[i], *this, ctx_[i]);
553     }
554     step_callbacks_.callbacks_.clear();
555   }
556 }
557 
name() const558 const char* RecordFunction::name() const {
559   return std::visit(
560       c10::overloaded(
561           [](const std::string& name) { return name.c_str(); },
562           [](const schema_ref_t schema) {
563             return schema.get().name().c_str();
564           }),
565       fn_);
566 }
567 
num_inputs() const568 size_t RecordFunction::num_inputs() const {
569   return std::visit(
570       c10::overloaded(
571           [&](const std::string&) { return inputs_.size(); },
572           [](const schema_ref_t schema) {
573             return schema.get().arguments().size();
574           }),
575       fn_);
576 }
577 
num_outputs() const578 size_t RecordFunction::num_outputs() const {
579   return std::visit(
580       c10::overloaded(
581           [&](const std::string&) { return outputs_.size(); },
582           [](const schema_ref_t schema) {
583             return schema.get().returns().size();
584           }),
585       fn_);
586 }
587 
operator_name() const588 std::optional<OperatorName> RecordFunction::operator_name() const {
589   return std::visit(
590       c10::overloaded(
591           [&](const std::string&) -> std::optional<OperatorName> {
592             return std::nullopt;
593           },
594           [](const schema_ref_t schema) -> std::optional<OperatorName> {
595             return schema.get().operator_name();
596           }),
597       fn_);
598 }
599 
operator_schema() const600 std::optional<c10::FunctionSchema> RecordFunction::operator_schema() const {
601   return std::visit(
602       c10::overloaded(
603           [&](const std::string&) -> std::optional<c10::FunctionSchema> {
604             return std::nullopt;
605           },
606           [](const schema_ref_t schema) -> std::optional<c10::FunctionSchema> {
607             return schema.get();
608           }),
609       fn_);
610 }
611 
getStepCallbacks(RecordScope scope)612 StepCallbacks getStepCallbacks(RecordScope scope) {
613   return LocalCallbackManager::get().getActiveCallbacks(scope);
614 }
615 
getStepCallbacksUnlessEmpty(RecordScope scope)616 std::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
617   return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
618 }
619 
get_record_function_tls_()620 const RecordFunctionTLS& get_record_function_tls_() {
621   return LocalCallbackManager::get().getTLS();
622 }
623 
set_record_function_tls_(const RecordFunctionTLS & tls)624 void set_record_function_tls_(const RecordFunctionTLS& tls) {
625   LocalCallbackManager::get().setTLS(tls);
626 }
627 
628 namespace {
anyEnabled(const RecordFunctionCallbacks & callbacks)629 bool anyEnabled(const RecordFunctionCallbacks& callbacks) {
630   return std::any_of(callbacks.begin(), callbacks.end(), [](const auto& cb) {
631     return cb.enabled_;
632   });
633 }
634 } // namespace
635 
hasCallbacks()636 bool hasCallbacks() {
637   return hasThreadLocalCallbacks() || hasGlobalCallbacks();
638 }
639 
hasGlobalCallbacks()640 bool hasGlobalCallbacks() {
641   return anyEnabled(GlobalCallbackManager::get().getSnapshot().second);
642 }
643 
hasThreadLocalCallbacks()644 bool hasThreadLocalCallbacks() {
645   return anyEnabled(get_record_function_tls_().sorted_tls_callbacks_);
646 }
647 
addThreadLocalCallback(RecordFunctionCallback cb)648 CallbackHandle addThreadLocalCallback(
649     RecordFunctionCallback cb) {
650   return LocalCallbackManager::get().addCallback(cb);
651 }
652 
addGlobalCallback(RecordFunctionCallback cb)653 CallbackHandle addGlobalCallback(
654     RecordFunctionCallback cb) {
655   return GlobalCallbackManager::get().addCallback(cb);
656 }
657 
removeCallback(CallbackHandle handle)658 void removeCallback(CallbackHandle handle) {
659   if (!LocalCallbackManager::get().removeCallback(handle)) {
660     GlobalCallbackManager::get().removeCallback(handle);
661   }
662 }
663 
disableCallback(CallbackHandle handle)664 void disableCallback(CallbackHandle handle) {
665   if (!LocalCallbackManager::get().setCallbackEnabled(handle, false)) {
666     GlobalCallbackManager::get().setCallbackEnabled(handle, false);
667   }
668 }
669 
reenableCallback(CallbackHandle handle)670 void reenableCallback(CallbackHandle handle) {
671   if (!LocalCallbackManager::get().setCallbackEnabled(handle, true)) {
672     GlobalCallbackManager::get().setCallbackEnabled(handle, true);
673   }
674 }
675 
clearGlobalCallbacks()676 void clearGlobalCallbacks() {
677   GlobalCallbackManager::get().clearCallbacks();
678 }
679 
clearThreadLocalCallbacks()680 void clearThreadLocalCallbacks() {
681   LocalCallbackManager::get().clearCallbacks();
682 }
683 
clearCallbacks()684 void clearCallbacks() {
685   clearGlobalCallbacks();
686   clearThreadLocalCallbacks();
687 }
688 
isRecordFunctionEnabled()689 bool isRecordFunctionEnabled() {
690   return LocalCallbackManager::get().getTLS().tls_record_function_enabled_;
691 }
692 
enableRecordFunction(bool enable)693 void enableRecordFunction(bool enable) {
694   auto tls = LocalCallbackManager::get().getTLS();
695   if (tls.tls_record_function_enabled_ != enable) {
696     tls.tls_record_function_enabled_ = enable;
697     LocalCallbackManager::get().setTLS(tls);
698   }
699 }
700 
set_record_function_seed_for_testing(uint32_t seed)701 void set_record_function_seed_for_testing(uint32_t seed) {
702   LocalCallbackManager::get().seed(seed);
703 }
704 
705 /* static */
currentThreadId()706 uint64_t RecordFunction::currentThreadId() {
707   if (!current_thread_id_) {
708     // happens only once per thread
709     current_thread_id_ = ++next_thread_id_;
710   }
711   return current_thread_id_;
712 }
713 
before(const char * name,int64_t sequence_nr)714 void RecordFunction::before(const char* name, int64_t sequence_nr) {
715   fn_ = name;
716   sequence_nr_ = sequence_nr;
717   is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0);
718 
719 #ifndef NDEBUG
720     inputs_valid_ = true;
721 #endif
722   runStartCallbacks();
723   invalidateInputs();
724 }
725 
before(std::string name,int64_t sequence_nr)726 void RecordFunction::before(std::string name, int64_t sequence_nr) {
727   is_nccl_meta_ = (name == kParamCommsCallName);
728   fn_ = std::move(name);
729   sequence_nr_ = sequence_nr;
730 
731 #ifndef NDEBUG
732     inputs_valid_ = true;
733 #endif
734   runStartCallbacks();
735   invalidateInputs();
736 }
737 
before(RecordFunction::schema_ref_t schema,int64_t sequence_nr)738 void RecordFunction::before(
739     RecordFunction::schema_ref_t schema,
740     int64_t sequence_nr) {
741   sequence_nr_ = sequence_nr;
742   fn_ = schema;
743   is_nccl_meta_ = (schema.get().name() == kParamCommsCallName);
744 
745 #ifndef NDEBUG
746     inputs_valid_ = true;
747 #endif
748   runStartCallbacks();
749   invalidateInputs();
750 }
751 
setDefaultNodeId(int64_t newDefaultNodeId)752 /* static */ void RecordFunction::setDefaultNodeId(int64_t newDefaultNodeId) {
753   TORCH_CHECK(newDefaultNodeId >= 0, "setDefaultNodeId expects an id >= 0.");
754   defaultNodeId = newDefaultNodeId;
755 }
756 
getDefaultNodeId()757 /* static */ int64_t RecordFunction::getDefaultNodeId() {
758   return defaultNodeId;
759 }
760 
~RecordFunction()761 RecordFunction::~RecordFunction() {
762   end();
763 }
764 
_setAsync()765 void RecordFunction::_setAsync() {
766   is_async_ = true;
767 }
768 
isAsync() const769 bool RecordFunction::isAsync() const {
770   return is_async_;
771 }
772 
_setStaticRuntimeOutVariant()773 void RecordFunction::_setStaticRuntimeOutVariant() {
774   if (isActive()) {
775     is_static_runtime_out_variant_ = true;
776   }
777 }
778 
isStaticRuntimeOutVariant() const779 bool RecordFunction::isStaticRuntimeOutVariant() const {
780   if (isActive()) {
781     return is_static_runtime_out_variant_;
782   }
783   return false;
784 }
785 } // namespace at
786