xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_device.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device.h"
17 
18 #include <stdlib.h>
19 
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "absl/base/call_once.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
28 #include "tensorflow/compiler/jit/xla_device_context.h"
29 #include "tensorflow/compiler/jit/xla_device_ops.h"
30 #include "tensorflow/compiler/tf2xla/shape_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/service/stream_pool.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/renamed_device.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/device_base.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/kernel_def.pb.h"
44 #include "tensorflow/core/framework/node_def_builder.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/lib/core/notification.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
53 #include "tensorflow/core/platform/tracing.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/public/version.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/dump_graph.h"
59 #include "tensorflow/core/util/ptr_util.h"
60 #include "tensorflow/core/util/stream_executor_util.h"
61 
62 namespace tensorflow {
63 
64 // Default PaddedShapeFn implementation that simply returns the unpadded
65 // on-device shape. This is accurate for CPU and GPU devices that neither
66 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)67 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
68   const tensorflow::XlaTensor* xla_tensor =
69       tensorflow::XlaTensor::FromTensor(&tensor);
70   if (xla_tensor == nullptr) {
71     return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
72   }
73 
74   const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
75   *shape = shaped_buffer.on_device_shape();
76   return OkStatus();
77 }
78 
79 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
80 // XlaDeviceAllocator is created on demand and is associated with a
81 // XlaDevice. It outlives the device itself (for instance, the buffer
82 // backing a tensor holds a pointer to the allocator for book-keeping,
83 // and this buffer can outlast the device).
84 class XlaDeviceAllocatorState {
85  public:
86   // Creates or returns a cached XlaDeviceAllocator for a given
87   // backend and device_ordinal.
88   static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
89       const xla::Backend* backend, int device_ordinal);
90 
91  private:
92   // Returns the singleton instance of XlaDeviceAllocatorState.
93   static XlaDeviceAllocatorState& Singleton();
94   XlaDeviceAllocatorState();
95   ~XlaDeviceAllocatorState();
96 
97   mutex allocator_mutex_;  // Guards the singleton allocator state.
98   std::unordered_map<std::pair<const xla::Backend*, int>,
99                      std::unique_ptr<XlaDeviceAllocator>,
100                      hash<std::pair<const xla::Backend*, int>>>
101       allocators_ TF_GUARDED_BY(allocator_mutex_);
102 
103   TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
104 };
105 
Singleton()106 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
107   static auto a = new XlaDeviceAllocatorState;
108   return *a;
109 }
110 
111 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
112 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
113 
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)114 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
115     const xla::Backend* backend, int device_ordinal) {
116   XlaDeviceAllocatorState& state = Singleton();
117   mutex_lock lock(state.allocator_mutex_);
118 
119   auto it = state.allocators_.find({backend, device_ordinal});
120   if (it != state.allocators_.end()) {
121     return it->second.get();
122   }
123 
124   std::unique_ptr<XlaDeviceAllocator> alloc =
125       std::make_unique<XlaDeviceAllocator>(
126           backend->stream_executors()[device_ordinal]);
127   XlaDeviceAllocator* alloc_ptr = alloc.get();
128   state.allocators_[{backend, device_ordinal}] = std::move(alloc);
129   return alloc_ptr;
130 }
131 
132 namespace {
133 
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)134 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
135                                                  const string& device_name,
136                                                  int device_ordinal) {
137   return Device::BuildDeviceAttributes(
138       absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
139       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
140       absl::StrCat("device: ", device_name, " device"));
141 }
142 
143 }  // namespace
144 
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> shape_determination_fns,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)145 XlaDevice::Metadata::Metadata(
146     int device_ordinal, se::Platform* platform, const DeviceType& device_type,
147     std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
148         shape_determination_fns,
149     PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
150     : device_ordinal_(device_ordinal),
151       device_type_(device_type),
152       platform_(platform),
153       shape_determination_fns_(std::move(shape_determination_fns)),
154       padded_shape_fn_(std::move(padded_shape_fn)),
155       use_multiple_streams_(use_multiple_streams) {}
156 
device_ordinal() const157 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
158 
platform() const159 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
160 
client() const161 xla::LocalClient* XlaDevice::Metadata::client() const {
162   auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
163   return client.ValueOrDie();
164 }
165 
jit_device_type() const166 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
167   return device_type_;
168 }
169 
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)170 /*static*/ Status XlaDevice::GetMetadataFromDevice(
171     DeviceBase* device, const XlaDevice::Metadata** metadata) {
172   *metadata = nullptr;
173   XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
174   if (xla_device == nullptr) {
175     return errors::Internal(
176         "Cannot get XLA metadata from non-XLA device \"", device->name(),
177         "\". GetMetadata must only be called on an XLA device. Either an "
178         "internal bug has been triggered, or an XLA-specific op has been "
179         "placed on the wrong device.");
180   }
181   *metadata = &(xla_device->xla_metadata_);
182   return OkStatus();
183 }
184 
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)185 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
186                                            const Metadata** metadata) {
187   return GetMetadataFromDevice(ctx->device(), metadata);
188 }
189 
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)190 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
191                                            const Metadata** metadata) {
192   return GetMetadataFromDevice(ctx->device(), metadata);
193 }
194 
195 /* static */ mutex XlaDevice::global_mu_(LINKER_INITIALIZED);
196 /* static */ std::vector<std::shared_ptr<se::Stream>>*
197     XlaDevice::global_compute_streams_ =
198         new std::vector<std::shared_ptr<se::Stream>>;
199 
XlaDevice(const SessionOptions & session_options,const Options & options)200 XlaDevice::XlaDevice(const SessionOptions& session_options,
201                      const Options& options)
202     : LocalDevice(session_options,
203                   BuildXlaDeviceAttributes(options.device_name_prefix,
204                                            options.device_name,
205                                            options.device_ordinal)),
206       xla_metadata_(options.device_ordinal, options.platform,
207                     DeviceType(options.compilation_device_name),
208                     options.shape_determination_fns,
209                     options.padded_shape_fn ? options.padded_shape_fn
210                                             : DefaultPaddedShapeFn,
211                     options.use_multiple_streams),
212       device_ordinal_(options.device_ordinal),
213       jit_device_name_(options.compilation_device_name),
214       platform_(options.platform),
215       intra_op_parallelism_threads_(
216           session_options.config.intra_op_parallelism_threads()),
217       use_multiple_streams_(options.use_multiple_streams),
218       shape_determination_fns_(options.shape_determination_fns),
219       allowed_devices_(options.allowed_devices),
220       use_global_compute_stream_(options.use_global_compute_stream) {
221   if (options.shape_determination_fns.empty()) {
222     LOG(ERROR) << "shape_representation_fns must be non-empty.";
223   }
224   VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
225           << options.device_ordinal << " " << this;
226   VLOG(1) << "XlaDevice options: use_multiple_streams: "
227           << options.use_multiple_streams << " use_global_compute_stream: "
228           << options.use_global_compute_stream;
229   thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
230                                             /*num_threads=*/1));
231 
232   // We have multiple device to device streams to allow for some concurrency
233   // between transfers. The particular value of '4' is chosen fairly
234   // arbitrarily. It may be necessary to make this tunable via
235   // XlaDevice::Options.
236   static constexpr int kNumDeviceToDeviceStreams = 4;
237   device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
238 }
239 
~XlaDevice()240 XlaDevice::~XlaDevice() {
241   VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
242   mutex_lock lock(mu_);
243   for (const auto& iter : device_contexts_) {
244     iter->Unref();
245   }
246 }
247 
GetOrCreateClient() const248 StatusOr<xla::LocalClient*> XlaDevice::GetOrCreateClient() const {
249   // We lazily create the client because the platform commits to the
250   // details of the host hardware when the client is created, so we
251   // don't want to do it until we get a chance to hook the platform up
252   // to a simulator.
253 
254   xla::LocalClientOptions options;
255   options.set_platform(platform_)
256       .set_allowed_devices(allowed_devices_)
257       .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
258   return xla::ClientLibrary::GetOrCreateLocalClient(options);
259 }
260 
GetAllocator(AllocatorAttributes attr)261 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
262   mutex_lock lock(mu_);
263   return GetAllocatorLocked(attr);
264 }
265 
GetAllocatorLocked(AllocatorAttributes attr)266 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
267   if (attr.on_host()) {
268     return cpu_allocator();
269   }
270 
271   if (xla_allocator_ == nullptr) {
272     // TODO(b/78468222): This can fail, at least when the backend is GPU and
273     // there is no GPU on the host.
274     xla::Backend* backend = GetOrCreateClient().ValueOrDie()->mutable_backend();
275     xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
276         backend, device_ordinal_);
277   }
278   return xla_allocator_;
279 }
280 
EnsureDeviceContextOk()281 Status XlaDevice::EnsureDeviceContextOk() {
282   mutex_lock lock(mu_);
283   return GetDeviceContextLocked().status();
284 }
285 
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)286 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
287                                        const string& name,
288                                        std::shared_ptr<se::Stream>* stream,
289                                        bool* stream_was_changed) {
290   if (!(*stream) || !(*stream)->ok()) {
291     xla::StreamPool::Ptr ptr;
292     TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
293     *stream = std::shared_ptr<se::Stream>(std::move(ptr));
294     VLOG(1) << "XlaDevice " << this << " new " << name << " "
295             << (*stream)->DebugStreamPointers();
296     *stream_was_changed = true;
297   }
298   return OkStatus();
299 }
300 
GetDeviceContextLocked()301 StatusOr<std::vector<XlaDeviceContext*>> XlaDevice::GetDeviceContextLocked() {
302   TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient());
303   xla::Backend* backend = client->mutable_backend();
304 
305   // Ensure all our streams are valid, borrowing new streams if necessary.
306   bool need_new_device_context = device_contexts_.empty();
307   if (use_global_compute_stream_) {
308     mutex_lock lock(global_mu_);
309     if (global_compute_streams_->size() <= device_ordinal_) {
310       global_compute_streams_->resize(device_ordinal_ + 1, nullptr);
311     }
312 
313     auto& global_stream = global_compute_streams_->at(device_ordinal_);
314     if (global_stream != nullptr && global_stream->ok()) {
315       stream_ = global_stream;
316     } else {
317       // Directly create the stream here instead of borrowing from the stream
318       // pool to avoid potential lifetime issues.
319       stream_ = std::make_unique<se::Stream>(
320           backend->stream_executors()[device_ordinal_]);
321       stream_->Init();
322       TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
323                                               &need_new_device_context));
324       (*global_compute_streams_)[device_ordinal_] = stream_;
325     }
326   } else {
327     TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
328                                             &need_new_device_context));
329   }
330 
331   std::shared_ptr<se::Stream> host_to_device_stream;
332   std::shared_ptr<se::Stream> device_to_host_stream;
333   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
334   if (use_multiple_streams_) {
335     TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
336                                             &host_to_device_stream_,
337                                             &need_new_device_context));
338     for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
339       TF_RETURN_IF_ERROR(
340           EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
341                                &need_new_device_context));
342     }
343     host_to_device_stream = host_to_device_stream_;
344     device_to_device_streams = device_to_device_streams_;
345     // The data transfer requests from device to host could arrive out of order,
346     // so a single stream would cause deadlock. For this case,
347     // xla_device_context would borrow a stream for each transfer request.
348     device_to_host_stream = nullptr;
349   } else {
350     host_to_device_stream = stream_;
351     device_to_host_stream = stream_;
352     device_to_device_streams = {stream_};
353   }
354 
355   if (!need_new_device_context) {
356     return device_contexts_;
357   }
358 
359   // At this point we know we need a new device context.
360   // Call GetAllocator for the side-effect of ensuring the allocator is created.
361   GetAllocatorLocked({});
362   for (const auto& iter : device_contexts_) {
363     iter->Unref();
364   }
365   // The XlaDeviceContext keeps a reference count to the streams, and the
366   // XlaDeviceContext remains live for the duration of a Executor run. This
367   // ensures that the streams remain live for the duration of a run, even if
368   // an error is encountered and the streams are replaced with new ones.
369   for (const auto& iter : shape_determination_fns_) {
370     auto device_context = new XlaDeviceContext(
371         stream_, host_to_device_stream, device_to_host_stream,
372         device_to_device_streams, client, iter, thread_pool_.get());
373     VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
374             << device_context;
375     device_contexts_.emplace_back(device_context);
376   }
377 
378   // Create and set a new AcceleratorDeviceInfo, if necessary.
379   //
380   // TODO(b/78232898): This isn't thread-safe; there is a race between the call
381   // to set_tensorflow_accelerator_device_info() with ops that call the getter
382   // tensorflow_accelerator_device_info(). This isn't trivially fixed by adding
383   // locking to those methods; see the bug for details. Our only saving grace at
384   // the moment is that this race doesn't seem to occur in practice.
385   if (use_accelerator_device_info_) {
386     auto accelerator_device_info =
387         std::make_unique<DeviceBase::AcceleratorDeviceInfo>();
388     accelerator_device_info->stream = stream_.get();
389     accelerator_device_info->default_context = device_contexts_.at(0);
390     set_tensorflow_accelerator_device_info(accelerator_device_info.get());
391     accelerator_device_info_ = std::move(accelerator_device_info);
392     VLOG(1) << "XlaDevice " << this << " new AcceleratorDeviceInfo "
393             << accelerator_device_info_.get();
394   }
395 
396   return device_contexts_;
397 }
398 
GetDeviceContextWithIndex(int index)399 StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextWithIndex(int index) {
400   mutex_lock lock(mu_);
401   TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked());
402   return device_contexts.at(index);
403 }
404 
GetDeviceContextDefault()405 StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextDefault() {
406   return GetDeviceContextWithIndex(0);
407 }
408 
UseAcceleratorDeviceInfo()409 Status XlaDevice::UseAcceleratorDeviceInfo() {
410   mutex_lock lock(mu_);
411   use_accelerator_device_info_ = true;
412   return GetDeviceContextLocked().status();
413 }
414 
TryGetDeviceContext(DeviceContext ** out_context)415 Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
416   TF_ASSIGN_OR_RETURN(auto device_context, GetDeviceContextDefault());
417   device_context->Ref();
418   *out_context = device_context;
419   return OkStatus();
420 }
421 
422 // Warn about XLA_CPU/XLA_GPU exactly once.
ShowXlaDeviceDeprecationWarning(absl::string_view compilation_device_name)423 static void ShowXlaDeviceDeprecationWarning(
424     absl::string_view compilation_device_name) {
425   static absl::once_flag once;
426   if (absl::StrContains(compilation_device_name, "CPU") ||
427       absl::StrContains(compilation_device_name, "GPU")) {
428     absl::call_once(once, [] {
429       LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
430                    "removed in subsequent releases. Instead, use either "
431                    "@tf.function(jit_compile=True) for must-compile "
432                    "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
433                    "for auto-clustering best-effort compilation.";
434     });
435   }
436 }
437 
Compute(OpKernel * op_kernel,OpKernelContext * context)438 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
439   VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
440           << op_kernel->type_string();
441   ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
442   op_kernel->Compute(context);
443 }
444 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)445 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
446                              AsyncOpKernel::DoneCallback done) {
447   ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
448   VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
449           << op_kernel->type_string();
450   op_kernel->ComputeAsync(context, done);
451 }
452 
Sync()453 Status XlaDevice::Sync() {
454   VLOG(1) << "XlaDevice::Sync";
455   profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
456   std::shared_ptr<se::Stream> stream;
457   {
458     mutex_lock lock(mu_);
459     stream = stream_;
460   }
461   if (!stream) return OkStatus();
462 
463   Status status = stream->BlockHostUntilDone();
464   TF_RETURN_IF_ERROR(status);
465   if (!stream->ok()) {
466     return errors::Internal("XlaDevice::Sync() failed.");
467   }
468   VLOG(1) << "XlaDevice::Sync completed";
469   return OkStatus();
470 }
471 
472 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
473 // synchronous version.
Sync(const DoneCallback & done)474 void XlaDevice::Sync(const DoneCallback& done) {
475   VLOG(1) << "XlaDevice::Sync (asynchronous)";
476   std::shared_ptr<se::Stream> stream;
477   {
478     mutex_lock lock(mu_);
479     stream = stream_;
480   }
481   if (!stream) {
482     done(OkStatus());
483     return;
484   }
485 
486   // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
487   // the end of the stream, after everything that has already been enqueued
488   // there at this moment. When the host callback is called, everything before
489   // it must have already finished, and the host callback will then place the
490   // task below onto a background thread. (See the implementation of
491   // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
492   // callback is finally called from that background thread, we know for sure
493   // that everything enqueued onto the stream (i.e., the device) at this very
494   // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
495   // This achieves a device-wide sync.
496   stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
497     profiler::TraceMe activity("XlaDevice::Sync::Callback",
498                                profiler::TraceMeLevel::kInfo);
499     done(stream->ok() ? OkStatus()
500                       : errors::Internal("XlaDevice::Sync() failed."));
501   });
502 }
503 
MakeTensorFromProto(XlaDeviceContext * device_context,const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)504 Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
505                                       const TensorProto& tensor_proto,
506                                       const AllocatorAttributes alloc_attrs,
507                                       Tensor* tensor) {
508   Tensor parsed(tensor_proto.dtype());
509   if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
510     return errors::InvalidArgument("Cannot parse tensor from proto: ",
511                                    tensor_proto.DebugString());
512   }
513 
514   Status status;
515   if (alloc_attrs.on_host()) {
516     *tensor = parsed;
517   } else {
518     Allocator* allocator;
519     {
520       mutex_lock lock(mu_);
521       allocator = GetAllocatorLocked(alloc_attrs);
522     }
523     Tensor copy(allocator, parsed.dtype(), parsed.shape());
524     TF_RETURN_IF_ERROR(
525         device_context->CopyCPUTensorToDeviceSync(&parsed, this, &copy));
526     *tensor = copy;
527   }
528   VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
529   return status;
530 }
531 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)532 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
533                                       const AllocatorAttributes alloc_attrs,
534                                       Tensor* tensor) {
535   VLOG(1) << "XlaDevice::MakeTensorFromProto";
536   XlaDeviceContext* device_context;
537   TF_ASSIGN_OR_RETURN(device_context, GetDeviceContextDefault());
538   return MakeTensorFromProto(device_context, tensor_proto, alloc_attrs, tensor);
539 }
540 
SetAllowsSyncOnCompletion(bool sync_on_completion)541 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
542   mutex_lock lock(mu_);
543   sync_on_completion_ = sync_on_completion;
544 }
545 
AllowsSyncOnCompletion() const546 bool XlaDevice::AllowsSyncOnCompletion() const {
547   mutex_lock lock(mu_);
548   return sync_on_completion_;
549 }
550 
SetHandleDeviceErrorCallback(std::function<Status ()> callback)551 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
552   mutex_lock lock(mu_);
553   device_error_callback_ = callback;
554 }
555 
HandleDeviceError()556 Status XlaDevice::HandleDeviceError() {
557   std::function<Status()> local_device_error_callback;
558   {
559     mutex_lock lock(mu_);
560     local_device_error_callback = device_error_callback_;
561   }
562   if (local_device_error_callback != nullptr) {
563     return local_device_error_callback();
564   }
565   return OkStatus();
566 }
567 
RefreshStatus()568 Status XlaDevice::RefreshStatus() {
569   std::shared_ptr<se::Stream> stream;
570   {
571     mutex_lock lock(mu_);
572     stream = stream_;
573   }
574   if (!stream) {
575     return OkStatus();
576   }
577   Status status = stream->RefreshStatus();
578   if (!status.ok()) {
579     // Ignore errors from HandleDeviceError, since by definition the status is
580     // already non-ok, so there's nothing extra to report if HandleDeviceError
581     // itself returns an error.
582     HandleDeviceError().IgnoreError();
583   }
584   return status;
585 }
586 
RegisterXlaDeviceKernels(const char * device,const char * jit_device)587 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
588                                                    const char* jit_device) {
589   // Any op assigned to the device that isn't rewritten by the graph rewriter
590   // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
591   // it just-in-time.
592   auto factory = [](OpKernelConstruction* context) -> OpKernel* {
593     return new XlaCompileOnDemandOp(context);
594   };
595   XlaOpRegistry::RegisterCompilationKernels();
596   XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
597   for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
598            jit_device,
599            /*include_compilation_only_kernels=*/false)) {
600     KernelDef* def = new KernelDef(*jit_def);
601     const std::unordered_set<std::string>* constant_inputs =
602         XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
603 
604     for (const std::string& arg_name : *constant_inputs) {
605       def->add_host_memory_arg(arg_name);
606     }
607 
608     def->set_device_type(device);
609     registrations->op_kernel_registrars.emplace_back(
610         new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
611                                               factory));
612   }
613   return registrations;
614 }
615 
616 }  // namespace tensorflow
617