xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_execute_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
16 
17 #include <utility>
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_tensor.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/resource_mgr.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/casts.h"
44 #include "tensorflow/core/platform/tracing.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
47 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
48 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
51 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
52 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
53 #include "tensorflow/core/tpu/tpu_configuration.h"
54 #include "tensorflow/core/tpu/tpu_defs.h"
55 #include "tensorflow/core/tpu/tpu_execute.h"
56 #include "tensorflow/core/util/stream_executor_util.h"
57 #include "tensorflow/stream_executor/device_memory_allocator.h"
58 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
59 
60 namespace tensorflow {
61 namespace {
62 using ::tensorflow::tpu::CompilationCacheEntryRef;
63 using ::tensorflow::tpu::TpuCompilationCacheLookup;
64 using ::tensorflow::tpu::TpuNodeContext;
65 
66 // Looks up the input `key` in the compilation cache, populating
67 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(OpKernelContext * context,string * rendezvous_key_base,std::unique_ptr<CompilationCacheEntryRef> * entry)68 Status GetComputationCacheEntry(
69     OpKernelContext* context, string* rendezvous_key_base,
70     std::unique_ptr<CompilationCacheEntryRef>* entry) {
71   const Tensor* key;
72   TF_RETURN_IF_ERROR(context->input("key", &key));
73   profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
74   if (!TensorShapeUtils::IsVector(key->shape()) ||
75       key->shape().dim_size(0) != 3) {
76     return errors::InvalidArgument(
77         "Key argument to TPUExecute must be a 3-element vector");
78   }
79 
80   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
81   TpuCompilationCacheLookup* proto_lookup;
82   TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
83                                   tpu::kCompiledProtoCacheResourceName,
84                                   &proto_lookup));
85   core::ScopedUnref lookup_unref(proto_lookup);
86   TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
87   *rendezvous_key_base = key->vec<tstring>()(1);
88   return OkStatus();
89 }
90 
91 struct VariableUpdateMap {
92   // Maps input index to the updated output index. If the variable doesn't have
93   // an updated output, the corresponding output is set to -1.
94   absl::flat_hash_map<int, int> input_to_output;
95   // Maps output index to (the input index, whether the update is generated from
96   // compilation).
97   absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
98   // Part of the input indices that are from the compilation, in the compiled
99   // order.
100   std::vector<int> input_in_compiled_update_order;
101 };
102 
103 // Creates a VariableUpdateMap from both the compilation and the fused variable
104 // reads/updates.
BuildVariableUpdateMap(absl::Span<const TPUExecutableInfoProto::UpdateIndexPair * const> compiled_variable_updates,absl::Span<int const> fused_device_var_reads_in_computation_inputs,const std::vector<int> & fused_device_var_updates_in_computation_outputs,int64_t computation_output_count)105 xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
106     absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
107         compiled_variable_updates,
108     absl::Span<int const> fused_device_var_reads_in_computation_inputs,
109     const std::vector<int>& fused_device_var_updates_in_computation_outputs,
110     int64_t computation_output_count) {
111   VariableUpdateMap map;
112   auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
113     TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
114         << "Duplicate variable input index: " << input;
115     if (output >= 0) {
116       TF_RET_CHECK(map.output_to_input
117                        .emplace(output, std::make_pair(input, from_compilation))
118                        .second)
119           << "Duplicate variable output index: " << output;
120     }
121     return OkStatus();
122   };
123 
124   // First add the updates produced by the compilation. Not all variables are
125   // updated, and if not, they do not have an output in the XLA computation. The
126   // update output indices in the XLA computation start after the non-variable
127   // outputs.
128   int num_updated_variables = 0;
129   for (int i = 0; i < compiled_variable_updates.size(); ++i) {
130     const bool updated = compiled_variable_updates[i]->updated();
131     if (updated) ++num_updated_variables;
132   }
133   TF_RET_CHECK(num_updated_variables <= computation_output_count)
134       << num_updated_variables << " <= " << computation_output_count;
135   int64_t compiled_variable_output_index =
136       computation_output_count - num_updated_variables;
137   for (auto update : compiled_variable_updates) {
138     map.input_in_compiled_update_order.push_back(update->index());
139     if (!update->updated()) {
140       TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
141       continue;
142     }
143     TF_RETURN_IF_ERROR(
144         add_pair(update->index(), compiled_variable_output_index, true));
145     ++compiled_variable_output_index;
146   }
147 
148   // Now add the updates from the attributes.
149   TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
150                fused_device_var_updates_in_computation_outputs.size());
151   for (int64_t i = 0; i < fused_device_var_reads_in_computation_inputs.size();
152        ++i) {
153     TF_RETURN_IF_ERROR(
154         add_pair(fused_device_var_reads_in_computation_inputs[i],
155                  fused_device_var_updates_in_computation_outputs[i], false));
156   }
157   return map;
158 }
159 
160 // Buffers representing the inputs to a computation.
161 struct InputBuffers {
InputBufferstensorflow::__anon2876098b0111::InputBuffers162   explicit InputBuffers(xla::Shape device_shape)
163       : buffers(std::move(device_shape)) {}
164 
165   InputBuffers(const InputBuffers&) = delete;
166   InputBuffers& operator=(const InputBuffers&) = delete;
167 
168   ~InputBuffers() = default;
169 
ToShapedBuffertensorflow::__anon2876098b0111::InputBuffers170   xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
171                                    se::DeviceMemoryAllocator* allocator,
172                                    int device_ordinal) {
173     CHECK_NE(allocator, nullptr);
174     xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
175                                     device_ordinal);
176     shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
177         [](const xla::MaybeOwningDeviceMemory& buffer) {
178           return buffer.AsDeviceMemoryBase();
179         }));
180     return shaped_buffer;
181   }
182 
183   // Describes the buffer tree.
184   xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
185 
186   // Information about resource variables passed directly to TPUExecute.
187   std::vector<VariableInfo> variables;
188 
189   // Mapping from input index to offsets in 'variables'. < 0 if the input does
190   // not correspond to a variable in 'variables'.
191   std::vector<int> variable_index;
192 };
193 
194 // Builds an InputBuffers object that describes the inputs to the computation.
BuildComputationInputs(OpKernelContext * context,const xla::Shape & input_host_shape,const VariableUpdateMap & variable_updates,xla::Backend * backend,int device_ordinal,se::Stream * stream)195 xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
196     OpKernelContext* context, const xla::Shape& input_host_shape,
197     const VariableUpdateMap& variable_updates, xla::Backend* backend,
198     int device_ordinal, se::Stream* stream) {
199   profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
200   OpInputList arg_list;
201   TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
202 
203   if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
204     return errors::InvalidArgument(
205         "Number of parameters (", arg_list.size(),
206         ") does not match input shape: ",
207         xla::ShapeUtil::TupleElementCount(input_host_shape));
208   }
209 
210   auto validate_shape = [&](int i, const Tensor& tensor) {
211     const xla::Shape& expected =
212         xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
213     VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
214     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
215 
216     if (xla_tensor == nullptr) {
217       // FromTensor failed; tensor must be empty.
218       if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
219         return errors::InvalidArgument(
220             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
221             context->op_kernel().requested_input(i), "). Expected ",
222             expected.DebugString(),
223             "; got empty tensor. If you are running "
224             "with TF2 TPU, make sure you set `drop_remainder=False` when "
225             "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
226             "size can be handled");
227       }
228     } else {
229       // Compare host shapes, easier than getting the expected device shape.
230       const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
231       if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
232         return errors::InvalidArgument(
233             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
234             context->op_kernel().requested_input(i), "). Expected ",
235             expected.DebugString(), "; got ", xla_shape.DebugString());
236       }
237     }
238 
239     return OkStatus();
240   };
241 
242   // Iterate over the inputs, validating the shapes of non-variable inputs,
243   // and creating a VariableInfo object for each variable. We consider variable
244   // inputs in a separate phase because we must acquire variable locks in order.
245   std::vector<VariableInfo> variables;
246   std::vector<int> variable_index(arg_list.size(), -1);
247   variables.reserve(arg_list.size());
248   for (int i = 0; i < arg_list.size(); ++i) {
249     // Arguments are assumed to be variables if they have a resource type.
250     // (Non-variable resources are not supported.)
251     if (context->input_dtype(i) == DT_RESOURCE) {
252       variable_index[i] = variables.size();
253       // TODO(phawkins): we may be looking up many variables here; it would be
254       // better if we did not repeatedly acquire the resource manager's lock.
255       const ResourceHandle& handle = HandleFromInput(context, i);
256       Var* variable;
257       TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
258       variables.push_back(VariableInfo(i, handle.name(), variable));
259     } else {
260       TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
261     }
262   }
263 
264   // Lock the variables, and validate their shapes. We hold the variable locks
265   // for the duration of the TPU execution so we can donate the variable buffers
266   // to the computation. If we copied the variable's Tensor instead, its
267   // reference count would be greater than one due to the reference the Var
268   // object holds, and we would never be able to reuse variable buffers.
269   // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
270   // the user to elect to copy the buffers and permit concurrent access instead.
271   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
272   for (int i = 0; i < variables.size(); ++i) {
273     TF_RETURN_IF_ERROR(
274         validate_shape(variables[i].index(), *variables[i].var()->tensor()));
275   }
276 
277   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
278   xla::TransferManager* const transfer_manager = backend->transfer_manager();
279 
280   auto input_buffers = absl::make_unique<InputBuffers>(
281       transfer_manager->HostShapeToDeviceShape(input_host_shape));
282 
283   // Allocates a buffer for the root tuple.
284   const int64_t root_size =
285       transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
286   TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
287                       allocator->Allocate(device_ordinal, root_size));
288 
289   // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
290   // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
291   // to the computation and overwrites the entries in 'buffers' with nulls.
292   auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
293                                       xla::ShapedBuffer* buffers) {
294     buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
295                                                  se::DeviceMemoryBase* buffer) {
296       xla::ShapeIndex in_index = {arg_index};
297       for (int64_t j : index) {
298         in_index.push_back(j);
299       }
300       auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
301       if (donate_buffers) {
302         *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
303         *buffer = se::DeviceMemoryBase();
304       } else {
305         *in_buffer = *buffer;
306       }
307     });
308   };
309 
310   // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
311   // buffers for zero-element tensors where required.
312   auto assign_input = [&](int i, const Tensor& tensor,
313                           bool may_reuse) -> xla::Status {
314     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
315 
316     // Size 0 tensors have no backing XlaTensor, but may still need to have
317     // tuple buffers allocated.
318     if (xla_tensor == nullptr) {
319       CHECK_EQ(tensor.NumElements(), 0);
320       const xla::Shape& host_shape =
321           xla::ShapeUtil::GetSubshape(input_host_shape, {i});
322       TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
323                           transfer_manager->AllocateScopedShapedBuffer(
324                               host_shape, allocator, device_ordinal));
325       set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
326                                &buffers);
327     } else {
328       bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
329       set_input_buffers_helper(/*arg_index=*/i,
330                                /*donate_buffers=*/can_reuse_buffers,
331                                &xla_tensor->shaped_buffer());
332       xla_tensor->WaitForDefinitionEventOnStream(stream);
333     }
334     return OkStatus();
335   };
336 
337   for (int i = 0; i < arg_list.size(); ++i) {
338     auto it = variable_updates.input_to_output.find(i);
339     if (it == variable_updates.input_to_output.end()) {
340       TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
341       continue;
342     }
343     // input i is a variable
344     bool updated = it->second >= 0;
345     if (arg_list[i].dtype() != DT_RESOURCE) {
346       TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
347     } else {
348       int vi = variable_index[i];
349       TF_RETURN_IF_ERROR(
350           assign_input(i, *variables[vi].var()->tensor(), updated));
351     }
352   }
353 
354   input_buffers->variables = std::move(variables);
355   input_buffers->variable_index = std::move(variable_index);
356 
357   return std::move(input_buffers);
358 }
359 
360 struct OutputBuffers {
OutputBufferstensorflow::__anon2876098b0111::OutputBuffers361   OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
362       : owned_buffers(b.on_device_shape(), true),
363         buffers(b.release()),
364         memory_allocator(allocator) {}
365 
~OutputBufferstensorflow::__anon2876098b0111::OutputBuffers366   ~OutputBuffers() {
367     buffers.buffers().ForEachElement(
368         [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
369           if (owned_buffers.element(index) && !buffer.is_null()) {
370             Status status =
371                 memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
372             if (!status.ok()) {
373               LOG(ERROR) << "Error deallocating buffer " << status;
374             }
375           }
376         });
377   }
378 
379   // Which of the buffers do we own?
380   xla::ShapeTree<bool> owned_buffers;
381 
382   xla::ShapedBuffer buffers;
383 
384   se::DeviceMemoryAllocator* const memory_allocator;
385 };
386 
387 // Allocates Tensors for the outputs of the computation. Ownership of most
388 // output buffers is passed to the output Tensors. Returns an OutputBuffer that
389 // owns the root buffer that should be passed to the XLA computation, as well as
390 // any output buffers that do not have corresponding output tensors. The latter
391 // may happen for zero-element tensors of type int64 or complex64 which still
392 // require a tuple buffer but do not have a corresponding XlaTensor.
AllocateOutputTensors(OpKernelContext * context,xla::ScopedShapedBuffer scoped_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,const VariableUpdateMap & variable_updates,TpuNodeContext * node_context,se::Stream * stream,int device_ordinal,InputBuffers * input_buffers,const std::shared_ptr<se::Event> & definition_event)393 xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
394     OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
395     absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
396     const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
397     se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
398     const std::shared_ptr<se::Event>& definition_event) {
399   VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
400 
401   profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
402   // Shapes of the outputs, in TensorShape form.
403   const int64_t sub_elements =
404       xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
405   if (sub_elements != output_tensor_shape_protos.size()) {
406     return errors::InvalidArgument(
407         "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
408         output_tensor_shape_protos.size());
409   }
410 
411   xla::TransferManager* const transfer_manager =
412       node_context->backend()->transfer_manager();
413 
414   std::vector<TensorShape> output_tensor_shapes;
415   output_tensor_shapes.reserve(sub_elements);
416   for (int64_t i = 0; i < sub_elements; ++i) {
417     TF_RETURN_IF_ERROR(
418         TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
419     TensorShape shape(*output_tensor_shape_protos[i]);
420     const xla::Shape& xla_shape =
421         xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
422     if (!xla_shape.IsArray() ||
423         xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
424       return errors::InvalidArgument(
425           "Mismatched number of elements in output shape: ",
426           xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
427     }
428     output_tensor_shapes.push_back(shape);
429   }
430 
431   // Builds a shaped buffer for the outputs.
432   TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
433   TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
434 
435   se::DeviceMemoryAllocator* const allocator =
436       node_context->backend()->memory_allocator();
437 
438   auto output_buffers =
439       absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
440 
441   xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
442 
443   if (!output_device_shape.is_static()) {
444     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
445         stream, &output_buffers->buffers, &output_device_shape));
446     for (int64_t i = 0; i < sub_elements; ++i) {
447       const xla::Shape& subshape =
448           xla::ShapeUtil::GetSubshape(output_device_shape, {i});
449       TensorShape shape;
450       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
451       output_tensor_shapes[i] = shape;
452     }
453   }
454 
455   // Transfers ownership of the buffers that back XLA computation output 'i'
456   // to 'output_tensor'.
457   auto transfer_buffers = [&](int i, Tensor* output_tensor) {
458     const xla::Shape& device_shape =
459         xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
460 
461     // Transfers ownership of the output buffers to the output Tensor, if
462     // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
463     // backing XlaTensor, so we let retain 'output_buffers' ownership of any
464     // buffers in that case.
465     if (output_tensor->NumElements() > 0) {
466       xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
467                                             device_ordinal);
468       shaped_buffer.buffers().ForEachMutableElement(
469           [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
470             xla::ShapeIndex out_index = {i};
471             for (int64_t j : index) {
472               out_index.push_back(j);
473             }
474             *buffer = output_buffers->buffers.buffers().element(out_index);
475             *output_buffers->owned_buffers.mutable_element(out_index) = false;
476           });
477 
478       XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
479       xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
480       xla_tensor->ResetDefinitionEvent(definition_event, stream);
481     }
482   };
483 
484   const int num_updated_variables = variable_updates.output_to_input.size();
485   TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
486       << num_updated_variables << " <= " << output_tensor_shapes.size();
487 
488   OpInputList arg_list;
489   TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
490 
491   // The TPU program outputs the updated variables including DT_RESOURCE and
492   // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
493   // variables (updated or not).
494   //
495   //                       updated          not_updated
496   //                 |------------------|------------------|
497   // DT_RESOURCE     | allocate persist |    do nothing    |
498   //                 |------------------|------------------|
499   //                 |     allocate     | forward Op input |
500   // not DT_RESOURCE |      output      |   to Op output   | Op output
501   //                 |------------------|------------------|
502   //                    program output
503 
504   // Allocates a fresh tensor for each updated variable. While the variable
505   // inputs need come in no particular order, the variable values are
506   // always added last by XlaCompiler class, in the same order as the
507   // corresponding input variables.
508   int op_output_index = 0;
509   int compiled_update_index = 0;
510   auto process_non_updated_variable = [&](int input_index) {
511     const int variable_index = input_buffers->variable_index.at(input_index);
512     // If a DT_RESOURCE input is not updated, nothing needs to be done
513     // because there is no corresponding output. If a non-resource input
514     // is not updated, forward the input to the output.
515     if (variable_index < 0) {
516       context->set_output(op_output_index, arg_list[input_index]);
517       ++op_output_index;
518     }
519   };
520   for (int i = 0; i < output_tensor_shapes.size(); ++i) {
521     auto it = variable_updates.output_to_input.find(i);
522     if (it == variable_updates.output_to_input.end()) {
523       // Not a variable update.
524       // Allocates a fresh tensor for each output of the operator. We always
525       // allocate a new host-side tensor, but the on-device buffers that back
526       // that tensor may be aliases of input buffers.
527       Tensor* output_tensor;
528       TF_RETURN_IF_ERROR(context->allocate_output(
529           op_output_index, output_tensor_shapes[i], &output_tensor));
530       transfer_buffers(i, output_tensor);
531       ++op_output_index;
532       continue;
533     }
534     const int input_index = it->second.first;
535     // We must process the compiled updates in order, which includes the
536     // non-updated variables, i.e., those without an XLA output.
537     const bool from_compilation = it->second.second;
538     while (from_compilation &&
539            variable_updates
540                    .input_in_compiled_update_order[compiled_update_index] !=
541                input_index) {
542       process_non_updated_variable(
543           variable_updates
544               .input_in_compiled_update_order[compiled_update_index]);
545       ++compiled_update_index;
546     }
547     ++compiled_update_index;
548     const int variable_index = input_buffers->variable_index.at(input_index);
549     if (variable_index >= 0) {
550       // This output corresponds to a DT_RESOURCE input to the TPUExecute
551       // operator. Update the corresponding variable.
552       VariableInfo& var = input_buffers->variables[variable_index];
553       TF_RETURN_IF_ERROR(context->allocate_temp(var.var()->tensor()->dtype(),
554                                                 output_tensor_shapes[i],
555                                                 var.var()->tensor()));
556       transfer_buffers(i, var.var()->tensor());
557     } else {
558       // This output corresponds to a non-resource input to the TPUExecute
559       // operator. This case occurs for the distributed TPU rewrite which
560       // adds variable values as inputs and outputs rather than passing the
561       // variables themselves; reading and writing the variable is handled
562       // outside the op.
563       // TODO(phawkins): remove this case when placement of variables on TPU
564       // devices is well supported and we no longer need to place "remote"
565       // variables on CPU devices.
566       Tensor* output_tensor;
567       TF_RETURN_IF_ERROR(context->allocate_output(
568           op_output_index, output_tensor_shapes[i], &output_tensor));
569       ++op_output_index;
570       transfer_buffers(i, output_tensor);
571     }
572   }
573 
574   // Process any remaining non-updated variables.
575   for (; compiled_update_index <
576          variable_updates.input_in_compiled_update_order.size();
577        ++compiled_update_index) {
578     process_non_updated_variable(
579         variable_updates.input_in_compiled_update_order[compiled_update_index]);
580   }
581   return std::move(output_buffers);
582 }
583 
584 }  // namespace
585 
586 // TPUExecuteOp
587 
TPUExecuteOp(OpKernelConstruction * context)588 TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
589     : AsyncOpKernel(context, /* is_deferred = */ true) {}
590 
AsAsync()591 AsyncOpKernel* TPUExecuteOp::AsAsync() {
592   // If TPU launches are asynchronous, we can perform the launch without
593   // blocking the calling thread, and so the executor may treat this kernel as
594   // a regular (synchronous) OpKernel.
595   return nullptr;
596 }
597 
Compute(OpKernelContext * context)598 void TPUExecuteOp::Compute(OpKernelContext* context) {
599   Status s = DoWork(context);
600   // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
601   // a dynamic check that we are not in an AsyncOpKernel.
602   if (TF_PREDICT_FALSE(!s.ok())) {
603     context->SetStatus(s);
604   }
605 }
606 
ComputeAsync(OpKernelContext * context,DoneCallback done)607 void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
608   // If TPU launches are asynchronous, then perform the launch on this
609   // thread to avoid a thread hop, which has an observable latency cost.
610   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
611   done();
612 }
613 
DoWork(OpKernelContext * context)614 Status TPUExecuteOp::DoWork(OpKernelContext* context) {
615   VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
616 
617   const XlaDevice::Metadata* metadata;
618   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
619   const int device_ordinal = metadata->device_ordinal();
620 
621   // We are guaranteed that the object underlying TpuNodeContext won't be
622   // deleted out from under us, while node_context is alive.
623   TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
624                       TpuNodeContext::Create(device_ordinal));
625 
626   profiler::TraceMe trace_me(
627       [device_ordinal, context] {
628         return profiler::TraceMeEncode(
629             "TpuExecuteOp", {{"device_ordinal", device_ordinal},
630                              {"id", context->step_id()},
631                              {"iter_num", context->frame_iter().iter_id}});
632       },
633       /*level=*/2);
634   profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
635 
636   string rendezvous_key_base;
637   std::unique_ptr<CompilationCacheEntryRef> entry_ref;
638   TF_RETURN_IF_ERROR(
639       GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
640 
641   // Shapes of the inputs and outputs, in xla::Shape form.
642   tpu::TpuCompilationCacheEntry entry = entry_ref->get();
643   const tpu::TpuProgramGroup* tpu_program_group =
644       tensorflow::down_cast<const tpu::TpuProgramGroup*>(
645           entry.tpu_program_group());
646   CHECK_NE(tpu_program_group, nullptr);
647   const int core_index = entry.core_index();
648   const TPUExecutableInfoProto& executable =
649       tpu_program_group->executable_info(core_index);
650 
651   xla::Backend* const backend = node_context->backend();
652   xla::TransferManager* const transfer_manager = backend->transfer_manager();
653   TF_RET_CHECK(context->op_device_context());
654   se::Stream* stream = context->op_device_context()->stream();
655 
656   TF_RET_CHECK(executable.input_shapes_size() == 1);
657 
658   xla::Shape host_shape(executable.input_shapes(0));
659 
660   TF_ASSIGN_OR_RETURN(
661       auto variable_update_map,
662       BuildVariableUpdateMap(executable.variable_indices(),
663                              fused_device_var_reads_in_computation_inputs_,
664                              fused_device_var_updates_in_computation_outputs_,
665                              executable.output_tensor_shapes().size()));
666   TF_ASSIGN_OR_RETURN(
667       std::unique_ptr<InputBuffers> input_buffers,
668       BuildComputationInputs(context, host_shape, variable_update_map, backend,
669                              device_ordinal, stream));
670 
671   // Ideally this should be the host-to-device stream from XlaDeviceContext.
672   // The particular anti-dependency this is avoiding (why we need a separate
673   // transfer stream) is between the executable writing tuple tables and
674   // TPUExecute()'s deregister_stream; if they come from the same stream pool
675   // antidependencies will occur. XlaBackend has a different pool of streams
676   // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
677   // will never refer to the same stream.
678   //
679   // TODO(jmolloy): Add the necessary plumbing to obtain the proper
680   // host-to-device stream here.
681   TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
682                       backend->BorrowStream(device_ordinal));
683 
684   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
685   auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
686                                                      allocator, device_ordinal);
687   if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
688                                                      shaped_buffer)) {
689     TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
690         transfer_stream_ptr.get(), shaped_buffer));
691     stream->ThenWaitFor(transfer_stream_ptr.get());
692   } else {
693     TF_RETURN_IF_ERROR(
694         transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
695   }
696   VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
697 
698   // Snapshot the inputs, if a snapshot was requested.
699   std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
700   if (executable.has_session_module()) {
701     hlo_snapshot =
702         std::make_shared<xla::HloSnapshot>(executable.session_module());
703     auto literal =
704         std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
705     transfer_manager->TransferLiteralFromDevice(
706         stream, shaped_buffer, literal.get(),
707         [hlo_snapshot, literal](Status status) {
708           if (!status.ok()) {
709             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
710                           "failed: "
711                        << status;
712             return;
713           }
714           *hlo_snapshot->add_arguments() = literal->ToProto();
715         });
716   }
717 
718   auto definition_event = std::make_shared<se::Event>(stream->parent());
719   TF_RET_CHECK(definition_event->Init())
720       << "TPU definition event initialization failed";
721 
722   trace_me_init.Stop();
723 
724   const uint32 rng_seed = GetXLARandomSeed();
725 
726   std::unique_ptr<xla::DeviceAssignment> device_assignment;
727   if (executable.has_device_assignment()) {
728     TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
729                                                executable.device_assignment()));
730   }
731 
732   VLOG(4) << "Input buffers after alias resolution: "
733           << shaped_buffer.ToString();
734 
735   std::vector<xla::ExecutionInput> input;
736   input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
737                                          shaped_buffer.on_host_shape()));
738 
739   // The buffers to be freed are in the `output` and will be automatically
740   // freed when it goes out of the scope. In async mode, this means the buffers
741   // will be freed before anyone calls "BlockHostUntilDone", which indicates
742   // that some of the (input) buffers will be freed while the program is running
743   // and looks scary. However, this turns out to be not a problem since although
744   // we free a memory and reassign it to other users while a program is running,
745   // all subsequent writes to the program that could possibly clobber the memory
746   // will depend on the program to finish.
747   const TPUHostTransferInfoProto& host_transfer_info =
748       tpu_program_group->host_transfer_info(core_index);
749   TF_ASSIGN_OR_RETURN(
750       xla::ExecutionOutput output,
751       TPUExecute(executable, host_transfer_info,
752                  *tpu_program_group->hlo_metadata(core_index), std::move(input),
753                  rendezvous_key_base, rng_seed, node_context.get(),
754                  device_assignment.get(), context->cancellation_manager(),
755                  context, stream, transfer_stream_ptr.get(),
756                  tpu_program_group->tpu_program(core_index)));
757   stream->ThenRecordEvent(definition_event.get());
758 
759   TF_ASSIGN_OR_RETURN(
760       std::unique_ptr<OutputBuffers> output_buffers,
761       AllocateOutputTensors(
762           context, output.ConsumeResult(), executable.output_tensor_shapes(),
763           variable_update_map, node_context.get(), stream, device_ordinal,
764           input_buffers.get(), definition_event));
765 
766   // Transfer the outputs and save the snapshot to disk.
767   if (hlo_snapshot) {
768     auto literal =
769         std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
770     transfer_manager->TransferLiteralFromDevice(
771         stream, output_buffers->buffers, literal.get(),
772         [hlo_snapshot, literal](Status status) {
773           if (status.ok()) {
774             *hlo_snapshot->mutable_result() = literal->ToProto();
775           } else {
776             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
777                           "outputs failed: "
778                        << status;
779           }
780           DumpHloSnapshotIfEnabled(*hlo_snapshot,
781                                    xla::GetDebugOptionsFromFlags());
782         });
783   }
784   return OkStatus();
785 }
786 
787 TPUExecuteOp::~TPUExecuteOp() = default;
788 
TPUExecuteAndUpdateVariablesOp(OpKernelConstruction * context)789 TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
790     OpKernelConstruction* context)
791     : TPUExecuteOp(context) {
792   OP_REQUIRES_OK(context, context->GetAttr(
793                               "device_var_reads_indices",
794                               &fused_device_var_reads_in_computation_inputs_));
795   OP_REQUIRES_OK(
796       context,
797       context->GetAttr("device_var_updates_indices",
798                        &fused_device_var_updates_in_computation_outputs_));
799 }
800 
801 REGISTER_KERNEL_BUILDER(
802     Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
803 
804 REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
805                             .Device(DEVICE_TPU_NODE)
806                             .HostMemory("key"),
807                         TPUExecuteAndUpdateVariablesOp);
808 
809 }  // namespace tensorflow
810