1 #include <ATen/record_function.h>
2 #include <ATen/core/dispatch/Dispatcher.h>
3 #include <c10/macros/Macros.h>
4 #include <c10/util/ThreadLocal.h>
5 #include <c10/util/overloaded.h>
6
7 #include <algorithm>
8 #include <cstdlib>
9 #include <random>
10
11 namespace at {
12
13 extern const std::string kParamCommsCallName = "record_param_comms";
14
15 namespace {
16
17 // Used to generate unique callback handles
next_unique_callback_handle()18 CallbackHandle next_unique_callback_handle() {
19 static std::atomic<uint64_t> unique_cb_id {1};
20 return CallbackHandle(unique_cb_id++);
21 }
22
next_unique_record_function_handle()23 RecordFunctionHandle next_unique_record_function_handle() {
24 static std::atomic<uint64_t> unique_rf_id {1};
25 return RecordFunctionHandle(unique_rf_id++);
26 }
27
28 std::atomic<int64_t> defaultNodeId(-1);
29
30 // Enumerates thread ids logically;
31 // note: std::this_thread::get_id may return potentially
32 // reused thread id
33 std::atomic<uint64_t> next_thread_id_ {0};
34 thread_local uint64_t current_thread_id_ = 0;
35
36 static constexpr size_t NumRecordScopes =
37 static_cast<size_t>(RecordScope::NUM_SCOPES);
38
findCallback(RecordFunctionCallbacks & entries,CallbackHandle handle)39 RecordFunctionCallbacks::iterator findCallback(
40 RecordFunctionCallbacks& entries,
41 CallbackHandle handle) {
42 auto match_handle = [handle](const auto& el) { return el.handle_ == handle; };
43 return std::find_if(entries.begin(), entries.end(), match_handle);
44 }
45
extractCallback(RecordFunctionCallbacks & entries,CallbackHandle handle)46 std::optional<RecordFunctionCallback> extractCallback(
47 RecordFunctionCallbacks& entries,
48 CallbackHandle handle) {
49 auto it = findCallback(entries, handle);
50 if (it == entries.end()) {
51 return std::nullopt;
52 }
53 auto out = it->callback_;
54 entries.erase(it);
55 return out;
56 }
57
58 // ============================================================================
59 // == Callback manager ========================================================
60 // ============================================================================
61 // The high level idea of the RecordFunction callback machinery is based on the
62 // observation that the set of callbacks to be run changes infrequently.
63 // However, in order to reuse the active set we have to be able to invalidate
64 // when the active set changes. There are three events that can change which
65 // callbacks should be run:
66 // 1) The set of global callbacks changes
67 // 2) The set of local callbacks changes
68 // 3) A sampling callback is present, and should run on this iteration
69 //
70 // Global callbacks rely on thread local replication and an atomic version
71 // counter to maintain consistency. Whenever we change the set of active global
72 // callbacks (add / remove / enable / disable) the `GlobalCallbackManager`
73 // increments the version number and updates the global state while holding
74 // a mutex. The local callback manager snapshots the global callbacks and
75 // lazily rebuilds by comparing`GlobalCallbackManager::version()` (which is
76 // a simple atomic read) to the version of the last rebuild. In the
77 // overwhelmingly common case that they match it can reuse the existing
78 // snapshot. Otherwise it must call the much more expensive (and locked)
79 // `GlobalCallbackManager::getSnapshot()`.
80 //
81 // Handling changes to the thread local callbacks is trivial; functions that
82 // change them can simply force a cache rebuild for that thread after the
83 // changes are made.
84 //
85 // Sampling is by far the most challenging to handle efficiently. In general
86 // sampling callbacks are expected to have very low frequency. (e.g. 1 per
87 // million) Random number generation is rather expensive, so flipping a coin on
88 // every call for every sampling callback is wasteful. We can significantly
89 // reduce this cost by noting that the number of failures of a Bernoulli random
90 // variable is a geometric distribution, and thus we can sample the geometric
91 // distribution to determine the next time a callback should run. This reduces
92 // the cost from a random sample to a simple integer decrement.
93 //
94 // We can further note that Bernoulli samples are independent. (In contrast to,
95 // say, sampling without replacement.) This means that we can generate a
96 // counter for each scope that a given callback supports and then decrement the
97 // counter corresponding to the RecordScope being called. Conceptually, this is
98 // analogous to flipping different coins with the same probability. By sharding
99 // on RecordScope, we can consolidate the decrement to a single shared counter
100 // and update individual counters during rebuild.
101
102 class GlobalCallbackManager {
103 public:
104 static GlobalCallbackManager& get(); // Singleton
105
106 private:
107 GlobalCallbackManager() = default;
108
109 public:
110 static constexpr size_t NoVersion = 0;
111 using snapshot_t = std::pair<size_t, RecordFunctionCallbacks>;
112
113 // Locking?
114 size_t version() const; // No
115 snapshot_t getSnapshot() const; // Yes
116 CallbackHandle addCallback(RecordFunctionCallback cb); // Yes
117 void setCallbackEnabled(CallbackHandle handle, bool enabled); // Yes
118 void removeCallback(CallbackHandle handle); // Yes
119 void clearCallbacks(); // Yes
120
121 private:
122 std::atomic<size_t> version_{NoVersion + 1};
123 RecordFunctionCallbacks global_callbacks_; // Source of truth.
124 mutable std::mutex update_mutex_;
125 };
126
127 class CacheEntry {
128 public:
129 CacheEntry() = default;
130 CacheEntry(std::mt19937* generator, RecordScope scope);
131
132 // The caller is expected to check `GlobalCallbackManager::get().version()'
133 // and call CacheEntry::update() if necessary.
134 StepCallbacks getActiveCallbacks();
135 std::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();
136
137 // Full rebuild. (E.g. during registration)
138 void update(const std::vector<RecordFunctionCallback>& callbacks);
139
140 private:
141 struct CallbackAndCounter {
142 RecordFunctionCallback callback_;
143
144 // `-1` indicates that a callback is not sampled.
145 int tries_left_{-1};
146 };
147
148 C10_ALWAYS_INLINE void getActiveCallbacksImpl();
149
150 void rebuildActiveCallbacks();
151 int sampleTries(double p) const;
152
153 // std::mt19937 is quite large, so all scopes share the same generator.
154 std::mt19937* generator_{nullptr};
155
156 // Includes sampling callbacks which are waiting to run.
157 c10::SmallVector<CallbackAndCounter, kSoftLimitCallbacks> callbacks_;
158 RecordScope scope_{RecordScope::FUNCTION};
159
160 StepCallbacks active_callbacks_;
161
162 // For managing sampling callbacks
163 int sampling_countdown_{0};
164 int steps_for_this_update_{0};
165 };
166
167 class LocalCallbackManager {
168 public:
169 static LocalCallbackManager& get(); // Singleton
170
171 private:
172 LocalCallbackManager();
173
174 public:
175 const RecordFunctionTLS& getTLS() const;
176 StepCallbacks getActiveCallbacks(const RecordScope scope);
177 std::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);
178
179 void setTLS(const RecordFunctionTLS& tls);
180 void seed(uint32_t seed);
181 CallbackHandle addCallback(RecordFunctionCallback callback);
182 bool setCallbackEnabled(CallbackHandle handle, bool enabled);
183 bool removeCallback(CallbackHandle handle);
184 void clearCallbacks();
185
186 private:
187 void rebuildActiveCallbacksIfNeeded();
188
189 void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);
190
191 void rebuild_callback_scopes(
192 const GlobalCallbackManager::snapshot_t& global_snapshot,
193 const RecordFunctionCallback& callback);
194
195 void rebuild_scope(
196 const GlobalCallbackManager::snapshot_t& global_snapshot,
197 const RecordScope scope);
198
199 // Source of truth.
200 RecordFunctionTLS registered_callbacks_;
201
202 // Runtime cache.
203 size_t global_version_{GlobalCallbackManager::NoVersion};
204 std::array<CacheEntry, NumRecordScopes> active_callbacks_;
205 std::mt19937 generator_{};
206 };
207
208 // ============================================================================
209 // == GlobalCallbackManager: Implementation ===================================
210 // ============================================================================
get()211 GlobalCallbackManager& GlobalCallbackManager::get() {
212 static GlobalCallbackManager manager;
213 return manager;
214 }
215
version() const216 size_t GlobalCallbackManager::version() const {
217 return version_.load(std::memory_order_relaxed);
218 }
219
getSnapshot() const220 std::pair<size_t, RecordFunctionCallbacks> GlobalCallbackManager::getSnapshot() const {
221 std::lock_guard<std::mutex> guard(update_mutex_);
222 return {version_.load(std::memory_order_seq_cst), global_callbacks_};
223 }
224
addCallback(RecordFunctionCallback cb)225 CallbackHandle GlobalCallbackManager::addCallback(RecordFunctionCallback cb) {
226 std::lock_guard<std::mutex> guard(update_mutex_);
227 ++version_;
228 auto handle = next_unique_callback_handle();
229 global_callbacks_.emplace_back(cb, handle);
230 return handle;
231 }
232
setCallbackEnabled(CallbackHandle handle,bool enabled)233 void GlobalCallbackManager::setCallbackEnabled(
234 CallbackHandle handle,
235 bool enabled) {
236 std::lock_guard<std::mutex> guard(update_mutex_);
237 auto it = findCallback(global_callbacks_, handle);
238 if (it != global_callbacks_.end()) {
239 if (it->enabled_ != enabled) {
240 ++version_;
241 it->enabled_ = enabled;
242 }
243 } else {
244 LOG(WARNING) << "Requested callback is not found";
245 }
246 }
247
removeCallback(CallbackHandle handle)248 void GlobalCallbackManager::removeCallback(CallbackHandle handle) {
249 std::lock_guard<std::mutex> guard(update_mutex_);
250 if (extractCallback(global_callbacks_, handle).has_value()) {
251 ++version_;
252 } else {
253 LOG(WARNING) << "Requested callback is not found";
254 }
255 }
256
clearCallbacks()257 void GlobalCallbackManager::clearCallbacks() {
258 std::lock_guard<std::mutex> guard(update_mutex_);
259 ++version_;
260 global_callbacks_.clear();
261 }
262
263 // ============================================================================
264 // == CacheEntry: Implementation ==============================================
265 // ============================================================================
CacheEntry(std::mt19937 * generator,RecordScope scope)266 CacheEntry::CacheEntry(std::mt19937* generator, RecordScope scope)
267 : generator_{generator}, scope_{scope} {
268 rebuildActiveCallbacks();
269 }
270
update(const std::vector<RecordFunctionCallback> & callbacks)271 void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
272 callbacks_.clear();
273 callbacks_.reserve(callbacks.size());
274 for (const auto& callback : callbacks) {
275 const auto p = callback.samplingProb();
276 callbacks_.push_back({callback, p < 1.0 ? sampleTries(p) : -1});
277 }
278
279 rebuildActiveCallbacks();
280 }
281
getActiveCallbacksImpl()282 void CacheEntry::getActiveCallbacksImpl() {
283 // We rebuild the active set when `sampling_countdown_` reaches zero, so if it
284 // reaches zero at the start of this function something has gone wrong.
285 TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
286
287 if (C10_UNLIKELY(!(--sampling_countdown_))) {
288 // Use inferred steps to update sampled callbacks.
289 for (auto& i : callbacks_) {
290 if (i.tries_left_ > 0) {
291 TORCH_INTERNAL_ASSERT(i.tries_left_ >= steps_for_this_update_);
292 i.tries_left_ -= steps_for_this_update_;
293 }
294 }
295
296 // Determine which callbacks to run and for how long.
297 rebuildActiveCallbacks();
298
299 // Resample any sampled callbacks that ran this call.
300 for (auto& i : callbacks_) {
301 if (!i.tries_left_) {
302 i.tries_left_ = sampleTries(i.callback_.samplingProb());
303 }
304 }
305 }
306 }
307
getActiveCallbacks()308 StepCallbacks CacheEntry::getActiveCallbacks() {
309 getActiveCallbacksImpl();
310 return active_callbacks_;
311 }
312
getActiveCallbacksUnlessEmpty()313 std::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
314 getActiveCallbacksImpl();
315 if (C10_LIKELY(active_callbacks_.empty())) {
316 return std::nullopt;
317 }
318 return active_callbacks_;
319 }
320
rebuildActiveCallbacks()321 void CacheEntry::rebuildActiveCallbacks() {
322 // We could store thread ID in CacheEntry, but rebuilds are infrequent and
323 // this saves us from having to plumb it through.
324 const auto thread_id = RecordFunction::currentThreadId();
325 active_callbacks_ = StepCallbacks(thread_id, scope_);
326
327 sampling_countdown_ = std::numeric_limits<int>::max();
328 for (const auto& i : callbacks_) {
329 if (i.tries_left_ < 0) {
330 // Callback is not sampled. Unconditionally push.
331 active_callbacks_.callbacks_.push_back(
332 {i.callback_.start(), i.callback_.end()});
333
334 } else if (i.tries_left_ == 0) {
335 // Callback is sampled and we have reached a sampling event. Push and
336 // set `sampling_countdown_` to one so we trigger a rebuild after one call.
337 active_callbacks_.callbacks_.push_back(
338 {i.callback_.start(), i.callback_.end()});
339 sampling_countdown_ = 1;
340
341 } else {
342 // Callback is sampled and we have not reached sampling event. Set
343 // `sampling_countdown_` to rebuild when it is time for this callback to
344 // execute.
345 sampling_countdown_ = std::min(sampling_countdown_, i.tries_left_);
346 }
347 active_callbacks_.needs_inputs_ |= i.callback_.needsInputs();
348 active_callbacks_.needs_outputs_ |= i.callback_.needsOutputs();
349 active_callbacks_.needs_ids_ |= i.callback_.needsIds();
350 }
351 steps_for_this_update_ = sampling_countdown_;
352 }
353
sampleTries(double p) const354 int CacheEntry::sampleTries(double p) const {
355 TORCH_INTERNAL_ASSERT(generator_ != nullptr);
356 TORCH_INTERNAL_ASSERT(p > 0.0 && p <= 1.0);
357
358 // The geometric distribution returns the number of failures. We add one to
359 // also account for the call where we succeed.
360 return std::geometric_distribution<int>(p)(*generator_) + 1;
361 }
362
363 // ============================================================================
364 // == LocalCallbackManager: Implementation ====================================
365 // ============================================================================
get()366 LocalCallbackManager& LocalCallbackManager::get() {
367 #if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
368 static c10::ThreadLocal<LocalCallbackManager> manager;
369 return manager.get();
370 #else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
371 static thread_local LocalCallbackManager manager;
372 return manager;
373 #endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
374 }
375
LocalCallbackManager()376 LocalCallbackManager::LocalCallbackManager() {
377 for (auto i : c10::irange(NumRecordScopes)) {
378 active_callbacks_[i] = CacheEntry(&generator_, static_cast<RecordScope>(i));
379 }
380 rebuild_all(GlobalCallbackManager::get().getSnapshot());
381 }
382
getTLS() const383 const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
384 return registered_callbacks_;
385 }
386
rebuildActiveCallbacksIfNeeded()387 void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
388 const auto global_version = GlobalCallbackManager::get().version();
389 if (C10_UNLIKELY(global_version != global_version_)) {
390 rebuild_all(GlobalCallbackManager::get().getSnapshot());
391 }
392 }
393
getActiveCallbacks(const RecordScope scope)394 StepCallbacks LocalCallbackManager::getActiveCallbacks(
395 const RecordScope scope) {
396 rebuildActiveCallbacksIfNeeded();
397 return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
398 }
399
getActiveCallbacksUnlessEmpty(const RecordScope scope)400 std::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
401 const RecordScope scope) {
402 rebuildActiveCallbacksIfNeeded();
403 return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
404 }
405
setTLS(const RecordFunctionTLS & tls)406 void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
407 registered_callbacks_ = tls;
408 rebuild_all(GlobalCallbackManager::get().getSnapshot());
409 }
410
seed(uint32_t seed)411 void LocalCallbackManager::seed(uint32_t seed) {
412 generator_.seed(seed);
413 }
414
addCallback(RecordFunctionCallback callback)415 CallbackHandle LocalCallbackManager::addCallback(
416 RecordFunctionCallback callback) {
417 auto handle = next_unique_callback_handle();
418 auto& callbacks = registered_callbacks_.sorted_tls_callbacks_;
419 callbacks.emplace_back(callback, handle);
420 rebuild_callback_scopes(
421 GlobalCallbackManager::get().getSnapshot(), callbacks.back().callback_);
422 return handle;
423 }
424
setCallbackEnabled(CallbackHandle handle,bool enabled)425 bool LocalCallbackManager::setCallbackEnabled(
426 CallbackHandle handle,
427 bool enabled) {
428 auto it = findCallback(registered_callbacks_.sorted_tls_callbacks_, handle);
429 auto found = (it != registered_callbacks_.sorted_tls_callbacks_.end());
430 if (found && it->enabled_ != enabled) {
431 it->enabled_ = enabled;
432 rebuild_callback_scopes(
433 GlobalCallbackManager::get().getSnapshot(), it->callback_);
434 }
435 return found;
436 }
437
removeCallback(CallbackHandle handle)438 bool LocalCallbackManager::removeCallback(CallbackHandle handle) {
439 auto& callbacks = registered_callbacks_.sorted_tls_callbacks_;
440 auto callback = extractCallback(callbacks, handle);
441 if (callback.has_value()) {
442 rebuild_callback_scopes(
443 GlobalCallbackManager::get().getSnapshot(), *callback);
444 }
445 return callback.has_value();
446 }
447
clearCallbacks()448 void LocalCallbackManager::clearCallbacks() {
449 registered_callbacks_.sorted_tls_callbacks_.clear();
450 rebuild_all(GlobalCallbackManager::get().getSnapshot());
451 }
452
rebuild_all(const GlobalCallbackManager::snapshot_t & global_snapshot)453 void LocalCallbackManager::rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot) {
454 global_version_ = global_snapshot.first;
455 for (auto i : c10::irange(NumRecordScopes)) {
456 rebuild_scope(global_snapshot, static_cast<RecordScope>(i));
457 }
458 }
459
rebuild_callback_scopes(const GlobalCallbackManager::snapshot_t & global_snapshot,const RecordFunctionCallback & callback)460 void LocalCallbackManager::rebuild_callback_scopes(
461 const GlobalCallbackManager::snapshot_t& global_snapshot,
462 const RecordFunctionCallback& callback) {
463 if (global_snapshot.first == global_version_) {
464 // Only rebuild scopes associated with `callback`
465 for (auto i : c10::irange(NumRecordScopes)) {
466 if (callback.checkScope(static_cast<RecordScope>(i))) {
467 rebuild_scope(global_snapshot, static_cast<RecordScope>(i));
468 }
469 }
470 } else {
471 rebuild_all(global_snapshot);
472 }
473 }
474
rebuild_scope(const GlobalCallbackManager::snapshot_t & global_snapshot,const RecordScope scope)475 void LocalCallbackManager::rebuild_scope(
476 const GlobalCallbackManager::snapshot_t& global_snapshot,
477 const RecordScope scope) {
478 std::vector<RecordFunctionCallback> callbacks;
479 if (registered_callbacks_.tls_record_function_enabled_) {
480 auto populate_callbacks =
481 [&](const RecordFunctionCallbacks& raw_callbacks) {
482 for (const auto& i : raw_callbacks) {
483 if (i.enabled_ && i.callback_.checkScope(scope) &&
484 i.callback_.samplingProb() > 0) {
485 callbacks.push_back(i.callback_);
486 }
487 }
488 };
489 populate_callbacks(global_snapshot.second);
490 populate_callbacks(registered_callbacks_.sorted_tls_callbacks_);
491 }
492 active_callbacks_[static_cast<size_t>(scope)].update(callbacks);
493 }
494
495 // ============================================================================
496 // == Callback execution ======================================================
497 // ============================================================================
logTryRunCallbackError(const char * what,const char * name)498 void logTryRunCallbackError(const char* what, const char* name) {
499 LOG(WARNING) << "Exception in RecordFunction callback: " << what
500 << " , for the range " << name;
501 }
502
503 template <bool is_start>
tryRunCallback(const StepCallbacks::StartEndPair callback_ptrs,const RecordFunction & rf,std::unique_ptr<ObserverContext> & ctx)504 C10_ALWAYS_INLINE bool tryRunCallback(
505 const StepCallbacks::StartEndPair callback_ptrs,
506 const RecordFunction& rf,
507 std::unique_ptr<ObserverContext>& ctx) {
508 try {
509 if (is_start && callback_ptrs.start_) {
510 ctx = callback_ptrs.start_(rf);
511 }
512
513 if (!is_start && callback_ptrs.end_) {
514 callback_ptrs.end_(rf, ctx.get());
515 }
516
517 return true;
518 } catch (const std::exception& e) {
519 logTryRunCallbackError(e.what(), rf.name());
520 return false;
521 } catch (...) {
522 logTryRunCallbackError("unknown", rf.name());
523 return false;
524 }
525 }
526
527 } // namespace
528
RecordFunction(RecordScope scope)529 RecordFunction::RecordFunction(RecordScope scope)
530 : RecordFunction(getStepCallbacks(scope)) {}
531
RecordFunction(StepCallbacks && step_callbacks)532 RecordFunction::RecordFunction(StepCallbacks&& step_callbacks)
533 : step_callbacks_{std::move(step_callbacks)} {
534 ctx_.resize(step_callbacks_.callbacks_.size());
535 if (step_callbacks_.needs_ids_) {
536 setHandle(next_unique_record_function_handle());
537 }
538 }
539
runStartCallbacks()540 void RecordFunction::runStartCallbacks() {
541 for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) {
542 tryRunCallback</*is_start=*/true>(
543 step_callbacks_.callbacks_[i], *this, ctx_[i]);
544 }
545 called_start_callbacks_ = true;
546 }
547
end()548 void RecordFunction::end() {
549 if (called_start_callbacks_) {
550 for (const auto i : c10::irange(step_callbacks_.callbacks_.size())) {
551 tryRunCallback</*is_start=*/false>(
552 step_callbacks_.callbacks_[i], *this, ctx_[i]);
553 }
554 step_callbacks_.callbacks_.clear();
555 }
556 }
557
name() const558 const char* RecordFunction::name() const {
559 return std::visit(
560 c10::overloaded(
561 [](const std::string& name) { return name.c_str(); },
562 [](const schema_ref_t schema) {
563 return schema.get().name().c_str();
564 }),
565 fn_);
566 }
567
num_inputs() const568 size_t RecordFunction::num_inputs() const {
569 return std::visit(
570 c10::overloaded(
571 [&](const std::string&) { return inputs_.size(); },
572 [](const schema_ref_t schema) {
573 return schema.get().arguments().size();
574 }),
575 fn_);
576 }
577
num_outputs() const578 size_t RecordFunction::num_outputs() const {
579 return std::visit(
580 c10::overloaded(
581 [&](const std::string&) { return outputs_.size(); },
582 [](const schema_ref_t schema) {
583 return schema.get().returns().size();
584 }),
585 fn_);
586 }
587
operator_name() const588 std::optional<OperatorName> RecordFunction::operator_name() const {
589 return std::visit(
590 c10::overloaded(
591 [&](const std::string&) -> std::optional<OperatorName> {
592 return std::nullopt;
593 },
594 [](const schema_ref_t schema) -> std::optional<OperatorName> {
595 return schema.get().operator_name();
596 }),
597 fn_);
598 }
599
operator_schema() const600 std::optional<c10::FunctionSchema> RecordFunction::operator_schema() const {
601 return std::visit(
602 c10::overloaded(
603 [&](const std::string&) -> std::optional<c10::FunctionSchema> {
604 return std::nullopt;
605 },
606 [](const schema_ref_t schema) -> std::optional<c10::FunctionSchema> {
607 return schema.get();
608 }),
609 fn_);
610 }
611
getStepCallbacks(RecordScope scope)612 StepCallbacks getStepCallbacks(RecordScope scope) {
613 return LocalCallbackManager::get().getActiveCallbacks(scope);
614 }
615
getStepCallbacksUnlessEmpty(RecordScope scope)616 std::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
617 return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
618 }
619
get_record_function_tls_()620 const RecordFunctionTLS& get_record_function_tls_() {
621 return LocalCallbackManager::get().getTLS();
622 }
623
set_record_function_tls_(const RecordFunctionTLS & tls)624 void set_record_function_tls_(const RecordFunctionTLS& tls) {
625 LocalCallbackManager::get().setTLS(tls);
626 }
627
628 namespace {
anyEnabled(const RecordFunctionCallbacks & callbacks)629 bool anyEnabled(const RecordFunctionCallbacks& callbacks) {
630 return std::any_of(callbacks.begin(), callbacks.end(), [](const auto& cb) {
631 return cb.enabled_;
632 });
633 }
634 } // namespace
635
hasCallbacks()636 bool hasCallbacks() {
637 return hasThreadLocalCallbacks() || hasGlobalCallbacks();
638 }
639
hasGlobalCallbacks()640 bool hasGlobalCallbacks() {
641 return anyEnabled(GlobalCallbackManager::get().getSnapshot().second);
642 }
643
hasThreadLocalCallbacks()644 bool hasThreadLocalCallbacks() {
645 return anyEnabled(get_record_function_tls_().sorted_tls_callbacks_);
646 }
647
addThreadLocalCallback(RecordFunctionCallback cb)648 CallbackHandle addThreadLocalCallback(
649 RecordFunctionCallback cb) {
650 return LocalCallbackManager::get().addCallback(cb);
651 }
652
addGlobalCallback(RecordFunctionCallback cb)653 CallbackHandle addGlobalCallback(
654 RecordFunctionCallback cb) {
655 return GlobalCallbackManager::get().addCallback(cb);
656 }
657
removeCallback(CallbackHandle handle)658 void removeCallback(CallbackHandle handle) {
659 if (!LocalCallbackManager::get().removeCallback(handle)) {
660 GlobalCallbackManager::get().removeCallback(handle);
661 }
662 }
663
disableCallback(CallbackHandle handle)664 void disableCallback(CallbackHandle handle) {
665 if (!LocalCallbackManager::get().setCallbackEnabled(handle, false)) {
666 GlobalCallbackManager::get().setCallbackEnabled(handle, false);
667 }
668 }
669
reenableCallback(CallbackHandle handle)670 void reenableCallback(CallbackHandle handle) {
671 if (!LocalCallbackManager::get().setCallbackEnabled(handle, true)) {
672 GlobalCallbackManager::get().setCallbackEnabled(handle, true);
673 }
674 }
675
clearGlobalCallbacks()676 void clearGlobalCallbacks() {
677 GlobalCallbackManager::get().clearCallbacks();
678 }
679
clearThreadLocalCallbacks()680 void clearThreadLocalCallbacks() {
681 LocalCallbackManager::get().clearCallbacks();
682 }
683
clearCallbacks()684 void clearCallbacks() {
685 clearGlobalCallbacks();
686 clearThreadLocalCallbacks();
687 }
688
isRecordFunctionEnabled()689 bool isRecordFunctionEnabled() {
690 return LocalCallbackManager::get().getTLS().tls_record_function_enabled_;
691 }
692
enableRecordFunction(bool enable)693 void enableRecordFunction(bool enable) {
694 auto tls = LocalCallbackManager::get().getTLS();
695 if (tls.tls_record_function_enabled_ != enable) {
696 tls.tls_record_function_enabled_ = enable;
697 LocalCallbackManager::get().setTLS(tls);
698 }
699 }
700
set_record_function_seed_for_testing(uint32_t seed)701 void set_record_function_seed_for_testing(uint32_t seed) {
702 LocalCallbackManager::get().seed(seed);
703 }
704
705 /* static */
currentThreadId()706 uint64_t RecordFunction::currentThreadId() {
707 if (!current_thread_id_) {
708 // happens only once per thread
709 current_thread_id_ = ++next_thread_id_;
710 }
711 return current_thread_id_;
712 }
713
before(const char * name,int64_t sequence_nr)714 void RecordFunction::before(const char* name, int64_t sequence_nr) {
715 fn_ = name;
716 sequence_nr_ = sequence_nr;
717 is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0);
718
719 #ifndef NDEBUG
720 inputs_valid_ = true;
721 #endif
722 runStartCallbacks();
723 invalidateInputs();
724 }
725
before(std::string name,int64_t sequence_nr)726 void RecordFunction::before(std::string name, int64_t sequence_nr) {
727 is_nccl_meta_ = (name == kParamCommsCallName);
728 fn_ = std::move(name);
729 sequence_nr_ = sequence_nr;
730
731 #ifndef NDEBUG
732 inputs_valid_ = true;
733 #endif
734 runStartCallbacks();
735 invalidateInputs();
736 }
737
before(RecordFunction::schema_ref_t schema,int64_t sequence_nr)738 void RecordFunction::before(
739 RecordFunction::schema_ref_t schema,
740 int64_t sequence_nr) {
741 sequence_nr_ = sequence_nr;
742 fn_ = schema;
743 is_nccl_meta_ = (schema.get().name() == kParamCommsCallName);
744
745 #ifndef NDEBUG
746 inputs_valid_ = true;
747 #endif
748 runStartCallbacks();
749 invalidateInputs();
750 }
751
setDefaultNodeId(int64_t newDefaultNodeId)752 /* static */ void RecordFunction::setDefaultNodeId(int64_t newDefaultNodeId) {
753 TORCH_CHECK(newDefaultNodeId >= 0, "setDefaultNodeId expects an id >= 0.");
754 defaultNodeId = newDefaultNodeId;
755 }
756
getDefaultNodeId()757 /* static */ int64_t RecordFunction::getDefaultNodeId() {
758 return defaultNodeId;
759 }
760
~RecordFunction()761 RecordFunction::~RecordFunction() {
762 end();
763 }
764
_setAsync()765 void RecordFunction::_setAsync() {
766 is_async_ = true;
767 }
768
isAsync() const769 bool RecordFunction::isAsync() const {
770 return is_async_;
771 }
772
_setStaticRuntimeOutVariant()773 void RecordFunction::_setStaticRuntimeOutVariant() {
774 if (isActive()) {
775 is_static_runtime_out_variant_ = true;
776 }
777 }
778
isStaticRuntimeOutVariant() const779 bool RecordFunction::isStaticRuntimeOutVariant() const {
780 if (isActive()) {
781 return is_static_runtime_out_variant_;
782 }
783 return false;
784 }
785 } // namespace at
786