xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/device_mesh_cluster_coarsening.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/Types.h"  // from @llvm-project
26 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
27 #include "mlir/Transforms/Passes.h"  // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/dtensor/cc/constants.h"
31 #include "tensorflow/dtensor/cc/tensor_layout.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35 #include "tensorflow/dtensor/mlir/layout_parsing.h"
36 
37 namespace tensorflow {
38 namespace dtensor {
39 namespace {
40 
41 constexpr char kMissingMeshAttributeErrorMessage[] =
42     "failed to merge mesh cluster as cluster does not have mesh attribute. "
43     "This is likely due to problem in mesh propagation.";
44 
45 // Determines whether two adjoining clusters should be merged.
ShouldMergeClusters(mlir::tf_device::ClusterOp cluster_a,mlir::tf_device::ClusterOp cluster_b,bool * should_merge)46 mlir::LogicalResult ShouldMergeClusters(mlir::tf_device::ClusterOp cluster_a,
47                                         mlir::tf_device::ClusterOp cluster_b,
48                                         bool* should_merge) {
49   if (cluster_a->getParentRegion() != cluster_b->getParentRegion()) {
50     *should_merge = false;
51     return mlir::success();
52   }
53 
54   auto mesh_a_or_status = ExtractDeviceMeshFromOp(cluster_a.getOperation());
55   if (!mesh_a_or_status.ok())
56     return cluster_a.emitOpError(mesh_a_or_status.status().error_message());
57 
58   auto mesh_b_or_status = ExtractDeviceMeshFromOp(cluster_b.getOperation());
59   if (!mesh_b_or_status.ok())
60     return cluster_b.emitOpError(mesh_b_or_status.status().error_message());
61 
62   auto mesh_a = mesh_a_or_status.ValueOrDie();
63   auto mesh_b = mesh_b_or_status.ValueOrDie();
64   if (!mesh_a || !mesh_b) {
65     return !mesh_a ? cluster_a.emitOpError(kMissingMeshAttributeErrorMessage)
66                    : cluster_b.emitOpError(kMissingMeshAttributeErrorMessage);
67   }
68 
69   *should_merge = mesh_a == mesh_b;
70   return mlir::success();
71 }
72 
73 // Moves all ops (except tf_device.return op) inside `src_cluster` to
74 // block inside `target_cluster`. Ops are moved before the `exit_op`
75 // inside the `target_cluster`.
MoveOpsInsideCluster(mlir::tf_device::ClusterOp src_cluster,mlir::tf_device::ClusterOp target_cluster,mlir::Operation * exit_op)76 void MoveOpsInsideCluster(mlir::tf_device::ClusterOp src_cluster,
77                           mlir::tf_device::ClusterOp target_cluster,
78                           mlir::Operation* exit_op) {
79   auto& cluster_body = src_cluster.GetBody().getOperations();
80   target_cluster.GetBody().getOperations().splice(
81       exit_op->getIterator(), cluster_body, cluster_body.begin(),
82       std::prev(cluster_body.end()));
83 }
84 
85 // Returns a list of pair of mlir Values that represent <return values of ops
86 // inside the merged_cluster, output values of merged cluster>.
87 //
88 // If outputs of `current_cluster` is used as operands to ops in
89 // `merging_cluster`, then make sure to replace operands such that
90 // results values from the inner ops of `current_cluster` is used instead.
91 //
92 // For example,
93 //    %0 = "tf_device.cluster"() ({
94 //      %1 = "tf.A"() : () -> tensor<i32>
95 //      "tf_device.return"(%1) : (tensor<i32>) -> ()
96 //    }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>)
97 //
98 //    %2 = "tf_device.cluster"() ({
99 //      %3 = "tf.B"(%0) : (tenosr<i32>) -> tensor<f32>
100 //      "tf_device.return"(%3) : (tensor<f32>) -> ()
101 //    }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<f32>)
102 //
103 // will be:
104 //    %0 = "tf_device.cluster"() ({
105 //      %1 = "tf.A"() : () -> tensor<i32>
106 //
107 //      # NOTE: `tf.B` op now takes operand directly from
108 //      # `tf.A` instead of `tf_dtensor.cluster op.
109 //      %2 = "tf.B"(%1) : (tenosr<i32>) -> tensor<f32>
110 //      "tf_device.return"(%1, %2) : (tensor<i32>, tensor<f32>)) -> ()
111 //    }) {mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>, tensor<f32>)
112 llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8>
GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster,mlir::tf_device::ClusterOp merging_cluster)113 GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster,
114                             mlir::tf_device::ClusterOp merging_cluster) {
115   llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8>
116       merged_cluster_results;
117   merged_cluster_results.reserve(current_cluster.getNumResults() +
118                                  merging_cluster.getNumResults());
119 
120   auto current_cluster_return_op = current_cluster.GetBody().getTerminator();
121   for (auto result : llvm::zip(current_cluster_return_op->getOpOperands(),
122                                current_cluster.getResults())) {
123     mlir::Value inner_op_result = std::get<0>(result).get();
124     mlir::Value outer_op_result = std::get<1>(result);
125 
126     // If the output value of `current_cluster` is only used by ops
127     // inside the `merged_cluster`, do not add the value as a return
128     // value for newly created tf_device.cluster op.
129     bool result_only_used_by_merging_cluster = true;
130     for (auto& use : llvm::make_early_inc_range(outer_op_result.getUses())) {
131       if (merging_cluster.GetBody().findAncestorOpInBlock(*use.getOwner())) {
132         use.set(inner_op_result);
133       } else {
134         result_only_used_by_merging_cluster = false;
135       }
136     }
137 
138     if (!result_only_used_by_merging_cluster) {
139       merged_cluster_results.emplace_back(inner_op_result, outer_op_result);
140     }
141   }
142 
143   auto merging_cluster_return_op = merging_cluster.GetBody().getTerminator();
144   for (auto result : llvm::zip(merging_cluster_return_op->getOpOperands(),
145                                merging_cluster.getResults())) {
146     mlir::Value inner_op_result = std::get<0>(result).get();
147     mlir::Value outer_op_result = std::get<1>(result);
148 
149     if (!outer_op_result.getUses().empty())
150       merged_cluster_results.emplace_back(inner_op_result, outer_op_result);
151   }
152 
153   return merged_cluster_results;
154 }
155 
156 // Updates the users of `merging_cluster` so that they use values
157 // from `merged_cluster` instead.
ReplaceOperandUsagesWithMergedClusterOutputs(const llvm::SmallVectorImpl<mlir::Value> & values_to_replace,mlir::tf_device::ClusterOp merged_cluster)158 void ReplaceOperandUsagesWithMergedClusterOutputs(
159     const llvm::SmallVectorImpl<mlir::Value>& values_to_replace,
160     mlir::tf_device::ClusterOp merged_cluster) {
161   for (auto result :
162        llvm::zip(values_to_replace, merged_cluster.getResults())) {
163     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
164   }
165 }
166 
167 // Creates a new tf_device.cluster op that merges
168 // `current_cluster` and `merging_cluster`.
CreateMergedMeshCluster(mlir::OpBuilder * builder,mlir::tf_device::ClusterOp current_cluster,mlir::tf_device::ClusterOp merging_cluster,mlir::tf_device::ClusterOp * merged_cluster)169 mlir::LogicalResult CreateMergedMeshCluster(
170     mlir::OpBuilder* builder, mlir::tf_device::ClusterOp current_cluster,
171     mlir::tf_device::ClusterOp merging_cluster,
172     mlir::tf_device::ClusterOp* merged_cluster) {
173   auto return_values =
174       GetMergedMeshClusterResults(current_cluster, merging_cluster);
175 
176   llvm::SmallVector<mlir::Type, 8> merged_cluster_output_types;
177   llvm::SmallVector<mlir::Value, 8> merged_cluster_output_values;
178   llvm::SmallVector<mlir::Value, 8> output_values_to_replace;
179   merged_cluster_output_types.reserve(return_values.size());
180   merged_cluster_output_values.reserve(return_values.size());
181   output_values_to_replace.reserve(return_values.size());
182   for (auto cluster_return_value : return_values) {
183     auto inner_op_return_value = std::get<0>(cluster_return_value);
184     merged_cluster_output_types.emplace_back(inner_op_return_value.getType());
185     merged_cluster_output_values.emplace_back(inner_op_return_value);
186     output_values_to_replace.emplace_back(std::get<1>(cluster_return_value));
187   }
188 
189   *merged_cluster = builder->create<mlir::tf_device::ClusterOp>(
190       current_cluster.getLoc(), merged_cluster_output_types);
191   auto mesh_attr = current_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
192   if (!mesh_attr)
193     return current_cluster.emitOpError(kMissingMeshAttributeErrorMessage);
194 
195   (*merged_cluster)->setAttr(kMeshAttr, mesh_attr);
196 
197   // Create a terminator op that returns all return values from
198   // `current_cluster` and `merging_cluster`.
199   merged_cluster->body().push_back(new mlir::Block);
200   builder->setInsertionPointToEnd(&merged_cluster->GetBody());
201   builder->create<mlir::tf_device::ReturnOp>(merged_cluster->getLoc(),
202                                              merged_cluster_output_values);
203 
204   // Make sure to replace usages of tf_device.cluster ops to be merged-away with
205   // newly created tf_device.cluster op.
206   ReplaceOperandUsagesWithMergedClusterOutputs(output_values_to_replace,
207                                                *merged_cluster);
208 
209   return mlir::success();
210 }
211 
212 // Merges `current_cluster` and `merging_cluster` and returns a new merged
213 // tf_device.cluster.
MergeClusters(mlir::OpBuilder * builder,mlir::tf_device::ClusterOp current_cluster,mlir::tf_device::ClusterOp merging_cluster,mlir::tf_device::ClusterOp * merged_cluster)214 mlir::LogicalResult MergeClusters(mlir::OpBuilder* builder,
215                                   mlir::tf_device::ClusterOp current_cluster,
216                                   mlir::tf_device::ClusterOp merging_cluster,
217                                   mlir::tf_device::ClusterOp* merged_cluster) {
218   builder->setInsertionPoint(current_cluster);
219 
220   // Create new tf_device.cluster op that outputs results of both
221   // `current_cluster` and `merging_cluster`.
222   if (mlir::failed(CreateMergedMeshCluster(builder, current_cluster,
223                                            merging_cluster, merged_cluster)))
224     return mlir::failure();
225 
226   // Move all ops to newly created merged cluster.
227   auto exit_op = merged_cluster->GetBody().getTerminator();
228   MoveOpsInsideCluster(current_cluster, *merged_cluster, exit_op);
229   MoveOpsInsideCluster(merging_cluster, *merged_cluster, exit_op);
230 
231   // Remove mesh clusters as they are now merged to a new cluster.
232   current_cluster.erase();
233   merging_cluster.erase();
234   return mlir::success();
235 }
236 
237 // Loops through tf_device.Cluster ops and merge clusters with same execution
238 // device set.
ClusterDeviceClusterOpsInBlock(mlir::OpBuilder * builder,mlir::Block * block)239 mlir::LogicalResult ClusterDeviceClusterOpsInBlock(mlir::OpBuilder* builder,
240                                                    mlir::Block* block) {
241   llvm::SmallVector<mlir::tf_device::ClusterOp, 4> block_ops;
242   block->walk([&](mlir::Operation* op) {
243     if (auto cluster = llvm::dyn_cast<mlir::tf_device::ClusterOp>(op))
244       block_ops.emplace_back(cluster);
245   });
246 
247   llvm::Optional<mlir::tf_device::ClusterOp> current_cluster;
248   for (mlir::tf_device::ClusterOp cluster :
249        llvm::make_early_inc_range(block_ops)) {
250     if (!current_cluster.has_value()) {
251       current_cluster = cluster;
252       continue;
253     }
254     bool should_merge;
255     if (failed(ShouldMergeClusters(*current_cluster, cluster, &should_merge)))
256       return mlir::failure();
257 
258     if (should_merge) {
259       mlir::tf_device::ClusterOp new_cluster;
260       if (mlir::failed(
261               MergeClusters(builder, *current_cluster, cluster, &new_cluster)))
262         return mlir::failure();
263 
264       current_cluster.emplace(new_cluster);
265     } else {
266       current_cluster.emplace(cluster);
267     }
268   }
269   return mlir::success();
270 }
271 
272 }  // namespace
273 
274 // MLIR pass that merges cluster ops with the same mesh attribute.
275 struct DTensorDeviceMeshClusterCoarsening
276     : public DTensorDeviceMeshClusterCoarseningBase<
277           DTensorDeviceMeshClusterCoarsening> {
runOnOperationtensorflow::dtensor::DTensorDeviceMeshClusterCoarsening278   void runOnOperation() override {
279     mlir::MLIRContext& context = getContext();
280     mlir::OpBuilder builder(&context);
281     for (mlir::Block& block : getOperation())
282       if (mlir::failed(ClusterDeviceClusterOpsInBlock(&builder, &block)))
283         return signalPassFailure();
284   }
285 };
286 
287 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorDeviceMeshClusterCoarsening()288 CreateDTensorDeviceMeshClusterCoarsening() {
289   return std::make_unique<DTensorDeviceMeshClusterCoarsening>();
290 }
291 
292 }  // namespace dtensor
293 }  // namespace tensorflow
294