1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/TypeRange.h" // from @llvm-project
35 #include "mlir/IR/Visitors.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48
49 namespace mlir {
50 namespace TFTPU {
51
52 namespace {
53
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kHostFunctionAttr[] = "host_func";
56 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
57
58 struct TPUExtractOutsideCompilation
59 : public TF::TPUExtractOutsideCompilationPassBase<
60 TPUExtractOutsideCompilation> {
61 void runOnOperation() override;
62 };
63
64 // Build a function containing `ops` with `inputs` and `outputs` using
65 // `builder`. The `ops` are cloned and modified to use the function arguments
66 // as inputs.
BuildFunction(llvm::ArrayRef<Operation * > ops,llvm::ArrayRef<Value> inputs,llvm::ArrayRef<Value> outputs,OpBuilder * builder)67 func::FuncOp BuildFunction(llvm::ArrayRef<Operation*> ops,
68 llvm::ArrayRef<Value> inputs,
69 llvm::ArrayRef<Value> outputs, OpBuilder* builder) {
70 llvm::SmallVector<Type, 4> operand_types;
71 operand_types.reserve(inputs.size());
72 for (Value v : inputs) operand_types.emplace_back(v.getType());
73 llvm::SmallVector<Type, 4> output_types;
74 output_types.reserve(outputs.size());
75 for (Value v : outputs) output_types.emplace_back(v.getType());
76
77 auto func_type = builder->getFunctionType(operand_types, output_types);
78
79 func::FuncOp outlined_func =
80 func::FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type);
81
82 // Create function body.
83 Block* outlined_func_block = outlined_func.addEntryBlock();
84
85 // Clone the operations and remap the inputs to use the function arguments.
86 BlockAndValueMapping mapping;
87 mapping.map(inputs, outlined_func.getArguments());
88 builder->setInsertionPoint(outlined_func_block, outlined_func_block->begin());
89 for (Operation* op : ops) {
90 builder->clone(*op, mapping);
91 }
92
93 // Set the returned values to use cloned ops results using mapping.
94 llvm::SmallVector<Value, 4> results_after_mapping;
95 for (Value result : outputs) {
96 results_after_mapping.push_back(mapping.lookupOrDefault(result));
97 }
98
99 builder->create<func::ReturnOp>(ops.front()->getLoc(), results_after_mapping);
100 return outlined_func;
101 }
102
103 // Encapsulates `func` in a module and serializes that module.
104 // `serialized_func_module` is set to the serialized module.
EncapsulateFuncAndSerialize(func::FuncOp func,std::string * serialized_func_module)105 void EncapsulateFuncAndSerialize(func::FuncOp func,
106 std::string* serialized_func_module) {
107 // Create a new module to hold func and all referenced functions.
108 OwningOpRef<mlir::ModuleOp> module_for_func =
109 ModuleOp::create(mlir::UnknownLoc::get(func.getContext()));
110 SymbolTable symbol_table(module_for_func.get());
111
112 symbol_table.insert(func);
113 *serialized_func_module =
114 tensorflow::SerializeMlirModule(module_for_func.get());
115 }
116
117 // Returns whether `op` or ops nested in `op` are outside compiled.
HasOutsideCompilationNested(Operation * op)118 bool HasOutsideCompilationNested(Operation* op) {
119 return op
120 ->walk([&](Operation* walked_op) {
121 if (op == walked_op) return WalkResult::advance();
122 if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
123 return WalkResult::interrupt();
124 }
125 return WalkResult::advance();
126 })
127 .wasInterrupted();
128 }
129
130 // Returns whether `op` or any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestor(Operation * op)131 bool HasOutsideCompilationAncestor(Operation* op) {
132 while (op) {
133 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
134 return true;
135 }
136 op = op->getParentOp();
137 }
138 return false;
139 }
140
141 // Returns whether any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestorExclusive(Operation * op)142 bool HasOutsideCompilationAncestorExclusive(Operation* op) {
143 Operation* parent_op = op->getParentOp();
144 if (!parent_op) return false;
145 return HasOutsideCompilationAncestor(parent_op);
146 }
147
ApplyXlaHostTransferAttr(Operation * op,OpBuilder & builder)148 Operation* ApplyXlaHostTransferAttr(Operation* op, OpBuilder& builder) {
149 op->setAttr("_xla_has_host_transfer", builder.getBoolAttr(true));
150 return op;
151 }
152
153 // Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
154 // is present, a tf._XlaSendFromHostV2 op is created instead.
CreateSendFromHostOp(OpBuilder & builder,Location loc,ValueRange inputs,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)155 Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
156 ValueRange inputs, Value compilation_key,
157 Value device_ordinal,
158 llvm::StringRef communication_key) {
159 if (device_ordinal)
160 return ApplyXlaHostTransferAttr(
161 builder.create<TF::_XlaSendFromHostV2Op>(
162 loc, inputs,
163 /*dynamic_key=*/compilation_key, device_ordinal,
164 builder.getStringAttr(communication_key)),
165 builder);
166
167 return ApplyXlaHostTransferAttr(
168 builder.create<TF::_XlaSendFromHostOp>(
169 loc, inputs,
170 /*dynamic_key=*/compilation_key,
171 builder.getStringAttr(communication_key),
172 /*device_ordinal=*/builder.getI64IntegerAttr(0)),
173 builder);
174 }
175
176 // Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
177 // present, a tf._XlaRecvAtHostV2 op is created instead.
CreateRecvAtHostOp(OpBuilder & builder,Location loc,TypeRange output_types,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)178 Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
179 TypeRange output_types, Value compilation_key,
180 Value device_ordinal,
181 llvm::StringRef communication_key) {
182 if (device_ordinal)
183 return ApplyXlaHostTransferAttr(
184 builder.create<TF::_XlaRecvAtHostV2Op>(
185 loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
186 builder.getStringAttr(communication_key)),
187 builder);
188
189 return ApplyXlaHostTransferAttr(
190 builder.create<TF::_XlaRecvAtHostOp>(
191 loc, output_types, /*dynamic_key=*/compilation_key,
192 builder.getStringAttr(communication_key),
193 /*device_ordinal=*/builder.getI64IntegerAttr(0)),
194 builder);
195 }
196
197 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
198 // with yield op and an empty block.
CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,OpBuilder & builder)199 TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
200 OpBuilder& builder) {
201 auto host_side_if = builder.create<TF::IfRegionOp>(
202 if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
203 if_region.is_stateless(), if_region._then_func_nameAttr(),
204 if_region._else_func_nameAttr());
205
206 // Create empty then branch region.
207 auto& then_branch = host_side_if.then_branch();
208 then_branch.push_back(new Block);
209 builder.setInsertionPointToEnd(&then_branch.front());
210 builder.create<TF::YieldOp>(if_region.getLoc(),
211 /*operands=*/ArrayRef<Value>{});
212
213 // Create empty else branch region.
214 auto& else_branch = host_side_if.else_branch();
215 else_branch.push_back(new Block);
216 builder.setInsertionPointToEnd(&else_branch.front());
217 builder.create<TF::YieldOp>(if_region.getLoc(),
218 /*operands=*/ArrayRef<Value>{});
219 return host_side_if;
220 }
221 // Creates a WhileRegionOp cond and body regions with yield op and
222 // an empty body.
CloneEmptyWhile(bool is_stateless,uint64_t parallel_iterations,Location loc,OpBuilder & builder)223 TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
224 uint64_t parallel_iterations, Location loc,
225 OpBuilder& builder) {
226 auto host_side_while = builder.create<TF::WhileRegionOp>(
227 loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
228 parallel_iterations, is_stateless, /*shape_invariant=*/false);
229
230 // Create empty else branch region.
231 auto& body = host_side_while.body();
232 body.push_back(new Block);
233 builder.setInsertionPointToEnd(&body.front());
234 builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
235 return host_side_while;
236 }
237
238 // TODO(b/157054714): Use a better abstraction instead of
239 // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
240 // Creates a compilation key as placeholder. A placeholder compilation cache key
241 // is created because it is a required input to _XlaRecvAtHost and
242 // _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
243 // cluster that contains the outside compiled ops. This placeholder should be
244 // replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
CreateCompilationKeyPlaceholder(Location loc,OpBuilder & builder)245 TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
246 Location loc, OpBuilder& builder) {
247 auto result_type =
248 RankedTensorType::get({3}, builder.getType<TF::StringType>());
249 return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
250 loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
251 }
252
253 // Creates a `tf_device.launch` to wrap cluster ops.
CreateLaunchOpForOutsideCluster(OpBuilder & builder,Operation * loc_op,llvm::StringRef host_device)254 tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
255 OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
256 // An empty string placeholder is used for the device as that will be later
257 // populated with the device of the associated TPUReplicateMetadata op.
258 auto launch_op = builder.create<tf_device::LaunchOp>(
259 loc_op->getLoc(), builder.getStringAttr(host_device),
260 /*result_types=*/ArrayRef<Type>{});
261
262 launch_op.body().push_back(new Block);
263 builder.setInsertionPointToEnd(&launch_op.GetBody());
264 builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
265 llvm::ArrayRef<Value>{});
266
267 return launch_op;
268 }
269
270 // Returns true if `op` has non-static shaped outputs.
HasDynamicOutputs(Operation * op)271 bool HasDynamicOutputs(Operation* op) {
272 for (Value v : op->getResults()) {
273 if (TF::CanBeRefined(v.getType())) return true;
274 }
275 return false;
276 }
277
278 // Returns true if any op in `cluster_ops` has outputs consumed by ops not
279 // `cluster_ops` with a non-static shape.
HasDynamicOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)280 bool HasDynamicOutputs(const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
281 for (Operation* op : cluster_ops) {
282 for (const OpOperand& use : op->getUses()) {
283 if (cluster_ops.count(use.getOwner())) {
284 continue;
285 }
286 if (TF::CanBeRefined(use.get().getType())) return true;
287 }
288 }
289 return false;
290 }
291
HasDynamicExternalValues(Operation * op)292 bool HasDynamicExternalValues(Operation* op) {
293 return op
294 ->walk([](Operation* walked_op) {
295 for (Value v : walked_op->getOperands()) {
296 if (TF::CanBeRefined(v.getType())) {
297 return WalkResult::interrupt();
298 }
299 }
300 return WalkResult::advance();
301 })
302 .wasInterrupted();
303 }
304
305 // Returns operands of `cluster_ops` that need to be
306 // communicated from device->host. This is for the case when all operands have a
307 // static shape.
GetStaticExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)308 llvm::SmallSetVector<Value, 4> GetStaticExternalOperands(
309 tf_device::ClusterOp tpu_cluster,
310 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
311 llvm::SmallSetVector<Value, 4> external_values;
312 for (Operation* op : cluster_ops) {
313 op->walk([&](Operation* walked_op) {
314 if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(
315 walked_op))
316 return WalkResult::advance();
317 for (Value v : walked_op->getOperands()) {
318 if (auto* defining_op = v.getDefiningOp()) {
319 if (!op->isAncestor(defining_op) &&
320 tpu_cluster->isAncestor(defining_op) &&
321 !HasOutsideCompilationAncestor(defining_op) &&
322 !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
323 external_values.insert(v);
324 }
325 continue;
326 }
327 auto block_arg = v.cast<BlockArgument>();
328 if (block_arg.getParentRegion() == op->getParentRegion())
329 external_values.insert(v);
330 }
331 return WalkResult::advance();
332 });
333 }
334 return external_values;
335 }
336
337 // Returns every operand of `cluster_ops` that does not come from an op in
338 // `cluster_ops`.
GetAllExternalOperands(const llvm::SmallSetVector<Operation *,4> & cluster_ops)339 llvm::SmallSetVector<Value, 4> GetAllExternalOperands(
340 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
341 llvm::SmallSetVector<Value, 4> external_values;
342 for (Operation* op : cluster_ops) {
343 op->walk([&](Operation* walked_op) {
344 for (Value v : walked_op->getOperands()) {
345 Operation* defining_op = v.getDefiningOp();
346 if (!defining_op || !cluster_ops.count(defining_op)) {
347 external_values.insert(v);
348 }
349 }
350 });
351 }
352 return external_values;
353 }
354
355 // Returns a SmallSetVector containing all of the operands that need to be
356 // communicated from device->host.
GetExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)357 llvm::SmallSetVector<Value, 4> GetExternalOperands(
358 tf_device::ClusterOp tpu_cluster,
359 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
360 // If there are any dynamic outputs, get all of the operands which are defined
361 // external to `cluster_ops`.
362 bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
363 if (has_dynamic_outputs) {
364 return GetAllExternalOperands(cluster_ops);
365 } else {
366 return GetStaticExternalOperands(tpu_cluster, cluster_ops);
367 }
368 }
369
370 // Gets all outputs that need to be communicated from host->device.
GetExternalOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)371 llvm::SmallSetVector<Value, 4> GetExternalOutputs(
372 const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
373 llvm::SmallSetVector<Value, 4> external_outputs;
374 bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
375 for (Operation* op : cluster_ops) {
376 for (Operation* user : op->getUsers()) {
377 // We skip any operations that are in the same outside compilation
378 // cluster that will be moved to the host at the same time since both
379 // defining op and user op will be moved to host.
380 if (cluster_ops.count(user)) {
381 continue;
382 }
383 // This is pessimistic and in some cases will add extra communication.
384 if (!HasOutsideCompilationAncestor(user) || has_dynamic_outputs ||
385 HasDynamicOutputs(user)) {
386 for (Value v : user->getOperands()) {
387 if (v.getDefiningOp() == op) external_outputs.insert(v);
388 }
389 }
390 }
391 }
392 return external_outputs;
393 }
394
395 // Creates the HostCompute with `inputs` and `outputs`
396 // using `communication_key`.
CreateHostCompute(OpBuilder & builder,Location loc,const llvm::SmallSetVector<Value,4> & inputs,llvm::ArrayRef<Value> outputs,llvm::StringRef args_communication_key,llvm::StringRef retvals_communication_key,llvm::StringRef serialized_func_module)397 TF::_XlaHostComputeMlirOp CreateHostCompute(
398 OpBuilder& builder, Location loc,
399 const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
400 llvm::StringRef args_communication_key,
401 llvm::StringRef retvals_communication_key,
402 llvm::StringRef serialized_func_module) {
403 llvm::SmallVector<Type, 4> device_output_types;
404 for (const auto& output : outputs)
405 device_output_types.push_back(output.getType());
406 auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
407 loc, device_output_types, inputs.getArrayRef(),
408 builder.getStringAttr(args_communication_key),
409 builder.getStringAttr(retvals_communication_key),
410 /*host_mlir_module=*/builder.getStringAttr(serialized_func_module));
411 return host_compute;
412 }
413
MarkOutsideCompiled(Operation * op)414 void MarkOutsideCompiled(Operation* op) {
415 op->setAttr(kXlaOutsideCompilationAttr,
416 StringAttr::get(op->getContext(), "temp"));
417 }
418
419 // Returns whether an outside compilation cluster should be closed. True when:
420 // 1. There is a dynamically shaped output consumed by a non-outside compiled
421 // op.
422 // 2. There is no dynamically shaped output.
ShouldCloseCluster(llvm::ArrayRef<Value> outputs)423 bool ShouldCloseCluster(llvm::ArrayRef<Value> outputs) {
424 bool has_dynamic_output = false;
425 for (Value v : outputs) {
426 if (TF::CanBeRefined(v.getType())) {
427 has_dynamic_output = true;
428 for (Operation* user : v.getUsers()) {
429 if (!HasOutsideCompilationAncestor(user)) return true;
430 }
431 }
432 }
433 return !has_dynamic_output;
434 }
435
436 // Replaces `external_operands` with the results from `recv_at_host`.
437 // For non-static shapes, only replace operand usage if op is in the same
438 // region as insertion.
439 // For static-shapes, Replace operand usages if op is in the same region as
440 // insertion or if the op is outside compiled and will be moved to host later.
ReplaceExternalOperandUsage(const llvm::SmallSetVector<Value,4> & external_operands,Operation * recv_at_host,Operation * insertion_point,Block * original_op_block)441 void ReplaceExternalOperandUsage(
442 const llvm::SmallSetVector<Value, 4>& external_operands,
443 Operation* recv_at_host, Operation* insertion_point,
444 Block* original_op_block) {
445 auto replace_operand_usage = [&](OpOperand& operand) {
446 if (TF::CanBeRefined(operand.get().getType()) ||
447 HasDynamicOutputs(operand.getOwner())) {
448 return insertion_point->getParentRegion()->isAncestor(
449 operand.getOwner()->getParentRegion());
450 }
451 return insertion_point->getParentRegion()->isAncestor(
452 operand.getOwner()->getParentRegion()) ||
453 (HasOutsideCompilationAncestor(operand.getOwner()) &&
454 original_op_block == operand.getOwner()->getBlock());
455 };
456 for (auto result : llvm::zip(external_operands, recv_at_host->getResults())) {
457 Value external_operand = std::get<0>(result);
458 external_operand.replaceUsesWithIf(std::get<1>(result),
459 replace_operand_usage);
460 }
461 }
462
HasDynamicOutputs(llvm::ArrayRef<Value> outputs)463 bool HasDynamicOutputs(llvm::ArrayRef<Value> outputs) {
464 for (Value v : outputs) {
465 if (TF::CanBeRefined(v.getType())) {
466 return true;
467 }
468 }
469 return false;
470 }
471
472 // Replaces usages of `external_outputs` which are values returned by outside
473 // compilation with the corresponding outputs from `host_compute`.
ReplaceExternalOutputUsage(const llvm::SmallSetVector<Value,4> & external_outputs,TF::_XlaHostComputeMlirOp host_compute)474 void ReplaceExternalOutputUsage(
475 const llvm::SmallSetVector<Value, 4>& external_outputs,
476 TF::_XlaHostComputeMlirOp host_compute) {
477 bool has_dynamic_outputs = HasDynamicOutputs(external_outputs.getArrayRef());
478
479 auto replace_output_usage = [&](OpOperand& operand) {
480 // Don't replace output usages if in host computation (defining op and user
481 // in same region).
482 bool in_same_region =
483 operand.get().getDefiningOp()->getParentRegion()->isAncestor(
484 operand.getOwner()->getParentRegion());
485 if (has_dynamic_outputs || HasDynamicOutputs(operand.getOwner())) {
486 return !in_same_region;
487 } else {
488 // Don't replace output usages in host computation or for outside
489 // compiled ops.
490 return !in_same_region &&
491 !HasOutsideCompilationAncestor(operand.getOwner());
492 }
493 };
494 for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
495 Value external_output = std::get<0>(result);
496 external_output.replaceUsesWithIf(std::get<1>(result),
497 replace_output_usage);
498 }
499 }
500
501 // Move `clustered_ops` to run on host and adds communication ops to transfer
502 // `external_operands` and `external_outputs` to/from device/host. Inserts
503 // ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when
504 // creating comm ops.
MoveOpsToHost(const llvm::SmallSetVector<Operation *,4> & clustered_ops,const llvm::SmallSetVector<Value,4> & external_operands,const llvm::SmallSetVector<Value,4> & external_outputs,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)505 void MoveOpsToHost(const llvm::SmallSetVector<Operation*, 4>& clustered_ops,
506 const llvm::SmallSetVector<Value, 4>& external_operands,
507 const llvm::SmallSetVector<Value, 4>& external_outputs,
508 Operation* insertion_point, Value compilation_key,
509 Value device_ordinal, int& communication_key_index) {
510 OpBuilder builder(insertion_point);
511 Operation& op = *clustered_ops.back();
512 std::string args_communication_key =
513 llvm::formatv("host_compute_channel_{0}_args", (communication_key_index))
514 .str();
515 std::string retvals_communication_key =
516 llvm::formatv("host_compute_channel_{0}_retvals",
517 (communication_key_index))
518 .str();
519
520 // Use a unique name when sending just the IfRegion predicate. This is
521 // for readable and to match the key in the TF2XLA bridge.
522 if (clustered_ops.size() == 1 && llvm::isa<TF::IfRegionOp>(op) &&
523 external_operands.size() == 1) {
524 args_communication_key =
525 llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
526 .str();
527 }
528
529 std::string serialized_func_module;
530 if (HasDynamicOutputs(external_outputs.getArrayRef())) {
531 func::FuncOp shape_op = BuildFunction(
532 clustered_ops.getArrayRef(), external_operands.getArrayRef(),
533 external_outputs.getArrayRef(), &builder);
534 EncapsulateFuncAndSerialize(shape_op, &serialized_func_module);
535 }
536
537 builder.setInsertionPoint(&op);
538 auto host_compute =
539 CreateHostCompute(builder, op.getLoc(), external_operands,
540 external_outputs.getArrayRef(), args_communication_key,
541 retvals_communication_key, serialized_func_module);
542 // Insert ops on the host side computation to receive data from device.
543 builder.setInsertionPoint(insertion_point);
544 llvm::SmallVector<Type, 4> host_operand_types;
545 for (const auto& operand : external_operands)
546 host_operand_types.push_back(operand.getType());
547
548 Operation* recv_at_host = CreateRecvAtHostOp(
549 builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal,
550 args_communication_key);
551 Block* original_op_block = op.getBlock();
552 Operation* after_op = recv_at_host;
553 for (Operation* cluster_op : clustered_ops) {
554 cluster_op->moveAfter(after_op);
555 cluster_op->removeAttr(StringAttr::get(op.getContext(), kDeviceAttr));
556 after_op = cluster_op;
557 }
558
559 if (!external_outputs.empty()) {
560 CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
561 compilation_key, device_ordinal,
562 retvals_communication_key);
563 }
564
565 if (external_operands.empty()) {
566 recv_at_host->erase();
567 } else {
568 ReplaceExternalOperandUsage(external_operands,
569 /*recv_at_host=*/recv_at_host,
570 /*insertion_point=*/insertion_point,
571 /*original_op_block=*/original_op_block);
572 }
573
574 ReplaceExternalOutputUsage(external_outputs, host_compute);
575
576 if (external_operands.empty() && external_outputs.empty()) {
577 host_compute.erase();
578 } else {
579 ++communication_key_index;
580 }
581 }
582
583 // Move outside compiled ops in `src` to `insertion_point` in host
584 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
585 // call to this method). Communication ops are added in both `src` and at
586 // `insertion_point` using `compilation_key`, `device_ordinal` and
587 // `communication_key_index` which is incremented when used. Communication ops
588 // are added only when needed and at the location need. There are checks to
589 // ensure that duplicate communication between device and host is not added.
MoveOpsToHost(tf_device::ClusterOp tpu_cluster,Block * src,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)590 LogicalResult MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
591 Operation* insertion_point, Value compilation_key,
592 Value device_ordinal,
593 int& communication_key_index) {
594 // Contains all of the outside compiled operations that should be moved to the
595 // host using a single `_XlaHostComputeMlir` op. This should only contain a
596 // single op except in the case where some of the input/output shapes are
597 // non-static.
598 llvm::SmallSetVector<Operation*, 4> clustered_ops;
599
600 for (Operation& op : llvm::make_early_inc_range(*src)) {
601 if (HasOutsideCompilationAncestorExclusive(&op) ||
602 !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
603 continue;
604
605 // We want to move the clustered_ops if the op to be added has all
606 // statically shaped operands since we can't ensure that the static shapes
607 // has been sent back to host in all cases. See
608 // @static_shapes_sandwiched_outside_compilation MLIR test for an example.
609 if (!HasDynamicExternalValues(&op) && !clustered_ops.empty()) {
610 llvm::SmallSetVector<Value, 4> external_operands =
611 GetExternalOperands(tpu_cluster, clustered_ops);
612 llvm::SmallSetVector<Value, 4> external_outputs =
613 GetExternalOutputs(clustered_ops);
614 MoveOpsToHost(clustered_ops, external_operands, external_outputs,
615 insertion_point, compilation_key, device_ordinal,
616 communication_key_index);
617 clustered_ops.clear();
618 }
619
620 clustered_ops.insert(&op);
621
622 // Get the outputs that need to be communicated from host -> device.
623 llvm::SmallSetVector<Value, 4> external_outputs =
624 GetExternalOutputs(clustered_ops);
625
626 if (ShouldCloseCluster(external_outputs.getArrayRef())) {
627 // Get the operands that need to be communicated from device -> host.
628 llvm::SmallSetVector<Value, 4> external_operands =
629 GetExternalOperands(tpu_cluster, clustered_ops);
630 MoveOpsToHost(clustered_ops, external_operands, external_outputs,
631 insertion_point, compilation_key, device_ordinal,
632 communication_key_index);
633 clustered_ops.clear();
634 }
635 }
636 return success();
637 }
638
639 // Decompose control flow in `tpu_cluster` into device computation and host
640 // (outside compiled) computation into two separate control flow ops with
641 // communication between the device/host for data dependencies. Both device and
642 // host control flow initially remain within `tpu_cluster` and a subsequency
643 // call to MoveOpsToHost moves the host side control flow to the host launch in
644 // tf_device.parallel_execute. Uses `compilation_key, `device_ordinal` and
645 // `communication_key_index` when creating communication ops.
DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,Value compilation_key,Value device_ordinal,int & communication_key_index)646 LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
647 Value compilation_key, Value device_ordinal,
648 int& communication_key_index) {
649 auto result = tpu_cluster.GetBody().walk([&](Operation* op) {
650 if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
651 if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
652 OpBuilder builder(if_op);
653 auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
654 if (failed(MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
655 host_if.then_branch().front().getTerminator(),
656 compilation_key, device_ordinal,
657 communication_key_index)))
658 return WalkResult::interrupt();
659 if (failed(MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
660 host_if.else_branch().front().getTerminator(),
661 compilation_key, device_ordinal,
662 communication_key_index)))
663 return WalkResult::interrupt();
664 MarkOutsideCompiled(host_if.getOperation());
665 }
666 if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
667 if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
668 OpBuilder builder(while_op);
669 auto host_while = CloneEmptyWhile(while_op.is_stateless(),
670 while_op.parallel_iterations(),
671 while_op.getLoc(), builder);
672 const auto condition_send_recv_key =
673 llvm::formatv("while_condition_channel_{0}",
674 communication_key_index++)
675 .str();
676 auto& cond = host_while.cond();
677 cond.push_back(new Block);
678 auto condition = while_op.cond().front().getTerminator()->getOperand(0);
679 builder.setInsertionPoint(while_op.cond().front().getTerminator());
680 builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
681 condition_send_recv_key);
682 builder.setInsertionPointToEnd(&cond.front());
683 auto recv_condition_at_host = CreateRecvAtHostOp(
684 builder, while_op.getLoc(), TypeRange{condition.getType()},
685 compilation_key, device_ordinal, condition_send_recv_key);
686 builder.create<TF::YieldOp>(while_op.getLoc(),
687 recv_condition_at_host->getResults());
688
689 if (failed(MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
690 recv_condition_at_host, compilation_key,
691 device_ordinal, communication_key_index)))
692 return WalkResult::interrupt();
693 if (failed(MoveOpsToHost(tpu_cluster, &while_op.body().front(),
694 host_while.body().front().getTerminator(),
695 compilation_key, device_ordinal,
696 communication_key_index)))
697 return WalkResult::interrupt();
698 MarkOutsideCompiled(host_while.getOperation());
699 }
700 return WalkResult::advance();
701 });
702 if (result.wasInterrupted()) return failure();
703 return success();
704 }
705
706 // Removes outside compilation from all ops inside `host_launch_op`. Should
707 // only be run after all outside compiled ops have been moved to
708 // `host_launch_op`.
RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op)709 void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
710 host_launch_op.GetBody().walk([&](Operation* op) {
711 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
712 op->removeAttr(
713 StringAttr::get(op->getContext(), kXlaOutsideCompilationAttr));
714 }
715 });
716 }
717
718 // Creates a `parallel_execute` op with a region for host computation and
719 // a region for `tpu_cluster` computation by extracting outside compiled ops to
720 // host computation.
CreateParallelExecuteForOutsideCompilation(ModuleOp module,tf_device::ClusterOp tpu_cluster,llvm::StringRef host_device)721 LogicalResult CreateParallelExecuteForOutsideCompilation(
722 ModuleOp module, tf_device::ClusterOp tpu_cluster,
723 llvm::StringRef host_device) {
724 OpBuilder builder(tpu_cluster);
725 // Create parallel_execute regions, one for the host computation for outside
726 // compilation and the second for the original TPU cluster computation.
727 const int num_regions = 2;
728 auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
729 tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
730 Block& host_computation_block =
731 parallel_execute_op.GetRegionBlockWithIndex(0);
732 builder.setInsertionPointToEnd(&host_computation_block);
733
734 // Create a single launch op for all outside compiled ops.
735 tf_device::LaunchOp host_launch_op =
736 CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
737 builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
738 auto compilation_key_op =
739 CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
740 Value compilation_key = compilation_key_op.program();
741 auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
742 tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
743 Value device_ordinal = nullptr;
744 if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
745 device_ordinal = device_ordinal_op.device_ordinal();
746 }
747
748 int communication_key_index = 0;
749 // Decompose control flow into device and host control flow when outside
750 // compilation is included.
751 if (failed(DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
752 communication_key_index)))
753 return failure();
754
755 // Move all outside compiled ops including control flow to host launch.
756 if (failed(MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
757 host_launch_op.GetBody().getTerminator(),
758 compilation_key, device_ordinal,
759 communication_key_index)))
760 return failure();
761
762 if (communication_key_index == 0) compilation_key_op.erase();
763 if (communication_key_index == 0 || device_ordinal == nullptr)
764 device_ordinal_op.erase();
765
766 RemoveOutsideCompilation(host_launch_op);
767
768 builder.setInsertionPointToEnd(&host_computation_block);
769 builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
770
771 // Move the launch body to last parallel_execute block.
772 Block& parallel_execute_tpu_block =
773 parallel_execute_op.GetRegionBlockWithIndex(1);
774 builder.setInsertionPointToEnd(¶llel_execute_tpu_block);
775 builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
776 tpu_cluster.getResults());
777 tpu_cluster.getOperation()->moveBefore(
778 parallel_execute_tpu_block.getTerminator());
779
780 // Remap cluster results with parallel_execute results if user is outside of
781 // parallel_execute.
782 for (auto result :
783 llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
784 Value tpu_cluster_result = std::get<0>(result);
785 Value parallel_execute_result = std::get<1>(result);
786 for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
787 if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
788 use.set(parallel_execute_result);
789 }
790 return success();
791 }
792
runOnOperation()793 void TPUExtractOutsideCompilation::runOnOperation() {
794 // Get runtime devices information from the closest parent module.
795 auto module = getOperation();
796 mlir::TF::RuntimeDevices devices;
797 if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
798 return signalPassFailure();
799
800 module.walk([&](tf_device::ClusterOp tpu_cluster) {
801 if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
802 std::string host_device;
803 if (failed(tensorflow::GetHostDeviceOutsideComputation(
804 devices, tpu_cluster, &host_device)))
805 return signalPassFailure();
806 if (failed(CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
807 host_device)))
808 return signalPassFailure();
809 }
810 });
811 // Remove `_xla_outside_compilation` attribute from all ops. These ops will
812 // be outside of the device cluster. The `_xla_outside_compilation` attribute
813 // on ops outside of tf_device.cluster don't have any meaning and can lead to
814 // errors later on. These ops were likely lifted out of the
815 // tf_device.cluster in an earlier pass.
816 module.walk(
817 [](Operation* op) { op->removeAttr("_xla_outside_compilation"); });
818 }
819
820 } // namespace
821
822 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass()823 CreateTPUExtractOutsideCompilationPass() {
824 return std::make_unique<TPUExtractOutsideCompilation>();
825 }
826
827 } // namespace TFTPU
828 } // namespace mlir
829