xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Work.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 
3 #include <torch/csrc/distributed/c10d/Work.hpp>
4 #include <utility>
5 
6 namespace c10d {
7 
Work(int rank,OpType opType,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)8 Work::Work(
9     int rank,
10     OpType opType,
11     const char* profilingTitle,
12     const std::optional<std::vector<at::Tensor>>& inputTensors)
13     : rank_(rank), opType_(opType) {
14   if (profilingTitle != nullptr) {
15     auto recordingFunction =
16         std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
17     if (recordingFunction->isActive()) {
18       // Work events follow a future like pattern and can potentially be marked
19       // as complete by different threads, so explicitly set as async event.
20       recordingFunction->_setAsync();
21       // Passing input tensor to recordFunction allows for shape information in
22       // profiling output.
23       std::vector<c10::IValue> inputs;
24       if (inputTensors) {
25         inputs.reserve(inputTensors->size());
26         for (const auto& tensor : *inputTensors) {
27           inputs.emplace_back(tensor);
28         }
29       }
30       recordingFunction->before(
31           profilingTitle,
32           c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
33       std::function<void()> end_handler = [recordingFunction]() {
34         recordingFunction->end();
35       };
36       recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
37     }
38   }
39 }
40 
retrieveOpType() const41 OpType Work::retrieveOpType() const {
42   return opType_;
43 }
44 
45 Work::~Work() = default;
46 
isCompleted()47 bool Work::isCompleted() {
48   std::lock_guard<std::mutex> lock(mutex_);
49   return completed_;
50 }
51 
isSuccess() const52 bool Work::isSuccess() const {
53   std::lock_guard<std::mutex> lock(mutex_);
54   return !exception_;
55 }
56 
exception() const57 std::exception_ptr Work::exception() const {
58   std::lock_guard<std::mutex> lock(mutex_);
59   return exception_;
60 }
61 
sourceRank() const62 int Work::sourceRank() const {
63   TORCH_CHECK(
64       false,
65       "sourceRank() may only be called on work objects "
66       "that correspond to a recv or recv-from-any call.");
67 }
68 
result()69 std::vector<at::Tensor> Work::result() {
70   TORCH_CHECK(false, "result() not implemented.");
71 }
72 
synchronize()73 void Work::synchronize() {}
74 
wait(std::chrono::milliseconds timeout)75 bool Work::wait(std::chrono::milliseconds timeout) {
76   std::unique_lock<std::mutex> lock(mutex_);
77   if (timeout == kNoTimeout) {
78     // This waits without a timeout.
79     cv_.wait(lock, [&] { return completed_; });
80   } else {
81     // Waits for the user-provided timeout.
82     cv_.wait_for(lock, timeout, [&] { return completed_; });
83     if (!completed_) {
84       // Throw exception if the wait operation timed out and the work was not
85       // completed.
86       TORCH_CHECK(false, "Operation timed out!");
87     }
88   }
89   if (exception_) {
90     std::rethrow_exception(exception_);
91   }
92   synchronize();
93   // Always return true, because abort API is not implemented.
94   return true;
95 }
96 
abort()97 void Work::abort() {
98   TORCH_CHECK(false, "Work::abort not implemented.");
99 }
100 
getFuture()101 c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
102   TORCH_CHECK(false, "Work::getFuture not implemented.")
103 }
104 
finish(std::exception_ptr exception)105 void Work::finish(std::exception_ptr exception) {
106   std::unique_lock<std::mutex> lock(mutex_);
107   completed_ = true;
108   exception_ = std::move(exception);
109   if (recordFunctionEndCallback_) {
110     recordFunctionEndCallback_();
111     recordFunctionEndCallback_ = nullptr;
112   }
113   lock.unlock();
114   cv_.notify_all();
115 }
116 
finishAndThrow(std::exception_ptr exception)117 void Work::finishAndThrow(std::exception_ptr exception) {
118   std::unique_lock<std::mutex> lock(mutex_);
119   completed_ = true;
120   exception_ = std::move(exception);
121   if (recordFunctionEndCallback_) {
122     recordFunctionEndCallback_();
123     recordFunctionEndCallback_ = nullptr;
124   }
125   if (exception_) {
126     std::rethrow_exception(exception_);
127   }
128 }
129 
getDuration() const130 float Work::getDuration() const {
131   TORCH_CHECK(false, "This Backend doesn't support getDuration.");
132 }
133 
getSequencenumber() const134 uint64_t Work::getSequencenumber() const {
135   TORCH_CHECK(false, "This Backend doesn't support getSequencenumber.");
136 }
137 
138 class FutureWrappingWork : public Work {
139  public:
FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)140   FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
141       : Work(), _fut(std::move(fut)) {}
142 
143   ~FutureWrappingWork() override = default;
144 
isCompleted()145   bool isCompleted() override {
146     return _fut->completed();
147   }
148 
isSuccess() const149   bool isSuccess() const override {
150     return _fut->hasValue();
151   }
152 
exception() const153   std::exception_ptr exception() const override {
154     return _fut->exception_ptr();
155   }
156 
sourceRank() const157   int sourceRank() const override {
158     TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
159   }
160 
result()161   std::vector<at::Tensor> result() override {
162     return _fut->value().toPyObjectHolder()->extractTensors();
163   }
164 
wait(std::chrono::milliseconds timeout)165   bool wait(std::chrono::milliseconds timeout) override {
166     // FIXME
167     TORCH_CHECK(
168         timeout == kNoTimeout,
169         "FutureWrappingWork::wait() with finite timeout not implemented");
170     _fut->wait();
171     return true;
172   }
173 
abort()174   void abort() override {
175     TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
176   }
177 
getFuture()178   c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
179     return _fut;
180   }
181 
182  private:
183   c10::intrusive_ptr<c10::ivalue::Future> _fut;
184 };
185 
create_from_future(const c10::intrusive_ptr<c10::ivalue::Future> & future)186 c10::intrusive_ptr<Work> Work::create_from_future(
187     const c10::intrusive_ptr<c10::ivalue::Future>& future) {
188   return c10::make_intrusive<FutureWrappingWork>(future);
189 }
190 
191 } // namespace c10d
192