xref: /aosp_15_r20/external/pytorch/c10/util/WaitCounter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <chrono>
4*da0073e9SAndroid Build Coastguard Worker #include <memory>
5*da0073e9SAndroid Build Coastguard Worker #include <string>
6*da0073e9SAndroid Build Coastguard Worker #include <string_view>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ScopeExit.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/SmallVector.h>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker namespace c10::monitor {
13*da0073e9SAndroid Build Coastguard Worker namespace detail {
14*da0073e9SAndroid Build Coastguard Worker class WaitCounterImpl;
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker class WaitCounterBackendIf {
17*da0073e9SAndroid Build Coastguard Worker  public:
18*da0073e9SAndroid Build Coastguard Worker   virtual ~WaitCounterBackendIf() = default;
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker   virtual intptr_t start(
21*da0073e9SAndroid Build Coastguard Worker       std::chrono::steady_clock::time_point now) noexcept = 0;
22*da0073e9SAndroid Build Coastguard Worker   virtual void stop(
23*da0073e9SAndroid Build Coastguard Worker       std::chrono::steady_clock::time_point now,
24*da0073e9SAndroid Build Coastguard Worker       intptr_t ctx) noexcept = 0;
25*da0073e9SAndroid Build Coastguard Worker };
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker class WaitCounterBackendFactoryIf {
28*da0073e9SAndroid Build Coastguard Worker  public:
29*da0073e9SAndroid Build Coastguard Worker   virtual ~WaitCounterBackendFactoryIf() = default;
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker   // May return nullptr.
32*da0073e9SAndroid Build Coastguard Worker   // In this case the counter will be ignored by the given backend.
33*da0073e9SAndroid Build Coastguard Worker   virtual std::unique_ptr<WaitCounterBackendIf> create(
34*da0073e9SAndroid Build Coastguard Worker       std::string_view key) noexcept = 0;
35*da0073e9SAndroid Build Coastguard Worker };
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker C10_API void registerWaitCounterBackend(
38*da0073e9SAndroid Build Coastguard Worker     std::unique_ptr<WaitCounterBackendFactoryIf>);
39*da0073e9SAndroid Build Coastguard Worker } // namespace detail
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker // A handle to a wait counter.
42*da0073e9SAndroid Build Coastguard Worker class C10_API WaitCounterHandle {
43*da0073e9SAndroid Build Coastguard Worker  public:
44*da0073e9SAndroid Build Coastguard Worker   explicit WaitCounterHandle(std::string_view key);
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker   class WaitGuard {
47*da0073e9SAndroid Build Coastguard Worker    public:
WaitGuard(WaitGuard && other)48*da0073e9SAndroid Build Coastguard Worker     WaitGuard(WaitGuard&& other) noexcept
49*da0073e9SAndroid Build Coastguard Worker         : handle_{std::exchange(other.handle_, {})},
50*da0073e9SAndroid Build Coastguard Worker           ctxs_{std::move(other.ctxs_)} {}
51*da0073e9SAndroid Build Coastguard Worker     WaitGuard(const WaitGuard&) = delete;
52*da0073e9SAndroid Build Coastguard Worker     WaitGuard& operator=(const WaitGuard&) = delete;
53*da0073e9SAndroid Build Coastguard Worker     WaitGuard& operator=(WaitGuard&&) = delete;
54*da0073e9SAndroid Build Coastguard Worker 
~WaitGuard()55*da0073e9SAndroid Build Coastguard Worker     ~WaitGuard() {
56*da0073e9SAndroid Build Coastguard Worker       stop();
57*da0073e9SAndroid Build Coastguard Worker     }
58*da0073e9SAndroid Build Coastguard Worker 
stop()59*da0073e9SAndroid Build Coastguard Worker     void stop() {
60*da0073e9SAndroid Build Coastguard Worker       if (auto handle = std::exchange(handle_, nullptr)) {
61*da0073e9SAndroid Build Coastguard Worker         handle->stop(std::move(ctxs_));
62*da0073e9SAndroid Build Coastguard Worker       }
63*da0073e9SAndroid Build Coastguard Worker     }
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker    private:
WaitGuard(WaitCounterHandle & handle,SmallVector<intptr_t> && ctxs)66*da0073e9SAndroid Build Coastguard Worker     WaitGuard(WaitCounterHandle& handle, SmallVector<intptr_t>&& ctxs)
67*da0073e9SAndroid Build Coastguard Worker         : handle_{&handle}, ctxs_{std::move(ctxs)} {}
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker     friend class WaitCounterHandle;
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker     WaitCounterHandle* handle_;
72*da0073e9SAndroid Build Coastguard Worker     SmallVector<intptr_t> ctxs_;
73*da0073e9SAndroid Build Coastguard Worker   };
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker   // Starts a waiter
76*da0073e9SAndroid Build Coastguard Worker   WaitGuard start();
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker  private:
79*da0073e9SAndroid Build Coastguard Worker   // Stops the waiter. Each start() call should be matched by exactly one stop()
80*da0073e9SAndroid Build Coastguard Worker   // call.
81*da0073e9SAndroid Build Coastguard Worker   void stop(SmallVector<intptr_t>&& ctxs);
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   detail::WaitCounterImpl& impl_;
84*da0073e9SAndroid Build Coastguard Worker };
85*da0073e9SAndroid Build Coastguard Worker } // namespace c10::monitor
86*da0073e9SAndroid Build Coastguard Worker 
87*da0073e9SAndroid Build Coastguard Worker #define STATIC_WAIT_COUNTER(_key)                           \
88*da0073e9SAndroid Build Coastguard Worker   []() -> ::c10::monitor::WaitCounterHandle& {              \
89*da0073e9SAndroid Build Coastguard Worker     static ::c10::monitor::WaitCounterHandle handle(#_key); \
90*da0073e9SAndroid Build Coastguard Worker     return handle;                                          \
91*da0073e9SAndroid Build Coastguard Worker   }()
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker #define STATIC_SCOPED_WAIT_COUNTER(_name) \
94*da0073e9SAndroid Build Coastguard Worker   auto C10_ANONYMOUS_VARIABLE(SCOPE_GUARD) = STATIC_WAIT_COUNTER(_name).start();
95