1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra 17 // runtime. 18 // 19 // Operators assigned to an XlaDevice are compiled into XLA computations. 20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. 21 // 22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), 23 // under different names (e.g., XLA_CPU or XLA_GPU). 24 25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 27 #include <set> 28 29 #include "absl/types/optional.h" 30 #include "tensorflow/compiler/jit/xla_device_context.h" 31 #include "tensorflow/compiler/jit/xla_tensor.h" 32 #include "tensorflow/compiler/tf2xla/layout_util.h" 33 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 34 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 35 #include "tensorflow/compiler/xla/client/local_client.h" 36 #include "tensorflow/core/common_runtime/device_factory.h" 37 #include "tensorflow/core/common_runtime/local_device.h" 38 #include "tensorflow/core/framework/allocator.h" 39 #include "tensorflow/core/framework/device_base.h" 40 #include "tensorflow/core/framework/node_def_builder.h" 41 #include "tensorflow/core/framework/op_kernel.h" 42 #include "tensorflow/core/framework/resource_mgr.h" 43 #include "tensorflow/core/framework/tensor.h" 44 #include "tensorflow/core/framework/types.h" 45 #include "tensorflow/core/lib/core/status.h" 46 #include "tensorflow/core/platform/mutex.h" 47 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 48 49 namespace tensorflow { 50 51 class XlaDevice : public LocalDevice { 52 public: 53 // Given a tensor, sets `xla::Shape*` the shape of tensor's representation 54 // on device, fully padded. On error, the contents of `xla::Shape*` 55 // are undefined. 56 typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn; 57 58 // Wrapper class to store metadata about the XlaDevice, where it can be 59 // retrieved e.g., when lazily creating the XlaCompilationCache device. 60 class Metadata { 61 public: 62 Metadata(int device_ordinal, se::Platform* platform, 63 const DeviceType& device_type, 64 std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> 65 shape_determination_fns, 66 PaddedShapeFn padded_shape_fn, bool use_multiple_streams); 67 68 // The index of the device on this host. 69 int device_ordinal() const; 70 71 se::Platform* platform() const; 72 xla::LocalClient* client() const; 73 const DeviceType& jit_device_type() const; 74 const XlaShapeLayoutHelpers::ShapeDeterminationFns& default_shape_determination_fns()75 default_shape_determination_fns() const { 76 return shape_determination_fns_.at(0); 77 } padded_shape_fn()78 const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } 79 UseMultipleStreams()80 bool UseMultipleStreams() const { return use_multiple_streams_; } 81 82 private: 83 const int device_ordinal_; 84 const DeviceType device_type_; 85 se::Platform* platform_; // Not owned. 86 std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> 87 shape_determination_fns_; 88 PaddedShapeFn padded_shape_fn_; 89 const bool use_multiple_streams_; 90 91 TF_DISALLOW_COPY_AND_ASSIGN(Metadata); 92 }; 93 94 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 95 static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); 96 97 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 98 static Status GetMetadata(OpKernelConstruction* ctx, 99 const Metadata** metadata); 100 101 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by 102 // `device`. 103 static Status GetMetadataFromDevice(DeviceBase* device, 104 const XlaDevice::Metadata** metadata); 105 106 struct Options { 107 // The StreamExecutor platform. Not owned. Must be non-null. 108 se::Platform* platform = nullptr; 109 110 // The device name's prefix (e.g., "/task:7") 111 string device_name_prefix; 112 113 // The name of the XLA device (e.g., "XLA_CPU") 114 string device_name; 115 116 // The number of the device. 117 int device_ordinal = -1; 118 119 // The name of the compilation device (e.g., "XLA_CPU_JIT"); 120 string compilation_device_name; 121 122 // If 'use_multiple_streams' is true, we create separate streams for 123 // compute, host-to-device, and device-to-host communication. 124 bool use_multiple_streams = false; 125 126 // If true, the XLA devices with the same device ordinal will share the same 127 // compute stream. Otherwise each XLA device will having their own compute 128 // streams. 129 bool use_global_compute_stream = false; 130 131 // A vector of ShapeDeterminationFn (i.e., a bundle of LayoutSelectionFn, 132 // ShapeRepresentationFn). Each bundle describes how the on-host shapes of 133 // a) argument and return value, for entry computations b) variables, for 134 // all computations, should be represented in XLA. Parameters/return values 135 // will be shaped according to the function pair, and reshaped back to/from 136 // their declared shapes for computations. Must be non-empty. 137 std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> 138 shape_determination_fns; 139 140 // If padded_shape_fn is empty, a default implementation that returns 141 // the logical on-device shape without padding is used. 142 PaddedShapeFn padded_shape_fn; 143 144 // Set of devices to use. This controls which of the devices on the given 145 // platform will have resources allocated. For GPUs this will be 146 // filled from visible_gpu_devices list from session configuration. 147 std::optional<std::set<int>> allowed_devices; 148 }; 149 150 // Creates a new XLA Device. 151 XlaDevice(const SessionOptions& session_options, const Options& options); 152 ~XlaDevice() override; 153 154 Allocator* GetAllocator(AllocatorAttributes attr) override 155 TF_LOCKS_EXCLUDED(mu_); 156 void Compute(OpKernel* op_kernel, OpKernelContext* context) override; 157 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 158 AsyncOpKernel::DoneCallback done) override; 159 Status Sync() override; 160 void Sync(const DoneCallback& done) override; 161 162 Status TryGetDeviceContext(DeviceContext** out_context) override 163 TF_LOCKS_EXCLUDED(mu_); 164 165 Status MakeTensorFromProto(const TensorProto& tensor_proto, 166 const AllocatorAttributes alloc_attrs, 167 Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_); 168 169 Status MakeTensorFromProto(XlaDeviceContext* device_context, 170 const TensorProto& tensor_proto, 171 const AllocatorAttributes alloc_attrs, 172 Tensor* tensor); 173 metadata()174 const Metadata& metadata() { return xla_metadata_; } 175 176 // Ensures the DeviceContext associated with this XlaDevice is created and 177 // valid (i.e. all streams are ok). If any state is not valid, a new 178 // DeviceContext will be created. 179 // 180 // TODO(b/111859745): The Eager context needs to call this method to recover 181 // from failures. 182 Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); 183 184 // Two convenient methods to get the underlying device context. 185 // Get the default device context, created by the first 186 // shape_representation_fn. 187 StatusOr<XlaDeviceContext*> GetDeviceContextDefault(); 188 // Get the device context given the index. 189 StatusOr<XlaDeviceContext*> GetDeviceContextWithIndex(int index); 190 191 // Instructs this XlaDevice to set a AcceleratorDeviceInfo, which holds extra 192 // information for GPU and TPU devices. 193 Status UseAcceleratorDeviceInfo() TF_LOCKS_EXCLUDED(mu_); 194 195 // Instructs this XlaDevice to return 'sync_on_completion' for 196 // AllowsSyncOnCompletion(). 197 void SetAllowsSyncOnCompletion(bool sync_on_completion) 198 TF_LOCKS_EXCLUDED(mu_); 199 bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_); 200 201 // Installs an error handling callback when RefreshStatus sees !status.ok(). 202 void SetHandleDeviceErrorCallback(std::function<Status()> callback); 203 204 Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); 205 206 private: 207 StatusOr<xla::LocalClient*> GetOrCreateClient() const; 208 Allocator* GetAllocatorLocked(AllocatorAttributes attr) 209 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 210 Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, 211 std::shared_ptr<se::Stream>* stream, 212 bool* stream_was_changed) 213 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 214 215 // Return a vector of device context, ordered by the sequence in the given 216 // shape_representation_fns. 217 StatusOr<std::vector<XlaDeviceContext*>> GetDeviceContextLocked() 218 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 219 220 // Handles error when RefreshStatus sees !status.ok(). 221 Status HandleDeviceError(); 222 223 mutable mutex mu_; 224 // The metadata of this XlaDevice. 225 const Metadata xla_metadata_; 226 // Which hardware device in the client's platform this XlaDevice controls. 227 const int device_ordinal_; 228 // The name of the device that is used to compile Ops for this XlaDevice. 229 const DeviceType jit_device_name_; 230 // The platform for this device. 231 se::Platform* const platform_; // Not owned. 232 // Intra-op threads to spawn (from SessionOptions). 233 const int intra_op_parallelism_threads_; 234 // Memory allocator associated with this device. 235 Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned. 236 237 // Stream associated with this device. Operations enqueued on this 238 // stream are executed on the device. Operations include data 239 // copying back and forth between CPU and the device, and 240 // computations enqueued by XLA. 241 std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_); 242 // If false, only stream_ is valid and all computation and transfers use 243 // stream_. If true, computation is performed by stream_ and transfers are 244 // performed by host_to_device/device_to_device stream or borrowing a stream 245 // for each device to host transfer. 246 const bool use_multiple_streams_; 247 // If use_multiple_streams_, host to device transfers are performed using this 248 // stream. 249 std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_); 250 // If use_multiple_streams_, transfers between different devices are performed 251 // using these streams. 252 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_ 253 TF_GUARDED_BY(mu_); 254 255 // See comments in options. 256 std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> 257 shape_determination_fns_; 258 259 // A list of the device context accessed by all users of the XlaDevice, set by 260 // calls to EnsureDeviceContextOk. The number of device conetexts is based on 261 // the number of shape representation functions in XlaDevice::Options. If 262 // accelerator_device_info_ is non-null, this pointer is also filled in to 263 // that struct. XlaDeviceContext is a ref-counted object. 264 std::vector<XlaDeviceContext*> device_contexts_ TF_GUARDED_BY(mu_); 265 266 // Holds extra information for GPU and TPU devices, e.g. the device context. 267 bool use_accelerator_device_info_ TF_GUARDED_BY(mu_) = false; 268 std::unique_ptr<DeviceBase::AcceleratorDeviceInfo> accelerator_device_info_ 269 TF_GUARDED_BY(mu_); 270 271 // Thread pool used for running closures 272 std::unique_ptr<thread::ThreadPool> thread_pool_; 273 274 // True if the device allows XlaDevice::Sync to be called on completion 275 // regardless of status. 276 bool sync_on_completion_ TF_GUARDED_BY(mu_) = true; 277 278 // A callback that will be invoked when RefreshStatus sees a status error. 279 std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_); 280 281 // Set of devices to use. This controls which of the devices on the given 282 // platform will have resources allocated. For GPUs this will be 283 // filled from visible_gpu_devices list from session configuration. 284 std::optional<std::set<int>> allowed_devices_; 285 286 const bool use_global_compute_stream_; 287 288 // A static vector with device_ordinal as its index, describing the global 289 // compute streams used in each XLA device. It is only used if 290 // `use_global_compute_stream` in `XlaDevice::Options` is set to true. 291 static mutex global_mu_; 292 static std::vector<std::shared_ptr<se::Stream>>* global_compute_streams_ 293 TF_GUARDED_BY(global_mu_); 294 }; 295 296 // Builds OpKernel registrations on 'device' for the JIT operators 297 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations 298 // object that encapsulates the kernel registrations. 299 struct XlaDeviceOpRegistrations { 300 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 301 op_kernel_registrars; 302 }; 303 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, 304 const char* jit_device); 305 306 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); 307 308 } // namespace tensorflow 309 310 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 311