1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/Diagnostics.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/Types.h" // from @llvm-project
26 #include "mlir/Support/LogicalResult.h" // from @llvm-project
27 #include "mlir/Transforms/Passes.h" // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/dtensor/cc/constants.h"
31 #include "tensorflow/dtensor/cc/tensor_layout.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35 #include "tensorflow/dtensor/mlir/layout_parsing.h"
36
37 namespace tensorflow {
38 namespace dtensor {
39 namespace {
40
41 constexpr char kMissingMeshAttributeErrorMessage[] =
42 "failed to merge mesh cluster as cluster does not have mesh attribute. "
43 "This is likely due to problem in mesh propagation.";
44
45 // Determines whether two adjoining clusters should be merged.
ShouldMergeClusters(mlir::tf_device::ClusterOp cluster_a,mlir::tf_device::ClusterOp cluster_b,bool * should_merge)46 mlir::LogicalResult ShouldMergeClusters(mlir::tf_device::ClusterOp cluster_a,
47 mlir::tf_device::ClusterOp cluster_b,
48 bool* should_merge) {
49 if (cluster_a->getParentRegion() != cluster_b->getParentRegion()) {
50 *should_merge = false;
51 return mlir::success();
52 }
53
54 auto mesh_a_or_status = ExtractDeviceMeshFromOp(cluster_a.getOperation());
55 if (!mesh_a_or_status.ok())
56 return cluster_a.emitOpError(mesh_a_or_status.status().error_message());
57
58 auto mesh_b_or_status = ExtractDeviceMeshFromOp(cluster_b.getOperation());
59 if (!mesh_b_or_status.ok())
60 return cluster_b.emitOpError(mesh_b_or_status.status().error_message());
61
62 auto mesh_a = mesh_a_or_status.ValueOrDie();
63 auto mesh_b = mesh_b_or_status.ValueOrDie();
64 if (!mesh_a || !mesh_b) {
65 return !mesh_a ? cluster_a.emitOpError(kMissingMeshAttributeErrorMessage)
66 : cluster_b.emitOpError(kMissingMeshAttributeErrorMessage);
67 }
68
69 *should_merge = mesh_a == mesh_b;
70 return mlir::success();
71 }
72
73 // Moves all ops (except tf_device.return op) inside `src_cluster` to
74 // block inside `target_cluster`. Ops are moved before the `exit_op`
75 // inside the `target_cluster`.
MoveOpsInsideCluster(mlir::tf_device::ClusterOp src_cluster,mlir::tf_device::ClusterOp target_cluster,mlir::Operation * exit_op)76 void MoveOpsInsideCluster(mlir::tf_device::ClusterOp src_cluster,
77 mlir::tf_device::ClusterOp target_cluster,
78 mlir::Operation* exit_op) {
79 auto& cluster_body = src_cluster.GetBody().getOperations();
80 target_cluster.GetBody().getOperations().splice(
81 exit_op->getIterator(), cluster_body, cluster_body.begin(),
82 std::prev(cluster_body.end()));
83 }
84
85 // Returns a list of pair of mlir Values that represent <return values of ops
86 // inside the merged_cluster, output values of merged cluster>.
87 //
88 // If outputs of `current_cluster` is used as operands to ops in
89 // `merging_cluster`, then make sure to replace operands such that
90 // results values from the inner ops of `current_cluster` is used instead.
91 //
92 // For example,
93 // %0 = "tf_device.cluster"() ({
94 // %1 = "tf.A"() : () -> tensor<i32>
95 // "tf_device.return"(%1) : (tensor<i32>) -> ()
96 // }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>)
97 //
98 // %2 = "tf_device.cluster"() ({
99 // %3 = "tf.B"(%0) : (tenosr<i32>) -> tensor<f32>
100 // "tf_device.return"(%3) : (tensor<f32>) -> ()
101 // }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<f32>)
102 //
103 // will be:
104 // %0 = "tf_device.cluster"() ({
105 // %1 = "tf.A"() : () -> tensor<i32>
106 //
107 // # NOTE: `tf.B` op now takes operand directly from
108 // # `tf.A` instead of `tf_dtensor.cluster op.
109 // %2 = "tf.B"(%1) : (tenosr<i32>) -> tensor<f32>
110 // "tf_device.return"(%1, %2) : (tensor<i32>, tensor<f32>)) -> ()
111 // }) {mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>, tensor<f32>)
112 llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8>
GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster,mlir::tf_device::ClusterOp merging_cluster)113 GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster,
114 mlir::tf_device::ClusterOp merging_cluster) {
115 llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8>
116 merged_cluster_results;
117 merged_cluster_results.reserve(current_cluster.getNumResults() +
118 merging_cluster.getNumResults());
119
120 auto current_cluster_return_op = current_cluster.GetBody().getTerminator();
121 for (auto result : llvm::zip(current_cluster_return_op->getOpOperands(),
122 current_cluster.getResults())) {
123 mlir::Value inner_op_result = std::get<0>(result).get();
124 mlir::Value outer_op_result = std::get<1>(result);
125
126 // If the output value of `current_cluster` is only used by ops
127 // inside the `merged_cluster`, do not add the value as a return
128 // value for newly created tf_device.cluster op.
129 bool result_only_used_by_merging_cluster = true;
130 for (auto& use : llvm::make_early_inc_range(outer_op_result.getUses())) {
131 if (merging_cluster.GetBody().findAncestorOpInBlock(*use.getOwner())) {
132 use.set(inner_op_result);
133 } else {
134 result_only_used_by_merging_cluster = false;
135 }
136 }
137
138 if (!result_only_used_by_merging_cluster) {
139 merged_cluster_results.emplace_back(inner_op_result, outer_op_result);
140 }
141 }
142
143 auto merging_cluster_return_op = merging_cluster.GetBody().getTerminator();
144 for (auto result : llvm::zip(merging_cluster_return_op->getOpOperands(),
145 merging_cluster.getResults())) {
146 mlir::Value inner_op_result = std::get<0>(result).get();
147 mlir::Value outer_op_result = std::get<1>(result);
148
149 if (!outer_op_result.getUses().empty())
150 merged_cluster_results.emplace_back(inner_op_result, outer_op_result);
151 }
152
153 return merged_cluster_results;
154 }
155
156 // Updates the users of `merging_cluster` so that they use values
157 // from `merged_cluster` instead.
ReplaceOperandUsagesWithMergedClusterOutputs(const llvm::SmallVectorImpl<mlir::Value> & values_to_replace,mlir::tf_device::ClusterOp merged_cluster)158 void ReplaceOperandUsagesWithMergedClusterOutputs(
159 const llvm::SmallVectorImpl<mlir::Value>& values_to_replace,
160 mlir::tf_device::ClusterOp merged_cluster) {
161 for (auto result :
162 llvm::zip(values_to_replace, merged_cluster.getResults())) {
163 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
164 }
165 }
166
167 // Creates a new tf_device.cluster op that merges
168 // `current_cluster` and `merging_cluster`.
CreateMergedMeshCluster(mlir::OpBuilder * builder,mlir::tf_device::ClusterOp current_cluster,mlir::tf_device::ClusterOp merging_cluster,mlir::tf_device::ClusterOp * merged_cluster)169 mlir::LogicalResult CreateMergedMeshCluster(
170 mlir::OpBuilder* builder, mlir::tf_device::ClusterOp current_cluster,
171 mlir::tf_device::ClusterOp merging_cluster,
172 mlir::tf_device::ClusterOp* merged_cluster) {
173 auto return_values =
174 GetMergedMeshClusterResults(current_cluster, merging_cluster);
175
176 llvm::SmallVector<mlir::Type, 8> merged_cluster_output_types;
177 llvm::SmallVector<mlir::Value, 8> merged_cluster_output_values;
178 llvm::SmallVector<mlir::Value, 8> output_values_to_replace;
179 merged_cluster_output_types.reserve(return_values.size());
180 merged_cluster_output_values.reserve(return_values.size());
181 output_values_to_replace.reserve(return_values.size());
182 for (auto cluster_return_value : return_values) {
183 auto inner_op_return_value = std::get<0>(cluster_return_value);
184 merged_cluster_output_types.emplace_back(inner_op_return_value.getType());
185 merged_cluster_output_values.emplace_back(inner_op_return_value);
186 output_values_to_replace.emplace_back(std::get<1>(cluster_return_value));
187 }
188
189 *merged_cluster = builder->create<mlir::tf_device::ClusterOp>(
190 current_cluster.getLoc(), merged_cluster_output_types);
191 auto mesh_attr = current_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
192 if (!mesh_attr)
193 return current_cluster.emitOpError(kMissingMeshAttributeErrorMessage);
194
195 (*merged_cluster)->setAttr(kMeshAttr, mesh_attr);
196
197 // Create a terminator op that returns all return values from
198 // `current_cluster` and `merging_cluster`.
199 merged_cluster->body().push_back(new mlir::Block);
200 builder->setInsertionPointToEnd(&merged_cluster->GetBody());
201 builder->create<mlir::tf_device::ReturnOp>(merged_cluster->getLoc(),
202 merged_cluster_output_values);
203
204 // Make sure to replace usages of tf_device.cluster ops to be merged-away with
205 // newly created tf_device.cluster op.
206 ReplaceOperandUsagesWithMergedClusterOutputs(output_values_to_replace,
207 *merged_cluster);
208
209 return mlir::success();
210 }
211
212 // Merges `current_cluster` and `merging_cluster` and returns a new merged
213 // tf_device.cluster.
MergeClusters(mlir::OpBuilder * builder,mlir::tf_device::ClusterOp current_cluster,mlir::tf_device::ClusterOp merging_cluster,mlir::tf_device::ClusterOp * merged_cluster)214 mlir::LogicalResult MergeClusters(mlir::OpBuilder* builder,
215 mlir::tf_device::ClusterOp current_cluster,
216 mlir::tf_device::ClusterOp merging_cluster,
217 mlir::tf_device::ClusterOp* merged_cluster) {
218 builder->setInsertionPoint(current_cluster);
219
220 // Create new tf_device.cluster op that outputs results of both
221 // `current_cluster` and `merging_cluster`.
222 if (mlir::failed(CreateMergedMeshCluster(builder, current_cluster,
223 merging_cluster, merged_cluster)))
224 return mlir::failure();
225
226 // Move all ops to newly created merged cluster.
227 auto exit_op = merged_cluster->GetBody().getTerminator();
228 MoveOpsInsideCluster(current_cluster, *merged_cluster, exit_op);
229 MoveOpsInsideCluster(merging_cluster, *merged_cluster, exit_op);
230
231 // Remove mesh clusters as they are now merged to a new cluster.
232 current_cluster.erase();
233 merging_cluster.erase();
234 return mlir::success();
235 }
236
237 // Loops through tf_device.Cluster ops and merge clusters with same execution
238 // device set.
ClusterDeviceClusterOpsInBlock(mlir::OpBuilder * builder,mlir::Block * block)239 mlir::LogicalResult ClusterDeviceClusterOpsInBlock(mlir::OpBuilder* builder,
240 mlir::Block* block) {
241 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> block_ops;
242 block->walk([&](mlir::Operation* op) {
243 if (auto cluster = llvm::dyn_cast<mlir::tf_device::ClusterOp>(op))
244 block_ops.emplace_back(cluster);
245 });
246
247 llvm::Optional<mlir::tf_device::ClusterOp> current_cluster;
248 for (mlir::tf_device::ClusterOp cluster :
249 llvm::make_early_inc_range(block_ops)) {
250 if (!current_cluster.has_value()) {
251 current_cluster = cluster;
252 continue;
253 }
254 bool should_merge;
255 if (failed(ShouldMergeClusters(*current_cluster, cluster, &should_merge)))
256 return mlir::failure();
257
258 if (should_merge) {
259 mlir::tf_device::ClusterOp new_cluster;
260 if (mlir::failed(
261 MergeClusters(builder, *current_cluster, cluster, &new_cluster)))
262 return mlir::failure();
263
264 current_cluster.emplace(new_cluster);
265 } else {
266 current_cluster.emplace(cluster);
267 }
268 }
269 return mlir::success();
270 }
271
272 } // namespace
273
274 // MLIR pass that merges cluster ops with the same mesh attribute.
275 struct DTensorDeviceMeshClusterCoarsening
276 : public DTensorDeviceMeshClusterCoarseningBase<
277 DTensorDeviceMeshClusterCoarsening> {
runOnOperationtensorflow::dtensor::DTensorDeviceMeshClusterCoarsening278 void runOnOperation() override {
279 mlir::MLIRContext& context = getContext();
280 mlir::OpBuilder builder(&context);
281 for (mlir::Block& block : getOperation())
282 if (mlir::failed(ClusterDeviceClusterOpsInBlock(&builder, &block)))
283 return signalPassFailure();
284 }
285 };
286
287 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorDeviceMeshClusterCoarsening()288 CreateDTensorDeviceMeshClusterCoarsening() {
289 return std::make_unique<DTensorDeviceMeshClusterCoarsening>();
290 }
291
292 } // namespace dtensor
293 } // namespace tensorflow
294