1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
17
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
24 #include "tensorflow/core/platform/error_payloads.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/status.h"
27
28 namespace tensorflow {
29
30 namespace {
WithErrorSourcePayload(Status error)31 Status WithErrorSourcePayload(Status error) {
32 core::platform::ErrorSourceProto error_source_proto;
33 error_source_proto.set_error_source(
34 core::platform::ErrorSourceProto::EAGER_REMOTE_MGR);
35 error.SetPayload(tensorflow::kErrorSource,
36 error_source_proto.SerializeAsString());
37 return error;
38 }
39 } // namespace
40
41 namespace eager {
42
AddOperationOutputs(const gtl::ArraySlice<tensorflow::TensorHandle * > handles,int64_t operation_id)43 void RemoteMgr::AddOperationOutputs(
44 const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
45 int64_t operation_id) {
46 mutex_lock l(remote_tensor_handle_mu_);
47 for (int i = 0, end = handles.size(); i < end; i++) {
48 // TODO(nareshmodi): Correctly handle operation_id not being unique.
49 remote_tensor_handle_map_.emplace(
50 RemoteTensorHandleInternal(operation_id, i), handles[i]);
51 }
52 }
53
AddOperationOutput(tensorflow::TensorHandle * handle,int64_t operation_id,int32_t output_num)54 void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle,
55 int64_t operation_id, int32_t output_num) {
56 mutex_lock l(remote_tensor_handle_mu_);
57 remote_tensor_handle_map_.emplace(
58 RemoteTensorHandleInternal(operation_id, output_num), handle);
59 }
60
GetTensorHandleImpl(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)61 Status RemoteMgr::GetTensorHandleImpl(
62 const RemoteTensorHandleInternal& remote_handle,
63 tensorflow::TensorHandle** handle) {
64 auto iter = remote_tensor_handle_map_.find(remote_handle);
65 if (iter == remote_tensor_handle_map_.end()) {
66 // TODO(b/217820532): Fix the tensor deallocation order issue.
67 return WithErrorSourcePayload(errors::InvalidArgument(
68 "Unable to find the relevant tensor remote_handle: Op ID: ",
69 remote_handle.op_id, ", Output num: ", remote_handle.output_num,
70 ". One possible cause is that the tensor was accessed after "
71 "deallocation in a distributed worker setup. Try setting "
72 "`os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'` in "
73 "your client to disable async streaming behavior to see if it fixes "
74 "the problem."));
75 }
76
77 *handle = iter->second;
78
79 return OkStatus();
80 }
81
GetTensorHandle(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)82 Status RemoteMgr::GetTensorHandle(
83 const RemoteTensorHandleInternal& remote_handle,
84 tensorflow::TensorHandle** handle) {
85 tf_shared_lock l(remote_tensor_handle_mu_);
86 return GetTensorHandleImpl(remote_handle, handle);
87 }
88
GetMirroredResourceShape(const RemoteTensorHandleInternal & remote_handle,std::vector<DtypeAndPartialTensorShape> * handle)89 Status RemoteMgr::GetMirroredResourceShape(
90 const RemoteTensorHandleInternal& remote_handle,
91 std::vector<DtypeAndPartialTensorShape>* handle) {
92 tf_shared_lock l(mirrored_resource_shape_mu_);
93 auto iter = mirrored_resource_shape_map_.find(remote_handle);
94 if (iter == mirrored_resource_shape_map_.end()) {
95 // TODO(b/217820532): Fix the tensor deallocation order issue.
96 return WithErrorSourcePayload(errors::InvalidArgument(
97 "Unable to find the relevant tensor remote_handle: Op ID: ",
98 remote_handle.op_id, ", Output num: ", remote_handle.output_num,
99 ". One possible cause is that the tensor was accessed after "
100 "deallocation in a distributed worker setup. Try setting "
101 "`os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'` in "
102 "your client to disable async streaming behavior to see if it fixes "
103 "the problem."));
104 }
105
106 *handle = iter->second;
107
108 return OkStatus();
109 }
110
GetRemoteTensorHandle(const tensorflow::TensorHandle * handle,const bool wait_until_ready,int64_t * op_id,int32 * output_num)111 Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
112 const bool wait_until_ready,
113 int64_t* op_id, int32* output_num) {
114 TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
115 op_id, output_num));
116 tensorflow::TensorHandle* h;
117 TF_RETURN_IF_ERROR(
118 GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
119 if (handle != h) {
120 return WithErrorSourcePayload(errors::Internal(
121 "Found two different tensor handles with the same op_id:", *op_id,
122 " and output_num:", *output_num));
123 }
124 return OkStatus();
125 }
126
DeleteTensorHandle(const RemoteTensorHandleInternal & remote_handle)127 Status RemoteMgr::DeleteTensorHandle(
128 const RemoteTensorHandleInternal& remote_handle) {
129 {
130 mutex_lock l(remote_tensor_handle_mu_);
131 auto iter = remote_tensor_handle_map_.find(remote_handle);
132 if (iter != remote_tensor_handle_map_.end()) {
133 iter->second->Unref();
134 remote_tensor_handle_map_.erase(iter);
135 return OkStatus();
136 }
137 }
138 {
139 mutex_lock l(mirrored_resource_shape_mu_);
140 auto iter = mirrored_resource_shape_map_.find(remote_handle);
141 if (iter != mirrored_resource_shape_map_.end()) {
142 mirrored_resource_shape_map_.erase(iter);
143 return OkStatus();
144 }
145 }
146 return WithErrorSourcePayload(errors::InvalidArgument(
147 "Unable to find the relevant tensor remote_handle: Op ID: ",
148 remote_handle.op_id, ", Output num: ", remote_handle.output_num));
149 }
150
SerializeRemoteTensorHandle(TensorHandle * in,const bool wait_until_ready,RemoteTensorHandle * out,Device * device,const string & device_name,const bool serialize_resource_dtype_and_shape)151 Status RemoteMgr::SerializeRemoteTensorHandle(
152 TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
153 Device* device, const string& device_name,
154 const bool serialize_resource_dtype_and_shape) {
155 int64_t op_id;
156 int32_t output_num;
157 if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) {
158 tf_shared_lock l(remote_tensor_handle_mu_);
159 TF_RETURN_IF_ERROR(
160 GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
161 }
162 out->Clear();
163 out->set_op_id(op_id);
164 out->set_output_num(output_num);
165 out->set_op_device(in->op_device() ? in->op_device()->name() : "");
166 out->set_device(device_name);
167 out->set_dtype(in->dtype);
168 if (serialize_resource_dtype_and_shape) {
169 std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
170 TF_RETURN_IF_ERROR(
171 in->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes));
172 for (const auto& dtype_and_shape : resource_dtypes_and_shapes) {
173 ResourceDtypeAndShape* dtype_and_shape_proto =
174 out->add_resource_dtypes_and_shapes();
175 dtype_and_shape_proto->set_dtype(dtype_and_shape.dtype);
176 dtype_and_shape.shape.AsProto(dtype_and_shape_proto->mutable_shape());
177 }
178 }
179 return OkStatus();
180 }
181
DeserializeRemoteTensorHandle(const RemoteTensorHandle & in,TensorHandle ** out)182 Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
183 TensorHandle** out) {
184 Device* device;
185 if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() ||
186 parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) {
187 TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out));
188 (*out)->Ref();
189 } else {
190 // Create a remote TensorHandle for remote tensors which have not been
191 // copied to the local worker yet (e.g. remote function inputs).
192 const string& device_name =
193 in.op_device().empty() ? in.device() : in.op_device();
194 TF_RETURN_IF_ERROR(
195 parent_->FindDeviceFromName(device_name.c_str(), &device));
196 *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
197 in.dtype(), device,
198 /*is_ready=*/true, parent_);
199 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
200 if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
201 &dtypes_and_shapes)
202 .ok()) {
203 for (const auto& dtype_and_shape_proto :
204 in.resource_dtypes_and_shapes()) {
205 dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
206 dtype_and_shape_proto.dtype(),
207 TensorShape(dtype_and_shape_proto.shape())});
208 }
209 mutex_lock l(mirrored_resource_shape_mu_);
210 mirrored_resource_shape_map_.emplace(
211 RemoteTensorHandleInternal(in.op_id(), in.output_num()),
212 dtypes_and_shapes);
213 }
214 (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
215 }
216
217 return OkStatus();
218 }
219
GetOrCreateExecutorForStream(uint64 stream_id)220 EagerExecutor& RemoteMgr::GetOrCreateExecutorForStream(uint64 stream_id) {
221 mutex_lock l(executor_map_mu_);
222 auto it = executor_map_.find(stream_id);
223 if (it == executor_map_.end()) {
224 auto it_and_bool = executor_map_.emplace(
225 std::piecewise_construct, std::forward_as_tuple(stream_id),
226 std::forward_as_tuple(/*async=*/true));
227 DCHECK(it_and_bool.second);
228 it = it_and_bool.first;
229 }
230 return it->second;
231 }
232
DeleteExecutorForStream(uint64 stream_id)233 void RemoteMgr::DeleteExecutorForStream(uint64 stream_id) {
234 mutex_lock l(executor_map_mu_);
235 auto it = executor_map_.find(stream_id);
236 if (it == executor_map_.end()) {
237 return;
238 }
239 Status s = it->second.ShutDown();
240 if (!s.ok()) {
241 LOG(ERROR) << "EagerExecutor shutdown with error " << s.error_message();
242 }
243 executor_map_.erase(it);
244 }
245
246 } // namespace eager
247 } // namespace tensorflow
248