1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <ATen/core/operator_name.h> 5 #include <c10/macros/Export.h> 6 #include <c10/util/SmallVector.h> 7 #include <optional> 8 9 #include <array> 10 #include <functional> 11 #include <memory> 12 #include <variant> 13 14 namespace c10 { 15 class TORCH_API OperatorHandle; 16 } 17 18 namespace at { 19 20 // Function name to record NCCL metadata 21 extern TORCH_API const std::string kParamCommsCallName; 22 23 // Kind of record function scope; 24 enum class C10_API_ENUM RecordScope : uint8_t { 25 // c10/ATen ops, autograd nodes 26 FUNCTION = 0, 27 // Functions/nodes called from the autograd 28 BACKWARD_FUNCTION, 29 // TorchScript functions, methods 30 TORCHSCRIPT_FUNCTION, 31 // Kernel Function dtype Tag 32 KERNEL_FUNCTION_DTYPE, 33 // Torchbind custom class, 34 CUSTOM_CLASS, 35 // Generic Build Feature 36 BUILD_FEATURE, 37 // Kernel Function dtype Tag 38 LITE_INTERPRETER, 39 // User defined scope (e.g. with record_function()) 40 USER_SCOPE, 41 // Scopes for static runtime, a specialized TorchScript interpreter 42 STATIC_RUNTIME_OP, 43 STATIC_RUNTIME_MODEL, 44 NUM_SCOPES, // must be the last in the list 45 }; 46 47 } // namespace at 48 49 namespace std { 50 template <> 51 struct hash<at::RecordScope> { 52 size_t operator()(const at::RecordScope& sc) const { 53 return static_cast<std::size_t>(sc); 54 } 55 }; 56 } // namespace std 57 58 namespace at { 59 60 struct TORCH_API StringView { 61 StringView() : StringView(nullptr) {} 62 explicit StringView(const char* str_ptr) 63 : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {} 64 explicit StringView(std::string str) 65 : owned_str_ptr_(std::make_shared<std::string>(std::move(str))), 66 str_ptr_(owned_str_ptr_->c_str()) {} 67 68 const char* str() const { 69 return str_ptr_; 70 } 71 72 friend std::ostream& operator<<(std::ostream& os, const StringView& dt) { 73 os << dt.str(); 74 return os; 75 } 76 77 friend bool operator==(const StringView& lhs, const StringView& rhs) { 78 return strcmp(lhs.str(), rhs.str()) == 0; 79 } 80 81 friend bool operator!=(const StringView& lhs, const StringView& rhs) { 82 return !(lhs == rhs); 83 } 84 85 private: 86 std::shared_ptr<std::string> owned_str_ptr_; 87 const char* str_ptr_; 88 }; 89 90 // Soft limit on the number of callbacks to use; 91 constexpr std::size_t kSoftLimitCallbacks = 4; 92 93 // An abstract base class for various observer contexts that can be attached to 94 // the RecordFunction. 95 struct ObserverContext { 96 virtual ~ObserverContext() = default; 97 98 protected: 99 ObserverContext() = default; 100 }; 101 102 typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles; 103 typedef c10::SmallVector<std::unique_ptr<ObserverContext>, kSoftLimitCallbacks> 104 ObserverContextList; 105 typedef uint64_t RecordFunctionHandle; 106 struct RecordFunction; 107 108 // 109 // PyTorch callbacks/observers API: 110 // 111 112 /** 113 * RecordFunctionCallback represents a pair of callbacks to be used with 114 * RecordFunction, members: 115 * start, end - the callbacks to run when entering and exiting the scope; 116 * optionally, the start callback may return an ObserverContext which will 117 * be passed to the end callback, use appropriate constructor accordingly. 118 * needs_inputs - whether the callbacks need the inputs passed from the 119 * observed function/range; NOTE: passing the inputs incurs an additional 120 * overhead; sampling_probability - if not 1.0, then the callback is 121 * probabilistically sampled to run; NOTE: start and end callbacks always run as 122 * a pair and are sampled together; scopes - types of scopes to execute the 123 * callbacks on (see RecordScope); passing empty set means the callbacks will be 124 * executed for all possible scope types should_run - optional function that 125 * returns whether this callback should run; overwrites the effect of setting 126 * sampling_probability 127 */ 128 class TORCH_API RecordFunctionCallback { 129 public: 130 using StartCallback = 131 std::unique_ptr<ObserverContext> (*)(const RecordFunction&); 132 using EndCallback = void (*)(const RecordFunction&, ObserverContext*); 133 134 // This interface supports observers that require passing an ObserverContext 135 // between start and end callbacks. 136 explicit RecordFunctionCallback( 137 StartCallback start, 138 EndCallback end = nullptr) 139 : start_(start), end_(end) { 140 scopes_.fill(true); 141 } 142 143 RecordFunctionCallback& needsInputs(bool needs_inputs) { 144 needs_inputs_ = needs_inputs; 145 return *this; 146 } 147 148 RecordFunctionCallback& needsOutputs(bool needs_outputs) { 149 needs_outputs_ = needs_outputs; 150 return *this; 151 } 152 153 RecordFunctionCallback& needsIds(bool needs_ids) { 154 needs_ids_ = needs_ids; 155 return *this; 156 } 157 158 RecordFunctionCallback& samplingProb(double sampling_prob) { 159 TORCH_CHECK( 160 sampling_prob >= 0.0 && sampling_prob <= 1.0, 161 "Invalid sampling probability"); 162 sampling_prob_ = sampling_prob; 163 return *this; 164 } 165 166 RecordFunctionCallback& scopes( 167 const std::unordered_set<RecordScope, std::hash<RecordScope>>& scopes) { 168 if (!scopes.empty()) { 169 scopes_.fill(false); 170 for (auto sc : scopes) { 171 scopes_[static_cast<size_t>(sc)] = true; 172 } 173 } else { 174 scopes_.fill(true); 175 } 176 return *this; 177 } 178 179 bool needsInputs() const { 180 return needs_inputs_; 181 } 182 183 bool needsOutputs() const { 184 return needs_outputs_; 185 } 186 187 bool needsIds() const { 188 return needs_ids_; 189 } 190 191 double samplingProb() const { 192 return sampling_prob_; 193 } 194 195 bool checkScope(RecordScope sc) const { 196 return scopes_[(size_t)sc]; 197 } 198 199 StartCallback start() const { 200 return start_; 201 } 202 203 EndCallback end() const { 204 return end_; 205 } 206 207 private: 208 StartCallback start_; 209 EndCallback end_; 210 double sampling_prob_ = 1.0; 211 std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {}; 212 bool needs_inputs_ = false; 213 bool needs_outputs_ = false; 214 bool needs_ids_ = false; 215 }; 216 217 // Notes: 218 // - two types of callbacks are provided: thread local and global 219 // - thread local callbacks are added/removed only for the given thread 220 // and are stored locally for each thread and separately from the list 221 // of the global callbacks 222 // - global callbacks are stored in a single per process list and are 223 // invoked by every RecordFunction, in addition to the thread local 224 // callbacks specific to the given thread 225 // - we allow the added callbacks to be sampled, by specifying a sampling 226 // probability for each callback pair, if the start callback is 227 // not picked to run, the corresponding end callback won't be called 228 // - a typical use case for the global callbacks is passive monitoring 229 // in the background (e.g. fleet-wide monitoring), without focusing on 230 // the specific piece of code 231 // - in contrast, thread local callbacks are enabled locally, on demand, 232 // for the specific piece of code (range) and are not sampled 233 // - a typical use case for thread local callbacks is profiler and code 234 // execution tracer 235 // - note, thread local callbacks are automatically propagated with 236 // ThreadLocalState across JIT continuations and async tasks (at::launch) 237 238 typedef uint64_t CallbackHandle; 239 240 constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0}; 241 242 // It is unnecessary to use atomic operations for enabling 243 // thread-local function callbacks. Moreover, it prevents saving to 244 // ThreadLocalState because std::atomic is non-copyable. 245 struct RecordFunctionCallbacksEntry { 246 RecordFunctionCallbacksEntry(RecordFunctionCallback cb, CallbackHandle h) 247 : callback_(cb), handle_(h) {} 248 249 RecordFunctionCallback callback_; 250 bool enabled_{true}; 251 CallbackHandle handle_; 252 }; 253 254 // Holds pairs (callbacks, unique_id) 255 using RecordFunctionCallbacks = std::vector<RecordFunctionCallbacksEntry>; 256 257 // Generated by the callback managers to determine which functions to run. 258 struct StepCallbacks { 259 StepCallbacks() = default; 260 StepCallbacks(uint64_t thread_id, RecordScope scope) 261 : thread_id_{thread_id}, scope_{scope} {} 262 263 bool empty() const { 264 return callbacks_.empty(); 265 } 266 267 struct StartEndPair { 268 RecordFunctionCallback::StartCallback start_; 269 RecordFunctionCallback::EndCallback end_; 270 }; 271 272 using StartEndPairs = c10::SmallVector<StartEndPair, kSoftLimitCallbacks>; 273 274 StartEndPairs callbacks_; 275 uint64_t thread_id_{0}; 276 RecordScope scope_{RecordScope::FUNCTION}; 277 bool needs_inputs_{false}; 278 bool needs_outputs_{false}; 279 bool needs_ids_{false}; 280 }; 281 282 struct TORCH_API RecordFunction { 283 // Default constructor is used with before function called afterwards: 284 // scope - record scope that this function tracks 285 // pre_sampled - whether this RecordFunction was already pre-sampled with 286 // kLowProb probability 287 explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); 288 explicit RecordFunction(StepCallbacks&& step_callbacks); 289 290 template <typename F> 291 void before( 292 F fn, 293 c10::ArrayRef<const c10::IValue> args, 294 int64_t current_sequence_nr = -1) { 295 if (!isActive()) { 296 return; 297 } 298 inputs_ = args; 299 before(fn, current_sequence_nr); 300 } 301 302 template <typename F> 303 void before( 304 F fn, 305 const std::vector<IValue>* args, 306 int64_t current_sequence_nr = -1) { 307 before( 308 std::move(fn), 309 c10::ArrayRef<const c10::IValue>(args->data(), args->size()), 310 current_sequence_nr); 311 } 312 313 template <typename F> 314 void before( 315 F fn, 316 const std::vector<IValue>* args, 317 const std::unordered_map<std::string, IValue>* kwargs, 318 int64_t current_sequence_nr = -1) { 319 if (!isActive()) { 320 return; 321 } 322 kwinputs_ = std::unordered_map<std::string, IValue>(*kwargs); 323 before(std::move(fn), args, current_sequence_nr); 324 } 325 326 // Destructor calls end callbacks 327 virtual ~RecordFunction(); 328 329 RecordFunction(const RecordFunction&) = delete; 330 RecordFunction& operator=(const RecordFunction&) = delete; 331 332 const char* name() const; 333 334 int64_t seqNr() const { 335 return sequence_nr_; 336 } 337 338 c10::ArrayRef<const IValue> inputs() const { 339 #ifndef NDEBUG 340 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 341 inputs_valid_, "Called inputs() outside RecordFunction start callback"); 342 #endif 343 return inputs_; 344 } 345 346 std::unordered_map<std::string, IValue> kwinputs() const { 347 #ifndef NDEBUG 348 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 349 inputs_valid_, 350 "Called kwinputs() outside RecordFunction start callback"); 351 #endif 352 return kwinputs_; 353 } 354 355 const std::vector<c10::IValue>& outputs() const { 356 return outputs_; 357 } 358 359 void setOutputs(std::vector<c10::IValue>&& outputs) { 360 outputs_ = std::move(outputs); 361 } 362 363 void setOutputs(c10::ArrayRef<c10::IValue> outputs) { 364 outputs_ = outputs.vec(); 365 } 366 367 size_t num_inputs() const; 368 size_t num_outputs() const; 369 370 // Retrieves the thread_id that this RecordFunction ran start callbacks with. 371 // Useful for writing thread safe end callbacks that may be potentially 372 // executed in a different thread (async ops) 373 uint64_t threadId() const { 374 return step_callbacks_.thread_id_; 375 } 376 377 // For backward functions - thread id of the corresponding forward function, 378 // or zero otherwise; 379 // used alongside with sequence number to correlate backward functions with 380 // the forward ones 381 uint64_t forwardThreadId() const { 382 return fwd_thread_id_; 383 } 384 385 void setForwardThreadId(uint64_t thread_id) { 386 fwd_thread_id_ = thread_id; 387 } 388 389 RecordScope scope() const { 390 return step_callbacks_.scope_; 391 } 392 393 // Returns logical thread_id for the current thread 394 static uint64_t currentThreadId(); 395 396 // Internal functions, do not use directly; 397 // used in python's context manager 398 399 // before functions initialize RecordFunction members and call 400 // start callbacks 401 using schema_ref_t = std::reference_wrapper<const c10::FunctionSchema>; 402 void before(const char* name, int64_t sequence_nr = -1); 403 void before(std::string name, int64_t sequence_nr = -1); 404 void before(schema_ref_t schema, int64_t sequence_nr = -1); 405 406 // Sets node ID for distributed profiling 407 static void setDefaultNodeId(int64_t defaultNodeId); 408 // Gets node ID for distributed profiling 409 static int64_t getDefaultNodeId(); 410 411 // Calls end callbacks. After end(), accessors will no longer provide useful 412 // results. 413 void end(); 414 415 // Internal-only, used only force async event for distributed events 416 // profiling. 417 void _setAsync(); 418 419 // Returns whether this RecordFunction corresponds to an async event or not. 420 bool isAsync() const; 421 422 // Returns whether this RecordFunction corresponds to NCCL metadata collection 423 // or not. 424 bool isNcclMeta() const { 425 return is_nccl_meta_; 426 } 427 428 // Internal-only, used to denote out variant used for Static Runtime execution 429 void _setStaticRuntimeOutVariant(); 430 bool isStaticRuntimeOutVariant() const; 431 432 RecordFunctionHandle handle() const { 433 return handle_; 434 } 435 436 std::optional<OperatorName> operator_name() const; 437 438 // This method returns a copy of the FunctionSchema and can be expensive. 439 std::optional<FunctionSchema> operator_schema() const; 440 441 void setHandle(RecordFunctionHandle handle) { 442 handle_ = handle; 443 } 444 445 // Whether this RecordFunction runs any callbacks. 446 bool isActive() const { 447 return !step_callbacks_.empty(); 448 } 449 450 bool needsInputs() const { 451 return step_callbacks_.needs_inputs_; 452 } 453 454 bool needsOutputs() const { 455 return step_callbacks_.needs_outputs_; 456 } 457 458 int64_t debugHandle() const { 459 return debug_handle_; 460 } 461 462 void setDebugHandle(int64_t debug_handle) { 463 debug_handle_ = debug_handle; 464 } 465 466 void invalidateInputs() { 467 #ifndef NDEBUG 468 inputs_valid_ = false; 469 #endif 470 } 471 472 private: 473 void runStartCallbacks(); 474 475 StepCallbacks step_callbacks_; 476 477 // In cases when RecordFunction might be active but we chose not to 478 // use the observers (e.g. operator is not observed), this boolean 479 // flag is used to check whether the start callbacks were called 480 bool called_start_callbacks_ = false; 481 482 #ifndef NDEBUG 483 bool inputs_valid_ = false; 484 #endif 485 486 // Stores various ObserverContext objects with event metadata for callbacks. 487 ObserverContextList ctx_; 488 489 std::variant<std::string, schema_ref_t> fn_; 490 491 int64_t sequence_nr_ = -1; 492 c10::ArrayRef<const IValue> inputs_; 493 std::unordered_map<std::string, IValue> kwinputs_; 494 std::vector<c10::IValue> outputs_; 495 496 // For backward functions - thread id of the forward function 497 uint64_t fwd_thread_id_ = 0; 498 499 // Unique id for this RecordFunction, used in callbacks to track start 500 // and end of ranges 501 RecordFunctionHandle handle_{0}; 502 503 // Whether this record_function corresponds to an async event or not. Async 504 // events can complete in different threads or follow a future-like pattern 505 // of use. 506 bool is_async_{false}; 507 508 // Debug handles are used for lazy annotation of module hierarchy 509 // and callstack. 510 // This is specifically is useful for mobile runtime, where generated 511 // debug handles can be lazily symbolicated using debug information 512 int64_t debug_handle_{-1}; 513 514 // Whether this RecordFunction is used for an out variant run with 515 // Static Runtime 516 bool is_static_runtime_out_variant_{false}; 517 518 // Whether this RecordFunction is used for NCCL metadata collection 519 bool is_nccl_meta_{false}; 520 }; 521 522 TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); 523 524 TORCH_API std::optional<StepCallbacks> getStepCallbacksUnlessEmpty( 525 RecordScope scope); 526 527 namespace detail { 528 template <typename Inputs, typename F, typename... Args> 529 void record_function_with_scope( 530 RecordFunction& guard, 531 F fn, 532 const Inputs& inputs, 533 Args&&... args) { 534 if (guard.needsInputs()) { 535 guard.before( 536 fn, 537 c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()), 538 std::forward<Args>(args)...); 539 } else { 540 guard.before(fn, std::forward<Args>(args)...); 541 } 542 } 543 544 template <typename Inputs, typename F, typename... Args> 545 void record_function_with_scope_and_debug_handle( 546 RecordFunction& guard, 547 F fn, 548 int64_t debug_handle, 549 const Inputs& inputs, 550 Args&&... args) { 551 guard.setDebugHandle(debug_handle); 552 if (guard.needsInputs()) { 553 guard.before( 554 fn, 555 c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()), 556 std::forward<Args>(args)...); 557 } else { 558 guard.before(fn, std::forward<Args>(args)...); 559 } 560 } 561 562 template <typename F, typename... Args> 563 void record_function_with_scope( 564 RecordFunction& guard, 565 F fn, 566 c10::ArrayRef<const c10::IValue> inputs, 567 Args&&... args) { 568 return record_function_with_scope< 569 c10::ArrayRef<const c10::IValue>, 570 F, 571 Args...>(guard, std::move(fn), inputs, std::forward<Args>(args)...); 572 } 573 574 template <typename F, typename... Args> 575 void record_function_with_scope_and_debug_handle( 576 RecordFunction& guard, 577 F fn, 578 int64_t debug_handle, 579 c10::ArrayRef<const c10::IValue> inputs, 580 Args&&... args) { 581 return record_function_with_scope_and_debug_handle< 582 c10::ArrayRef<const c10::IValue>, 583 F, 584 Args...>( 585 guard, std::move(fn), debug_handle, inputs, std::forward<Args>(args)...); 586 } 587 588 } // namespace detail 589 590 // optional argument - function's seq_no 591 #define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ 592 at::RecordFunction guard(scope); \ 593 if (guard.isActive()) { \ 594 ::at::detail::record_function_with_scope( \ 595 guard, fn, inputs, ##__VA_ARGS__); \ 596 } 597 598 #define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ 599 scope, fn, inputs, outputs, ...) \ 600 at::RecordFunction guard(scope); \ 601 if (guard.isActive()) { \ 602 if (guard.needsInputs()) { \ 603 guard.before(fn, inputs, ##__VA_ARGS__); \ 604 } else { \ 605 guard.before(fn, ##__VA_ARGS__); \ 606 } \ 607 if (guard.needsOutputs()) { \ 608 guard.setOutputs(outputs); \ 609 } \ 610 } 611 612 #define RECORD_FUNCTION(fn, inputs, ...) \ 613 RECORD_FUNCTION_WITH_SCOPE( \ 614 at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__) 615 616 #define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \ 617 RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs) 618 619 #define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \ 620 RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ 621 at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__) 622 623 // Custom user scopes in C++; similar to Python's 'with record_function("..."):' 624 #define RECORD_USER_SCOPE(fn) \ 625 RECORD_FUNCTION_WITH_SCOPE( \ 626 at::RecordScope::USER_SCOPE, fn, c10::ArrayRef<const c10::IValue>{}) 627 628 // RECORD_USER_SCOPE with inputs 629 #define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ 630 RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) 631 632 // Helper macro to pass in debug handle that is used to 633 // post process events 634 #define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ 635 scope, fn, debug_handle, inputs, ...) \ 636 at::RecordFunction guard(scope); \ 637 if (guard.isActive()) { \ 638 ::at::detail::record_function_with_scope_and_debug_handle( \ 639 guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ 640 } 641 642 // Helper macros to record LITE INTERPETER scope events with debug handles 643 #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ 644 fn, debug_handle, inputs) \ 645 RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ 646 at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) 647 648 // Bookend to the RECORD_FUNCTION macros. Use this after the kernel 649 // launch to let the profiler bind the outputs to the op that produced 650 // them. Note that guard is declared by RECORD_FUNCTION so this macro 651 // needs to be called from the same scope as RECORD_FUNCTION 652 #define RECORD_OUTPUTS(outputs) \ 653 if (guard.needsOutputs()) { \ 654 guard.setOutputs( \ 655 std::vector<c10::IValue>(outputs.begin(), outputs.end())); \ 656 } 657 658 /** 659 * addThreadLocalCallback adds a thread local callback to run with 660 * RecordFunction, returns handle to use with removeThreadLocalCallback 661 */ 662 TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb); 663 664 /** 665 * hasThreadLocalCallbacks returns whether there're callbacks registered 666 * with addThreadLocalCallback 667 */ 668 TORCH_API bool hasThreadLocalCallbacks(); 669 670 /** 671 * clearThreadLocalCallbacks removes all thread local callbacks 672 */ 673 TORCH_API void clearThreadLocalCallbacks(); 674 675 /** 676 * addGlobalCallback adds a global callback to run with RecordFunction: 677 * 678 * only during the program initialization 679 */ 680 TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb); 681 682 /** 683 * removeCallback removes a callback given the handle returned by 684 * addThreadLocalCallback or addGlobalCallback; 685 * 686 * no other code can run simultaneously 687 */ 688 TORCH_API void removeCallback(CallbackHandle handle); 689 690 /** 691 * Prevent the given callback from executing. If handle is invalid, 692 * does nothing. 693 */ 694 TORCH_API void disableCallback(CallbackHandle handle); 695 696 /** 697 * Allow the given callback, previously disabled with disableCallback, to 698 * execute again. If handle is invalid, does nothing. 699 */ 700 TORCH_API void reenableCallback(CallbackHandle handle); 701 702 /** 703 * hasGlobalCallbacks returns whether there're global callbacks 704 * registered with pushGlobalCallback 705 */ 706 TORCH_API bool hasGlobalCallbacks(); 707 708 /** 709 * clearGlobalCallbacks removes all global callbacks 710 */ 711 TORCH_API void clearGlobalCallbacks(); 712 713 // for both thread local and global callbacks 714 TORCH_API bool hasCallbacks(); 715 TORCH_API void clearCallbacks(); 716 717 /** 718 * enableRecordFunction enables RecordFunction thread locally 719 */ 720 TORCH_API void enableRecordFunction(bool enable = true); 721 722 /** 723 * isRecordFunctionEnabled returns whether RecordFunction 724 * is enabled thread locally 725 */ 726 TORCH_API bool isRecordFunctionEnabled(); 727 728 class TORCH_API RecordFunctionGuard { 729 public: 730 explicit RecordFunctionGuard(bool is_enabled = true) 731 : prev_value_(isRecordFunctionEnabled()) { 732 enableRecordFunction(is_enabled); 733 } 734 735 virtual ~RecordFunctionGuard() { 736 enableRecordFunction(prev_value_); 737 } 738 739 private: 740 bool prev_value_ = false; 741 }; 742 743 class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { 744 public: 745 DisableRecordFunctionGuard() : RecordFunctionGuard(false) {} 746 ~DisableRecordFunctionGuard() override = default; 747 }; 748 749 struct TORCH_API RecordFunctionTLS { 750 // Thread local vector of callbacks, holds pairs (callbacks, unique_id); 751 // must be sorted in increasing handles order 752 RecordFunctionCallbacks sorted_tls_callbacks_; 753 754 bool tls_record_function_enabled_ = true; 755 }; 756 757 TORCH_API const RecordFunctionTLS& get_record_function_tls_(); 758 759 TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); 760 761 TORCH_API void set_record_function_seed_for_testing(uint32_t seed); 762 763 } // namespace at 764