xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
16 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
17 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
18 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
19 #include "tfrt/basic_kernels/opdefs/tfrt_base.h"  // from @tf_runtime
20 #include "tfrt/compiler/stream_analysis.h"  // from @tf_runtime
21 
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 namespace {
25 
26 // This pass inserts copy kernels for fallback tensors when they are passed to
27 // multiple threads, to avoid atomic contention on their refcounts.
28 class InsertFallbackTensorCopy
29     : public mlir::PassWrapper<InsertFallbackTensorCopy,
30                                mlir::OperationPass<mlir::func::FuncOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const31   void getDependentDialects(mlir::DialectRegistry& registry) const override {
32     registry.insert<tfrt::fallback_async::FallbackAsyncDialect>();
33   }
34 
getArgument() const35   llvm::StringRef getArgument() const final {
36     return "tfrt-insert-fallback-tensor-copy";
37   }
38 
getDescription() const39   llvm::StringRef getDescription() const final {
40     return "Inserts copy kernels for fallback tensors when they are passed to "
41            "multiple threads, to avoid atomic contention on refcounts.";
42   }
43 
44  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertFallbackTensorCopy)45   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertFallbackTensorCopy)
46 
47   void runOnOperation() override {
48     mlir::func::FuncOp func_op = getOperation();
49 
50     // Use stream analysis to know whether a value is passed to different
51     // threads.
52     tfrt::compiler::StreamAnalysis stream_analysis(func_op);
53 
54     auto builder = mlir::OpBuilder::atBlockBegin(&func_op.front());
55 
56     // Process function arguments first.
57     for (auto arg : func_op.getArguments()) {
58       if (!arg.getType().isa<tfrt::fallback::TFTensorType>()) continue;
59       InsertFallbackTensorCopyForValue(arg, func_op->getLoc(), builder,
60                                        stream_analysis);
61     }
62 
63     // Then process each operations in the block.
64     for (mlir::Operation& op : llvm::make_early_inc_range(func_op.front())) {
65       if (llvm::isa<tfrt::fallback_async::ExecuteOp,
66                     tfrt::fallback_async::ExecuteOpSeq>(&op)) {
67         InsertFallbackTensorCopyForFallbackOp(&op, builder, stream_analysis);
68       }
69     }
70   }
71 
72  private:
InsertFallbackTensorCopyForFallbackOp(mlir::Operation * op,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)73   void InsertFallbackTensorCopyForFallbackOp(
74       mlir::Operation* op, mlir::OpBuilder& builder,
75       const tfrt::compiler::StreamAnalysis& stream_analysis) {
76     builder.setInsertionPointAfter(op);
77 
78     // Process each result value.
79     for (auto result : op->getResults()) {
80       if (!result.getType().isa<tfrt::fallback::TFTensorType>()) continue;
81       InsertFallbackTensorCopyForValue(result, op->getLoc(), builder,
82                                        stream_analysis);
83     }
84   }
85 
86   // Insert copy kernels to copy the result, and allocate new atomic refcount
87   // if the value is going to be used by different streams/threads, in order to
88   // avoid contention on the atomic counter.
InsertFallbackTensorCopyForValue(mlir::Value value,mlir::Location loc,mlir::OpBuilder & builder,const tfrt::compiler::StreamAnalysis & stream_analysis)89   void InsertFallbackTensorCopyForValue(
90       mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder,
91       const tfrt::compiler::StreamAnalysis& stream_analysis) {
92     llvm::DenseMap<int, llvm::SmallVector<mlir::OpOperand*, 4>> stream_map;
93 
94     // Find out streams that use this value and the corresponding uses.
95     for (mlir::OpOperand& use : value.getUses()) {
96       // Skip return op as there should not be atomic contention on the return
97       // op.
98       if (llvm::isa<tfrt::compiler::ReturnOp>(use.getOwner())) continue;
99 
100       int stream_id = stream_analysis.GetStream(use.getOwner()).id();
101       stream_map[stream_id].push_back(&use);
102     }
103 
104     // Organize these uses into groups. If a stream has many uses of this value,
105     // put these uses into one stream. Otherwise, streams with small number
106     // of uses are grouped with each other to form groups with enough uses.
107     constexpr int kCopyGroupThreshold = 16;
108     llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> small_copies;
109     llvm::SmallVector<llvm::SmallVector<mlir::OpOperand*, 4>, 4> copies;
110     for (const auto& iter : stream_map) {
111       if (iter.second.size() >= kCopyGroupThreshold) {
112         copies.push_back(iter.second);
113       } else {
114         if (small_copies.empty() ||
115             small_copies.back().size() >= kCopyGroupThreshold) {
116           small_copies.push_back(iter.second);
117         } else {
118           small_copies.back().append(iter.second.begin(), iter.second.end());
119         }
120       }
121     }
122 
123     if (!small_copies.empty())
124       copies.append(small_copies.begin(), small_copies.end());
125 
126     // If it is only used by one group, then we don't need to copy.
127     if (copies.size() <= 1) return;
128 
129     // Remove one group from the candidates, as we can just use the original
130     // value for this group.
131     copies.pop_back();
132 
133     // For each stream, we will create one new value that replaces the uses in
134     // that stream.
135 
136     assert(value.getType().isa<tfrt::fallback::TFTensorType>());
137 
138     // The number of results is the number candidate streams.
139     llvm::SmallVector<mlir::Type, 4> result_types(copies.size(),
140                                                   value.getType());
141     assert(!result_types.empty());
142 
143     // Create the tfrt_fallback_async.copy_if_small kernel.
144     auto copy_op = builder.create<tfrt::fallback_async::CopyIfSmallOp>(
145         loc, result_types, value);
146 
147     // Finally, replaces all uses with the new value.
148     for (int i = 0; i < copies.size(); ++i) {
149       const auto& uses = copies[i];
150       auto new_value = copy_op.getResult(i);
151       for (auto* use : uses) {
152         use->set(new_value);
153       }
154     }
155   }
156 };
157 
158 }  // namespace
159 
160 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateInsertFallbackTensorCopyPass()161 CreateInsertFallbackTensorCopyPass() {
162   return std::make_unique<InsertFallbackTensorCopy>();
163 }
164 
165 static mlir::PassRegistration<InsertFallbackTensorCopy> register_pass(
166     CreateInsertFallbackTensorCopyPass);
167 
168 }  // namespace tfrt_compiler
169 }  // namespace tensorflow
170