1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/cleanup/cleanup.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/jit/defs.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
26 #include "tensorflow/compiler/xla/client/client_library.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/gpu_device_context.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/resource_mgr.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/refcount.h"
42 #include "tensorflow/core/util/stream_executor_util.h"
43
44 namespace tensorflow {
45 namespace {
46 using xla::ScopedShapedBuffer;
47 using xla::ShapedBuffer;
48
49 // Fetch the platform Id from device.
XlaPlatformInfoFromDevice(DeviceBase * device_base)50 se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) {
51 auto device = static_cast<Device*>(device_base);
52 se::Platform::Id platform_id = nullptr;
53 if (device->device_type() == DEVICE_CPU) {
54 platform_id = se::host::kHostPlatformId;
55 }
56
57 return platform_id;
58 }
59
60 } // anonymous namespace
61
VariableInfo(int index,absl::string_view name,Var * var,const std::optional<ManagedStackTrace> & definition_stack_trace)62 VariableInfo::VariableInfo(
63 int index, absl::string_view name, Var* var,
64 const std::optional<ManagedStackTrace>& definition_stack_trace)
65 : index_(index),
66 name_(name),
67 var_(var),
68 definition_stack_trace_(definition_stack_trace) {}
69
VariableInfo(VariableInfo && other)70 VariableInfo::VariableInfo(VariableInfo&& other)
71 : index_(other.index_),
72 var_(other.var_),
73 definition_stack_trace_(other.definition_stack_trace_),
74 lock_held_(other.lock_held_) {
75 other.index_ = -1;
76 other.var_ = nullptr;
77 }
78
operator =(VariableInfo && other)79 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
80 index_ = other.index_;
81 var_ = other.var_;
82 lock_held_ = other.lock_held_;
83 definition_stack_trace_ = other.definition_stack_trace_;
84
85 other.index_ = -1;
86 other.var_ = nullptr;
87
88 return *this;
89 }
90
~VariableInfo()91 VariableInfo::~VariableInfo() {
92 // Release the variable's lock if we hold it. Ensures that the lock is
93 // released even on error. It does not matter in what order we release the
94 // locks.
95 if (var()) {
96 if (lock_held()) {
97 var()->mu()->unlock();
98 }
99
100 // Unref the variable so it can be released by ResourceManager.
101 var()->Unref();
102 }
103 }
104
GetVariableInfosFromInputs(ResourceMgr * rm,DeviceBase * dev,absl::Span<const Tensor * const> inputs,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)105 Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
106 absl::Span<const Tensor* const> inputs,
107 absl::Span<const int> variable_indices,
108 std::vector<VariableInfo>* result) {
109 result->clear();
110 result->reserve(variable_indices.size());
111 for (int var_idx : variable_indices) {
112 Var* variable = nullptr;
113 ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
114 if (handle.device() != dev->attributes().name()) {
115 std::string definition_location =
116 DefinitionLocationMsg(handle.definition_stack_trace());
117 return errors::InvalidArgument(
118 "Trying to access resource ", handle.name(), definition_location,
119 " located in device ", handle.device(), " from device ",
120 dev->attributes().name(),
121 "\n Cf. "
122 "https://www.tensorflow.org/xla/"
123 "known_issues#tfvariable_on_a_different_device");
124 }
125 TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
126 handle.container(), handle.name(), &variable, [](Var** ptr) {
127 // This var is uninitialized for now.
128 *ptr = new Var(DT_INVALID);
129 return OkStatus();
130 }));
131 result->emplace_back(var_idx, handle.name(), variable,
132 handle.definition_stack_trace());
133 }
134 return OkStatus();
135 }
136
InputsFromContext(OpKernelContext * ctx)137 std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
138 std::vector<const Tensor*> inputs;
139 inputs.reserve(ctx->num_inputs());
140 for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
141 inputs.push_back(&ctx->input(input_idx));
142 }
143 return inputs;
144 }
145
LockVariables(absl::Span<VariableInfo * > variables)146 Status LockVariables(absl::Span<VariableInfo*> variables) {
147 std::vector<int> lock_order(variables.size());
148 std::iota(lock_order.begin(), lock_order.end(), 0);
149
150 // VariableInfoComparator orders all empty VariableInfo instances as
151 // equivalent so it looks like we may want to stable sort these to maintain a
152 // deterministic order between the empty VariableInfo instances. However
153 // since we're sorting by pointer value the sort is pretty non-deterministic
154 // anyway so we don't bother using std::stable_sort for now.
155 absl::c_sort(lock_order, [&](int a, int b) {
156 if (variables[a]->var() && variables[b]->var()) {
157 return variables[a]->var()->mu() < variables[b]->var()->mu();
158 }
159
160 // Move all the empty VariableInfo instances to the end.
161 return variables[a]->var() != nullptr;
162 });
163
164 mutex* prev = nullptr;
165 for (int i : lock_order) {
166 Var* variable = variables[i]->var();
167 if (variable == nullptr) {
168 // All empty VariableInfo instances are at the end of the order
169 // so we're done.
170 break;
171 }
172 mutex* mu = variable->mu();
173 if (prev == mu) {
174 // It is an error to pass the same variable handle twice to the same XLA
175 // cluster because we would not handle variable updates correctly. Any
176 // locks we have already acquired will be released when the VariableInfo
177 // objects are destroyed.
178 // TODO(b/128495870) Add support for passing aliased resource variables.
179 return errors::Unimplemented("Duplicate variable passed to XLA cluster");
180 }
181 VLOG(4) << "Acquiring lock for variable "
182 << reinterpret_cast<void*>(variable);
183 mu->lock();
184 variables[i]->set_lock_held();
185 prev = mu;
186 }
187 VLOG(4) << "Finished acquiring variable locks.";
188 return OkStatus();
189 }
190
LockVariables(absl::Span<VariableInfo> variables)191 Status LockVariables(absl::Span<VariableInfo> variables) {
192 std::vector<VariableInfo*> variable_ptrs;
193 variable_ptrs.reserve(variables.size());
194 for (auto& var : variables) {
195 variable_ptrs.push_back(&var);
196 }
197 return LockVariables(absl::MakeSpan(variable_ptrs));
198 }
199
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,absl::Span<VariableInfo const> variable_infos,ResourceVarsSnapshot * result)200 Status SnapshotResourceVariables(OpKernelContext* ctx,
201 absl::Span<const int> variable_indices,
202 absl::Span<VariableInfo const> variable_infos,
203 ResourceVarsSnapshot* result) {
204 for (int i = 0, end = variable_indices.size(); i < end; i++) {
205 Var* var = variable_infos[i].var();
206 (*result)[variable_indices[i]] =
207 var ? absl::make_optional(*var->tensor()) : std::nullopt;
208 }
209 return OkStatus();
210 }
211
XlaComputationLaunchContext(xla::LocalClient * client,se::DeviceMemoryAllocator * xla_allocator,int device_ordinal,bool allocate_xla_tensors,bool use_multiple_streams)212 XlaComputationLaunchContext::XlaComputationLaunchContext(
213 xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
214 int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
215 : client_(client),
216 xla_allocator_(xla_allocator),
217 allocate_xla_tensors_(allocate_xla_tensors),
218 use_multiple_streams_(use_multiple_streams),
219 device_ordinal_(device_ordinal) {
220 if (use_multiple_streams_) {
221 CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
222 "be allocating XLA tensors!";
223 }
224 }
225
226 // Fills in `execution_input` with `buffer` for `index`.
PopulateExecutionInputBuffer(xla::ExecutionInput & execution_input,xla::ShapeIndex index,se::DeviceMemoryBase buffer,bool donate_buffer,int device_ordinal,se::DeviceMemoryAllocator * allocator)227 static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
228 xla::ShapeIndex index,
229 se::DeviceMemoryBase buffer,
230 bool donate_buffer, int device_ordinal,
231 se::DeviceMemoryAllocator* allocator) {
232 xla::MaybeOwningDeviceMemory* in_buffer =
233 execution_input.MutableBuffer(index);
234 if (donate_buffer) {
235 // Here we pass ownership of the buffer to execution_input without releasing
236 // ownership from the caller of PopulateExecutionInputBuffer. If execution
237 // succeeds, we'll take back that duplicate ownership in
238 // GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will
239 // release that duplicate ownership automatically.
240 *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
241 } else {
242 *in_buffer = buffer;
243 }
244 }
245
246 StatusOr<std::vector<xla::ExecutionInput>>
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,const std::map<int,const Tensor * > & resource_vars,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias)247 XlaComputationLaunchContext::PopulateInputs(
248 OpKernelContext* ctx,
249 const XlaCompiler::CompilationResult* compilation_result,
250 const std::map<int, const Tensor*>& resource_vars,
251 int missing_ctx_input_prefix,
252 const xla::HloInputOutputAliasConfig& input_output_alias) {
253 std::vector<xla::ExecutionInput> arguments;
254 arguments.reserve(compilation_result->xla_input_shapes.size());
255
256 for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
257 ++i) {
258 int arg_num = compilation_result->input_mapping[i];
259 CHECK_GE(arg_num, missing_ctx_input_prefix);
260 const xla::Shape& device_shape = compilation_result->xla_input_shapes[i];
261 const xla::Shape& host_shape =
262 xla::ShapeUtil::DeviceShapeToHostShape(device_shape);
263
264 bool is_resource_variable = resource_vars.count(arg_num);
265 bool is_updated_resource_variable =
266 is_resource_variable &&
267 absl::c_any_of(compilation_result->resource_updates,
268 [&](const XlaCompiler::ResourceUpdate& update) {
269 // XlaCompiler records `arg_num` (instead of kernel
270 // parameters) in `resource_updates`.
271 return update.input_index == arg_num &&
272 update.modified;
273 });
274
275 const Tensor* t = is_resource_variable
276 ? resource_vars.at(arg_num)
277 : &(ctx->input(arg_num - missing_ctx_input_prefix));
278 CHECK(t);
279 bool donate_buffer =
280 t->RefCountIsOne() && is_updated_resource_variable &&
281 input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
282 VLOG(3) << "Processing input: " << i
283 << "; is_resource_variable=" << is_resource_variable
284 << "; is_updated_resource_variable=" << is_updated_resource_variable
285 << "; donate_buffer=" << donate_buffer;
286
287 if (use_multiple_streams_) {
288 CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
289 << "Must have a stream available when using XLA tensors!";
290 XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
291 CHECK(xla_tensor);
292 xla_tensor->WaitForDefinitionEventOnStream(
293 ctx->op_device_context()->stream());
294 }
295
296 arguments.emplace_back(device_shape, host_shape);
297 xla::ExecutionInput& execution_input = arguments.back();
298 se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
299 PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
300 donate_buffer, device_ordinal_,
301 xla_allocator_);
302 }
303 return std::move(arguments);
304 }
305
306 // Construct the tensor for the given type and buffer.
MakeTensor(DataType dtype,const TensorShape & shape,se::DeviceMemoryBase buffer,Allocator * allocator)307 static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
308 se::DeviceMemoryBase buffer, Allocator* allocator) {
309 size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
310 auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
311 buffer.size(), allocator);
312 Tensor t(dtype, shape, tensor_buffer);
313 tensor_buffer->Unref();
314 return t;
315 }
316
317 // Get aliased tensor from output, or make a new one for the corresponding
318 // output operation. Transfers ownership of the buffer from output to the
319 // returned tensor.
GetOrCreateTensorForOutput(xla::ScopedShapedBuffer & output,int output_num,OpKernelContext * ctx,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const int> input_mapping,const std::map<int,const Tensor * > & resource_vars_snapshots,DataType output_dtype,const TensorShape & output_shape,Allocator * output_allocator,bool allocate_xla_tensors,se::Stream * stream,bool use_multiple_streams,std::shared_ptr<se::Event> definition_event)320 static StatusOr<Tensor> GetOrCreateTensorForOutput(
321 xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx,
322 int missing_ctx_input_prefix,
323 const xla::HloInputOutputAliasConfig& input_output_alias,
324 absl::Span<const int> input_mapping,
325 const std::map<int, const Tensor*>& resource_vars_snapshots,
326 DataType output_dtype, const TensorShape& output_shape,
327 Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
328 bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
329 xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
330 ? xla::ShapeIndex({output_num})
331 : xla::ShapeIndex({});
332 CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
333 if (std::optional<xla::HloInputOutputAliasConfig::Alias> alias =
334 input_output_alias.GetAliasedParameter(output_index)) {
335 VLOG(3) << "Found alias: " << alias->ToString();
336 int tf_param =
337 input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
338 const Tensor input_tensor =
339 ctx->input(tf_param).dtype() != DT_RESOURCE
340 ? ctx->input(tf_param)
341 : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
342 se::DeviceMemoryBase input_buffer =
343 XlaTensor::DeviceMemoryFromTensor(input_tensor);
344 se::DeviceMemoryBase output_buffer = output.buffer({output_num});
345 if (input_buffer.opaque() == output_buffer.opaque()) {
346 // In the case of a donated buffer, both input_tensor and output think
347 // they have ownership of the buffer (see comment in
348 // PopulateExecutionInputBuffer). Release ownership from output to avoid
349 // double free.
350 output.set_buffer(se::OwningDeviceMemory(), {output_num});
351 return input_tensor;
352 }
353 }
354
355 if (allocate_xla_tensors) {
356 Tensor output_tensor;
357 TF_RETURN_IF_ERROR(
358 ctx->allocate_temp(output_dtype, output_shape, &output_tensor));
359 if (output_tensor.TotalBytes() > 0) {
360 XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
361 TF_RET_CHECK(xla_tensor);
362 xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
363 if (use_multiple_streams) {
364 xla_tensor->ResetDefinitionEvent(definition_event, stream);
365 }
366 }
367 return output_tensor;
368 }
369
370 se::DeviceMemoryBase output_buffer = output.buffer({output_num});
371 Tensor output_tensor =
372 MakeTensor(output_dtype, output_shape, output_buffer, output_allocator);
373 output.set_buffer(se::OwningDeviceMemory(), {output_num});
374 return output_tensor;
375 }
376
377 // Sets output `output_num` for `ctx` provided it is known at a compile time.
SetOutputForConstant(OpKernelContext * ctx,se::Stream * stream,const XlaCompiler::CompilationResult * compilation_result,int output_num)378 static Status SetOutputForConstant(
379 OpKernelContext* ctx, se::Stream* stream,
380 const XlaCompiler::CompilationResult* compilation_result, int output_num) {
381 CHECK(compilation_result->outputs[output_num].is_constant);
382 const Tensor& const_tensor =
383 compilation_result->outputs[output_num].constant_value;
384 Tensor* output_tensor;
385 if (stream && const_tensor.TotalBytes() > 0) {
386 // Copy host -> device. (Empty tensors don't have backing buffers.)
387 // Manually allocate memory using an XlaTensorBuffer so we can allocate
388 // as much memory as the device requires (as given by
389 // GetByteSizeRequirement). This avoids XlaTransferManager having to
390 // reallocate the device buffer later.
391 VLOG(1) << "Constant output tensor on device";
392
393 TF_RETURN_IF_ERROR(
394 ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
395 Device* device = dynamic_cast<Device*>(ctx->device());
396 if (device == nullptr) {
397 return errors::Internal("DeviceBase was not a Device.");
398 }
399 ctx->op_device_context()->CopyCPUTensorToDevice(
400 &const_tensor, device, output_tensor,
401 [&](Status status) { TF_CHECK_OK(status); });
402
403 if (device->device_type() == DEVICE_GPU) {
404 // The GPUDeviceContext enqueues the host->device transfer in a
405 // separate stream from the main compute stream. We must ensure the
406 // compute stream is synchronized with the host->device transfer
407 // stream now otherwise we will create a race condition.
408 auto* gpu_device_context =
409 static_cast<GPUDeviceContext*>(ctx->op_device_context());
410 gpu_device_context->stream()->ThenWaitFor(
411 gpu_device_context->host_to_device_stream());
412 }
413 } else {
414 // No copy required.
415 ctx->set_output(output_num, const_tensor);
416 output_tensor = ctx->mutable_output(output_num);
417 }
418 return OkStatus();
419 }
420
GetOrCreateResourceVar(OpKernelContext * ctx,const ResourceHandle & handle,const XlaCompiler::ResourceUpdate & write)421 static StatusOr<Var*> GetOrCreateResourceVar(
422 OpKernelContext* ctx, const ResourceHandle& handle,
423 const XlaCompiler::ResourceUpdate& write) {
424 Var* variable = nullptr;
425 TF_RETURN_IF_ERROR(
426 LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
427 *ptr = new Var(write.type);
428 return OkStatus();
429 }));
430 return variable;
431 }
432
GatherVariableInfo(OpKernelContext * ctx,const XlaCompiler::CompilationResult & compilation_result,int missing_ctx_input_prefix)433 StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
434 OpKernelContext* ctx,
435 const XlaCompiler::CompilationResult& compilation_result,
436 int missing_ctx_input_prefix) {
437 std::vector<VariableInfo> out;
438 out.reserve(compilation_result.resource_updates.size());
439 for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
440 const XlaCompiler::ResourceUpdate& write =
441 compilation_result.resource_updates[i];
442 int actual_input_index = write.input_index - missing_ctx_input_prefix;
443 if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
444 return errors::Internal("Invalid input index for variable write.");
445 }
446
447 const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
448 TF_ASSIGN_OR_RETURN(Var * variable,
449 GetOrCreateResourceVar(ctx, handle, write));
450 out.emplace_back(actual_input_index, handle.name(), variable,
451 handle.definition_stack_trace());
452 }
453 return std::move(out);
454 }
455
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,ScopedShapedBuffer output,int missing_ctx_input_prefix,absl::Span<VariableInfo> variable_infos,const xla::HloInputOutputAliasConfig & input_output_alias,const std::map<int,const Tensor * > & resource_vars)456 Status XlaComputationLaunchContext::PopulateOutputs(
457 OpKernelContext* ctx,
458 const XlaCompiler::CompilationResult* compilation_result,
459 ScopedShapedBuffer output, int missing_ctx_input_prefix,
460 absl::Span<VariableInfo> variable_infos,
461 const xla::HloInputOutputAliasConfig& input_output_alias,
462 const std::map<int, const Tensor*>& resource_vars) {
463 se::Stream* stream =
464 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
465 Allocator* allocator = ctx->device()->GetAllocator({});
466
467 // Computation output should always be a tuple.
468 VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
469 VLOG(2) << "Result tuple shape (on device): "
470 << output.on_device_shape().DebugString();
471 CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
472
473 // If the on-host-shape isn't a tuple, create a new single-element tuple
474 // buffer with a nullptr root index table. This allows the code below to treat
475 // output as a tuple unconditionally.
476 if (!output.on_host_shape().IsTuple()) {
477 ShapedBuffer nontuple_buffer = output.release();
478 ShapedBuffer buffer(
479 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
480 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
481 output.device_ordinal());
482 buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
483 /*source_base_index=*/{},
484 /*target_base_index=*/{0});
485 output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
486 }
487
488 std::shared_ptr<se::Event> definition_event;
489 if (use_multiple_streams_) {
490 definition_event = std::make_shared<se::Event>(stream->parent());
491 if (!definition_event->Init()) {
492 return errors::Internal("Failed to initialize tensor definition event.");
493 }
494 stream->ThenRecordEvent(definition_event.get());
495 }
496
497 for (const XlaOutputDescription& descr : compilation_result->outputs) {
498 if (descr.type == DT_VARIANT) {
499 return errors::Unimplemented(
500 "Support for TensorList crossing the XLA/TF boundary "
501 "is not implemented");
502 }
503 }
504
505 std::vector<TensorShape> output_tensor_shapes;
506 output_tensor_shapes.reserve(ctx->num_outputs());
507 if (output.on_host_shape().is_dynamic()) {
508 const se::Platform* platform = nullptr;
509 if (stream != nullptr) {
510 platform = stream->parent()->platform();
511 } else {
512 // Stream is not set for the host platform.
513 TF_ASSIGN_OR_RETURN(platform,
514 se::MultiPlatformManager::PlatformWithId(
515 XlaPlatformInfoFromDevice(ctx->device())));
516 }
517 TF_ASSIGN_OR_RETURN(auto transfer_manager,
518 xla::TransferManager::GetForPlatform(platform));
519
520 xla::Shape output_device_shape = output.on_device_shape();
521 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
522 stream, &output, &output_device_shape));
523
524 output.set_shapes(output_device_shape, output_device_shape);
525 for (int i = 0; i < ctx->num_outputs(); ++i) {
526 const xla::Shape& subshape =
527 xla::ShapeUtil::GetSubshape(output_device_shape, {i});
528 TensorShape shape;
529 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
530 output_tensor_shapes.push_back(shape);
531 }
532 } else {
533 for (int i = 0; i < ctx->num_outputs(); ++i) {
534 output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
535 }
536 }
537
538 // Copy XLA results to the OpOutputList.
539 int output_num = 0;
540 for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
541 const TensorShape& shape = output_tensor_shapes[i];
542 const DataType& type = compilation_result->outputs[i].type;
543 VLOG(2) << "Populating output for retval " << i << " shape "
544 << shape.DebugString() << " type " << DataTypeString(type);
545
546 if (compilation_result->outputs[i].is_constant) {
547 TF_RETURN_IF_ERROR(
548 SetOutputForConstant(ctx, stream, compilation_result, i));
549 } else if (type == DT_RESOURCE) {
550 int input_index =
551 compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
552 TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
553 << "Invalid input for outputs " << i << ": " << input_index;
554 ctx->set_output(i, ctx->input(input_index));
555 } else {
556 TF_ASSIGN_OR_RETURN(
557 Tensor output_tensor,
558 GetOrCreateTensorForOutput(
559 output, output_num, ctx, missing_ctx_input_prefix,
560 input_output_alias, compilation_result->input_mapping,
561 resource_vars, ctx->expected_output_dtype(i), shape, allocator,
562 allocate_xla_tensors_, stream, use_multiple_streams_,
563 definition_event));
564 ctx->set_output(i, output_tensor);
565 ++output_num;
566 }
567 }
568
569 // input_index -> index into variable_infos.
570 absl::flat_hash_map<int, int> variable_info_lookup;
571 for (int i = 0; i < variable_infos.size(); i++) {
572 variable_info_lookup.emplace(variable_infos[i].index(), i);
573 }
574
575 // Apply variable updates, if any.
576 for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
577 ++i) {
578 const XlaCompiler::ResourceUpdate& write =
579 compilation_result->resource_updates[i];
580 int actual_input_index = write.input_index - missing_ctx_input_prefix;
581 CHECK_GE(actual_input_index, 0);
582 CHECK_LT(actual_input_index, ctx->num_inputs());
583 Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
584 CHECK(var);
585
586 VLOG(2) << "Updating variable #" << i
587 << " at input index: " << actual_input_index << " with shape "
588 << write.shape.DebugString() << "; variable tensor has shape: "
589 << var->tensor()->shape().DebugString();
590
591 if (var->is_initialized && var->tensor()->dtype() != write.type) {
592 return errors::Internal("Mismatched type in variable write");
593 }
594
595 TF_ASSIGN_OR_RETURN(
596 Tensor output_tensor,
597 GetOrCreateTensorForOutput(output, output_num, ctx,
598 missing_ctx_input_prefix, input_output_alias,
599 compilation_result->input_mapping,
600 resource_vars, write.type, write.shape,
601 allocator, allocate_xla_tensors_, stream,
602 use_multiple_streams_, definition_event));
603 var->is_initialized |= write.modified;
604 *var->tensor() = output_tensor;
605 ++output_num;
606 }
607 return OkStatus();
608 }
609
610 StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_args,Device * device)611 XlaComputationLaunchContext::BuildXlaCompilerArguments(
612 absl::Span<int const> must_be_constant_idxs,
613 absl::Span<const Tensor* const> inputs,
614 absl::Span<VariableInfo const> variable_args, Device* device) {
615 CHECK(absl::c_is_sorted(must_be_constant_idxs));
616 VLOG(2) << "Must be const args: {"
617 << absl::StrJoin(must_be_constant_idxs, ",") << "} out of "
618 << inputs.size() << " args";
619 std::vector<XlaCompiler::Argument> out;
620 out.resize(inputs.size());
621
622 // TODO(cheshire): Avoid duplication with framework/op_kernel.h
623 DeviceContext* device_context = nullptr;
624 TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
625 bool using_default_context = false;
626 auto cleanup = absl::MakeCleanup([&] {
627 if (device_context != nullptr && !using_default_context) {
628 device_context->Unref();
629 }
630 });
631 if (device_context == nullptr) {
632 using_default_context = true;
633 auto* dev_info = device->tensorflow_accelerator_device_info();
634 if (dev_info) device_context = dev_info->default_context;
635 }
636
637 absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
638 for (const VariableInfo& info : variable_args) {
639 CHECK(!info.var() || info.lock_held())
640 << "Need to hold the lock on resource variables "
641 "before calling BuildXlaCompilerArguments";
642 variable_info_lookup.emplace(info.index(), &info);
643 }
644
645 for (int64_t input_num = 0; input_num < inputs.size(); ++input_num) {
646 const Tensor* input = inputs[input_num];
647
648 XlaCompiler::Argument& arg = out[input_num];
649 if (variable_info_lookup.count(input_num)) {
650 // Handles resource variables.
651 TF_RET_CHECK(input->dtype() == DT_RESOURCE);
652 const VariableInfo& variable = *variable_info_lookup[input_num];
653 arg.name = std::string(variable.name());
654 arg.kind = XlaCompiler::Argument::kResource;
655 arg.resource_kind = XlaResource::kVariable;
656 arg.definition_stack_trace = variable.definition_stack_trace();
657 if (variable.var() && variable.var()->is_initialized) {
658 const Tensor* value = variable.var()->tensor();
659 arg.type = value->dtype();
660 arg.shape = value->shape();
661 arg.initialized = true;
662 } else {
663 // The values of uninitialized variables are not passed as inputs, since
664 // they are meaningless. However, it is legal to assign to a resource
665 // variable for the first time inside the XLA computation, so we do
666 // permit uninitialized variables.
667 arg.initialized = false;
668 arg.type = DT_INVALID;
669 arg.shape = TensorShape();
670 }
671
672 if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
673 TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
674 const Tensor* value = variable.var()->tensor();
675 Tensor value_on_host(value->dtype(), value->shape());
676 if (!device_context) {
677 value_on_host = *value;
678 } else {
679 TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
680 value, "", device, &value_on_host));
681 }
682 arg.kind = XlaCompiler::Argument::kConstantResource;
683 arg.constant_value = value_on_host;
684 }
685 } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
686 arg.kind = XlaCompiler::Argument::kConstant;
687 arg.type = input->dtype();
688 arg.shape = input->shape();
689 arg.constant_value = *input;
690 } else {
691 // Normal inputs.
692 TF_RET_CHECK(input->dtype() != DT_RESOURCE);
693 if (input->NumElements() > 0) {
694 arg.kind = XlaCompiler::Argument::kParameter;
695 } else {
696 arg.kind = XlaCompiler::Argument::kConstant;
697 arg.constant_value = *input;
698 }
699 arg.type = input->dtype();
700 arg.shape = input->shape();
701 }
702 }
703
704 return out;
705 }
706
707 } // namespace tensorflow
708