1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
16
17 #include <utility>
18
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_tensor.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/resource_mgr.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/casts.h"
44 #include "tensorflow/core/platform/tracing.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
47 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
48 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
51 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
52 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
53 #include "tensorflow/core/tpu/tpu_configuration.h"
54 #include "tensorflow/core/tpu/tpu_defs.h"
55 #include "tensorflow/core/tpu/tpu_execute.h"
56 #include "tensorflow/core/util/stream_executor_util.h"
57 #include "tensorflow/stream_executor/device_memory_allocator.h"
58 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
59
60 namespace tensorflow {
61 namespace {
62 using ::tensorflow::tpu::CompilationCacheEntryRef;
63 using ::tensorflow::tpu::TpuCompilationCacheLookup;
64 using ::tensorflow::tpu::TpuNodeContext;
65
66 // Looks up the input `key` in the compilation cache, populating
67 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(OpKernelContext * context,string * rendezvous_key_base,std::unique_ptr<CompilationCacheEntryRef> * entry)68 Status GetComputationCacheEntry(
69 OpKernelContext* context, string* rendezvous_key_base,
70 std::unique_ptr<CompilationCacheEntryRef>* entry) {
71 const Tensor* key;
72 TF_RETURN_IF_ERROR(context->input("key", &key));
73 profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
74 if (!TensorShapeUtils::IsVector(key->shape()) ||
75 key->shape().dim_size(0) != 3) {
76 return errors::InvalidArgument(
77 "Key argument to TPUExecute must be a 3-element vector");
78 }
79
80 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
81 TpuCompilationCacheLookup* proto_lookup;
82 TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
83 tpu::kCompiledProtoCacheResourceName,
84 &proto_lookup));
85 core::ScopedUnref lookup_unref(proto_lookup);
86 TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
87 *rendezvous_key_base = key->vec<tstring>()(1);
88 return OkStatus();
89 }
90
91 struct VariableUpdateMap {
92 // Maps input index to the updated output index. If the variable doesn't have
93 // an updated output, the corresponding output is set to -1.
94 absl::flat_hash_map<int, int> input_to_output;
95 // Maps output index to (the input index, whether the update is generated from
96 // compilation).
97 absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
98 // Part of the input indices that are from the compilation, in the compiled
99 // order.
100 std::vector<int> input_in_compiled_update_order;
101 };
102
103 // Creates a VariableUpdateMap from both the compilation and the fused variable
104 // reads/updates.
BuildVariableUpdateMap(absl::Span<const TPUExecutableInfoProto::UpdateIndexPair * const> compiled_variable_updates,absl::Span<int const> fused_device_var_reads_in_computation_inputs,const std::vector<int> & fused_device_var_updates_in_computation_outputs,int64_t computation_output_count)105 xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
106 absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
107 compiled_variable_updates,
108 absl::Span<int const> fused_device_var_reads_in_computation_inputs,
109 const std::vector<int>& fused_device_var_updates_in_computation_outputs,
110 int64_t computation_output_count) {
111 VariableUpdateMap map;
112 auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
113 TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
114 << "Duplicate variable input index: " << input;
115 if (output >= 0) {
116 TF_RET_CHECK(map.output_to_input
117 .emplace(output, std::make_pair(input, from_compilation))
118 .second)
119 << "Duplicate variable output index: " << output;
120 }
121 return OkStatus();
122 };
123
124 // First add the updates produced by the compilation. Not all variables are
125 // updated, and if not, they do not have an output in the XLA computation. The
126 // update output indices in the XLA computation start after the non-variable
127 // outputs.
128 int num_updated_variables = 0;
129 for (int i = 0; i < compiled_variable_updates.size(); ++i) {
130 const bool updated = compiled_variable_updates[i]->updated();
131 if (updated) ++num_updated_variables;
132 }
133 TF_RET_CHECK(num_updated_variables <= computation_output_count)
134 << num_updated_variables << " <= " << computation_output_count;
135 int64_t compiled_variable_output_index =
136 computation_output_count - num_updated_variables;
137 for (auto update : compiled_variable_updates) {
138 map.input_in_compiled_update_order.push_back(update->index());
139 if (!update->updated()) {
140 TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
141 continue;
142 }
143 TF_RETURN_IF_ERROR(
144 add_pair(update->index(), compiled_variable_output_index, true));
145 ++compiled_variable_output_index;
146 }
147
148 // Now add the updates from the attributes.
149 TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
150 fused_device_var_updates_in_computation_outputs.size());
151 for (int64_t i = 0; i < fused_device_var_reads_in_computation_inputs.size();
152 ++i) {
153 TF_RETURN_IF_ERROR(
154 add_pair(fused_device_var_reads_in_computation_inputs[i],
155 fused_device_var_updates_in_computation_outputs[i], false));
156 }
157 return map;
158 }
159
160 // Buffers representing the inputs to a computation.
161 struct InputBuffers {
InputBufferstensorflow::__anon2876098b0111::InputBuffers162 explicit InputBuffers(xla::Shape device_shape)
163 : buffers(std::move(device_shape)) {}
164
165 InputBuffers(const InputBuffers&) = delete;
166 InputBuffers& operator=(const InputBuffers&) = delete;
167
168 ~InputBuffers() = default;
169
ToShapedBuffertensorflow::__anon2876098b0111::InputBuffers170 xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
171 se::DeviceMemoryAllocator* allocator,
172 int device_ordinal) {
173 CHECK_NE(allocator, nullptr);
174 xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
175 device_ordinal);
176 shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
177 [](const xla::MaybeOwningDeviceMemory& buffer) {
178 return buffer.AsDeviceMemoryBase();
179 }));
180 return shaped_buffer;
181 }
182
183 // Describes the buffer tree.
184 xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
185
186 // Information about resource variables passed directly to TPUExecute.
187 std::vector<VariableInfo> variables;
188
189 // Mapping from input index to offsets in 'variables'. < 0 if the input does
190 // not correspond to a variable in 'variables'.
191 std::vector<int> variable_index;
192 };
193
194 // Builds an InputBuffers object that describes the inputs to the computation.
BuildComputationInputs(OpKernelContext * context,const xla::Shape & input_host_shape,const VariableUpdateMap & variable_updates,xla::Backend * backend,int device_ordinal,se::Stream * stream)195 xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
196 OpKernelContext* context, const xla::Shape& input_host_shape,
197 const VariableUpdateMap& variable_updates, xla::Backend* backend,
198 int device_ordinal, se::Stream* stream) {
199 profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
200 OpInputList arg_list;
201 TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
202
203 if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
204 return errors::InvalidArgument(
205 "Number of parameters (", arg_list.size(),
206 ") does not match input shape: ",
207 xla::ShapeUtil::TupleElementCount(input_host_shape));
208 }
209
210 auto validate_shape = [&](int i, const Tensor& tensor) {
211 const xla::Shape& expected =
212 xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
213 VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
214 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
215
216 if (xla_tensor == nullptr) {
217 // FromTensor failed; tensor must be empty.
218 if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
219 return errors::InvalidArgument(
220 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
221 context->op_kernel().requested_input(i), "). Expected ",
222 expected.DebugString(),
223 "; got empty tensor. If you are running "
224 "with TF2 TPU, make sure you set `drop_remainder=False` when "
225 "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
226 "size can be handled");
227 }
228 } else {
229 // Compare host shapes, easier than getting the expected device shape.
230 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
231 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
232 return errors::InvalidArgument(
233 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
234 context->op_kernel().requested_input(i), "). Expected ",
235 expected.DebugString(), "; got ", xla_shape.DebugString());
236 }
237 }
238
239 return OkStatus();
240 };
241
242 // Iterate over the inputs, validating the shapes of non-variable inputs,
243 // and creating a VariableInfo object for each variable. We consider variable
244 // inputs in a separate phase because we must acquire variable locks in order.
245 std::vector<VariableInfo> variables;
246 std::vector<int> variable_index(arg_list.size(), -1);
247 variables.reserve(arg_list.size());
248 for (int i = 0; i < arg_list.size(); ++i) {
249 // Arguments are assumed to be variables if they have a resource type.
250 // (Non-variable resources are not supported.)
251 if (context->input_dtype(i) == DT_RESOURCE) {
252 variable_index[i] = variables.size();
253 // TODO(phawkins): we may be looking up many variables here; it would be
254 // better if we did not repeatedly acquire the resource manager's lock.
255 const ResourceHandle& handle = HandleFromInput(context, i);
256 Var* variable;
257 TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
258 variables.push_back(VariableInfo(i, handle.name(), variable));
259 } else {
260 TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
261 }
262 }
263
264 // Lock the variables, and validate their shapes. We hold the variable locks
265 // for the duration of the TPU execution so we can donate the variable buffers
266 // to the computation. If we copied the variable's Tensor instead, its
267 // reference count would be greater than one due to the reference the Var
268 // object holds, and we would never be able to reuse variable buffers.
269 // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
270 // the user to elect to copy the buffers and permit concurrent access instead.
271 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
272 for (int i = 0; i < variables.size(); ++i) {
273 TF_RETURN_IF_ERROR(
274 validate_shape(variables[i].index(), *variables[i].var()->tensor()));
275 }
276
277 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
278 xla::TransferManager* const transfer_manager = backend->transfer_manager();
279
280 auto input_buffers = absl::make_unique<InputBuffers>(
281 transfer_manager->HostShapeToDeviceShape(input_host_shape));
282
283 // Allocates a buffer for the root tuple.
284 const int64_t root_size =
285 transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
286 TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
287 allocator->Allocate(device_ordinal, root_size));
288
289 // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
290 // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
291 // to the computation and overwrites the entries in 'buffers' with nulls.
292 auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
293 xla::ShapedBuffer* buffers) {
294 buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
295 se::DeviceMemoryBase* buffer) {
296 xla::ShapeIndex in_index = {arg_index};
297 for (int64_t j : index) {
298 in_index.push_back(j);
299 }
300 auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
301 if (donate_buffers) {
302 *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
303 *buffer = se::DeviceMemoryBase();
304 } else {
305 *in_buffer = *buffer;
306 }
307 });
308 };
309
310 // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
311 // buffers for zero-element tensors where required.
312 auto assign_input = [&](int i, const Tensor& tensor,
313 bool may_reuse) -> xla::Status {
314 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
315
316 // Size 0 tensors have no backing XlaTensor, but may still need to have
317 // tuple buffers allocated.
318 if (xla_tensor == nullptr) {
319 CHECK_EQ(tensor.NumElements(), 0);
320 const xla::Shape& host_shape =
321 xla::ShapeUtil::GetSubshape(input_host_shape, {i});
322 TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
323 transfer_manager->AllocateScopedShapedBuffer(
324 host_shape, allocator, device_ordinal));
325 set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
326 &buffers);
327 } else {
328 bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
329 set_input_buffers_helper(/*arg_index=*/i,
330 /*donate_buffers=*/can_reuse_buffers,
331 &xla_tensor->shaped_buffer());
332 xla_tensor->WaitForDefinitionEventOnStream(stream);
333 }
334 return OkStatus();
335 };
336
337 for (int i = 0; i < arg_list.size(); ++i) {
338 auto it = variable_updates.input_to_output.find(i);
339 if (it == variable_updates.input_to_output.end()) {
340 TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
341 continue;
342 }
343 // input i is a variable
344 bool updated = it->second >= 0;
345 if (arg_list[i].dtype() != DT_RESOURCE) {
346 TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
347 } else {
348 int vi = variable_index[i];
349 TF_RETURN_IF_ERROR(
350 assign_input(i, *variables[vi].var()->tensor(), updated));
351 }
352 }
353
354 input_buffers->variables = std::move(variables);
355 input_buffers->variable_index = std::move(variable_index);
356
357 return std::move(input_buffers);
358 }
359
360 struct OutputBuffers {
OutputBufferstensorflow::__anon2876098b0111::OutputBuffers361 OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
362 : owned_buffers(b.on_device_shape(), true),
363 buffers(b.release()),
364 memory_allocator(allocator) {}
365
~OutputBufferstensorflow::__anon2876098b0111::OutputBuffers366 ~OutputBuffers() {
367 buffers.buffers().ForEachElement(
368 [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
369 if (owned_buffers.element(index) && !buffer.is_null()) {
370 Status status =
371 memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
372 if (!status.ok()) {
373 LOG(ERROR) << "Error deallocating buffer " << status;
374 }
375 }
376 });
377 }
378
379 // Which of the buffers do we own?
380 xla::ShapeTree<bool> owned_buffers;
381
382 xla::ShapedBuffer buffers;
383
384 se::DeviceMemoryAllocator* const memory_allocator;
385 };
386
387 // Allocates Tensors for the outputs of the computation. Ownership of most
388 // output buffers is passed to the output Tensors. Returns an OutputBuffer that
389 // owns the root buffer that should be passed to the XLA computation, as well as
390 // any output buffers that do not have corresponding output tensors. The latter
391 // may happen for zero-element tensors of type int64 or complex64 which still
392 // require a tuple buffer but do not have a corresponding XlaTensor.
AllocateOutputTensors(OpKernelContext * context,xla::ScopedShapedBuffer scoped_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,const VariableUpdateMap & variable_updates,TpuNodeContext * node_context,se::Stream * stream,int device_ordinal,InputBuffers * input_buffers,const std::shared_ptr<se::Event> & definition_event)393 xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
394 OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
395 absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
396 const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
397 se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
398 const std::shared_ptr<se::Event>& definition_event) {
399 VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
400
401 profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
402 // Shapes of the outputs, in TensorShape form.
403 const int64_t sub_elements =
404 xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
405 if (sub_elements != output_tensor_shape_protos.size()) {
406 return errors::InvalidArgument(
407 "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
408 output_tensor_shape_protos.size());
409 }
410
411 xla::TransferManager* const transfer_manager =
412 node_context->backend()->transfer_manager();
413
414 std::vector<TensorShape> output_tensor_shapes;
415 output_tensor_shapes.reserve(sub_elements);
416 for (int64_t i = 0; i < sub_elements; ++i) {
417 TF_RETURN_IF_ERROR(
418 TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
419 TensorShape shape(*output_tensor_shape_protos[i]);
420 const xla::Shape& xla_shape =
421 xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
422 if (!xla_shape.IsArray() ||
423 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
424 return errors::InvalidArgument(
425 "Mismatched number of elements in output shape: ",
426 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
427 }
428 output_tensor_shapes.push_back(shape);
429 }
430
431 // Builds a shaped buffer for the outputs.
432 TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
433 TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
434
435 se::DeviceMemoryAllocator* const allocator =
436 node_context->backend()->memory_allocator();
437
438 auto output_buffers =
439 absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
440
441 xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
442
443 if (!output_device_shape.is_static()) {
444 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
445 stream, &output_buffers->buffers, &output_device_shape));
446 for (int64_t i = 0; i < sub_elements; ++i) {
447 const xla::Shape& subshape =
448 xla::ShapeUtil::GetSubshape(output_device_shape, {i});
449 TensorShape shape;
450 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
451 output_tensor_shapes[i] = shape;
452 }
453 }
454
455 // Transfers ownership of the buffers that back XLA computation output 'i'
456 // to 'output_tensor'.
457 auto transfer_buffers = [&](int i, Tensor* output_tensor) {
458 const xla::Shape& device_shape =
459 xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
460
461 // Transfers ownership of the output buffers to the output Tensor, if
462 // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
463 // backing XlaTensor, so we let retain 'output_buffers' ownership of any
464 // buffers in that case.
465 if (output_tensor->NumElements() > 0) {
466 xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
467 device_ordinal);
468 shaped_buffer.buffers().ForEachMutableElement(
469 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
470 xla::ShapeIndex out_index = {i};
471 for (int64_t j : index) {
472 out_index.push_back(j);
473 }
474 *buffer = output_buffers->buffers.buffers().element(out_index);
475 *output_buffers->owned_buffers.mutable_element(out_index) = false;
476 });
477
478 XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
479 xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
480 xla_tensor->ResetDefinitionEvent(definition_event, stream);
481 }
482 };
483
484 const int num_updated_variables = variable_updates.output_to_input.size();
485 TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
486 << num_updated_variables << " <= " << output_tensor_shapes.size();
487
488 OpInputList arg_list;
489 TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
490
491 // The TPU program outputs the updated variables including DT_RESOURCE and
492 // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
493 // variables (updated or not).
494 //
495 // updated not_updated
496 // |------------------|------------------|
497 // DT_RESOURCE | allocate persist | do nothing |
498 // |------------------|------------------|
499 // | allocate | forward Op input |
500 // not DT_RESOURCE | output | to Op output | Op output
501 // |------------------|------------------|
502 // program output
503
504 // Allocates a fresh tensor for each updated variable. While the variable
505 // inputs need come in no particular order, the variable values are
506 // always added last by XlaCompiler class, in the same order as the
507 // corresponding input variables.
508 int op_output_index = 0;
509 int compiled_update_index = 0;
510 auto process_non_updated_variable = [&](int input_index) {
511 const int variable_index = input_buffers->variable_index.at(input_index);
512 // If a DT_RESOURCE input is not updated, nothing needs to be done
513 // because there is no corresponding output. If a non-resource input
514 // is not updated, forward the input to the output.
515 if (variable_index < 0) {
516 context->set_output(op_output_index, arg_list[input_index]);
517 ++op_output_index;
518 }
519 };
520 for (int i = 0; i < output_tensor_shapes.size(); ++i) {
521 auto it = variable_updates.output_to_input.find(i);
522 if (it == variable_updates.output_to_input.end()) {
523 // Not a variable update.
524 // Allocates a fresh tensor for each output of the operator. We always
525 // allocate a new host-side tensor, but the on-device buffers that back
526 // that tensor may be aliases of input buffers.
527 Tensor* output_tensor;
528 TF_RETURN_IF_ERROR(context->allocate_output(
529 op_output_index, output_tensor_shapes[i], &output_tensor));
530 transfer_buffers(i, output_tensor);
531 ++op_output_index;
532 continue;
533 }
534 const int input_index = it->second.first;
535 // We must process the compiled updates in order, which includes the
536 // non-updated variables, i.e., those without an XLA output.
537 const bool from_compilation = it->second.second;
538 while (from_compilation &&
539 variable_updates
540 .input_in_compiled_update_order[compiled_update_index] !=
541 input_index) {
542 process_non_updated_variable(
543 variable_updates
544 .input_in_compiled_update_order[compiled_update_index]);
545 ++compiled_update_index;
546 }
547 ++compiled_update_index;
548 const int variable_index = input_buffers->variable_index.at(input_index);
549 if (variable_index >= 0) {
550 // This output corresponds to a DT_RESOURCE input to the TPUExecute
551 // operator. Update the corresponding variable.
552 VariableInfo& var = input_buffers->variables[variable_index];
553 TF_RETURN_IF_ERROR(context->allocate_temp(var.var()->tensor()->dtype(),
554 output_tensor_shapes[i],
555 var.var()->tensor()));
556 transfer_buffers(i, var.var()->tensor());
557 } else {
558 // This output corresponds to a non-resource input to the TPUExecute
559 // operator. This case occurs for the distributed TPU rewrite which
560 // adds variable values as inputs and outputs rather than passing the
561 // variables themselves; reading and writing the variable is handled
562 // outside the op.
563 // TODO(phawkins): remove this case when placement of variables on TPU
564 // devices is well supported and we no longer need to place "remote"
565 // variables on CPU devices.
566 Tensor* output_tensor;
567 TF_RETURN_IF_ERROR(context->allocate_output(
568 op_output_index, output_tensor_shapes[i], &output_tensor));
569 ++op_output_index;
570 transfer_buffers(i, output_tensor);
571 }
572 }
573
574 // Process any remaining non-updated variables.
575 for (; compiled_update_index <
576 variable_updates.input_in_compiled_update_order.size();
577 ++compiled_update_index) {
578 process_non_updated_variable(
579 variable_updates.input_in_compiled_update_order[compiled_update_index]);
580 }
581 return std::move(output_buffers);
582 }
583
584 } // namespace
585
586 // TPUExecuteOp
587
TPUExecuteOp(OpKernelConstruction * context)588 TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
589 : AsyncOpKernel(context, /* is_deferred = */ true) {}
590
AsAsync()591 AsyncOpKernel* TPUExecuteOp::AsAsync() {
592 // If TPU launches are asynchronous, we can perform the launch without
593 // blocking the calling thread, and so the executor may treat this kernel as
594 // a regular (synchronous) OpKernel.
595 return nullptr;
596 }
597
Compute(OpKernelContext * context)598 void TPUExecuteOp::Compute(OpKernelContext* context) {
599 Status s = DoWork(context);
600 // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
601 // a dynamic check that we are not in an AsyncOpKernel.
602 if (TF_PREDICT_FALSE(!s.ok())) {
603 context->SetStatus(s);
604 }
605 }
606
ComputeAsync(OpKernelContext * context,DoneCallback done)607 void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
608 // If TPU launches are asynchronous, then perform the launch on this
609 // thread to avoid a thread hop, which has an observable latency cost.
610 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
611 done();
612 }
613
DoWork(OpKernelContext * context)614 Status TPUExecuteOp::DoWork(OpKernelContext* context) {
615 VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
616
617 const XlaDevice::Metadata* metadata;
618 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
619 const int device_ordinal = metadata->device_ordinal();
620
621 // We are guaranteed that the object underlying TpuNodeContext won't be
622 // deleted out from under us, while node_context is alive.
623 TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
624 TpuNodeContext::Create(device_ordinal));
625
626 profiler::TraceMe trace_me(
627 [device_ordinal, context] {
628 return profiler::TraceMeEncode(
629 "TpuExecuteOp", {{"device_ordinal", device_ordinal},
630 {"id", context->step_id()},
631 {"iter_num", context->frame_iter().iter_id}});
632 },
633 /*level=*/2);
634 profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
635
636 string rendezvous_key_base;
637 std::unique_ptr<CompilationCacheEntryRef> entry_ref;
638 TF_RETURN_IF_ERROR(
639 GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
640
641 // Shapes of the inputs and outputs, in xla::Shape form.
642 tpu::TpuCompilationCacheEntry entry = entry_ref->get();
643 const tpu::TpuProgramGroup* tpu_program_group =
644 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
645 entry.tpu_program_group());
646 CHECK_NE(tpu_program_group, nullptr);
647 const int core_index = entry.core_index();
648 const TPUExecutableInfoProto& executable =
649 tpu_program_group->executable_info(core_index);
650
651 xla::Backend* const backend = node_context->backend();
652 xla::TransferManager* const transfer_manager = backend->transfer_manager();
653 TF_RET_CHECK(context->op_device_context());
654 se::Stream* stream = context->op_device_context()->stream();
655
656 TF_RET_CHECK(executable.input_shapes_size() == 1);
657
658 xla::Shape host_shape(executable.input_shapes(0));
659
660 TF_ASSIGN_OR_RETURN(
661 auto variable_update_map,
662 BuildVariableUpdateMap(executable.variable_indices(),
663 fused_device_var_reads_in_computation_inputs_,
664 fused_device_var_updates_in_computation_outputs_,
665 executable.output_tensor_shapes().size()));
666 TF_ASSIGN_OR_RETURN(
667 std::unique_ptr<InputBuffers> input_buffers,
668 BuildComputationInputs(context, host_shape, variable_update_map, backend,
669 device_ordinal, stream));
670
671 // Ideally this should be the host-to-device stream from XlaDeviceContext.
672 // The particular anti-dependency this is avoiding (why we need a separate
673 // transfer stream) is between the executable writing tuple tables and
674 // TPUExecute()'s deregister_stream; if they come from the same stream pool
675 // antidependencies will occur. XlaBackend has a different pool of streams
676 // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
677 // will never refer to the same stream.
678 //
679 // TODO(jmolloy): Add the necessary plumbing to obtain the proper
680 // host-to-device stream here.
681 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
682 backend->BorrowStream(device_ordinal));
683
684 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
685 auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
686 allocator, device_ordinal);
687 if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
688 shaped_buffer)) {
689 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
690 transfer_stream_ptr.get(), shaped_buffer));
691 stream->ThenWaitFor(transfer_stream_ptr.get());
692 } else {
693 TF_RETURN_IF_ERROR(
694 transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
695 }
696 VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
697
698 // Snapshot the inputs, if a snapshot was requested.
699 std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
700 if (executable.has_session_module()) {
701 hlo_snapshot =
702 std::make_shared<xla::HloSnapshot>(executable.session_module());
703 auto literal =
704 std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
705 transfer_manager->TransferLiteralFromDevice(
706 stream, shaped_buffer, literal.get(),
707 [hlo_snapshot, literal](Status status) {
708 if (!status.ok()) {
709 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
710 "failed: "
711 << status;
712 return;
713 }
714 *hlo_snapshot->add_arguments() = literal->ToProto();
715 });
716 }
717
718 auto definition_event = std::make_shared<se::Event>(stream->parent());
719 TF_RET_CHECK(definition_event->Init())
720 << "TPU definition event initialization failed";
721
722 trace_me_init.Stop();
723
724 const uint32 rng_seed = GetXLARandomSeed();
725
726 std::unique_ptr<xla::DeviceAssignment> device_assignment;
727 if (executable.has_device_assignment()) {
728 TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
729 executable.device_assignment()));
730 }
731
732 VLOG(4) << "Input buffers after alias resolution: "
733 << shaped_buffer.ToString();
734
735 std::vector<xla::ExecutionInput> input;
736 input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
737 shaped_buffer.on_host_shape()));
738
739 // The buffers to be freed are in the `output` and will be automatically
740 // freed when it goes out of the scope. In async mode, this means the buffers
741 // will be freed before anyone calls "BlockHostUntilDone", which indicates
742 // that some of the (input) buffers will be freed while the program is running
743 // and looks scary. However, this turns out to be not a problem since although
744 // we free a memory and reassign it to other users while a program is running,
745 // all subsequent writes to the program that could possibly clobber the memory
746 // will depend on the program to finish.
747 const TPUHostTransferInfoProto& host_transfer_info =
748 tpu_program_group->host_transfer_info(core_index);
749 TF_ASSIGN_OR_RETURN(
750 xla::ExecutionOutput output,
751 TPUExecute(executable, host_transfer_info,
752 *tpu_program_group->hlo_metadata(core_index), std::move(input),
753 rendezvous_key_base, rng_seed, node_context.get(),
754 device_assignment.get(), context->cancellation_manager(),
755 context, stream, transfer_stream_ptr.get(),
756 tpu_program_group->tpu_program(core_index)));
757 stream->ThenRecordEvent(definition_event.get());
758
759 TF_ASSIGN_OR_RETURN(
760 std::unique_ptr<OutputBuffers> output_buffers,
761 AllocateOutputTensors(
762 context, output.ConsumeResult(), executable.output_tensor_shapes(),
763 variable_update_map, node_context.get(), stream, device_ordinal,
764 input_buffers.get(), definition_event));
765
766 // Transfer the outputs and save the snapshot to disk.
767 if (hlo_snapshot) {
768 auto literal =
769 std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
770 transfer_manager->TransferLiteralFromDevice(
771 stream, output_buffers->buffers, literal.get(),
772 [hlo_snapshot, literal](Status status) {
773 if (status.ok()) {
774 *hlo_snapshot->mutable_result() = literal->ToProto();
775 } else {
776 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
777 "outputs failed: "
778 << status;
779 }
780 DumpHloSnapshotIfEnabled(*hlo_snapshot,
781 xla::GetDebugOptionsFromFlags());
782 });
783 }
784 return OkStatus();
785 }
786
787 TPUExecuteOp::~TPUExecuteOp() = default;
788
TPUExecuteAndUpdateVariablesOp(OpKernelConstruction * context)789 TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
790 OpKernelConstruction* context)
791 : TPUExecuteOp(context) {
792 OP_REQUIRES_OK(context, context->GetAttr(
793 "device_var_reads_indices",
794 &fused_device_var_reads_in_computation_inputs_));
795 OP_REQUIRES_OK(
796 context,
797 context->GetAttr("device_var_updates_indices",
798 &fused_device_var_updates_in_computation_outputs_));
799 }
800
801 REGISTER_KERNEL_BUILDER(
802 Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
803
804 REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
805 .Device(DEVICE_TPU_NODE)
806 .HostMemory("key"),
807 TPUExecuteAndUpdateVariablesOp);
808
809 } // namespace tensorflow
810