1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
26 #include "mlir/Support/LLVM.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 namespace mlir {
32 namespace TFTPU {
33 
34 namespace {
35 
36 // A pass that moves `tf.AssignVariableOp` into a `tf_device.parallel_execute`
37 // region if the `tf.AssignVariableOp` is the only consumer of a
38 // `tf_device.parallel_execute` result. This will allow
39 // TPUMergeVariablesWithExecute to merge resource writes without special
40 // handling for `tf_device.parallel_execute`.
41 struct TPUParallelExecuteSinkResourceWrite
42     : public TF::TPUParallelExecuteSinkResourceWritePassBase<
43           TPUParallelExecuteSinkResourceWrite> {
44   void runOnOperation() override;
45 };
46 
47 // Finds an AssignVariableOp that can be moved into the parallel_execute region.
48 // These AssignVariableOps must be the only consumer of the respective
49 // parallel_execute result, and the resource handle producer must be from an op
50 // before or above the parallel_execute.
GetSingleUseResourceWrite(tf_device::ParallelExecuteOp parallel_execute,Value result)51 TF::AssignVariableOp GetSingleUseResourceWrite(
52     tf_device::ParallelExecuteOp parallel_execute, Value result) {
53   if (!result.hasOneUse()) return nullptr;
54 
55   OpOperand& use = *result.getUses().begin();
56   auto assign_var = dyn_cast<TF::AssignVariableOp>(use.getOwner());
57   if (!assign_var) return nullptr;
58 
59   if (use.get() != assign_var.value()) return nullptr;
60 
61   auto* resource_handle_op = assign_var.resource().getDefiningOp();
62   if (resource_handle_op == parallel_execute) return nullptr;
63 
64   if (resource_handle_op &&
65       resource_handle_op->getBlock() ==
66           parallel_execute.getOperation()->getBlock() &&
67       parallel_execute.getOperation()->isBeforeInBlock(resource_handle_op))
68     return nullptr;
69 
70   return assign_var;
71 }
72 
73 // Finds AssignVariableOps that can be moved into a parallel_execute region and
74 // moves them. Leftover parallel_execute results that were used by the
75 // such AssignVariableOp are also pruned.
SinkResourceWritesIntoParallelExecute(tf_device::ParallelExecuteOp parallel_execute)76 void SinkResourceWritesIntoParallelExecute(
77     tf_device::ParallelExecuteOp parallel_execute) {
78   bool rewrite = false;
79   const int num_regions = parallel_execute.getNumRegions();
80   llvm::SmallVector<Value, 4> results_to_remap;
81 
82   // Go through each region and find AssignVariableOps that can be moved into
83   // the parallel_execute region. Result indices by region index are collected,
84   // so they can be removed afterwards.
85   llvm::SmallVector<llvm::SmallVector<int, 4>, 4> results_to_remove_by_region;
86   results_to_remove_by_region.resize(num_regions);
87   for (int i = 0; i < num_regions; ++i) {
88     Block& block = parallel_execute.GetRegionBlockWithIndex(i);
89     auto results = parallel_execute.GetRegionOutputs(i);
90     auto& results_to_remove = results_to_remove_by_region[i];
91     results_to_remove.reserve(results.size());
92     Operation* terminator = block.getTerminator();
93     for (auto result : llvm::enumerate(results)) {
94       TF::AssignVariableOp assign_var =
95           GetSingleUseResourceWrite(parallel_execute, result.value());
96       if (!assign_var) {
97         results_to_remap.push_back(result.value());
98         continue;
99       }
100 
101       // Move AssignVariableOp and update the value to be written to the
102       // resource variable to be the non forwarded value from within the
103       // parallel_execute region.
104       assign_var.getOperation()->moveBefore(terminator);
105       assign_var.valueMutable().assign(terminator->getOperand(result.index()));
106       results_to_remove.push_back(result.index());
107     }
108 
109     rewrite |= !results_to_remove.empty();
110   }
111 
112   if (!rewrite) return;
113 
114   // Remove leftover unused results (terminator operands) from moving
115   // AssignVariabeOps into the parallel_execute region.
116   for (auto results_to_remove : llvm::enumerate(results_to_remove_by_region)) {
117     Block& block =
118         parallel_execute.GetRegionBlockWithIndex(results_to_remove.index());
119     Operation* terminator = block.getTerminator();
120     for (int index_to_remove : llvm::reverse(results_to_remove.value()))
121       terminator->eraseOperand(index_to_remove);
122   }
123 
124   // Replace old parallel_execute with new parallel_execute by moving the
125   // regions to a new parallel_execute and remapping the results.
126   llvm::SmallVector<Type, 4> new_result_types;
127   new_result_types.reserve(results_to_remap.size());
128   for (Value old_result : results_to_remap)
129     new_result_types.push_back(old_result.getType());
130 
131   OpBuilder builder(parallel_execute);
132   auto new_parallel_execute = builder.create<tf_device::ParallelExecuteOp>(
133       parallel_execute.getLoc(), num_regions, new_result_types);
134 
135   for (auto region : llvm::zip(new_parallel_execute.getRegions(),
136                                parallel_execute.getRegions()))
137     std::get<0>(region)->takeBody(*std::get<1>(region));
138 
139   for (auto result :
140        llvm::zip(results_to_remap, new_parallel_execute.getResults()))
141     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
142 
143   parallel_execute.erase();
144 }
145 
runOnOperation()146 void TPUParallelExecuteSinkResourceWrite::runOnOperation() {
147   llvm::SmallVector<tf_device::ParallelExecuteOp, 4> parallel_executes;
148   getOperation().walk([&](tf_device::ParallelExecuteOp parallel_execute) {
149     parallel_executes.push_back(parallel_execute);
150   });
151 
152   for (tf_device::ParallelExecuteOp parallel_execute : parallel_executes)
153     SinkResourceWritesIntoParallelExecute(parallel_execute);
154 }
155 
156 }  // anonymous namespace
157 
158 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUParallelExecuteSinkResourceWritePass()159 CreateTPUParallelExecuteSinkResourceWritePass() {
160   return std::make_unique<TPUParallelExecuteSinkResourceWrite>();
161 }
162 
163 }  // namespace TFTPU
164 }  // namespace mlir
165