1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/statusor.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22
23 namespace xla {
24 namespace {
25
26 namespace op = xla::testing::opcode_matchers;
27
28 class ArCrsCombinerTest : public HloTestBase {};
29
TEST_F(ArCrsCombinerTest,SameValueTestBasecase)30 TEST_F(ArCrsCombinerTest, SameValueTestBasecase) {
31 const char* module_str = R"(
32 HloModule foobar
33
34 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
35 %p = f32[2,2] parameter(0)
36 %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
37 %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}})
38 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
39 }
40 )";
41
42 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
43 ParseAndReturnVerifiedModule(module_str));
44 auto root_tuple = module->entry_computation()->root_instruction();
45 auto i1 = root_tuple->operands()[0];
46 auto i2 = root_tuple->operands()[1];
47 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(
48 i1, module->entry_computation()->parameter_instruction(0)));
49 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
50 }
51
TEST_F(ArCrsCombinerTest,SameValueTestBasecase2)52 TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) {
53 const char* module_str = R"(
54 HloModule foobar
55
56 ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) {
57 %x = f32[] parameter(0)
58 ROOT %tuple = (f32[], f32[]) tuple(%x, %x)
59 }
60 )";
61
62 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
63 ParseAndReturnVerifiedModule(module_str));
64 auto root_tuple = module->entry_computation()->root_instruction();
65 auto i1 = root_tuple->operands()[0];
66 auto i2 = root_tuple->operands()[1];
67 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
68 }
69
TEST_F(ArCrsCombinerTest,SameValueTestBasecase3)70 TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) {
71 const char* module_str = R"(
72 HloModule foobar
73
74 ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) {
75 %x = f32[] parameter(0)
76 %y = f32[] parameter(1)
77 ROOT %tuple = (f32[], f32[]) tuple(%x, %y)
78 }
79 )";
80
81 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
82 ParseAndReturnVerifiedModule(module_str));
83 auto root_tuple = module->entry_computation()->root_instruction();
84 auto i1 = root_tuple->operands()[0];
85 auto i2 = root_tuple->operands()[1];
86 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
87 }
88
TEST_F(ArCrsCombinerTest,SameValueTestNumOperands)89 TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) {
90 const char* module_str = R"(
91 HloModule foobar
92
93 ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) {
94 %p = f32[2,2] parameter(0)
95 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
96 %tuple1 = (f32[2,2]) tuple(%constant.f32)
97 %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
98 ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2)
99 }
100 )";
101
102 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
103 ParseAndReturnVerifiedModule(module_str));
104 auto root_tuple = module->entry_computation()->root_instruction();
105 auto i1 = root_tuple->operands()[0];
106 auto i2 = root_tuple->operands()[1];
107 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
108 }
109
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesMatch)110 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) {
111 const char* module_str = R"(
112 HloModule foobar
113
114 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
115 %p = f32[2] parameter(0)
116 %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
117 %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]}
118 ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
119 }
120 )";
121
122 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
123 ParseAndReturnVerifiedModule(module_str));
124 auto root_tuple = module->entry_computation()->root_instruction();
125 auto i1 = root_tuple->operands()[0];
126 auto i2 = root_tuple->operands()[1];
127 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
128 }
129
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesDontMatch)130 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) {
131 const char* module_str = R"(
132 HloModule foobar
133
134 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
135 %p = f32[2] parameter(0)
136 %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
137 %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]}
138 ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
139 }
140 )";
141
142 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
143 ParseAndReturnVerifiedModule(module_str));
144 auto root_tuple = module->entry_computation()->root_instruction();
145 auto i1 = root_tuple->operands()[0];
146 auto i2 = root_tuple->operands()[1];
147 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
148 }
149
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementSameIndex)150 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) {
151 const char* module_str = R"(
152 HloModule foobar
153
154 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
155 %p = f32[2,2] parameter(0)
156 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
157 %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
158 %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
159 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0
160 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
161 }
162 )";
163
164 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
165 ParseAndReturnVerifiedModule(module_str));
166 auto root_tuple = module->entry_computation()->root_instruction();
167 auto i1 = root_tuple->operands()[0];
168 auto i2 = root_tuple->operands()[1];
169 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
170 }
171
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex1)172 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) {
173 const char* module_str = R"(
174 HloModule foobar
175
176 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
177 %p = f32[2,2] parameter(0)
178 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
179 %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
180 %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
181 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
182 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
183 }
184 )";
185
186 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
187 ParseAndReturnVerifiedModule(module_str));
188 auto root_tuple = module->entry_computation()->root_instruction();
189 auto i1 = root_tuple->operands()[0];
190 auto i2 = root_tuple->operands()[1];
191 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
192 }
193
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex2)194 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) {
195 const char* module_str = R"(
196 HloModule foobar
197
198 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
199 %p = f32[2,2] parameter(0)
200 %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
201 %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}})
202 %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
203 %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
204 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
205 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
206 }
207 )";
208
209 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
210 ParseAndReturnVerifiedModule(module_str));
211 auto root_tuple = module->entry_computation()->root_instruction();
212 auto i1 = root_tuple->operands()[0];
213 auto i2 = root_tuple->operands()[1];
214 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
215 }
216
TEST_F(ArCrsCombinerTest,SameValueTestWhile1)217 TEST_F(ArCrsCombinerTest, SameValueTestWhile1) {
218 const char* module_str = R"(
219 HloModule foobar
220
221 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
222 %x = (f32[2,2], f32[2,2]) parameter(0)
223 %constant.0 = s32[] constant(0)
224 %constant.1 = s32[] constant(1)
225 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
226 }
227
228 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
229 %x = (f32[2,2], f32[2,2]) parameter(0)
230 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
231 %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
232 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
233 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
234 %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
235 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
236 }
237
238 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
239 %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
240 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
241 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
242 }
243 )";
244
245 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
246 ParseAndReturnVerifiedModule(module_str));
247 auto root_while = module->entry_computation()->root_instruction();
248 auto body_tuple = root_while->while_body()->root_instruction();
249 auto i1 = body_tuple->operands()[0];
250 auto i2 = body_tuple->operands()[1];
251 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
252 }
253
TEST_F(ArCrsCombinerTest,SameValueTestWhile2)254 TEST_F(ArCrsCombinerTest, SameValueTestWhile2) {
255 const char* module_str = R"(
256 HloModule foobar
257
258 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
259 %x = (f32[2,2], f32[2,2]) parameter(0)
260 %constant.0 = s32[] constant(0)
261 %constant.1 = s32[] constant(1)
262 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
263 }
264
265 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
266 %x = (f32[2,2], f32[2,2]) parameter(0)
267 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
268 %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
269 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
270 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
271 %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
272 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
273 }
274
275 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
276 %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}})
277 %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}})
278 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
279 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
280 }
281 )";
282
283 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
284 ParseAndReturnVerifiedModule(module_str));
285 auto root_while = module->entry_computation()->root_instruction();
286 auto body_tuple = root_while->while_body()->root_instruction();
287 auto i1 = body_tuple->operands()[0];
288 auto i2 = body_tuple->operands()[1];
289 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
290 }
291
TEST_F(ArCrsCombinerTest,SameValueTestWhile3)292 TEST_F(ArCrsCombinerTest, SameValueTestWhile3) {
293 const char* module_str = R"(
294 HloModule foobar
295
296 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
297 %x = (f32[2,2], f32[2,2]) parameter(0)
298 %constant.0 = s32[] constant(0)
299 %constant.1 = s32[] constant(1)
300 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
301 }
302
303 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
304 %x = (f32[2,2], f32[2,2]) parameter(0)
305 %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
306 %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}})
307 %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
308 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
309 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1)
310 %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2)
311 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
312 }
313
314 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
315 %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
316 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
317 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
318 }
319 )";
320
321 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
322 ParseAndReturnVerifiedModule(module_str));
323 auto root_while = module->entry_computation()->root_instruction();
324 auto body_tuple = root_while->while_body()->root_instruction();
325 auto i1 = body_tuple->operands()[0]->operands()[0]; // %get-tuple-element.1
326 auto i2 = body_tuple->operands()[1]->operands()[0]; // %get-tuple-element.2
327 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
328 }
329
TEST_F(ArCrsCombinerTest,SameValueTestNestedWhile)330 TEST_F(ArCrsCombinerTest, SameValueTestNestedWhile) {
331 const char* module_str = R"(
332 HloModule foobar
333
334 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
335 %x = (f32[2,2], f32[2,2]) parameter(0)
336 ROOT %t = pred[] constant(true)
337 }
338
339 %body_inner (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
340 %x = (f32[2,2], f32[2,2]) parameter(0)
341 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
342 %gte.1 = f32[2,2] get-tuple-element(%x), index=0
343 %gte.2 = f32[2,2] get-tuple-element(%x), index=1
344 %add.1 = f32[2,2] add(%gte.1, %constant.f32)
345 %add.2 = f32[2,2] add(%gte.2, %constant.f32)
346 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
347 }
348
349 %body_outer (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
350 %x = (f32[2,2], f32[2,2]) parameter(0)
351 %gte.1 = f32[2,2] get-tuple-element(%x), index=0
352 %gte.2 = f32[2,2] get-tuple-element(%x), index=1
353 %init = (f32[2,2], f32[2,2]) tuple(%gte.1, %gte.2)
354 ROOT %while.1 = (f32[2,2], f32[2,2]) while(%init), condition=%condition,
355 body=%body_inner
356 }
357
358 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
359 %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
360 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
361 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition,
362 body=%body_outer
363 }
364 )";
365
366 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
367 ParseAndReturnVerifiedModule(module_str));
368
369 auto root_while = module->entry_computation()->root_instruction();
370 auto inner_while = root_while->while_body()->root_instruction();
371 auto i1 = inner_while->while_body()->root_instruction()->operands()[0];
372 auto i2 = inner_while->while_body()->root_instruction()->operands()[1];
373 // They are the same because the same constant {{3, 4}, {5, 6}} flows to both,
374 // and we add the same number {{1, 2}, {3, 4}} to both in each iteration
375 // of the inner while.
376 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
377 }
378
CompareReplicaGroups(absl::Span<const ReplicaGroup> groups_before,absl::Span<const ReplicaGroup> groups_after)379 void CompareReplicaGroups(absl::Span<const ReplicaGroup> groups_before,
380 absl::Span<const ReplicaGroup> groups_after) {
381 ASSERT_EQ(groups_before.size(), groups_after.size());
382 for (int i = 0; i < groups_before.size(); ++i) {
383 // Somewhat verbose way to compare the replica_ids, because EqualsProto
384 // is not available in the open-source build.
385 auto group_before = groups_before[i];
386 std::vector<int64_t> ids_before(group_before.replica_ids().begin(),
387 group_before.replica_ids().end());
388 auto group_after = groups_after[i];
389 std::vector<int64_t> ids_after(group_after.replica_ids().begin(),
390 group_after.replica_ids().end());
391 EXPECT_EQ(ids_before, ids_after);
392 }
393 }
394
TEST_F(ArCrsCombinerTest,RewriteArConvertCrs)395 TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
396 const char* module_str = R"(
397 HloModule foobar
398
399 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
400 %a = bf16[] parameter(0)
401 %b = bf16[] parameter(1)
402 ROOT %add = bf16[] add(%a, %b)
403 }
404
405 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
406 %x = f32[] parameter(0)
407 %y = f32[] parameter(1)
408 ROOT %add = f32[] add(%x, %y)
409 }
410
411 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
412 %p = bf16[] parameter(0)
413 %constant.bf16 = bf16[] constant(1)
414
415 %all-reduce.ar.1 = bf16[]
416 all-reduce(%p),
417 replica_groups={{0},{1}},
418 channel_id=1,
419 to_apply=%sum.bf16,
420 sharding={maximal device=0}
421 %convert.1 = f32[]
422 convert(%all-reduce.ar.1),
423 sharding={maximal device=0}
424 %all-reduce.1 = f32[]
425 all-reduce(%convert.1),
426 replica_groups={{0,1}},
427 to_apply=%sum.f32,
428 sharding={maximal device=0}
429
430 %all-reduce.ar.2 = bf16[]
431 all-reduce(%constant.bf16),
432 replica_groups={{0},{1}},
433 channel_id=1,
434 to_apply=%sum.bf16,
435 sharding={maximal device=1}
436 %convert.2 = f32[]
437 convert(%all-reduce.ar.2),
438 sharding={maximal device=1}
439 %all-reduce.2 = f32[]
440 all-reduce(%convert.2),
441 replica_groups={{0,1}},
442 to_apply=%sum.f32,
443 sharding={maximal device=1}
444
445 ROOT %tuple = (f32[], f32[])
446 tuple(%all-reduce.1, %all-reduce.2),
447 sharding={{maximal device=0}, {maximal device=1}}
448 }
449 )";
450
451 TF_ASSERT_OK_AND_ASSIGN(
452 std::unique_ptr<HloModule> module,
453 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
454 auto crs_before =
455 module->entry_computation()->root_instruction()->operands()[0];
456 auto replica_groups_before = crs_before->replica_groups();
457 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
458 /*spmd_partition=*/false);
459 auto changed = combiner.Run(module.get()).ValueOrDie();
460 EXPECT_TRUE(changed);
461 EXPECT_THAT(module->entry_computation()->root_instruction(),
462 op::Tuple(op::AllReduce(op::Convert(op::Parameter())),
463 op::AllReduce(op::Convert(op::Constant()))));
464 auto crs_after =
465 module->entry_computation()->root_instruction()->operands()[0];
466 auto replica_groups_after = crs_after->replica_groups();
467 CompareReplicaGroups(replica_groups_before, replica_groups_after);
468 }
469
TEST_F(ArCrsCombinerTest,RewriteArConvertCrsSPMD)470 TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) {
471 const char* module_str = R"(
472 HloModule foobar
473
474 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
475 %a = bf16[] parameter(0)
476 %b = bf16[] parameter(1)
477 ROOT %add = bf16[] add(%a, %b)
478 }
479
480 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
481 %x = f32[] parameter(0)
482 %y = f32[] parameter(1)
483 ROOT %add = f32[] add(%x, %y)
484 }
485
486 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
487 %p = bf16[] parameter(0)
488 %all-reduce.ar.1 = bf16[]
489 all-reduce(%p),
490 replica_groups={{0},{1}},
491 channel_id=1,
492 to_apply=%sum.bf16
493 %convert.1 = f32[] convert(%all-reduce.ar.1)
494 %all-reduce.1 = f32[]
495 all-reduce(%convert.1),
496 replica_groups={{0,1}},
497 to_apply=%sum.f32
498 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
499 }
500 )";
501
502 TF_ASSERT_OK_AND_ASSIGN(
503 std::unique_ptr<HloModule> module,
504 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
505 auto crs_before =
506 module->entry_computation()->root_instruction()->operands()[0];
507 auto replica_groups_before = crs_before->replica_groups();
508 ArCrsCombiner combiner(/*num_spatial_partitions=*/2, true);
509 auto changed = combiner.Run(module.get()).ValueOrDie();
510 EXPECT_TRUE(changed);
511 EXPECT_THAT(module->entry_computation()->root_instruction(),
512 op::Tuple(op::AllReduce(op::Convert(op::Parameter()))));
513 auto crs_after =
514 module->entry_computation()->root_instruction()->operands()[0];
515 auto replica_groups_after = crs_after->replica_groups();
516 CompareReplicaGroups(replica_groups_before, replica_groups_after);
517 }
518
TEST_F(ArCrsCombinerTest,RewriteArBitcastCrs)519 TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
520 const char* module_str = R"(
521 HloModule foobar
522
523 %sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
524 %a = f32[2,1] parameter(0)
525 %b = f32[2,1] parameter(1)
526 ROOT %add = f32[2,1] add(%a, %b)
527 }
528
529 %sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
530 %x = f32[2] parameter(0)
531 %y = f32[2] parameter(1)
532 ROOT %add = f32[2] add(%x, %y)
533 }
534
535 ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
536 %p = f32[2,1] parameter(0)
537
538 %all-reduce.ar.1 = f32[2,1]
539 all-reduce(%p),
540 replica_groups={{0},{1}},
541 channel_id=1,
542 to_apply=%sum.1,
543 sharding={maximal device=0}
544 %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
545 %all-reduce.1 = f32[2]
546 all-reduce(%bitcast.1),
547 replica_groups={{0,1}},
548 to_apply=%sum.2,
549 sharding={maximal device=0}
550
551 %all-reduce.ar.2 = f32[2,1]
552 all-reduce(%p),
553 replica_groups={{0},{1}},
554 channel_id=1,
555 to_apply=%sum.1,
556 sharding={maximal device=1}
557 %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
558 %all-reduce.2 = f32[2]
559 all-reduce(%bitcast.2),
560 replica_groups={{0,1}},
561 to_apply=%sum.2,
562 sharding={maximal device=1}
563
564 ROOT %tuple = (f32[2], f32[2])
565 tuple(%all-reduce.1, %all-reduce.2),
566 sharding={{maximal device=0}, {maximal device=1}}
567 }
568 )";
569
570 TF_ASSERT_OK_AND_ASSIGN(
571 std::unique_ptr<HloModule> module,
572 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
573 auto crs_before =
574 module->entry_computation()->root_instruction()->operands()[0];
575 auto replica_groups_before = crs_before->replica_groups();
576 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
577 /*spmd_partition=*/false);
578 auto changed = combiner.Run(module.get()).ValueOrDie();
579 EXPECT_TRUE(changed);
580 EXPECT_THAT(module->entry_computation()->root_instruction(),
581 op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())),
582 op::AllReduce(op::Bitcast(op::Parameter()))));
583 auto crs_after =
584 module->entry_computation()->root_instruction()->operands()[0];
585 auto replica_groups_after = crs_after->replica_groups();
586 CompareReplicaGroups(replica_groups_before, replica_groups_after);
587 }
588
TEST_F(ArCrsCombinerTest,RewriteArMultiplyCrs)589 TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
590 const char* module_str = R"(
591 HloModule foobar
592
593 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
594 %x = f32[] parameter(0)
595 %y = f32[] parameter(1)
596 ROOT %add = f32[] add(%x, %y)
597 }
598
599 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
600 %p = f32[] parameter(0)
601 %constant.f32 = f32[] constant(123)
602
603 %all-reduce.ar.1 = f32[]
604 all-reduce(%p),
605 replica_groups={{0},{1}},
606 channel_id=1,
607 to_apply=%sum.f32,
608 sharding={maximal device=0}
609 %multiply.1 = f32[]
610 multiply(%all-reduce.ar.1, %constant.f32),
611 sharding={maximal device=0}
612 %all-reduce.1 = f32[]
613 all-reduce(%multiply.1),
614 replica_groups={{0,1}},
615 to_apply=%sum.f32,
616 sharding={maximal device=0}
617
618 %all-reduce.ar.2 = f32[]
619 all-reduce(%p),
620 replica_groups={{0},{1}},
621 channel_id=1,
622 to_apply=%sum.f32,
623 sharding={maximal device=1}
624 %multiply.2 = f32[]
625 multiply(%all-reduce.ar.2, %constant.f32),
626 sharding={maximal device=1}
627 %all-reduce.2 = f32[]
628 all-reduce(%multiply.2),
629 replica_groups={{0,1}},
630 to_apply=%sum.f32,
631 sharding={maximal device=1}
632
633 ROOT %tuple = (f32[], f32[])
634 tuple(%all-reduce.1, %all-reduce.2),
635 sharding={{maximal device=0}, {maximal device=1}}
636 }
637 )";
638
639 TF_ASSERT_OK_AND_ASSIGN(
640 std::unique_ptr<HloModule> module,
641 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
642 auto crs_before =
643 module->entry_computation()->root_instruction()->operands()[0];
644 auto replica_groups_before = crs_before->replica_groups();
645 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
646 /*spmd_partition=*/false);
647 auto changed = combiner.Run(module.get()).ValueOrDie();
648 EXPECT_TRUE(changed);
649 EXPECT_THAT(
650 module->entry_computation()->root_instruction(),
651 op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())),
652 op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
653 auto crs_after =
654 module->entry_computation()->root_instruction()->operands()[0];
655 auto replica_groups_after = crs_after->replica_groups();
656 CompareReplicaGroups(replica_groups_before, replica_groups_after);
657 }
658
TEST_F(ArCrsCombinerTest,RewriteArMultiplyCrsSPMD)659 TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) {
660 const char* module_str = R"(
661 HloModule foobar
662
663 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
664 %x = f32[] parameter(0)
665 %y = f32[] parameter(1)
666 ROOT %add = f32[] add(%x, %y)
667 }
668
669 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
670 %p = f32[] parameter(0)
671 %constant.f32 = f32[] constant(123)
672
673 %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
674 channel_id=1, to_apply=%sum.f32
675 %multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32)
676 %all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}},
677 to_apply=%sum.f32, sharding={maximal device=0}
678 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
679 }
680 )";
681
682 TF_ASSERT_OK_AND_ASSIGN(
683 std::unique_ptr<HloModule> module,
684 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
685 auto crs_before =
686 module->entry_computation()->root_instruction()->operands()[0];
687 auto replica_groups_before = crs_before->replica_groups();
688 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
689 /*spmd_partition=*/true);
690 auto changed = combiner.Run(module.get()).ValueOrDie();
691 EXPECT_TRUE(changed);
692 EXPECT_THAT(
693 module->entry_computation()->root_instruction(),
694 op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
695 auto crs_after =
696 module->entry_computation()->root_instruction()->operands()[0];
697 auto replica_groups_after = crs_after->replica_groups();
698 CompareReplicaGroups(replica_groups_before, replica_groups_after);
699 }
700
TEST_F(ArCrsCombinerTest,RewriteArConvertAddCrs)701 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
702 const char* module_str = R"(
703 HloModule foobar
704
705 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
706 %a = bf16[] parameter(0)
707 %b = bf16[] parameter(1)
708 ROOT %add = bf16[] add(%a, %b)
709 }
710
711 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
712 %x = f32[] parameter(0)
713 %y = f32[] parameter(1)
714 ROOT %add = f32[] add(%x, %y)
715 }
716
717 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
718 %p = f32[] parameter(0)
719 %constant.bf16 = bf16[] constant(1)
720 %constant.f32 = f32[] constant(2)
721
722 %all-reduce.ar.1 = bf16[]
723 all-reduce(%constant.bf16),
724 replica_groups={{0},{1}},
725 channel_id=1,
726 to_apply=%sum.bf16,
727 sharding={maximal device=0}
728 %convert.1 = f32[]
729 convert(%all-reduce.ar.1),
730 sharding={maximal device=0}
731 %add.1 = f32[]
732 add(%constant.f32, %convert.1),
733 sharding={maximal device=0}
734 %all-reduce.1 = f32[]
735 all-reduce(%add.1),
736 replica_groups={{0,1}},
737 to_apply=%sum.f32,
738 sharding={maximal device=0}
739
740 %all-reduce.ar.2 = bf16[]
741 all-reduce(%constant.bf16),
742 replica_groups={{0},{1}},
743 channel_id=1,
744 to_apply=%sum.bf16,
745 sharding={maximal device=1}
746 %convert.2 = f32[]
747 convert(%all-reduce.ar.2),
748 sharding={maximal device=1}
749 %add.2 = f32[]
750 add(%constant.f32, %convert.2),
751 sharding={maximal device=1}
752 %all-reduce.2 = f32[]
753 all-reduce(%add.2),
754 replica_groups={{0,1}},
755 to_apply=%sum.f32,
756 sharding={maximal device=1}
757
758 ROOT %tuple = (f32[], f32[])
759 tuple(%all-reduce.1, %all-reduce.2),
760 sharding={{maximal device=0}, {maximal device=1}}
761 }
762 )";
763
764 TF_ASSERT_OK_AND_ASSIGN(
765 std::unique_ptr<HloModule> module,
766 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
767 auto crs_before =
768 module->entry_computation()->root_instruction()->operands()[0];
769 auto replica_groups_before = crs_before->replica_groups();
770 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
771 /*spmd_partition=*/false);
772 auto changed = combiner.Run(module.get()).ValueOrDie();
773 EXPECT_TRUE(changed);
774 EXPECT_THAT(
775 module->entry_computation()->root_instruction(),
776 op::Tuple(
777 op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
778 op::Convert())),
779 op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
780 op::Convert()))));
781 auto crs_after =
782 module->entry_computation()->root_instruction()->operands()[0];
783 auto replica_groups_after = crs_after->replica_groups();
784 CompareReplicaGroups(replica_groups_before, replica_groups_after);
785 }
786
TEST_F(ArCrsCombinerTest,RewriteArConvertAddCrsSPMD)787 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) {
788 const char* module_str = R"(
789 HloModule foobar
790
791 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
792 %a = bf16[] parameter(0)
793 %b = bf16[] parameter(1)
794 ROOT %add = bf16[] add(%a, %b)
795 }
796
797 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
798 %x = f32[] parameter(0)
799 %y = f32[] parameter(1)
800 ROOT %add = f32[] add(%x, %y)
801 }
802
803 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
804 %p = f32[] parameter(0)
805 %constant.bf16 = bf16[] constant(1)
806 %constant.f32 = f32[] constant(2)
807
808 %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
809 channel_id=1, to_apply=%sum.bf16
810 %convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0}
811 %add.1 = f32[] add(%constant.f32, %convert.1)
812 %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}},
813 to_apply=%sum.f32
814 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
815 }
816 )";
817
818 TF_ASSERT_OK_AND_ASSIGN(
819 std::unique_ptr<HloModule> module,
820 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
821 auto crs_before =
822 module->entry_computation()->root_instruction()->operands()[0];
823 auto replica_groups_before = crs_before->replica_groups();
824 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
825 /*spmd_partition=*/true);
826 auto changed = combiner.Run(module.get()).ValueOrDie();
827 EXPECT_TRUE(changed);
828 EXPECT_THAT(module->entry_computation()->root_instruction(),
829 op::Tuple(op::AllReduce(op::Add(
830 op::Divide(op::Constant(), op::Constant()), op::Convert()))));
831 auto crs_after =
832 module->entry_computation()->root_instruction()->operands()[0];
833 auto replica_groups_after = crs_after->replica_groups();
834 CompareReplicaGroups(replica_groups_before, replica_groups_after);
835 }
836
TEST_F(ArCrsCombinerTest,OtherSummandNotTheSameDontRewrite)837 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
838 const char* module_str = R"(
839 HloModule foobar
840
841 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
842 %a = bf16[] parameter(0)
843 %b = bf16[] parameter(1)
844 ROOT %add = bf16[] add(%a, %b)
845 }
846
847 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
848 %x = f32[] parameter(0)
849 %y = f32[] parameter(1)
850 ROOT %add = f32[] add(%x, %y)
851 }
852
853 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
854 %p = f32[] parameter(0)
855 %constant.bf16 = bf16[] constant(1)
856 %constant.f32.1 = f32[] constant(2)
857 %constant.f32.2 = f32[] constant(3)
858
859 %all-reduce.ar.1 = bf16[]
860 all-reduce(%constant.bf16),
861 replica_groups={{0},{1}},
862 channel_id=1,
863 to_apply=%sum.bf16,
864 sharding={maximal device=0}
865 %convert.1 = f32[]
866 convert(%all-reduce.ar.1),
867 sharding={maximal device=0}
868 %add.1 = f32[]
869 add(%constant.f32.1, %convert.1),
870 sharding={maximal device=0}
871 %all-reduce.1 = f32[]
872 all-reduce(%add.1),
873 replica_groups={{0,1}},
874 to_apply=%sum.f32,
875 sharding={maximal device=0}
876
877 %all-reduce.ar.2 = bf16[]
878 all-reduce(%constant.bf16),
879 replica_groups={{0},{1}},
880 channel_id=1,
881 to_apply=%sum.bf16,
882 sharding={maximal device=1}
883 %convert.2 = f32[]
884 convert(%all-reduce.ar.2),
885 sharding={maximal device=1}
886 %add.2 = f32[]
887 add(%constant.f32.2, %convert.2),
888 sharding={maximal device=1}
889 %all-reduce.2 = f32[]
890 all-reduce(%add.2),
891 replica_groups={{0,1}},
892 to_apply=%sum.f32,
893 sharding={maximal device=1}
894
895 ROOT %tuple = (f32[], f32[])
896 tuple(%all-reduce.1, %all-reduce.2),
897 sharding={{maximal device=0}, {maximal device=1}}
898 }
899 )";
900
901 TF_ASSERT_OK_AND_ASSIGN(
902 std::unique_ptr<HloModule> module,
903 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
904 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
905 /*spmd_partition=*/false);
906 auto changed = combiner.Run(module.get()).ValueOrDie();
907 EXPECT_FALSE(changed);
908 }
909
TEST_F(ArCrsCombinerTest,OtherSummandNotTheSameDontRewriteSPMD)910 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) {
911 const char* module_str = R"(
912 HloModule foobar
913
914 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
915 %a = bf16[] parameter(0)
916 %b = bf16[] parameter(1)
917 ROOT %add = bf16[] add(%a, %b)
918 }
919
920 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
921 %x = f32[] parameter(0)
922 %y = f32[] parameter(1)
923 ROOT %add = f32[] add(%x, %y)
924 }
925
926 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
927 %p = f32[] parameter(0)
928 %constant.bf16 = bf16[] constant(1)
929 %constant.f32.1 = f32[] constant(2)
930
931 %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
932 channel_id=1, to_apply=%sum.bf16
933 %convert.1 = f32[] convert(%all-reduce.ar.1)
934 %add.1 = f32[] add(%p, %convert.1)
935 %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32
936 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
937 }
938 )";
939
940 TF_ASSERT_OK_AND_ASSIGN(
941 std::unique_ptr<HloModule> module,
942 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
943 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
944 /*spmd_partition=*/true);
945 auto changed = combiner.Run(module.get()).ValueOrDie();
946 EXPECT_FALSE(changed);
947 }
948
TEST_F(ArCrsCombinerTest,ArThenCrsDontCrash)949 TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) {
950 const char* module_str = R"(
951 HloModule foobar
952
953 %sum.1 (a: f32[], b: f32[]) -> f32[] {
954 %a = f32[] parameter(0)
955 %b = f32[] parameter(1)
956 ROOT %add = f32[] add(%a, %b)
957 }
958
959 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
960 %p = f32[] parameter(0)
961 %constant.f32 = f32[] constant(123)
962
963 %all-reduce.ar.1 = f32[]
964 all-reduce(%p),
965 replica_groups={{0},{1}},
966 channel_id=1,
967 to_apply=%sum.1,
968 sharding={maximal device=0}
969 %all-reduce.1 = f32[]
970 all-reduce(%all-reduce.ar.1),
971 replica_groups={{0,1}},
972 to_apply=%sum.1,
973 sharding={maximal device=0}
974 %multiply.1 = f32[]
975 multiply(%all-reduce.1, %constant.f32),
976 sharding={maximal device=0}
977
978 %all-reduce.ar.2 = f32[]
979 all-reduce(%p),
980 replica_groups={{0},{1}},
981 channel_id=1,
982 to_apply=%sum.1,
983 sharding={maximal device=1}
984 %all-reduce.2 = f32[]
985 all-reduce(%all-reduce.ar.2),
986 replica_groups={{0,1}},
987 to_apply=%sum.1,
988 sharding={maximal device=1}
989 %multiply.2 = f32[]
990 multiply(%all-reduce.2, %constant.f32),
991 sharding={maximal device=1}
992
993 ROOT %tuple = (f32[], f32[])
994 tuple(%all-reduce.1, %all-reduce.2),
995 sharding={{maximal device=0}, {maximal device=1}}
996 }
997 )";
998
999 TF_ASSERT_OK_AND_ASSIGN(
1000 std::unique_ptr<HloModule> module,
1001 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1002 auto crs_before =
1003 module->entry_computation()->root_instruction()->operands()[0];
1004 auto replica_groups_before = crs_before->replica_groups();
1005 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1006 /*spmd_partition=*/false);
1007 auto changed = combiner.Run(module.get()).ValueOrDie();
1008 EXPECT_TRUE(changed);
1009 EXPECT_THAT(module->entry_computation()->root_instruction(),
1010 op::Tuple(op::AllReduce(op::Parameter()),
1011 op::AllReduce(op::Parameter())));
1012 auto crs_after =
1013 module->entry_computation()->root_instruction()->operands()[0];
1014 auto replica_groups_after = crs_after->replica_groups();
1015 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1016 }
1017
TEST_F(ArCrsCombinerTest,RewriteMultipleAdds)1018 TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) {
1019 const char* module_str = R"(
1020 HloModule foobar
1021
1022 %sum (x: f32[], y: f32[]) -> f32[] {
1023 %x = f32[] parameter(0)
1024 %y = f32[] parameter(1)
1025 ROOT %add = f32[] add(%x, %y)
1026 }
1027
1028 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1029 %p = f32[] parameter(0)
1030 %constant.1 = f32[] constant(1)
1031 %constant.2 = f32[] constant(2)
1032
1033 %all-reduce.ar.1 = f32[]
1034 all-reduce(%p),
1035 replica_groups={{0},{1}},
1036 channel_id=1,
1037 to_apply=%sum,
1038 sharding={maximal device=0}
1039 %add.11 = f32[]
1040 add(%constant.1, %all-reduce.ar.1),
1041 sharding={maximal device=0}
1042 %add.12 = f32[]
1043 add(%constant.2, %add.11),
1044 sharding={maximal device=0}
1045 %all-reduce.1 = f32[]
1046 all-reduce(%add.12),
1047 replica_groups={{0,1}},
1048 to_apply=%sum,
1049 sharding={maximal device=0}
1050
1051 %all-reduce.ar.2 = f32[]
1052 all-reduce(%p),
1053 replica_groups={{0},{1}},
1054 channel_id=1,
1055 to_apply=%sum,
1056 sharding={maximal device=0}
1057 %add.21 = f32[]
1058 add(%constant.1, %all-reduce.ar.2),
1059 sharding={maximal device=0}
1060 %add.22 = f32[]
1061 add(%constant.2, %add.21),
1062 sharding={maximal device=0}
1063 %all-reduce.2 = f32[]
1064 all-reduce(%add.22),
1065 replica_groups={{0,1}},
1066 to_apply=%sum,
1067 sharding={maximal device=0}
1068
1069 ROOT %tuple = (f32[], f32[])
1070 tuple(%all-reduce.1, %all-reduce.2),
1071 sharding={{maximal device=0}, {maximal device=1}}
1072 }
1073 )";
1074
1075 TF_ASSERT_OK_AND_ASSIGN(
1076 std::unique_ptr<HloModule> module,
1077 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1078 auto crs_before =
1079 module->entry_computation()->root_instruction()->operands()[0];
1080 auto replica_groups_before = crs_before->replica_groups();
1081 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1082 /*spmd_partition=*/false);
1083 auto changed = combiner.Run(module.get()).ValueOrDie();
1084 EXPECT_TRUE(changed);
1085 EXPECT_THAT(module->entry_computation()->root_instruction(),
1086 op::Tuple(op::AllReduce(op::Add(
1087 op::Divide(op::Constant(), op::Constant()),
1088 op::Add(op::Divide(op::Constant(), op::Constant()),
1089 op::Parameter()))),
1090 op::AllReduce(op::Add(
1091 op::Divide(op::Constant(), op::Constant()),
1092 op::Add(op::Divide(op::Constant(), op::Constant()),
1093 op::Parameter())))));
1094 auto crs_after =
1095 module->entry_computation()->root_instruction()->operands()[0];
1096 auto replica_groups_after = crs_after->replica_groups();
1097 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1098 }
1099
TEST_F(ArCrsCombinerTest,RewriteMultipleAddsSPMD)1100 TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) {
1101 const char* module_str = R"(
1102 HloModule foobar
1103
1104 %sum (x: f32[], y: f32[]) -> f32[] {
1105 %x = f32[] parameter(0)
1106 %y = f32[] parameter(1)
1107 ROOT %add = f32[] add(%x, %y)
1108 }
1109
1110 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1111 %p = f32[] parameter(0)
1112 %constant.1 = f32[] constant(1)
1113 %constant.2 = f32[] constant(2)
1114
1115 %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
1116 channel_id=1, to_apply=%sum
1117 %add.11 = f32[] add(%constant.1, %all-reduce.ar.1)
1118 %add.12 = f32[] add(%constant.2, %add.11)
1119 %all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum
1120 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
1121 }
1122 )";
1123
1124 TF_ASSERT_OK_AND_ASSIGN(
1125 std::unique_ptr<HloModule> module,
1126 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1127 auto crs_before =
1128 module->entry_computation()->root_instruction()->operands()[0];
1129 auto replica_groups_before = crs_before->replica_groups();
1130 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1131 /*spmd_partition=*/true);
1132 auto changed = combiner.Run(module.get()).ValueOrDie();
1133 EXPECT_TRUE(changed);
1134 EXPECT_THAT(module->entry_computation()->root_instruction(),
1135 op::Tuple(op::AllReduce(
1136 op::Add(op::Divide(op::Constant(), op::Constant()),
1137 op::Add(op::Divide(op::Constant(), op::Constant()),
1138 op::Parameter())))));
1139 auto crs_after =
1140 module->entry_computation()->root_instruction()->operands()[0];
1141 auto replica_groups_after = crs_after->replica_groups();
1142 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1143 }
1144
TEST_F(ArCrsCombinerTest,RewriteArSubtractCrs)1145 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
1146 const char* module_str = R"(
1147 HloModule foobar
1148
1149 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1150 %x = f32[] parameter(0)
1151 %y = f32[] parameter(1)
1152 ROOT %add = f32[] add(%x, %y)
1153 }
1154
1155 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1156 %p = f32[] parameter(0)
1157 %constant.f32 = f32[] constant(123)
1158
1159 %all-reduce.ar.1 = f32[]
1160 all-reduce(%p),
1161 replica_groups={{0},{1}},
1162 channel_id=1,
1163 to_apply=%sum.f32,
1164 sharding={maximal device=0}
1165 %sub.1 = f32[]
1166 subtract(%constant.f32, %all-reduce.ar.1),
1167 sharding={maximal device=0}
1168 %all-reduce.1 = f32[]
1169 all-reduce(%sub.1),
1170 replica_groups={{0,1}},
1171 to_apply=%sum.f32,
1172 sharding={maximal device=0}
1173
1174 %all-reduce.ar.2 = f32[]
1175 all-reduce(%p),
1176 replica_groups={{0},{1}},
1177 channel_id=1,
1178 to_apply=%sum.f32,
1179 sharding={maximal device=1}
1180 %sub.2 = f32[]
1181 subtract(%constant.f32, %all-reduce.ar.2),
1182 sharding={maximal device=1}
1183 %all-reduce.2 = f32[]
1184 all-reduce(%sub.2),
1185 replica_groups={{0,1}},
1186 to_apply=%sum.f32,
1187 sharding={maximal device=1}
1188
1189 ROOT %tuple = (f32[], f32[])
1190 tuple(%all-reduce.1, %all-reduce.2),
1191 sharding={{maximal device=0}, {maximal device=1}}
1192 }
1193 )";
1194
1195 TF_ASSERT_OK_AND_ASSIGN(
1196 std::unique_ptr<HloModule> module,
1197 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1198 auto crs_before =
1199 module->entry_computation()->root_instruction()->operands()[0];
1200 auto replica_groups_before = crs_before->replica_groups();
1201 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1202 /*spmd_partition=*/false);
1203 auto changed = combiner.Run(module.get()).ValueOrDie();
1204 EXPECT_TRUE(changed);
1205 EXPECT_THAT(
1206 module->entry_computation()->root_instruction(),
1207 op::Tuple(
1208 op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
1209 op::Parameter())),
1210 op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
1211 op::Parameter()))));
1212 auto crs_after =
1213 module->entry_computation()->root_instruction()->operands()[0];
1214 auto replica_groups_after = crs_after->replica_groups();
1215 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1216 }
1217
TEST_F(ArCrsCombinerTest,RewriteArSubtractCrsSPMD)1218 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) {
1219 const char* module_str = R"(
1220 HloModule foobar
1221
1222 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1223 %x = f32[] parameter(0)
1224 %y = f32[] parameter(1)
1225 ROOT %add = f32[] add(%x, %y)
1226 }
1227
1228 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1229 %p = f32[] parameter(0)
1230 %constant.f32 = f32[] constant(123)
1231 %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
1232 channel_id=1, to_apply=%sum.f32
1233 %sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1)
1234 %all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}},
1235 to_apply=%sum.f32
1236 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
1237 }
1238 )";
1239
1240 TF_ASSERT_OK_AND_ASSIGN(
1241 std::unique_ptr<HloModule> module,
1242 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1243 auto crs_before =
1244 module->entry_computation()->root_instruction()->operands()[0];
1245 auto replica_groups_before = crs_before->replica_groups();
1246 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1247 /*spmd_partition=*/true);
1248 auto changed = combiner.Run(module.get()).ValueOrDie();
1249 EXPECT_TRUE(changed);
1250 EXPECT_THAT(
1251 module->entry_computation()->root_instruction(),
1252 op::Tuple(op::AllReduce(op::Subtract(
1253 op::Divide(op::Constant(), op::Constant()), op::Parameter()))));
1254 auto crs_after =
1255 module->entry_computation()->root_instruction()->operands()[0];
1256 auto replica_groups_after = crs_after->replica_groups();
1257 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1258 }
1259
TEST_F(ArCrsCombinerTest,RewriteMultipleARsLeft)1260 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
1261 const char* module_str = R"(
1262 HloModule foobar
1263
1264 %sum (x: f32[], y: f32[]) -> f32[] {
1265 %x = f32[] parameter(0)
1266 %y = f32[] parameter(1)
1267 ROOT %add = f32[] add(%x, %y)
1268 }
1269
1270 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1271 %p = f32[] parameter(0)
1272 %const1 = f32[] constant(1)
1273 %const2 = f32[] constant(2)
1274
1275 %ar11 = f32[]
1276 all-reduce(%p),
1277 replica_groups={{0},{1}},
1278 channel_id=1,
1279 to_apply=%sum,
1280 sharding={maximal device=0}
1281 %add11 = f32[]
1282 add(%ar11, %const1),
1283 sharding={maximal device=0}
1284 %ar12 = f32[]
1285 all-reduce(%p),
1286 replica_groups={{0},{1}},
1287 channel_id=2,
1288 to_apply=%sum,
1289 sharding={maximal device=0}
1290 %add12 = f32[]
1291 add(%add11, %ar12),
1292 sharding={maximal device=0}
1293 %crs1 = f32[]
1294 all-reduce(%add12),
1295 replica_groups={{0,1}},
1296 to_apply=%sum,
1297 sharding={maximal device=0}
1298
1299 %ar21 = f32[]
1300 all-reduce(%p),
1301 replica_groups={{0},{1}},
1302 channel_id=1,
1303 to_apply=%sum,
1304 sharding={maximal device=1}
1305 %add21 = f32[]
1306 add(%ar21, %const1),
1307 sharding={maximal device=1}
1308 %ar22 = f32[]
1309 all-reduce(%p),
1310 replica_groups={{0},{1}},
1311 channel_id=2,
1312 to_apply=%sum,
1313 sharding={maximal device=1}
1314 %add22 = f32[]
1315 add(%add21, %ar22),
1316 sharding={maximal device=1}
1317 %crs2 = f32[]
1318 all-reduce(%add22),
1319 replica_groups={{0,1}},
1320 to_apply=%sum,
1321 sharding={maximal device=1}
1322
1323 ROOT %tuple = (f32[], f32[])
1324 tuple(%crs1, %crs2),
1325 sharding={{maximal device=0}, {maximal device=1}}
1326 }
1327 )";
1328
1329 TF_ASSERT_OK_AND_ASSIGN(
1330 std::unique_ptr<HloModule> module,
1331 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1332 auto crs_before =
1333 module->entry_computation()->root_instruction()->operands()[0];
1334 auto replica_groups_before = crs_before->replica_groups();
1335 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1336 /*spmd_partition=*/false);
1337 auto changed = combiner.Run(module.get()).ValueOrDie();
1338 EXPECT_TRUE(changed);
1339 EXPECT_THAT(module->entry_computation()->root_instruction(),
1340 op::Tuple(op::AllReduce(op::Add(
1341 op::Add(op::Parameter(),
1342 op::Divide(op::Constant(), op::Constant())),
1343 op::Parameter())),
1344 op::AllReduce(op::Add(
1345 op::Add(op::Parameter(),
1346 op::Divide(op::Constant(), op::Constant())),
1347 op::Parameter()))));
1348 auto crs_after =
1349 module->entry_computation()->root_instruction()->operands()[0];
1350 auto replica_groups_after = crs_after->replica_groups();
1351 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1352 }
1353
TEST_F(ArCrsCombinerTest,RewriteMultipleARsLeftSPMD)1354 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) {
1355 const char* module_str = R"(
1356 HloModule foobar
1357
1358 %sum (x: f32[], y: f32[]) -> f32[] {
1359 %x = f32[] parameter(0)
1360 %y = f32[] parameter(1)
1361 ROOT %add = f32[] add(%x, %y)
1362 }
1363
1364 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1365 %p = f32[] parameter(0)
1366 %const1 = f32[] constant(1)
1367 %const2 = f32[] constant(2)
1368
1369 %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1,
1370 to_apply=%sum
1371 %add11 = f32[] add(%ar11, %const1)
1372 %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2,
1373 to_apply=%sum
1374 %add12 = f32[] add(%add11, %ar12)
1375 %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}},
1376 to_apply=%sum
1377 ROOT %tuple = (f32[]) tuple(%crs1)
1378 }
1379 )";
1380
1381 TF_ASSERT_OK_AND_ASSIGN(
1382 std::unique_ptr<HloModule> module,
1383 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1384 auto crs_before =
1385 module->entry_computation()->root_instruction()->operands()[0];
1386 auto replica_groups_before = crs_before->replica_groups();
1387 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1388 /*spmd_partition=*/true);
1389 auto changed = combiner.Run(module.get()).ValueOrDie();
1390 EXPECT_TRUE(changed);
1391 EXPECT_THAT(
1392 module->entry_computation()->root_instruction(),
1393 op::Tuple(op::AllReduce(op::Add(
1394 op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())),
1395 op::Parameter()))));
1396 auto crs_after =
1397 module->entry_computation()->root_instruction()->operands()[0];
1398 auto replica_groups_after = crs_after->replica_groups();
1399 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1400 }
1401
TEST_F(ArCrsCombinerTest,RewriteMultipleARsRight)1402 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
1403 const char* module_str = R"(
1404 HloModule foobar
1405
1406 %sum (x: f32[], y: f32[]) -> f32[] {
1407 %x = f32[] parameter(0)
1408 %y = f32[] parameter(1)
1409 ROOT %add = f32[] add(%x, %y)
1410 }
1411
1412 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1413 %p = f32[] parameter(0)
1414 %const1 = f32[] constant(1)
1415 %const2 = f32[] constant(2)
1416
1417 %ar11 = f32[]
1418 all-reduce(%p),
1419 replica_groups={{0},{1}},
1420 channel_id=1,
1421 to_apply=%sum,
1422 sharding={maximal device=0}
1423 %ar12 = f32[]
1424 all-reduce(%p),
1425 replica_groups={{0},{1}},
1426 channel_id=2,
1427 to_apply=%sum,
1428 sharding={maximal device=0}
1429 %add11 = f32[]
1430 add(%ar12, %const1),
1431 sharding={maximal device=0}
1432 %add12 = f32[]
1433 add(%ar11, %add11),
1434 sharding={maximal device=0}
1435 %crs1 = f32[]
1436 all-reduce(%add12),
1437 replica_groups={{0,1}},
1438 to_apply=%sum,
1439 sharding={maximal device=0}
1440
1441 %ar21 = f32[]
1442 all-reduce(%p),
1443 replica_groups={{0},{1}},
1444 channel_id=1,
1445 to_apply=%sum,
1446 sharding={maximal device=1}
1447 %ar22 = f32[]
1448 all-reduce(%p),
1449 replica_groups={{0},{1}},
1450 channel_id=2,
1451 to_apply=%sum,
1452 sharding={maximal device=1}
1453 %add21 = f32[]
1454 add(%ar22, %const1),
1455 sharding={maximal device=1}
1456 %add22 = f32[]
1457 add(%ar21, %add21),
1458 sharding={maximal device=1}
1459 %crs2 = f32[]
1460 all-reduce(%add22),
1461 replica_groups={{0,1}},
1462 to_apply=%sum,
1463 sharding={maximal device=1}
1464
1465 ROOT %tuple = (f32[], f32[])
1466 tuple(%crs1, %crs2),
1467 sharding={{maximal device=0}, {maximal device=1}}
1468 }
1469 )";
1470
1471 TF_ASSERT_OK_AND_ASSIGN(
1472 std::unique_ptr<HloModule> module,
1473 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1474 auto crs_before =
1475 module->entry_computation()->root_instruction()->operands()[0];
1476 auto replica_groups_before = crs_before->replica_groups();
1477 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1478 /*spmd_partition=*/false);
1479 auto changed = combiner.Run(module.get()).ValueOrDie();
1480 EXPECT_TRUE(changed);
1481 EXPECT_THAT(
1482 module->entry_computation()->root_instruction(),
1483 op::Tuple(op::AllReduce(op::Add(
1484 op::Parameter(),
1485 op::Add(op::Parameter(),
1486 op::Divide(op::Constant(), op::Constant())))),
1487 op::AllReduce(op::Add(
1488 op::Parameter(),
1489 op::Add(op::Parameter(),
1490 op::Divide(op::Constant(), op::Constant()))))));
1491
1492 auto crs_after =
1493 module->entry_computation()->root_instruction()->operands()[0];
1494 auto replica_groups_after = crs_after->replica_groups();
1495 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1496 }
1497
TEST_F(ArCrsCombinerTest,RewriteMultipleARsRightSPMD)1498 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) {
1499 const char* module_str = R"(
1500 HloModule foobar
1501
1502 %sum (x: f32[], y: f32[]) -> f32[] {
1503 %x = f32[] parameter(0)
1504 %y = f32[] parameter(1)
1505 ROOT %add = f32[] add(%x, %y)
1506 }
1507
1508 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1509 %p = f32[] parameter(0)
1510 %const1 = f32[] constant(1)
1511 %const2 = f32[] constant(2)
1512
1513 %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum
1514 %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum
1515 %add11 = f32[] add(%ar12, %const1)
1516 %add12 = f32[] add(%ar11, %add11)
1517 %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum
1518 ROOT %tuple = (f32[]) tuple(%crs1)
1519 }
1520 )";
1521
1522 TF_ASSERT_OK_AND_ASSIGN(
1523 std::unique_ptr<HloModule> module,
1524 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1525 auto crs_before =
1526 module->entry_computation()->root_instruction()->operands()[0];
1527 auto replica_groups_before = crs_before->replica_groups();
1528 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1529 /*spmd_partition=*/true);
1530 auto changed = combiner.Run(module.get()).ValueOrDie();
1531 EXPECT_TRUE(changed);
1532 EXPECT_THAT(module->entry_computation()->root_instruction(),
1533 op::Tuple(op::AllReduce(op::Add(
1534 op::Parameter(),
1535 op::Add(op::Parameter(),
1536 op::Divide(op::Constant(), op::Constant()))))));
1537
1538 auto crs_after =
1539 module->entry_computation()->root_instruction()->operands()[0];
1540 auto replica_groups_after = crs_after->replica_groups();
1541 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1542 }
1543
TEST_F(ArCrsCombinerTest,OneReplicaDontRewrite)1544 TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
1545 const char* module_str = R"(
1546 HloModule foobar
1547
1548 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
1549 %a = bf16[] parameter(0)
1550 %b = bf16[] parameter(1)
1551 ROOT %add = bf16[] add(%a, %b)
1552 }
1553
1554 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1555 %x = f32[] parameter(0)
1556 %y = f32[] parameter(1)
1557 ROOT %add = f32[] add(%x, %y)
1558 }
1559
1560 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
1561 %p = bf16[] parameter(0)
1562 %constant.bf16 = bf16[] constant(1)
1563
1564 %all-reduce.ar.1 = bf16[]
1565 all-reduce(%p),
1566 replica_groups={{0}},
1567 channel_id=1,
1568 to_apply=%sum.bf16,
1569 sharding={maximal device=0}
1570 %convert.1 = f32[]
1571 convert(%all-reduce.ar.1),
1572 sharding={maximal device=0}
1573 %all-reduce.1 = f32[]
1574 all-reduce(%convert.1),
1575 replica_groups={{0}},
1576 to_apply=%sum.f32,
1577 sharding={maximal device=0}
1578
1579 %all-reduce.ar.2 = bf16[]
1580 all-reduce(%constant.bf16),
1581 replica_groups={{0}},
1582 channel_id=1,
1583 to_apply=%sum.bf16,
1584 sharding={maximal device=1}
1585 %convert.2 = f32[]
1586 convert(%all-reduce.ar.2),
1587 sharding={maximal device=1}
1588 %all-reduce.2 = f32[]
1589 all-reduce(%convert.2),
1590 replica_groups={{0}},
1591 to_apply=%sum.f32,
1592 sharding={maximal device=1}
1593
1594 ROOT %tuple = (f32[], f32[])
1595 tuple(%all-reduce.1, %all-reduce.2),
1596 sharding={{maximal device=0}, {maximal device=1}}
1597 }
1598 )";
1599
1600 TF_ASSERT_OK_AND_ASSIGN(
1601 std::unique_ptr<HloModule> module,
1602 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
1603 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1604 /*spmd_partition=*/false);
1605 auto changed = combiner.Run(module.get()).ValueOrDie();
1606 EXPECT_FALSE(changed);
1607 }
1608
TEST_F(ArCrsCombinerTest,OneReplicaDontRewriteSPMD)1609 TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) {
1610 const char* module_str = R"(
1611 HloModule foobar
1612
1613 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
1614 %a = bf16[] parameter(0)
1615 %b = bf16[] parameter(1)
1616 ROOT %add = bf16[] add(%a, %b)
1617 }
1618
1619 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1620 %x = f32[] parameter(0)
1621 %y = f32[] parameter(1)
1622 ROOT %add = f32[] add(%x, %y)
1623 }
1624
1625 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
1626 %p = bf16[] parameter(0)
1627 %constant.bf16 = bf16[] constant(1)
1628
1629 %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}},
1630 channel_id=1, to_apply=%sum.bf16
1631 %convert.1 = f32[] convert(%all-reduce.ar.1)
1632 %all-reduce.1 = f32[] all-reduce(%convert.1),
1633 replica_groups={{0}}, to_apply=%sum.f32
1634 ROOT %tuple = (f32[]) tuple(%all-reduce.1)
1635 }
1636 )";
1637
1638 TF_ASSERT_OK_AND_ASSIGN(
1639 std::unique_ptr<HloModule> module,
1640 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
1641 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1642 /*spmd_partition=*/true);
1643 auto changed = combiner.Run(module.get()).ValueOrDie();
1644 EXPECT_FALSE(changed);
1645 }
1646
TEST_F(ArCrsCombinerTest,SameValueTestConditional)1647 TEST_F(ArCrsCombinerTest, SameValueTestConditional) {
1648 const char* module_str = R"(
1649 HloModule foobar
1650
1651 branch_true {
1652 pt = (f32[2,4], f32[2,4]) parameter(0)
1653 gte.0 = f32[2,4] get-tuple-element(pt), index=0
1654 gte.1 = f32[2,4] get-tuple-element(pt), index=1
1655 ROOT tuple.t = (f32[2,4], f32[2,4]) tuple(gte.1, gte.0)
1656 }
1657
1658 branch_false {
1659 pf = (f32[2,4], f32[2,4]) parameter(0)
1660 gte.0 = f32[2,4] get-tuple-element(pf), index=0
1661 gte.1 = f32[2,4] get-tuple-element(pf), index=1
1662 add = f32[2,4] add(gte.1, gte.1)
1663 ROOT tuple.f = (f32[2,4], f32[2,4]) tuple(gte.0, add)
1664 }
1665
1666 ENTRY Parameters1.v4 {
1667 constant = pred[] constant(true)
1668 p = f32[2,4] parameter(0)
1669 tuple = (f32[2,4], f32[2,4]) tuple(p, p)
1670 ROOT conditional = (f32[2,4], f32[2,4]) conditional(constant, tuple, tuple), true_computation=branch_true, false_computation=branch_false
1671 }
1672 )";
1673
1674 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1675 ParseAndReturnVerifiedModule(module_str));
1676 auto cond = module->entry_computation()->root_instruction();
1677
1678 auto branch_true = cond->branch_computation(0)->root_instruction();
1679 auto t0 = branch_true->mutable_operand(0);
1680 auto t1 = branch_true->mutable_operand(1);
1681 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(t0, t1));
1682
1683 auto branch_false = cond->branch_computation(1)->root_instruction();
1684 auto f0 = branch_false->mutable_operand(0);
1685 auto f1 = branch_false->mutable_operand(1);
1686 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(f0, f1));
1687 }
1688
TEST_F(ArCrsCombinerTest,AllReduceWithReplicas)1689 TEST_F(ArCrsCombinerTest, AllReduceWithReplicas) {
1690 const char* module_str = R"(
1691 HloModule foobar
1692
1693 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1694 %x = f32[] parameter(0)
1695 %y = f32[] parameter(1)
1696 ROOT %add = f32[] add(%x, %y)
1697 }
1698
1699 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
1700 %p = bf16[] parameter(0)
1701 %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
1702 to_apply=%sum.f32, sharding={maximal device=0}
1703 %all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
1704 to_apply=%sum.f32, sharding={maximal device=1}
1705 %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
1706 to_apply=%sum.f32, sharding={maximal device=0}
1707 %all-reduce.3 = f32[] all-reduce(%all-reduce.1), replica_groups={{0,1}},
1708 to_apply=%sum.f32, sharding={maximal device=1}
1709 ROOT %tuple = (f32[], f32[]) tuple(%all-reduce.2, %all-reduce.3),
1710 sharding={{maximal device=0}, {maximal device=1}}
1711 }
1712 )";
1713
1714 TF_ASSERT_OK_AND_ASSIGN(
1715 std::unique_ptr<HloModule> module,
1716 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1717 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1718 /*spmd_partition=*/false);
1719 auto changed = combiner.Run(module.get()).ValueOrDie();
1720 EXPECT_FALSE(changed);
1721 }
1722
TEST_F(ArCrsCombinerTest,AllReduceWithReplicasSPMD)1723 TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) {
1724 const char* module_str = R"(
1725 HloModule foobar
1726
1727 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1728 %x = f32[] parameter(0)
1729 %y = f32[] parameter(1)
1730 ROOT %add = f32[] add(%x, %y)
1731 }
1732
1733 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
1734 %p = bf16[] parameter(0)
1735 %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0},{1}},
1736 to_apply=%sum.f32
1737 %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0},{1}},
1738 to_apply=%sum.f32
1739 ROOT %tuple = (f32[]) tuple(%all-reduce.2)
1740 }
1741 )";
1742
1743 TF_ASSERT_OK_AND_ASSIGN(
1744 std::unique_ptr<HloModule> module,
1745 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1746 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1747 /*spmd_partition=*/true);
1748 auto changed = combiner.Run(module.get()).ValueOrDie();
1749 EXPECT_FALSE(changed);
1750 }
1751
TEST_F(ArCrsCombinerTest,ReplaceReplicatedAllReduceSPMD)1752 TEST_F(ArCrsCombinerTest, ReplaceReplicatedAllReduceSPMD) {
1753 const char* module_str = R"(
1754 HloModule foobar
1755
1756 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1757 %x = f32[] parameter(0)
1758 %y = f32[] parameter(1)
1759 ROOT %add = f32[] add(%x, %y)
1760 }
1761
1762 ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
1763 %p = f32[2,4] parameter(0), sharding={replicated}
1764 ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32,
1765 replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}}
1766 }
1767 )";
1768
1769 // Replacing replicated all-reduce is only triggered when there are enough
1770 // replicas (currently > num_partitions * 8).
1771 TF_ASSERT_OK_AND_ASSIGN(
1772 std::unique_ptr<HloModule> module,
1773 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32));
1774 ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1775 /*spmd_partition=*/true);
1776 auto changed = combiner.Run(module.get()).ValueOrDie();
1777 EXPECT_TRUE(changed);
1778
1779 auto root = module->entry_computation()->root_instruction();
1780 EXPECT_THAT(root, op::Divide(op::AllReduce(op::Parameter()),
1781 op::Broadcast(op::Constant())));
1782
1783 auto ar = root->operand(0);
1784 auto divisor = root->operand(1)->operand(0);
1785 EXPECT_TRUE(ar->channel_id());
1786 EXPECT_TRUE(divisor->literal().IsAllFloat(2));
1787 }
1788
TEST_F(ArCrsCombinerTest,AllReduceWithGlobalIdReplicaGroups)1789 TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) {
1790 const char* module_str = R"(
1791 HloModule foobar
1792
1793 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1794 %x = f32[] parameter(0)
1795 %y = f32[] parameter(1)
1796 ROOT %add = f32[] add(%x, %y)
1797 }
1798
1799 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
1800 %p = bf16[] parameter(0)
1801 %all-reduce.0 = f32[] all-reduce(%p), channel_id=1,
1802 replica_groups={{0,1,2,3},{4,5,6,7}}, use_global_device_ids=true,
1803 to_apply=%sum.f32
1804 %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
1805 to_apply=%sum.f32
1806 ROOT %tuple = (f32[]) tuple(%all-reduce.2)
1807 }
1808 )";
1809
1810 TF_ASSERT_OK_AND_ASSIGN(
1811 std::unique_ptr<HloModule> module,
1812 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2,
1813 /*num_partitions=*/4));
1814 ArCrsCombiner combiner(/*num_spatial_partitions=*/4,
1815 /*spmd_partition=*/true);
1816 auto changed = combiner.Run(module.get()).ValueOrDie();
1817 EXPECT_TRUE(changed);
1818 }
1819
1820 } // namespace
1821 } // namespace xla
1822