xref: /aosp_15_r20/external/pytorch/c10/util/CallOnce.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <mutex>
5 #include <utility>
6 
7 #include <c10/macros/Macros.h>
8 #include <c10/util/C++17.h>
9 
10 namespace c10 {
11 
12 // custom c10 call_once implementation to avoid the deadlock in std::call_once.
13 // The implementation here is a simplified version from folly and likely much
14 // much higher memory footprint.
15 template <typename Flag, typename F, typename... Args>
call_once(Flag & flag,F && f,Args &&...args)16 inline void call_once(Flag& flag, F&& f, Args&&... args) {
17   if (C10_LIKELY(flag.test_once())) {
18     return;
19   }
20   flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
21 }
22 
23 class once_flag {
24  public:
25 #ifndef _WIN32
26   // running into build error on MSVC. Can't seem to get a repro locally so I'm
27   // just avoiding constexpr
28   //
29   //   C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
30   //   defaulted default constructor cannot be constexpr because the
31   //   corresponding implicitly declared default constructor would not be
32   //   constexpr 1 error detected in the compilation of
33   //   "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
34   constexpr
35 #endif
36       once_flag() noexcept = default;
37   once_flag(const once_flag&) = delete;
38   once_flag& operator=(const once_flag&) = delete;
39 
40  private:
41   template <typename Flag, typename F, typename... Args>
42   friend void call_once(Flag& flag, F&& f, Args&&... args);
43 
44   template <typename F, typename... Args>
call_once_slow(F && f,Args &&...args)45   void call_once_slow(F&& f, Args&&... args) {
46     std::lock_guard<std::mutex> guard(mutex_);
47     if (init_.load(std::memory_order_relaxed)) {
48       return;
49     }
50     c10::guts::invoke(std::forward<F>(f), std::forward<Args>(args)...);
51     init_.store(true, std::memory_order_release);
52   }
53 
test_once()54   bool test_once() {
55     return init_.load(std::memory_order_acquire);
56   }
57 
reset_once()58   void reset_once() {
59     init_.store(false, std::memory_order_release);
60   }
61 
62  private:
63   std::mutex mutex_;
64   std::atomic<bool> init_{false};
65 };
66 
67 } // namespace c10
68