1 #include <ATen/ThreadLocalState.h>
2
3 #include <torch/csrc/distributed/c10d/Work.hpp>
4 #include <utility>
5
6 namespace c10d {
7
Work(int rank,OpType opType,const char * profilingTitle,const std::optional<std::vector<at::Tensor>> & inputTensors)8 Work::Work(
9 int rank,
10 OpType opType,
11 const char* profilingTitle,
12 const std::optional<std::vector<at::Tensor>>& inputTensors)
13 : rank_(rank), opType_(opType) {
14 if (profilingTitle != nullptr) {
15 auto recordingFunction =
16 std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
17 if (recordingFunction->isActive()) {
18 // Work events follow a future like pattern and can potentially be marked
19 // as complete by different threads, so explicitly set as async event.
20 recordingFunction->_setAsync();
21 // Passing input tensor to recordFunction allows for shape information in
22 // profiling output.
23 std::vector<c10::IValue> inputs;
24 if (inputTensors) {
25 inputs.reserve(inputTensors->size());
26 for (const auto& tensor : *inputTensors) {
27 inputs.emplace_back(tensor);
28 }
29 }
30 recordingFunction->before(
31 profilingTitle,
32 c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
33 std::function<void()> end_handler = [recordingFunction]() {
34 recordingFunction->end();
35 };
36 recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
37 }
38 }
39 }
40
retrieveOpType() const41 OpType Work::retrieveOpType() const {
42 return opType_;
43 }
44
45 Work::~Work() = default;
46
isCompleted()47 bool Work::isCompleted() {
48 std::lock_guard<std::mutex> lock(mutex_);
49 return completed_;
50 }
51
isSuccess() const52 bool Work::isSuccess() const {
53 std::lock_guard<std::mutex> lock(mutex_);
54 return !exception_;
55 }
56
exception() const57 std::exception_ptr Work::exception() const {
58 std::lock_guard<std::mutex> lock(mutex_);
59 return exception_;
60 }
61
sourceRank() const62 int Work::sourceRank() const {
63 TORCH_CHECK(
64 false,
65 "sourceRank() may only be called on work objects "
66 "that correspond to a recv or recv-from-any call.");
67 }
68
result()69 std::vector<at::Tensor> Work::result() {
70 TORCH_CHECK(false, "result() not implemented.");
71 }
72
synchronize()73 void Work::synchronize() {}
74
wait(std::chrono::milliseconds timeout)75 bool Work::wait(std::chrono::milliseconds timeout) {
76 std::unique_lock<std::mutex> lock(mutex_);
77 if (timeout == kNoTimeout) {
78 // This waits without a timeout.
79 cv_.wait(lock, [&] { return completed_; });
80 } else {
81 // Waits for the user-provided timeout.
82 cv_.wait_for(lock, timeout, [&] { return completed_; });
83 if (!completed_) {
84 // Throw exception if the wait operation timed out and the work was not
85 // completed.
86 TORCH_CHECK(false, "Operation timed out!");
87 }
88 }
89 if (exception_) {
90 std::rethrow_exception(exception_);
91 }
92 synchronize();
93 // Always return true, because abort API is not implemented.
94 return true;
95 }
96
abort()97 void Work::abort() {
98 TORCH_CHECK(false, "Work::abort not implemented.");
99 }
100
getFuture()101 c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture() {
102 TORCH_CHECK(false, "Work::getFuture not implemented.")
103 }
104
finish(std::exception_ptr exception)105 void Work::finish(std::exception_ptr exception) {
106 std::unique_lock<std::mutex> lock(mutex_);
107 completed_ = true;
108 exception_ = std::move(exception);
109 if (recordFunctionEndCallback_) {
110 recordFunctionEndCallback_();
111 recordFunctionEndCallback_ = nullptr;
112 }
113 lock.unlock();
114 cv_.notify_all();
115 }
116
finishAndThrow(std::exception_ptr exception)117 void Work::finishAndThrow(std::exception_ptr exception) {
118 std::unique_lock<std::mutex> lock(mutex_);
119 completed_ = true;
120 exception_ = std::move(exception);
121 if (recordFunctionEndCallback_) {
122 recordFunctionEndCallback_();
123 recordFunctionEndCallback_ = nullptr;
124 }
125 if (exception_) {
126 std::rethrow_exception(exception_);
127 }
128 }
129
getDuration() const130 float Work::getDuration() const {
131 TORCH_CHECK(false, "This Backend doesn't support getDuration.");
132 }
133
getSequencenumber() const134 uint64_t Work::getSequencenumber() const {
135 TORCH_CHECK(false, "This Backend doesn't support getSequencenumber.");
136 }
137
138 class FutureWrappingWork : public Work {
139 public:
FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)140 FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
141 : Work(), _fut(std::move(fut)) {}
142
143 ~FutureWrappingWork() override = default;
144
isCompleted()145 bool isCompleted() override {
146 return _fut->completed();
147 }
148
isSuccess() const149 bool isSuccess() const override {
150 return _fut->hasValue();
151 }
152
exception() const153 std::exception_ptr exception() const override {
154 return _fut->exception_ptr();
155 }
156
sourceRank() const157 int sourceRank() const override {
158 TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
159 }
160
result()161 std::vector<at::Tensor> result() override {
162 return _fut->value().toPyObjectHolder()->extractTensors();
163 }
164
wait(std::chrono::milliseconds timeout)165 bool wait(std::chrono::milliseconds timeout) override {
166 // FIXME
167 TORCH_CHECK(
168 timeout == kNoTimeout,
169 "FutureWrappingWork::wait() with finite timeout not implemented");
170 _fut->wait();
171 return true;
172 }
173
abort()174 void abort() override {
175 TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
176 }
177
getFuture()178 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
179 return _fut;
180 }
181
182 private:
183 c10::intrusive_ptr<c10::ivalue::Future> _fut;
184 };
185
create_from_future(const c10::intrusive_ptr<c10::ivalue::Future> & future)186 c10::intrusive_ptr<Work> Work::create_from_future(
187 const c10::intrusive_ptr<c10::ivalue::Future>& future) {
188 return c10::make_intrusive<FutureWrappingWork>(future);
189 }
190
191 } // namespace c10d
192