1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30
31 namespace mlir {
32 namespace TFTPU {
33
34 // A pass that finds TPU clusters with write only resource access and adds an
35 // associated resource read, so the resource can later be fused into TPUExecute.
36 namespace {
37 struct TPUResourceReadForWritePass
38 : public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
39 void runOnOperation() override;
40 };
41
42 // Helper struct holding a resource value and its associated type.
43 struct ResourceValueAndSubtype {
44 Value resource;
45 Type subtype;
46 };
47
48 // Finds resource handle and type for result if result writes to a resource.
GetResourceWriteResult(tf_device::ClusterFuncOp cluster_func,Value result)49 ResourceValueAndSubtype GetResourceWriteResult(
50 tf_device::ClusterFuncOp cluster_func, Value result) {
51 ResourceValueAndSubtype resource;
52 if (!result.hasOneUse()) return resource;
53 Operation* result_user = *result.getUsers().begin();
54 auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
55 if (!assign_var) return resource;
56
57 auto handle = assign_var.resource();
58 // Skip result if cluster writes to the same variable via multiple results.
59 for (Operation* handle_user : handle.getUsers()) {
60 if (handle_user == assign_var) continue;
61 auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
62 if (!assign_var_user) continue;
63 if (assign_var_user.value().getDefiningOp() == cluster_func)
64 return resource;
65 }
66
67 resource.resource = assign_var.resource();
68 resource.subtype = assign_var.value().getType();
69 return resource;
70 }
71
72 // Checks if resource is read by TPU cluster.
ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,Value resource)73 bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
74 Value resource) {
75 for (Operation* resource_user : resource.getUsers())
76 if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
77 for (Operation* read_user : read.value().getUsers())
78 if (read_user == cluster_func) return true;
79
80 return false;
81 }
82
runOnOperation()83 void TPUResourceReadForWritePass::runOnOperation() {
84 SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
85 getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
86 cluster_funcs.push_back(cluster_func);
87 });
88
89 OpBuilder builder(&getContext());
90 // Add resource reads for resource writes from TPU cluster where for such
91 // resources the TPU cluster does not read from.
92 for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
93 builder.setInsertionPoint(cluster_func);
94
95 SmallVector<Value, 4> read_operands;
96 for (Value result : cluster_func.getResults()) {
97 // TODO(lyandy): Update pass to use resource alias analysis.
98 auto resource_and_type = GetResourceWriteResult(cluster_func, result);
99 if (!resource_and_type.resource) continue;
100 if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
101 continue;
102 auto new_read = builder.create<TF::ReadVariableOp>(
103 resource_and_type.resource.getLoc(), resource_and_type.subtype,
104 resource_and_type.resource);
105 read_operands.push_back(new_read.value());
106 }
107
108 if (read_operands.empty()) continue;
109
110 // Update caller and function types with new read operands.
111 auto operands = llvm::to_vector<4>(cluster_func.getOperands());
112 operands.append(read_operands.begin(), read_operands.end());
113
114 auto loc = cluster_func.getLoc();
115 auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
116 loc, cluster_func.getResultTypes(), operands, cluster_func->getAttrs());
117 cluster_func.replaceAllUsesWith(new_cluster_func);
118 func::FuncOp func = cluster_func.getFunc();
119 Block& block = func.front();
120 for (Value read_operand : read_operands)
121 block.addArgument(read_operand.getType(), loc);
122
123 func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
124 func.getCallableResults()));
125 cluster_func.erase();
126 }
127 }
128
129 } // namespace
130
CreateTPUResourceReadForWritePass()131 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
132 return std::make_unique<TPUResourceReadForWritePass>();
133 }
134
135 } // namespace TFTPU
136 } // namespace mlir
137