1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/cf_sink/pass.h"
17
18 #include <functional>
19 #include <memory>
20
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/Debug.h"
23 #include "mlir/IR/Dominance.h" // from @llvm-project
24 #include "mlir/IR/OpDefinition.h" // from @llvm-project
25 #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "mlir/Transforms/ControlFlowSinkUtils.h" // from @llvm-project
28 #include "tensorflow/core/ir/dialect.h"
29 #include "tensorflow/core/ir/interfaces.h"
30 #include "tensorflow/core/ir/ops.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/transforms/pass_detail.h"
33
34 namespace mlir {
35 namespace tfg {
36 namespace {
37
38 class ControlFlowSinkPass : public ControlFlowSinkBase<ControlFlowSinkPass> {
39 public:
40 // Initialize the pass by getting a cached identifier to the name attribute.
initialize(MLIRContext * context)41 LogicalResult initialize(MLIRContext *context) override {
42 name_id_ =
43 context->getOrLoadDialect<TFGraphDialect>()->getNameAttrIdentifier();
44 return success();
45 }
46
47 // Move the operation to the start of the entry block. Rename it if necessary.
48 void moveAndRename(Operation *op, Region *region);
49
50 void runOnOperation() override;
51
52 private:
53 // Cached name identifier.
54 StringAttr name_id_;
55 };
56 } // namespace
57
IsStateless(Operation * op)58 static bool IsStateless(Operation *op) {
59 if (auto registry = dyn_cast<TensorFlowRegistryInterface>(op))
60 return !registry.isStateful();
61 return false;
62 }
63
64 // Don't sink TPU-specific ops and ops with regions.
IsExcluded(Operation * op)65 static bool IsExcluded(Operation *op) {
66 // TODO(b/228618345) Ops with `i32` operands cannot be safely sunk due to a
67 // potential placement bug on GPU.
68
69 // Don't sink ops with regions as it can create nested regions so deep that
70 // the verifier is stack overflowed.
71 if (op->getNumRegions()) {
72 return true;
73 }
74
75 // TPU ops cannot be moved, even though they are marked as stateless.
76 // TODO(jeffniu): TPU ops should be marked in some other way.
77 StringRef op_name = op->getName().stripDialect();
78 return op_name == "TPUReplicateMetadata" || op_name == "TPUReplicatedInput" ||
79 op_name == "TPUReplicatedOutput" ||
80 op_name == "TPUCompilationResult" || op_name == "_TPUReplicate";
81 }
82
moveAndRename(Operation * op,Region * region)83 void ControlFlowSinkPass::moveAndRename(Operation *op, Region *region) {
84 op->moveBefore(®ion->front(), region->front().begin());
85 auto name = op->getAttrOfType<StringAttr>(name_id_);
86 auto parent_name = region->getParentOp()->getAttrOfType<StringAttr>(name_id_);
87 if (!name || !parent_name) return;
88 op->setAttr(name_id_, StringAttr::get(op->getContext(),
89 name.getValue() + "_tfg_cf_sunk_" +
90 parent_name.getValue()));
91 }
92
runOnOperation()93 void ControlFlowSinkPass::runOnOperation() {
94 auto &domInfo = getAnalysis<DominanceInfo>();
95 getOperation()->walk([&](RegionBranchOpInterface branch) {
96 SmallVector<Region *> regions;
97 getSinglyExecutedRegionsToSink(branch, regions);
98 num_sunk += controlFlowSink(
99 regions, domInfo,
100 /*shouldMoveIntoRegion=*/
101 [&](Operation *op, Region *) {
102 return IsStateless(op) && !IsExcluded(op);
103 },
104 /*moveIntoRegion=*/
105 [&](Operation *op, Region *region) { moveAndRename(op, region); });
106 });
107 VLOG(1) << "tfg-cf-sink num-sunk: " << num_sunk;
108 }
109
CreateControlFlowSinkPass()110 std::unique_ptr<Pass> CreateControlFlowSinkPass() {
111 return std::make_unique<ControlFlowSinkPass>();
112 }
113
114 } // namespace tfg
115 } // namespace mlir
116