1 #include <torch/csrc/lazy/core/multi_wait.h> 2 3 #include <chrono> 4 #include <exception> 5 #include <stdexcept> 6 7 namespace torch { 8 namespace lazy { 9 Done()10void MultiWait::Done() { 11 bool notify = false; 12 { 13 std::lock_guard<std::mutex> lock(mutex_); 14 completed_count_ += 1; 15 notify = completed_count_ == count_; 16 } 17 if (notify) { 18 cv_.notify_all(); 19 } 20 } 21 Wait()22void MultiWait::Wait() { 23 std::unique_lock<std::mutex> lock(mutex_); 24 cv_.wait(lock, [this] { return completed_count_ >= count_; }); 25 if (exptr_ != nullptr) { 26 std::rethrow_exception(exptr_); 27 } 28 } 29 Wait(double wait_seconds)30void MultiWait::Wait(double wait_seconds) { 31 std::unique_lock<std::mutex> lock(mutex_); 32 if (!cv_.wait_for(lock, std::chrono::duration<double>(wait_seconds), [this] { 33 return completed_count_ >= count_; 34 })) { 35 throw std::runtime_error("Timeout"); 36 } 37 if (exptr_ != nullptr) { 38 std::rethrow_exception(exptr_); 39 } 40 } 41 Reset(size_t count)42void MultiWait::Reset(size_t count) { 43 std::lock_guard<std::mutex> lock(mutex_); 44 count_ = count; 45 completed_count_ = 0; 46 exptr_ = nullptr; 47 } 48 Completer(std::function<void ()> func)49std::function<void()> MultiWait::Completer(std::function<void()> func) { 50 auto completer = [this, func = std::move(func)]() { Complete(func); }; 51 return completer; 52 } 53 Completer(std::shared_ptr<MultiWait> mwait,std::function<void ()> func)54std::function<void()> MultiWait::Completer( 55 std::shared_ptr<MultiWait> mwait, 56 std::function<void()> func) { 57 auto completer = [mwait = std::move(mwait), func = std::move(func)]() { 58 mwait->Complete(func); 59 }; 60 return completer; 61 } 62 Complete(const std::function<void ()> & func)63void MultiWait::Complete(const std::function<void()>& func) { 64 try { 65 func(); 66 } catch (...) { 67 std::lock_guard<std::mutex> lock(mutex_); 68 exptr_ = std::current_exception(); 69 } 70 Done(); 71 } 72 73 } // namespace lazy 74 } // namespace torch 75