1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
16
17 #include <memory>
18 #include <utility>
19
20 #include "tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h"
21 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24
25 namespace tensorflow {
26
27 namespace {
28
DestroyRemoteTensorHandle(EagerContext * ctx,const string & remote_task,uint64 context_id,uint64 op_id,int output_num,bool ready)29 void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
30 uint64 context_id, uint64 op_id, int output_num,
31 bool ready) {
32 if (ctx->GetContextId() != context_id) {
33 // This means that this tensor was pointing to a remote device, which
34 // has been changed out from under us. Simply return since there is
35 // nothing we can do.
36 return;
37 }
38
39 core::RefCountPtr<eager::EagerClient> eager_client;
40 Status status = ctx->GetClient(remote_task, &eager_client);
41 if (!status.ok()) {
42 LOG_EVERY_N_SEC(INFO, 60)
43 << "Unable to destroy remote tensor handle because the target "
44 << remote_task << " is no longer available.";
45 return;
46 }
47
48 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
49 request->set_context_id(context_id);
50
51 auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
52 handle_to_decref->set_op_id(op_id);
53 handle_to_decref->set_output_num(output_num);
54
55 VLOG(3) << "Sending request to delete " << request->DebugString();
56 std::unique_ptr<EagerNode> node(
57 std::make_unique<eager::DestroyTensorHandleNode>(
58 std::move(request), std::move(eager_client), ready));
59 auto& executor = ctx->Executor();
60 if (executor.Async()) {
61 Status status = executor.AddOrExecute(std::move(node));
62 if (!status.ok()) {
63 LOG_EVERY_N_SEC(WARNING, 60)
64 << "Unable to destroy remote tensor handles. If you are "
65 "running a tf.function, it usually indicates some op in "
66 "the graph gets an error: "
67 << status.error_message();
68 }
69 } else {
70 // This thread may still hold tensorflow::StreamingRPCState::mu_. We need
71 // to send out the destroy request in a new thread to avoid deadlock.
72 auto* released_node = node.release();
73 (*ctx->runner())([ctx, released_node] {
74 Status status =
75 ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
76 if (!status.ok()) {
77 LOG_EVERY_N_SEC(WARNING, 60)
78 << "Unable to destroy remote tensor handles. If you are "
79 "running a tf.function, it usually indicates some op in "
80 "the graph gets an error: "
81 << status.error_message();
82 }
83 });
84 }
85 }
86 } // namespace
87
RemoteTensorHandleData(int64_t op_id,int output_num,uint64 context_view_id,bool is_ready)88 RemoteTensorHandleData::RemoteTensorHandleData(int64_t op_id, int output_num,
89 uint64 context_view_id,
90 bool is_ready)
91 : is_ready_(is_ready),
92 op_id_(op_id),
93 output_num_(output_num),
94 context_view_id_(context_view_id),
95 ctx_(nullptr) {
96 DCHECK(op_id_ >= 0 && output_num_ >= 0)
97 << "Op ID and output num should be >= 0. Op ID: " << op_id
98 << ", Output num: " << output_num;
99 }
100
RemoteTensorHandleData(int64_t op_id,int output_num,const string & remote_task,EagerContext * ctx)101 RemoteTensorHandleData::RemoteTensorHandleData(int64_t op_id, int output_num,
102 const string& remote_task,
103 EagerContext* ctx)
104 : is_ready_(false),
105 op_id_(op_id),
106 output_num_(output_num),
107 remote_task_(remote_task),
108 context_id_(ctx->GetContextId()),
109 context_view_id_(ctx->GetContextViewId()),
110 ctx_(ctx) {
111 DCHECK(op_id_ >= 0 && output_num_ >= 0)
112 << "Op ID and output num should be >= 0. Op ID: " << op_id
113 << ", Output num: " << output_num;
114 ctx_->Ref();
115 }
116
~RemoteTensorHandleData()117 RemoteTensorHandleData::~RemoteTensorHandleData() {
118 if (ctx_) {
119 DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
120 output_num_, /*ready=*/true);
121 ctx_->Unref();
122 }
123 }
124
Shape(TensorShape * shape) const125 Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
126 TF_RETURN_IF_ERROR(WaitReady("Shape"));
127
128 tf_shared_lock l(mu_);
129 *shape = shape_;
130
131 return OkStatus();
132 }
133
NumDims(int * num_dims) const134 Status RemoteTensorHandleData::NumDims(int* num_dims) const {
135 TF_RETURN_IF_ERROR(WaitReady("NumDims"));
136
137 tf_shared_lock l(mu_);
138 *num_dims = shape_.dims();
139
140 return OkStatus();
141 }
142
Dim(int dim_index,int64_t * dim) const143 Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const {
144 TF_RETURN_IF_ERROR(WaitReady("Dim"));
145
146 tf_shared_lock l(mu_);
147 *dim = shape_.dim_size(dim_index);
148
149 return OkStatus();
150 }
151
NumElements(int64_t * num_elements) const152 Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const {
153 TF_RETURN_IF_ERROR(WaitReady("NumElements"));
154
155 tf_shared_lock l(mu_);
156 *num_elements = shape_.num_elements();
157
158 return OkStatus();
159 }
160
IsReady() const161 bool RemoteTensorHandleData::IsReady() const {
162 tf_shared_lock l(mu_);
163 return is_ready_;
164 }
165
Poison(Status status)166 void RemoteTensorHandleData::Poison(Status status) {
167 mutex_lock l(mu_);
168 is_poisoned_ = status;
169 is_ready_ = true;
170 }
171
IsPoisoned() const172 Status RemoteTensorHandleData::IsPoisoned() const {
173 tf_shared_lock l(mu_);
174 return is_poisoned_;
175 }
176
SetShape(const TensorShape & shape)177 Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
178 return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
179 }
180
SetShapeAndRemoteTask(const TensorShape & shape,const string & remote_task)181 Status RemoteTensorHandleData::SetShapeAndRemoteTask(
182 const TensorShape& shape, const string& remote_task) {
183 // If `is_ready_` is set previously due to poisoning, return the original
184 // error that poisoned this tensor.
185 TF_RETURN_IF_ERROR(IsPoisoned());
186
187 mutex_lock l(mu_);
188 if (is_ready_) {
189 return errors::Internal("SetShape is only called on non-ready handles.");
190 }
191
192 shape_ = shape;
193 if (!remote_task.empty()) {
194 remote_task_ = remote_task;
195 }
196 is_poisoned_ = OkStatus();
197 is_ready_ = true;
198
199 return OkStatus();
200 }
201
DebugString() const202 string RemoteTensorHandleData::DebugString() const {
203 return absl::StrCat("RemoteTensorHandleData:", " op_id: ", op_id_,
204 " output_num: ", output_num_);
205 }
206
OpIdAndOutputNum(const bool wait_util_ready,int64_t * op_id,int32 * output_num) const207 Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
208 int64_t* op_id,
209 int32* output_num) const {
210 if (wait_util_ready) {
211 TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
212 }
213 *op_id = op_id_;
214 *output_num = output_num_;
215 return OkStatus();
216 }
217
WaitReady(const char * caller) const218 Status RemoteTensorHandleData::WaitReady(const char* caller) const {
219 tf_shared_lock l(mu_);
220 if (!is_ready_) {
221 profiler::TraceMe activity(
222 [caller] { return absl::StrCat(caller, " WaitReady"); },
223 profiler::TraceMeLevel::kInfo);
224 DVLOG(3) << "WaitReady: " << caller << " " << this;
225 mu_.Await(Condition(&is_ready_));
226 }
227 return is_poisoned_;
228 }
229
230 } // namespace tensorflow
231