1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
26 #include "mlir/Rewrite/PatternApplicator.h"  // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37 
38 namespace mlir {
39 namespace TFDevice {
40 
41 namespace {
42 
43 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
44 constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement";
45 
46 auto* auto_outside_compilation_gauge =
47     tensorflow::monitoring::Gauge<bool, 0>::New(
48         "/tensorflow/core/use_auto_outside_compilation",
49         "Tracks if auto outside compilation is enabled");
50 
51 struct MarkOpsForOutsideCompilation
52     : public TF::MarkOpsForOutsideCompilationPassBase<
53           MarkOpsForOutsideCompilation> {
54   void runOnOperation() override;
55 };
56 
57 // Adds any canonicalization patterns to list of supported `patterns`.
58 // TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and
59 // remove this.
AddCanonicalizationPatterns(MLIRContext * context,RewritePatternSet * patterns)60 void AddCanonicalizationPatterns(MLIRContext* context,
61                                  RewritePatternSet* patterns) {
62   for (auto op : context->getRegisteredOperations())
63     op.getCanonicalizationPatterns(*patterns, context);
64 }
65 
66 // Adds the list of ops that are supported on TPU through constant folding which
67 // may depend on the inputs shapes not known at this point. Such ops may not
68 // have any legalization or canonicalization patterns but shouldn't be marked
69 // for outside compilation.
70 //
71 // TODO(b/177523289): Remove manual handling once we support constant folding
72 // and shape inference through the computation on the host side.
AddSupportedOpsUsingFolding(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)73 void AddSupportedOpsUsingFolding(MLIRContext* context,
74                                  llvm::DenseSet<OperationName>* supported_ops) {
75   llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
76       OperationName(TF::BroadcastArgsOp::getOperationName(), context),
77       OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context),
78       OperationName(TF::ConcatOffsetOp::getOperationName(), context),
79       OperationName(TF::EmptyOp::getOperationName(), context),
80       OperationName(TF::ListDiffOp::getOperationName(), context),
81       OperationName(TF::RankOp::getOperationName(), context),
82       OperationName(TF::RangeOp::getOperationName(), context),
83       OperationName(TF::ShapeOp::getOperationName(), context),
84       OperationName(TF::ShapeNOp::getOperationName(), context),
85       OperationName(TF::SizeOp::getOperationName(), context),
86   };
87 
88   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
89 }
90 
91 // Adds the list of ops that are supported through dynamic padder using op by op
92 // fallback to the TF2XLA bridge.
93 // TODO(b/168036682): Remove this once ops are supported using dynamic padder
94 // on MLIR bridge.
AddSupportedOpsUsingDynamicPadder(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)95 void AddSupportedOpsUsingDynamicPadder(
96     MLIRContext* context, llvm::DenseSet<OperationName>* supported_ops) {
97   llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
98       OperationName(TF::WhereOp::getOperationName(), context),
99       OperationName(TF::UniqueOp::getOperationName(), context),
100       OperationName(TF::XlaSetDynamicDimensionSizeOp::getOperationName(),
101                     context),
102   };
103 
104   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
105 }
106 
107 // TODO(b/159128666): Check the control flow legalization passes instead once
108 // added.
AddSupportedFunctionalOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)109 void AddSupportedFunctionalOps(MLIRContext* context,
110                                llvm::DenseSet<OperationName>* supported_ops) {
111   supported_ops->insert(
112       OperationName(TF::CaseRegionOp::getOperationName(), context));
113   supported_ops->insert(
114       OperationName(TF::IfRegionOp::getOperationName(), context));
115   supported_ops->insert(
116       OperationName(TF::InplaceAddOp::getOperationName(), context));
117   supported_ops->insert(
118       OperationName(TF::WhileRegionOp::getOperationName(), context));
119   supported_ops->insert(
120       OperationName(TF::XlaCallModuleOp::getOperationName(), context));
121   supported_ops->insert(
122       OperationName(TF::XlaReduceOp::getOperationName(), context));
123   supported_ops->insert(
124       OperationName(TF::XlaReduceWindowOp::getOperationName(), context));
125   supported_ops->insert(
126       OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context));
127   supported_ops->insert(
128       OperationName(TF::XlaScatterOp::getOperationName(), context));
129   supported_ops->insert(
130       OperationName(TF::XlaSelectAndScatterOp::getOperationName(), context));
131   supported_ops->insert(
132       OperationName(TF::SymbolicGradientOp::getOperationName(), context));
133   supported_ops->insert(
134       OperationName(TF::XlaVariadicReduceOp::getOperationName(), context));
135   supported_ops->insert(
136       OperationName(TF::XlaVariadicReduceV2Op::getOperationName(), context));
137   supported_ops->insert(
138       OperationName(TF::XlaVariadicSortOp::getOperationName(), context));
139   supported_ops->insert(
140       OperationName(TF::XlaReplicaIdOp::getOperationName(), context));
141   supported_ops->insert(
142       OperationName(TF::YieldOp::getOperationName(), context));
143 }
144 
145 // These embedding ops are rewritten when running TPUCompileOp.
AddRewrittenEmbeddingOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)146 void AddRewrittenEmbeddingOps(MLIRContext* context,
147                               llvm::DenseSet<OperationName>* supported_ops) {
148   supported_ops->insert(OperationName(
149       TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context));
150   supported_ops->insert(OperationName(
151       TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
152 }
153 
154 // Stack, TensorList and TensorArray ops are rewritten during the second phase
155 // of the bridge (compilation of TPUCompile op). They would not match any
156 // legalization/canonicalization pattern and have to be manually added to the
157 // list of supported ops.
AddRewrittenCompositeOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)158 void AddRewrittenCompositeOps(MLIRContext* context,
159                               llvm::DenseSet<OperationName>* supported_ops) {
160 #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
161   llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
162       // Stack ops.
163       GET_OPERATION_NAME(TF::StackV2Op),
164       GET_OPERATION_NAME(TF::StackPushV2Op),
165       GET_OPERATION_NAME(TF::StackPopV2Op),
166       // Tensor Array ops.
167       GET_OPERATION_NAME(TF::TensorArrayV3Op),
168       GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
169       GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
170       GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
171       GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
172       GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
173       GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
174       GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
175       GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
176       // Tensor List Ops.
177       GET_OPERATION_NAME(TF::EmptyTensorListOp),
178       GET_OPERATION_NAME(TF::TensorListReserveOp),
179       GET_OPERATION_NAME(TF::TensorListFromTensorOp),
180       GET_OPERATION_NAME(TF::TensorListPushBackOp),
181       GET_OPERATION_NAME(TF::TensorListPopBackOp),
182       GET_OPERATION_NAME(TF::TensorListGetItemOp),
183       GET_OPERATION_NAME(TF::TensorListSetItemOp),
184       GET_OPERATION_NAME(TF::TensorListLengthOp),
185       GET_OPERATION_NAME(TF::TensorListElementShapeOp),
186       GET_OPERATION_NAME(TF::TensorListGatherOp),
187       GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
188       GET_OPERATION_NAME(TF::TensorListStackOp),
189   };
190 #undef GET_OPERATION_NAME
191 
192   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
193 }
194 
IsStringType(Type type)195 bool IsStringType(Type type) {
196   if (type.isa<TF::StringType>()) return true;
197 
198   auto sub_type = type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
199   if (!sub_type) return false;
200 
201   bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) {
202     return type.getElementType().isa<TF::StringType>();
203   });
204   return has_string;
205 }
206 
HasStringOperand(Operation & op)207 bool HasStringOperand(Operation& op) {
208   for (auto operand : op.getOperands()) {
209     auto operand_type = getElementTypeOrSelf(operand);
210     if (IsStringType(operand_type)) return true;
211   }
212   return false;
213 }
214 
HasStringResult(Operation & op)215 bool HasStringResult(Operation& op) {
216   for (auto result : op.getResults()) {
217     auto result_type = getElementTypeOrSelf(result);
218     if (IsStringType(result_type)) return true;
219   }
220   return false;
221 }
222 
MatchesPattern(Operation & op,const llvm::DenseSet<OperationName> & supported_ops)223 bool MatchesPattern(Operation& op,
224                     const llvm::DenseSet<OperationName>& supported_ops) {
225   return (supported_ops.contains(op.getName()));
226 }
227 
228 // Checks if the op is supported inside of a device cluster.  Ops not
229 // in `tf_dialect` are considered supported.
IsSupportedOp(Operation & op,const llvm::DenseSet<OperationName> & supported_ops,const Dialect * tf_dialect)230 bool IsSupportedOp(Operation& op,
231                    const llvm::DenseSet<OperationName>& supported_ops,
232                    const Dialect* tf_dialect) {
233   if (op.getDialect() != tf_dialect)
234     return true;
235   // Assert has a legalization that later removes it so we don't want to outside
236   // compile it ever for performance reasons.
237   if (llvm::isa<TF::AssertOp>(op)) return true;
238   return !HasStringOperand(op) && !HasStringResult(op) &&
239          (MatchesPattern(op, supported_ops) ||
240           mhlo::IsOpAllowedTf2XlaFallback(&op));
241 }
242 
243 // Checks all regions of `op` for captured string operands.
HasCapturedStringOperand(Operation * op)244 bool HasCapturedStringOperand(Operation* op) {
245   bool string_operand = false;
246   for (auto& region : op->getRegions()) {
247     mlir::visitUsedValuesDefinedAbove(
248         region, region, [&](mlir::OpOperand* operand) {
249           if (getElementTypeOrSelf(operand->get()).isa<TF::StringType>())
250             string_operand = true;
251         });
252     if (string_operand) return string_operand;
253   }
254   return string_operand;
255 }
256 
IsVariant(Value value)257 bool IsVariant(Value value) {
258   return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>();
259 }
260 
HasOutsideCompiledAncestor(Operation * op)261 bool HasOutsideCompiledAncestor(Operation* op) {
262   Operation* parent = op->getParentOp();
263   while (parent) {
264     if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
265       return true;
266     parent = parent->getParentOp();
267   }
268   return false;
269 }
270 
271 // If any tf.variants are inputs/outputs to the another outside compiled
272 // Operation, `op`, mark  them for outside compilation unless they are already
273 // marks with outside compilation attribute.
MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster)274 void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) {
275   std::queue<Operation*> outside_compiled_ops;
276   tpu_cluster.walk([&](Operation* op) {
277     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
278       outside_compiled_ops.push(op);
279   });
280 
281   while (!outside_compiled_ops.empty()) {
282     Operation* host_op = outside_compiled_ops.front();
283     outside_compiled_ops.pop();
284     host_op->walk([&](Operation* op) {
285       // Add any operations that provide variant inputs to the cluster.
286       for (auto value : op->getOperands()) {
287         Operation* input_defining_op = value.getDefiningOp();
288         if (IsVariant(value) && input_defining_op &&
289             !HasOutsideCompiledAncestor(input_defining_op) &&
290             !input_defining_op->hasAttrOfType<StringAttr>(
291                 kXlaOutsideCompilationAttr)) {
292           input_defining_op->setAttr(
293               kXlaOutsideCompilationAttr,
294               StringAttr::get(input_defining_op->getContext(), "auto"));
295           outside_compiled_ops.push(input_defining_op);
296         }
297       }
298       // Mark for outside compilation any operations that consume variant
299       // outputs from an outside compiled operation.
300       for (auto value : op->getResults()) {
301         if (IsVariant(value)) {
302           for (auto user : value.getUsers()) {
303             if (!user->hasTrait<OpTrait::IsTerminator>() &&
304                 !HasOutsideCompiledAncestor(user) &&
305                 !user->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
306               user->setAttr(kXlaOutsideCompilationAttr,
307                             StringAttr::get(user->getContext(), "auto"));
308               outside_compiled_ops.push(user);
309             }
310           }
311         }
312       }
313     });
314   }
315 }
316 
317 // Marks uncompilable ops that are in `tf_dialect` for outside compilation.
MarkUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)318 LogicalResult MarkUncompilableOps(
319     const Dialect* tf_dialect, Block* block,
320     llvm::DenseSet<OperationName>& supported_ops) {
321   // Automatically marked ops for outside compilation have
322   // `_xla_outside_compilation` attribute value of "auto" plus
323   // an increasing counter.  Manually marked ops for outside compilation only
324   // have an increasing counteri for the attribute value.  Therefore there is no
325   // collision in
326   // `_xla_outside_compilation` attribute between automatically and manually
327   // marking ops.
328   int outside_compiled_cluster_counter = 0;
329   block->walk([&](Operation* op) {
330     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
331       VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
332               << " isn't compilable, adding outside_compilation attr. "
333                  "This op will automatically be placed on CPU.";
334       op->setAttr(kXlaOutsideCompilationAttr,
335                   StringAttr::get(
336                       op->getContext(),
337                       llvm::formatv("auto{0}", outside_compiled_cluster_counter)
338                           .str()));
339       outside_compiled_cluster_counter++;
340     }
341   });
342   if (outside_compiled_cluster_counter > 0) {
343     auto_outside_compilation_gauge->GetCell()->Set(true);
344   }
345   return success();
346 }
347 
348 // Check for uncompilable ops that are in `tf_dialect` and are not already
349 // marked for outside compilation.
ContainsUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)350 bool ContainsUncompilableOps(const Dialect* tf_dialect, Block* block,
351                              llvm::DenseSet<OperationName>& supported_ops) {
352   int uncompilable_op_count = 0;
353   // Check if op or any parent is already marked for outside compilation.
354   block->walk([&](Operation* op) {
355     Operation* iter_op = op;
356     while (iter_op && !llvm::isa<tf_device::ClusterOp>(iter_op)) {
357       if (iter_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
358         return;
359       }
360       iter_op = iter_op->getParentOp();
361     }
362 
363     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
364       op->emitOpError() << "isn't compilable for TPU device. enable "
365                            "soft_device_placement option to run on CPU";
366       ++uncompilable_op_count;
367     }
368   });
369   return uncompilable_op_count > 0;
370 }
371 
372 // Unmarks outside compilation for any op that has parents already
373 // marked for outside compilation since the child will be extracted
374 // anyways.
UnmarkChildren(Block * block)375 void UnmarkChildren(Block* block) {
376   block->walk([&](Operation* op) {
377     if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
378     Operation* iter_op = op;
379     bool remove_attr = false;
380     while (auto* parent_op = iter_op->getParentOp()) {
381       if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
382         remove_attr = true;
383         break;
384       }
385       iter_op = parent_op;
386     }
387     if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
388   });
389 }
390 
runOnOperation()391 void MarkOpsForOutsideCompilation::runOnOperation() {
392   auto module = getOperation();
393   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
394   if (!tf_dialect) {
395     getOperation().emitError() << "'tf' dialect is not registered";
396     return signalPassFailure();
397   }
398   RewritePatternSet patterns(&getContext());
399   mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
400   TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns);
401   TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns);
402   AddCanonicalizationPatterns(module.getContext(), &patterns);
403 
404   // `supported_ops` contains the name of all of the ops that can potentially be
405   // lowered into HLO on the device. This doesn't always mean that the op can
406   // be lowered in the future passes but if the op is not in this set, it can't
407   // be lowered in a subsequent pass.
408   llvm::DenseSet<OperationName> supported_ops;
409   PatternApplicator(std::move(patterns))
410       .walkAllPatterns([&](const Pattern& pattern) {
411         Optional<OperationName> root_kind = pattern.getRootKind();
412         if (root_kind.has_value()) supported_ops.insert(root_kind.getValue());
413       });
414   AddSupportedFunctionalOps(module.getContext(), &supported_ops);
415   AddSupportedOpsUsingFolding(module.getContext(), &supported_ops);
416   AddSupportedOpsUsingDynamicPadder(module.getContext(), &supported_ops);
417   AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
418   AddRewrittenCompositeOps(module.getContext(), &supported_ops);
419 
420   auto result = module.walk([&](tf_device::ClusterOp cluster) {
421     // Only if `allow_soft_placement` attribute is true should we mark ops
422     // for outside compilation.
423     auto soft_placement_attr =
424         cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
425     if ((soft_placement_attr && soft_placement_attr.getValue())) {
426       if (failed(MarkUncompilableOps(tf_dialect, &cluster.GetBody(),
427                                      supported_ops)))
428         return WalkResult::interrupt();
429     } else {
430       if (ContainsUncompilableOps(tf_dialect, &cluster.GetBody(),
431                                   supported_ops))
432         return WalkResult::interrupt();
433     }
434     MarkVariantInputsOutputs(cluster);
435 
436     return WalkResult::advance();
437   });
438 
439   if (result.wasInterrupted()) return signalPassFailure();
440 
441   module.walk([&](tf_device::ClusterOp cluster) {
442     // Only if `allow_soft_placement` attribute is true should we unmark ops
443     // for outside compilation.
444     auto soft_placement_attr =
445         cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
446     if (!(soft_placement_attr && soft_placement_attr.getValue())) {
447       return;
448     }
449     UnmarkChildren(&cluster.GetBody());
450   });
451 }
452 
453 }  // namespace
454 
455 std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass()456 CreateMarkOpsForOutsideCompilationPass() {
457   return std::make_unique<MarkOpsForOutsideCompilation>();
458 }
459 
460 }  // namespace TFDevice
461 }  // namespace mlir
462