xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h"
17 
18 #include "tensorflow/compiler/jit/xla_device.h"
19 #include "tensorflow/compiler/jit/xla_launch_util.h"
20 #include "tensorflow/compiler/jit/xla_tensor.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_mgr.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
35 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
36 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
37 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
38 #include "tensorflow/core/tpu/tpu_configuration.h"
39 #include "tensorflow/core/tpu/tpu_defs.h"
40 #include "tensorflow/core/tpu/tpu_execute.h"
41 #include "tensorflow/core/util/stream_executor_util.h"
42 #include "tensorflow/stream_executor/device_memory_allocator.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
44 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
45 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
46 
47 namespace tensorflow {
48 
49 namespace reshard_util = ::tensorflow::tpu::reshard_variables;
50 
TPUReshardVariablesOpKernel(OpKernelConstruction * context)51 TPUReshardVariablesOpKernel::TPUReshardVariablesOpKernel(
52     OpKernelConstruction* context)
53     : AsyncOpKernel(context, /* is_deferred = */ true) {
54   OP_REQUIRES_OK(context, context->GetAttr("N", &num_vars_));
55 }
56 
ComputeAsync(OpKernelContext * context,DoneCallback done)57 void TPUReshardVariablesOpKernel::ComputeAsync(OpKernelContext* context,
58                                                DoneCallback done) {
59   // If TPU launches are asynchronous, then perform the launch on this thread
60   // to avoid a thread hop, which has an observable latency cost.
61   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
62   done();
63 }
64 
DoWork(OpKernelContext * context)65 Status TPUReshardVariablesOpKernel::DoWork(OpKernelContext* context) {
66   VLOG(1) << "Cloud TPU: TPUReshardVariablesOpKernel::DoWork";
67   TF_RET_CHECK(context->input_dtype(num_vars_) == DT_STRING);
68   const Tensor* new_format_key;
69   TF_RETURN_IF_ERROR(context->input("new_format_key", &new_format_key));
70   TF_RETURN_IF_ERROR(reshard_util::CheckIsValidKey(*new_format_key));
71 
72   TF_RET_CHECK(context->input_dtype(num_vars_ + 1) == DT_RESOURCE);
73   const ResourceHandle& handle = HandleFromInput(context, num_vars_ + 1);
74   core::RefCountPtr<Var> format_state_var;
75   TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
76       context, handle, &format_state_var, [new_format_key](Var** ptr) {
77         *ptr = new Var(new_format_key->dtype());
78         return OkStatus();
79       }));
80   mutex_lock ml(*format_state_var->mu());
81   const bool initialized = format_state_var->is_initialized;
82   if (initialized) {
83     TF_RETURN_IF_ERROR(
84         reshard_util::CheckIsValidKey(*format_state_var->tensor()));
85   }
86 
87   const bool state_is_default =
88       !initialized || reshard_util::IsDefaultKey(*format_state_var->tensor());
89   const bool new_format_is_default =
90       reshard_util::IsDefaultKey(*new_format_key);
91 
92   if ((state_is_default && new_format_is_default) ||
93       (initialized && format_state_var->tensor()->vec<tstring>()(2) ==
94                           new_format_key->vec<tstring>()(2))) {
95     VLOG(1) << "Sharding unchanged, nothing to do.";
96     return OkStatus();
97   }
98 
99   if (!state_is_default) {
100     // Convert the current format to default (unsharded).
101     VLOG(1) << "Unsharding with key: "
102             << format_state_var->tensor()->vec<tstring>()(2);
103     TF_RETURN_IF_ERROR(
104         DoTpuExecute(context, *format_state_var->tensor(),
105                      tpu::CompilationCacheFetchTarget::UNSHARDING));
106   }
107 
108   if (!new_format_is_default) {
109     // Convert the new format.
110     VLOG(1) << "Sharding with key: " << new_format_key->vec<tstring>()(2);
111     TF_RETURN_IF_ERROR(DoTpuExecute(
112         context, *new_format_key, tpu::CompilationCacheFetchTarget::SHARDING));
113   }
114 
115   // Change the state.
116   *format_state_var->tensor() = *new_format_key;
117   format_state_var->is_initialized = true;
118   return OkStatus();
119 }
120 
DoTpuExecute(OpKernelContext * context,const Tensor & format_key,tpu::CompilationCacheFetchTarget fetch_target)121 Status TPUReshardVariablesOpKernel::DoTpuExecute(
122     OpKernelContext* context, const Tensor& format_key,
123     tpu::CompilationCacheFetchTarget fetch_target) {
124   const XlaDevice::Metadata* metadata;
125   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
126   const int device_ordinal = metadata->device_ordinal();
127 
128   // We are guaranteed that the underlying object won't be deleted out from
129   // under us
130   TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
131                       tpu::TpuNodeContext::Create(device_ordinal));
132 
133   profiler::TraceMe trace_me(
134       [device_ordinal] {
135         return profiler::TraceMeEncode("TPUReshardVariablesOpKernel",
136                                        {{"device_ordinal", device_ordinal}});
137       },
138       /*level=*/2);
139   profiler::TraceMe trace_me_init("TPUReshardVariablesOpKernel::Init",
140                                   /*level=*/2);
141 
142   string rendezvous_key_base;
143   std::unique_ptr<tpu::CompilationCacheEntryRef> entry_ref;
144   TF_RETURN_IF_ERROR(reshard_util::GetComputationCacheEntry(
145       format_key, &rendezvous_key_base, &entry_ref, fetch_target));
146   tpu::TpuCompilationCacheEntry entry = entry_ref->get();
147   if (entry.tpu_program_group() == nullptr) {
148     VLOG(2) << "Sharding/unsharding program does not exist, so this is default "
149                "sharding.";
150     return OkStatus();
151   }
152 
153   const tpu::TpuProgramGroupInterface* tpu_program_group =
154       entry.tpu_program_group();
155   const int core_index = entry.core_index();
156   const TPUExecutableInfoProto& executable_info_proto =
157       tpu_program_group->executable_info(core_index);
158   const TPUExecutableInfoProto* executable = &executable_info_proto;
159 
160   xla::Backend* const backend = node_interfaces->backend();
161   xla::TransferManager* const transfer_manager = backend->transfer_manager();
162 
163   CHECK(context->op_device_context());
164   se::Stream* stream = context->op_device_context()->stream();
165 
166   TF_RET_CHECK(executable->input_shapes_size() == 1);
167   xla::Shape host_shape(executable->input_shapes(0));
168   std::vector<VariableInfo> variables;
169   for (int i = 0; i < num_vars_; ++i) {
170     TF_RET_CHECK(context->input_dtype(i) == DT_RESOURCE);
171     const ResourceHandle& handle = HandleFromInput(context, i);
172     Var* variable;
173     TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
174     variables.push_back(VariableInfo(i, handle.name(), variable));
175   }
176 
177   // Block for previous TPUExecute ops so that the memory used for them could be
178   // freed.
179   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
180   // Lock variables to prevent concurrent access.
181   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
182 
183   // Build input buffers.
184   TF_ASSIGN_OR_RETURN(auto input_buffers, reshard_util::BuildInputBuffers(
185                                               context, variables, host_shape,
186                                               backend, device_ordinal, stream));
187   xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
188                                   device_ordinal);
189   shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
190       [](const xla::MaybeOwningDeviceMemory& buffer) {
191         return buffer.AsDeviceMemoryBase();
192       }));
193 
194   // Write input root tuple.
195   TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
196                       backend->BorrowStream(device_ordinal));
197   if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
198                                                      shaped_buffer)) {
199     TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
200         transfer_stream_ptr.get(), shaped_buffer));
201     stream->ThenWaitFor(transfer_stream_ptr.get());
202   } else {
203     TF_RETURN_IF_ERROR(
204         transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
205   }
206   VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
207 
208   TF_RET_CHECK(!executable->has_session_module())
209       << "session module not supported in sharding/unsharding program.";
210 
211   auto definition_event = std::make_shared<se::Event>(stream->parent());
212   TF_RET_CHECK(definition_event->Init())
213       << "TPU definition event initialization failed";
214 
215   trace_me_init.Stop();
216 
217   // Execute the program.
218   std::unique_ptr<xla::DeviceAssignment> device_assignment;
219   if (executable->has_device_assignment()) {
220     TF_ASSIGN_OR_RETURN(
221         device_assignment,
222         xla::DeviceAssignment::Deserialize(executable->device_assignment()));
223   }
224   std::vector<xla::ExecutionInput> input;
225   input.emplace_back(xla::ExecutionInput(std::move(input_buffers),
226                                          shaped_buffer.on_host_shape()));
227 
228   const TPUHostTransferInfoProto& host_transfer_info =
229       tpu_program_group->host_transfer_info(core_index);
230 
231   TF_ASSIGN_OR_RETURN(
232       xla::ExecutionOutput output,
233       TPUExecute(*executable, host_transfer_info,
234                  *tpu_program_group->hlo_metadatas()[core_index],
235                  std::move(input), rendezvous_key_base, GetXLARandomSeed(),
236                  node_interfaces.get(), device_assignment.get(),
237                  context->cancellation_manager(), context, stream,
238                  transfer_stream_ptr.get(),
239                  tpu_program_group->tpu_program(core_index)));
240 
241   stream->ThenRecordEvent(definition_event.get());
242 
243   // Assign the new buffers to the variables.
244   xla::ScopedShapedBuffer result = output.ConsumeResult();
245 
246   // Only perform compaction when sharding.
247   // NOTE: Compaction is not supported on some TPUs, see b/168322060 for details
248   if (node_interfaces->CompactionSupported(device_ordinal) &&
249       fetch_target == tpu::CompilationCacheFetchTarget::SHARDING) {
250     // Block until program execution is done so that input, output, and program
251     // cache memory can be actually released.
252     TF_RETURN_IF_ERROR(transfer_stream_ptr->BlockHostUntilDone());
253     TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
254     {
255       // Explicitly release any RAII objects owning on-device allocations.
256       auto unused = output.ConsumeToBeReleased();
257     }
258     // Release variables holding inputs.
259     for (int i = 0; i < variables.size(); ++i) {
260       *variables[i].var()->tensor() =
261           Tensor(variables[i].var()->tensor()->dtype());
262     }
263     // Flush on-device program memory cache.
264     TF_RETURN_IF_ERROR(
265         reshard_util::FlushProgramMemory(backend->platform(), device_ordinal));
266     TF_RETURN_IF_ERROR(reshard_util::PerformCompaction(stream));
267   }
268   return reshard_util::UpdateOutputVariables(
269       context, std::move(result), executable->output_tensor_shapes(), backend,
270       stream, device_ordinal, variables, definition_event);
271 }
272 
273 TPUReshardVariablesOpKernel::~TPUReshardVariablesOpKernel() = default;
274 
275 #if !defined(PLATFORM_GOOGLE)
276 REGISTER_KERNEL_BUILDER(Name("TPUReshardVariables")
277                             .Device(DEVICE_TPU_NODE)
278                             .HostMemory("format_state_var")
279                             .HostMemory("new_format_key"),
280                         TPUReshardVariablesOpKernel);
281 #endif
282 
283 }  // namespace tensorflow
284