xref: /aosp_15_r20/external/pytorch/c10/util/WaitCounter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <chrono>
4 #include <memory>
5 #include <string>
6 #include <string_view>
7 
8 #include <c10/macros/Macros.h>
9 #include <c10/util/ScopeExit.h>
10 #include <c10/util/SmallVector.h>
11 
12 namespace c10::monitor {
13 namespace detail {
14 class WaitCounterImpl;
15 
16 class WaitCounterBackendIf {
17  public:
18   virtual ~WaitCounterBackendIf() = default;
19 
20   virtual intptr_t start(
21       std::chrono::steady_clock::time_point now) noexcept = 0;
22   virtual void stop(
23       std::chrono::steady_clock::time_point now,
24       intptr_t ctx) noexcept = 0;
25 };
26 
27 class WaitCounterBackendFactoryIf {
28  public:
29   virtual ~WaitCounterBackendFactoryIf() = default;
30 
31   // May return nullptr.
32   // In this case the counter will be ignored by the given backend.
33   virtual std::unique_ptr<WaitCounterBackendIf> create(
34       std::string_view key) noexcept = 0;
35 };
36 
37 C10_API void registerWaitCounterBackend(
38     std::unique_ptr<WaitCounterBackendFactoryIf>);
39 } // namespace detail
40 
41 // A handle to a wait counter.
42 class C10_API WaitCounterHandle {
43  public:
44   explicit WaitCounterHandle(std::string_view key);
45 
46   class WaitGuard {
47    public:
WaitGuard(WaitGuard && other)48     WaitGuard(WaitGuard&& other) noexcept
49         : handle_{std::exchange(other.handle_, {})},
50           ctxs_{std::move(other.ctxs_)} {}
51     WaitGuard(const WaitGuard&) = delete;
52     WaitGuard& operator=(const WaitGuard&) = delete;
53     WaitGuard& operator=(WaitGuard&&) = delete;
54 
~WaitGuard()55     ~WaitGuard() {
56       stop();
57     }
58 
stop()59     void stop() {
60       if (auto handle = std::exchange(handle_, nullptr)) {
61         handle->stop(std::move(ctxs_));
62       }
63     }
64 
65    private:
WaitGuard(WaitCounterHandle & handle,SmallVector<intptr_t> && ctxs)66     WaitGuard(WaitCounterHandle& handle, SmallVector<intptr_t>&& ctxs)
67         : handle_{&handle}, ctxs_{std::move(ctxs)} {}
68 
69     friend class WaitCounterHandle;
70 
71     WaitCounterHandle* handle_;
72     SmallVector<intptr_t> ctxs_;
73   };
74 
75   // Starts a waiter
76   WaitGuard start();
77 
78  private:
79   // Stops the waiter. Each start() call should be matched by exactly one stop()
80   // call.
81   void stop(SmallVector<intptr_t>&& ctxs);
82 
83   detail::WaitCounterImpl& impl_;
84 };
85 } // namespace c10::monitor
86 
87 #define STATIC_WAIT_COUNTER(_key)                           \
88   []() -> ::c10::monitor::WaitCounterHandle& {              \
89     static ::c10::monitor::WaitCounterHandle handle(#_key); \
90     return handle;                                          \
91   }()
92 
93 #define STATIC_SCOPED_WAIT_COUNTER(_name) \
94   auto C10_ANONYMOUS_VARIABLE(SCOPE_GUARD) = STATIC_WAIT_COUNTER(_name).start();
95