1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/IR/UseDefLists.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 #include "mlir/Transforms/Passes.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/dtensor/mlir/collectives_common.h"
29 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
30 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
31 #include "tensorflow/dtensor/mlir/group_assignment.h"
32 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
33 #include "tensorflow/dtensor/mlir/layout_parsing.h"
34 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
35
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39
40 // Returns true if both group assignments are constant and equal.
same_group_assignments(mlir::DenseIntElementsAttr attr_a,mlir::DenseIntElementsAttr attr_b)41 bool same_group_assignments(mlir::DenseIntElementsAttr attr_a,
42 mlir::DenseIntElementsAttr attr_b) {
43 if (attr_a.getType().getShape() != attr_b.getType().getShape()) {
44 return false;
45 }
46 return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end());
47 }
48
GetScatterGroupAssignment(mlir::TF::DTensorAllScatterOp all_scatter,int scatter_dim)49 mlir::DenseIntElementsAttr GetScatterGroupAssignment(
50 mlir::TF::DTensorAllScatterOp all_scatter, int scatter_dim) {
51 const Layout original_layout = all_scatter.input_layout();
52 const Layout desired_layout = all_scatter.output_layout();
53 absl::flat_hash_set<std::string> scattered_dims;
54 scattered_dims.insert(desired_layout.sharding_spec(scatter_dim));
55
56 auto partitions =
57 GetAllReducePartitionsFromReducedDims(original_layout, scattered_dims)
58 .ValueOrDie();
59 const int32 num_partitions = partitions.size();
60
61 // Construct a flattened list of scatter partitions.
62 std::vector<int32> partitions_flat;
63 for (auto& p : partitions) {
64 partitions_flat.insert(partitions_flat.end(), p.second.begin(),
65 p.second.end());
66 }
67
68 int32 partition_size = partitions.begin()->second.size();
69 mlir::OpBuilder builder(all_scatter);
70 auto group_shaped_type = mlir::RankedTensorType::get(
71 {num_partitions, partition_size},
72 mlir::IntegerType::get(builder.getContext(), 32));
73
74 return mlir::DenseIntElementsAttr::get(group_shaped_type, partitions_flat);
75 }
76
ApplyOptimization(mlir::func::FuncOp function)77 mlir::LogicalResult ApplyOptimization(mlir::func::FuncOp function) {
78 std::vector<mlir::Operation*> ops_to_delete;
79 function.walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
80 if (all_reduce->hasOneUse()) {
81 if (auto all_scatter = mlir::dyn_cast<mlir::TF::DTensorAllScatterOp>(
82 *all_reduce->getUsers().begin())) {
83 VLOG(2) << "Found potential AllReduce+AllScatter to fuse.";
84 if (VLOG_IS_ON(2)) all_reduce.dump();
85 if (VLOG_IS_ON(2)) all_scatter.dump();
86
87 const Layout original_layout = all_scatter.input_layout();
88 const Layout desired_layout = all_scatter.output_layout();
89
90 // Find all potential scatter dimensions.
91 std::vector<int> scatter_dims;
92 for (int i = 0; i < original_layout.rank(); ++i) {
93 if (original_layout.sharding_spec(i) !=
94 desired_layout.sharding_spec(i)) {
95 scatter_dims.push_back(i);
96 }
97 }
98
99 if (scatter_dims.empty()) return mlir::WalkResult::advance();
100 if (scatter_dims.size() > 1) {
101 VLOG(2) << "Multiple dimensions are scatter. This is unsupported "
102 "for AllReduce+Scatter fusion.";
103 return mlir::WalkResult::advance();
104 }
105
106 int scatter_dim = scatter_dims[0];
107 VLOG(2) << "Scatter_dim: " << scatter_dim;
108
109 // Check that the all-reduce and all-scatter group assignments are the
110 // same.
111 mlir::DenseIntElementsAttr all_reduce_group_assignment_attr;
112 if (!matchPattern(all_reduce.group_assignment(),
113 m_Constant(&all_reduce_group_assignment_attr))) {
114 all_reduce.emitOpError("group_assignment should be a constant");
115 return mlir::WalkResult::interrupt();
116 }
117
118 mlir::DenseIntElementsAttr all_scatter_group_assignment_attr =
119 GetScatterGroupAssignment(all_scatter, scatter_dim);
120
121 VLOG(2) << "All scatter group assignment: ";
122 if (VLOG_IS_ON(2)) all_scatter_group_assignment_attr.dump();
123
124 bool same_group =
125 same_group_assignments(all_reduce_group_assignment_attr,
126 all_scatter_group_assignment_attr);
127
128 if (!same_group) return mlir::WalkResult::advance();
129 VLOG(2) << "Fuse reduce scatter with scatter_dim: " << scatter_dim;
130
131 mlir::OpBuilder builder(all_reduce);
132 auto scatter_dim_const_op = builder.create<mlir::TF::ConstOp>(
133 all_reduce.getLoc(),
134 mlir::DenseIntElementsAttr::get(
135 mlir::RankedTensorType::get({}, builder.getI32Type()),
136 {scatter_dim}));
137
138 auto reduce_scatter = builder.create<mlir::TF::DTensorReduceScatterOp>(
139 all_reduce.getLoc(), all_scatter->getResultTypes(),
140 all_reduce.getOperand(0), all_reduce.group_assignment(),
141 scatter_dim_const_op, all_reduce.reduce_op(),
142 all_reduce.device_type());
143 SetSingleLayoutOnOp(reduce_scatter, desired_layout);
144
145 all_scatter->replaceAllUsesWith(reduce_scatter);
146
147 ops_to_delete.push_back(all_scatter);
148 ops_to_delete.push_back(all_reduce);
149 }
150 }
151 return mlir::WalkResult::advance();
152 });
153
154 for (mlir::Operation* op : ops_to_delete) {
155 op->erase();
156 }
157 return mlir::success();
158 }
159
160 // MLIR pass that combines AllReduce and AllScatter to ReduceScatter.
161 struct DTensorAllReduceScatterOptimization
162 : public DTensorAllReduceScatterOptimizationBase<
163 DTensorAllReduceScatterOptimization> {
runOnOperationtensorflow::dtensor::__anondf6720da0111::DTensorAllReduceScatterOptimization164 void runOnOperation() override {
165 mlir::func::FuncOp function = getOperation();
166
167 if (mlir::failed(ApplyOptimization(function))) return signalPassFailure();
168 }
169 };
170
171 } // namespace
172
173 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorAllReduceScatterOptimization()174 CreateDTensorAllReduceScatterOptimization() {
175 return std::make_unique<DTensorAllReduceScatterOptimization>();
176 }
177
178 } // namespace dtensor
179 } // namespace tensorflow
180