xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/all_reduce_reassociate_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_reduce_reassociate.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 
23 namespace xla {
24 namespace {
25 
26 namespace m = xla::testing::opcode_matchers;
27 
28 class AllReduceSimplifierTest : public HloTestBase {
29  public:
RunPass(absl::string_view hlo_module,bool expect_change)30   StatusOr<std::unique_ptr<HloModule>> RunPass(absl::string_view hlo_module,
31                                                bool expect_change) {
32     TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module));
33     auto changed = AllReduceReassociate().Run(module.get());
34     if (!changed.ok()) {
35       return changed.status();
36     }
37     EXPECT_EQ(changed.ValueOrDie(), expect_change);
38     return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
39   }
40 
AllReduceCount(std::unique_ptr<HloModule> & module)41   size_t AllReduceCount(std::unique_ptr<HloModule>& module) {
42     return absl::c_count_if(module->entry_computation()->instructions(),
43                             [](const HloInstruction* inst) {
44                               return inst->opcode() == HloOpcode::kAllReduce;
45                             });
46   }
47 };
48 
TEST_F(AllReduceSimplifierTest,Simple)49 TEST_F(AllReduceSimplifierTest, Simple) {
50   absl::string_view hlo_string = R"(
51 HloModule m
52 
53 sum {
54   a = f32[] parameter(0)
55   b = f32[] parameter(1)
56   ROOT add.2 = f32[] add(a, b)
57 }
58 
59 ENTRY main {
60   p0 = f32[8] parameter(0)
61   p1 = f32[8] parameter(1)
62   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
63   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
64   ROOT add = f32[8] add(ar0, ar1)
65 }
66 )";
67   TF_ASSERT_OK_AND_ASSIGN(auto module,
68                           RunPass(hlo_string, /*expect_change=*/true));
69   EXPECT_THAT(module->entry_computation()->root_instruction(),
70               m::AllReduce(m::Add(m::Parameter(0), m::Parameter(1))));
71   EXPECT_EQ(AllReduceCount(module), 1);
72 }
73 
TEST_F(AllReduceSimplifierTest,SimpleWithChannelId)74 TEST_F(AllReduceSimplifierTest, SimpleWithChannelId) {
75   absl::string_view hlo_string = R"(
76 HloModule m
77 
78 sum {
79   a = f32[] parameter(0)
80   b = f32[] parameter(1)
81   ROOT add.2 = f32[] add(a, b)
82 }
83 
84 ENTRY main {
85   p0 = f32[8] parameter(0)
86   p1 = f32[8] parameter(1)
87   ar0 = f32[8] all-reduce(p0), channel_id=1, replica_groups={}, to_apply=sum
88   ar1 = f32[8] all-reduce(p1), channel_id=1, replica_groups={}, to_apply=sum
89   ROOT add = f32[8] add(ar0, ar1)
90 }
91 )";
92   TF_ASSERT_OK_AND_ASSIGN(auto module,
93                           RunPass(hlo_string, /*expect_change=*/true));
94   EXPECT_THAT(module->entry_computation()->root_instruction(),
95               m::AllReduce(m::Add(m::Parameter(0), m::Parameter(1))));
96   EXPECT_EQ(AllReduceCount(module), 1);
97 }
98 
99 // Checks whether a linear chain of adds of ARs is reassociated iin a single
100 // pass.
TEST_F(AllReduceSimplifierTest,SimpleChain)101 TEST_F(AllReduceSimplifierTest, SimpleChain) {
102   absl::string_view hlo_string = R"(
103 HloModule m
104 
105 sum {
106   a = f32[] parameter(0)
107   b = f32[] parameter(1)
108   ROOT add.2 = f32[] add(a, b)
109 }
110 
111 ENTRY main {
112   p0 = f32[8] parameter(0)
113   p1 = f32[8] parameter(1)
114   p2 = f32[8] parameter(2)
115   p3 = f32[8] parameter(3)
116   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
117   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
118   ar2 = f32[8] all-reduce(p2), replica_groups={}, to_apply=sum
119   ar3 = f32[8] all-reduce(p3), replica_groups={}, to_apply=sum
120   add0 = f32[8] add(ar0, ar1)
121   add1 = f32[8] add(add0, ar2)
122   ROOT add2 = f32[8] add(add1, ar3)
123 }
124 )";
125   TF_ASSERT_OK_AND_ASSIGN(auto module,
126                           RunPass(hlo_string, /*expect_change=*/true));
127   EXPECT_THAT(
128       module->entry_computation()->root_instruction(),
129       m::AllReduce(m::Add(
130           m::Add(m::Add(m::Parameter(0), m::Parameter(1)), m::Parameter(2)),
131           m::Parameter(3))));
132   EXPECT_EQ(AllReduceCount(module), 1);
133 }
134 
135 // Checks whether a tree of add of ARs is reassociated in a single pass.
TEST_F(AllReduceSimplifierTest,SimpleTree)136 TEST_F(AllReduceSimplifierTest, SimpleTree) {
137   absl::string_view hlo_string = R"(
138 HloModule m
139 
140 sum {
141   a = f32[] parameter(0)
142   b = f32[] parameter(1)
143   ROOT add.2 = f32[] add(a, b)
144 }
145 
146 ENTRY main {
147   p0 = f32[8] parameter(0)
148   p1 = f32[8] parameter(1)
149   p2 = f32[8] parameter(2)
150   p3 = f32[8] parameter(3)
151   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
152   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
153   ar2 = f32[8] all-reduce(p2), replica_groups={}, to_apply=sum
154   ar3 = f32[8] all-reduce(p3), replica_groups={}, to_apply=sum
155   add0 = f32[8] add(ar0, ar1)
156   add1 = f32[8] add(ar2, ar3)
157   ROOT add2 = f32[8] add(add0, add1)
158 }
159 )";
160   TF_ASSERT_OK_AND_ASSIGN(auto module,
161                           RunPass(hlo_string, /*expect_change=*/true));
162   EXPECT_THAT(module->entry_computation()->root_instruction(),
163               m::AllReduce(m::Add(m::Add(m::Parameter(0), m::Parameter(1)),
164                                   m::Add(m::Parameter(2), m::Parameter(3)))));
165   EXPECT_EQ(AllReduceCount(module), 1);
166 }
167 
TEST_F(AllReduceSimplifierTest,MismatchOp0)168 TEST_F(AllReduceSimplifierTest, MismatchOp0) {
169   absl::string_view hlo_string = R"(
170 HloModule m
171 
172 sum {
173   a = f32[] parameter(0)
174   b = f32[] parameter(1)
175   ROOT add.2 = f32[] add(a, b)
176 }
177 
178 max {
179   a = f32[] parameter(0)
180   b = f32[] parameter(1)
181   ROOT r = f32[] maximum(a, b)
182 }
183 
184 ENTRY main {
185   p0 = f32[8] parameter(0)
186   p1 = f32[8] parameter(1)
187   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
188   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=max
189   ROOT add = f32[8] add(ar0, ar1)
190 }
191 )";
192   TF_ASSERT_OK_AND_ASSIGN(auto module,
193                           RunPass(hlo_string, /*expect_change=*/false));
194 }
195 
TEST_F(AllReduceSimplifierTest,MismatchOp1)196 TEST_F(AllReduceSimplifierTest, MismatchOp1) {
197   absl::string_view hlo_string = R"(
198 HloModule m
199 
200 sum {
201   a = f32[] parameter(0)
202   b = f32[] parameter(1)
203   ROOT add.2 = f32[] add(a, b)
204 }
205 
206 max {
207   a = f32[] parameter(0)
208   b = f32[] parameter(1)
209   ROOT r = f32[] maximum(a, b)
210 }
211 
212 ENTRY main {
213   p0 = f32[8] parameter(0)
214   p1 = f32[8] parameter(1)
215   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=max
216   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=max
217   ROOT add = f32[8] add(ar0, ar1)
218 }
219 )";
220   TF_ASSERT_OK_AND_ASSIGN(auto module,
221                           RunPass(hlo_string, /*expect_change=*/false));
222 }
223 
TEST_F(AllReduceSimplifierTest,MismatchReplicaGroups)224 TEST_F(AllReduceSimplifierTest, MismatchReplicaGroups) {
225   absl::string_view hlo_string = R"(
226 HloModule m
227 
228 sum {
229   a = f32[] parameter(0)
230   b = f32[] parameter(1)
231   ROOT add.2 = f32[] add(a, b)
232 }
233 
234 ENTRY main {
235   p0 = f32[8] parameter(0)
236   p1 = f32[8] parameter(1)
237   ar0 = f32[8] all-reduce(p0), replica_groups={{0}}, to_apply=sum
238   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
239   ROOT add = f32[8] add(ar0, ar1)
240 }
241 )";
242   TF_ASSERT_OK_AND_ASSIGN(auto module,
243                           RunPass(hlo_string, /*expect_change=*/false));
244 }
245 
TEST_F(AllReduceSimplifierTest,MismatchHasChannelId)246 TEST_F(AllReduceSimplifierTest, MismatchHasChannelId) {
247   absl::string_view hlo_string = R"(
248 HloModule m
249 
250 sum {
251   a = f32[] parameter(0)
252   b = f32[] parameter(1)
253   ROOT add.2 = f32[] add(a, b)
254 }
255 
256 ENTRY main {
257   p0 = f32[8] parameter(0)
258   p1 = f32[8] parameter(1)
259   ar0 = f32[8] all-reduce(p0), replica_groups={}, channel_id=3, to_apply=sum
260   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
261   ROOT add = f32[8] add(ar0, ar1)
262 }
263 )";
264   TF_ASSERT_OK_AND_ASSIGN(auto module,
265                           RunPass(hlo_string, /*expect_change=*/false));
266 }
267 
TEST_F(AllReduceSimplifierTest,MismatchUseGlobalDeviceId)268 TEST_F(AllReduceSimplifierTest, MismatchUseGlobalDeviceId) {
269   absl::string_view hlo_string = R"(
270 HloModule m
271 
272 sum {
273   a = f32[] parameter(0)
274   b = f32[] parameter(1)
275   ROOT add.2 = f32[] add(a, b)
276 }
277 
278 ENTRY main {
279   p0 = f32[8] parameter(0)
280   p1 = f32[8] parameter(1)
281   ar0 = f32[8] all-reduce(p0), replica_groups={{0, 1}}, channel_id=3, use_global_device_ids=true, to_apply=sum
282   ar1 = f32[8] all-reduce(p1), replica_groups={{0, 1}}, channel_id=4, to_apply=sum
283   ROOT add = f32[8] add(ar0, ar1)
284 }
285 )";
286   TF_ASSERT_OK_AND_ASSIGN(auto module,
287                           RunPass(hlo_string, /*expect_change=*/false));
288 }
289 
TEST_F(AllReduceSimplifierTest,NotSingleUser)290 TEST_F(AllReduceSimplifierTest, NotSingleUser) {
291   absl::string_view hlo_string = R"(
292 HloModule m
293 
294 sum {
295   a = f32[] parameter(0)
296   b = f32[] parameter(1)
297   ROOT add.2 = f32[] add(a, b)
298 }
299 
300 ENTRY main {
301   p0 = f32[8] parameter(0)
302   p1 = f32[8] parameter(1)
303   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
304   ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
305   add = f32[8] add(ar0, ar1)
306   ROOT t = (f32[8], f32[8]) tuple(ar0, add)
307 }
308 )";
309   TF_ASSERT_OK_AND_ASSIGN(auto module,
310                           RunPass(hlo_string, /*expect_change=*/false));
311 }
312 
TEST_F(AllReduceSimplifierTest,DoubleUse)313 TEST_F(AllReduceSimplifierTest, DoubleUse) {
314   absl::string_view hlo_string = R"(
315 HloModule m
316 
317 sum {
318   a = f32[] parameter(0)
319   b = f32[] parameter(1)
320   ROOT add.2 = f32[] add(a, b)
321 }
322 
323 ENTRY main {
324   p0 = f32[8] parameter(0)
325   p1 = f32[8] parameter(1)
326   ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
327   add = f32[8] add(ar0, ar0)
328   ROOT c = f32[8] copy(add)
329 }
330 )";
331   TF_ASSERT_OK_AND_ASSIGN(auto module,
332                           RunPass(hlo_string, /*expect_change=*/true));
333 }
334 
335 }  // namespace
336 }  // namespace xla
337