1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device_context.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "tensorflow/compiler/jit/xla_device.h"
24 #include "tensorflow/compiler/jit/xla_launch_util.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/framework/tensor_reference.h"
32 #include "tensorflow/core/platform/mem.h"
33 #include "tensorflow/stream_executor/platform/port.h"
34
35 namespace tensorflow {
36
37 // The allocator used for Tensors assigned to the XLA device.
XlaDeviceAllocator(stream_executor::StreamExecutor * stream_executor)38 XlaDeviceAllocator::XlaDeviceAllocator(
39 stream_executor::StreamExecutor* stream_executor)
40 : stream_executor_(stream_executor) {}
41
42 XlaDeviceAllocator::~XlaDeviceAllocator() = default;
43
Name()44 string XlaDeviceAllocator::Name() { return "xla"; }
45
AllocateRaw(size_t alignment,size_t num_bytes)46 void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
47 // We always return an empty XlaTensor object, encoded as an opaque tagged
48 // pointer. We can return an empty object and ignore num_bytes here because we
49 // have control over all of the uses of this device tensor, and can lazily
50 // allocate memory when used. This allows us to also know the shape of the
51 // allocated Tensor, which is useful if the device's tensor representation
52 // differs from the host.
53 return XlaTensor::ToOpaquePointer(new XlaTensor());
54 }
55
DeallocateRaw(void * ptr)56 void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
57 delete XlaTensor::FromOpaquePointer(ptr);
58 }
59
GetStats()60 std::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
61 std::optional<stream_executor::AllocatorStats> se_stats =
62 stream_executor_->GetAllocatorStats();
63 if (!se_stats) {
64 return std::nullopt;
65 }
66
67 tensorflow::AllocatorStats tf_stats;
68 tf_stats.num_allocs = se_stats->num_allocs;
69 tf_stats.bytes_in_use = se_stats->bytes_in_use;
70 tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
71 tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
72 tf_stats.bytes_limit = se_stats->bytes_limit;
73 tf_stats.bytes_reserved = se_stats->bytes_reserved;
74 tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
75 tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
76 tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
77 return tf_stats;
78 }
79
ClearStats()80 bool XlaDeviceAllocator::ClearStats() {
81 if (!stream_executor_->SynchronizeAllActivity()) {
82 return false;
83 }
84 return stream_executor_->ClearAllocatorStats();
85 }
86
XlaDeviceContext(std::shared_ptr<se::Stream> compute_stream,std::shared_ptr<se::Stream> host_to_device_stream,std::shared_ptr<se::Stream> device_to_host_stream,std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,xla::LocalClient * client,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,thread::ThreadPool * thread_pool)87 XlaDeviceContext::XlaDeviceContext(
88 std::shared_ptr<se::Stream> compute_stream,
89 std::shared_ptr<se::Stream> host_to_device_stream,
90 std::shared_ptr<se::Stream> device_to_host_stream,
91 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
92 xla::LocalClient* client,
93 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
94 thread::ThreadPool* thread_pool)
95 : stream_(std::move(compute_stream)),
96 host_to_device_stream_(std::move(host_to_device_stream)),
97 device_to_host_stream_(std::move(device_to_host_stream)),
98 device_to_device_streams_(std::move(device_to_device_streams)),
99 client_(client),
100 transfer_manager_(client->backend().transfer_manager()),
101 shape_determination_fns_(std::move(shape_determination_fns)),
102 thread_pool_(thread_pool) {
103 CHECK(host_to_device_stream_ != nullptr);
104 CHECK(stream_ != nullptr);
105 }
106
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const107 void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
108 Device* device,
109 Tensor* output_tensor,
110 StatusCallback done) const {
111 done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
112 }
113
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute) const114 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
115 Device* device,
116 Tensor* device_tensor,
117 StatusCallback done,
118 bool sync_dst_compute) const {
119 if (cpu_tensor->NumElements() == 0) {
120 VLOG(2) << "CopyCPUTensorToDevice empty tensor";
121 done(OkStatus());
122 return;
123 }
124
125 VLOG(2) << "CopyCPUTensorToDevice " << this << " "
126 << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
127 << " "
128 << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
129 << " " << cpu_tensor->NumElements() << " "
130 << cpu_tensor->shape().DebugString() << " "
131 << device_tensor->shape().DebugString();
132
133 XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
134 CHECK(xla_tensor);
135
136 XlaLayoutPreference layout_preference =
137 shape_determination_fns_.layout_preference_fn(
138 device_tensor->shape(), device_tensor->dtype(), std::nullopt);
139 Status status = [&]() -> Status {
140 TF_ASSIGN_OR_RETURN(xla::Shape shape,
141 shape_determination_fns_.shape_representation_fn(
142 device_tensor->shape(), device_tensor->dtype(),
143 /*fast_mem=*/false, layout_preference));
144
145 // The device tensor should always be fresh.
146 TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
147
148 TF_RETURN_IF_ERROR(
149 xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
150 stream_->parent()->device_ordinal()));
151
152 // The cpu_tensor and literal that we created here hold the data of host
153 // tensor in descending layout. The layout could be different from layout in
154 // device_tensor (but the logical shape has to be the same). The
155 // transfer_manager is responsible to do corresponding transposing when
156 // transferring the data to device.
157 xla::BorrowingLiteral literal(
158 static_cast<const char*>(DMAHelper::base(cpu_tensor)),
159 xla::ShapeUtil::MakeShape(shape.element_type(), shape.dimensions()));
160
161 VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " "
162 << xla_tensor->shaped_buffer().ToString();
163 if (UseMultipleStreams() &&
164 !transfer_manager_->CanShapedBufferBeAccessedNow(
165 stream_->parent(), xla_tensor->shaped_buffer())) {
166 // Initially wait for the compute stream so that memory allocations are
167 // synchronized.
168 host_to_device_stream_->ThenWaitFor(stream_.get());
169 }
170
171 TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
172 host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
173
174 if (UseMultipleStreams()) {
175 auto event = std::make_shared<se::Event>(stream_->parent());
176 TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
177 host_to_device_stream_->ThenRecordEvent(event.get());
178 xla_tensor->ResetDefinitionEvent(std::move(event),
179 host_to_device_stream_.get());
180 }
181
182 return OkStatus();
183 }();
184 if (!status.ok()) {
185 done(status);
186 return;
187 }
188
189 // Create a reference to hold onto cpu_tensor until after the literal has
190 // been transferred
191 TensorReference ref(*cpu_tensor);
192 if (UseMultipleStreams()) {
193 // Unref the host tensor when the transfer completes.
194 // We don't defer the call to done() onto the stream here, and the reasons
195 // why this is correct are subtle. We assume that:
196 // a) all consumers of the device tensor will wait for its definition event.
197 // b) if the tensor is destroyed, then the memory allocator will not hand
198 // out the same buffers until the transfer has completed.
199 host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
200 done(status);
201 } else {
202 host_to_device_stream_->ThenDoHostCallback([ref, done]() {
203 ref.Unref();
204 done(OkStatus());
205 });
206 }
207 }
208
CopyDeviceTensorToCPU(const Tensor * device_tensor,absl::string_view tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)209 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
210 absl::string_view tensor_name,
211 Device* device, Tensor* cpu_tensor,
212 StatusCallback done) {
213 if (device_tensor->NumElements() == 0) {
214 VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
215 done(OkStatus());
216 return;
217 }
218 VLOG(2) << "CopyDeviceTensorToCPU "
219 << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
220 << " "
221 << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
222 << " " << device_tensor->NumElements() << " "
223 << cpu_tensor->shape().DebugString() << " "
224 << device_tensor->shape().DebugString();
225
226 std::shared_ptr<se::Stream> device_to_host_stream;
227 if (device_to_host_stream_) {
228 device_to_host_stream = device_to_host_stream_;
229 } else {
230 stream_executor::port::StatusOr<xla::StreamPool::Ptr> ptr_or_status =
231 client_->mutable_backend()->BorrowStream(
232 stream_->parent()->device_ordinal());
233 if (!ptr_or_status.status().ok()) {
234 done(ptr_or_status.status());
235 return;
236 }
237 device_to_host_stream =
238 std::shared_ptr<se::Stream>(std::move(ptr_or_status.ValueOrDie()));
239 }
240
241 XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
242 xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get());
243
244 // Transfer manager requires the shape of the shaped buffer to be the same as
245 // literal shape except for the layout. Set the literal to use xla_tensor's
246 // shape as it is derived from the cpu_tensor's shape using
247 // shape_representation_fn_.
248 xla::MutableBorrowingLiteral literal;
249 TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
250 xla::LayoutUtil::GetWithDefaultLayout(
251 xla_tensor->shaped_buffer().on_host_shape()),
252 cpu_tensor, &literal));
253
254 TensorReference ref(*device_tensor);
255 const bool device_allows_sync_on_completion =
256 device->AllowsSyncOnCompletion();
257 // Explicitly capture device_to_host_stream to make sure the stream is alive
258 // before the transfer finishes.
259 transfer_manager_->TransferLiteralFromDevice(
260 device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
261 [this, ref, xla_tensor, done, device_to_host_stream,
262 device_allows_sync_on_completion](xla::Status status) {
263 Status done_status = status;
264 VLOG(2) << "Transfer from device as literal: "
265 << xla_tensor->shaped_buffer().ToString();
266 // For devices don't allow sync on completion, the device execution is
267 // deferred. We check the execution stream status here to avoid wrong
268 // results from a failed stream being propagated to following
269 // host-side ops.
270 if (!device_allows_sync_on_completion) {
271 done_status.Update(xla_tensor->RefreshStatusOfStreams());
272 }
273 done(done_status);
274 ref.Unref();
275 // If a stream is in a bad state, it gets deleted when it's returned to
276 // the stream pool, i.e. when it leaves this scope. However, a stream
277 // deleting itself in a host callback on itself can cause bad behaviors
278 // on some platforms. Releasing it in another stream to avoid that.
279 if (!device_allows_sync_on_completion &&
280 !device_to_host_stream->RefreshStatus().ok()) {
281 auto status_or_new_stream = client_->mutable_backend()->BorrowStream(
282 stream_->parent()->device_ordinal());
283 if (status_or_new_stream.ok()) {
284 status_or_new_stream.ValueOrDie()->ThenDoHostCallback(
285 [device_to_host_stream] {});
286 }
287 }
288 });
289 }
290
GetDeviceToDeviceStream()291 se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
292 DCHECK_GT(device_to_device_streams_.size(), 0);
293 absl::MutexLock lock(&mu_);
294 int stream = next_stream_;
295 next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
296 return device_to_device_stream(stream);
297 }
298
ThenExecute(Device * device,stream_executor::Stream * stream,std::function<void ()> func)299 Status XlaDeviceContext::ThenExecute(Device* device,
300 stream_executor::Stream* stream,
301 std::function<void()> func) {
302 VLOG(2) << "XlaDeviceContext::ThenExecute";
303 stream->ThenDoHostCallback(std::move(func));
304 return OkStatus();
305 }
306
307 } // namespace tensorflow
308