xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/rendezvous_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/device_factory.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/notification.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
34 
35 namespace tensorflow {
36 
37 namespace {
SameWorkerRecvDone(const DeviceMgr * device_mgr,const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)38 void SameWorkerRecvDone(const DeviceMgr* device_mgr,
39                         const Rendezvous::ParsedKey& parsed,
40                         const Rendezvous::Args& send_args,
41                         const Rendezvous::Args& recv_args, const Tensor& in,
42                         Tensor* out, StatusCallback done) {
43   // Do a quick copy (sharing the underlying buffer) if both tensors
44   // are on host memory.
45   const bool src_host =
46       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
47   const bool dst_host =
48       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
49   if (src_host && dst_host) {
50     if (VLOG_IS_ON(3)) {
51       bool src_override =
52           send_args.alloc_attrs.on_host() && !(parsed.src.type == "CPU");
53       bool dst_override =
54           recv_args.alloc_attrs.on_host() && !(parsed.dst.type == "CPU");
55       if (src_override || dst_override) {
56         VLOG(3) << "Shortcut to keep tensor on host (src_override "
57                 << src_override << " and dst_override " << dst_override
58                 << ") tensor dtype:" << DataTypeString(in.dtype()) << " "
59                 << parsed.FullKey();
60       }
61     }
62     *out = in;
63     done(OkStatus());
64     return;
65   }
66 
67   // This copy must involve a non-CPU device. Hence, "in" must support DMA
68   // (e.g., string tensors do not work on GPU).  Variant copy DMA
69   // checks happen inside CopyTensor::ViaDMA.
70   if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
71       in.dtype() != DT_RESOURCE) {
72     done(errors::InvalidArgument(
73         "Non-DMA-safe ", DataTypeString(in.dtype()),
74         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
75     return;
76   }
77 
78   Device* src_device;
79   Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
80   if (!s.ok()) {
81     done(s);
82     return;
83   }
84   Device* dst_device;
85   s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
86   if (!s.ok()) {
87     done(s);
88     return;
89   }
90 
91   profiler::ScopedMemoryDebugAnnotation op_annotation(
92       "SameWorkerRecvDone", 0, "dynamic", in.dtype(),
93       [&in]() { return in.shape().DebugString(); });
94   AllocatorAttributes attr = recv_args.alloc_attrs;
95   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
96                           recv_args.alloc_attrs.gpu_compatible());
97   Allocator* out_allocator = dst_device->GetAllocator(attr);
98   bool sync_dst_compute = true;
99   if (in.dtype() != DT_VARIANT) {
100     // Variants are handled by CopyTensor::ViaDMA.
101     AllocationAttributes aa;
102     uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
103     std::function<uint64()> freed_by_func = [dst_device,
104                                              &safe_alloc_frontier]() {
105       safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
106       return safe_alloc_frontier;
107     };
108     if ((parsed.dst.type == "GPU" ||
109          DeviceFactory::IsPluggableDevice(parsed.dst.type)) &&
110         safe_alloc_frontier > 0) {
111       // There's a timestamped allocator at work, so use it instead
112       // of sync_dst_compute.
113       aa.freed_by_func = &freed_by_func;
114       sync_dst_compute = false;
115     }
116     Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
117     *out = copy;
118     if (in.shape().num_elements() > 0 && out->data() == nullptr) {
119       done(tensorflow::errors::ResourceExhausted(
120           "SameWorkerRecvDone unable to allocate output tensor. Key: ",
121           parsed.FullKey()));
122       return;
123     }
124   }
125 
126   CopyTensor::ViaDMA(
127       parsed.edge_name, send_args.device_context, recv_args.device_context,
128       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
129       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
130 }
131 
IntraProcessRecvAsyncImpl(const DeviceMgr * device_mgr,LocalRendezvous * local,const RendezvousInterface::ParsedKey & parsed,const Rendezvous::Args & recv_args,RendezvousInterface::DoneCallback done)132 void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
133                                LocalRendezvous* local,
134                                const RendezvousInterface::ParsedKey& parsed,
135                                const Rendezvous::Args& recv_args,
136                                RendezvousInterface::DoneCallback done) {
137   VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
138 
139   profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync");
140   // Recv the tensor from local_.
141   local->RecvAsync(
142       parsed, recv_args,
143       [device_mgr, parsed, done = std::move(done)](
144           const Status& status, const Rendezvous::Args& send_args,
145           const Rendezvous::Args& recv_args, const Tensor& in,
146           bool is_dead) mutable {
147         // If "in" is an uninitialized tensor, do copy-construction to
148         // preserve the uninitialized state, along with data type and shape
149         // info, which is useful for debugger purposes.
150         Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
151 
152         auto final_callback = [send_args, recv_args, out, is_dead,
153                                done = std::move(done)](const Status& s) {
154           done(s, send_args, recv_args, *out, is_dead);
155           delete out;
156         };
157 
158         if (status.ok() && in.IsInitialized()) {
159           SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
160                              std::move(final_callback));
161         } else {
162           final_callback(status);
163         }
164       });
165 }
166 
167 }  // namespace
168 
RefCountedIntraProcessRendezvous(const DeviceMgr * device_mgr)169 RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
170     const DeviceMgr* device_mgr)
171     : device_mgr_(device_mgr), local_(this) {}
172 
~RefCountedIntraProcessRendezvous()173 RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
174 
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)175 Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
176                                               const Rendezvous::Args& args,
177                                               const Tensor& val,
178                                               const bool is_dead) {
179   VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
180   return local_.Send(key, args, val, is_dead);
181 }
182 
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)183 void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
184                                                  const Rendezvous::Args& args,
185                                                  DoneCallback done) {
186   VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
187   IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
188 }
189 
StartAbort(const Status & s)190 void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
191   local_.StartAbort(s);
192 }
193 
GetLocalRendezvousStatus()194 Status RefCountedIntraProcessRendezvous::GetLocalRendezvousStatus() {
195   return local_.status();
196 }
197 
PrivateIntraProcessRendezvous(const DeviceMgr * device_mgr)198 PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
199     const DeviceMgr* device_mgr)
200     : device_mgr_(device_mgr), local_(nullptr) {}
201 
~PrivateIntraProcessRendezvous()202 PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
203 
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)204 Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
205                                            const Rendezvous::Args& args,
206                                            const Tensor& val,
207                                            const bool is_dead) {
208   DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
209   return local_.Send(key, args, val, is_dead);
210 }
211 
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)212 void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
213                                               const Rendezvous::Args& args,
214                                               DoneCallback done) {
215   DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
216            << key.FullKey();
217   IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
218 }
219 
StartAbort(const Status & s)220 void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
221   local_.StartAbort(s);
222 }
223 
224 }  // end namespace tensorflow
225