xref: /aosp_15_r20/external/pytorch/c10/util/CallOnce.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <atomic>
4*da0073e9SAndroid Build Coastguard Worker #include <mutex>
5*da0073e9SAndroid Build Coastguard Worker #include <utility>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/C++17.h>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker namespace c10 {
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker // custom c10 call_once implementation to avoid the deadlock in std::call_once.
13*da0073e9SAndroid Build Coastguard Worker // The implementation here is a simplified version from folly and likely much
14*da0073e9SAndroid Build Coastguard Worker // much higher memory footprint.
15*da0073e9SAndroid Build Coastguard Worker template <typename Flag, typename F, typename... Args>
call_once(Flag & flag,F && f,Args &&...args)16*da0073e9SAndroid Build Coastguard Worker inline void call_once(Flag& flag, F&& f, Args&&... args) {
17*da0073e9SAndroid Build Coastguard Worker   if (C10_LIKELY(flag.test_once())) {
18*da0073e9SAndroid Build Coastguard Worker     return;
19*da0073e9SAndroid Build Coastguard Worker   }
20*da0073e9SAndroid Build Coastguard Worker   flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker class once_flag {
24*da0073e9SAndroid Build Coastguard Worker  public:
25*da0073e9SAndroid Build Coastguard Worker #ifndef _WIN32
26*da0073e9SAndroid Build Coastguard Worker   // running into build error on MSVC. Can't seem to get a repro locally so I'm
27*da0073e9SAndroid Build Coastguard Worker   // just avoiding constexpr
28*da0073e9SAndroid Build Coastguard Worker   //
29*da0073e9SAndroid Build Coastguard Worker   //   C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
30*da0073e9SAndroid Build Coastguard Worker   //   defaulted default constructor cannot be constexpr because the
31*da0073e9SAndroid Build Coastguard Worker   //   corresponding implicitly declared default constructor would not be
32*da0073e9SAndroid Build Coastguard Worker   //   constexpr 1 error detected in the compilation of
33*da0073e9SAndroid Build Coastguard Worker   //   "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
34*da0073e9SAndroid Build Coastguard Worker   constexpr
35*da0073e9SAndroid Build Coastguard Worker #endif
36*da0073e9SAndroid Build Coastguard Worker       once_flag() noexcept = default;
37*da0073e9SAndroid Build Coastguard Worker   once_flag(const once_flag&) = delete;
38*da0073e9SAndroid Build Coastguard Worker   once_flag& operator=(const once_flag&) = delete;
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker  private:
41*da0073e9SAndroid Build Coastguard Worker   template <typename Flag, typename F, typename... Args>
42*da0073e9SAndroid Build Coastguard Worker   friend void call_once(Flag& flag, F&& f, Args&&... args);
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker   template <typename F, typename... Args>
call_once_slow(F && f,Args &&...args)45*da0073e9SAndroid Build Coastguard Worker   void call_once_slow(F&& f, Args&&... args) {
46*da0073e9SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> guard(mutex_);
47*da0073e9SAndroid Build Coastguard Worker     if (init_.load(std::memory_order_relaxed)) {
48*da0073e9SAndroid Build Coastguard Worker       return;
49*da0073e9SAndroid Build Coastguard Worker     }
50*da0073e9SAndroid Build Coastguard Worker     c10::guts::invoke(std::forward<F>(f), std::forward<Args>(args)...);
51*da0073e9SAndroid Build Coastguard Worker     init_.store(true, std::memory_order_release);
52*da0073e9SAndroid Build Coastguard Worker   }
53*da0073e9SAndroid Build Coastguard Worker 
test_once()54*da0073e9SAndroid Build Coastguard Worker   bool test_once() {
55*da0073e9SAndroid Build Coastguard Worker     return init_.load(std::memory_order_acquire);
56*da0073e9SAndroid Build Coastguard Worker   }
57*da0073e9SAndroid Build Coastguard Worker 
reset_once()58*da0073e9SAndroid Build Coastguard Worker   void reset_once() {
59*da0073e9SAndroid Build Coastguard Worker     init_.store(false, std::memory_order_release);
60*da0073e9SAndroid Build Coastguard Worker   }
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker  private:
63*da0073e9SAndroid Build Coastguard Worker   std::mutex mutex_;
64*da0073e9SAndroid Build Coastguard Worker   std::atomic<bool> init_{false};
65*da0073e9SAndroid Build Coastguard Worker };
66*da0073e9SAndroid Build Coastguard Worker 
67*da0073e9SAndroid Build Coastguard Worker } // namespace c10
68