1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
16 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
17 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
18 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
19 #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
20 #include "tfrt/compiler/stream_analysis.h" // from @tf_runtime
21
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 namespace {
25
26 // This pass inserts copy kernels for fallback tensors when they are passed to
27 // multiple threads, to avoid atomic contention on their refcounts.
28 class InsertFallbackTensorCopy
29 : public mlir::PassWrapper<InsertFallbackTensorCopy,
30 mlir::OperationPass<mlir::func::FuncOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const31 void getDependentDialects(mlir::DialectRegistry& registry) const override {
32 registry.insert<tfrt::fallback_async::FallbackAsyncDialect>();
33 }
34
getArgument() const35 llvm::StringRef getArgument() const final {
36 return "tfrt-insert-fallback-tensor-copy";
37 }
38
getDescription() const39 llvm::StringRef getDescription() const final {
40 return "Inserts copy kernels for fallback tensors when they are passed to "
41 "multiple threads, to avoid atomic contention on refcounts.";
42 }
43
44 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertFallbackTensorCopy)45 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertFallbackTensorCopy)
46
47 void runOnOperation() override {
48 mlir::func::FuncOp func_op = getOperation();
49
50 // Use stream analysis to know whether a value is passed to different
51 // threads.
52 tfrt::compiler::StreamAnalysis stream_analysis(func_op);
53
54 auto builder = mlir::OpBuilder::atBlockBegin(&func_op.front());
55
56 // Process function arguments first.
57 for (auto arg : func_op.getArguments()) {
58 if (!arg.getType().isa<tfrt::fallback::TFTensorType>()) continue;
59 InsertFallbackTensorCopyForValue(arg, func_op->getLoc(), builder,
60 stream_analysis);
61 }
62
63 // Then process each operations in the block.
64 for (mlir::Operation& op : llvm::make_early_inc_range(func_op.front())) {
65 if (llvm::isa<tfrt::fallback_async::ExecuteOp,
66 tfrt::fallback_async::ExecuteOpSeq>(&op)) {
67 InsertFallbackTensorCopyForFallbackOp(&op, builder, stream_analysis);
68 }
69 }
70 }
71
72 private:
InsertFallbackTensorCopyForFallbackOp(mlir::Operation * op,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)73 void InsertFallbackTensorCopyForFallbackOp(
74 mlir::Operation* op, mlir::OpBuilder& builder,
75 const tfrt::compiler::StreamAnalysis& stream_analysis) {
76 builder.setInsertionPointAfter(op);
77
78 // Process each result value.
79 for (auto result : op->getResults()) {
80 if (!result.getType().isa<tfrt::fallback::TFTensorType>()) continue;
81 InsertFallbackTensorCopyForValue(result, op->getLoc(), builder,
82 stream_analysis);
83 }
84 }
85
86 // Insert copy kernels to copy the result, and allocate new atomic refcount
87 // if the value is going to be used by different streams/threads, in order to
88 // avoid contention on the atomic counter.
InsertFallbackTensorCopyForValue(mlir::Value value,mlir::Location loc,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)89 void InsertFallbackTensorCopyForValue(
90 mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder,
91 const tfrt::compiler::StreamAnalysis& stream_analysis) {
92 llvm::DenseMap<int, llvm::SmallVector<mlir::OpOperand*, 4>> stream_map;
93
94 // Find out streams that use this value and the corresponding uses.
95 for (mlir::OpOperand& use : value.getUses()) {
96 // Skip return op as there should not be atomic contention on the return
97 // op.
98 if (llvm::isa<tfrt::compiler::ReturnOp>(use.getOwner())) continue;
99
100 int stream_id = stream_analysis.GetStream(use.getOwner()).id();
101 stream_map[stream_id].push_back(&use);
102 }
103
104 // Organize these uses into groups. If a stream has many uses of this value,
105 // put these uses into one stream. Otherwise, streams with small number
106 // of uses are grouped with each other to form groups with enough uses.
107 constexpr int kCopyGroupThreshold = 16;
108 llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> small_copies;
109 llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> copies;
110 for (const auto& iter : stream_map) {
111 if (iter.second.size() >= kCopyGroupThreshold) {
112 copies.push_back(iter.second);
113 } else {
114 if (small_copies.empty() ||
115 small_copies.back().size() >= kCopyGroupThreshold) {
116 small_copies.push_back(iter.second);
117 } else {
118 small_copies.back().append(iter.second.begin(), iter.second.end());
119 }
120 }
121 }
122
123 if (!small_copies.empty())
124 copies.append(small_copies.begin(), small_copies.end());
125
126 // If it is only used by one group, then we don't need to copy.
127 if (copies.size() <= 1) return;
128
129 // Remove one group from the candidates, as we can just use the original
130 // value for this group.
131 copies.pop_back();
132
133 // For each stream, we will create one new value that replaces the uses in
134 // that stream.
135
136 assert(value.getType().isa<tfrt::fallback::TFTensorType>());
137
138 // The number of results is the number candidate streams.
139 llvm::SmallVector<mlir::Type, 4> result_types(copies.size(),
140 value.getType());
141 assert(!result_types.empty());
142
143 // Create the tfrt_fallback_async.copy_if_small kernel.
144 auto copy_op = builder.create<tfrt::fallback_async::CopyIfSmallOp>(
145 loc, result_types, value);
146
147 // Finally, replaces all uses with the new value.
148 for (int i = 0; i < copies.size(); ++i) {
149 const auto& uses = copies[i];
150 auto new_value = copy_op.getResult(i);
151 for (auto* use : uses) {
152 use->set(new_value);
153 }
154 }
155 }
156 };
157
158 } // namespace
159
160 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateInsertFallbackTensorCopyPass()161 CreateInsertFallbackTensorCopyPass() {
162 return std::make_unique<InsertFallbackTensorCopy>();
163 }
164
165 static mlir::PassRegistration<InsertFallbackTensorCopy> register_pass(
166 CreateInsertFallbackTensorCopyPass);
167
168 } // namespace tfrt_compiler
169 } // namespace tensorflow
170