1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/dtensor/mlir/collectives_common.h"
29 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
30 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
31 #include "tensorflow/dtensor/mlir/group_assignment.h"
32 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
33 #include "tensorflow/dtensor/mlir/layout_parsing.h"
34 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
35 
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39 
40 // Returns true if both group assignments are constant and equal.
same_group_assignments(mlir::DenseIntElementsAttr attr_a,mlir::DenseIntElementsAttr attr_b)41 bool same_group_assignments(mlir::DenseIntElementsAttr attr_a,
42                             mlir::DenseIntElementsAttr attr_b) {
43   if (attr_a.getType().getShape() != attr_b.getType().getShape()) {
44     return false;
45   }
46   return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end());
47 }
48 
GetScatterGroupAssignment(mlir::TF::DTensorAllScatterOp all_scatter,int scatter_dim)49 mlir::DenseIntElementsAttr GetScatterGroupAssignment(
50     mlir::TF::DTensorAllScatterOp all_scatter, int scatter_dim) {
51   const Layout original_layout = all_scatter.input_layout();
52   const Layout desired_layout = all_scatter.output_layout();
53   absl::flat_hash_set<std::string> scattered_dims;
54   scattered_dims.insert(desired_layout.sharding_spec(scatter_dim));
55 
56   auto partitions =
57       GetAllReducePartitionsFromReducedDims(original_layout, scattered_dims)
58           .ValueOrDie();
59   const int32 num_partitions = partitions.size();
60 
61   // Construct a flattened list of scatter partitions.
62   std::vector<int32> partitions_flat;
63   for (auto& p : partitions) {
64     partitions_flat.insert(partitions_flat.end(), p.second.begin(),
65                            p.second.end());
66   }
67 
68   int32 partition_size = partitions.begin()->second.size();
69   mlir::OpBuilder builder(all_scatter);
70   auto group_shaped_type = mlir::RankedTensorType::get(
71       {num_partitions, partition_size},
72       mlir::IntegerType::get(builder.getContext(), 32));
73 
74   return mlir::DenseIntElementsAttr::get(group_shaped_type, partitions_flat);
75 }
76 
ApplyOptimization(mlir::func::FuncOp function)77 mlir::LogicalResult ApplyOptimization(mlir::func::FuncOp function) {
78   std::vector<mlir::Operation*> ops_to_delete;
79   function.walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
80     if (all_reduce->hasOneUse()) {
81       if (auto all_scatter = mlir::dyn_cast<mlir::TF::DTensorAllScatterOp>(
82               *all_reduce->getUsers().begin())) {
83         VLOG(2) << "Found potential AllReduce+AllScatter to fuse.";
84         if (VLOG_IS_ON(2)) all_reduce.dump();
85         if (VLOG_IS_ON(2)) all_scatter.dump();
86 
87         const Layout original_layout = all_scatter.input_layout();
88         const Layout desired_layout = all_scatter.output_layout();
89 
90         // Find all potential scatter dimensions.
91         std::vector<int> scatter_dims;
92         for (int i = 0; i < original_layout.rank(); ++i) {
93           if (original_layout.sharding_spec(i) !=
94               desired_layout.sharding_spec(i)) {
95             scatter_dims.push_back(i);
96           }
97         }
98 
99         if (scatter_dims.empty()) return mlir::WalkResult::advance();
100         if (scatter_dims.size() > 1) {
101           VLOG(2) << "Multiple dimensions are scatter.  This is unsupported "
102                      "for AllReduce+Scatter fusion.";
103           return mlir::WalkResult::advance();
104         }
105 
106         int scatter_dim = scatter_dims[0];
107         VLOG(2) << "Scatter_dim: " << scatter_dim;
108 
109         // Check that the all-reduce and all-scatter group assignments are the
110         // same.
111         mlir::DenseIntElementsAttr all_reduce_group_assignment_attr;
112         if (!matchPattern(all_reduce.group_assignment(),
113                           m_Constant(&all_reduce_group_assignment_attr))) {
114           all_reduce.emitOpError("group_assignment should be a constant");
115           return mlir::WalkResult::interrupt();
116         }
117 
118         mlir::DenseIntElementsAttr all_scatter_group_assignment_attr =
119             GetScatterGroupAssignment(all_scatter, scatter_dim);
120 
121         VLOG(2) << "All scatter group assignment: ";
122         if (VLOG_IS_ON(2)) all_scatter_group_assignment_attr.dump();
123 
124         bool same_group =
125             same_group_assignments(all_reduce_group_assignment_attr,
126                                    all_scatter_group_assignment_attr);
127 
128         if (!same_group) return mlir::WalkResult::advance();
129         VLOG(2) << "Fuse reduce scatter with scatter_dim: " << scatter_dim;
130 
131         mlir::OpBuilder builder(all_reduce);
132         auto scatter_dim_const_op = builder.create<mlir::TF::ConstOp>(
133             all_reduce.getLoc(),
134             mlir::DenseIntElementsAttr::get(
135                 mlir::RankedTensorType::get({}, builder.getI32Type()),
136                 {scatter_dim}));
137 
138         auto reduce_scatter = builder.create<mlir::TF::DTensorReduceScatterOp>(
139             all_reduce.getLoc(), all_scatter->getResultTypes(),
140             all_reduce.getOperand(0), all_reduce.group_assignment(),
141             scatter_dim_const_op, all_reduce.reduce_op(),
142             all_reduce.device_type());
143         SetSingleLayoutOnOp(reduce_scatter, desired_layout);
144 
145         all_scatter->replaceAllUsesWith(reduce_scatter);
146 
147         ops_to_delete.push_back(all_scatter);
148         ops_to_delete.push_back(all_reduce);
149       }
150     }
151     return mlir::WalkResult::advance();
152   });
153 
154   for (mlir::Operation* op : ops_to_delete) {
155     op->erase();
156   }
157   return mlir::success();
158 }
159 
160 // MLIR pass that combines AllReduce and AllScatter to ReduceScatter.
161 struct DTensorAllReduceScatterOptimization
162     : public DTensorAllReduceScatterOptimizationBase<
163           DTensorAllReduceScatterOptimization> {
runOnOperationtensorflow::dtensor::__anondf6720da0111::DTensorAllReduceScatterOptimization164   void runOnOperation() override {
165     mlir::func::FuncOp function = getOperation();
166 
167     if (mlir::failed(ApplyOptimization(function))) return signalPassFailure();
168   }
169 };
170 
171 }  // namespace
172 
173 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorAllReduceScatterOptimization()174 CreateDTensorAllReduceScatterOptimization() {
175   return std::make_unique<DTensorAllReduceScatterOptimization>();
176 }
177 
178 }  // namespace dtensor
179 }  // namespace tensorflow
180