xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/thread_pool.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/thread_pool.h>
2 
3 #include <c10/util/Logging.h>
4 #include <c10/util/irange.h>
5 #include <c10/util/thread_name.h>
6 #include <torch/csrc/lazy/core/config.h>
7 #include <torch/csrc/lazy/core/metrics.h>
8 
9 #include <condition_variable>
10 #include <deque>
11 #include <exception>
12 #include <mutex>
13 #include <thread>
14 
15 namespace torch {
16 namespace lazy {
17 namespace {
18 
19 class ThreadPool {
20  public:
ThreadPool(size_t num_threads)21   explicit ThreadPool(size_t num_threads) {
22     threads_.reserve(num_threads);
23     for (const auto i : c10::irange(num_threads)) {
24       (void)i; // Suppress unused variable warning
25       threads_.emplace_back([this]() {
26         c10::setThreadName("pt_thread_pool");
27         Worker();
28       });
29     }
30   }
31 
~ThreadPool()32   ~ThreadPool() {
33     {
34       std::lock_guard<std::mutex> lock(mutex_);
35       exiting_ = true;
36       cv_.notify_all();
37     }
38     for (auto& thread : threads_) {
39       thread.join();
40     }
41   }
42 
Schedule(std::function<void ()> closure)43   void Schedule(std::function<void()> closure) {
44     // If we have more work scheduled than waiting worker threads, just schedule
45     // it on a separate thread. This prevents tricky thread-pool-size-deadlocks
46     // caused by an undersized thread pool and closures that end up doing sync
47     // waits on the pool threads.
48     {
49       std::unique_lock<std::mutex> lock(mutex_);
50       if (work_.size() < waiting_) {
51         work_.emplace_back(std::move(closure));
52         lock.unlock();
53         cv_.notify_one();
54         return;
55       }
56     }
57     ScheduleOnThread(std::move(closure));
58   }
59 
60  private:
Worker()61   void Worker() {
62     while (true) {
63       std::function<void()> closure = GetWork();
64       if (closure == nullptr) {
65         break;
66       }
67       try {
68         closure();
69       } catch (const std::exception& ex) {
70         TORCH_LAZY_COUNTER("ThreadPoolException", 1);
71         LOG(ERROR) << "Exception from running thread pool closure: "
72                    << ex.what();
73       }
74     }
75   }
76 
ScheduleOnThread(std::function<void ()> closure)77   void ScheduleOnThread(std::function<void()> closure) {
78     std::thread thread(std::move(closure));
79     thread.detach();
80   }
81 
GetWork()82   std::function<void()> GetWork() {
83     std::unique_lock<std::mutex> lock(mutex_);
84     ++waiting_;
85     cv_.wait(lock, [this] { return exiting_ || !work_.empty(); });
86     --waiting_;
87     if (work_.empty()) {
88       return nullptr;
89     }
90     std::function<void()> closure(std::move(work_.front()));
91     work_.pop_front();
92     return closure;
93   }
94 
95   std::vector<std::thread> threads_;
96   std::mutex mutex_;
97   std::condition_variable cv_;
98   bool exiting_ = false;
99   std::deque<std::function<void()>> work_;
100   size_t waiting_ = 0;
101 };
102 
GetIoThreadPool()103 ThreadPool* GetIoThreadPool() {
104   static ThreadPool* pool =
105       new ThreadPool(FLAGS_torch_lazy_io_thread_pool_size);
106   return pool;
107 }
108 
109 } // namespace
110 
111 class Completion::Data {
112  public:
Wait()113   void Wait() {
114     std::unique_lock<std::mutex> lock(mutex_);
115     cv_.wait(lock, [this] { return completed_; });
116     if (exptr_ != nullptr) {
117       std::rethrow_exception(exptr_);
118     }
119   }
120 
GetCompleter(const std::shared_ptr<Data> & data,std::function<void ()> closure)121   static std::function<void()> GetCompleter(
122       const std::shared_ptr<Data>& data,
123       std::function<void()> closure) {
124     auto closure_wrapper = [closure = std::move(closure), data]() {
125       std::exception_ptr exptr;
126       try {
127         closure();
128       } catch (...) {
129         exptr = std::current_exception();
130       }
131       data->Complete(exptr);
132     };
133     return closure_wrapper;
134   }
135 
136  private:
Complete(std::exception_ptr exptr)137   void Complete(std::exception_ptr exptr) {
138     std::lock_guard<std::mutex> lock(mutex_);
139     exptr_ = std::move(exptr);
140     completed_ = true;
141     cv_.notify_all();
142   }
143 
144   std::mutex mutex_;
145   std::condition_variable cv_;
146   bool completed_ = false;
147   std::exception_ptr exptr_;
148 };
149 
Completion(std::shared_ptr<Data> data)150 Completion::Completion(std::shared_ptr<Data> data) : data_(std::move(data)) {}
151 
Wait()152 void Completion::Wait() {
153   data_->Wait();
154 }
155 
ScheduleIoClosure(std::function<void ()> closure)156 void ScheduleIoClosure(std::function<void()> closure) {
157   GetIoThreadPool()->Schedule(std::move(closure));
158 }
159 
ScheduleIoClosureWithCompletion(std::function<void ()> closure)160 Completion ScheduleIoClosureWithCompletion(std::function<void()> closure) {
161   auto data = std::make_shared<Completion::Data>();
162   GetIoThreadPool()->Schedule(
163       Completion::Data::GetCompleter(data, std::move(closure)));
164   return Completion(std::move(data));
165 }
166 
167 } // namespace lazy
168 } // namespace torch
169