xref: /aosp_15_r20/external/pytorch/aten/src/ATen/record_function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <ATen/core/operator_name.h>
5 #include <c10/macros/Export.h>
6 #include <c10/util/SmallVector.h>
7 #include <optional>
8 
9 #include <array>
10 #include <functional>
11 #include <memory>
12 #include <variant>
13 
14 namespace c10 {
15 class TORCH_API OperatorHandle;
16 }
17 
18 namespace at {
19 
20 // Function name to record NCCL metadata
21 extern TORCH_API const std::string kParamCommsCallName;
22 
23 // Kind of record function scope;
24 enum class C10_API_ENUM RecordScope : uint8_t {
25   // c10/ATen ops, autograd nodes
26   FUNCTION = 0,
27   // Functions/nodes called from the autograd
28   BACKWARD_FUNCTION,
29   // TorchScript functions, methods
30   TORCHSCRIPT_FUNCTION,
31   // Kernel Function dtype Tag
32   KERNEL_FUNCTION_DTYPE,
33   // Torchbind custom class,
34   CUSTOM_CLASS,
35   // Generic Build Feature
36   BUILD_FEATURE,
37   // Kernel Function dtype Tag
38   LITE_INTERPRETER,
39   // User defined scope (e.g. with record_function())
40   USER_SCOPE,
41   // Scopes for static runtime, a specialized TorchScript interpreter
42   STATIC_RUNTIME_OP,
43   STATIC_RUNTIME_MODEL,
44   NUM_SCOPES, // must be the last in the list
45 };
46 
47 } // namespace at
48 
49 namespace std {
50 template <>
51 struct hash<at::RecordScope> {
52   size_t operator()(const at::RecordScope& sc) const {
53     return static_cast<std::size_t>(sc);
54   }
55 };
56 } // namespace std
57 
58 namespace at {
59 
60 struct TORCH_API StringView {
61   StringView() : StringView(nullptr) {}
62   explicit StringView(const char* str_ptr)
63       : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {}
64   explicit StringView(std::string str)
65       : owned_str_ptr_(std::make_shared<std::string>(std::move(str))),
66         str_ptr_(owned_str_ptr_->c_str()) {}
67 
68   const char* str() const {
69     return str_ptr_;
70   }
71 
72   friend std::ostream& operator<<(std::ostream& os, const StringView& dt) {
73     os << dt.str();
74     return os;
75   }
76 
77   friend bool operator==(const StringView& lhs, const StringView& rhs) {
78     return strcmp(lhs.str(), rhs.str()) == 0;
79   }
80 
81   friend bool operator!=(const StringView& lhs, const StringView& rhs) {
82     return !(lhs == rhs);
83   }
84 
85  private:
86   std::shared_ptr<std::string> owned_str_ptr_;
87   const char* str_ptr_;
88 };
89 
90 // Soft limit on the number of callbacks to use;
91 constexpr std::size_t kSoftLimitCallbacks = 4;
92 
93 // An abstract base class for various observer contexts that can be attached to
94 // the RecordFunction.
95 struct ObserverContext {
96   virtual ~ObserverContext() = default;
97 
98  protected:
99   ObserverContext() = default;
100 };
101 
102 typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
103 typedef c10::SmallVector<std::unique_ptr<ObserverContext>, kSoftLimitCallbacks>
104     ObserverContextList;
105 typedef uint64_t RecordFunctionHandle;
106 struct RecordFunction;
107 
108 //
109 // PyTorch callbacks/observers API:
110 //
111 
112 /**
113  * RecordFunctionCallback represents a pair of callbacks to be used with
114  * RecordFunction, members:
115  *   start, end - the callbacks to run when entering and exiting the scope;
116  *     optionally, the start callback may return an ObserverContext which will
117  *     be passed to the end callback, use appropriate constructor accordingly.
118  *   needs_inputs - whether the callbacks need the inputs passed from the
119  * observed function/range; NOTE: passing the inputs incurs an additional
120  * overhead; sampling_probability - if not 1.0, then the callback is
121  * probabilistically sampled to run; NOTE: start and end callbacks always run as
122  * a pair and are sampled together; scopes - types of scopes to execute the
123  * callbacks on (see RecordScope); passing empty set means the callbacks will be
124  * executed for all possible scope types should_run - optional function that
125  * returns whether this callback should run; overwrites the effect of setting
126  * sampling_probability
127  */
128 class TORCH_API RecordFunctionCallback {
129  public:
130   using StartCallback =
131       std::unique_ptr<ObserverContext> (*)(const RecordFunction&);
132   using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
133 
134   // This interface supports observers that require passing an ObserverContext
135   // between start and end callbacks.
136   explicit RecordFunctionCallback(
137       StartCallback start,
138       EndCallback end = nullptr)
139       : start_(start), end_(end) {
140     scopes_.fill(true);
141   }
142 
143   RecordFunctionCallback& needsInputs(bool needs_inputs) {
144     needs_inputs_ = needs_inputs;
145     return *this;
146   }
147 
148   RecordFunctionCallback& needsOutputs(bool needs_outputs) {
149     needs_outputs_ = needs_outputs;
150     return *this;
151   }
152 
153   RecordFunctionCallback& needsIds(bool needs_ids) {
154     needs_ids_ = needs_ids;
155     return *this;
156   }
157 
158   RecordFunctionCallback& samplingProb(double sampling_prob) {
159     TORCH_CHECK(
160         sampling_prob >= 0.0 && sampling_prob <= 1.0,
161         "Invalid sampling probability");
162     sampling_prob_ = sampling_prob;
163     return *this;
164   }
165 
166   RecordFunctionCallback& scopes(
167       const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) {
168     if (!scopes.empty()) {
169       scopes_.fill(false);
170       for (auto sc : scopes) {
171         scopes_[static_cast<size_t>(sc)] = true;
172       }
173     } else {
174       scopes_.fill(true);
175     }
176     return *this;
177   }
178 
179   bool needsInputs() const {
180     return needs_inputs_;
181   }
182 
183   bool needsOutputs() const {
184     return needs_outputs_;
185   }
186 
187   bool needsIds() const {
188     return needs_ids_;
189   }
190 
191   double samplingProb() const {
192     return sampling_prob_;
193   }
194 
195   bool checkScope(RecordScope sc) const {
196     return scopes_[(size_t)sc];
197   }
198 
199   StartCallback start() const {
200     return start_;
201   }
202 
203   EndCallback end() const {
204     return end_;
205   }
206 
207  private:
208   StartCallback start_;
209   EndCallback end_;
210   double sampling_prob_ = 1.0;
211   std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
212   bool needs_inputs_ = false;
213   bool needs_outputs_ = false;
214   bool needs_ids_ = false;
215 };
216 
217 // Notes:
218 //  - two types of callbacks are provided: thread local and global
219 //     - thread local callbacks are added/removed only for the given thread
220 //       and are stored locally for each thread and separately from the list
221 //       of the global callbacks
222 //     - global callbacks are stored in a single per process list and are
223 //       invoked by every RecordFunction, in addition to the thread local
224 //       callbacks specific to the given thread
225 //  - we allow the added callbacks to be sampled, by specifying a sampling
226 //    probability for each callback pair, if the start callback is
227 //    not picked to run, the corresponding end callback won't be called
228 //  - a typical use case for the global callbacks is passive monitoring
229 //    in the background (e.g. fleet-wide monitoring), without focusing on
230 //    the specific piece of code
231 //  - in contrast, thread local callbacks are enabled locally, on demand,
232 //    for the specific piece of code (range) and are not sampled
233 //  - a typical use case for thread local callbacks is profiler and code
234 //    execution tracer
235 //  - note, thread local callbacks are automatically propagated with
236 //    ThreadLocalState across JIT continuations and async tasks (at::launch)
237 
238 typedef uint64_t CallbackHandle;
239 
240 constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0};
241 
242 // It is unnecessary to use atomic operations for enabling
243 // thread-local function callbacks. Moreover, it prevents saving to
244 // ThreadLocalState because std::atomic is non-copyable.
245 struct RecordFunctionCallbacksEntry {
246   RecordFunctionCallbacksEntry(RecordFunctionCallback cb, CallbackHandle h)
247       : callback_(cb), handle_(h) {}
248 
249   RecordFunctionCallback callback_;
250   bool enabled_{true};
251   CallbackHandle handle_;
252 };
253 
254 // Holds pairs (callbacks, unique_id)
255 using RecordFunctionCallbacks = std::vector<RecordFunctionCallbacksEntry>;
256 
257 // Generated by the callback managers to determine which functions to run.
258 struct StepCallbacks {
259   StepCallbacks() = default;
260   StepCallbacks(uint64_t thread_id, RecordScope scope)
261       : thread_id_{thread_id}, scope_{scope} {}
262 
263   bool empty() const {
264     return callbacks_.empty();
265   }
266 
267   struct StartEndPair {
268     RecordFunctionCallback::StartCallback start_;
269     RecordFunctionCallback::EndCallback end_;
270   };
271 
272   using StartEndPairs = c10::SmallVector<StartEndPair, kSoftLimitCallbacks>;
273 
274   StartEndPairs callbacks_;
275   uint64_t thread_id_{0};
276   RecordScope scope_{RecordScope::FUNCTION};
277   bool needs_inputs_{false};
278   bool needs_outputs_{false};
279   bool needs_ids_{false};
280 };
281 
282 struct TORCH_API RecordFunction {
283   // Default constructor is used with before function called afterwards:
284   //  scope - record scope that this function tracks
285   //  pre_sampled - whether this RecordFunction was already pre-sampled with
286   //    kLowProb probability
287   explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION);
288   explicit RecordFunction(StepCallbacks&& step_callbacks);
289 
290   template <typename F>
291   void before(
292       F fn,
293       c10::ArrayRef<const c10::IValue> args,
294       int64_t current_sequence_nr = -1) {
295     if (!isActive()) {
296       return;
297     }
298     inputs_ = args;
299     before(fn, current_sequence_nr);
300   }
301 
302   template <typename F>
303   void before(
304       F fn,
305       const std::vector<IValue>* args,
306       int64_t current_sequence_nr = -1) {
307     before(
308         std::move(fn),
309         c10::ArrayRef<const c10::IValue>(args->data(), args->size()),
310         current_sequence_nr);
311   }
312 
313   template <typename F>
314   void before(
315       F fn,
316       const std::vector<IValue>* args,
317       const std::unordered_map<std::string, IValue>* kwargs,
318       int64_t current_sequence_nr = -1) {
319     if (!isActive()) {
320       return;
321     }
322     kwinputs_ = std::unordered_map<std::string, IValue>(*kwargs);
323     before(std::move(fn), args, current_sequence_nr);
324   }
325 
326   // Destructor calls end callbacks
327   virtual ~RecordFunction();
328 
329   RecordFunction(const RecordFunction&) = delete;
330   RecordFunction& operator=(const RecordFunction&) = delete;
331 
332   const char* name() const;
333 
334   int64_t seqNr() const {
335     return sequence_nr_;
336   }
337 
338   c10::ArrayRef<const IValue> inputs() const {
339 #ifndef NDEBUG
340     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
341         inputs_valid_, "Called inputs() outside RecordFunction start callback");
342 #endif
343     return inputs_;
344   }
345 
346   std::unordered_map<std::string, IValue> kwinputs() const {
347 #ifndef NDEBUG
348     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
349         inputs_valid_,
350         "Called kwinputs() outside RecordFunction start callback");
351 #endif
352     return kwinputs_;
353   }
354 
355   const std::vector<c10::IValue>& outputs() const {
356     return outputs_;
357   }
358 
359   void setOutputs(std::vector<c10::IValue>&& outputs) {
360     outputs_ = std::move(outputs);
361   }
362 
363   void setOutputs(c10::ArrayRef<c10::IValue> outputs) {
364     outputs_ = outputs.vec();
365   }
366 
367   size_t num_inputs() const;
368   size_t num_outputs() const;
369 
370   // Retrieves the thread_id that this RecordFunction ran start callbacks with.
371   // Useful for writing thread safe end callbacks that may be potentially
372   // executed in a different thread (async ops)
373   uint64_t threadId() const {
374     return step_callbacks_.thread_id_;
375   }
376 
377   // For backward functions - thread id of the corresponding forward function,
378   // or zero otherwise;
379   // used alongside with sequence number to correlate backward functions with
380   // the forward ones
381   uint64_t forwardThreadId() const {
382     return fwd_thread_id_;
383   }
384 
385   void setForwardThreadId(uint64_t thread_id) {
386     fwd_thread_id_ = thread_id;
387   }
388 
389   RecordScope scope() const {
390     return step_callbacks_.scope_;
391   }
392 
393   // Returns logical thread_id for the current thread
394   static uint64_t currentThreadId();
395 
396   // Internal functions, do not use directly;
397   // used in python's context manager
398 
399   // before functions initialize RecordFunction members and call
400   // start callbacks
401   using schema_ref_t = std::reference_wrapper<const c10::FunctionSchema>;
402   void before(const char* name, int64_t sequence_nr = -1);
403   void before(std::string name, int64_t sequence_nr = -1);
404   void before(schema_ref_t schema, int64_t sequence_nr = -1);
405 
406   // Sets node ID for distributed profiling
407   static void setDefaultNodeId(int64_t defaultNodeId);
408   // Gets node ID for distributed profiling
409   static int64_t getDefaultNodeId();
410 
411   // Calls end callbacks. After end(), accessors will no longer provide useful
412   // results.
413   void end();
414 
415   // Internal-only, used only force async event for distributed events
416   // profiling.
417   void _setAsync();
418 
419   // Returns whether this RecordFunction corresponds to an async event or not.
420   bool isAsync() const;
421 
422   // Returns whether this RecordFunction corresponds to NCCL metadata collection
423   // or not.
424   bool isNcclMeta() const {
425     return is_nccl_meta_;
426   }
427 
428   // Internal-only, used to denote out variant used for Static Runtime execution
429   void _setStaticRuntimeOutVariant();
430   bool isStaticRuntimeOutVariant() const;
431 
432   RecordFunctionHandle handle() const {
433     return handle_;
434   }
435 
436   std::optional<OperatorName> operator_name() const;
437 
438   // This method returns a copy of the FunctionSchema and can be expensive.
439   std::optional<FunctionSchema> operator_schema() const;
440 
441   void setHandle(RecordFunctionHandle handle) {
442     handle_ = handle;
443   }
444 
445   // Whether this RecordFunction runs any callbacks.
446   bool isActive() const {
447     return !step_callbacks_.empty();
448   }
449 
450   bool needsInputs() const {
451     return step_callbacks_.needs_inputs_;
452   }
453 
454   bool needsOutputs() const {
455     return step_callbacks_.needs_outputs_;
456   }
457 
458   int64_t debugHandle() const {
459     return debug_handle_;
460   }
461 
462   void setDebugHandle(int64_t debug_handle) {
463     debug_handle_ = debug_handle;
464   }
465 
466   void invalidateInputs() {
467 #ifndef NDEBUG
468     inputs_valid_ = false;
469 #endif
470   }
471 
472  private:
473   void runStartCallbacks();
474 
475   StepCallbacks step_callbacks_;
476 
477   // In cases when RecordFunction might be active but we chose not to
478   // use the observers (e.g. operator is not observed), this boolean
479   // flag is used to check whether the start callbacks were called
480   bool called_start_callbacks_ = false;
481 
482 #ifndef NDEBUG
483   bool inputs_valid_ = false;
484 #endif
485 
486   // Stores various ObserverContext objects with event metadata for callbacks.
487   ObserverContextList ctx_;
488 
489   std::variant<std::string, schema_ref_t> fn_;
490 
491   int64_t sequence_nr_ = -1;
492   c10::ArrayRef<const IValue> inputs_;
493   std::unordered_map<std::string, IValue> kwinputs_;
494   std::vector<c10::IValue> outputs_;
495 
496   // For backward functions - thread id of the forward function
497   uint64_t fwd_thread_id_ = 0;
498 
499   // Unique id for this RecordFunction, used in callbacks to track start
500   // and end of ranges
501   RecordFunctionHandle handle_{0};
502 
503   // Whether this record_function corresponds to an async event or not. Async
504   // events can complete in different threads or follow a future-like pattern
505   // of use.
506   bool is_async_{false};
507 
508   // Debug handles are used for lazy annotation of module hierarchy
509   // and callstack.
510   // This is specifically is useful for mobile runtime, where generated
511   // debug handles can be lazily symbolicated using debug information
512   int64_t debug_handle_{-1};
513 
514   // Whether this RecordFunction is used for an out variant run with
515   // Static Runtime
516   bool is_static_runtime_out_variant_{false};
517 
518   // Whether this RecordFunction is used for NCCL metadata collection
519   bool is_nccl_meta_{false};
520 };
521 
522 TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
523 
524 TORCH_API std::optional<StepCallbacks> getStepCallbacksUnlessEmpty(
525     RecordScope scope);
526 
527 namespace detail {
528 template <typename Inputs, typename F, typename... Args>
529 void record_function_with_scope(
530     RecordFunction& guard,
531     F fn,
532     const Inputs& inputs,
533     Args&&... args) {
534   if (guard.needsInputs()) {
535     guard.before(
536         fn,
537         c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()),
538         std::forward<Args>(args)...);
539   } else {
540     guard.before(fn, std::forward<Args>(args)...);
541   }
542 }
543 
544 template <typename Inputs, typename F, typename... Args>
545 void record_function_with_scope_and_debug_handle(
546     RecordFunction& guard,
547     F fn,
548     int64_t debug_handle,
549     const Inputs& inputs,
550     Args&&... args) {
551   guard.setDebugHandle(debug_handle);
552   if (guard.needsInputs()) {
553     guard.before(
554         fn,
555         c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()),
556         std::forward<Args>(args)...);
557   } else {
558     guard.before(fn, std::forward<Args>(args)...);
559   }
560 }
561 
562 template <typename F, typename... Args>
563 void record_function_with_scope(
564     RecordFunction& guard,
565     F fn,
566     c10::ArrayRef<const c10::IValue> inputs,
567     Args&&... args) {
568   return record_function_with_scope<
569       c10::ArrayRef<const c10::IValue>,
570       F,
571       Args...>(guard, std::move(fn), inputs, std::forward<Args>(args)...);
572 }
573 
574 template <typename F, typename... Args>
575 void record_function_with_scope_and_debug_handle(
576     RecordFunction& guard,
577     F fn,
578     int64_t debug_handle,
579     c10::ArrayRef<const c10::IValue> inputs,
580     Args&&... args) {
581   return record_function_with_scope_and_debug_handle<
582       c10::ArrayRef<const c10::IValue>,
583       F,
584       Args...>(
585       guard, std::move(fn), debug_handle, inputs, std::forward<Args>(args)...);
586 }
587 
588 } // namespace detail
589 
590 // optional argument - function's seq_no
591 #define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \
592   at::RecordFunction guard(scope);                         \
593   if (guard.isActive()) {                                  \
594     ::at::detail::record_function_with_scope(              \
595         guard, fn, inputs, ##__VA_ARGS__);                 \
596   }
597 
598 #define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \
599     scope, fn, inputs, outputs, ...)               \
600   at::RecordFunction guard(scope);                 \
601   if (guard.isActive()) {                          \
602     if (guard.needsInputs()) {                     \
603       guard.before(fn, inputs, ##__VA_ARGS__);     \
604     } else {                                       \
605       guard.before(fn, ##__VA_ARGS__);             \
606     }                                              \
607     if (guard.needsOutputs()) {                    \
608       guard.setOutputs(outputs);                   \
609     }                                              \
610   }
611 
612 #define RECORD_FUNCTION(fn, inputs, ...) \
613   RECORD_FUNCTION_WITH_SCOPE(            \
614       at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__)
615 
616 #define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \
617   RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs)
618 
619 #define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \
620   RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS(                          \
621       at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__)
622 
623 // Custom user scopes in C++; similar to Python's 'with record_function("..."):'
624 #define RECORD_USER_SCOPE(fn) \
625   RECORD_FUNCTION_WITH_SCOPE( \
626       at::RecordScope::USER_SCOPE, fn, c10::ArrayRef<const c10::IValue>{})
627 
628 // RECORD_USER_SCOPE with inputs
629 #define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \
630   RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs)
631 
632 // Helper macro to pass in debug handle that is used to
633 // post process events
634 #define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS(             \
635     scope, fn, debug_handle, inputs, ...)                      \
636   at::RecordFunction guard(scope);                             \
637   if (guard.isActive()) {                                      \
638     ::at::detail::record_function_with_scope_and_debug_handle( \
639         guard, fn, debug_handle, inputs, ##__VA_ARGS__);       \
640   }
641 
642 // Helper macros to record LITE INTERPETER scope events with debug handles
643 #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \
644     fn, debug_handle, inputs)                           \
645   RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS(            \
646       at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs)
647 
648 // Bookend to the RECORD_FUNCTION macros.  Use this after the kernel
649 // launch to let the profiler bind the outputs to the op that produced
650 // them.  Note that guard is declared by RECORD_FUNCTION so this macro
651 // needs to be called from the same scope as RECORD_FUNCTION
652 #define RECORD_OUTPUTS(outputs)                                    \
653   if (guard.needsOutputs()) {                                      \
654     guard.setOutputs(                                              \
655         std::vector<c10::IValue>(outputs.begin(), outputs.end())); \
656   }
657 
658 /**
659  * addThreadLocalCallback adds a thread local callback to run with
660  * RecordFunction, returns handle to use with removeThreadLocalCallback
661  */
662 TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb);
663 
664 /**
665  * hasThreadLocalCallbacks returns whether there're callbacks registered
666  * with addThreadLocalCallback
667  */
668 TORCH_API bool hasThreadLocalCallbacks();
669 
670 /**
671  * clearThreadLocalCallbacks removes all thread local callbacks
672  */
673 TORCH_API void clearThreadLocalCallbacks();
674 
675 /**
676  * addGlobalCallback adds a global callback to run with RecordFunction:
677  *
678  * only during the program initialization
679  */
680 TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb);
681 
682 /**
683  * removeCallback removes a callback given the handle returned by
684  * addThreadLocalCallback or addGlobalCallback;
685  *
686  * no other code can run simultaneously
687  */
688 TORCH_API void removeCallback(CallbackHandle handle);
689 
690 /**
691  * Prevent the given callback from executing. If handle is invalid,
692  * does nothing.
693  */
694 TORCH_API void disableCallback(CallbackHandle handle);
695 
696 /**
697  * Allow the given callback, previously disabled with disableCallback, to
698  * execute again. If handle is invalid, does nothing.
699  */
700 TORCH_API void reenableCallback(CallbackHandle handle);
701 
702 /**
703  * hasGlobalCallbacks returns whether there're global callbacks
704  * registered with pushGlobalCallback
705  */
706 TORCH_API bool hasGlobalCallbacks();
707 
708 /**
709  * clearGlobalCallbacks removes all global callbacks
710  */
711 TORCH_API void clearGlobalCallbacks();
712 
713 // for both thread local and global callbacks
714 TORCH_API bool hasCallbacks();
715 TORCH_API void clearCallbacks();
716 
717 /**
718  * enableRecordFunction enables RecordFunction thread locally
719  */
720 TORCH_API void enableRecordFunction(bool enable = true);
721 
722 /**
723  * isRecordFunctionEnabled returns whether RecordFunction
724  * is enabled thread locally
725  */
726 TORCH_API bool isRecordFunctionEnabled();
727 
728 class TORCH_API RecordFunctionGuard {
729  public:
730   explicit RecordFunctionGuard(bool is_enabled = true)
731       : prev_value_(isRecordFunctionEnabled()) {
732     enableRecordFunction(is_enabled);
733   }
734 
735   virtual ~RecordFunctionGuard() {
736     enableRecordFunction(prev_value_);
737   }
738 
739  private:
740   bool prev_value_ = false;
741 };
742 
743 class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard {
744  public:
745   DisableRecordFunctionGuard() : RecordFunctionGuard(false) {}
746   ~DisableRecordFunctionGuard() override = default;
747 };
748 
749 struct TORCH_API RecordFunctionTLS {
750   // Thread local vector of callbacks, holds pairs (callbacks, unique_id);
751   // must be sorted in increasing handles order
752   RecordFunctionCallbacks sorted_tls_callbacks_;
753 
754   bool tls_record_function_enabled_ = true;
755 };
756 
757 TORCH_API const RecordFunctionTLS& get_record_function_tls_();
758 
759 TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls);
760 
761 TORCH_API void set_record_function_seed_for_testing(uint32_t seed);
762 
763 } // namespace at
764