1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h"
21 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 
25 namespace tensorflow {
26 
27 namespace {
28 
DestroyRemoteTensorHandle(EagerContext * ctx,const string & remote_task,uint64 context_id,uint64 op_id,int output_num,bool ready)29 void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
30                                uint64 context_id, uint64 op_id, int output_num,
31                                bool ready) {
32   if (ctx->GetContextId() != context_id) {
33     // This means that this tensor was pointing to a remote device, which
34     // has been changed out from under us. Simply return since there is
35     // nothing we can do.
36     return;
37   }
38 
39   core::RefCountPtr<eager::EagerClient> eager_client;
40   Status status = ctx->GetClient(remote_task, &eager_client);
41   if (!status.ok()) {
42     LOG_EVERY_N_SEC(INFO, 60)
43         << "Unable to destroy remote tensor handle because the target "
44         << remote_task << " is no longer available.";
45     return;
46   }
47 
48   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
49   request->set_context_id(context_id);
50 
51   auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
52   handle_to_decref->set_op_id(op_id);
53   handle_to_decref->set_output_num(output_num);
54 
55   VLOG(3) << "Sending request to delete " << request->DebugString();
56   std::unique_ptr<EagerNode> node(
57       std::make_unique<eager::DestroyTensorHandleNode>(
58           std::move(request), std::move(eager_client), ready));
59   auto& executor = ctx->Executor();
60   if (executor.Async()) {
61     Status status = executor.AddOrExecute(std::move(node));
62     if (!status.ok()) {
63       LOG_EVERY_N_SEC(WARNING, 60)
64           << "Unable to destroy remote tensor handles. If you are "
65              "running a tf.function, it usually indicates some op in "
66              "the graph gets an error: "
67           << status.error_message();
68     }
69   } else {
70     // This thread may still hold tensorflow::StreamingRPCState::mu_. We need
71     // to send out the destroy request in a new thread to avoid deadlock.
72     auto* released_node = node.release();
73     (*ctx->runner())([ctx, released_node] {
74       Status status =
75           ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
76       if (!status.ok()) {
77         LOG_EVERY_N_SEC(WARNING, 60)
78             << "Unable to destroy remote tensor handles. If you are "
79                "running a tf.function, it usually indicates some op in "
80                "the graph gets an error: "
81             << status.error_message();
82       }
83     });
84   }
85 }
86 }  // namespace
87 
RemoteTensorHandleData(int64_t op_id,int output_num,uint64 context_view_id,bool is_ready)88 RemoteTensorHandleData::RemoteTensorHandleData(int64_t op_id, int output_num,
89                                                uint64 context_view_id,
90                                                bool is_ready)
91     : is_ready_(is_ready),
92       op_id_(op_id),
93       output_num_(output_num),
94       context_view_id_(context_view_id),
95       ctx_(nullptr) {
96   DCHECK(op_id_ >= 0 && output_num_ >= 0)
97       << "Op ID and output num should be >= 0. Op ID: " << op_id
98       << ", Output num: " << output_num;
99 }
100 
RemoteTensorHandleData(int64_t op_id,int output_num,const string & remote_task,EagerContext * ctx)101 RemoteTensorHandleData::RemoteTensorHandleData(int64_t op_id, int output_num,
102                                                const string& remote_task,
103                                                EagerContext* ctx)
104     : is_ready_(false),
105       op_id_(op_id),
106       output_num_(output_num),
107       remote_task_(remote_task),
108       context_id_(ctx->GetContextId()),
109       context_view_id_(ctx->GetContextViewId()),
110       ctx_(ctx) {
111   DCHECK(op_id_ >= 0 && output_num_ >= 0)
112       << "Op ID and output num should be >= 0. Op ID: " << op_id
113       << ", Output num: " << output_num;
114   ctx_->Ref();
115 }
116 
~RemoteTensorHandleData()117 RemoteTensorHandleData::~RemoteTensorHandleData() {
118   if (ctx_) {
119     DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
120                               output_num_, /*ready=*/true);
121     ctx_->Unref();
122   }
123 }
124 
Shape(TensorShape * shape) const125 Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
126   TF_RETURN_IF_ERROR(WaitReady("Shape"));
127 
128   tf_shared_lock l(mu_);
129   *shape = shape_;
130 
131   return OkStatus();
132 }
133 
NumDims(int * num_dims) const134 Status RemoteTensorHandleData::NumDims(int* num_dims) const {
135   TF_RETURN_IF_ERROR(WaitReady("NumDims"));
136 
137   tf_shared_lock l(mu_);
138   *num_dims = shape_.dims();
139 
140   return OkStatus();
141 }
142 
Dim(int dim_index,int64_t * dim) const143 Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const {
144   TF_RETURN_IF_ERROR(WaitReady("Dim"));
145 
146   tf_shared_lock l(mu_);
147   *dim = shape_.dim_size(dim_index);
148 
149   return OkStatus();
150 }
151 
NumElements(int64_t * num_elements) const152 Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const {
153   TF_RETURN_IF_ERROR(WaitReady("NumElements"));
154 
155   tf_shared_lock l(mu_);
156   *num_elements = shape_.num_elements();
157 
158   return OkStatus();
159 }
160 
IsReady() const161 bool RemoteTensorHandleData::IsReady() const {
162   tf_shared_lock l(mu_);
163   return is_ready_;
164 }
165 
Poison(Status status)166 void RemoteTensorHandleData::Poison(Status status) {
167   mutex_lock l(mu_);
168   is_poisoned_ = status;
169   is_ready_ = true;
170 }
171 
IsPoisoned() const172 Status RemoteTensorHandleData::IsPoisoned() const {
173   tf_shared_lock l(mu_);
174   return is_poisoned_;
175 }
176 
SetShape(const TensorShape & shape)177 Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
178   return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
179 }
180 
SetShapeAndRemoteTask(const TensorShape & shape,const string & remote_task)181 Status RemoteTensorHandleData::SetShapeAndRemoteTask(
182     const TensorShape& shape, const string& remote_task) {
183   // If `is_ready_` is set previously due to poisoning, return the original
184   // error that poisoned this tensor.
185   TF_RETURN_IF_ERROR(IsPoisoned());
186 
187   mutex_lock l(mu_);
188   if (is_ready_) {
189     return errors::Internal("SetShape is only called on non-ready handles.");
190   }
191 
192   shape_ = shape;
193   if (!remote_task.empty()) {
194     remote_task_ = remote_task;
195   }
196   is_poisoned_ = OkStatus();
197   is_ready_ = true;
198 
199   return OkStatus();
200 }
201 
DebugString() const202 string RemoteTensorHandleData::DebugString() const {
203   return absl::StrCat("RemoteTensorHandleData:", " op_id: ", op_id_,
204                       " output_num: ", output_num_);
205 }
206 
OpIdAndOutputNum(const bool wait_util_ready,int64_t * op_id,int32 * output_num) const207 Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
208                                                 int64_t* op_id,
209                                                 int32* output_num) const {
210   if (wait_util_ready) {
211     TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
212   }
213   *op_id = op_id_;
214   *output_num = output_num_;
215   return OkStatus();
216 }
217 
WaitReady(const char * caller) const218 Status RemoteTensorHandleData::WaitReady(const char* caller) const {
219   tf_shared_lock l(mu_);
220   if (!is_ready_) {
221     profiler::TraceMe activity(
222         [caller] { return absl::StrCat(caller, " WaitReady"); },
223         profiler::TraceMeLevel::kInfo);
224     DVLOG(3) << "WaitReady: " << caller << " " << this;
225     mu_.Await(Condition(&is_ready_));
226   }
227   return is_poisoned_;
228 }
229 
230 }  // namespace tensorflow
231