xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
17 
18 #include "tensorflow/compiler/jit/xla_device.h"
19 #include "tensorflow/compiler/jit/xla_launch_util.h"
20 #include "tensorflow/compiler/jit/xla_tensor.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_mgr.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
35 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
36 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
37 #include "tensorflow/core/tpu/tpu_configuration.h"
38 #include "tensorflow/core/tpu/tpu_defs.h"
39 #include "tensorflow/core/tpu/tpu_execute.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41 #include "tensorflow/stream_executor/device_memory_allocator.h"
42 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
44 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
45 
46 namespace tensorflow {
47 namespace tpu {
48 namespace reshard_variables {
49 
FlushProgramMemory(se::Platform * platform,int device_ordinal)50 Status FlushProgramMemory(se::Platform* platform, int device_ordinal) {
51   TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
52                       tpu::TpuNodeContext::Create(device_ordinal));
53 
54   auto* executor = tensorflow::down_cast<tpu::TpuExecutorInterface*>(
55       node_interfaces->stream_executor()->implementation());
56   return executor->UnloadAllPrograms();
57 }
58 
CheckIsValidKey(const Tensor & key)59 Status CheckIsValidKey(const Tensor& key) {
60   if (!TensorShapeUtils::IsVector(key.shape()) ||
61       key.shape().dim_size(0) != 3) {
62     return errors::InvalidArgument(
63         "new_format_key argument to TPUReshardVariables  must be a 3-element "
64         "vector");
65   }
66   if (key.dtype() != DT_STRING) {
67     return errors::InvalidArgument(
68         "new_format_key argument to TPUReshardVariables must be DT_STRING "
69         "type");
70   }
71   return OkStatus();
72 }
73 
IsDefaultKey(const Tensor & key)74 bool IsDefaultKey(const Tensor& key) { return key.vec<tstring>()(0).empty(); }
75 
76 // Looks up the input `key` in the compilation cache, populating
77 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(const Tensor & key,string * rendezvous_key_base,std::unique_ptr<tpu::CompilationCacheEntryRef> * entry,tpu::CompilationCacheFetchTarget fetch_target)78 Status GetComputationCacheEntry(
79     const Tensor& key, string* rendezvous_key_base,
80     std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
81     tpu::CompilationCacheFetchTarget fetch_target) {
82   profiler::TraceMe trace_me("TPUReshardVariablesOpKernel::LookupProto",
83                              /*level=*/2);
84   TF_RETURN_IF_ERROR(CheckIsValidKey(key));
85   auto* rmgr = GetTPUConfigResourceMgr();
86   tpu::TpuCompilationCacheLookup* proto_lookup;
87   TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
88                                   tpu::kCompiledProtoCacheResourceName,
89                                   &proto_lookup));
90   core::ScopedUnref lookup_unref(proto_lookup);
91   TF_RETURN_IF_ERROR(
92       proto_lookup->Lookup(key.vec<tstring>()(0), entry, fetch_target));
93   *rendezvous_key_base = key.vec<tstring>()(1);
94   return OkStatus();
95 }
96 
97 // Builds an InputBuffers object that describes the inputs to the computation.
BuildInputBuffers(OpKernelContext * context,const std::vector<VariableInfo> & variables,const xla::Shape & input_host_shape,xla::Backend * backend,int device_ordinal,se::Stream * stream)98 xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
99     OpKernelContext* context, const std::vector<VariableInfo>& variables,
100     const xla::Shape& input_host_shape, xla::Backend* backend,
101     int device_ordinal, se::Stream* stream) {
102   profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
103   OpInputList var_list;
104   TF_RETURN_IF_ERROR(context->input_list("vars", &var_list));
105 
106   if (var_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
107     return errors::InvalidArgument(
108         "Number of variables (", var_list.size(),
109         ") does not match input shape: ",
110         xla::ShapeUtil::TupleElementCount(input_host_shape));
111   }
112 
113   auto validate_shape = [&](int i, const Tensor& tensor) {
114     const xla::Shape& expected =
115         xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
116     VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
117     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
118 
119     if (xla_tensor == nullptr) {
120       // FromTensor failed; tensor must be empty.
121       if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
122         return errors::InvalidArgument(
123             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
124             context->op_kernel().requested_input(i), "). Expected ",
125             expected.DebugString(),
126             "; got empty tensor. If you are running "
127             "with TF2 TPU, make sure you set `drop_remainder=False` when "
128             "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
129             "size can be handled");
130       }
131     } else {
132       const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
133       if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
134         return errors::InvalidArgument(
135             "Run-time shape mismatch for TPUReshardVariables argument[", i,
136             "] (", context->op_kernel().requested_input(i), "). Expected ",
137             expected.DebugString(), "; got ", xla_shape.DebugString());
138       }
139     }
140 
141     return OkStatus();
142   };
143 
144   for (int i = 0; i < variables.size(); ++i) {
145     TF_RETURN_IF_ERROR(
146         validate_shape(variables[i].index(), *variables[i].var()->tensor()));
147   }
148 
149   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
150   xla::TransferManager* const transfer_manager = backend->transfer_manager();
151 
152   xla::ShapeTree<xla::MaybeOwningDeviceMemory> input_buffers(
153       transfer_manager->HostShapeToDeviceShape(input_host_shape));
154 
155   // Allocates a buffer for the root tuple.
156   const int64_t root_size =
157       transfer_manager->GetByteSizeRequirement(input_buffers.shape());
158   TF_ASSIGN_OR_RETURN(*input_buffers.mutable_element({}),
159                       allocator->Allocate(device_ordinal, root_size));
160 
161   auto set_input_buffers_helper = [&](int arg_index, xla::ShapedBuffer* buffers,
162                                       bool owning = false) {
163     buffers->buffers().ForEachMutableElement(
164         [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
165           xla::ShapeIndex in_index = {arg_index};
166           for (int64_t j : index) {
167             in_index.push_back(j);
168           }
169           if (owning) {
170             *input_buffers.mutable_element(in_index) =
171                 se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
172             *buffer = se::DeviceMemoryBase();
173           } else {
174             *input_buffers.mutable_element(in_index) = *buffer;
175           }
176         });
177   };
178 
179   // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
180   // buffers for zero-element tensors where required.
181   auto assign_input = [&](int i, const Tensor& tensor) -> xla::Status {
182     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
183 
184     // Size 0 tensors have no backing XlaTensor, but may still need to have
185     // tuple buffers allocated.
186     if (xla_tensor == nullptr) {
187       CHECK_EQ(tensor.NumElements(), 0);
188       const xla::Shape& host_shape =
189           xla::ShapeUtil::GetSubshape(input_host_shape, {i});
190       TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
191                           transfer_manager->AllocateScopedShapedBuffer(
192                               host_shape, allocator, device_ordinal));
193       set_input_buffers_helper(/*arg_index=*/i, &buffers);
194     } else {
195       set_input_buffers_helper(/*arg_index=*/i, &xla_tensor->shaped_buffer(),
196                                tensor.RefCountIsOne());
197       xla_tensor->WaitForDefinitionEventOnStream(stream);
198     }
199     return OkStatus();
200   };
201 
202   for (int i = 0; i < var_list.size(); ++i) {
203     TF_RET_CHECK(var_list[i].dtype() == DT_RESOURCE);
204     TF_RETURN_IF_ERROR(assign_input(i, *variables[i].var()->tensor()));
205   }
206 
207   return std::move(input_buffers);
208 }
209 
210 // Perform a compaction to reduce fragmentation.
PerformCompaction(stream_executor::Stream * stream)211 Status PerformCompaction(stream_executor::Stream* stream) {
212   profiler::TraceMe trace_me("PerformCompaction", /*level=*/2);
213   auto* ds_executor =
214       down_cast<tpu::TpuExecutorInterface*>(stream->parent()->implementation());
215   TF_RETURN_IF_ERROR(ds_executor->EnqueueCompactionOnStreamForHbm(stream));
216   // LoadProgram and GetOrCreateConstantHandle are not managed by stream
217   // dependencies but they write to shared memory, so we need to block here to
218   // prevent those operations from racing.
219   return stream->BlockHostUntilDone();
220 }
221 
222 // Updates the variables to the execution result's buffers, and deallocates the
223 // root tuple buffer.
UpdateOutputVariables(OpKernelContext * context,xla::ScopedShapedBuffer result_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,xla::Backend * backend,se::Stream * stream,int device_ordinal,const std::vector<VariableInfo> & variables,const std::shared_ptr<se::Event> & definition_event)224 Status UpdateOutputVariables(
225     OpKernelContext* context, xla::ScopedShapedBuffer result_buffers,
226     absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
227     xla::Backend* backend, se::Stream* stream, int device_ordinal,
228     const std::vector<VariableInfo>& variables,
229     const std::shared_ptr<se::Event>& definition_event) {
230   profiler::TraceMe trace_me("UpdateOutputVariables", /*level=*/2);
231   // Shapes of the outputs, in TensorShape form.
232   const int64_t sub_elements =
233       xla::ShapeUtil::TupleElementCount(result_buffers.on_host_shape());
234   if (sub_elements != output_tensor_shape_protos.size()) {
235     return errors::InvalidArgument(
236         "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
237         output_tensor_shape_protos.size());
238   }
239 
240   if (sub_elements != variables.size()) {
241     return errors::InvalidArgument(
242         "Output count does not equal varaible count: ", sub_elements, " vs. ",
243         variables.size());
244   }
245 
246   std::vector<TensorShape> output_tensor_shapes;
247   output_tensor_shapes.reserve(sub_elements);
248   for (int64_t i = 0; i < sub_elements; ++i) {
249     TF_RETURN_IF_ERROR(
250         TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
251     TensorShape shape(*output_tensor_shape_protos[i]);
252     const xla::Shape& xla_shape =
253         xla::ShapeUtil::GetSubshape(result_buffers.on_host_shape(), {i});
254     if (!xla_shape.IsArray() ||
255         xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
256       return errors::InvalidArgument(
257           "Mismatched number of elements in output shape: ",
258           xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
259     }
260     output_tensor_shapes.push_back(shape);
261     VLOG(2) << "Output " << i << " shape " << shape.DebugString();
262   }
263 
264   // Build a shaped buffer for the outputs.
265   TF_RET_CHECK(result_buffers.on_host_shape().IsTuple());
266   TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(result_buffers.on_host_shape()));
267 
268   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
269 
270   auto output_buffers = result_buffers.release();
271   const xla::Shape& output_host_shape = output_buffers.on_host_shape();
272   const xla::Shape& output_device_shape = output_buffers.on_device_shape();
273 
274   // Transfers ownership of the buffers that back XLA computation output 'i'
275   // to 'output_tensor'.
276   auto transfer_buffers = [&](int i, Tensor* output_tensor) {
277     const xla::Shape& host_shape =
278         xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
279     const xla::Shape& device_shape =
280         xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
281     if (output_tensor->NumElements() > 0) {
282       xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
283                                             device_ordinal);
284       shaped_buffer.buffers().ForEachMutableElement(
285           [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
286             xla::ShapeIndex out_index = {i};
287             for (int64_t j : index) {
288               out_index.push_back(j);
289             }
290             *buffer = output_buffers.buffers().element(out_index);
291           });
292 
293       XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
294       xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
295       xla_tensor->ResetDefinitionEvent(definition_event, stream);
296     }
297   };
298 
299   for (int i = 0; i < variables.size(); ++i) {
300     TF_RETURN_IF_ERROR(context->allocate_temp(
301         variables[i].var()->tensor()->dtype(), output_tensor_shapes[i],
302         variables[i].var()->tensor()));
303     transfer_buffers(i, variables[i].var()->tensor());
304   }
305   return allocator->Deallocate(output_buffers.device_ordinal(),
306                                output_buffers.buffer({}));
307 }
308 
309 }  // namespace reshard_variables
310 }  // namespace tpu
311 }  // namespace tensorflow
312