xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/move_compilation_to_host.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "mlir/Pass/PassManager.h"  // from @llvm-project
31 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
32 #include "mlir/Transforms/Passes.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/device_utils.h"
37 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
38 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
39 #include "tensorflow/dtensor/mlir/layout_parsing.h"
40 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41 #include "tensorflow/dtensor/mlir/value_utils.h"
42 
43 namespace tensorflow {
44 namespace dtensor {
45 namespace {
46 
47 // Prefix for send/recv key used for transferring compilation program key.
48 constexpr char kSendRecvKeyPrefix[] = "compilation_send_recv_key_";
49 
50 // Identifies all StatefulPartitionedCallOps for executing computation for
51 // each mesh cluster and validate that at most one TPU computation exists.
IdentifyAndValidateMeshComputations(mlir::func::FuncOp function,std::map<Mesh,mlir::TF::StatefulPartitionedCallOp> * function_map)52 mlir::LogicalResult IdentifyAndValidateMeshComputations(
53     mlir::func::FuncOp function,
54     std::map<Mesh, mlir::TF::StatefulPartitionedCallOp>* function_map) {
55   for (auto dtensor_function :
56        function.getOps<mlir::TF::StatefulPartitionedCallOp>()) {
57     auto mesh_or = ExtractDeviceMeshFromOp(dtensor_function);
58     if (!mesh_or.ok() || !mesh_or->has_value())
59       return dtensor_function.emitOpError(
60           "StatefulPartitionCall op must have `_mesh` attribute specified.");
61 
62     const Mesh& computation_mesh = mesh_or->value();
63     if (function_map->count(computation_mesh))
64       return dtensor_function.emitOpError(
65           "Found DTensor function with duplicate mesh specification. There "
66           "should be exactly 1 function for each mesh in computation cluster.");
67 
68     (*function_map)[computation_mesh] = dtensor_function;
69   }
70 
71   int num_xla_meshes = 0;
72   for (const auto& it : *function_map) {
73     if (it.first.is_tpu_mesh()) num_xla_meshes += 1;
74   }
75 
76   if (num_xla_meshes > 1)
77     return function.emitOpError(
78         "Multiple XLA computation clusters found. Only 1 XLA cluster for "
79         "DTensor computation is supported for now.");
80 
81   return mlir::success();
82 }
83 
84 // Creates Send/Recv ops to transfer TPUCompile program key from host
85 // computation to XLA computation.
CreateSendRecvOpsToTransferProgramKey(const Mesh & mesh,mlir::ModuleOp module,mlir::func::FuncOp function,mlir::OpBuilder::InsertPoint insertpoint,mlir::TF::_TPUCompileMlirOp compile_op,mlir::tf_device::LaunchOp compile_op_launch,int * num_send_recv,mlir::Value * program_key_output)86 mlir::LogicalResult CreateSendRecvOpsToTransferProgramKey(
87     const Mesh& mesh, mlir::ModuleOp module, mlir::func::FuncOp function,
88     mlir::OpBuilder::InsertPoint insertpoint,
89     mlir::TF::_TPUCompileMlirOp compile_op,
90     mlir::tf_device::LaunchOp compile_op_launch, int* num_send_recv,
91     mlir::Value* program_key_output) {
92   mlir::OpBuilder builder(module.getContext());
93   mlir::Value compilation_key = *compile_op.program().begin();
94   absl::Span<const std::string> local_devices = mesh.local_devices();
95 
96   // Create tensor name mapping for each send/recv pair.
97   llvm::SmallDenseMap<int, std::string> device_key_map;
98   const int num_tpu_devices = local_devices.size();
99   device_key_map.reserve(num_tpu_devices);
100   for (int i = 0; i < num_tpu_devices; ++i) {
101     std::string tensor_name = absl::StrCat(kSendRecvKeyPrefix, *num_send_recv);
102     *num_send_recv += 1;
103     device_key_map.try_emplace(i, std::move(tensor_name));
104   }
105 
106   // Create send op to send TPU program key from host computation to XLA
107   // computation.
108   builder.setInsertionPointAfter(compile_op);
109   for (int i = 0; i < num_tpu_devices; ++i) {
110     const std::string& tensor_name = device_key_map[i];
111     auto send = builder.create<mlir::TF::_HostSendOp>(
112         compile_op->getLoc(), compilation_key, tensor_name,
113         compile_op_launch.device(),
114         /*send_device_incarnation=*/0, local_devices[i]);
115     send->setAttr("device", compile_op_launch.deviceAttr());
116   }
117 
118   // Create Recv ops to receive program key from host to each xla device
119   // computation.
120   llvm::SmallVector<mlir::func::FuncOp, 4> compilation_key_functions;
121   compilation_key_functions.reserve(num_tpu_devices);
122   mlir::SymbolTable symbol_table(module);
123 
124   // For receiving TPU program key from host, `recv_device` attribute depends
125   // on `device_id` argument and therefore cannot be known statically.
126   // Therefore, we use tf.Case op to select correct receive op depending on
127   // the device id value.
128   for (int i = 0; i < num_tpu_devices; ++i) {
129     auto func_type = mlir::FunctionType::get(
130         builder.getContext(), llvm::ArrayRef<mlir::Type>{},
131         llvm::ArrayRef<mlir::Type>{compilation_key.getType()});
132 
133     mlir::func::FuncOp recv_select_fn = mlir::func::FuncOp::create(
134         compile_op.getLoc(),
135         llvm::formatv("recv_compile_key_{0}_{1}", i, *num_send_recv).str(),
136         func_type, llvm::ArrayRef<mlir::NamedAttribute>{});
137     symbol_table.insert(recv_select_fn);
138     *num_send_recv += 1;
139 
140     mlir::Block* fn_block = recv_select_fn.addEntryBlock();
141     mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockEnd(fn_block);
142     auto recv = fn_builder.create<mlir::TF::_HostRecvOp>(
143         compile_op->getLoc(),
144         compilation_key.getType().cast<mlir::TensorType>(), device_key_map[i],
145         compile_op_launch.device(), /*send_device_incarnation=*/0,
146         local_devices[i]);
147     recv->setAttr("device", builder.getStringAttr(local_devices[i]));
148 
149     fn_builder.create<mlir::func::ReturnOp>(recv_select_fn.getLoc(),
150                                             recv.tensor());
151 
152     compilation_key_functions.emplace_back(recv_select_fn);
153   }
154 
155   // Create logic that receives program key.
156   builder.restoreInsertionPoint(insertpoint);
157   auto device_id = GetDeviceOrdinal(mesh, function.getLoc(), function, &builder,
158                                     /*return_int64_type=*/false);
159   if (!device_id.ok()) return function->emitOpError("Cannot get device id");
160 
161   llvm::SmallVector<mlir::Attribute, 4> symbols;
162   for (auto& func : compilation_key_functions)
163     symbols.push_back(mlir::SymbolRefAttr::get(func));
164 
165   // Create a TF::Case op that selects `values` based on `id`.
166   auto program_key = builder.create<mlir::TF::CaseOp>(
167       compile_op.getLoc(),
168       /*output=*/llvm::SmallVector<mlir::Type, 4>{compilation_key.getType()},
169       /*branch_index=*/*device_id,
170       /*input=*/llvm::ArrayRef<mlir::Value>{},
171       /*branches=*/builder.getArrayAttr(symbols),
172       /*is_stateless=*/builder.getBoolAttr(false));
173   *program_key_output = program_key.getResult(0);
174   return mlir::success();
175 }
176 
177 struct CompilationKeyRecvInfo {
178   const Mesh& receiving_function_mesh;
179   mlir::func::FuncOp receiving_function;
180   mlir::OpBuilder::InsertPoint recv_insertion_point;
181   mlir::Value program_key;
182 };
183 
184 // Broadcasts compilation key across meshes specified by `recv_info`. The
185 // broadcasted compilation key is added to `program_key` of each vector
186 // element of `recv_info`.
SendRecvCompilationKey(const Mesh & host_mesh,mlir::ModuleOp module,mlir::TF::_TPUCompileMlirOp compile_op,mlir::tf_device::LaunchOp compile_launch_op,mlir::Operation * compilation_move_before,int * num_send_recv,llvm::SmallVectorImpl<CompilationKeyRecvInfo> * recv_info)187 mlir::LogicalResult SendRecvCompilationKey(
188     const Mesh& host_mesh, mlir::ModuleOp module,
189     mlir::TF::_TPUCompileMlirOp compile_op,
190     mlir::tf_device::LaunchOp compile_launch_op,
191     mlir::Operation* compilation_move_before, int* num_send_recv,
192     llvm::SmallVectorImpl<CompilationKeyRecvInfo>* recv_info) {
193   for (int i = 0; i < recv_info->size(); ++i) {
194     CompilationKeyRecvInfo& info = (*recv_info)[i];
195     // Create send/recv ops to transfer compilation key from receiving meshes.
196     mlir::Value program_key;
197     if (mlir::failed(CreateSendRecvOpsToTransferProgramKey(
198             info.receiving_function_mesh, module, info.receiving_function,
199             info.recv_insertion_point, compile_op, compile_launch_op,
200             num_send_recv, &program_key)))
201       return mlir::failure();
202 
203     info.program_key = program_key;
204   }
205 
206   return mlir::success();
207 }
208 
HandleCompilationOps(const llvm::SmallVectorImpl<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp> & compilation_key_ops,std::map<Mesh,mlir::TF::StatefulPartitionedCallOp> & computation_map,mlir::ModuleOp module,int * num_send_recv)209 mlir::LogicalResult HandleCompilationOps(
210     const llvm::SmallVectorImpl<
211         mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp>& compilation_key_ops,
212     std::map<Mesh, mlir::TF::StatefulPartitionedCallOp>& computation_map,
213     mlir::ModuleOp module, int* num_send_recv) {
214   // Identity XLA function and corresponding CPU functions to move compilation.
215   const auto xla_mesh = llvm::find_if(
216       computation_map, [](const auto& it) { return it.first.is_tpu_mesh(); });
217 
218   if (xla_mesh == computation_map.end()) {
219     return module.emitOpError(
220         "Found TPUCompilationKey op but XLA computation does not exist.");
221   }
222 
223   mlir::func::FuncOp tpu_function = xla_mesh->second.func();
224   mlir::func::FuncOp host_function;
225   Mesh host_mesh;
226   for (auto compilation_key : compilation_key_ops) {
227     auto parent_function =
228         compilation_key->getParentOfType<mlir::func::FuncOp>();
229 
230     if (!host_function) {
231       host_function = parent_function;
232       auto mesh_it = llvm::find_if(computation_map, [&](auto& it) {
233         return it.second.f() == host_function.getSymName();
234       });
235       if (mesh_it == computation_map.end())
236         return compilation_key.emitOpError(
237             "cannot find host mesh for TPU computation.");
238 
239       host_mesh = mesh_it->first;
240 
241     } else {
242       // TODO(hongjunchoi): Handle the case when CopyToMesh is used with
243       // special topology approach. In this case there will be 2 host
244       // meshes/functions.
245       if (host_function != parent_function)
246         return compilation_key.emitOpError(
247             "Found multiple TPU host mesh functions. There must be at most one "
248             "TPU host function.");
249     }
250   }
251 
252   // Identify TPUCompileOp to host side mesh.
253   llvm::SmallVector<mlir::TF::_TPUCompileMlirOp, 4> compile_ops;
254   tpu_function.walk(
255       [&](mlir::TF::_TPUCompileMlirOp op) { compile_ops.emplace_back(op); });
256 
257   const int num_compilations = compile_ops.size();
258   if (num_compilations != 1)
259     return tpu_function.emitOpError(llvm::formatv(
260         "Expected exactly 1 compilation op for TPU computation. Found {0}",
261         num_compilations));
262 
263   mlir::TF::_TPUCompileMlirOp compile_op = *compile_ops.begin();
264   mlir::Operation& first_host_op = host_function.getBody().front().front();
265   mlir::OpBuilder builder(&first_host_op);
266   mlir::OpBuilder::InsertPoint host_insertion_point =
267       builder.saveInsertionPoint();
268   mlir::Operation* compilation_move_before = &first_host_op;
269 
270   // If host mesh has multiple local devices only conduct compilation for the
271   // first host device by creating If Op to only compile for host with device
272   // ordinal 0.
273   if (host_mesh.local_device_ids().size() > 1) {
274     auto device_ordinal_host = GetDeviceOrdinal(
275         host_mesh, compile_op.getLoc(),
276         first_host_op.getParentOfType<mlir::func::FuncOp>(), &builder);
277     if (!device_ordinal_host.ok())
278       return compile_op.emitOpError(
279           llvm::formatv("error while creating TPU compilation logic. {0}",
280                         device_ordinal_host.status().error_message()));
281 
282     mlir::Value predicate_host = builder.create<mlir::TF::EqualOp>(
283         compile_op.getLoc(), *device_ordinal_host,
284         CreateIntScalarConst(0, builder, compile_op.getLoc()),
285         /*incompatible_shape_error=*/builder.getBoolAttr(true));
286 
287     // If op here contains send/recv and TPUCompile op that should not be pruned
288     // away. Therefore, we explicitly set the op to be stateful.
289     auto if_host = builder.create<mlir::TF::IfRegionOp>(
290         compile_op.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, predicate_host,
291         /*is_stateless=*/builder.getBoolAttr(false),
292         GetUniqueControlflowFnName("compilation_host_then", builder),
293         GetUniqueControlflowFnName("compilation_host_else", builder));
294 
295     // Create empty else branch region.
296     auto& host_else_branch = if_host.else_branch();
297     host_else_branch.push_back(new mlir::Block);
298     builder.setInsertionPointToEnd(&host_else_branch.front());
299     builder.create<mlir::TF::YieldOp>(
300         compile_op.getLoc(),
301         /*operands=*/llvm::ArrayRef<mlir::Value>{});
302 
303     // Create then branch region with logic to compile TPU program and send
304     // program key to all TPU devices.
305     auto& host_then_branch = if_host.then_branch();
306     host_then_branch.push_back(new mlir::Block);
307     builder.setInsertionPointToEnd(&host_then_branch.front());
308     auto yield = builder.create<mlir::TF::YieldOp>(
309         compile_op.getLoc(),
310         /*operands=*/llvm::ArrayRef<mlir::Value>{});
311     compilation_move_before = yield;
312 
313     builder.setInsertionPointAfter(if_host);
314     host_insertion_point = builder.saveInsertionPoint();
315   }
316 
317   auto compile_launch_op =
318       compile_op->getParentOfType<mlir::tf_device::LaunchOp>();
319 
320   // Move Compile op and compile succeeded assert ops to host function.
321   compile_launch_op->moveBefore(compilation_move_before);
322 
323   for (mlir::Operation* user : compile_launch_op.getResult(0).getUsers())
324     user->getParentOfType<mlir::tf_device::LaunchOp>()->moveBefore(
325         compilation_move_before);
326 
327   // Send and receive compilation key across meshes.
328   llvm::SmallVector<CompilationKeyRecvInfo, 4> compilation_key_recv_info;
329   builder.setInsertionPointToStart(&tpu_function.front());
330   auto device_insertion_point = builder.saveInsertionPoint();
331   compilation_key_recv_info.emplace_back(CompilationKeyRecvInfo{
332       xla_mesh->first, tpu_function, device_insertion_point, nullptr});
333 
334   compilation_key_recv_info.emplace_back(CompilationKeyRecvInfo{
335       host_mesh, host_function, host_insertion_point, nullptr});
336 
337   if (mlir::failed(SendRecvCompilationKey(
338           host_mesh, module, compile_op, compile_launch_op,
339           compilation_move_before, num_send_recv, &compilation_key_recv_info)))
340     return mlir::failure();
341 
342   // Replace usages of TPU program key in host and device meshes.
343   mlir::Value device_program_key = compilation_key_recv_info[0].program_key;
344   tpu_function.walk([&](mlir::Operation* op) {
345     if (llvm::isa<mlir::TF::TPUExecuteOp,
346                   mlir::TF::TPUExecuteAndUpdateVariablesOp>(op))
347       op->setOperand(op->getNumOperands() - 1, device_program_key);
348   });
349 
350   // Remove placeholder CompilationKey ops and replace it's usages with output
351   // of TPUCompile op.
352   mlir::Value host_program_key = compilation_key_recv_info[1].program_key;
353   for (auto compilation_key_op : compilation_key_ops) {
354     compilation_key_op.replaceAllUsesWith(host_program_key);
355     compilation_key_op.erase();
356   }
357   return mlir::success();
358 }
359 
360 // Pass to move TPUCompile/TPUCompileSucceededAssert op to host mesh computation
361 // and add necessary send/recv ops to transfer TPU program key to TPU device
362 // computation.
363 struct DTensorMoveCompilationToHost
364     : public DTensorMoveCompilationToHostBase<DTensorMoveCompilationToHost> {
runOnOperationtensorflow::dtensor::__anon6e76c9040111::DTensorMoveCompilationToHost365   void runOnOperation() override {
366     mlir::MLIRContext& context = getContext();
367     mlir::OpBuilder builder(&context);
368     auto module = getOperation();
369 
370     llvm::SmallVector<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp, 4>
371         compilation_key_ops;
372     module.walk([&](mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp op) {
373       compilation_key_ops.emplace_back(op);
374     });
375 
376     if (compilation_key_ops.empty()) return;
377 
378     mlir::func::FuncOp main_func =
379         module.lookupSymbol<mlir::func::FuncOp>("main");
380     if (!main_func) return;
381 
382     std::map<Mesh, mlir::TF::StatefulPartitionedCallOp> computation_map;
383     if (mlir::failed(
384             IdentifyAndValidateMeshComputations(main_func, &computation_map)))
385       return signalPassFailure();
386 
387     int num_send_recv = 0;
388     if (mlir::failed(HandleCompilationOps(compilation_key_ops, computation_map,
389                                           module, &num_send_recv)))
390       return signalPassFailure();
391   };
392 };
393 
394 }  // namespace
395 
396 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorMoveCompilationToHost()397 CreateDTensorMoveCompilationToHost() {
398   return std::make_unique<DTensorMoveCompilationToHost>();
399 }
400 
401 }  // namespace dtensor
402 }  // namespace tensorflow
403