xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/eager/remote_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
17 
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
24 #include "tensorflow/core/platform/error_payloads.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/status.h"
27 
28 namespace tensorflow {
29 
30 namespace {
WithErrorSourcePayload(Status error)31 Status WithErrorSourcePayload(Status error) {
32   core::platform::ErrorSourceProto error_source_proto;
33   error_source_proto.set_error_source(
34       core::platform::ErrorSourceProto::EAGER_REMOTE_MGR);
35   error.SetPayload(tensorflow::kErrorSource,
36                    error_source_proto.SerializeAsString());
37   return error;
38 }
39 }  // namespace
40 
41 namespace eager {
42 
AddOperationOutputs(const gtl::ArraySlice<tensorflow::TensorHandle * > handles,int64_t operation_id)43 void RemoteMgr::AddOperationOutputs(
44     const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
45     int64_t operation_id) {
46   mutex_lock l(remote_tensor_handle_mu_);
47   for (int i = 0, end = handles.size(); i < end; i++) {
48     // TODO(nareshmodi): Correctly handle operation_id not being unique.
49     remote_tensor_handle_map_.emplace(
50         RemoteTensorHandleInternal(operation_id, i), handles[i]);
51   }
52 }
53 
AddOperationOutput(tensorflow::TensorHandle * handle,int64_t operation_id,int32_t output_num)54 void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle,
55                                    int64_t operation_id, int32_t output_num) {
56   mutex_lock l(remote_tensor_handle_mu_);
57   remote_tensor_handle_map_.emplace(
58       RemoteTensorHandleInternal(operation_id, output_num), handle);
59 }
60 
GetTensorHandleImpl(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)61 Status RemoteMgr::GetTensorHandleImpl(
62     const RemoteTensorHandleInternal& remote_handle,
63     tensorflow::TensorHandle** handle) {
64   auto iter = remote_tensor_handle_map_.find(remote_handle);
65   if (iter == remote_tensor_handle_map_.end()) {
66     // TODO(b/217820532): Fix the tensor deallocation order issue.
67     return WithErrorSourcePayload(errors::InvalidArgument(
68         "Unable to find the relevant tensor remote_handle: Op ID: ",
69         remote_handle.op_id, ", Output num: ", remote_handle.output_num,
70         ". One possible cause is that the tensor was accessed after "
71         "deallocation in a distributed worker setup. Try setting "
72         "`os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'` in "
73         "your client to disable async streaming behavior to see if it fixes "
74         "the problem."));
75   }
76 
77   *handle = iter->second;
78 
79   return OkStatus();
80 }
81 
GetTensorHandle(const RemoteTensorHandleInternal & remote_handle,tensorflow::TensorHandle ** handle)82 Status RemoteMgr::GetTensorHandle(
83     const RemoteTensorHandleInternal& remote_handle,
84     tensorflow::TensorHandle** handle) {
85   tf_shared_lock l(remote_tensor_handle_mu_);
86   return GetTensorHandleImpl(remote_handle, handle);
87 }
88 
GetMirroredResourceShape(const RemoteTensorHandleInternal & remote_handle,std::vector<DtypeAndPartialTensorShape> * handle)89 Status RemoteMgr::GetMirroredResourceShape(
90     const RemoteTensorHandleInternal& remote_handle,
91     std::vector<DtypeAndPartialTensorShape>* handle) {
92   tf_shared_lock l(mirrored_resource_shape_mu_);
93   auto iter = mirrored_resource_shape_map_.find(remote_handle);
94   if (iter == mirrored_resource_shape_map_.end()) {
95     // TODO(b/217820532): Fix the tensor deallocation order issue.
96     return WithErrorSourcePayload(errors::InvalidArgument(
97         "Unable to find the relevant tensor remote_handle: Op ID: ",
98         remote_handle.op_id, ", Output num: ", remote_handle.output_num,
99         ". One possible cause is that the tensor was accessed after "
100         "deallocation in a distributed worker setup. Try setting "
101         "`os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'` in "
102         "your client to disable async streaming behavior to see if it fixes "
103         "the problem."));
104   }
105 
106   *handle = iter->second;
107 
108   return OkStatus();
109 }
110 
GetRemoteTensorHandle(const tensorflow::TensorHandle * handle,const bool wait_until_ready,int64_t * op_id,int32 * output_num)111 Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
112                                         const bool wait_until_ready,
113                                         int64_t* op_id, int32* output_num) {
114   TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
115                                            op_id, output_num));
116   tensorflow::TensorHandle* h;
117   TF_RETURN_IF_ERROR(
118       GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
119   if (handle != h) {
120     return WithErrorSourcePayload(errors::Internal(
121         "Found two different tensor handles with the same op_id:", *op_id,
122         " and output_num:", *output_num));
123   }
124   return OkStatus();
125 }
126 
DeleteTensorHandle(const RemoteTensorHandleInternal & remote_handle)127 Status RemoteMgr::DeleteTensorHandle(
128     const RemoteTensorHandleInternal& remote_handle) {
129   {
130     mutex_lock l(remote_tensor_handle_mu_);
131     auto iter = remote_tensor_handle_map_.find(remote_handle);
132     if (iter != remote_tensor_handle_map_.end()) {
133       iter->second->Unref();
134       remote_tensor_handle_map_.erase(iter);
135       return OkStatus();
136     }
137   }
138   {
139     mutex_lock l(mirrored_resource_shape_mu_);
140     auto iter = mirrored_resource_shape_map_.find(remote_handle);
141     if (iter != mirrored_resource_shape_map_.end()) {
142       mirrored_resource_shape_map_.erase(iter);
143       return OkStatus();
144     }
145   }
146   return WithErrorSourcePayload(errors::InvalidArgument(
147       "Unable to find the relevant tensor remote_handle: Op ID: ",
148       remote_handle.op_id, ", Output num: ", remote_handle.output_num));
149 }
150 
SerializeRemoteTensorHandle(TensorHandle * in,const bool wait_until_ready,RemoteTensorHandle * out,Device * device,const string & device_name,const bool serialize_resource_dtype_and_shape)151 Status RemoteMgr::SerializeRemoteTensorHandle(
152     TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out,
153     Device* device, const string& device_name,
154     const bool serialize_resource_dtype_and_shape) {
155   int64_t op_id;
156   int32_t output_num;
157   if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) {
158     tf_shared_lock l(remote_tensor_handle_mu_);
159     TF_RETURN_IF_ERROR(
160         GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num));
161   }
162   out->Clear();
163   out->set_op_id(op_id);
164   out->set_output_num(output_num);
165   out->set_op_device(in->op_device() ? in->op_device()->name() : "");
166   out->set_device(device_name);
167   out->set_dtype(in->dtype);
168   if (serialize_resource_dtype_and_shape) {
169     std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
170     TF_RETURN_IF_ERROR(
171         in->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes));
172     for (const auto& dtype_and_shape : resource_dtypes_and_shapes) {
173       ResourceDtypeAndShape* dtype_and_shape_proto =
174           out->add_resource_dtypes_and_shapes();
175       dtype_and_shape_proto->set_dtype(dtype_and_shape.dtype);
176       dtype_and_shape.shape.AsProto(dtype_and_shape_proto->mutable_shape());
177     }
178   }
179   return OkStatus();
180 }
181 
DeserializeRemoteTensorHandle(const RemoteTensorHandle & in,TensorHandle ** out)182 Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
183                                                 TensorHandle** out) {
184   Device* device;
185   if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() ||
186       parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) {
187     TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out));
188     (*out)->Ref();
189   } else {
190     // Create a remote TensorHandle for remote tensors which have not been
191     // copied to the local worker yet (e.g. remote function inputs).
192     const string& device_name =
193         in.op_device().empty() ? in.device() : in.op_device();
194     TF_RETURN_IF_ERROR(
195         parent_->FindDeviceFromName(device_name.c_str(), &device));
196     *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
197                                                 in.dtype(), device,
198                                                 /*is_ready=*/true, parent_);
199     std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
200     if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
201                                   &dtypes_and_shapes)
202              .ok()) {
203       for (const auto& dtype_and_shape_proto :
204            in.resource_dtypes_and_shapes()) {
205         dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
206             dtype_and_shape_proto.dtype(),
207             TensorShape(dtype_and_shape_proto.shape())});
208       }
209       mutex_lock l(mirrored_resource_shape_mu_);
210       mirrored_resource_shape_map_.emplace(
211           RemoteTensorHandleInternal(in.op_id(), in.output_num()),
212           dtypes_and_shapes);
213     }
214     (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
215   }
216 
217   return OkStatus();
218 }
219 
GetOrCreateExecutorForStream(uint64 stream_id)220 EagerExecutor& RemoteMgr::GetOrCreateExecutorForStream(uint64 stream_id) {
221   mutex_lock l(executor_map_mu_);
222   auto it = executor_map_.find(stream_id);
223   if (it == executor_map_.end()) {
224     auto it_and_bool = executor_map_.emplace(
225         std::piecewise_construct, std::forward_as_tuple(stream_id),
226         std::forward_as_tuple(/*async=*/true));
227     DCHECK(it_and_bool.second);
228     it = it_and_bool.first;
229   }
230   return it->second;
231 }
232 
DeleteExecutorForStream(uint64 stream_id)233 void RemoteMgr::DeleteExecutorForStream(uint64 stream_id) {
234   mutex_lock l(executor_map_mu_);
235   auto it = executor_map_.find(stream_id);
236   if (it == executor_map_.end()) {
237     return;
238   }
239   Status s = it->second.ShutDown();
240   if (!s.ok()) {
241     LOG(ERROR) << "EagerExecutor shutdown with error " << s.error_message();
242   }
243   executor_map_.erase(it);
244 }
245 
246 }  // namespace eager
247 }  // namespace tensorflow
248