1 #include <torch/csrc/lazy/core/thread_pool.h> 2 3 #include <c10/util/Logging.h> 4 #include <c10/util/irange.h> 5 #include <c10/util/thread_name.h> 6 #include <torch/csrc/lazy/core/config.h> 7 #include <torch/csrc/lazy/core/metrics.h> 8 9 #include <condition_variable> 10 #include <deque> 11 #include <exception> 12 #include <mutex> 13 #include <thread> 14 15 namespace torch { 16 namespace lazy { 17 namespace { 18 19 class ThreadPool { 20 public: ThreadPool(size_t num_threads)21 explicit ThreadPool(size_t num_threads) { 22 threads_.reserve(num_threads); 23 for (const auto i : c10::irange(num_threads)) { 24 (void)i; // Suppress unused variable warning 25 threads_.emplace_back([this]() { 26 c10::setThreadName("pt_thread_pool"); 27 Worker(); 28 }); 29 } 30 } 31 ~ThreadPool()32 ~ThreadPool() { 33 { 34 std::lock_guard<std::mutex> lock(mutex_); 35 exiting_ = true; 36 cv_.notify_all(); 37 } 38 for (auto& thread : threads_) { 39 thread.join(); 40 } 41 } 42 Schedule(std::function<void ()> closure)43 void Schedule(std::function<void()> closure) { 44 // If we have more work scheduled than waiting worker threads, just schedule 45 // it on a separate thread. This prevents tricky thread-pool-size-deadlocks 46 // caused by an undersized thread pool and closures that end up doing sync 47 // waits on the pool threads. 48 { 49 std::unique_lock<std::mutex> lock(mutex_); 50 if (work_.size() < waiting_) { 51 work_.emplace_back(std::move(closure)); 52 lock.unlock(); 53 cv_.notify_one(); 54 return; 55 } 56 } 57 ScheduleOnThread(std::move(closure)); 58 } 59 60 private: Worker()61 void Worker() { 62 while (true) { 63 std::function<void()> closure = GetWork(); 64 if (closure == nullptr) { 65 break; 66 } 67 try { 68 closure(); 69 } catch (const std::exception& ex) { 70 TORCH_LAZY_COUNTER("ThreadPoolException", 1); 71 LOG(ERROR) << "Exception from running thread pool closure: " 72 << ex.what(); 73 } 74 } 75 } 76 ScheduleOnThread(std::function<void ()> closure)77 void ScheduleOnThread(std::function<void()> closure) { 78 std::thread thread(std::move(closure)); 79 thread.detach(); 80 } 81 GetWork()82 std::function<void()> GetWork() { 83 std::unique_lock<std::mutex> lock(mutex_); 84 ++waiting_; 85 cv_.wait(lock, [this] { return exiting_ || !work_.empty(); }); 86 --waiting_; 87 if (work_.empty()) { 88 return nullptr; 89 } 90 std::function<void()> closure(std::move(work_.front())); 91 work_.pop_front(); 92 return closure; 93 } 94 95 std::vector<std::thread> threads_; 96 std::mutex mutex_; 97 std::condition_variable cv_; 98 bool exiting_ = false; 99 std::deque<std::function<void()>> work_; 100 size_t waiting_ = 0; 101 }; 102 GetIoThreadPool()103ThreadPool* GetIoThreadPool() { 104 static ThreadPool* pool = 105 new ThreadPool(FLAGS_torch_lazy_io_thread_pool_size); 106 return pool; 107 } 108 109 } // namespace 110 111 class Completion::Data { 112 public: Wait()113 void Wait() { 114 std::unique_lock<std::mutex> lock(mutex_); 115 cv_.wait(lock, [this] { return completed_; }); 116 if (exptr_ != nullptr) { 117 std::rethrow_exception(exptr_); 118 } 119 } 120 GetCompleter(const std::shared_ptr<Data> & data,std::function<void ()> closure)121 static std::function<void()> GetCompleter( 122 const std::shared_ptr<Data>& data, 123 std::function<void()> closure) { 124 auto closure_wrapper = [closure = std::move(closure), data]() { 125 std::exception_ptr exptr; 126 try { 127 closure(); 128 } catch (...) { 129 exptr = std::current_exception(); 130 } 131 data->Complete(exptr); 132 }; 133 return closure_wrapper; 134 } 135 136 private: Complete(std::exception_ptr exptr)137 void Complete(std::exception_ptr exptr) { 138 std::lock_guard<std::mutex> lock(mutex_); 139 exptr_ = std::move(exptr); 140 completed_ = true; 141 cv_.notify_all(); 142 } 143 144 std::mutex mutex_; 145 std::condition_variable cv_; 146 bool completed_ = false; 147 std::exception_ptr exptr_; 148 }; 149 Completion(std::shared_ptr<Data> data)150Completion::Completion(std::shared_ptr<Data> data) : data_(std::move(data)) {} 151 Wait()152void Completion::Wait() { 153 data_->Wait(); 154 } 155 ScheduleIoClosure(std::function<void ()> closure)156void ScheduleIoClosure(std::function<void()> closure) { 157 GetIoThreadPool()->Schedule(std::move(closure)); 158 } 159 ScheduleIoClosureWithCompletion(std::function<void ()> closure)160Completion ScheduleIoClosureWithCompletion(std::function<void()> closure) { 161 auto data = std::make_shared<Completion::Data>(); 162 GetIoThreadPool()->Schedule( 163 Completion::Data::GetCompleter(data, std::move(closure))); 164 return Completion(std::move(data)); 165 } 166 167 } // namespace lazy 168 } // namespace torch 169