xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_launch_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/cleanup/cleanup.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/jit/defs.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
26 #include "tensorflow/compiler/xla/client/client_library.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/gpu_device_context.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/resource_mgr.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/refcount.h"
42 #include "tensorflow/core/util/stream_executor_util.h"
43 
44 namespace tensorflow {
45 namespace {
46 using xla::ScopedShapedBuffer;
47 using xla::ShapedBuffer;
48 
49 // Fetch the platform Id from device.
XlaPlatformInfoFromDevice(DeviceBase * device_base)50 se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) {
51   auto device = static_cast<Device*>(device_base);
52   se::Platform::Id platform_id = nullptr;
53   if (device->device_type() == DEVICE_CPU) {
54     platform_id = se::host::kHostPlatformId;
55   }
56 
57   return platform_id;
58 }
59 
60 }  // anonymous namespace
61 
VariableInfo(int index,absl::string_view name,Var * var,const std::optional<ManagedStackTrace> & definition_stack_trace)62 VariableInfo::VariableInfo(
63     int index, absl::string_view name, Var* var,
64     const std::optional<ManagedStackTrace>& definition_stack_trace)
65     : index_(index),
66       name_(name),
67       var_(var),
68       definition_stack_trace_(definition_stack_trace) {}
69 
VariableInfo(VariableInfo && other)70 VariableInfo::VariableInfo(VariableInfo&& other)
71     : index_(other.index_),
72       var_(other.var_),
73       definition_stack_trace_(other.definition_stack_trace_),
74       lock_held_(other.lock_held_) {
75   other.index_ = -1;
76   other.var_ = nullptr;
77 }
78 
operator =(VariableInfo && other)79 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
80   index_ = other.index_;
81   var_ = other.var_;
82   lock_held_ = other.lock_held_;
83   definition_stack_trace_ = other.definition_stack_trace_;
84 
85   other.index_ = -1;
86   other.var_ = nullptr;
87 
88   return *this;
89 }
90 
~VariableInfo()91 VariableInfo::~VariableInfo() {
92   // Release the variable's lock if we hold it. Ensures that the lock is
93   // released even on error.  It does not matter in what order we release the
94   // locks.
95   if (var()) {
96     if (lock_held()) {
97       var()->mu()->unlock();
98     }
99 
100     // Unref the variable so it can be released by ResourceManager.
101     var()->Unref();
102   }
103 }
104 
GetVariableInfosFromInputs(ResourceMgr * rm,DeviceBase * dev,absl::Span<const Tensor * const> inputs,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)105 Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
106                                   absl::Span<const Tensor* const> inputs,
107                                   absl::Span<const int> variable_indices,
108                                   std::vector<VariableInfo>* result) {
109   result->clear();
110   result->reserve(variable_indices.size());
111   for (int var_idx : variable_indices) {
112     Var* variable = nullptr;
113     ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
114     if (handle.device() != dev->attributes().name()) {
115       std::string definition_location =
116           DefinitionLocationMsg(handle.definition_stack_trace());
117       return errors::InvalidArgument(
118           "Trying to access resource ", handle.name(), definition_location,
119           " located in device ", handle.device(), " from device ",
120           dev->attributes().name(),
121           "\n Cf. "
122           "https://www.tensorflow.org/xla/"
123           "known_issues#tfvariable_on_a_different_device");
124     }
125     TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
126         handle.container(), handle.name(), &variable, [](Var** ptr) {
127           // This var is uninitialized for now.
128           *ptr = new Var(DT_INVALID);
129           return OkStatus();
130         }));
131     result->emplace_back(var_idx, handle.name(), variable,
132                          handle.definition_stack_trace());
133   }
134   return OkStatus();
135 }
136 
InputsFromContext(OpKernelContext * ctx)137 std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
138   std::vector<const Tensor*> inputs;
139   inputs.reserve(ctx->num_inputs());
140   for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
141     inputs.push_back(&ctx->input(input_idx));
142   }
143   return inputs;
144 }
145 
LockVariables(absl::Span<VariableInfo * > variables)146 Status LockVariables(absl::Span<VariableInfo*> variables) {
147   std::vector<int> lock_order(variables.size());
148   std::iota(lock_order.begin(), lock_order.end(), 0);
149 
150   // VariableInfoComparator orders all empty VariableInfo instances as
151   // equivalent so it looks like we may want to stable sort these to maintain a
152   // deterministic order between the empty VariableInfo instances.  However
153   // since we're sorting by pointer value the sort is pretty non-deterministic
154   // anyway so we don't bother using std::stable_sort for now.
155   absl::c_sort(lock_order, [&](int a, int b) {
156     if (variables[a]->var() && variables[b]->var()) {
157       return variables[a]->var()->mu() < variables[b]->var()->mu();
158     }
159 
160     // Move all the empty VariableInfo instances to the end.
161     return variables[a]->var() != nullptr;
162   });
163 
164   mutex* prev = nullptr;
165   for (int i : lock_order) {
166     Var* variable = variables[i]->var();
167     if (variable == nullptr) {
168       // All empty VariableInfo instances are at the end of the order
169       // so we're done.
170       break;
171     }
172     mutex* mu = variable->mu();
173     if (prev == mu) {
174       // It is an error to pass the same variable handle twice to the same XLA
175       // cluster because we would not handle variable updates correctly.  Any
176       // locks we have already acquired will be released when the VariableInfo
177       // objects are destroyed.
178       // TODO(b/128495870) Add support for passing aliased resource variables.
179       return errors::Unimplemented("Duplicate variable passed to XLA cluster");
180     }
181     VLOG(4) << "Acquiring lock for variable "
182             << reinterpret_cast<void*>(variable);
183     mu->lock();
184     variables[i]->set_lock_held();
185     prev = mu;
186   }
187   VLOG(4) << "Finished acquiring variable locks.";
188   return OkStatus();
189 }
190 
LockVariables(absl::Span<VariableInfo> variables)191 Status LockVariables(absl::Span<VariableInfo> variables) {
192   std::vector<VariableInfo*> variable_ptrs;
193   variable_ptrs.reserve(variables.size());
194   for (auto& var : variables) {
195     variable_ptrs.push_back(&var);
196   }
197   return LockVariables(absl::MakeSpan(variable_ptrs));
198 }
199 
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,absl::Span<VariableInfo const> variable_infos,ResourceVarsSnapshot * result)200 Status SnapshotResourceVariables(OpKernelContext* ctx,
201                                  absl::Span<const int> variable_indices,
202                                  absl::Span<VariableInfo const> variable_infos,
203                                  ResourceVarsSnapshot* result) {
204   for (int i = 0, end = variable_indices.size(); i < end; i++) {
205     Var* var = variable_infos[i].var();
206     (*result)[variable_indices[i]] =
207         var ? absl::make_optional(*var->tensor()) : std::nullopt;
208   }
209   return OkStatus();
210 }
211 
XlaComputationLaunchContext(xla::LocalClient * client,se::DeviceMemoryAllocator * xla_allocator,int device_ordinal,bool allocate_xla_tensors,bool use_multiple_streams)212 XlaComputationLaunchContext::XlaComputationLaunchContext(
213     xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
214     int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
215     : client_(client),
216       xla_allocator_(xla_allocator),
217       allocate_xla_tensors_(allocate_xla_tensors),
218       use_multiple_streams_(use_multiple_streams),
219       device_ordinal_(device_ordinal) {
220   if (use_multiple_streams_) {
221     CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
222                                     "be allocating XLA tensors!";
223   }
224 }
225 
226 // Fills in `execution_input` with `buffer` for `index`.
PopulateExecutionInputBuffer(xla::ExecutionInput & execution_input,xla::ShapeIndex index,se::DeviceMemoryBase buffer,bool donate_buffer,int device_ordinal,se::DeviceMemoryAllocator * allocator)227 static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
228                                          xla::ShapeIndex index,
229                                          se::DeviceMemoryBase buffer,
230                                          bool donate_buffer, int device_ordinal,
231                                          se::DeviceMemoryAllocator* allocator) {
232   xla::MaybeOwningDeviceMemory* in_buffer =
233       execution_input.MutableBuffer(index);
234   if (donate_buffer) {
235     // Here we pass ownership of the buffer to execution_input without releasing
236     // ownership from the caller of PopulateExecutionInputBuffer. If execution
237     // succeeds, we'll take back that duplicate ownership in
238     // GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will
239     // release that duplicate ownership automatically.
240     *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
241   } else {
242     *in_buffer = buffer;
243   }
244 }
245 
246 StatusOr<std::vector<xla::ExecutionInput>>
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,const std::map<int,const Tensor * > & resource_vars,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias)247 XlaComputationLaunchContext::PopulateInputs(
248     OpKernelContext* ctx,
249     const XlaCompiler::CompilationResult* compilation_result,
250     const std::map<int, const Tensor*>& resource_vars,
251     int missing_ctx_input_prefix,
252     const xla::HloInputOutputAliasConfig& input_output_alias) {
253   std::vector<xla::ExecutionInput> arguments;
254   arguments.reserve(compilation_result->xla_input_shapes.size());
255 
256   for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
257        ++i) {
258     int arg_num = compilation_result->input_mapping[i];
259     CHECK_GE(arg_num, missing_ctx_input_prefix);
260     const xla::Shape& device_shape = compilation_result->xla_input_shapes[i];
261     const xla::Shape& host_shape =
262         xla::ShapeUtil::DeviceShapeToHostShape(device_shape);
263 
264     bool is_resource_variable = resource_vars.count(arg_num);
265     bool is_updated_resource_variable =
266         is_resource_variable &&
267         absl::c_any_of(compilation_result->resource_updates,
268                        [&](const XlaCompiler::ResourceUpdate& update) {
269                          // XlaCompiler records `arg_num` (instead of kernel
270                          // parameters) in `resource_updates`.
271                          return update.input_index == arg_num &&
272                                 update.modified;
273                        });
274 
275     const Tensor* t = is_resource_variable
276                           ? resource_vars.at(arg_num)
277                           : &(ctx->input(arg_num - missing_ctx_input_prefix));
278     CHECK(t);
279     bool donate_buffer =
280         t->RefCountIsOne() && is_updated_resource_variable &&
281         input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
282     VLOG(3) << "Processing input: " << i
283             << "; is_resource_variable=" << is_resource_variable
284             << "; is_updated_resource_variable=" << is_updated_resource_variable
285             << "; donate_buffer=" << donate_buffer;
286 
287     if (use_multiple_streams_) {
288       CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
289           << "Must have a stream available when using XLA tensors!";
290       XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
291       CHECK(xla_tensor);
292       xla_tensor->WaitForDefinitionEventOnStream(
293           ctx->op_device_context()->stream());
294     }
295 
296     arguments.emplace_back(device_shape, host_shape);
297     xla::ExecutionInput& execution_input = arguments.back();
298     se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
299     PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
300                                  donate_buffer, device_ordinal_,
301                                  xla_allocator_);
302   }
303   return std::move(arguments);
304 }
305 
306 // Construct the tensor for the given type and buffer.
MakeTensor(DataType dtype,const TensorShape & shape,se::DeviceMemoryBase buffer,Allocator * allocator)307 static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
308                          se::DeviceMemoryBase buffer, Allocator* allocator) {
309   size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
310   auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
311                                             buffer.size(), allocator);
312   Tensor t(dtype, shape, tensor_buffer);
313   tensor_buffer->Unref();
314   return t;
315 }
316 
317 // Get aliased tensor from output, or make a new one for the corresponding
318 // output operation. Transfers ownership of the buffer from output to the
319 // returned tensor.
GetOrCreateTensorForOutput(xla::ScopedShapedBuffer & output,int output_num,OpKernelContext * ctx,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const int> input_mapping,const std::map<int,const Tensor * > & resource_vars_snapshots,DataType output_dtype,const TensorShape & output_shape,Allocator * output_allocator,bool allocate_xla_tensors,se::Stream * stream,bool use_multiple_streams,std::shared_ptr<se::Event> definition_event)320 static StatusOr<Tensor> GetOrCreateTensorForOutput(
321     xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx,
322     int missing_ctx_input_prefix,
323     const xla::HloInputOutputAliasConfig& input_output_alias,
324     absl::Span<const int> input_mapping,
325     const std::map<int, const Tensor*>& resource_vars_snapshots,
326     DataType output_dtype, const TensorShape& output_shape,
327     Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
328     bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
329   xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
330                                      ? xla::ShapeIndex({output_num})
331                                      : xla::ShapeIndex({});
332   CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
333   if (std::optional<xla::HloInputOutputAliasConfig::Alias> alias =
334           input_output_alias.GetAliasedParameter(output_index)) {
335     VLOG(3) << "Found alias: " << alias->ToString();
336     int tf_param =
337         input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
338     const Tensor input_tensor =
339         ctx->input(tf_param).dtype() != DT_RESOURCE
340             ? ctx->input(tf_param)
341             : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
342     se::DeviceMemoryBase input_buffer =
343         XlaTensor::DeviceMemoryFromTensor(input_tensor);
344     se::DeviceMemoryBase output_buffer = output.buffer({output_num});
345     if (input_buffer.opaque() == output_buffer.opaque()) {
346       // In the case of a donated buffer, both input_tensor and output think
347       // they have ownership of the buffer (see comment in
348       // PopulateExecutionInputBuffer). Release ownership from output to avoid
349       // double free.
350       output.set_buffer(se::OwningDeviceMemory(), {output_num});
351       return input_tensor;
352     }
353   }
354 
355   if (allocate_xla_tensors) {
356     Tensor output_tensor;
357     TF_RETURN_IF_ERROR(
358         ctx->allocate_temp(output_dtype, output_shape, &output_tensor));
359     if (output_tensor.TotalBytes() > 0) {
360       XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
361       TF_RET_CHECK(xla_tensor);
362       xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
363       if (use_multiple_streams) {
364         xla_tensor->ResetDefinitionEvent(definition_event, stream);
365       }
366     }
367     return output_tensor;
368   }
369 
370   se::DeviceMemoryBase output_buffer = output.buffer({output_num});
371   Tensor output_tensor =
372       MakeTensor(output_dtype, output_shape, output_buffer, output_allocator);
373   output.set_buffer(se::OwningDeviceMemory(), {output_num});
374   return output_tensor;
375 }
376 
377 // Sets output `output_num` for `ctx` provided it is known at a compile time.
SetOutputForConstant(OpKernelContext * ctx,se::Stream * stream,const XlaCompiler::CompilationResult * compilation_result,int output_num)378 static Status SetOutputForConstant(
379     OpKernelContext* ctx, se::Stream* stream,
380     const XlaCompiler::CompilationResult* compilation_result, int output_num) {
381   CHECK(compilation_result->outputs[output_num].is_constant);
382   const Tensor& const_tensor =
383       compilation_result->outputs[output_num].constant_value;
384   Tensor* output_tensor;
385   if (stream && const_tensor.TotalBytes() > 0) {
386     // Copy host -> device. (Empty tensors don't have backing buffers.)
387     // Manually allocate memory using an XlaTensorBuffer so we can allocate
388     // as much memory as the device requires (as given by
389     // GetByteSizeRequirement). This avoids XlaTransferManager having to
390     // reallocate the device buffer later.
391     VLOG(1) << "Constant output tensor on device";
392 
393     TF_RETURN_IF_ERROR(
394         ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
395     Device* device = dynamic_cast<Device*>(ctx->device());
396     if (device == nullptr) {
397       return errors::Internal("DeviceBase was not a Device.");
398     }
399     ctx->op_device_context()->CopyCPUTensorToDevice(
400         &const_tensor, device, output_tensor,
401         [&](Status status) { TF_CHECK_OK(status); });
402 
403     if (device->device_type() == DEVICE_GPU) {
404       // The GPUDeviceContext enqueues the host->device transfer in a
405       // separate stream from the main compute stream. We must ensure the
406       // compute stream is synchronized with the host->device transfer
407       // stream now otherwise we will create a race condition.
408       auto* gpu_device_context =
409           static_cast<GPUDeviceContext*>(ctx->op_device_context());
410       gpu_device_context->stream()->ThenWaitFor(
411           gpu_device_context->host_to_device_stream());
412     }
413   } else {
414     // No copy required.
415     ctx->set_output(output_num, const_tensor);
416     output_tensor = ctx->mutable_output(output_num);
417   }
418   return OkStatus();
419 }
420 
GetOrCreateResourceVar(OpKernelContext * ctx,const ResourceHandle & handle,const XlaCompiler::ResourceUpdate & write)421 static StatusOr<Var*> GetOrCreateResourceVar(
422     OpKernelContext* ctx, const ResourceHandle& handle,
423     const XlaCompiler::ResourceUpdate& write) {
424   Var* variable = nullptr;
425   TF_RETURN_IF_ERROR(
426       LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
427         *ptr = new Var(write.type);
428         return OkStatus();
429       }));
430   return variable;
431 }
432 
GatherVariableInfo(OpKernelContext * ctx,const XlaCompiler::CompilationResult & compilation_result,int missing_ctx_input_prefix)433 StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
434     OpKernelContext* ctx,
435     const XlaCompiler::CompilationResult& compilation_result,
436     int missing_ctx_input_prefix) {
437   std::vector<VariableInfo> out;
438   out.reserve(compilation_result.resource_updates.size());
439   for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
440     const XlaCompiler::ResourceUpdate& write =
441         compilation_result.resource_updates[i];
442     int actual_input_index = write.input_index - missing_ctx_input_prefix;
443     if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
444       return errors::Internal("Invalid input index for variable write.");
445     }
446 
447     const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
448     TF_ASSIGN_OR_RETURN(Var * variable,
449                         GetOrCreateResourceVar(ctx, handle, write));
450     out.emplace_back(actual_input_index, handle.name(), variable,
451                      handle.definition_stack_trace());
452   }
453   return std::move(out);
454 }
455 
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,ScopedShapedBuffer output,int missing_ctx_input_prefix,absl::Span<VariableInfo> variable_infos,const xla::HloInputOutputAliasConfig & input_output_alias,const std::map<int,const Tensor * > & resource_vars)456 Status XlaComputationLaunchContext::PopulateOutputs(
457     OpKernelContext* ctx,
458     const XlaCompiler::CompilationResult* compilation_result,
459     ScopedShapedBuffer output, int missing_ctx_input_prefix,
460     absl::Span<VariableInfo> variable_infos,
461     const xla::HloInputOutputAliasConfig& input_output_alias,
462     const std::map<int, const Tensor*>& resource_vars) {
463   se::Stream* stream =
464       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
465   Allocator* allocator = ctx->device()->GetAllocator({});
466 
467   // Computation output should always be a tuple.
468   VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
469   VLOG(2) << "Result tuple shape (on device): "
470           << output.on_device_shape().DebugString();
471   CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
472 
473   // If the on-host-shape isn't a tuple, create a new single-element tuple
474   // buffer with a nullptr root index table. This allows the code below to treat
475   // output as a tuple unconditionally.
476   if (!output.on_host_shape().IsTuple()) {
477     ShapedBuffer nontuple_buffer = output.release();
478     ShapedBuffer buffer(
479         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
480         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
481         output.device_ordinal());
482     buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
483                                      /*source_base_index=*/{},
484                                      /*target_base_index=*/{0});
485     output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
486   }
487 
488   std::shared_ptr<se::Event> definition_event;
489   if (use_multiple_streams_) {
490     definition_event = std::make_shared<se::Event>(stream->parent());
491     if (!definition_event->Init()) {
492       return errors::Internal("Failed to initialize tensor definition event.");
493     }
494     stream->ThenRecordEvent(definition_event.get());
495   }
496 
497   for (const XlaOutputDescription& descr : compilation_result->outputs) {
498     if (descr.type == DT_VARIANT) {
499       return errors::Unimplemented(
500           "Support for TensorList crossing the XLA/TF boundary "
501           "is not implemented");
502     }
503   }
504 
505   std::vector<TensorShape> output_tensor_shapes;
506   output_tensor_shapes.reserve(ctx->num_outputs());
507   if (output.on_host_shape().is_dynamic()) {
508     const se::Platform* platform = nullptr;
509     if (stream != nullptr) {
510       platform = stream->parent()->platform();
511     } else {
512       // Stream is not set for the host platform.
513       TF_ASSIGN_OR_RETURN(platform,
514                           se::MultiPlatformManager::PlatformWithId(
515                               XlaPlatformInfoFromDevice(ctx->device())));
516     }
517     TF_ASSIGN_OR_RETURN(auto transfer_manager,
518                         xla::TransferManager::GetForPlatform(platform));
519 
520     xla::Shape output_device_shape = output.on_device_shape();
521     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
522         stream, &output, &output_device_shape));
523 
524     output.set_shapes(output_device_shape, output_device_shape);
525     for (int i = 0; i < ctx->num_outputs(); ++i) {
526       const xla::Shape& subshape =
527           xla::ShapeUtil::GetSubshape(output_device_shape, {i});
528       TensorShape shape;
529       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
530       output_tensor_shapes.push_back(shape);
531     }
532   } else {
533     for (int i = 0; i < ctx->num_outputs(); ++i) {
534       output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
535     }
536   }
537 
538   // Copy XLA results to the OpOutputList.
539   int output_num = 0;
540   for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
541     const TensorShape& shape = output_tensor_shapes[i];
542     const DataType& type = compilation_result->outputs[i].type;
543     VLOG(2) << "Populating output for retval " << i << " shape "
544             << shape.DebugString() << " type " << DataTypeString(type);
545 
546     if (compilation_result->outputs[i].is_constant) {
547       TF_RETURN_IF_ERROR(
548           SetOutputForConstant(ctx, stream, compilation_result, i));
549     } else if (type == DT_RESOURCE) {
550       int input_index =
551           compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
552       TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
553           << "Invalid input for outputs " << i << ": " << input_index;
554       ctx->set_output(i, ctx->input(input_index));
555     } else {
556       TF_ASSIGN_OR_RETURN(
557           Tensor output_tensor,
558           GetOrCreateTensorForOutput(
559               output, output_num, ctx, missing_ctx_input_prefix,
560               input_output_alias, compilation_result->input_mapping,
561               resource_vars, ctx->expected_output_dtype(i), shape, allocator,
562               allocate_xla_tensors_, stream, use_multiple_streams_,
563               definition_event));
564       ctx->set_output(i, output_tensor);
565       ++output_num;
566     }
567   }
568 
569   // input_index -> index into variable_infos.
570   absl::flat_hash_map<int, int> variable_info_lookup;
571   for (int i = 0; i < variable_infos.size(); i++) {
572     variable_info_lookup.emplace(variable_infos[i].index(), i);
573   }
574 
575   // Apply variable updates, if any.
576   for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
577        ++i) {
578     const XlaCompiler::ResourceUpdate& write =
579         compilation_result->resource_updates[i];
580     int actual_input_index = write.input_index - missing_ctx_input_prefix;
581     CHECK_GE(actual_input_index, 0);
582     CHECK_LT(actual_input_index, ctx->num_inputs());
583     Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
584     CHECK(var);
585 
586     VLOG(2) << "Updating variable #" << i
587             << " at input index: " << actual_input_index << " with shape "
588             << write.shape.DebugString() << "; variable tensor has shape: "
589             << var->tensor()->shape().DebugString();
590 
591     if (var->is_initialized && var->tensor()->dtype() != write.type) {
592       return errors::Internal("Mismatched type in variable write");
593     }
594 
595     TF_ASSIGN_OR_RETURN(
596         Tensor output_tensor,
597         GetOrCreateTensorForOutput(output, output_num, ctx,
598                                    missing_ctx_input_prefix, input_output_alias,
599                                    compilation_result->input_mapping,
600                                    resource_vars, write.type, write.shape,
601                                    allocator, allocate_xla_tensors_, stream,
602                                    use_multiple_streams_, definition_event));
603     var->is_initialized |= write.modified;
604     *var->tensor() = output_tensor;
605     ++output_num;
606   }
607   return OkStatus();
608 }
609 
610 StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_args,Device * device)611 XlaComputationLaunchContext::BuildXlaCompilerArguments(
612     absl::Span<int const> must_be_constant_idxs,
613     absl::Span<const Tensor* const> inputs,
614     absl::Span<VariableInfo const> variable_args, Device* device) {
615   CHECK(absl::c_is_sorted(must_be_constant_idxs));
616   VLOG(2) << "Must be const args: {"
617           << absl::StrJoin(must_be_constant_idxs, ",") << "} out of "
618           << inputs.size() << " args";
619   std::vector<XlaCompiler::Argument> out;
620   out.resize(inputs.size());
621 
622   // TODO(cheshire): Avoid duplication with framework/op_kernel.h
623   DeviceContext* device_context = nullptr;
624   TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
625   bool using_default_context = false;
626   auto cleanup = absl::MakeCleanup([&] {
627     if (device_context != nullptr && !using_default_context) {
628       device_context->Unref();
629     }
630   });
631   if (device_context == nullptr) {
632     using_default_context = true;
633     auto* dev_info = device->tensorflow_accelerator_device_info();
634     if (dev_info) device_context = dev_info->default_context;
635   }
636 
637   absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
638   for (const VariableInfo& info : variable_args) {
639     CHECK(!info.var() || info.lock_held())
640         << "Need to hold the lock on resource variables "
641            "before calling BuildXlaCompilerArguments";
642     variable_info_lookup.emplace(info.index(), &info);
643   }
644 
645   for (int64_t input_num = 0; input_num < inputs.size(); ++input_num) {
646     const Tensor* input = inputs[input_num];
647 
648     XlaCompiler::Argument& arg = out[input_num];
649     if (variable_info_lookup.count(input_num)) {
650       // Handles resource variables.
651       TF_RET_CHECK(input->dtype() == DT_RESOURCE);
652       const VariableInfo& variable = *variable_info_lookup[input_num];
653       arg.name = std::string(variable.name());
654       arg.kind = XlaCompiler::Argument::kResource;
655       arg.resource_kind = XlaResource::kVariable;
656       arg.definition_stack_trace = variable.definition_stack_trace();
657       if (variable.var() && variable.var()->is_initialized) {
658         const Tensor* value = variable.var()->tensor();
659         arg.type = value->dtype();
660         arg.shape = value->shape();
661         arg.initialized = true;
662       } else {
663         // The values of uninitialized variables are not passed as inputs, since
664         // they are meaningless. However, it is legal to assign to a resource
665         // variable for the first time inside the XLA computation, so we do
666         // permit uninitialized variables.
667         arg.initialized = false;
668         arg.type = DT_INVALID;
669         arg.shape = TensorShape();
670       }
671 
672       if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
673         TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
674         const Tensor* value = variable.var()->tensor();
675         Tensor value_on_host(value->dtype(), value->shape());
676         if (!device_context) {
677           value_on_host = *value;
678         } else {
679           TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
680               value, "", device, &value_on_host));
681         }
682         arg.kind = XlaCompiler::Argument::kConstantResource;
683         arg.constant_value = value_on_host;
684       }
685     } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
686       arg.kind = XlaCompiler::Argument::kConstant;
687       arg.type = input->dtype();
688       arg.shape = input->shape();
689       arg.constant_value = *input;
690     } else {
691       // Normal inputs.
692       TF_RET_CHECK(input->dtype() != DT_RESOURCE);
693       if (input->NumElements() > 0) {
694         arg.kind = XlaCompiler::Argument::kParameter;
695       } else {
696         arg.kind = XlaCompiler::Argument::kConstant;
697         arg.constant_value = *input;
698       }
699       arg.type = input->dtype();
700       arg.shape = input->shape();
701     }
702   }
703 
704   return out;
705 }
706 
707 }  // namespace tensorflow
708