xref: /aosp_15_r20/external/pytorch/c10/util/ThreadLocalDebugInfo.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Export.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
6*da0073e9SAndroid Build Coastguard Worker #include <memory>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace c10 {
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker enum class C10_API_ENUM DebugInfoKind : uint8_t {
11*da0073e9SAndroid Build Coastguard Worker   PRODUCER_INFO = 0,
12*da0073e9SAndroid Build Coastguard Worker   MOBILE_RUNTIME_INFO,
13*da0073e9SAndroid Build Coastguard Worker   PROFILER_STATE,
14*da0073e9SAndroid Build Coastguard Worker   INFERENCE_CONTEXT, // for inference usage
15*da0073e9SAndroid Build Coastguard Worker   PARAM_COMMS_INFO,
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker   TEST_INFO, // used only in tests
18*da0073e9SAndroid Build Coastguard Worker   TEST_INFO_2, // used only in tests
19*da0073e9SAndroid Build Coastguard Worker };
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker class C10_API DebugInfoBase {
22*da0073e9SAndroid Build Coastguard Worker  public:
23*da0073e9SAndroid Build Coastguard Worker   DebugInfoBase() = default;
24*da0073e9SAndroid Build Coastguard Worker   virtual ~DebugInfoBase() = default;
25*da0073e9SAndroid Build Coastguard Worker };
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker // Thread local debug information is propagated across the forward
28*da0073e9SAndroid Build Coastguard Worker // (including async fork tasks) and backward passes and is supposed
29*da0073e9SAndroid Build Coastguard Worker // to be utilized by the user's code to pass extra information from
30*da0073e9SAndroid Build Coastguard Worker // the higher layers (e.g. model id) down to the lower levels
31*da0073e9SAndroid Build Coastguard Worker // (e.g. to the operator observers used for debugging, logging,
32*da0073e9SAndroid Build Coastguard Worker // profiling, etc)
33*da0073e9SAndroid Build Coastguard Worker class C10_API ThreadLocalDebugInfo {
34*da0073e9SAndroid Build Coastguard Worker  public:
35*da0073e9SAndroid Build Coastguard Worker   static DebugInfoBase* get(DebugInfoKind kind);
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker   // Get current ThreadLocalDebugInfo
38*da0073e9SAndroid Build Coastguard Worker   static std::shared_ptr<ThreadLocalDebugInfo> current();
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   // Internal, use DebugInfoGuard/ThreadLocalStateGuard
41*da0073e9SAndroid Build Coastguard Worker   static void _forceCurrentDebugInfo(
42*da0073e9SAndroid Build Coastguard Worker       std::shared_ptr<ThreadLocalDebugInfo> info);
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker   // Push debug info struct of a given kind
45*da0073e9SAndroid Build Coastguard Worker   static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
46*da0073e9SAndroid Build Coastguard Worker   // Pop debug info, throws in case the last pushed
47*da0073e9SAndroid Build Coastguard Worker   // debug info is not of a given kind
48*da0073e9SAndroid Build Coastguard Worker   static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
49*da0073e9SAndroid Build Coastguard Worker   // Peek debug info, throws in case the last pushed debug info is not of the
50*da0073e9SAndroid Build Coastguard Worker   // given kind
51*da0073e9SAndroid Build Coastguard Worker   static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker  private:
54*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<DebugInfoBase> info_;
55*da0073e9SAndroid Build Coastguard Worker   DebugInfoKind kind_;
56*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker   friend class DebugInfoGuard;
59*da0073e9SAndroid Build Coastguard Worker };
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker // DebugInfoGuard is used to set debug information,
62*da0073e9SAndroid Build Coastguard Worker // ThreadLocalDebugInfo is semantically immutable, the values are set
63*da0073e9SAndroid Build Coastguard Worker // through the scope-based guard object.
64*da0073e9SAndroid Build Coastguard Worker // Nested DebugInfoGuard adds/overrides existing values in the scope,
65*da0073e9SAndroid Build Coastguard Worker // restoring the original values after exiting the scope.
66*da0073e9SAndroid Build Coastguard Worker // Users can access the values through the ThreadLocalDebugInfo::get() call;
67*da0073e9SAndroid Build Coastguard Worker class C10_API DebugInfoGuard {
68*da0073e9SAndroid Build Coastguard Worker  public:
69*da0073e9SAndroid Build Coastguard Worker   DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker   explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker   ~DebugInfoGuard();
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker   DebugInfoGuard(const DebugInfoGuard&) = delete;
76*da0073e9SAndroid Build Coastguard Worker   DebugInfoGuard(DebugInfoGuard&&) = delete;
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker  private:
79*da0073e9SAndroid Build Coastguard Worker   bool active_ = false;
80*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
81*da0073e9SAndroid Build Coastguard Worker };
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker } // namespace c10
84