xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_device_context.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device_context.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/compiler/jit/xla_device.h"
24 #include "tensorflow/compiler/jit/xla_launch_util.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/common_runtime/device.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/framework/tensor_reference.h"
32 #include "tensorflow/core/platform/mem.h"
33 #include "tensorflow/stream_executor/platform/port.h"
34 
35 namespace tensorflow {
36 
37 // The allocator used for Tensors assigned to the XLA device.
XlaDeviceAllocator(stream_executor::StreamExecutor * stream_executor)38 XlaDeviceAllocator::XlaDeviceAllocator(
39     stream_executor::StreamExecutor* stream_executor)
40     : stream_executor_(stream_executor) {}
41 
42 XlaDeviceAllocator::~XlaDeviceAllocator() = default;
43 
Name()44 string XlaDeviceAllocator::Name() { return "xla"; }
45 
AllocateRaw(size_t alignment,size_t num_bytes)46 void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
47   // We always return an empty XlaTensor object, encoded as an opaque tagged
48   // pointer. We can return an empty object and ignore num_bytes here because we
49   // have control over all of the uses of this device tensor, and can lazily
50   // allocate memory when used. This allows us to also know the shape of the
51   // allocated Tensor, which is useful if the device's tensor representation
52   // differs from the host.
53   return XlaTensor::ToOpaquePointer(new XlaTensor());
54 }
55 
DeallocateRaw(void * ptr)56 void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
57   delete XlaTensor::FromOpaquePointer(ptr);
58 }
59 
GetStats()60 std::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
61   std::optional<stream_executor::AllocatorStats> se_stats =
62       stream_executor_->GetAllocatorStats();
63   if (!se_stats) {
64     return std::nullopt;
65   }
66 
67   tensorflow::AllocatorStats tf_stats;
68   tf_stats.num_allocs = se_stats->num_allocs;
69   tf_stats.bytes_in_use = se_stats->bytes_in_use;
70   tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
71   tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
72   tf_stats.bytes_limit = se_stats->bytes_limit;
73   tf_stats.bytes_reserved = se_stats->bytes_reserved;
74   tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
75   tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
76   tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
77   return tf_stats;
78 }
79 
ClearStats()80 bool XlaDeviceAllocator::ClearStats() {
81   if (!stream_executor_->SynchronizeAllActivity()) {
82     return false;
83   }
84   return stream_executor_->ClearAllocatorStats();
85 }
86 
XlaDeviceContext(std::shared_ptr<se::Stream> compute_stream,std::shared_ptr<se::Stream> host_to_device_stream,std::shared_ptr<se::Stream> device_to_host_stream,std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,xla::LocalClient * client,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,thread::ThreadPool * thread_pool)87 XlaDeviceContext::XlaDeviceContext(
88     std::shared_ptr<se::Stream> compute_stream,
89     std::shared_ptr<se::Stream> host_to_device_stream,
90     std::shared_ptr<se::Stream> device_to_host_stream,
91     std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
92     xla::LocalClient* client,
93     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
94     thread::ThreadPool* thread_pool)
95     : stream_(std::move(compute_stream)),
96       host_to_device_stream_(std::move(host_to_device_stream)),
97       device_to_host_stream_(std::move(device_to_host_stream)),
98       device_to_device_streams_(std::move(device_to_device_streams)),
99       client_(client),
100       transfer_manager_(client->backend().transfer_manager()),
101       shape_determination_fns_(std::move(shape_determination_fns)),
102       thread_pool_(thread_pool) {
103   CHECK(host_to_device_stream_ != nullptr);
104   CHECK(stream_ != nullptr);
105 }
106 
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const107 void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
108                                               Device* device,
109                                               Tensor* output_tensor,
110                                               StatusCallback done) const {
111   done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
112 }
113 
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute) const114 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
115                                              Device* device,
116                                              Tensor* device_tensor,
117                                              StatusCallback done,
118                                              bool sync_dst_compute) const {
119   if (cpu_tensor->NumElements() == 0) {
120     VLOG(2) << "CopyCPUTensorToDevice empty tensor";
121     done(OkStatus());
122     return;
123   }
124 
125   VLOG(2) << "CopyCPUTensorToDevice " << this << " "
126           << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
127           << " "
128           << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
129           << " " << cpu_tensor->NumElements() << " "
130           << cpu_tensor->shape().DebugString() << " "
131           << device_tensor->shape().DebugString();
132 
133   XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
134   CHECK(xla_tensor);
135 
136   XlaLayoutPreference layout_preference =
137       shape_determination_fns_.layout_preference_fn(
138           device_tensor->shape(), device_tensor->dtype(), std::nullopt);
139   Status status = [&]() -> Status {
140     TF_ASSIGN_OR_RETURN(xla::Shape shape,
141                         shape_determination_fns_.shape_representation_fn(
142                             device_tensor->shape(), device_tensor->dtype(),
143                             /*fast_mem=*/false, layout_preference));
144 
145     // The device tensor should always be fresh.
146     TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
147 
148     TF_RETURN_IF_ERROR(
149         xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
150                                          stream_->parent()->device_ordinal()));
151 
152     // The cpu_tensor and literal that we created here hold the data of host
153     // tensor in descending layout. The layout could be different from layout in
154     // device_tensor (but the logical shape has to be the same). The
155     // transfer_manager is responsible to do corresponding transposing when
156     // transferring the data to device.
157     xla::BorrowingLiteral literal(
158         static_cast<const char*>(DMAHelper::base(cpu_tensor)),
159         xla::ShapeUtil::MakeShape(shape.element_type(), shape.dimensions()));
160 
161     VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " "
162             << xla_tensor->shaped_buffer().ToString();
163     if (UseMultipleStreams() &&
164         !transfer_manager_->CanShapedBufferBeAccessedNow(
165             stream_->parent(), xla_tensor->shaped_buffer())) {
166       // Initially wait for the compute stream so that memory allocations are
167       // synchronized.
168       host_to_device_stream_->ThenWaitFor(stream_.get());
169     }
170 
171     TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
172         host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
173 
174     if (UseMultipleStreams()) {
175       auto event = std::make_shared<se::Event>(stream_->parent());
176       TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
177       host_to_device_stream_->ThenRecordEvent(event.get());
178       xla_tensor->ResetDefinitionEvent(std::move(event),
179                                        host_to_device_stream_.get());
180     }
181 
182     return OkStatus();
183   }();
184   if (!status.ok()) {
185     done(status);
186     return;
187   }
188 
189   // Create a reference to hold onto cpu_tensor until after the literal has
190   // been transferred
191   TensorReference ref(*cpu_tensor);
192   if (UseMultipleStreams()) {
193     // Unref the host tensor when the transfer completes.
194     // We don't defer the call to done() onto the stream here, and the reasons
195     // why this is correct are subtle. We assume that:
196     // a) all consumers of the device tensor will wait for its definition event.
197     // b) if the tensor is destroyed, then the memory allocator will not hand
198     //    out the same buffers until the transfer has completed.
199     host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
200     done(status);
201   } else {
202     host_to_device_stream_->ThenDoHostCallback([ref, done]() {
203       ref.Unref();
204       done(OkStatus());
205     });
206   }
207 }
208 
CopyDeviceTensorToCPU(const Tensor * device_tensor,absl::string_view tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)209 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
210                                              absl::string_view tensor_name,
211                                              Device* device, Tensor* cpu_tensor,
212                                              StatusCallback done) {
213   if (device_tensor->NumElements() == 0) {
214     VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
215     done(OkStatus());
216     return;
217   }
218   VLOG(2) << "CopyDeviceTensorToCPU "
219           << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
220           << " "
221           << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
222           << " " << device_tensor->NumElements() << " "
223           << cpu_tensor->shape().DebugString() << " "
224           << device_tensor->shape().DebugString();
225 
226   std::shared_ptr<se::Stream> device_to_host_stream;
227   if (device_to_host_stream_) {
228     device_to_host_stream = device_to_host_stream_;
229   } else {
230     stream_executor::port::StatusOr<xla::StreamPool::Ptr> ptr_or_status =
231         client_->mutable_backend()->BorrowStream(
232             stream_->parent()->device_ordinal());
233     if (!ptr_or_status.status().ok()) {
234       done(ptr_or_status.status());
235       return;
236     }
237     device_to_host_stream =
238         std::shared_ptr<se::Stream>(std::move(ptr_or_status.ValueOrDie()));
239   }
240 
241   XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
242   xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get());
243 
244   // Transfer manager requires the shape of the shaped buffer to be the same as
245   // literal shape except for the layout.  Set the literal to use xla_tensor's
246   // shape as it is derived from the cpu_tensor's shape using
247   // shape_representation_fn_.
248   xla::MutableBorrowingLiteral literal;
249   TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
250       xla::LayoutUtil::GetWithDefaultLayout(
251           xla_tensor->shaped_buffer().on_host_shape()),
252       cpu_tensor, &literal));
253 
254   TensorReference ref(*device_tensor);
255   const bool device_allows_sync_on_completion =
256       device->AllowsSyncOnCompletion();
257   // Explicitly capture device_to_host_stream to make sure the stream is alive
258   // before the transfer finishes.
259   transfer_manager_->TransferLiteralFromDevice(
260       device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
261       [this, ref, xla_tensor, done, device_to_host_stream,
262        device_allows_sync_on_completion](xla::Status status) {
263         Status done_status = status;
264         VLOG(2) << "Transfer from device as literal: "
265                 << xla_tensor->shaped_buffer().ToString();
266         // For devices don't allow sync on completion, the device execution is
267         // deferred. We check the execution stream status here to avoid wrong
268         // results from a failed stream being propagated to following
269         // host-side ops.
270         if (!device_allows_sync_on_completion) {
271           done_status.Update(xla_tensor->RefreshStatusOfStreams());
272         }
273         done(done_status);
274         ref.Unref();
275         // If a stream is in a bad state, it gets deleted when it's returned to
276         // the stream pool, i.e. when it leaves this scope. However, a stream
277         // deleting itself in a host callback on itself can cause bad behaviors
278         // on some platforms. Releasing it in another stream to avoid that.
279         if (!device_allows_sync_on_completion &&
280             !device_to_host_stream->RefreshStatus().ok()) {
281           auto status_or_new_stream = client_->mutable_backend()->BorrowStream(
282               stream_->parent()->device_ordinal());
283           if (status_or_new_stream.ok()) {
284             status_or_new_stream.ValueOrDie()->ThenDoHostCallback(
285                 [device_to_host_stream] {});
286           }
287         }
288       });
289 }
290 
GetDeviceToDeviceStream()291 se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
292   DCHECK_GT(device_to_device_streams_.size(), 0);
293   absl::MutexLock lock(&mu_);
294   int stream = next_stream_;
295   next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
296   return device_to_device_stream(stream);
297 }
298 
ThenExecute(Device * device,stream_executor::Stream * stream,std::function<void ()> func)299 Status XlaDeviceContext::ThenExecute(Device* device,
300                                      stream_executor::Stream* stream,
301                                      std::function<void()> func) {
302   VLOG(2) << "XlaDeviceContext::ThenExecute";
303   stream->ThenDoHostCallback(std::move(func));
304   return OkStatus();
305 }
306 
307 }  // namespace tensorflow
308