1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/OperationSupport.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
26 #include "mlir/Rewrite/PatternApplicator.h" // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37
38 namespace mlir {
39 namespace TFDevice {
40
41 namespace {
42
43 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
44 constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement";
45
46 auto* auto_outside_compilation_gauge =
47 tensorflow::monitoring::Gauge<bool, 0>::New(
48 "/tensorflow/core/use_auto_outside_compilation",
49 "Tracks if auto outside compilation is enabled");
50
51 struct MarkOpsForOutsideCompilation
52 : public TF::MarkOpsForOutsideCompilationPassBase<
53 MarkOpsForOutsideCompilation> {
54 void runOnOperation() override;
55 };
56
57 // Adds any canonicalization patterns to list of supported `patterns`.
58 // TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and
59 // remove this.
AddCanonicalizationPatterns(MLIRContext * context,RewritePatternSet * patterns)60 void AddCanonicalizationPatterns(MLIRContext* context,
61 RewritePatternSet* patterns) {
62 for (auto op : context->getRegisteredOperations())
63 op.getCanonicalizationPatterns(*patterns, context);
64 }
65
66 // Adds the list of ops that are supported on TPU through constant folding which
67 // may depend on the inputs shapes not known at this point. Such ops may not
68 // have any legalization or canonicalization patterns but shouldn't be marked
69 // for outside compilation.
70 //
71 // TODO(b/177523289): Remove manual handling once we support constant folding
72 // and shape inference through the computation on the host side.
AddSupportedOpsUsingFolding(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)73 void AddSupportedOpsUsingFolding(MLIRContext* context,
74 llvm::DenseSet<OperationName>* supported_ops) {
75 llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
76 OperationName(TF::BroadcastArgsOp::getOperationName(), context),
77 OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context),
78 OperationName(TF::ConcatOffsetOp::getOperationName(), context),
79 OperationName(TF::EmptyOp::getOperationName(), context),
80 OperationName(TF::ListDiffOp::getOperationName(), context),
81 OperationName(TF::RankOp::getOperationName(), context),
82 OperationName(TF::RangeOp::getOperationName(), context),
83 OperationName(TF::ShapeOp::getOperationName(), context),
84 OperationName(TF::ShapeNOp::getOperationName(), context),
85 OperationName(TF::SizeOp::getOperationName(), context),
86 };
87
88 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
89 }
90
91 // Adds the list of ops that are supported through dynamic padder using op by op
92 // fallback to the TF2XLA bridge.
93 // TODO(b/168036682): Remove this once ops are supported using dynamic padder
94 // on MLIR bridge.
AddSupportedOpsUsingDynamicPadder(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)95 void AddSupportedOpsUsingDynamicPadder(
96 MLIRContext* context, llvm::DenseSet<OperationName>* supported_ops) {
97 llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
98 OperationName(TF::WhereOp::getOperationName(), context),
99 OperationName(TF::UniqueOp::getOperationName(), context),
100 OperationName(TF::XlaSetDynamicDimensionSizeOp::getOperationName(),
101 context),
102 };
103
104 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
105 }
106
107 // TODO(b/159128666): Check the control flow legalization passes instead once
108 // added.
AddSupportedFunctionalOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)109 void AddSupportedFunctionalOps(MLIRContext* context,
110 llvm::DenseSet<OperationName>* supported_ops) {
111 supported_ops->insert(
112 OperationName(TF::CaseRegionOp::getOperationName(), context));
113 supported_ops->insert(
114 OperationName(TF::IfRegionOp::getOperationName(), context));
115 supported_ops->insert(
116 OperationName(TF::InplaceAddOp::getOperationName(), context));
117 supported_ops->insert(
118 OperationName(TF::WhileRegionOp::getOperationName(), context));
119 supported_ops->insert(
120 OperationName(TF::XlaCallModuleOp::getOperationName(), context));
121 supported_ops->insert(
122 OperationName(TF::XlaReduceOp::getOperationName(), context));
123 supported_ops->insert(
124 OperationName(TF::XlaReduceWindowOp::getOperationName(), context));
125 supported_ops->insert(
126 OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context));
127 supported_ops->insert(
128 OperationName(TF::XlaScatterOp::getOperationName(), context));
129 supported_ops->insert(
130 OperationName(TF::XlaSelectAndScatterOp::getOperationName(), context));
131 supported_ops->insert(
132 OperationName(TF::SymbolicGradientOp::getOperationName(), context));
133 supported_ops->insert(
134 OperationName(TF::XlaVariadicReduceOp::getOperationName(), context));
135 supported_ops->insert(
136 OperationName(TF::XlaVariadicReduceV2Op::getOperationName(), context));
137 supported_ops->insert(
138 OperationName(TF::XlaVariadicSortOp::getOperationName(), context));
139 supported_ops->insert(
140 OperationName(TF::XlaReplicaIdOp::getOperationName(), context));
141 supported_ops->insert(
142 OperationName(TF::YieldOp::getOperationName(), context));
143 }
144
145 // These embedding ops are rewritten when running TPUCompileOp.
AddRewrittenEmbeddingOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)146 void AddRewrittenEmbeddingOps(MLIRContext* context,
147 llvm::DenseSet<OperationName>* supported_ops) {
148 supported_ops->insert(OperationName(
149 TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context));
150 supported_ops->insert(OperationName(
151 TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
152 }
153
154 // Stack, TensorList and TensorArray ops are rewritten during the second phase
155 // of the bridge (compilation of TPUCompile op). They would not match any
156 // legalization/canonicalization pattern and have to be manually added to the
157 // list of supported ops.
AddRewrittenCompositeOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)158 void AddRewrittenCompositeOps(MLIRContext* context,
159 llvm::DenseSet<OperationName>* supported_ops) {
160 #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
161 llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
162 // Stack ops.
163 GET_OPERATION_NAME(TF::StackV2Op),
164 GET_OPERATION_NAME(TF::StackPushV2Op),
165 GET_OPERATION_NAME(TF::StackPopV2Op),
166 // Tensor Array ops.
167 GET_OPERATION_NAME(TF::TensorArrayV3Op),
168 GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
169 GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
170 GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
171 GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
172 GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
173 GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
174 GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
175 GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
176 // Tensor List Ops.
177 GET_OPERATION_NAME(TF::EmptyTensorListOp),
178 GET_OPERATION_NAME(TF::TensorListReserveOp),
179 GET_OPERATION_NAME(TF::TensorListFromTensorOp),
180 GET_OPERATION_NAME(TF::TensorListPushBackOp),
181 GET_OPERATION_NAME(TF::TensorListPopBackOp),
182 GET_OPERATION_NAME(TF::TensorListGetItemOp),
183 GET_OPERATION_NAME(TF::TensorListSetItemOp),
184 GET_OPERATION_NAME(TF::TensorListLengthOp),
185 GET_OPERATION_NAME(TF::TensorListElementShapeOp),
186 GET_OPERATION_NAME(TF::TensorListGatherOp),
187 GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
188 GET_OPERATION_NAME(TF::TensorListStackOp),
189 };
190 #undef GET_OPERATION_NAME
191
192 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
193 }
194
IsStringType(Type type)195 bool IsStringType(Type type) {
196 if (type.isa<TF::StringType>()) return true;
197
198 auto sub_type = type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
199 if (!sub_type) return false;
200
201 bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) {
202 return type.getElementType().isa<TF::StringType>();
203 });
204 return has_string;
205 }
206
HasStringOperand(Operation & op)207 bool HasStringOperand(Operation& op) {
208 for (auto operand : op.getOperands()) {
209 auto operand_type = getElementTypeOrSelf(operand);
210 if (IsStringType(operand_type)) return true;
211 }
212 return false;
213 }
214
HasStringResult(Operation & op)215 bool HasStringResult(Operation& op) {
216 for (auto result : op.getResults()) {
217 auto result_type = getElementTypeOrSelf(result);
218 if (IsStringType(result_type)) return true;
219 }
220 return false;
221 }
222
MatchesPattern(Operation & op,const llvm::DenseSet<OperationName> & supported_ops)223 bool MatchesPattern(Operation& op,
224 const llvm::DenseSet<OperationName>& supported_ops) {
225 return (supported_ops.contains(op.getName()));
226 }
227
228 // Checks if the op is supported inside of a device cluster. Ops not
229 // in `tf_dialect` are considered supported.
IsSupportedOp(Operation & op,const llvm::DenseSet<OperationName> & supported_ops,const Dialect * tf_dialect)230 bool IsSupportedOp(Operation& op,
231 const llvm::DenseSet<OperationName>& supported_ops,
232 const Dialect* tf_dialect) {
233 if (op.getDialect() != tf_dialect)
234 return true;
235 // Assert has a legalization that later removes it so we don't want to outside
236 // compile it ever for performance reasons.
237 if (llvm::isa<TF::AssertOp>(op)) return true;
238 return !HasStringOperand(op) && !HasStringResult(op) &&
239 (MatchesPattern(op, supported_ops) ||
240 mhlo::IsOpAllowedTf2XlaFallback(&op));
241 }
242
243 // Checks all regions of `op` for captured string operands.
HasCapturedStringOperand(Operation * op)244 bool HasCapturedStringOperand(Operation* op) {
245 bool string_operand = false;
246 for (auto& region : op->getRegions()) {
247 mlir::visitUsedValuesDefinedAbove(
248 region, region, [&](mlir::OpOperand* operand) {
249 if (getElementTypeOrSelf(operand->get()).isa<TF::StringType>())
250 string_operand = true;
251 });
252 if (string_operand) return string_operand;
253 }
254 return string_operand;
255 }
256
IsVariant(Value value)257 bool IsVariant(Value value) {
258 return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>();
259 }
260
HasOutsideCompiledAncestor(Operation * op)261 bool HasOutsideCompiledAncestor(Operation* op) {
262 Operation* parent = op->getParentOp();
263 while (parent) {
264 if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
265 return true;
266 parent = parent->getParentOp();
267 }
268 return false;
269 }
270
271 // If any tf.variants are inputs/outputs to the another outside compiled
272 // Operation, `op`, mark them for outside compilation unless they are already
273 // marks with outside compilation attribute.
MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster)274 void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) {
275 std::queue<Operation*> outside_compiled_ops;
276 tpu_cluster.walk([&](Operation* op) {
277 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
278 outside_compiled_ops.push(op);
279 });
280
281 while (!outside_compiled_ops.empty()) {
282 Operation* host_op = outside_compiled_ops.front();
283 outside_compiled_ops.pop();
284 host_op->walk([&](Operation* op) {
285 // Add any operations that provide variant inputs to the cluster.
286 for (auto value : op->getOperands()) {
287 Operation* input_defining_op = value.getDefiningOp();
288 if (IsVariant(value) && input_defining_op &&
289 !HasOutsideCompiledAncestor(input_defining_op) &&
290 !input_defining_op->hasAttrOfType<StringAttr>(
291 kXlaOutsideCompilationAttr)) {
292 input_defining_op->setAttr(
293 kXlaOutsideCompilationAttr,
294 StringAttr::get(input_defining_op->getContext(), "auto"));
295 outside_compiled_ops.push(input_defining_op);
296 }
297 }
298 // Mark for outside compilation any operations that consume variant
299 // outputs from an outside compiled operation.
300 for (auto value : op->getResults()) {
301 if (IsVariant(value)) {
302 for (auto user : value.getUsers()) {
303 if (!user->hasTrait<OpTrait::IsTerminator>() &&
304 !HasOutsideCompiledAncestor(user) &&
305 !user->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
306 user->setAttr(kXlaOutsideCompilationAttr,
307 StringAttr::get(user->getContext(), "auto"));
308 outside_compiled_ops.push(user);
309 }
310 }
311 }
312 }
313 });
314 }
315 }
316
317 // Marks uncompilable ops that are in `tf_dialect` for outside compilation.
MarkUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)318 LogicalResult MarkUncompilableOps(
319 const Dialect* tf_dialect, Block* block,
320 llvm::DenseSet<OperationName>& supported_ops) {
321 // Automatically marked ops for outside compilation have
322 // `_xla_outside_compilation` attribute value of "auto" plus
323 // an increasing counter. Manually marked ops for outside compilation only
324 // have an increasing counteri for the attribute value. Therefore there is no
325 // collision in
326 // `_xla_outside_compilation` attribute between automatically and manually
327 // marking ops.
328 int outside_compiled_cluster_counter = 0;
329 block->walk([&](Operation* op) {
330 if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
331 VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
332 << " isn't compilable, adding outside_compilation attr. "
333 "This op will automatically be placed on CPU.";
334 op->setAttr(kXlaOutsideCompilationAttr,
335 StringAttr::get(
336 op->getContext(),
337 llvm::formatv("auto{0}", outside_compiled_cluster_counter)
338 .str()));
339 outside_compiled_cluster_counter++;
340 }
341 });
342 if (outside_compiled_cluster_counter > 0) {
343 auto_outside_compilation_gauge->GetCell()->Set(true);
344 }
345 return success();
346 }
347
348 // Check for uncompilable ops that are in `tf_dialect` and are not already
349 // marked for outside compilation.
ContainsUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)350 bool ContainsUncompilableOps(const Dialect* tf_dialect, Block* block,
351 llvm::DenseSet<OperationName>& supported_ops) {
352 int uncompilable_op_count = 0;
353 // Check if op or any parent is already marked for outside compilation.
354 block->walk([&](Operation* op) {
355 Operation* iter_op = op;
356 while (iter_op && !llvm::isa<tf_device::ClusterOp>(iter_op)) {
357 if (iter_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
358 return;
359 }
360 iter_op = iter_op->getParentOp();
361 }
362
363 if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
364 op->emitOpError() << "isn't compilable for TPU device. enable "
365 "soft_device_placement option to run on CPU";
366 ++uncompilable_op_count;
367 }
368 });
369 return uncompilable_op_count > 0;
370 }
371
372 // Unmarks outside compilation for any op that has parents already
373 // marked for outside compilation since the child will be extracted
374 // anyways.
UnmarkChildren(Block * block)375 void UnmarkChildren(Block* block) {
376 block->walk([&](Operation* op) {
377 if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
378 Operation* iter_op = op;
379 bool remove_attr = false;
380 while (auto* parent_op = iter_op->getParentOp()) {
381 if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
382 remove_attr = true;
383 break;
384 }
385 iter_op = parent_op;
386 }
387 if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
388 });
389 }
390
runOnOperation()391 void MarkOpsForOutsideCompilation::runOnOperation() {
392 auto module = getOperation();
393 const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
394 if (!tf_dialect) {
395 getOperation().emitError() << "'tf' dialect is not registered";
396 return signalPassFailure();
397 }
398 RewritePatternSet patterns(&getContext());
399 mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
400 TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns);
401 TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns);
402 AddCanonicalizationPatterns(module.getContext(), &patterns);
403
404 // `supported_ops` contains the name of all of the ops that can potentially be
405 // lowered into HLO on the device. This doesn't always mean that the op can
406 // be lowered in the future passes but if the op is not in this set, it can't
407 // be lowered in a subsequent pass.
408 llvm::DenseSet<OperationName> supported_ops;
409 PatternApplicator(std::move(patterns))
410 .walkAllPatterns([&](const Pattern& pattern) {
411 Optional<OperationName> root_kind = pattern.getRootKind();
412 if (root_kind.has_value()) supported_ops.insert(root_kind.getValue());
413 });
414 AddSupportedFunctionalOps(module.getContext(), &supported_ops);
415 AddSupportedOpsUsingFolding(module.getContext(), &supported_ops);
416 AddSupportedOpsUsingDynamicPadder(module.getContext(), &supported_ops);
417 AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
418 AddRewrittenCompositeOps(module.getContext(), &supported_ops);
419
420 auto result = module.walk([&](tf_device::ClusterOp cluster) {
421 // Only if `allow_soft_placement` attribute is true should we mark ops
422 // for outside compilation.
423 auto soft_placement_attr =
424 cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
425 if ((soft_placement_attr && soft_placement_attr.getValue())) {
426 if (failed(MarkUncompilableOps(tf_dialect, &cluster.GetBody(),
427 supported_ops)))
428 return WalkResult::interrupt();
429 } else {
430 if (ContainsUncompilableOps(tf_dialect, &cluster.GetBody(),
431 supported_ops))
432 return WalkResult::interrupt();
433 }
434 MarkVariantInputsOutputs(cluster);
435
436 return WalkResult::advance();
437 });
438
439 if (result.wasInterrupted()) return signalPassFailure();
440
441 module.walk([&](tf_device::ClusterOp cluster) {
442 // Only if `allow_soft_placement` attribute is true should we unmark ops
443 // for outside compilation.
444 auto soft_placement_attr =
445 cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
446 if (!(soft_placement_attr && soft_placement_attr.getValue())) {
447 return;
448 }
449 UnmarkChildren(&cluster.GetBody());
450 });
451 }
452
453 } // namespace
454
455 std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass()456 CreateMarkOpsForOutsideCompilationPass() {
457 return std::make_unique<MarkOpsForOutsideCompilation>();
458 }
459
460 } // namespace TFDevice
461 } // namespace mlir
462