1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
17
18 #include "tensorflow/compiler/jit/xla_device.h"
19 #include "tensorflow/compiler/jit/xla_launch_util.h"
20 #include "tensorflow/compiler/jit/xla_tensor.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_mgr.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
35 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
36 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
37 #include "tensorflow/core/tpu/tpu_configuration.h"
38 #include "tensorflow/core/tpu/tpu_defs.h"
39 #include "tensorflow/core/tpu/tpu_execute.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41 #include "tensorflow/stream_executor/device_memory_allocator.h"
42 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
44 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
45
46 namespace tensorflow {
47 namespace tpu {
48 namespace reshard_variables {
49
FlushProgramMemory(se::Platform * platform,int device_ordinal)50 Status FlushProgramMemory(se::Platform* platform, int device_ordinal) {
51 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
52 tpu::TpuNodeContext::Create(device_ordinal));
53
54 auto* executor = tensorflow::down_cast<tpu::TpuExecutorInterface*>(
55 node_interfaces->stream_executor()->implementation());
56 return executor->UnloadAllPrograms();
57 }
58
CheckIsValidKey(const Tensor & key)59 Status CheckIsValidKey(const Tensor& key) {
60 if (!TensorShapeUtils::IsVector(key.shape()) ||
61 key.shape().dim_size(0) != 3) {
62 return errors::InvalidArgument(
63 "new_format_key argument to TPUReshardVariables must be a 3-element "
64 "vector");
65 }
66 if (key.dtype() != DT_STRING) {
67 return errors::InvalidArgument(
68 "new_format_key argument to TPUReshardVariables must be DT_STRING "
69 "type");
70 }
71 return OkStatus();
72 }
73
IsDefaultKey(const Tensor & key)74 bool IsDefaultKey(const Tensor& key) { return key.vec<tstring>()(0).empty(); }
75
76 // Looks up the input `key` in the compilation cache, populating
77 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(const Tensor & key,string * rendezvous_key_base,std::unique_ptr<tpu::CompilationCacheEntryRef> * entry,tpu::CompilationCacheFetchTarget fetch_target)78 Status GetComputationCacheEntry(
79 const Tensor& key, string* rendezvous_key_base,
80 std::unique_ptr<tpu::CompilationCacheEntryRef>* entry,
81 tpu::CompilationCacheFetchTarget fetch_target) {
82 profiler::TraceMe trace_me("TPUReshardVariablesOpKernel::LookupProto",
83 /*level=*/2);
84 TF_RETURN_IF_ERROR(CheckIsValidKey(key));
85 auto* rmgr = GetTPUConfigResourceMgr();
86 tpu::TpuCompilationCacheLookup* proto_lookup;
87 TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
88 tpu::kCompiledProtoCacheResourceName,
89 &proto_lookup));
90 core::ScopedUnref lookup_unref(proto_lookup);
91 TF_RETURN_IF_ERROR(
92 proto_lookup->Lookup(key.vec<tstring>()(0), entry, fetch_target));
93 *rendezvous_key_base = key.vec<tstring>()(1);
94 return OkStatus();
95 }
96
97 // Builds an InputBuffers object that describes the inputs to the computation.
BuildInputBuffers(OpKernelContext * context,const std::vector<VariableInfo> & variables,const xla::Shape & input_host_shape,xla::Backend * backend,int device_ordinal,se::Stream * stream)98 xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
99 OpKernelContext* context, const std::vector<VariableInfo>& variables,
100 const xla::Shape& input_host_shape, xla::Backend* backend,
101 int device_ordinal, se::Stream* stream) {
102 profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
103 OpInputList var_list;
104 TF_RETURN_IF_ERROR(context->input_list("vars", &var_list));
105
106 if (var_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
107 return errors::InvalidArgument(
108 "Number of variables (", var_list.size(),
109 ") does not match input shape: ",
110 xla::ShapeUtil::TupleElementCount(input_host_shape));
111 }
112
113 auto validate_shape = [&](int i, const Tensor& tensor) {
114 const xla::Shape& expected =
115 xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
116 VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
117 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
118
119 if (xla_tensor == nullptr) {
120 // FromTensor failed; tensor must be empty.
121 if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
122 return errors::InvalidArgument(
123 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
124 context->op_kernel().requested_input(i), "). Expected ",
125 expected.DebugString(),
126 "; got empty tensor. If you are running "
127 "with TF2 TPU, make sure you set `drop_remainder=False` when "
128 "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
129 "size can be handled");
130 }
131 } else {
132 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
133 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
134 return errors::InvalidArgument(
135 "Run-time shape mismatch for TPUReshardVariables argument[", i,
136 "] (", context->op_kernel().requested_input(i), "). Expected ",
137 expected.DebugString(), "; got ", xla_shape.DebugString());
138 }
139 }
140
141 return OkStatus();
142 };
143
144 for (int i = 0; i < variables.size(); ++i) {
145 TF_RETURN_IF_ERROR(
146 validate_shape(variables[i].index(), *variables[i].var()->tensor()));
147 }
148
149 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
150 xla::TransferManager* const transfer_manager = backend->transfer_manager();
151
152 xla::ShapeTree<xla::MaybeOwningDeviceMemory> input_buffers(
153 transfer_manager->HostShapeToDeviceShape(input_host_shape));
154
155 // Allocates a buffer for the root tuple.
156 const int64_t root_size =
157 transfer_manager->GetByteSizeRequirement(input_buffers.shape());
158 TF_ASSIGN_OR_RETURN(*input_buffers.mutable_element({}),
159 allocator->Allocate(device_ordinal, root_size));
160
161 auto set_input_buffers_helper = [&](int arg_index, xla::ShapedBuffer* buffers,
162 bool owning = false) {
163 buffers->buffers().ForEachMutableElement(
164 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
165 xla::ShapeIndex in_index = {arg_index};
166 for (int64_t j : index) {
167 in_index.push_back(j);
168 }
169 if (owning) {
170 *input_buffers.mutable_element(in_index) =
171 se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
172 *buffer = se::DeviceMemoryBase();
173 } else {
174 *input_buffers.mutable_element(in_index) = *buffer;
175 }
176 });
177 };
178
179 // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
180 // buffers for zero-element tensors where required.
181 auto assign_input = [&](int i, const Tensor& tensor) -> xla::Status {
182 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
183
184 // Size 0 tensors have no backing XlaTensor, but may still need to have
185 // tuple buffers allocated.
186 if (xla_tensor == nullptr) {
187 CHECK_EQ(tensor.NumElements(), 0);
188 const xla::Shape& host_shape =
189 xla::ShapeUtil::GetSubshape(input_host_shape, {i});
190 TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
191 transfer_manager->AllocateScopedShapedBuffer(
192 host_shape, allocator, device_ordinal));
193 set_input_buffers_helper(/*arg_index=*/i, &buffers);
194 } else {
195 set_input_buffers_helper(/*arg_index=*/i, &xla_tensor->shaped_buffer(),
196 tensor.RefCountIsOne());
197 xla_tensor->WaitForDefinitionEventOnStream(stream);
198 }
199 return OkStatus();
200 };
201
202 for (int i = 0; i < var_list.size(); ++i) {
203 TF_RET_CHECK(var_list[i].dtype() == DT_RESOURCE);
204 TF_RETURN_IF_ERROR(assign_input(i, *variables[i].var()->tensor()));
205 }
206
207 return std::move(input_buffers);
208 }
209
210 // Perform a compaction to reduce fragmentation.
PerformCompaction(stream_executor::Stream * stream)211 Status PerformCompaction(stream_executor::Stream* stream) {
212 profiler::TraceMe trace_me("PerformCompaction", /*level=*/2);
213 auto* ds_executor =
214 down_cast<tpu::TpuExecutorInterface*>(stream->parent()->implementation());
215 TF_RETURN_IF_ERROR(ds_executor->EnqueueCompactionOnStreamForHbm(stream));
216 // LoadProgram and GetOrCreateConstantHandle are not managed by stream
217 // dependencies but they write to shared memory, so we need to block here to
218 // prevent those operations from racing.
219 return stream->BlockHostUntilDone();
220 }
221
222 // Updates the variables to the execution result's buffers, and deallocates the
223 // root tuple buffer.
UpdateOutputVariables(OpKernelContext * context,xla::ScopedShapedBuffer result_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,xla::Backend * backend,se::Stream * stream,int device_ordinal,const std::vector<VariableInfo> & variables,const std::shared_ptr<se::Event> & definition_event)224 Status UpdateOutputVariables(
225 OpKernelContext* context, xla::ScopedShapedBuffer result_buffers,
226 absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
227 xla::Backend* backend, se::Stream* stream, int device_ordinal,
228 const std::vector<VariableInfo>& variables,
229 const std::shared_ptr<se::Event>& definition_event) {
230 profiler::TraceMe trace_me("UpdateOutputVariables", /*level=*/2);
231 // Shapes of the outputs, in TensorShape form.
232 const int64_t sub_elements =
233 xla::ShapeUtil::TupleElementCount(result_buffers.on_host_shape());
234 if (sub_elements != output_tensor_shape_protos.size()) {
235 return errors::InvalidArgument(
236 "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
237 output_tensor_shape_protos.size());
238 }
239
240 if (sub_elements != variables.size()) {
241 return errors::InvalidArgument(
242 "Output count does not equal varaible count: ", sub_elements, " vs. ",
243 variables.size());
244 }
245
246 std::vector<TensorShape> output_tensor_shapes;
247 output_tensor_shapes.reserve(sub_elements);
248 for (int64_t i = 0; i < sub_elements; ++i) {
249 TF_RETURN_IF_ERROR(
250 TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
251 TensorShape shape(*output_tensor_shape_protos[i]);
252 const xla::Shape& xla_shape =
253 xla::ShapeUtil::GetSubshape(result_buffers.on_host_shape(), {i});
254 if (!xla_shape.IsArray() ||
255 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
256 return errors::InvalidArgument(
257 "Mismatched number of elements in output shape: ",
258 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
259 }
260 output_tensor_shapes.push_back(shape);
261 VLOG(2) << "Output " << i << " shape " << shape.DebugString();
262 }
263
264 // Build a shaped buffer for the outputs.
265 TF_RET_CHECK(result_buffers.on_host_shape().IsTuple());
266 TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(result_buffers.on_host_shape()));
267
268 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
269
270 auto output_buffers = result_buffers.release();
271 const xla::Shape& output_host_shape = output_buffers.on_host_shape();
272 const xla::Shape& output_device_shape = output_buffers.on_device_shape();
273
274 // Transfers ownership of the buffers that back XLA computation output 'i'
275 // to 'output_tensor'.
276 auto transfer_buffers = [&](int i, Tensor* output_tensor) {
277 const xla::Shape& host_shape =
278 xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
279 const xla::Shape& device_shape =
280 xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
281 if (output_tensor->NumElements() > 0) {
282 xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
283 device_ordinal);
284 shaped_buffer.buffers().ForEachMutableElement(
285 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
286 xla::ShapeIndex out_index = {i};
287 for (int64_t j : index) {
288 out_index.push_back(j);
289 }
290 *buffer = output_buffers.buffers().element(out_index);
291 });
292
293 XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
294 xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
295 xla_tensor->ResetDefinitionEvent(definition_event, stream);
296 }
297 };
298
299 for (int i = 0; i < variables.size(); ++i) {
300 TF_RETURN_IF_ERROR(context->allocate_temp(
301 variables[i].var()->tensor()->dtype(), output_tensor_shapes[i],
302 variables[i].var()->tensor()));
303 transfer_buffers(i, variables[i].var()->tensor());
304 }
305 return allocator->Deallocate(output_buffers.device_ordinal(),
306 output_buffers.buffer({}));
307 }
308
309 } // namespace reshard_variables
310 } // namespace tpu
311 } // namespace tensorflow
312