1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h"
17
18 #include "tensorflow/compiler/jit/xla_device.h"
19 #include "tensorflow/compiler/jit/xla_launch_util.h"
20 #include "tensorflow/compiler/jit/xla_tensor.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_mgr.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
35 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
36 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
37 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
38 #include "tensorflow/core/tpu/tpu_configuration.h"
39 #include "tensorflow/core/tpu/tpu_defs.h"
40 #include "tensorflow/core/tpu/tpu_execute.h"
41 #include "tensorflow/core/util/stream_executor_util.h"
42 #include "tensorflow/stream_executor/device_memory_allocator.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
44 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
45 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
46
47 namespace tensorflow {
48
49 namespace reshard_util = ::tensorflow::tpu::reshard_variables;
50
TPUReshardVariablesOpKernel(OpKernelConstruction * context)51 TPUReshardVariablesOpKernel::TPUReshardVariablesOpKernel(
52 OpKernelConstruction* context)
53 : AsyncOpKernel(context, /* is_deferred = */ true) {
54 OP_REQUIRES_OK(context, context->GetAttr("N", &num_vars_));
55 }
56
ComputeAsync(OpKernelContext * context,DoneCallback done)57 void TPUReshardVariablesOpKernel::ComputeAsync(OpKernelContext* context,
58 DoneCallback done) {
59 // If TPU launches are asynchronous, then perform the launch on this thread
60 // to avoid a thread hop, which has an observable latency cost.
61 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
62 done();
63 }
64
DoWork(OpKernelContext * context)65 Status TPUReshardVariablesOpKernel::DoWork(OpKernelContext* context) {
66 VLOG(1) << "Cloud TPU: TPUReshardVariablesOpKernel::DoWork";
67 TF_RET_CHECK(context->input_dtype(num_vars_) == DT_STRING);
68 const Tensor* new_format_key;
69 TF_RETURN_IF_ERROR(context->input("new_format_key", &new_format_key));
70 TF_RETURN_IF_ERROR(reshard_util::CheckIsValidKey(*new_format_key));
71
72 TF_RET_CHECK(context->input_dtype(num_vars_ + 1) == DT_RESOURCE);
73 const ResourceHandle& handle = HandleFromInput(context, num_vars_ + 1);
74 core::RefCountPtr<Var> format_state_var;
75 TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
76 context, handle, &format_state_var, [new_format_key](Var** ptr) {
77 *ptr = new Var(new_format_key->dtype());
78 return OkStatus();
79 }));
80 mutex_lock ml(*format_state_var->mu());
81 const bool initialized = format_state_var->is_initialized;
82 if (initialized) {
83 TF_RETURN_IF_ERROR(
84 reshard_util::CheckIsValidKey(*format_state_var->tensor()));
85 }
86
87 const bool state_is_default =
88 !initialized || reshard_util::IsDefaultKey(*format_state_var->tensor());
89 const bool new_format_is_default =
90 reshard_util::IsDefaultKey(*new_format_key);
91
92 if ((state_is_default && new_format_is_default) ||
93 (initialized && format_state_var->tensor()->vec<tstring>()(2) ==
94 new_format_key->vec<tstring>()(2))) {
95 VLOG(1) << "Sharding unchanged, nothing to do.";
96 return OkStatus();
97 }
98
99 if (!state_is_default) {
100 // Convert the current format to default (unsharded).
101 VLOG(1) << "Unsharding with key: "
102 << format_state_var->tensor()->vec<tstring>()(2);
103 TF_RETURN_IF_ERROR(
104 DoTpuExecute(context, *format_state_var->tensor(),
105 tpu::CompilationCacheFetchTarget::UNSHARDING));
106 }
107
108 if (!new_format_is_default) {
109 // Convert the new format.
110 VLOG(1) << "Sharding with key: " << new_format_key->vec<tstring>()(2);
111 TF_RETURN_IF_ERROR(DoTpuExecute(
112 context, *new_format_key, tpu::CompilationCacheFetchTarget::SHARDING));
113 }
114
115 // Change the state.
116 *format_state_var->tensor() = *new_format_key;
117 format_state_var->is_initialized = true;
118 return OkStatus();
119 }
120
DoTpuExecute(OpKernelContext * context,const Tensor & format_key,tpu::CompilationCacheFetchTarget fetch_target)121 Status TPUReshardVariablesOpKernel::DoTpuExecute(
122 OpKernelContext* context, const Tensor& format_key,
123 tpu::CompilationCacheFetchTarget fetch_target) {
124 const XlaDevice::Metadata* metadata;
125 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
126 const int device_ordinal = metadata->device_ordinal();
127
128 // We are guaranteed that the underlying object won't be deleted out from
129 // under us
130 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
131 tpu::TpuNodeContext::Create(device_ordinal));
132
133 profiler::TraceMe trace_me(
134 [device_ordinal] {
135 return profiler::TraceMeEncode("TPUReshardVariablesOpKernel",
136 {{"device_ordinal", device_ordinal}});
137 },
138 /*level=*/2);
139 profiler::TraceMe trace_me_init("TPUReshardVariablesOpKernel::Init",
140 /*level=*/2);
141
142 string rendezvous_key_base;
143 std::unique_ptr<tpu::CompilationCacheEntryRef> entry_ref;
144 TF_RETURN_IF_ERROR(reshard_util::GetComputationCacheEntry(
145 format_key, &rendezvous_key_base, &entry_ref, fetch_target));
146 tpu::TpuCompilationCacheEntry entry = entry_ref->get();
147 if (entry.tpu_program_group() == nullptr) {
148 VLOG(2) << "Sharding/unsharding program does not exist, so this is default "
149 "sharding.";
150 return OkStatus();
151 }
152
153 const tpu::TpuProgramGroupInterface* tpu_program_group =
154 entry.tpu_program_group();
155 const int core_index = entry.core_index();
156 const TPUExecutableInfoProto& executable_info_proto =
157 tpu_program_group->executable_info(core_index);
158 const TPUExecutableInfoProto* executable = &executable_info_proto;
159
160 xla::Backend* const backend = node_interfaces->backend();
161 xla::TransferManager* const transfer_manager = backend->transfer_manager();
162
163 CHECK(context->op_device_context());
164 se::Stream* stream = context->op_device_context()->stream();
165
166 TF_RET_CHECK(executable->input_shapes_size() == 1);
167 xla::Shape host_shape(executable->input_shapes(0));
168 std::vector<VariableInfo> variables;
169 for (int i = 0; i < num_vars_; ++i) {
170 TF_RET_CHECK(context->input_dtype(i) == DT_RESOURCE);
171 const ResourceHandle& handle = HandleFromInput(context, i);
172 Var* variable;
173 TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
174 variables.push_back(VariableInfo(i, handle.name(), variable));
175 }
176
177 // Block for previous TPUExecute ops so that the memory used for them could be
178 // freed.
179 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
180 // Lock variables to prevent concurrent access.
181 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
182
183 // Build input buffers.
184 TF_ASSIGN_OR_RETURN(auto input_buffers, reshard_util::BuildInputBuffers(
185 context, variables, host_shape,
186 backend, device_ordinal, stream));
187 xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
188 device_ordinal);
189 shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
190 [](const xla::MaybeOwningDeviceMemory& buffer) {
191 return buffer.AsDeviceMemoryBase();
192 }));
193
194 // Write input root tuple.
195 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
196 backend->BorrowStream(device_ordinal));
197 if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
198 shaped_buffer)) {
199 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
200 transfer_stream_ptr.get(), shaped_buffer));
201 stream->ThenWaitFor(transfer_stream_ptr.get());
202 } else {
203 TF_RETURN_IF_ERROR(
204 transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
205 }
206 VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
207
208 TF_RET_CHECK(!executable->has_session_module())
209 << "session module not supported in sharding/unsharding program.";
210
211 auto definition_event = std::make_shared<se::Event>(stream->parent());
212 TF_RET_CHECK(definition_event->Init())
213 << "TPU definition event initialization failed";
214
215 trace_me_init.Stop();
216
217 // Execute the program.
218 std::unique_ptr<xla::DeviceAssignment> device_assignment;
219 if (executable->has_device_assignment()) {
220 TF_ASSIGN_OR_RETURN(
221 device_assignment,
222 xla::DeviceAssignment::Deserialize(executable->device_assignment()));
223 }
224 std::vector<xla::ExecutionInput> input;
225 input.emplace_back(xla::ExecutionInput(std::move(input_buffers),
226 shaped_buffer.on_host_shape()));
227
228 const TPUHostTransferInfoProto& host_transfer_info =
229 tpu_program_group->host_transfer_info(core_index);
230
231 TF_ASSIGN_OR_RETURN(
232 xla::ExecutionOutput output,
233 TPUExecute(*executable, host_transfer_info,
234 *tpu_program_group->hlo_metadatas()[core_index],
235 std::move(input), rendezvous_key_base, GetXLARandomSeed(),
236 node_interfaces.get(), device_assignment.get(),
237 context->cancellation_manager(), context, stream,
238 transfer_stream_ptr.get(),
239 tpu_program_group->tpu_program(core_index)));
240
241 stream->ThenRecordEvent(definition_event.get());
242
243 // Assign the new buffers to the variables.
244 xla::ScopedShapedBuffer result = output.ConsumeResult();
245
246 // Only perform compaction when sharding.
247 // NOTE: Compaction is not supported on some TPUs, see b/168322060 for details
248 if (node_interfaces->CompactionSupported(device_ordinal) &&
249 fetch_target == tpu::CompilationCacheFetchTarget::SHARDING) {
250 // Block until program execution is done so that input, output, and program
251 // cache memory can be actually released.
252 TF_RETURN_IF_ERROR(transfer_stream_ptr->BlockHostUntilDone());
253 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
254 {
255 // Explicitly release any RAII objects owning on-device allocations.
256 auto unused = output.ConsumeToBeReleased();
257 }
258 // Release variables holding inputs.
259 for (int i = 0; i < variables.size(); ++i) {
260 *variables[i].var()->tensor() =
261 Tensor(variables[i].var()->tensor()->dtype());
262 }
263 // Flush on-device program memory cache.
264 TF_RETURN_IF_ERROR(
265 reshard_util::FlushProgramMemory(backend->platform(), device_ordinal));
266 TF_RETURN_IF_ERROR(reshard_util::PerformCompaction(stream));
267 }
268 return reshard_util::UpdateOutputVariables(
269 context, std::move(result), executable->output_tensor_shapes(), backend,
270 stream, device_ordinal, variables, definition_event);
271 }
272
273 TPUReshardVariablesOpKernel::~TPUReshardVariablesOpKernel() = default;
274
275 #if !defined(PLATFORM_GOOGLE)
276 REGISTER_KERNEL_BUILDER(Name("TPUReshardVariables")
277 .Device(DEVICE_TPU_NODE)
278 .HostMemory("format_state_var")
279 .HostMemory("new_format_key"),
280 TPUReshardVariablesOpKernel);
281 #endif
282
283 } // namespace tensorflow
284