xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_device_context.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
18 
19 #include <memory>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/jit/xla_tensor.h"
23 #include "tensorflow/compiler/tf2xla/layout_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/device_base.h"
29 #include "tensorflow/core/lib/core/status.h"
30 
31 namespace tensorflow {
32 
33 // The allocator used for Tensors assigned to the XLA device. The allocator
34 // ignores the alignment and size of the request and always returns a new,
35 // empty, XlaTensor.
36 class XlaDeviceAllocator : public Allocator {
37  public:
38   XlaDeviceAllocator(se::StreamExecutor* stream_executor);
39   ~XlaDeviceAllocator() override;
40 
41   string Name() override;
42 
43   void* AllocateRaw(size_t alignment, size_t num_bytes) override;
44   void DeallocateRaw(void* ptr) override;
45   std::optional<AllocatorStats> GetStats() override;
46   bool ClearStats() override;
47 
48  private:
49   // The stream executor of the device.
50   se::StreamExecutor* stream_executor_;
51 };
52 
53 // Helper class for managing data transfers between host and XLA devices.
54 class XlaDeviceContext : public DeviceContext {
55  public:
56   explicit XlaDeviceContext(
57       std::shared_ptr<se::Stream> compute_stream,
58       std::shared_ptr<se::Stream> host_to_device_stream,
59       std::shared_ptr<se::Stream> device_to_host_stream,
60       std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
61       xla::LocalClient* client,
62       XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
63       thread::ThreadPool* thread_pool);
64 
65   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
66                              Tensor* device_tensor, StatusCallback done,
67                              bool sync_dst_compute) const override;
68   void CopyDeviceTensorToCPU(const Tensor* device_tensor,
69                              absl::string_view tensor_name, Device* device,
70                              Tensor* cpu_tensor, StatusCallback done) override;
71   void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device,
72                               Tensor* output_tensor,
73                               StatusCallback done) const override;
74 
client()75   xla::LocalClient* client() const { return client_; }
stream()76   se::Stream* stream() const override { return stream_.get(); }
host_to_device_stream()77   se::Stream* host_to_device_stream() const {
78     return host_to_device_stream_.get();
79   }
device_to_device_stream(int index)80   se::Stream* device_to_device_stream(int index) const {
81     return device_to_device_streams_.at(index).get();
82   }
transfer_manager()83   xla::TransferManager* transfer_manager() const { return transfer_manager_; }
shape_determination_fns()84   const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns()
85       const {
86     return shape_determination_fns_;
87   }
88 
89   // Returns a device-to-device stream, in round-robin fashion.
90   se::Stream* GetDeviceToDeviceStream();
91 
92   Status ThenExecute(Device* device, stream_executor::Stream* stream,
93                      std::function<void()> func) override;
94 
95  private:
UseMultipleStreams()96   bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
97 
98   // The main compute stream of the device, used to synchronize the transfer
99   // streams if they are set.
100   std::shared_ptr<se::Stream> stream_;
101   // The stream to use for transferring data from host to device. Can be
102   // idential to stream_, but must not be nullptr.
103   std::shared_ptr<se::Stream> host_to_device_stream_;
104   // The stream to use for transferring data from device to host. Can be
105   // idential to stream_. If nullptr, borrow a stream from backend for each
106   // transfer request to support out-of-order requests.
107   std::shared_ptr<se::Stream> device_to_host_stream_;
108   // Streams to use for transferring data directly between different devices,
109   // e.g., over NVLINK.
110   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
111 
112   // For the underlying memory allocator and XLA's TransferManager.
113   xla::LocalClient* client_;
114   // Transfer manager, for marshalling data to and from the device.
115   xla::TransferManager* transfer_manager_;
116 
117   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns_;
118 
119   // Thread pool used for running closures
120   thread::ThreadPool* thread_pool_;
121 
122   absl::Mutex mu_;
123   int next_stream_ TF_GUARDED_BY(mu_) = 0;
124 };
125 
126 }  // namespace tensorflow
127 
128 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
129