xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/multi_wait.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/multi_wait.h>
2 
3 #include <chrono>
4 #include <exception>
5 #include <stdexcept>
6 
7 namespace torch {
8 namespace lazy {
9 
Done()10 void MultiWait::Done() {
11   bool notify = false;
12   {
13     std::lock_guard<std::mutex> lock(mutex_);
14     completed_count_ += 1;
15     notify = completed_count_ == count_;
16   }
17   if (notify) {
18     cv_.notify_all();
19   }
20 }
21 
Wait()22 void MultiWait::Wait() {
23   std::unique_lock<std::mutex> lock(mutex_);
24   cv_.wait(lock, [this] { return completed_count_ >= count_; });
25   if (exptr_ != nullptr) {
26     std::rethrow_exception(exptr_);
27   }
28 }
29 
Wait(double wait_seconds)30 void MultiWait::Wait(double wait_seconds) {
31   std::unique_lock<std::mutex> lock(mutex_);
32   if (!cv_.wait_for(lock, std::chrono::duration<double>(wait_seconds), [this] {
33         return completed_count_ >= count_;
34       })) {
35     throw std::runtime_error("Timeout");
36   }
37   if (exptr_ != nullptr) {
38     std::rethrow_exception(exptr_);
39   }
40 }
41 
Reset(size_t count)42 void MultiWait::Reset(size_t count) {
43   std::lock_guard<std::mutex> lock(mutex_);
44   count_ = count;
45   completed_count_ = 0;
46   exptr_ = nullptr;
47 }
48 
Completer(std::function<void ()> func)49 std::function<void()> MultiWait::Completer(std::function<void()> func) {
50   auto completer = [this, func = std::move(func)]() { Complete(func); };
51   return completer;
52 }
53 
Completer(std::shared_ptr<MultiWait> mwait,std::function<void ()> func)54 std::function<void()> MultiWait::Completer(
55     std::shared_ptr<MultiWait> mwait,
56     std::function<void()> func) {
57   auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
58     mwait->Complete(func);
59   };
60   return completer;
61 }
62 
Complete(const std::function<void ()> & func)63 void MultiWait::Complete(const std::function<void()>& func) {
64   try {
65     func();
66   } catch (...) {
67     std::lock_guard<std::mutex> lock(mutex_);
68     exptr_ = std::current_exception();
69   }
70   Done();
71 }
72 
73 } // namespace lazy
74 } // namespace torch
75