xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_device.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra
17 // runtime.
18 //
19 // Operators assigned to an XlaDevice are compiled into XLA computations.
20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
21 //
22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
23 // under different names (e.g., XLA_CPU or XLA_GPU).
24 
25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
27 #include <set>
28 
29 #include "absl/types/optional.h"
30 #include "tensorflow/compiler/jit/xla_device_context.h"
31 #include "tensorflow/compiler/jit/xla_tensor.h"
32 #include "tensorflow/compiler/tf2xla/layout_util.h"
33 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
34 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
35 #include "tensorflow/compiler/xla/client/local_client.h"
36 #include "tensorflow/core/common_runtime/device_factory.h"
37 #include "tensorflow/core/common_runtime/local_device.h"
38 #include "tensorflow/core/framework/allocator.h"
39 #include "tensorflow/core/framework/device_base.h"
40 #include "tensorflow/core/framework/node_def_builder.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/resource_mgr.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
48 
49 namespace tensorflow {
50 
51 class XlaDevice : public LocalDevice {
52  public:
53   // Given a tensor, sets `xla::Shape*` the shape of tensor's representation
54   // on device, fully padded. On error, the contents of `xla::Shape*`
55   // are undefined.
56   typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
57 
58   // Wrapper class to store metadata about the XlaDevice, where it can be
59   // retrieved e.g., when lazily creating the XlaCompilationCache device.
60   class Metadata {
61    public:
62     Metadata(int device_ordinal, se::Platform* platform,
63              const DeviceType& device_type,
64              std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
65                  shape_determination_fns,
66              PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
67 
68     // The index of the device on this host.
69     int device_ordinal() const;
70 
71     se::Platform* platform() const;
72     xla::LocalClient* client() const;
73     const DeviceType& jit_device_type() const;
74     const XlaShapeLayoutHelpers::ShapeDeterminationFns&
default_shape_determination_fns()75     default_shape_determination_fns() const {
76       return shape_determination_fns_.at(0);
77     }
padded_shape_fn()78     const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
79 
UseMultipleStreams()80     bool UseMultipleStreams() const { return use_multiple_streams_; }
81 
82    private:
83     const int device_ordinal_;
84     const DeviceType device_type_;
85     se::Platform* platform_;  // Not owned.
86     std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
87         shape_determination_fns_;
88     PaddedShapeFn padded_shape_fn_;
89     const bool use_multiple_streams_;
90 
91     TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
92   };
93 
94   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
95   static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
96 
97   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
98   static Status GetMetadata(OpKernelConstruction* ctx,
99                             const Metadata** metadata);
100 
101   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
102   // `device`.
103   static Status GetMetadataFromDevice(DeviceBase* device,
104                                       const XlaDevice::Metadata** metadata);
105 
106   struct Options {
107     // The StreamExecutor platform. Not owned. Must be non-null.
108     se::Platform* platform = nullptr;
109 
110     // The device name's prefix (e.g., "/task:7")
111     string device_name_prefix;
112 
113     // The name of the XLA device (e.g., "XLA_CPU")
114     string device_name;
115 
116     // The number of the device.
117     int device_ordinal = -1;
118 
119     // The name of the compilation device (e.g., "XLA_CPU_JIT");
120     string compilation_device_name;
121 
122     // If 'use_multiple_streams' is true, we create separate streams for
123     // compute, host-to-device, and device-to-host communication.
124     bool use_multiple_streams = false;
125 
126     // If true, the XLA devices with the same device ordinal will share the same
127     // compute stream. Otherwise each XLA device will having their own compute
128     // streams.
129     bool use_global_compute_stream = false;
130 
131     // A vector of ShapeDeterminationFn (i.e., a bundle of LayoutSelectionFn,
132     // ShapeRepresentationFn). Each bundle describes how the on-host shapes of
133     // a) argument and return value, for entry computations b) variables, for
134     // all computations, should be represented in XLA. Parameters/return values
135     // will be shaped according to the function pair, and reshaped back to/from
136     // their declared shapes for computations. Must be non-empty.
137     std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
138         shape_determination_fns;
139 
140     // If padded_shape_fn is empty, a default implementation that returns
141     // the logical on-device shape without padding is used.
142     PaddedShapeFn padded_shape_fn;
143 
144     // Set of devices to use. This controls which of the devices on the given
145     // platform will have resources allocated. For GPUs this will be
146     // filled from visible_gpu_devices list from session configuration.
147     std::optional<std::set<int>> allowed_devices;
148   };
149 
150   // Creates a new XLA Device.
151   XlaDevice(const SessionOptions& session_options, const Options& options);
152   ~XlaDevice() override;
153 
154   Allocator* GetAllocator(AllocatorAttributes attr) override
155       TF_LOCKS_EXCLUDED(mu_);
156   void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
157   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
158                     AsyncOpKernel::DoneCallback done) override;
159   Status Sync() override;
160   void Sync(const DoneCallback& done) override;
161 
162   Status TryGetDeviceContext(DeviceContext** out_context) override
163       TF_LOCKS_EXCLUDED(mu_);
164 
165   Status MakeTensorFromProto(const TensorProto& tensor_proto,
166                              const AllocatorAttributes alloc_attrs,
167                              Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
168 
169   Status MakeTensorFromProto(XlaDeviceContext* device_context,
170                              const TensorProto& tensor_proto,
171                              const AllocatorAttributes alloc_attrs,
172                              Tensor* tensor);
173 
metadata()174   const Metadata& metadata() { return xla_metadata_; }
175 
176   // Ensures the DeviceContext associated with this XlaDevice is created and
177   // valid (i.e. all streams are ok). If any state is not valid, a new
178   // DeviceContext will be created.
179   //
180   // TODO(b/111859745): The Eager context needs to call this method to recover
181   // from failures.
182   Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
183 
184   // Two convenient methods to get the underlying device context.
185   // Get the default device context, created by the first
186   // shape_representation_fn.
187   StatusOr<XlaDeviceContext*> GetDeviceContextDefault();
188   // Get the device context given the index.
189   StatusOr<XlaDeviceContext*> GetDeviceContextWithIndex(int index);
190 
191   // Instructs this XlaDevice to set a AcceleratorDeviceInfo, which holds extra
192   // information for GPU and TPU devices.
193   Status UseAcceleratorDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
194 
195   // Instructs this XlaDevice to return 'sync_on_completion' for
196   // AllowsSyncOnCompletion().
197   void SetAllowsSyncOnCompletion(bool sync_on_completion)
198       TF_LOCKS_EXCLUDED(mu_);
199   bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
200 
201   // Installs an error handling callback when RefreshStatus sees !status.ok().
202   void SetHandleDeviceErrorCallback(std::function<Status()> callback);
203 
204   Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
205 
206  private:
207   StatusOr<xla::LocalClient*> GetOrCreateClient() const;
208   Allocator* GetAllocatorLocked(AllocatorAttributes attr)
209       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
210   Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
211                               std::shared_ptr<se::Stream>* stream,
212                               bool* stream_was_changed)
213       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
214 
215   // Return a vector of device context, ordered by the sequence in the given
216   // shape_representation_fns.
217   StatusOr<std::vector<XlaDeviceContext*>> GetDeviceContextLocked()
218       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
219 
220   // Handles error when RefreshStatus sees !status.ok().
221   Status HandleDeviceError();
222 
223   mutable mutex mu_;
224   // The metadata of this XlaDevice.
225   const Metadata xla_metadata_;
226   // Which hardware device in the client's platform this XlaDevice controls.
227   const int device_ordinal_;
228   // The name of the device that is used to compile Ops for this XlaDevice.
229   const DeviceType jit_device_name_;
230   // The platform for this device.
231   se::Platform* const platform_;  // Not owned.
232   // Intra-op threads to spawn (from SessionOptions).
233   const int intra_op_parallelism_threads_;
234   // Memory allocator associated with this device.
235   Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr;  // Not owned.
236 
237   // Stream associated with this device. Operations enqueued on this
238   // stream are executed on the device. Operations include data
239   // copying back and forth between CPU and the device, and
240   // computations enqueued by XLA.
241   std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
242   // If false, only stream_ is valid and all computation and transfers use
243   // stream_. If true, computation is performed by stream_ and transfers are
244   // performed by host_to_device/device_to_device stream or borrowing a stream
245   // for each device to host transfer.
246   const bool use_multiple_streams_;
247   // If use_multiple_streams_, host to device transfers are performed using this
248   // stream.
249   std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
250   // If use_multiple_streams_, transfers between different devices are performed
251   // using these streams.
252   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
253       TF_GUARDED_BY(mu_);
254 
255   // See comments in options.
256   std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
257       shape_determination_fns_;
258 
259   // A list of the device context accessed by all users of the XlaDevice, set by
260   // calls to EnsureDeviceContextOk. The number of device conetexts is based on
261   // the number of shape representation functions in XlaDevice::Options. If
262   // accelerator_device_info_ is non-null, this pointer is also filled in to
263   // that struct. XlaDeviceContext is a ref-counted object.
264   std::vector<XlaDeviceContext*> device_contexts_ TF_GUARDED_BY(mu_);
265 
266   // Holds extra information for GPU and TPU devices, e.g. the device context.
267   bool use_accelerator_device_info_ TF_GUARDED_BY(mu_) = false;
268   std::unique_ptr<DeviceBase::AcceleratorDeviceInfo> accelerator_device_info_
269       TF_GUARDED_BY(mu_);
270 
271   // Thread pool used for running closures
272   std::unique_ptr<thread::ThreadPool> thread_pool_;
273 
274   // True if the device allows XlaDevice::Sync to be called on completion
275   // regardless of status.
276   bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
277 
278   // A callback that will be invoked when RefreshStatus sees a status error.
279   std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
280 
281   // Set of devices to use. This controls which of the devices on the given
282   // platform will have resources allocated. For GPUs this will be
283   // filled from visible_gpu_devices list from session configuration.
284   std::optional<std::set<int>> allowed_devices_;
285 
286   const bool use_global_compute_stream_;
287 
288   // A static vector with device_ordinal as its index, describing the global
289   // compute streams used in each XLA device. It is only used if
290   // `use_global_compute_stream` in `XlaDevice::Options` is set to true.
291   static mutex global_mu_;
292   static std::vector<std::shared_ptr<se::Stream>>* global_compute_streams_
293       TF_GUARDED_BY(global_mu_);
294 };
295 
296 // Builds OpKernel registrations on 'device' for the JIT operators
297 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
298 // object that encapsulates the kernel registrations.
299 struct XlaDeviceOpRegistrations {
300   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
301       op_kernel_registrars;
302 };
303 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
304                                                    const char* jit_device);
305 
306 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
307 
308 }  // namespace tensorflow
309 
310 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
311