1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/common_runtime/copy_tensor.h"
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/device_factory.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/notification.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
34
35 namespace tensorflow {
36
37 namespace {
SameWorkerRecvDone(const DeviceMgr * device_mgr,const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)38 void SameWorkerRecvDone(const DeviceMgr* device_mgr,
39 const Rendezvous::ParsedKey& parsed,
40 const Rendezvous::Args& send_args,
41 const Rendezvous::Args& recv_args, const Tensor& in,
42 Tensor* out, StatusCallback done) {
43 // Do a quick copy (sharing the underlying buffer) if both tensors
44 // are on host memory.
45 const bool src_host =
46 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
47 const bool dst_host =
48 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
49 if (src_host && dst_host) {
50 if (VLOG_IS_ON(3)) {
51 bool src_override =
52 send_args.alloc_attrs.on_host() && !(parsed.src.type == "CPU");
53 bool dst_override =
54 recv_args.alloc_attrs.on_host() && !(parsed.dst.type == "CPU");
55 if (src_override || dst_override) {
56 VLOG(3) << "Shortcut to keep tensor on host (src_override "
57 << src_override << " and dst_override " << dst_override
58 << ") tensor dtype:" << DataTypeString(in.dtype()) << " "
59 << parsed.FullKey();
60 }
61 }
62 *out = in;
63 done(OkStatus());
64 return;
65 }
66
67 // This copy must involve a non-CPU device. Hence, "in" must support DMA
68 // (e.g., string tensors do not work on GPU). Variant copy DMA
69 // checks happen inside CopyTensor::ViaDMA.
70 if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
71 in.dtype() != DT_RESOURCE) {
72 done(errors::InvalidArgument(
73 "Non-DMA-safe ", DataTypeString(in.dtype()),
74 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
75 return;
76 }
77
78 Device* src_device;
79 Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
80 if (!s.ok()) {
81 done(s);
82 return;
83 }
84 Device* dst_device;
85 s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
86 if (!s.ok()) {
87 done(s);
88 return;
89 }
90
91 profiler::ScopedMemoryDebugAnnotation op_annotation(
92 "SameWorkerRecvDone", 0, "dynamic", in.dtype(),
93 [&in]() { return in.shape().DebugString(); });
94 AllocatorAttributes attr = recv_args.alloc_attrs;
95 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
96 recv_args.alloc_attrs.gpu_compatible());
97 Allocator* out_allocator = dst_device->GetAllocator(attr);
98 bool sync_dst_compute = true;
99 if (in.dtype() != DT_VARIANT) {
100 // Variants are handled by CopyTensor::ViaDMA.
101 AllocationAttributes aa;
102 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
103 std::function<uint64()> freed_by_func = [dst_device,
104 &safe_alloc_frontier]() {
105 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
106 return safe_alloc_frontier;
107 };
108 if ((parsed.dst.type == "GPU" ||
109 DeviceFactory::IsPluggableDevice(parsed.dst.type)) &&
110 safe_alloc_frontier > 0) {
111 // There's a timestamped allocator at work, so use it instead
112 // of sync_dst_compute.
113 aa.freed_by_func = &freed_by_func;
114 sync_dst_compute = false;
115 }
116 Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
117 *out = copy;
118 if (in.shape().num_elements() > 0 && out->data() == nullptr) {
119 done(tensorflow::errors::ResourceExhausted(
120 "SameWorkerRecvDone unable to allocate output tensor. Key: ",
121 parsed.FullKey()));
122 return;
123 }
124 }
125
126 CopyTensor::ViaDMA(
127 parsed.edge_name, send_args.device_context, recv_args.device_context,
128 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
129 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
130 }
131
IntraProcessRecvAsyncImpl(const DeviceMgr * device_mgr,LocalRendezvous * local,const RendezvousInterface::ParsedKey & parsed,const Rendezvous::Args & recv_args,RendezvousInterface::DoneCallback done)132 void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
133 LocalRendezvous* local,
134 const RendezvousInterface::ParsedKey& parsed,
135 const Rendezvous::Args& recv_args,
136 RendezvousInterface::DoneCallback done) {
137 VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
138
139 profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync");
140 // Recv the tensor from local_.
141 local->RecvAsync(
142 parsed, recv_args,
143 [device_mgr, parsed, done = std::move(done)](
144 const Status& status, const Rendezvous::Args& send_args,
145 const Rendezvous::Args& recv_args, const Tensor& in,
146 bool is_dead) mutable {
147 // If "in" is an uninitialized tensor, do copy-construction to
148 // preserve the uninitialized state, along with data type and shape
149 // info, which is useful for debugger purposes.
150 Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
151
152 auto final_callback = [send_args, recv_args, out, is_dead,
153 done = std::move(done)](const Status& s) {
154 done(s, send_args, recv_args, *out, is_dead);
155 delete out;
156 };
157
158 if (status.ok() && in.IsInitialized()) {
159 SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
160 std::move(final_callback));
161 } else {
162 final_callback(status);
163 }
164 });
165 }
166
167 } // namespace
168
RefCountedIntraProcessRendezvous(const DeviceMgr * device_mgr)169 RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
170 const DeviceMgr* device_mgr)
171 : device_mgr_(device_mgr), local_(this) {}
172
~RefCountedIntraProcessRendezvous()173 RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
174
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)175 Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
176 const Rendezvous::Args& args,
177 const Tensor& val,
178 const bool is_dead) {
179 VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
180 return local_.Send(key, args, val, is_dead);
181 }
182
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)183 void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
184 const Rendezvous::Args& args,
185 DoneCallback done) {
186 VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
187 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
188 }
189
StartAbort(const Status & s)190 void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
191 local_.StartAbort(s);
192 }
193
GetLocalRendezvousStatus()194 Status RefCountedIntraProcessRendezvous::GetLocalRendezvousStatus() {
195 return local_.status();
196 }
197
PrivateIntraProcessRendezvous(const DeviceMgr * device_mgr)198 PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
199 const DeviceMgr* device_mgr)
200 : device_mgr_(device_mgr), local_(nullptr) {}
201
~PrivateIntraProcessRendezvous()202 PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
203
Send(const ParsedKey & key,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)204 Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
205 const Rendezvous::Args& args,
206 const Tensor& val,
207 const bool is_dead) {
208 DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
209 return local_.Send(key, args, val, is_dead);
210 }
211
RecvAsync(const ParsedKey & key,const Rendezvous::Args & args,DoneCallback done)212 void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
213 const Rendezvous::Args& args,
214 DoneCallback done) {
215 DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
216 << key.FullKey();
217 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
218 }
219
StartAbort(const Status & s)220 void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
221 local_.StartAbort(s);
222 }
223
224 } // end namespace tensorflow
225