1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device.h"
17
18 #include <stdlib.h>
19
20 #include <unordered_set>
21 #include <utility>
22
23 #include "absl/base/call_once.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
28 #include "tensorflow/compiler/jit/xla_device_context.h"
29 #include "tensorflow/compiler/jit/xla_device_ops.h"
30 #include "tensorflow/compiler/tf2xla/shape_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/service/stream_pool.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/renamed_device.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/device_base.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/kernel_def.pb.h"
44 #include "tensorflow/core/framework/node_def_builder.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/lib/core/notification.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
53 #include "tensorflow/core/platform/tracing.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/public/version.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/dump_graph.h"
59 #include "tensorflow/core/util/ptr_util.h"
60 #include "tensorflow/core/util/stream_executor_util.h"
61
62 namespace tensorflow {
63
64 // Default PaddedShapeFn implementation that simply returns the unpadded
65 // on-device shape. This is accurate for CPU and GPU devices that neither
66 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)67 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
68 const tensorflow::XlaTensor* xla_tensor =
69 tensorflow::XlaTensor::FromTensor(&tensor);
70 if (xla_tensor == nullptr) {
71 return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
72 }
73
74 const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
75 *shape = shaped_buffer.on_device_shape();
76 return OkStatus();
77 }
78
79 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
80 // XlaDeviceAllocator is created on demand and is associated with a
81 // XlaDevice. It outlives the device itself (for instance, the buffer
82 // backing a tensor holds a pointer to the allocator for book-keeping,
83 // and this buffer can outlast the device).
84 class XlaDeviceAllocatorState {
85 public:
86 // Creates or returns a cached XlaDeviceAllocator for a given
87 // backend and device_ordinal.
88 static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
89 const xla::Backend* backend, int device_ordinal);
90
91 private:
92 // Returns the singleton instance of XlaDeviceAllocatorState.
93 static XlaDeviceAllocatorState& Singleton();
94 XlaDeviceAllocatorState();
95 ~XlaDeviceAllocatorState();
96
97 mutex allocator_mutex_; // Guards the singleton allocator state.
98 std::unordered_map<std::pair<const xla::Backend*, int>,
99 std::unique_ptr<XlaDeviceAllocator>,
100 hash<std::pair<const xla::Backend*, int>>>
101 allocators_ TF_GUARDED_BY(allocator_mutex_);
102
103 TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
104 };
105
Singleton()106 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
107 static auto a = new XlaDeviceAllocatorState;
108 return *a;
109 }
110
111 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
112 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
113
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)114 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
115 const xla::Backend* backend, int device_ordinal) {
116 XlaDeviceAllocatorState& state = Singleton();
117 mutex_lock lock(state.allocator_mutex_);
118
119 auto it = state.allocators_.find({backend, device_ordinal});
120 if (it != state.allocators_.end()) {
121 return it->second.get();
122 }
123
124 std::unique_ptr<XlaDeviceAllocator> alloc =
125 std::make_unique<XlaDeviceAllocator>(
126 backend->stream_executors()[device_ordinal]);
127 XlaDeviceAllocator* alloc_ptr = alloc.get();
128 state.allocators_[{backend, device_ordinal}] = std::move(alloc);
129 return alloc_ptr;
130 }
131
132 namespace {
133
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)134 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
135 const string& device_name,
136 int device_ordinal) {
137 return Device::BuildDeviceAttributes(
138 absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
139 DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
140 absl::StrCat("device: ", device_name, " device"));
141 }
142
143 } // namespace
144
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> shape_determination_fns,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)145 XlaDevice::Metadata::Metadata(
146 int device_ordinal, se::Platform* platform, const DeviceType& device_type,
147 std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
148 shape_determination_fns,
149 PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
150 : device_ordinal_(device_ordinal),
151 device_type_(device_type),
152 platform_(platform),
153 shape_determination_fns_(std::move(shape_determination_fns)),
154 padded_shape_fn_(std::move(padded_shape_fn)),
155 use_multiple_streams_(use_multiple_streams) {}
156
device_ordinal() const157 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
158
platform() const159 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
160
client() const161 xla::LocalClient* XlaDevice::Metadata::client() const {
162 auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
163 return client.ValueOrDie();
164 }
165
jit_device_type() const166 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
167 return device_type_;
168 }
169
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)170 /*static*/ Status XlaDevice::GetMetadataFromDevice(
171 DeviceBase* device, const XlaDevice::Metadata** metadata) {
172 *metadata = nullptr;
173 XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
174 if (xla_device == nullptr) {
175 return errors::Internal(
176 "Cannot get XLA metadata from non-XLA device \"", device->name(),
177 "\". GetMetadata must only be called on an XLA device. Either an "
178 "internal bug has been triggered, or an XLA-specific op has been "
179 "placed on the wrong device.");
180 }
181 *metadata = &(xla_device->xla_metadata_);
182 return OkStatus();
183 }
184
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)185 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
186 const Metadata** metadata) {
187 return GetMetadataFromDevice(ctx->device(), metadata);
188 }
189
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)190 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
191 const Metadata** metadata) {
192 return GetMetadataFromDevice(ctx->device(), metadata);
193 }
194
195 /* static */ mutex XlaDevice::global_mu_(LINKER_INITIALIZED);
196 /* static */ std::vector<std::shared_ptr<se::Stream>>*
197 XlaDevice::global_compute_streams_ =
198 new std::vector<std::shared_ptr<se::Stream>>;
199
XlaDevice(const SessionOptions & session_options,const Options & options)200 XlaDevice::XlaDevice(const SessionOptions& session_options,
201 const Options& options)
202 : LocalDevice(session_options,
203 BuildXlaDeviceAttributes(options.device_name_prefix,
204 options.device_name,
205 options.device_ordinal)),
206 xla_metadata_(options.device_ordinal, options.platform,
207 DeviceType(options.compilation_device_name),
208 options.shape_determination_fns,
209 options.padded_shape_fn ? options.padded_shape_fn
210 : DefaultPaddedShapeFn,
211 options.use_multiple_streams),
212 device_ordinal_(options.device_ordinal),
213 jit_device_name_(options.compilation_device_name),
214 platform_(options.platform),
215 intra_op_parallelism_threads_(
216 session_options.config.intra_op_parallelism_threads()),
217 use_multiple_streams_(options.use_multiple_streams),
218 shape_determination_fns_(options.shape_determination_fns),
219 allowed_devices_(options.allowed_devices),
220 use_global_compute_stream_(options.use_global_compute_stream) {
221 if (options.shape_determination_fns.empty()) {
222 LOG(ERROR) << "shape_representation_fns must be non-empty.";
223 }
224 VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
225 << options.device_ordinal << " " << this;
226 VLOG(1) << "XlaDevice options: use_multiple_streams: "
227 << options.use_multiple_streams << " use_global_compute_stream: "
228 << options.use_global_compute_stream;
229 thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
230 /*num_threads=*/1));
231
232 // We have multiple device to device streams to allow for some concurrency
233 // between transfers. The particular value of '4' is chosen fairly
234 // arbitrarily. It may be necessary to make this tunable via
235 // XlaDevice::Options.
236 static constexpr int kNumDeviceToDeviceStreams = 4;
237 device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
238 }
239
~XlaDevice()240 XlaDevice::~XlaDevice() {
241 VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
242 mutex_lock lock(mu_);
243 for (const auto& iter : device_contexts_) {
244 iter->Unref();
245 }
246 }
247
GetOrCreateClient() const248 StatusOr<xla::LocalClient*> XlaDevice::GetOrCreateClient() const {
249 // We lazily create the client because the platform commits to the
250 // details of the host hardware when the client is created, so we
251 // don't want to do it until we get a chance to hook the platform up
252 // to a simulator.
253
254 xla::LocalClientOptions options;
255 options.set_platform(platform_)
256 .set_allowed_devices(allowed_devices_)
257 .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
258 return xla::ClientLibrary::GetOrCreateLocalClient(options);
259 }
260
GetAllocator(AllocatorAttributes attr)261 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
262 mutex_lock lock(mu_);
263 return GetAllocatorLocked(attr);
264 }
265
GetAllocatorLocked(AllocatorAttributes attr)266 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
267 if (attr.on_host()) {
268 return cpu_allocator();
269 }
270
271 if (xla_allocator_ == nullptr) {
272 // TODO(b/78468222): This can fail, at least when the backend is GPU and
273 // there is no GPU on the host.
274 xla::Backend* backend = GetOrCreateClient().ValueOrDie()->mutable_backend();
275 xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
276 backend, device_ordinal_);
277 }
278 return xla_allocator_;
279 }
280
EnsureDeviceContextOk()281 Status XlaDevice::EnsureDeviceContextOk() {
282 mutex_lock lock(mu_);
283 return GetDeviceContextLocked().status();
284 }
285
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)286 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
287 const string& name,
288 std::shared_ptr<se::Stream>* stream,
289 bool* stream_was_changed) {
290 if (!(*stream) || !(*stream)->ok()) {
291 xla::StreamPool::Ptr ptr;
292 TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
293 *stream = std::shared_ptr<se::Stream>(std::move(ptr));
294 VLOG(1) << "XlaDevice " << this << " new " << name << " "
295 << (*stream)->DebugStreamPointers();
296 *stream_was_changed = true;
297 }
298 return OkStatus();
299 }
300
GetDeviceContextLocked()301 StatusOr<std::vector<XlaDeviceContext*>> XlaDevice::GetDeviceContextLocked() {
302 TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient());
303 xla::Backend* backend = client->mutable_backend();
304
305 // Ensure all our streams are valid, borrowing new streams if necessary.
306 bool need_new_device_context = device_contexts_.empty();
307 if (use_global_compute_stream_) {
308 mutex_lock lock(global_mu_);
309 if (global_compute_streams_->size() <= device_ordinal_) {
310 global_compute_streams_->resize(device_ordinal_ + 1, nullptr);
311 }
312
313 auto& global_stream = global_compute_streams_->at(device_ordinal_);
314 if (global_stream != nullptr && global_stream->ok()) {
315 stream_ = global_stream;
316 } else {
317 // Directly create the stream here instead of borrowing from the stream
318 // pool to avoid potential lifetime issues.
319 stream_ = std::make_unique<se::Stream>(
320 backend->stream_executors()[device_ordinal_]);
321 stream_->Init();
322 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
323 &need_new_device_context));
324 (*global_compute_streams_)[device_ordinal_] = stream_;
325 }
326 } else {
327 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
328 &need_new_device_context));
329 }
330
331 std::shared_ptr<se::Stream> host_to_device_stream;
332 std::shared_ptr<se::Stream> device_to_host_stream;
333 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
334 if (use_multiple_streams_) {
335 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
336 &host_to_device_stream_,
337 &need_new_device_context));
338 for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
339 TF_RETURN_IF_ERROR(
340 EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
341 &need_new_device_context));
342 }
343 host_to_device_stream = host_to_device_stream_;
344 device_to_device_streams = device_to_device_streams_;
345 // The data transfer requests from device to host could arrive out of order,
346 // so a single stream would cause deadlock. For this case,
347 // xla_device_context would borrow a stream for each transfer request.
348 device_to_host_stream = nullptr;
349 } else {
350 host_to_device_stream = stream_;
351 device_to_host_stream = stream_;
352 device_to_device_streams = {stream_};
353 }
354
355 if (!need_new_device_context) {
356 return device_contexts_;
357 }
358
359 // At this point we know we need a new device context.
360 // Call GetAllocator for the side-effect of ensuring the allocator is created.
361 GetAllocatorLocked({});
362 for (const auto& iter : device_contexts_) {
363 iter->Unref();
364 }
365 // The XlaDeviceContext keeps a reference count to the streams, and the
366 // XlaDeviceContext remains live for the duration of a Executor run. This
367 // ensures that the streams remain live for the duration of a run, even if
368 // an error is encountered and the streams are replaced with new ones.
369 for (const auto& iter : shape_determination_fns_) {
370 auto device_context = new XlaDeviceContext(
371 stream_, host_to_device_stream, device_to_host_stream,
372 device_to_device_streams, client, iter, thread_pool_.get());
373 VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
374 << device_context;
375 device_contexts_.emplace_back(device_context);
376 }
377
378 // Create and set a new AcceleratorDeviceInfo, if necessary.
379 //
380 // TODO(b/78232898): This isn't thread-safe; there is a race between the call
381 // to set_tensorflow_accelerator_device_info() with ops that call the getter
382 // tensorflow_accelerator_device_info(). This isn't trivially fixed by adding
383 // locking to those methods; see the bug for details. Our only saving grace at
384 // the moment is that this race doesn't seem to occur in practice.
385 if (use_accelerator_device_info_) {
386 auto accelerator_device_info =
387 std::make_unique<DeviceBase::AcceleratorDeviceInfo>();
388 accelerator_device_info->stream = stream_.get();
389 accelerator_device_info->default_context = device_contexts_.at(0);
390 set_tensorflow_accelerator_device_info(accelerator_device_info.get());
391 accelerator_device_info_ = std::move(accelerator_device_info);
392 VLOG(1) << "XlaDevice " << this << " new AcceleratorDeviceInfo "
393 << accelerator_device_info_.get();
394 }
395
396 return device_contexts_;
397 }
398
GetDeviceContextWithIndex(int index)399 StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextWithIndex(int index) {
400 mutex_lock lock(mu_);
401 TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked());
402 return device_contexts.at(index);
403 }
404
GetDeviceContextDefault()405 StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextDefault() {
406 return GetDeviceContextWithIndex(0);
407 }
408
UseAcceleratorDeviceInfo()409 Status XlaDevice::UseAcceleratorDeviceInfo() {
410 mutex_lock lock(mu_);
411 use_accelerator_device_info_ = true;
412 return GetDeviceContextLocked().status();
413 }
414
TryGetDeviceContext(DeviceContext ** out_context)415 Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
416 TF_ASSIGN_OR_RETURN(auto device_context, GetDeviceContextDefault());
417 device_context->Ref();
418 *out_context = device_context;
419 return OkStatus();
420 }
421
422 // Warn about XLA_CPU/XLA_GPU exactly once.
ShowXlaDeviceDeprecationWarning(absl::string_view compilation_device_name)423 static void ShowXlaDeviceDeprecationWarning(
424 absl::string_view compilation_device_name) {
425 static absl::once_flag once;
426 if (absl::StrContains(compilation_device_name, "CPU") ||
427 absl::StrContains(compilation_device_name, "GPU")) {
428 absl::call_once(once, [] {
429 LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
430 "removed in subsequent releases. Instead, use either "
431 "@tf.function(jit_compile=True) for must-compile "
432 "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
433 "for auto-clustering best-effort compilation.";
434 });
435 }
436 }
437
Compute(OpKernel * op_kernel,OpKernelContext * context)438 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
439 VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
440 << op_kernel->type_string();
441 ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
442 op_kernel->Compute(context);
443 }
444
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)445 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
446 AsyncOpKernel::DoneCallback done) {
447 ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
448 VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
449 << op_kernel->type_string();
450 op_kernel->ComputeAsync(context, done);
451 }
452
Sync()453 Status XlaDevice::Sync() {
454 VLOG(1) << "XlaDevice::Sync";
455 profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
456 std::shared_ptr<se::Stream> stream;
457 {
458 mutex_lock lock(mu_);
459 stream = stream_;
460 }
461 if (!stream) return OkStatus();
462
463 Status status = stream->BlockHostUntilDone();
464 TF_RETURN_IF_ERROR(status);
465 if (!stream->ok()) {
466 return errors::Internal("XlaDevice::Sync() failed.");
467 }
468 VLOG(1) << "XlaDevice::Sync completed";
469 return OkStatus();
470 }
471
472 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
473 // synchronous version.
Sync(const DoneCallback & done)474 void XlaDevice::Sync(const DoneCallback& done) {
475 VLOG(1) << "XlaDevice::Sync (asynchronous)";
476 std::shared_ptr<se::Stream> stream;
477 {
478 mutex_lock lock(mu_);
479 stream = stream_;
480 }
481 if (!stream) {
482 done(OkStatus());
483 return;
484 }
485
486 // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
487 // the end of the stream, after everything that has already been enqueued
488 // there at this moment. When the host callback is called, everything before
489 // it must have already finished, and the host callback will then place the
490 // task below onto a background thread. (See the implementation of
491 // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
492 // callback is finally called from that background thread, we know for sure
493 // that everything enqueued onto the stream (i.e., the device) at this very
494 // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
495 // This achieves a device-wide sync.
496 stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
497 profiler::TraceMe activity("XlaDevice::Sync::Callback",
498 profiler::TraceMeLevel::kInfo);
499 done(stream->ok() ? OkStatus()
500 : errors::Internal("XlaDevice::Sync() failed."));
501 });
502 }
503
MakeTensorFromProto(XlaDeviceContext * device_context,const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)504 Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
505 const TensorProto& tensor_proto,
506 const AllocatorAttributes alloc_attrs,
507 Tensor* tensor) {
508 Tensor parsed(tensor_proto.dtype());
509 if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
510 return errors::InvalidArgument("Cannot parse tensor from proto: ",
511 tensor_proto.DebugString());
512 }
513
514 Status status;
515 if (alloc_attrs.on_host()) {
516 *tensor = parsed;
517 } else {
518 Allocator* allocator;
519 {
520 mutex_lock lock(mu_);
521 allocator = GetAllocatorLocked(alloc_attrs);
522 }
523 Tensor copy(allocator, parsed.dtype(), parsed.shape());
524 TF_RETURN_IF_ERROR(
525 device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©));
526 *tensor = copy;
527 }
528 VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
529 return status;
530 }
531
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)532 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
533 const AllocatorAttributes alloc_attrs,
534 Tensor* tensor) {
535 VLOG(1) << "XlaDevice::MakeTensorFromProto";
536 XlaDeviceContext* device_context;
537 TF_ASSIGN_OR_RETURN(device_context, GetDeviceContextDefault());
538 return MakeTensorFromProto(device_context, tensor_proto, alloc_attrs, tensor);
539 }
540
SetAllowsSyncOnCompletion(bool sync_on_completion)541 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
542 mutex_lock lock(mu_);
543 sync_on_completion_ = sync_on_completion;
544 }
545
AllowsSyncOnCompletion() const546 bool XlaDevice::AllowsSyncOnCompletion() const {
547 mutex_lock lock(mu_);
548 return sync_on_completion_;
549 }
550
SetHandleDeviceErrorCallback(std::function<Status ()> callback)551 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
552 mutex_lock lock(mu_);
553 device_error_callback_ = callback;
554 }
555
HandleDeviceError()556 Status XlaDevice::HandleDeviceError() {
557 std::function<Status()> local_device_error_callback;
558 {
559 mutex_lock lock(mu_);
560 local_device_error_callback = device_error_callback_;
561 }
562 if (local_device_error_callback != nullptr) {
563 return local_device_error_callback();
564 }
565 return OkStatus();
566 }
567
RefreshStatus()568 Status XlaDevice::RefreshStatus() {
569 std::shared_ptr<se::Stream> stream;
570 {
571 mutex_lock lock(mu_);
572 stream = stream_;
573 }
574 if (!stream) {
575 return OkStatus();
576 }
577 Status status = stream->RefreshStatus();
578 if (!status.ok()) {
579 // Ignore errors from HandleDeviceError, since by definition the status is
580 // already non-ok, so there's nothing extra to report if HandleDeviceError
581 // itself returns an error.
582 HandleDeviceError().IgnoreError();
583 }
584 return status;
585 }
586
RegisterXlaDeviceKernels(const char * device,const char * jit_device)587 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
588 const char* jit_device) {
589 // Any op assigned to the device that isn't rewritten by the graph rewriter
590 // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
591 // it just-in-time.
592 auto factory = [](OpKernelConstruction* context) -> OpKernel* {
593 return new XlaCompileOnDemandOp(context);
594 };
595 XlaOpRegistry::RegisterCompilationKernels();
596 XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
597 for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
598 jit_device,
599 /*include_compilation_only_kernels=*/false)) {
600 KernelDef* def = new KernelDef(*jit_def);
601 const std::unordered_set<std::string>* constant_inputs =
602 XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
603
604 for (const std::string& arg_name : *constant_inputs) {
605 def->add_host_memory_arg(arg_name);
606 }
607
608 def->set_device_type(device);
609 registrations->op_kernel_registrars.emplace_back(
610 new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
611 factory));
612 }
613 return registrations;
614 }
615
616 } // namespace tensorflow
617