1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/TypeRange.h"  // from @llvm-project
35 #include "mlir/IR/Visitors.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48 
49 namespace mlir {
50 namespace TFTPU {
51 
52 namespace {
53 
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kHostFunctionAttr[] = "host_func";
56 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
57 
58 struct TPUExtractOutsideCompilation
59     : public TF::TPUExtractOutsideCompilationPassBase<
60           TPUExtractOutsideCompilation> {
61   void runOnOperation() override;
62 };
63 
64 // Build a function containing `ops` with `inputs` and `outputs` using
65 // `builder`.  The `ops` are cloned and modified to use the function arguments
66 // as inputs.
BuildFunction(llvm::ArrayRef<Operation * > ops,llvm::ArrayRef<Value> inputs,llvm::ArrayRef<Value> outputs,OpBuilder * builder)67 func::FuncOp BuildFunction(llvm::ArrayRef<Operation*> ops,
68                            llvm::ArrayRef<Value> inputs,
69                            llvm::ArrayRef<Value> outputs, OpBuilder* builder) {
70   llvm::SmallVector<Type, 4> operand_types;
71   operand_types.reserve(inputs.size());
72   for (Value v : inputs) operand_types.emplace_back(v.getType());
73   llvm::SmallVector<Type, 4> output_types;
74   output_types.reserve(outputs.size());
75   for (Value v : outputs) output_types.emplace_back(v.getType());
76 
77   auto func_type = builder->getFunctionType(operand_types, output_types);
78 
79   func::FuncOp outlined_func =
80       func::FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type);
81 
82   // Create function body.
83   Block* outlined_func_block = outlined_func.addEntryBlock();
84 
85   // Clone the operations and remap the inputs to use the function arguments.
86   BlockAndValueMapping mapping;
87   mapping.map(inputs, outlined_func.getArguments());
88   builder->setInsertionPoint(outlined_func_block, outlined_func_block->begin());
89   for (Operation* op : ops) {
90     builder->clone(*op, mapping);
91   }
92 
93   // Set the returned values to use cloned ops results using mapping.
94   llvm::SmallVector<Value, 4> results_after_mapping;
95   for (Value result : outputs) {
96     results_after_mapping.push_back(mapping.lookupOrDefault(result));
97   }
98 
99   builder->create<func::ReturnOp>(ops.front()->getLoc(), results_after_mapping);
100   return outlined_func;
101 }
102 
103 // Encapsulates `func` in a module and serializes that module.
104 // `serialized_func_module` is set to the serialized module.
EncapsulateFuncAndSerialize(func::FuncOp func,std::string * serialized_func_module)105 void EncapsulateFuncAndSerialize(func::FuncOp func,
106                                  std::string* serialized_func_module) {
107   // Create a new module to hold func and all referenced functions.
108   OwningOpRef<mlir::ModuleOp> module_for_func =
109       ModuleOp::create(mlir::UnknownLoc::get(func.getContext()));
110   SymbolTable symbol_table(module_for_func.get());
111 
112   symbol_table.insert(func);
113   *serialized_func_module =
114       tensorflow::SerializeMlirModule(module_for_func.get());
115 }
116 
117 // Returns whether `op` or ops nested in `op` are outside compiled.
HasOutsideCompilationNested(Operation * op)118 bool HasOutsideCompilationNested(Operation* op) {
119   return op
120       ->walk([&](Operation* walked_op) {
121         if (op == walked_op) return WalkResult::advance();
122         if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
123           return WalkResult::interrupt();
124         }
125         return WalkResult::advance();
126       })
127       .wasInterrupted();
128 }
129 
130 // Returns whether `op` or any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestor(Operation * op)131 bool HasOutsideCompilationAncestor(Operation* op) {
132   while (op) {
133     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
134       return true;
135     }
136     op = op->getParentOp();
137   }
138   return false;
139 }
140 
141 // Returns whether any ancestors of `op` are outside compiled.
HasOutsideCompilationAncestorExclusive(Operation * op)142 bool HasOutsideCompilationAncestorExclusive(Operation* op) {
143   Operation* parent_op = op->getParentOp();
144   if (!parent_op) return false;
145   return HasOutsideCompilationAncestor(parent_op);
146 }
147 
ApplyXlaHostTransferAttr(Operation * op,OpBuilder & builder)148 Operation* ApplyXlaHostTransferAttr(Operation* op, OpBuilder& builder) {
149   op->setAttr("_xla_has_host_transfer", builder.getBoolAttr(true));
150   return op;
151 }
152 
153 // Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
154 // is present, a tf._XlaSendFromHostV2 op is created instead.
CreateSendFromHostOp(OpBuilder & builder,Location loc,ValueRange inputs,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)155 Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
156                                 ValueRange inputs, Value compilation_key,
157                                 Value device_ordinal,
158                                 llvm::StringRef communication_key) {
159   if (device_ordinal)
160     return ApplyXlaHostTransferAttr(
161         builder.create<TF::_XlaSendFromHostV2Op>(
162             loc, inputs,
163             /*dynamic_key=*/compilation_key, device_ordinal,
164             builder.getStringAttr(communication_key)),
165         builder);
166 
167   return ApplyXlaHostTransferAttr(
168       builder.create<TF::_XlaSendFromHostOp>(
169           loc, inputs,
170           /*dynamic_key=*/compilation_key,
171           builder.getStringAttr(communication_key),
172           /*device_ordinal=*/builder.getI64IntegerAttr(0)),
173       builder);
174 }
175 
176 // Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
177 // present, a tf._XlaRecvAtHostV2 op is created instead.
CreateRecvAtHostOp(OpBuilder & builder,Location loc,TypeRange output_types,Value compilation_key,Value device_ordinal,llvm::StringRef communication_key)178 Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
179                               TypeRange output_types, Value compilation_key,
180                               Value device_ordinal,
181                               llvm::StringRef communication_key) {
182   if (device_ordinal)
183     return ApplyXlaHostTransferAttr(
184         builder.create<TF::_XlaRecvAtHostV2Op>(
185             loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
186             builder.getStringAttr(communication_key)),
187         builder);
188 
189   return ApplyXlaHostTransferAttr(
190       builder.create<TF::_XlaRecvAtHostOp>(
191           loc, output_types, /*dynamic_key=*/compilation_key,
192           builder.getStringAttr(communication_key),
193           /*device_ordinal=*/builder.getI64IntegerAttr(0)),
194       builder);
195 }
196 
197 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
198 // with yield op and an empty block.
CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,OpBuilder & builder)199 TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
200                                          OpBuilder& builder) {
201   auto host_side_if = builder.create<TF::IfRegionOp>(
202       if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
203       if_region.is_stateless(), if_region._then_func_nameAttr(),
204       if_region._else_func_nameAttr());
205 
206   // Create empty then branch region.
207   auto& then_branch = host_side_if.then_branch();
208   then_branch.push_back(new Block);
209   builder.setInsertionPointToEnd(&then_branch.front());
210   builder.create<TF::YieldOp>(if_region.getLoc(),
211                               /*operands=*/ArrayRef<Value>{});
212 
213   // Create empty else branch region.
214   auto& else_branch = host_side_if.else_branch();
215   else_branch.push_back(new Block);
216   builder.setInsertionPointToEnd(&else_branch.front());
217   builder.create<TF::YieldOp>(if_region.getLoc(),
218                               /*operands=*/ArrayRef<Value>{});
219   return host_side_if;
220 }
221 // Creates a WhileRegionOp cond and body regions with yield op and
222 // an empty body.
CloneEmptyWhile(bool is_stateless,uint64_t parallel_iterations,Location loc,OpBuilder & builder)223 TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
224                                   uint64_t parallel_iterations, Location loc,
225                                   OpBuilder& builder) {
226   auto host_side_while = builder.create<TF::WhileRegionOp>(
227       loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
228       parallel_iterations, is_stateless, /*shape_invariant=*/false);
229 
230   // Create empty else branch region.
231   auto& body = host_side_while.body();
232   body.push_back(new Block);
233   builder.setInsertionPointToEnd(&body.front());
234   builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
235   return host_side_while;
236 }
237 
238 // TODO(b/157054714): Use a better abstraction instead of
239 // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
240 // Creates a compilation key as placeholder. A placeholder compilation cache key
241 // is created because it is a required input to _XlaRecvAtHost and
242 // _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
243 // cluster that contains the outside compiled ops. This placeholder should be
244 // replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
CreateCompilationKeyPlaceholder(Location loc,OpBuilder & builder)245 TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
246     Location loc, OpBuilder& builder) {
247   auto result_type =
248       RankedTensorType::get({3}, builder.getType<TF::StringType>());
249   return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
250       loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
251 }
252 
253 // Creates a `tf_device.launch` to wrap cluster ops.
CreateLaunchOpForOutsideCluster(OpBuilder & builder,Operation * loc_op,llvm::StringRef host_device)254 tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
255     OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
256   // An empty string placeholder is used for the device as that will be later
257   // populated with the device of the associated TPUReplicateMetadata op.
258   auto launch_op = builder.create<tf_device::LaunchOp>(
259       loc_op->getLoc(), builder.getStringAttr(host_device),
260       /*result_types=*/ArrayRef<Type>{});
261 
262   launch_op.body().push_back(new Block);
263   builder.setInsertionPointToEnd(&launch_op.GetBody());
264   builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
265                                       llvm::ArrayRef<Value>{});
266 
267   return launch_op;
268 }
269 
270 // Returns true if `op` has non-static shaped outputs.
HasDynamicOutputs(Operation * op)271 bool HasDynamicOutputs(Operation* op) {
272   for (Value v : op->getResults()) {
273     if (TF::CanBeRefined(v.getType())) return true;
274   }
275   return false;
276 }
277 
278 // Returns true if any op in `cluster_ops` has outputs consumed by ops not
279 // `cluster_ops` with a non-static shape.
HasDynamicOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)280 bool HasDynamicOutputs(const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
281   for (Operation* op : cluster_ops) {
282     for (const OpOperand& use : op->getUses()) {
283       if (cluster_ops.count(use.getOwner())) {
284         continue;
285       }
286       if (TF::CanBeRefined(use.get().getType())) return true;
287     }
288   }
289   return false;
290 }
291 
HasDynamicExternalValues(Operation * op)292 bool HasDynamicExternalValues(Operation* op) {
293   return op
294       ->walk([](Operation* walked_op) {
295         for (Value v : walked_op->getOperands()) {
296           if (TF::CanBeRefined(v.getType())) {
297             return WalkResult::interrupt();
298           }
299         }
300         return WalkResult::advance();
301       })
302       .wasInterrupted();
303 }
304 
305 // Returns operands of `cluster_ops` that need to be
306 // communicated from device->host. This is for the case when all operands have a
307 // static shape.
GetStaticExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)308 llvm::SmallSetVector<Value, 4> GetStaticExternalOperands(
309     tf_device::ClusterOp tpu_cluster,
310     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
311   llvm::SmallSetVector<Value, 4> external_values;
312   for (Operation* op : cluster_ops) {
313     op->walk([&](Operation* walked_op) {
314       if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(
315               walked_op))
316         return WalkResult::advance();
317       for (Value v : walked_op->getOperands()) {
318         if (auto* defining_op = v.getDefiningOp()) {
319           if (!op->isAncestor(defining_op) &&
320               tpu_cluster->isAncestor(defining_op) &&
321               !HasOutsideCompilationAncestor(defining_op) &&
322               !llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
323             external_values.insert(v);
324           }
325           continue;
326         }
327         auto block_arg = v.cast<BlockArgument>();
328         if (block_arg.getParentRegion() == op->getParentRegion())
329           external_values.insert(v);
330       }
331       return WalkResult::advance();
332     });
333   }
334   return external_values;
335 }
336 
337 // Returns every operand of `cluster_ops` that does not come from an op in
338 // `cluster_ops`.
GetAllExternalOperands(const llvm::SmallSetVector<Operation *,4> & cluster_ops)339 llvm::SmallSetVector<Value, 4> GetAllExternalOperands(
340     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
341   llvm::SmallSetVector<Value, 4> external_values;
342   for (Operation* op : cluster_ops) {
343     op->walk([&](Operation* walked_op) {
344       for (Value v : walked_op->getOperands()) {
345         Operation* defining_op = v.getDefiningOp();
346         if (!defining_op || !cluster_ops.count(defining_op)) {
347           external_values.insert(v);
348         }
349       }
350     });
351   }
352   return external_values;
353 }
354 
355 // Returns a SmallSetVector containing all of the operands that need to be
356 // communicated from device->host.
GetExternalOperands(tf_device::ClusterOp tpu_cluster,const llvm::SmallSetVector<Operation *,4> & cluster_ops)357 llvm::SmallSetVector<Value, 4> GetExternalOperands(
358     tf_device::ClusterOp tpu_cluster,
359     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
360   // If there are any dynamic outputs, get all of the operands which are defined
361   // external to `cluster_ops`.
362   bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
363   if (has_dynamic_outputs) {
364     return GetAllExternalOperands(cluster_ops);
365   } else {
366     return GetStaticExternalOperands(tpu_cluster, cluster_ops);
367   }
368 }
369 
370 // Gets all outputs that need to be communicated from host->device.
GetExternalOutputs(const llvm::SmallSetVector<Operation *,4> & cluster_ops)371 llvm::SmallSetVector<Value, 4> GetExternalOutputs(
372     const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
373   llvm::SmallSetVector<Value, 4> external_outputs;
374   bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
375   for (Operation* op : cluster_ops) {
376     for (Operation* user : op->getUsers()) {
377       // We skip any operations that are in the same outside compilation
378       // cluster that will be moved to the host at the same time since both
379       // defining op and user op will be moved to host.
380       if (cluster_ops.count(user)) {
381         continue;
382       }
383       // This is pessimistic and in some cases will add extra communication.
384       if (!HasOutsideCompilationAncestor(user) || has_dynamic_outputs ||
385           HasDynamicOutputs(user)) {
386         for (Value v : user->getOperands()) {
387           if (v.getDefiningOp() == op) external_outputs.insert(v);
388         }
389       }
390     }
391   }
392   return external_outputs;
393 }
394 
395 // Creates the HostCompute with `inputs` and `outputs`
396 // using `communication_key`.
CreateHostCompute(OpBuilder & builder,Location loc,const llvm::SmallSetVector<Value,4> & inputs,llvm::ArrayRef<Value> outputs,llvm::StringRef args_communication_key,llvm::StringRef retvals_communication_key,llvm::StringRef serialized_func_module)397 TF::_XlaHostComputeMlirOp CreateHostCompute(
398     OpBuilder& builder, Location loc,
399     const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
400     llvm::StringRef args_communication_key,
401     llvm::StringRef retvals_communication_key,
402     llvm::StringRef serialized_func_module) {
403   llvm::SmallVector<Type, 4> device_output_types;
404   for (const auto& output : outputs)
405     device_output_types.push_back(output.getType());
406   auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
407       loc, device_output_types, inputs.getArrayRef(),
408       builder.getStringAttr(args_communication_key),
409       builder.getStringAttr(retvals_communication_key),
410       /*host_mlir_module=*/builder.getStringAttr(serialized_func_module));
411   return host_compute;
412 }
413 
MarkOutsideCompiled(Operation * op)414 void MarkOutsideCompiled(Operation* op) {
415   op->setAttr(kXlaOutsideCompilationAttr,
416               StringAttr::get(op->getContext(), "temp"));
417 }
418 
419 // Returns whether an outside compilation cluster should be closed.  True when:
420 // 1. There is a dynamically shaped output consumed by a non-outside compiled
421 // op.
422 // 2. There is no dynamically shaped output.
ShouldCloseCluster(llvm::ArrayRef<Value> outputs)423 bool ShouldCloseCluster(llvm::ArrayRef<Value> outputs) {
424   bool has_dynamic_output = false;
425   for (Value v : outputs) {
426     if (TF::CanBeRefined(v.getType())) {
427       has_dynamic_output = true;
428       for (Operation* user : v.getUsers()) {
429         if (!HasOutsideCompilationAncestor(user)) return true;
430       }
431     }
432   }
433   return !has_dynamic_output;
434 }
435 
436 // Replaces `external_operands` with the results from `recv_at_host`.
437 // For non-static shapes, only replace operand usage if op is in the same
438 // region as insertion.
439 // For static-shapes, Replace operand usages if op is in the same region as
440 // insertion or if the op is outside compiled and will be moved to host later.
ReplaceExternalOperandUsage(const llvm::SmallSetVector<Value,4> & external_operands,Operation * recv_at_host,Operation * insertion_point,Block * original_op_block)441 void ReplaceExternalOperandUsage(
442     const llvm::SmallSetVector<Value, 4>& external_operands,
443     Operation* recv_at_host, Operation* insertion_point,
444     Block* original_op_block) {
445   auto replace_operand_usage = [&](OpOperand& operand) {
446     if (TF::CanBeRefined(operand.get().getType()) ||
447         HasDynamicOutputs(operand.getOwner())) {
448       return insertion_point->getParentRegion()->isAncestor(
449           operand.getOwner()->getParentRegion());
450     }
451     return insertion_point->getParentRegion()->isAncestor(
452                operand.getOwner()->getParentRegion()) ||
453            (HasOutsideCompilationAncestor(operand.getOwner()) &&
454             original_op_block == operand.getOwner()->getBlock());
455   };
456   for (auto result : llvm::zip(external_operands, recv_at_host->getResults())) {
457     Value external_operand = std::get<0>(result);
458     external_operand.replaceUsesWithIf(std::get<1>(result),
459                                        replace_operand_usage);
460   }
461 }
462 
HasDynamicOutputs(llvm::ArrayRef<Value> outputs)463 bool HasDynamicOutputs(llvm::ArrayRef<Value> outputs) {
464   for (Value v : outputs) {
465     if (TF::CanBeRefined(v.getType())) {
466       return true;
467     }
468   }
469   return false;
470 }
471 
472 // Replaces usages of `external_outputs` which are values returned by outside
473 // compilation with the corresponding outputs from `host_compute`.
ReplaceExternalOutputUsage(const llvm::SmallSetVector<Value,4> & external_outputs,TF::_XlaHostComputeMlirOp host_compute)474 void ReplaceExternalOutputUsage(
475     const llvm::SmallSetVector<Value, 4>& external_outputs,
476     TF::_XlaHostComputeMlirOp host_compute) {
477   bool has_dynamic_outputs = HasDynamicOutputs(external_outputs.getArrayRef());
478 
479   auto replace_output_usage = [&](OpOperand& operand) {
480     // Don't replace output usages if in host computation (defining op and user
481     // in same region).
482     bool in_same_region =
483         operand.get().getDefiningOp()->getParentRegion()->isAncestor(
484             operand.getOwner()->getParentRegion());
485     if (has_dynamic_outputs || HasDynamicOutputs(operand.getOwner())) {
486       return !in_same_region;
487     } else {
488       // Don't replace output usages in host computation or for outside
489       // compiled ops.
490       return !in_same_region &&
491              !HasOutsideCompilationAncestor(operand.getOwner());
492     }
493   };
494   for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
495     Value external_output = std::get<0>(result);
496     external_output.replaceUsesWithIf(std::get<1>(result),
497                                       replace_output_usage);
498   }
499 }
500 
501 // Move `clustered_ops` to run on host and adds communication ops to transfer
502 // `external_operands` and `external_outputs` to/from device/host.  Inserts
503 // ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when
504 // creating comm ops.
MoveOpsToHost(const llvm::SmallSetVector<Operation *,4> & clustered_ops,const llvm::SmallSetVector<Value,4> & external_operands,const llvm::SmallSetVector<Value,4> & external_outputs,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)505 void MoveOpsToHost(const llvm::SmallSetVector<Operation*, 4>& clustered_ops,
506                    const llvm::SmallSetVector<Value, 4>& external_operands,
507                    const llvm::SmallSetVector<Value, 4>& external_outputs,
508                    Operation* insertion_point, Value compilation_key,
509                    Value device_ordinal, int& communication_key_index) {
510   OpBuilder builder(insertion_point);
511   Operation& op = *clustered_ops.back();
512   std::string args_communication_key =
513       llvm::formatv("host_compute_channel_{0}_args", (communication_key_index))
514           .str();
515   std::string retvals_communication_key =
516       llvm::formatv("host_compute_channel_{0}_retvals",
517                     (communication_key_index))
518           .str();
519 
520   // Use a unique name when sending just the IfRegion predicate.  This is
521   // for readable and to match the key in the TF2XLA bridge.
522   if (clustered_ops.size() == 1 && llvm::isa<TF::IfRegionOp>(op) &&
523       external_operands.size() == 1) {
524     args_communication_key =
525         llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
526             .str();
527   }
528 
529   std::string serialized_func_module;
530   if (HasDynamicOutputs(external_outputs.getArrayRef())) {
531     func::FuncOp shape_op = BuildFunction(
532         clustered_ops.getArrayRef(), external_operands.getArrayRef(),
533         external_outputs.getArrayRef(), &builder);
534     EncapsulateFuncAndSerialize(shape_op, &serialized_func_module);
535   }
536 
537   builder.setInsertionPoint(&op);
538   auto host_compute =
539       CreateHostCompute(builder, op.getLoc(), external_operands,
540                         external_outputs.getArrayRef(), args_communication_key,
541                         retvals_communication_key, serialized_func_module);
542   // Insert ops on the host side computation to receive data from device.
543   builder.setInsertionPoint(insertion_point);
544   llvm::SmallVector<Type, 4> host_operand_types;
545   for (const auto& operand : external_operands)
546     host_operand_types.push_back(operand.getType());
547 
548   Operation* recv_at_host = CreateRecvAtHostOp(
549       builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal,
550       args_communication_key);
551   Block* original_op_block = op.getBlock();
552   Operation* after_op = recv_at_host;
553   for (Operation* cluster_op : clustered_ops) {
554     cluster_op->moveAfter(after_op);
555     cluster_op->removeAttr(StringAttr::get(op.getContext(), kDeviceAttr));
556     after_op = cluster_op;
557   }
558 
559   if (!external_outputs.empty()) {
560     CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
561                          compilation_key, device_ordinal,
562                          retvals_communication_key);
563   }
564 
565   if (external_operands.empty()) {
566     recv_at_host->erase();
567   } else {
568     ReplaceExternalOperandUsage(external_operands,
569                                 /*recv_at_host=*/recv_at_host,
570                                 /*insertion_point=*/insertion_point,
571                                 /*original_op_block=*/original_op_block);
572   }
573 
574   ReplaceExternalOutputUsage(external_outputs, host_compute);
575 
576   if (external_operands.empty() && external_outputs.empty()) {
577     host_compute.erase();
578   } else {
579     ++communication_key_index;
580   }
581 }
582 
583 // Move outside compiled ops in `src` to `insertion_point` in host
584 // computation (may be temporarily with `tpu_cluster` but moved in subsequent
585 // call to this method).  Communication ops are added in both `src` and at
586 // `insertion_point` using `compilation_key`, `device_ordinal` and
587 // `communication_key_index` which is incremented when used. Communication ops
588 // are added only when needed and at the location need.  There are checks to
589 // ensure that duplicate communication between device and host is not added.
MoveOpsToHost(tf_device::ClusterOp tpu_cluster,Block * src,Operation * insertion_point,Value compilation_key,Value device_ordinal,int & communication_key_index)590 LogicalResult MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
591                             Operation* insertion_point, Value compilation_key,
592                             Value device_ordinal,
593                             int& communication_key_index) {
594   // Contains all of the outside compiled operations that should be moved to the
595   // host using a single `_XlaHostComputeMlir` op.  This should only contain a
596   // single op except in the case where some of the input/output shapes are
597   // non-static.
598   llvm::SmallSetVector<Operation*, 4> clustered_ops;
599 
600   for (Operation& op : llvm::make_early_inc_range(*src)) {
601     if (HasOutsideCompilationAncestorExclusive(&op) ||
602         !op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
603       continue;
604 
605     // We want to move the clustered_ops if the op to be added has all
606     // statically shaped operands since we can't ensure that the static shapes
607     // has been sent back to host in all cases.  See
608     // @static_shapes_sandwiched_outside_compilation MLIR test for an example.
609     if (!HasDynamicExternalValues(&op) && !clustered_ops.empty()) {
610       llvm::SmallSetVector<Value, 4> external_operands =
611           GetExternalOperands(tpu_cluster, clustered_ops);
612       llvm::SmallSetVector<Value, 4> external_outputs =
613           GetExternalOutputs(clustered_ops);
614       MoveOpsToHost(clustered_ops, external_operands, external_outputs,
615                     insertion_point, compilation_key, device_ordinal,
616                     communication_key_index);
617       clustered_ops.clear();
618     }
619 
620     clustered_ops.insert(&op);
621 
622     // Get the outputs that need to be communicated from host -> device.
623     llvm::SmallSetVector<Value, 4> external_outputs =
624         GetExternalOutputs(clustered_ops);
625 
626     if (ShouldCloseCluster(external_outputs.getArrayRef())) {
627       // Get the operands that need to be communicated from device -> host.
628       llvm::SmallSetVector<Value, 4> external_operands =
629           GetExternalOperands(tpu_cluster, clustered_ops);
630       MoveOpsToHost(clustered_ops, external_operands, external_outputs,
631                     insertion_point, compilation_key, device_ordinal,
632                     communication_key_index);
633       clustered_ops.clear();
634     }
635   }
636   return success();
637 }
638 
639 // Decompose control flow in `tpu_cluster` into device computation and host
640 // (outside compiled) computation into two separate control flow ops with
641 // communication between the device/host for data dependencies.  Both device and
642 // host control flow initially remain within `tpu_cluster` and a subsequency
643 // call to MoveOpsToHost moves the host side control flow to the host launch in
644 // tf_device.parallel_execute.  Uses `compilation_key, `device_ordinal` and
645 // `communication_key_index` when creating communication ops.
DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,Value compilation_key,Value device_ordinal,int & communication_key_index)646 LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
647                                    Value compilation_key, Value device_ordinal,
648                                    int& communication_key_index) {
649   auto result = tpu_cluster.GetBody().walk([&](Operation* op) {
650     if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
651       if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
652       OpBuilder builder(if_op);
653       auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
654       if (failed(MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
655                                host_if.then_branch().front().getTerminator(),
656                                compilation_key, device_ordinal,
657                                communication_key_index)))
658         return WalkResult::interrupt();
659       if (failed(MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
660                                host_if.else_branch().front().getTerminator(),
661                                compilation_key, device_ordinal,
662                                communication_key_index)))
663         return WalkResult::interrupt();
664       MarkOutsideCompiled(host_if.getOperation());
665     }
666     if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
667       if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
668       OpBuilder builder(while_op);
669       auto host_while = CloneEmptyWhile(while_op.is_stateless(),
670                                         while_op.parallel_iterations(),
671                                         while_op.getLoc(), builder);
672       const auto condition_send_recv_key =
673           llvm::formatv("while_condition_channel_{0}",
674                         communication_key_index++)
675               .str();
676       auto& cond = host_while.cond();
677       cond.push_back(new Block);
678       auto condition = while_op.cond().front().getTerminator()->getOperand(0);
679       builder.setInsertionPoint(while_op.cond().front().getTerminator());
680       builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
681                                           condition_send_recv_key);
682       builder.setInsertionPointToEnd(&cond.front());
683       auto recv_condition_at_host = CreateRecvAtHostOp(
684           builder, while_op.getLoc(), TypeRange{condition.getType()},
685           compilation_key, device_ordinal, condition_send_recv_key);
686       builder.create<TF::YieldOp>(while_op.getLoc(),
687                                   recv_condition_at_host->getResults());
688 
689       if (failed(MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
690                                recv_condition_at_host, compilation_key,
691                                device_ordinal, communication_key_index)))
692         return WalkResult::interrupt();
693       if (failed(MoveOpsToHost(tpu_cluster, &while_op.body().front(),
694                                host_while.body().front().getTerminator(),
695                                compilation_key, device_ordinal,
696                                communication_key_index)))
697         return WalkResult::interrupt();
698       MarkOutsideCompiled(host_while.getOperation());
699     }
700     return WalkResult::advance();
701   });
702   if (result.wasInterrupted()) return failure();
703   return success();
704 }
705 
706 // Removes outside compilation from all ops inside `host_launch_op`.  Should
707 // only be run after all outside compiled ops have been moved to
708 // `host_launch_op`.
RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op)709 void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
710   host_launch_op.GetBody().walk([&](Operation* op) {
711     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
712       op->removeAttr(
713           StringAttr::get(op->getContext(), kXlaOutsideCompilationAttr));
714     }
715   });
716 }
717 
718 // Creates a `parallel_execute` op with a region for host computation and
719 // a region for `tpu_cluster` computation by extracting outside compiled ops to
720 // host computation.
CreateParallelExecuteForOutsideCompilation(ModuleOp module,tf_device::ClusterOp tpu_cluster,llvm::StringRef host_device)721 LogicalResult CreateParallelExecuteForOutsideCompilation(
722     ModuleOp module, tf_device::ClusterOp tpu_cluster,
723     llvm::StringRef host_device) {
724   OpBuilder builder(tpu_cluster);
725   // Create parallel_execute regions, one for the host computation for outside
726   // compilation and the second for the original TPU cluster computation.
727   const int num_regions = 2;
728   auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
729       tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
730   Block& host_computation_block =
731       parallel_execute_op.GetRegionBlockWithIndex(0);
732   builder.setInsertionPointToEnd(&host_computation_block);
733 
734   // Create a single launch op for all outside compiled ops.
735   tf_device::LaunchOp host_launch_op =
736       CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
737   builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
738   auto compilation_key_op =
739       CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
740   Value compilation_key = compilation_key_op.program();
741   auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
742       tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
743   Value device_ordinal = nullptr;
744   if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
745     device_ordinal = device_ordinal_op.device_ordinal();
746   }
747 
748   int communication_key_index = 0;
749   // Decompose control flow into device and host control flow when outside
750   // compilation is included.
751   if (failed(DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
752                                   communication_key_index)))
753     return failure();
754 
755   // Move all outside compiled ops including control flow to host launch.
756   if (failed(MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
757                            host_launch_op.GetBody().getTerminator(),
758                            compilation_key, device_ordinal,
759                            communication_key_index)))
760     return failure();
761 
762   if (communication_key_index == 0) compilation_key_op.erase();
763   if (communication_key_index == 0 || device_ordinal == nullptr)
764     device_ordinal_op.erase();
765 
766   RemoveOutsideCompilation(host_launch_op);
767 
768   builder.setInsertionPointToEnd(&host_computation_block);
769   builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
770 
771   // Move the launch body to last parallel_execute block.
772   Block& parallel_execute_tpu_block =
773       parallel_execute_op.GetRegionBlockWithIndex(1);
774   builder.setInsertionPointToEnd(&parallel_execute_tpu_block);
775   builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
776                                       tpu_cluster.getResults());
777   tpu_cluster.getOperation()->moveBefore(
778       parallel_execute_tpu_block.getTerminator());
779 
780   // Remap cluster results with parallel_execute results if user is outside of
781   // parallel_execute.
782   for (auto result :
783        llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
784     Value tpu_cluster_result = std::get<0>(result);
785     Value parallel_execute_result = std::get<1>(result);
786     for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
787       if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
788         use.set(parallel_execute_result);
789   }
790   return success();
791 }
792 
runOnOperation()793 void TPUExtractOutsideCompilation::runOnOperation() {
794   // Get runtime devices information from the closest parent module.
795   auto module = getOperation();
796   mlir::TF::RuntimeDevices devices;
797   if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
798     return signalPassFailure();
799 
800   module.walk([&](tf_device::ClusterOp tpu_cluster) {
801     if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
802       std::string host_device;
803       if (failed(tensorflow::GetHostDeviceOutsideComputation(
804               devices, tpu_cluster, &host_device)))
805         return signalPassFailure();
806       if (failed(CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
807                                                             host_device)))
808         return signalPassFailure();
809     }
810   });
811   // Remove `_xla_outside_compilation` attribute from all ops.  These ops will
812   // be outside of the device cluster. The `_xla_outside_compilation` attribute
813   // on ops outside of tf_device.cluster don't have any meaning and can lead to
814   // errors later on.  These ops were likely lifted out of the
815   // tf_device.cluster in an earlier pass.
816   module.walk(
817       [](Operation* op) { op->removeAttr("_xla_outside_compilation"); });
818 }
819 
820 }  // namespace
821 
822 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass()823 CreateTPUExtractOutsideCompilationPass() {
824   return std::make_unique<TPUExtractOutsideCompilation>();
825 }
826 
827 }  // namespace TFTPU
828 }  // namespace mlir
829