xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/statusor.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 
23 namespace xla {
24 namespace {
25 
26 namespace op = xla::testing::opcode_matchers;
27 
28 class ArCrsCombinerTest : public HloTestBase {};
29 
TEST_F(ArCrsCombinerTest,SameValueTestBasecase)30 TEST_F(ArCrsCombinerTest, SameValueTestBasecase) {
31   const char* module_str = R"(
32 HloModule foobar
33 
34 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
35   %p = f32[2,2] parameter(0)
36   %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
37   %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}})
38   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
39 }
40 )";
41 
42   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
43                           ParseAndReturnVerifiedModule(module_str));
44   auto root_tuple = module->entry_computation()->root_instruction();
45   auto i1 = root_tuple->operands()[0];
46   auto i2 = root_tuple->operands()[1];
47   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(
48       i1, module->entry_computation()->parameter_instruction(0)));
49   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
50 }
51 
TEST_F(ArCrsCombinerTest,SameValueTestBasecase2)52 TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) {
53   const char* module_str = R"(
54 HloModule foobar
55 
56 ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) {
57   %x = f32[] parameter(0)
58   ROOT %tuple = (f32[], f32[]) tuple(%x, %x)
59 }
60 )";
61 
62   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
63                           ParseAndReturnVerifiedModule(module_str));
64   auto root_tuple = module->entry_computation()->root_instruction();
65   auto i1 = root_tuple->operands()[0];
66   auto i2 = root_tuple->operands()[1];
67   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
68 }
69 
TEST_F(ArCrsCombinerTest,SameValueTestBasecase3)70 TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) {
71   const char* module_str = R"(
72 HloModule foobar
73 
74 ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) {
75   %x = f32[] parameter(0)
76   %y = f32[] parameter(1)
77   ROOT %tuple = (f32[], f32[]) tuple(%x, %y)
78 }
79 )";
80 
81   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
82                           ParseAndReturnVerifiedModule(module_str));
83   auto root_tuple = module->entry_computation()->root_instruction();
84   auto i1 = root_tuple->operands()[0];
85   auto i2 = root_tuple->operands()[1];
86   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
87 }
88 
TEST_F(ArCrsCombinerTest,SameValueTestNumOperands)89 TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) {
90   const char* module_str = R"(
91 HloModule foobar
92 
93 ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) {
94   %p = f32[2,2] parameter(0)
95   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
96   %tuple1 = (f32[2,2]) tuple(%constant.f32)
97   %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
98   ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2)
99 }
100 )";
101 
102   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
103                           ParseAndReturnVerifiedModule(module_str));
104   auto root_tuple = module->entry_computation()->root_instruction();
105   auto i1 = root_tuple->operands()[0];
106   auto i2 = root_tuple->operands()[1];
107   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
108 }
109 
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesMatch)110 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) {
111   const char* module_str = R"(
112 HloModule foobar
113 
114 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
115   %p = f32[2] parameter(0)
116   %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
117   %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]}
118   ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
119 }
120 )";
121 
122   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
123                           ParseAndReturnVerifiedModule(module_str));
124   auto root_tuple = module->entry_computation()->root_instruction();
125   auto i1 = root_tuple->operands()[0];
126   auto i2 = root_tuple->operands()[1];
127   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
128 }
129 
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesDontMatch)130 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) {
131   const char* module_str = R"(
132 HloModule foobar
133 
134 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
135   %p = f32[2] parameter(0)
136   %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
137   %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]}
138   ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
139 }
140 )";
141 
142   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
143                           ParseAndReturnVerifiedModule(module_str));
144   auto root_tuple = module->entry_computation()->root_instruction();
145   auto i1 = root_tuple->operands()[0];
146   auto i2 = root_tuple->operands()[1];
147   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
148 }
149 
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementSameIndex)150 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) {
151   const char* module_str = R"(
152 HloModule foobar
153 
154 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
155   %p = f32[2,2] parameter(0)
156   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
157   %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
158   %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
159   %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0
160   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
161 }
162 )";
163 
164   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
165                           ParseAndReturnVerifiedModule(module_str));
166   auto root_tuple = module->entry_computation()->root_instruction();
167   auto i1 = root_tuple->operands()[0];
168   auto i2 = root_tuple->operands()[1];
169   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
170 }
171 
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex1)172 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) {
173   const char* module_str = R"(
174 HloModule foobar
175 
176 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
177   %p = f32[2,2] parameter(0)
178   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
179   %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
180   %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
181   %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
182   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
183 }
184 )";
185 
186   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
187                           ParseAndReturnVerifiedModule(module_str));
188   auto root_tuple = module->entry_computation()->root_instruction();
189   auto i1 = root_tuple->operands()[0];
190   auto i2 = root_tuple->operands()[1];
191   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
192 }
193 
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex2)194 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) {
195   const char* module_str = R"(
196 HloModule foobar
197 
198 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
199   %p = f32[2,2] parameter(0)
200   %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
201   %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}})
202   %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
203   %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
204   %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
205   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
206 }
207 )";
208 
209   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
210                           ParseAndReturnVerifiedModule(module_str));
211   auto root_tuple = module->entry_computation()->root_instruction();
212   auto i1 = root_tuple->operands()[0];
213   auto i2 = root_tuple->operands()[1];
214   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
215 }
216 
TEST_F(ArCrsCombinerTest,SameValueTestWhile1)217 TEST_F(ArCrsCombinerTest, SameValueTestWhile1) {
218   const char* module_str = R"(
219 HloModule foobar
220 
221 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
222   %x = (f32[2,2], f32[2,2]) parameter(0)
223   %constant.0 = s32[] constant(0)
224   %constant.1 = s32[] constant(1)
225   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
226 }
227 
228 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
229   %x = (f32[2,2], f32[2,2]) parameter(0)
230   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
231   %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
232   %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
233   %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
234   %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
235   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
236 }
237 
238 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
239   %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
240   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
241   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
242 }
243 )";
244 
245   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
246                           ParseAndReturnVerifiedModule(module_str));
247   auto root_while = module->entry_computation()->root_instruction();
248   auto body_tuple = root_while->while_body()->root_instruction();
249   auto i1 = body_tuple->operands()[0];
250   auto i2 = body_tuple->operands()[1];
251   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
252 }
253 
TEST_F(ArCrsCombinerTest,SameValueTestWhile2)254 TEST_F(ArCrsCombinerTest, SameValueTestWhile2) {
255   const char* module_str = R"(
256 HloModule foobar
257 
258 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
259   %x = (f32[2,2], f32[2,2]) parameter(0)
260   %constant.0 = s32[] constant(0)
261   %constant.1 = s32[] constant(1)
262   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
263 }
264 
265 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
266   %x = (f32[2,2], f32[2,2]) parameter(0)
267   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
268   %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
269   %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
270   %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
271   %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
272   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
273 }
274 
275 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
276   %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}})
277   %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}})
278   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
279   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
280 }
281 )";
282 
283   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
284                           ParseAndReturnVerifiedModule(module_str));
285   auto root_while = module->entry_computation()->root_instruction();
286   auto body_tuple = root_while->while_body()->root_instruction();
287   auto i1 = body_tuple->operands()[0];
288   auto i2 = body_tuple->operands()[1];
289   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
290 }
291 
TEST_F(ArCrsCombinerTest,SameValueTestWhile3)292 TEST_F(ArCrsCombinerTest, SameValueTestWhile3) {
293   const char* module_str = R"(
294 HloModule foobar
295 
296 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
297   %x = (f32[2,2], f32[2,2]) parameter(0)
298   %constant.0 = s32[] constant(0)
299   %constant.1 = s32[] constant(1)
300   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
301 }
302 
303 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
304   %x = (f32[2,2], f32[2,2]) parameter(0)
305   %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
306   %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}})
307   %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
308   %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
309   %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1)
310   %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2)
311   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
312 }
313 
314 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
315   %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
316   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
317   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
318 }
319 )";
320 
321   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
322                           ParseAndReturnVerifiedModule(module_str));
323   auto root_while = module->entry_computation()->root_instruction();
324   auto body_tuple = root_while->while_body()->root_instruction();
325   auto i1 = body_tuple->operands()[0]->operands()[0];  // %get-tuple-element.1
326   auto i2 = body_tuple->operands()[1]->operands()[0];  // %get-tuple-element.2
327   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
328 }
329 
TEST_F(ArCrsCombinerTest,SameValueTestNestedWhile)330 TEST_F(ArCrsCombinerTest, SameValueTestNestedWhile) {
331   const char* module_str = R"(
332 HloModule foobar
333 
334 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
335   %x = (f32[2,2], f32[2,2]) parameter(0)
336   ROOT %t = pred[] constant(true)
337 }
338 
339 %body_inner (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
340   %x = (f32[2,2], f32[2,2]) parameter(0)
341   %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
342   %gte.1 = f32[2,2] get-tuple-element(%x), index=0
343   %gte.2 = f32[2,2] get-tuple-element(%x), index=1
344   %add.1 = f32[2,2] add(%gte.1, %constant.f32)
345   %add.2 = f32[2,2] add(%gte.2, %constant.f32)
346   ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
347 }
348 
349 %body_outer (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
350   %x = (f32[2,2], f32[2,2]) parameter(0)
351   %gte.1 = f32[2,2] get-tuple-element(%x), index=0
352   %gte.2 = f32[2,2] get-tuple-element(%x), index=1
353   %init = (f32[2,2], f32[2,2]) tuple(%gte.1, %gte.2)
354   ROOT %while.1 = (f32[2,2], f32[2,2]) while(%init), condition=%condition,
355     body=%body_inner
356 }
357 
358 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
359   %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
360   %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
361   ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition,
362     body=%body_outer
363 }
364 )";
365 
366   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
367                           ParseAndReturnVerifiedModule(module_str));
368 
369   auto root_while = module->entry_computation()->root_instruction();
370   auto inner_while = root_while->while_body()->root_instruction();
371   auto i1 = inner_while->while_body()->root_instruction()->operands()[0];
372   auto i2 = inner_while->while_body()->root_instruction()->operands()[1];
373   // They are the same because the same constant {{3, 4}, {5, 6}} flows to both,
374   // and we add the same number {{1, 2}, {3, 4}} to both in each iteration
375   // of the inner while.
376   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
377 }
378 
CompareReplicaGroups(absl::Span<const ReplicaGroup> groups_before,absl::Span<const ReplicaGroup> groups_after)379 void CompareReplicaGroups(absl::Span<const ReplicaGroup> groups_before,
380                           absl::Span<const ReplicaGroup> groups_after) {
381   ASSERT_EQ(groups_before.size(), groups_after.size());
382   for (int i = 0; i < groups_before.size(); ++i) {
383     // Somewhat verbose way to compare the replica_ids, because EqualsProto
384     // is not available in the open-source build.
385     auto group_before = groups_before[i];
386     std::vector<int64_t> ids_before(group_before.replica_ids().begin(),
387                                     group_before.replica_ids().end());
388     auto group_after = groups_after[i];
389     std::vector<int64_t> ids_after(group_after.replica_ids().begin(),
390                                    group_after.replica_ids().end());
391     EXPECT_EQ(ids_before, ids_after);
392   }
393 }
394 
TEST_F(ArCrsCombinerTest,RewriteArConvertCrs)395 TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
396   const char* module_str = R"(
397 HloModule foobar
398 
399 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
400   %a = bf16[] parameter(0)
401   %b = bf16[] parameter(1)
402   ROOT %add = bf16[] add(%a, %b)
403 }
404 
405 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
406   %x = f32[] parameter(0)
407   %y = f32[] parameter(1)
408   ROOT %add = f32[] add(%x, %y)
409 }
410 
411 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
412   %p = bf16[] parameter(0)
413   %constant.bf16 = bf16[] constant(1)
414 
415   %all-reduce.ar.1 = bf16[]
416       all-reduce(%p),
417       replica_groups={{0},{1}},
418       channel_id=1,
419       to_apply=%sum.bf16,
420       sharding={maximal device=0}
421   %convert.1 = f32[]
422       convert(%all-reduce.ar.1),
423       sharding={maximal device=0}
424   %all-reduce.1 = f32[]
425       all-reduce(%convert.1),
426       replica_groups={{0,1}},
427       to_apply=%sum.f32,
428       sharding={maximal device=0}
429 
430   %all-reduce.ar.2 = bf16[]
431       all-reduce(%constant.bf16),
432       replica_groups={{0},{1}},
433       channel_id=1,
434       to_apply=%sum.bf16,
435       sharding={maximal device=1}
436   %convert.2 = f32[]
437       convert(%all-reduce.ar.2),
438       sharding={maximal device=1}
439   %all-reduce.2 = f32[]
440       all-reduce(%convert.2),
441       replica_groups={{0,1}},
442       to_apply=%sum.f32,
443       sharding={maximal device=1}
444 
445   ROOT %tuple = (f32[], f32[])
446       tuple(%all-reduce.1, %all-reduce.2),
447       sharding={{maximal device=0}, {maximal device=1}}
448 }
449 )";
450 
451   TF_ASSERT_OK_AND_ASSIGN(
452       std::unique_ptr<HloModule> module,
453       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
454   auto crs_before =
455       module->entry_computation()->root_instruction()->operands()[0];
456   auto replica_groups_before = crs_before->replica_groups();
457   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
458                          /*spmd_partition=*/false);
459   auto changed = combiner.Run(module.get()).ValueOrDie();
460   EXPECT_TRUE(changed);
461   EXPECT_THAT(module->entry_computation()->root_instruction(),
462               op::Tuple(op::AllReduce(op::Convert(op::Parameter())),
463                         op::AllReduce(op::Convert(op::Constant()))));
464   auto crs_after =
465       module->entry_computation()->root_instruction()->operands()[0];
466   auto replica_groups_after = crs_after->replica_groups();
467   CompareReplicaGroups(replica_groups_before, replica_groups_after);
468 }
469 
TEST_F(ArCrsCombinerTest,RewriteArConvertCrsSPMD)470 TEST_F(ArCrsCombinerTest, RewriteArConvertCrsSPMD) {
471   const char* module_str = R"(
472 HloModule foobar
473 
474 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
475   %a = bf16[] parameter(0)
476   %b = bf16[] parameter(1)
477   ROOT %add = bf16[] add(%a, %b)
478 }
479 
480 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
481   %x = f32[] parameter(0)
482   %y = f32[] parameter(1)
483   ROOT %add = f32[] add(%x, %y)
484 }
485 
486 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
487   %p = bf16[] parameter(0)
488   %all-reduce.ar.1 = bf16[]
489       all-reduce(%p),
490       replica_groups={{0},{1}},
491       channel_id=1,
492       to_apply=%sum.bf16
493   %convert.1 = f32[] convert(%all-reduce.ar.1)
494   %all-reduce.1 = f32[]
495       all-reduce(%convert.1),
496       replica_groups={{0,1}},
497       to_apply=%sum.f32
498   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
499 }
500 )";
501 
502   TF_ASSERT_OK_AND_ASSIGN(
503       std::unique_ptr<HloModule> module,
504       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
505   auto crs_before =
506       module->entry_computation()->root_instruction()->operands()[0];
507   auto replica_groups_before = crs_before->replica_groups();
508   ArCrsCombiner combiner(/*num_spatial_partitions=*/2, true);
509   auto changed = combiner.Run(module.get()).ValueOrDie();
510   EXPECT_TRUE(changed);
511   EXPECT_THAT(module->entry_computation()->root_instruction(),
512               op::Tuple(op::AllReduce(op::Convert(op::Parameter()))));
513   auto crs_after =
514       module->entry_computation()->root_instruction()->operands()[0];
515   auto replica_groups_after = crs_after->replica_groups();
516   CompareReplicaGroups(replica_groups_before, replica_groups_after);
517 }
518 
TEST_F(ArCrsCombinerTest,RewriteArBitcastCrs)519 TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
520   const char* module_str = R"(
521 HloModule foobar
522 
523 %sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
524   %a = f32[2,1] parameter(0)
525   %b = f32[2,1] parameter(1)
526   ROOT %add = f32[2,1] add(%a, %b)
527 }
528 
529 %sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
530   %x = f32[2] parameter(0)
531   %y = f32[2] parameter(1)
532   ROOT %add = f32[2] add(%x, %y)
533 }
534 
535 ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
536   %p = f32[2,1] parameter(0)
537 
538   %all-reduce.ar.1 = f32[2,1]
539       all-reduce(%p),
540       replica_groups={{0},{1}},
541       channel_id=1,
542       to_apply=%sum.1,
543       sharding={maximal device=0}
544   %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
545   %all-reduce.1 = f32[2]
546       all-reduce(%bitcast.1),
547       replica_groups={{0,1}},
548       to_apply=%sum.2,
549       sharding={maximal device=0}
550 
551   %all-reduce.ar.2 = f32[2,1]
552       all-reduce(%p),
553       replica_groups={{0},{1}},
554       channel_id=1,
555       to_apply=%sum.1,
556       sharding={maximal device=1}
557   %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
558   %all-reduce.2 = f32[2]
559       all-reduce(%bitcast.2),
560       replica_groups={{0,1}},
561       to_apply=%sum.2,
562       sharding={maximal device=1}
563 
564   ROOT %tuple = (f32[2], f32[2])
565       tuple(%all-reduce.1, %all-reduce.2),
566       sharding={{maximal device=0}, {maximal device=1}}
567 }
568 )";
569 
570   TF_ASSERT_OK_AND_ASSIGN(
571       std::unique_ptr<HloModule> module,
572       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
573   auto crs_before =
574       module->entry_computation()->root_instruction()->operands()[0];
575   auto replica_groups_before = crs_before->replica_groups();
576   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
577                          /*spmd_partition=*/false);
578   auto changed = combiner.Run(module.get()).ValueOrDie();
579   EXPECT_TRUE(changed);
580   EXPECT_THAT(module->entry_computation()->root_instruction(),
581               op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())),
582                         op::AllReduce(op::Bitcast(op::Parameter()))));
583   auto crs_after =
584       module->entry_computation()->root_instruction()->operands()[0];
585   auto replica_groups_after = crs_after->replica_groups();
586   CompareReplicaGroups(replica_groups_before, replica_groups_after);
587 }
588 
TEST_F(ArCrsCombinerTest,RewriteArMultiplyCrs)589 TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
590   const char* module_str = R"(
591 HloModule foobar
592 
593 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
594   %x = f32[] parameter(0)
595   %y = f32[] parameter(1)
596   ROOT %add = f32[] add(%x, %y)
597 }
598 
599 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
600   %p = f32[] parameter(0)
601   %constant.f32 = f32[] constant(123)
602 
603   %all-reduce.ar.1 = f32[]
604       all-reduce(%p),
605       replica_groups={{0},{1}},
606       channel_id=1,
607       to_apply=%sum.f32,
608       sharding={maximal device=0}
609   %multiply.1 = f32[]
610       multiply(%all-reduce.ar.1, %constant.f32),
611       sharding={maximal device=0}
612   %all-reduce.1 = f32[]
613       all-reduce(%multiply.1),
614       replica_groups={{0,1}},
615       to_apply=%sum.f32,
616       sharding={maximal device=0}
617 
618   %all-reduce.ar.2 = f32[]
619       all-reduce(%p),
620       replica_groups={{0},{1}},
621       channel_id=1,
622       to_apply=%sum.f32,
623       sharding={maximal device=1}
624   %multiply.2 = f32[]
625       multiply(%all-reduce.ar.2, %constant.f32),
626       sharding={maximal device=1}
627   %all-reduce.2 = f32[]
628       all-reduce(%multiply.2),
629       replica_groups={{0,1}},
630       to_apply=%sum.f32,
631       sharding={maximal device=1}
632 
633   ROOT %tuple = (f32[], f32[])
634       tuple(%all-reduce.1, %all-reduce.2),
635       sharding={{maximal device=0}, {maximal device=1}}
636 }
637 )";
638 
639   TF_ASSERT_OK_AND_ASSIGN(
640       std::unique_ptr<HloModule> module,
641       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
642   auto crs_before =
643       module->entry_computation()->root_instruction()->operands()[0];
644   auto replica_groups_before = crs_before->replica_groups();
645   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
646                          /*spmd_partition=*/false);
647   auto changed = combiner.Run(module.get()).ValueOrDie();
648   EXPECT_TRUE(changed);
649   EXPECT_THAT(
650       module->entry_computation()->root_instruction(),
651       op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())),
652                 op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
653   auto crs_after =
654       module->entry_computation()->root_instruction()->operands()[0];
655   auto replica_groups_after = crs_after->replica_groups();
656   CompareReplicaGroups(replica_groups_before, replica_groups_after);
657 }
658 
TEST_F(ArCrsCombinerTest,RewriteArMultiplyCrsSPMD)659 TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrsSPMD) {
660   const char* module_str = R"(
661 HloModule foobar
662 
663 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
664   %x = f32[] parameter(0)
665   %y = f32[] parameter(1)
666   ROOT %add = f32[] add(%x, %y)
667 }
668 
669 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
670   %p = f32[] parameter(0)
671   %constant.f32 = f32[] constant(123)
672 
673   %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
674       channel_id=1, to_apply=%sum.f32
675   %multiply.1 = f32[] multiply(%all-reduce.ar.1, %constant.f32)
676   %all-reduce.1 = f32[] all-reduce(%multiply.1), replica_groups={{0,1}},
677       to_apply=%sum.f32, sharding={maximal device=0}
678   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
679 }
680 )";
681 
682   TF_ASSERT_OK_AND_ASSIGN(
683       std::unique_ptr<HloModule> module,
684       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
685   auto crs_before =
686       module->entry_computation()->root_instruction()->operands()[0];
687   auto replica_groups_before = crs_before->replica_groups();
688   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
689                          /*spmd_partition=*/true);
690   auto changed = combiner.Run(module.get()).ValueOrDie();
691   EXPECT_TRUE(changed);
692   EXPECT_THAT(
693       module->entry_computation()->root_instruction(),
694       op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
695   auto crs_after =
696       module->entry_computation()->root_instruction()->operands()[0];
697   auto replica_groups_after = crs_after->replica_groups();
698   CompareReplicaGroups(replica_groups_before, replica_groups_after);
699 }
700 
TEST_F(ArCrsCombinerTest,RewriteArConvertAddCrs)701 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
702   const char* module_str = R"(
703 HloModule foobar
704 
705 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
706   %a = bf16[] parameter(0)
707   %b = bf16[] parameter(1)
708   ROOT %add = bf16[] add(%a, %b)
709 }
710 
711 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
712   %x = f32[] parameter(0)
713   %y = f32[] parameter(1)
714   ROOT %add = f32[] add(%x, %y)
715 }
716 
717 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
718   %p = f32[] parameter(0)
719   %constant.bf16 = bf16[] constant(1)
720   %constant.f32 = f32[] constant(2)
721 
722   %all-reduce.ar.1 = bf16[]
723       all-reduce(%constant.bf16),
724       replica_groups={{0},{1}},
725       channel_id=1,
726       to_apply=%sum.bf16,
727       sharding={maximal device=0}
728   %convert.1 = f32[]
729       convert(%all-reduce.ar.1),
730       sharding={maximal device=0}
731   %add.1 = f32[]
732       add(%constant.f32, %convert.1),
733       sharding={maximal device=0}
734   %all-reduce.1 = f32[]
735       all-reduce(%add.1),
736       replica_groups={{0,1}},
737       to_apply=%sum.f32,
738       sharding={maximal device=0}
739 
740   %all-reduce.ar.2 = bf16[]
741       all-reduce(%constant.bf16),
742       replica_groups={{0},{1}},
743       channel_id=1,
744       to_apply=%sum.bf16,
745       sharding={maximal device=1}
746   %convert.2 = f32[]
747       convert(%all-reduce.ar.2),
748       sharding={maximal device=1}
749   %add.2 = f32[]
750       add(%constant.f32, %convert.2),
751       sharding={maximal device=1}
752   %all-reduce.2 = f32[]
753       all-reduce(%add.2),
754       replica_groups={{0,1}},
755       to_apply=%sum.f32,
756       sharding={maximal device=1}
757 
758   ROOT %tuple = (f32[], f32[])
759       tuple(%all-reduce.1, %all-reduce.2),
760       sharding={{maximal device=0}, {maximal device=1}}
761 }
762 )";
763 
764   TF_ASSERT_OK_AND_ASSIGN(
765       std::unique_ptr<HloModule> module,
766       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
767   auto crs_before =
768       module->entry_computation()->root_instruction()->operands()[0];
769   auto replica_groups_before = crs_before->replica_groups();
770   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
771                          /*spmd_partition=*/false);
772   auto changed = combiner.Run(module.get()).ValueOrDie();
773   EXPECT_TRUE(changed);
774   EXPECT_THAT(
775       module->entry_computation()->root_instruction(),
776       op::Tuple(
777           op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
778                                 op::Convert())),
779           op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
780                                 op::Convert()))));
781   auto crs_after =
782       module->entry_computation()->root_instruction()->operands()[0];
783   auto replica_groups_after = crs_after->replica_groups();
784   CompareReplicaGroups(replica_groups_before, replica_groups_after);
785 }
786 
TEST_F(ArCrsCombinerTest,RewriteArConvertAddCrsSPMD)787 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrsSPMD) {
788   const char* module_str = R"(
789 HloModule foobar
790 
791 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
792   %a = bf16[] parameter(0)
793   %b = bf16[] parameter(1)
794   ROOT %add = bf16[] add(%a, %b)
795 }
796 
797 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
798   %x = f32[] parameter(0)
799   %y = f32[] parameter(1)
800   ROOT %add = f32[] add(%x, %y)
801 }
802 
803 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
804   %p = f32[] parameter(0)
805   %constant.bf16 = bf16[] constant(1)
806   %constant.f32 = f32[] constant(2)
807 
808   %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
809       channel_id=1, to_apply=%sum.bf16
810   %convert.1 = f32[] convert(%all-reduce.ar.1), sharding={maximal device=0}
811   %add.1 = f32[] add(%constant.f32, %convert.1)
812   %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}},
813       to_apply=%sum.f32
814   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
815 }
816 )";
817 
818   TF_ASSERT_OK_AND_ASSIGN(
819       std::unique_ptr<HloModule> module,
820       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
821   auto crs_before =
822       module->entry_computation()->root_instruction()->operands()[0];
823   auto replica_groups_before = crs_before->replica_groups();
824   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
825                          /*spmd_partition=*/true);
826   auto changed = combiner.Run(module.get()).ValueOrDie();
827   EXPECT_TRUE(changed);
828   EXPECT_THAT(module->entry_computation()->root_instruction(),
829               op::Tuple(op::AllReduce(op::Add(
830                   op::Divide(op::Constant(), op::Constant()), op::Convert()))));
831   auto crs_after =
832       module->entry_computation()->root_instruction()->operands()[0];
833   auto replica_groups_after = crs_after->replica_groups();
834   CompareReplicaGroups(replica_groups_before, replica_groups_after);
835 }
836 
TEST_F(ArCrsCombinerTest,OtherSummandNotTheSameDontRewrite)837 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
838   const char* module_str = R"(
839 HloModule foobar
840 
841 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
842   %a = bf16[] parameter(0)
843   %b = bf16[] parameter(1)
844   ROOT %add = bf16[] add(%a, %b)
845 }
846 
847 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
848   %x = f32[] parameter(0)
849   %y = f32[] parameter(1)
850   ROOT %add = f32[] add(%x, %y)
851 }
852 
853 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
854   %p = f32[] parameter(0)
855   %constant.bf16 = bf16[] constant(1)
856   %constant.f32.1 = f32[] constant(2)
857   %constant.f32.2 = f32[] constant(3)
858 
859   %all-reduce.ar.1 = bf16[]
860       all-reduce(%constant.bf16),
861       replica_groups={{0},{1}},
862       channel_id=1,
863       to_apply=%sum.bf16,
864       sharding={maximal device=0}
865   %convert.1 = f32[]
866       convert(%all-reduce.ar.1),
867       sharding={maximal device=0}
868   %add.1 = f32[]
869       add(%constant.f32.1, %convert.1),
870       sharding={maximal device=0}
871   %all-reduce.1 = f32[]
872       all-reduce(%add.1),
873       replica_groups={{0,1}},
874       to_apply=%sum.f32,
875       sharding={maximal device=0}
876 
877   %all-reduce.ar.2 = bf16[]
878       all-reduce(%constant.bf16),
879       replica_groups={{0},{1}},
880       channel_id=1,
881       to_apply=%sum.bf16,
882       sharding={maximal device=1}
883   %convert.2 = f32[]
884       convert(%all-reduce.ar.2),
885       sharding={maximal device=1}
886   %add.2 = f32[]
887       add(%constant.f32.2, %convert.2),
888       sharding={maximal device=1}
889   %all-reduce.2 = f32[]
890       all-reduce(%add.2),
891       replica_groups={{0,1}},
892       to_apply=%sum.f32,
893       sharding={maximal device=1}
894 
895   ROOT %tuple = (f32[], f32[])
896       tuple(%all-reduce.1, %all-reduce.2),
897       sharding={{maximal device=0}, {maximal device=1}}
898 }
899 )";
900 
901   TF_ASSERT_OK_AND_ASSIGN(
902       std::unique_ptr<HloModule> module,
903       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
904   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
905                          /*spmd_partition=*/false);
906   auto changed = combiner.Run(module.get()).ValueOrDie();
907   EXPECT_FALSE(changed);
908 }
909 
TEST_F(ArCrsCombinerTest,OtherSummandNotTheSameDontRewriteSPMD)910 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewriteSPMD) {
911   const char* module_str = R"(
912 HloModule foobar
913 
914 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
915   %a = bf16[] parameter(0)
916   %b = bf16[] parameter(1)
917   ROOT %add = bf16[] add(%a, %b)
918 }
919 
920 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
921   %x = f32[] parameter(0)
922   %y = f32[] parameter(1)
923   ROOT %add = f32[] add(%x, %y)
924 }
925 
926 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
927   %p = f32[] parameter(0)
928   %constant.bf16 = bf16[] constant(1)
929   %constant.f32.1 = f32[] constant(2)
930 
931   %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}},
932       channel_id=1, to_apply=%sum.bf16
933   %convert.1 = f32[] convert(%all-reduce.ar.1)
934   %add.1 = f32[] add(%p, %convert.1)
935   %all-reduce.1 = f32[] all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32
936   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
937 }
938 )";
939 
940   TF_ASSERT_OK_AND_ASSIGN(
941       std::unique_ptr<HloModule> module,
942       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
943   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
944                          /*spmd_partition=*/true);
945   auto changed = combiner.Run(module.get()).ValueOrDie();
946   EXPECT_FALSE(changed);
947 }
948 
TEST_F(ArCrsCombinerTest,ArThenCrsDontCrash)949 TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) {
950   const char* module_str = R"(
951 HloModule foobar
952 
953 %sum.1 (a: f32[], b: f32[]) -> f32[] {
954   %a = f32[] parameter(0)
955   %b = f32[] parameter(1)
956   ROOT %add = f32[] add(%a, %b)
957 }
958 
959 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
960   %p = f32[] parameter(0)
961   %constant.f32 = f32[] constant(123)
962 
963   %all-reduce.ar.1 = f32[]
964       all-reduce(%p),
965       replica_groups={{0},{1}},
966       channel_id=1,
967       to_apply=%sum.1,
968       sharding={maximal device=0}
969   %all-reduce.1 = f32[]
970       all-reduce(%all-reduce.ar.1),
971       replica_groups={{0,1}},
972       to_apply=%sum.1,
973       sharding={maximal device=0}
974   %multiply.1 = f32[]
975       multiply(%all-reduce.1, %constant.f32),
976       sharding={maximal device=0}
977 
978   %all-reduce.ar.2 = f32[]
979       all-reduce(%p),
980       replica_groups={{0},{1}},
981       channel_id=1,
982       to_apply=%sum.1,
983       sharding={maximal device=1}
984   %all-reduce.2 = f32[]
985       all-reduce(%all-reduce.ar.2),
986       replica_groups={{0,1}},
987       to_apply=%sum.1,
988       sharding={maximal device=1}
989   %multiply.2 = f32[]
990       multiply(%all-reduce.2, %constant.f32),
991       sharding={maximal device=1}
992 
993   ROOT %tuple = (f32[], f32[])
994       tuple(%all-reduce.1, %all-reduce.2),
995       sharding={{maximal device=0}, {maximal device=1}}
996 }
997 )";
998 
999   TF_ASSERT_OK_AND_ASSIGN(
1000       std::unique_ptr<HloModule> module,
1001       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1002   auto crs_before =
1003       module->entry_computation()->root_instruction()->operands()[0];
1004   auto replica_groups_before = crs_before->replica_groups();
1005   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1006                          /*spmd_partition=*/false);
1007   auto changed = combiner.Run(module.get()).ValueOrDie();
1008   EXPECT_TRUE(changed);
1009   EXPECT_THAT(module->entry_computation()->root_instruction(),
1010               op::Tuple(op::AllReduce(op::Parameter()),
1011                         op::AllReduce(op::Parameter())));
1012   auto crs_after =
1013       module->entry_computation()->root_instruction()->operands()[0];
1014   auto replica_groups_after = crs_after->replica_groups();
1015   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1016 }
1017 
TEST_F(ArCrsCombinerTest,RewriteMultipleAdds)1018 TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) {
1019   const char* module_str = R"(
1020 HloModule foobar
1021 
1022 %sum (x: f32[], y: f32[]) -> f32[] {
1023   %x = f32[] parameter(0)
1024   %y = f32[] parameter(1)
1025   ROOT %add = f32[] add(%x, %y)
1026 }
1027 
1028 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1029   %p = f32[] parameter(0)
1030   %constant.1 = f32[] constant(1)
1031   %constant.2 = f32[] constant(2)
1032 
1033   %all-reduce.ar.1 = f32[]
1034       all-reduce(%p),
1035       replica_groups={{0},{1}},
1036       channel_id=1,
1037       to_apply=%sum,
1038       sharding={maximal device=0}
1039   %add.11 = f32[]
1040       add(%constant.1, %all-reduce.ar.1),
1041       sharding={maximal device=0}
1042   %add.12 = f32[]
1043       add(%constant.2, %add.11),
1044       sharding={maximal device=0}
1045   %all-reduce.1 = f32[]
1046       all-reduce(%add.12),
1047       replica_groups={{0,1}},
1048       to_apply=%sum,
1049       sharding={maximal device=0}
1050 
1051   %all-reduce.ar.2 = f32[]
1052       all-reduce(%p),
1053       replica_groups={{0},{1}},
1054       channel_id=1,
1055       to_apply=%sum,
1056       sharding={maximal device=0}
1057   %add.21 = f32[]
1058       add(%constant.1, %all-reduce.ar.2),
1059       sharding={maximal device=0}
1060   %add.22 = f32[]
1061       add(%constant.2, %add.21),
1062       sharding={maximal device=0}
1063   %all-reduce.2 = f32[]
1064       all-reduce(%add.22),
1065       replica_groups={{0,1}},
1066       to_apply=%sum,
1067       sharding={maximal device=0}
1068 
1069   ROOT %tuple = (f32[], f32[])
1070       tuple(%all-reduce.1, %all-reduce.2),
1071       sharding={{maximal device=0}, {maximal device=1}}
1072 }
1073 )";
1074 
1075   TF_ASSERT_OK_AND_ASSIGN(
1076       std::unique_ptr<HloModule> module,
1077       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1078   auto crs_before =
1079       module->entry_computation()->root_instruction()->operands()[0];
1080   auto replica_groups_before = crs_before->replica_groups();
1081   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1082                          /*spmd_partition=*/false);
1083   auto changed = combiner.Run(module.get()).ValueOrDie();
1084   EXPECT_TRUE(changed);
1085   EXPECT_THAT(module->entry_computation()->root_instruction(),
1086               op::Tuple(op::AllReduce(op::Add(
1087                             op::Divide(op::Constant(), op::Constant()),
1088                             op::Add(op::Divide(op::Constant(), op::Constant()),
1089                                     op::Parameter()))),
1090                         op::AllReduce(op::Add(
1091                             op::Divide(op::Constant(), op::Constant()),
1092                             op::Add(op::Divide(op::Constant(), op::Constant()),
1093                                     op::Parameter())))));
1094   auto crs_after =
1095       module->entry_computation()->root_instruction()->operands()[0];
1096   auto replica_groups_after = crs_after->replica_groups();
1097   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1098 }
1099 
TEST_F(ArCrsCombinerTest,RewriteMultipleAddsSPMD)1100 TEST_F(ArCrsCombinerTest, RewriteMultipleAddsSPMD) {
1101   const char* module_str = R"(
1102 HloModule foobar
1103 
1104 %sum (x: f32[], y: f32[]) -> f32[] {
1105   %x = f32[] parameter(0)
1106   %y = f32[] parameter(1)
1107   ROOT %add = f32[] add(%x, %y)
1108 }
1109 
1110 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1111   %p = f32[] parameter(0)
1112   %constant.1 = f32[] constant(1)
1113   %constant.2 = f32[] constant(2)
1114 
1115   %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
1116       channel_id=1, to_apply=%sum
1117   %add.11 = f32[] add(%constant.1, %all-reduce.ar.1)
1118   %add.12 = f32[] add(%constant.2, %add.11)
1119   %all-reduce.1 = f32[] all-reduce(%add.12), replica_groups={{0,1}}, to_apply=%sum
1120   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
1121 }
1122 )";
1123 
1124   TF_ASSERT_OK_AND_ASSIGN(
1125       std::unique_ptr<HloModule> module,
1126       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1127   auto crs_before =
1128       module->entry_computation()->root_instruction()->operands()[0];
1129   auto replica_groups_before = crs_before->replica_groups();
1130   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1131                          /*spmd_partition=*/true);
1132   auto changed = combiner.Run(module.get()).ValueOrDie();
1133   EXPECT_TRUE(changed);
1134   EXPECT_THAT(module->entry_computation()->root_instruction(),
1135               op::Tuple(op::AllReduce(
1136                   op::Add(op::Divide(op::Constant(), op::Constant()),
1137                           op::Add(op::Divide(op::Constant(), op::Constant()),
1138                                   op::Parameter())))));
1139   auto crs_after =
1140       module->entry_computation()->root_instruction()->operands()[0];
1141   auto replica_groups_after = crs_after->replica_groups();
1142   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1143 }
1144 
TEST_F(ArCrsCombinerTest,RewriteArSubtractCrs)1145 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
1146   const char* module_str = R"(
1147 HloModule foobar
1148 
1149 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1150   %x = f32[] parameter(0)
1151   %y = f32[] parameter(1)
1152   ROOT %add = f32[] add(%x, %y)
1153 }
1154 
1155 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1156   %p = f32[] parameter(0)
1157   %constant.f32 = f32[] constant(123)
1158 
1159   %all-reduce.ar.1 = f32[]
1160       all-reduce(%p),
1161       replica_groups={{0},{1}},
1162       channel_id=1,
1163       to_apply=%sum.f32,
1164       sharding={maximal device=0}
1165   %sub.1 = f32[]
1166       subtract(%constant.f32, %all-reduce.ar.1),
1167       sharding={maximal device=0}
1168   %all-reduce.1 = f32[]
1169       all-reduce(%sub.1),
1170       replica_groups={{0,1}},
1171       to_apply=%sum.f32,
1172       sharding={maximal device=0}
1173 
1174   %all-reduce.ar.2 = f32[]
1175       all-reduce(%p),
1176       replica_groups={{0},{1}},
1177       channel_id=1,
1178       to_apply=%sum.f32,
1179       sharding={maximal device=1}
1180   %sub.2 = f32[]
1181       subtract(%constant.f32, %all-reduce.ar.2),
1182       sharding={maximal device=1}
1183   %all-reduce.2 = f32[]
1184       all-reduce(%sub.2),
1185       replica_groups={{0,1}},
1186       to_apply=%sum.f32,
1187       sharding={maximal device=1}
1188 
1189   ROOT %tuple = (f32[], f32[])
1190       tuple(%all-reduce.1, %all-reduce.2),
1191       sharding={{maximal device=0}, {maximal device=1}}
1192 }
1193 )";
1194 
1195   TF_ASSERT_OK_AND_ASSIGN(
1196       std::unique_ptr<HloModule> module,
1197       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1198   auto crs_before =
1199       module->entry_computation()->root_instruction()->operands()[0];
1200   auto replica_groups_before = crs_before->replica_groups();
1201   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1202                          /*spmd_partition=*/false);
1203   auto changed = combiner.Run(module.get()).ValueOrDie();
1204   EXPECT_TRUE(changed);
1205   EXPECT_THAT(
1206       module->entry_computation()->root_instruction(),
1207       op::Tuple(
1208           op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
1209                                      op::Parameter())),
1210           op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
1211                                      op::Parameter()))));
1212   auto crs_after =
1213       module->entry_computation()->root_instruction()->operands()[0];
1214   auto replica_groups_after = crs_after->replica_groups();
1215   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1216 }
1217 
TEST_F(ArCrsCombinerTest,RewriteArSubtractCrsSPMD)1218 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrsSPMD) {
1219   const char* module_str = R"(
1220 HloModule foobar
1221 
1222 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1223   %x = f32[] parameter(0)
1224   %y = f32[] parameter(1)
1225   ROOT %add = f32[] add(%x, %y)
1226 }
1227 
1228 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1229   %p = f32[] parameter(0)
1230   %constant.f32 = f32[] constant(123)
1231   %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}},
1232       channel_id=1, to_apply=%sum.f32
1233   %sub.1 = f32[] subtract(%constant.f32, %all-reduce.ar.1)
1234   %all-reduce.1 = f32[] all-reduce(%sub.1), replica_groups={{0,1}},
1235       to_apply=%sum.f32
1236   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
1237 }
1238 )";
1239 
1240   TF_ASSERT_OK_AND_ASSIGN(
1241       std::unique_ptr<HloModule> module,
1242       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1243   auto crs_before =
1244       module->entry_computation()->root_instruction()->operands()[0];
1245   auto replica_groups_before = crs_before->replica_groups();
1246   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1247                          /*spmd_partition=*/true);
1248   auto changed = combiner.Run(module.get()).ValueOrDie();
1249   EXPECT_TRUE(changed);
1250   EXPECT_THAT(
1251       module->entry_computation()->root_instruction(),
1252       op::Tuple(op::AllReduce(op::Subtract(
1253           op::Divide(op::Constant(), op::Constant()), op::Parameter()))));
1254   auto crs_after =
1255       module->entry_computation()->root_instruction()->operands()[0];
1256   auto replica_groups_after = crs_after->replica_groups();
1257   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1258 }
1259 
TEST_F(ArCrsCombinerTest,RewriteMultipleARsLeft)1260 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
1261   const char* module_str = R"(
1262 HloModule foobar
1263 
1264 %sum (x: f32[], y: f32[]) -> f32[] {
1265   %x = f32[] parameter(0)
1266   %y = f32[] parameter(1)
1267   ROOT %add = f32[] add(%x, %y)
1268 }
1269 
1270 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1271   %p = f32[] parameter(0)
1272   %const1 = f32[] constant(1)
1273   %const2 = f32[] constant(2)
1274 
1275   %ar11 = f32[]
1276       all-reduce(%p),
1277       replica_groups={{0},{1}},
1278       channel_id=1,
1279       to_apply=%sum,
1280       sharding={maximal device=0}
1281   %add11 = f32[]
1282       add(%ar11, %const1),
1283       sharding={maximal device=0}
1284   %ar12 = f32[]
1285       all-reduce(%p),
1286       replica_groups={{0},{1}},
1287       channel_id=2,
1288       to_apply=%sum,
1289       sharding={maximal device=0}
1290   %add12 = f32[]
1291       add(%add11, %ar12),
1292       sharding={maximal device=0}
1293   %crs1 = f32[]
1294       all-reduce(%add12),
1295       replica_groups={{0,1}},
1296       to_apply=%sum,
1297       sharding={maximal device=0}
1298 
1299   %ar21 = f32[]
1300       all-reduce(%p),
1301       replica_groups={{0},{1}},
1302       channel_id=1,
1303       to_apply=%sum,
1304       sharding={maximal device=1}
1305   %add21 = f32[]
1306       add(%ar21, %const1),
1307       sharding={maximal device=1}
1308   %ar22 = f32[]
1309       all-reduce(%p),
1310       replica_groups={{0},{1}},
1311       channel_id=2,
1312       to_apply=%sum,
1313       sharding={maximal device=1}
1314   %add22 = f32[]
1315       add(%add21, %ar22),
1316       sharding={maximal device=1}
1317   %crs2 = f32[]
1318       all-reduce(%add22),
1319       replica_groups={{0,1}},
1320       to_apply=%sum,
1321       sharding={maximal device=1}
1322 
1323   ROOT %tuple = (f32[], f32[])
1324       tuple(%crs1, %crs2),
1325       sharding={{maximal device=0}, {maximal device=1}}
1326 }
1327 )";
1328 
1329   TF_ASSERT_OK_AND_ASSIGN(
1330       std::unique_ptr<HloModule> module,
1331       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1332   auto crs_before =
1333       module->entry_computation()->root_instruction()->operands()[0];
1334   auto replica_groups_before = crs_before->replica_groups();
1335   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1336                          /*spmd_partition=*/false);
1337   auto changed = combiner.Run(module.get()).ValueOrDie();
1338   EXPECT_TRUE(changed);
1339   EXPECT_THAT(module->entry_computation()->root_instruction(),
1340               op::Tuple(op::AllReduce(op::Add(
1341                             op::Add(op::Parameter(),
1342                                     op::Divide(op::Constant(), op::Constant())),
1343                             op::Parameter())),
1344                         op::AllReduce(op::Add(
1345                             op::Add(op::Parameter(),
1346                                     op::Divide(op::Constant(), op::Constant())),
1347                             op::Parameter()))));
1348   auto crs_after =
1349       module->entry_computation()->root_instruction()->operands()[0];
1350   auto replica_groups_after = crs_after->replica_groups();
1351   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1352 }
1353 
TEST_F(ArCrsCombinerTest,RewriteMultipleARsLeftSPMD)1354 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeftSPMD) {
1355   const char* module_str = R"(
1356 HloModule foobar
1357 
1358 %sum (x: f32[], y: f32[]) -> f32[] {
1359   %x = f32[] parameter(0)
1360   %y = f32[] parameter(1)
1361   ROOT %add = f32[] add(%x, %y)
1362 }
1363 
1364 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1365   %p = f32[] parameter(0)
1366   %const1 = f32[] constant(1)
1367   %const2 = f32[] constant(2)
1368 
1369   %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1,
1370       to_apply=%sum
1371   %add11 = f32[] add(%ar11, %const1)
1372   %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2,
1373       to_apply=%sum
1374   %add12 = f32[] add(%add11, %ar12)
1375   %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}},
1376       to_apply=%sum
1377   ROOT %tuple = (f32[]) tuple(%crs1)
1378 }
1379 )";
1380 
1381   TF_ASSERT_OK_AND_ASSIGN(
1382       std::unique_ptr<HloModule> module,
1383       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1384   auto crs_before =
1385       module->entry_computation()->root_instruction()->operands()[0];
1386   auto replica_groups_before = crs_before->replica_groups();
1387   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1388                          /*spmd_partition=*/true);
1389   auto changed = combiner.Run(module.get()).ValueOrDie();
1390   EXPECT_TRUE(changed);
1391   EXPECT_THAT(
1392       module->entry_computation()->root_instruction(),
1393       op::Tuple(op::AllReduce(op::Add(
1394           op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())),
1395           op::Parameter()))));
1396   auto crs_after =
1397       module->entry_computation()->root_instruction()->operands()[0];
1398   auto replica_groups_after = crs_after->replica_groups();
1399   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1400 }
1401 
TEST_F(ArCrsCombinerTest,RewriteMultipleARsRight)1402 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
1403   const char* module_str = R"(
1404 HloModule foobar
1405 
1406 %sum (x: f32[], y: f32[]) -> f32[] {
1407   %x = f32[] parameter(0)
1408   %y = f32[] parameter(1)
1409   ROOT %add = f32[] add(%x, %y)
1410 }
1411 
1412 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1413   %p = f32[] parameter(0)
1414   %const1 = f32[] constant(1)
1415   %const2 = f32[] constant(2)
1416 
1417   %ar11 = f32[]
1418       all-reduce(%p),
1419       replica_groups={{0},{1}},
1420       channel_id=1,
1421       to_apply=%sum,
1422       sharding={maximal device=0}
1423   %ar12 = f32[]
1424       all-reduce(%p),
1425       replica_groups={{0},{1}},
1426       channel_id=2,
1427       to_apply=%sum,
1428       sharding={maximal device=0}
1429   %add11 = f32[]
1430       add(%ar12, %const1),
1431       sharding={maximal device=0}
1432   %add12 = f32[]
1433       add(%ar11, %add11),
1434       sharding={maximal device=0}
1435   %crs1 = f32[]
1436       all-reduce(%add12),
1437       replica_groups={{0,1}},
1438       to_apply=%sum,
1439       sharding={maximal device=0}
1440 
1441   %ar21 = f32[]
1442       all-reduce(%p),
1443       replica_groups={{0},{1}},
1444       channel_id=1,
1445       to_apply=%sum,
1446       sharding={maximal device=1}
1447   %ar22 = f32[]
1448       all-reduce(%p),
1449       replica_groups={{0},{1}},
1450       channel_id=2,
1451       to_apply=%sum,
1452       sharding={maximal device=1}
1453   %add21 = f32[]
1454       add(%ar22, %const1),
1455       sharding={maximal device=1}
1456   %add22 = f32[]
1457       add(%ar21, %add21),
1458       sharding={maximal device=1}
1459   %crs2 = f32[]
1460       all-reduce(%add22),
1461       replica_groups={{0,1}},
1462       to_apply=%sum,
1463       sharding={maximal device=1}
1464 
1465   ROOT %tuple = (f32[], f32[])
1466       tuple(%crs1, %crs2),
1467       sharding={{maximal device=0}, {maximal device=1}}
1468 }
1469 )";
1470 
1471   TF_ASSERT_OK_AND_ASSIGN(
1472       std::unique_ptr<HloModule> module,
1473       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1474   auto crs_before =
1475       module->entry_computation()->root_instruction()->operands()[0];
1476   auto replica_groups_before = crs_before->replica_groups();
1477   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1478                          /*spmd_partition=*/false);
1479   auto changed = combiner.Run(module.get()).ValueOrDie();
1480   EXPECT_TRUE(changed);
1481   EXPECT_THAT(
1482       module->entry_computation()->root_instruction(),
1483       op::Tuple(op::AllReduce(op::Add(
1484                     op::Parameter(),
1485                     op::Add(op::Parameter(),
1486                             op::Divide(op::Constant(), op::Constant())))),
1487                 op::AllReduce(op::Add(
1488                     op::Parameter(),
1489                     op::Add(op::Parameter(),
1490                             op::Divide(op::Constant(), op::Constant()))))));
1491 
1492   auto crs_after =
1493       module->entry_computation()->root_instruction()->operands()[0];
1494   auto replica_groups_after = crs_after->replica_groups();
1495   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1496 }
1497 
TEST_F(ArCrsCombinerTest,RewriteMultipleARsRightSPMD)1498 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRightSPMD) {
1499   const char* module_str = R"(
1500 HloModule foobar
1501 
1502 %sum (x: f32[], y: f32[]) -> f32[] {
1503   %x = f32[] parameter(0)
1504   %y = f32[] parameter(1)
1505   ROOT %add = f32[] add(%x, %y)
1506 }
1507 
1508 ENTRY %entrycomp (p: f32[]) -> (f32[]) {
1509   %p = f32[] parameter(0)
1510   %const1 = f32[] constant(1)
1511   %const2 = f32[] constant(2)
1512 
1513   %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=1, to_apply=%sum
1514   %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, channel_id=2, to_apply=%sum
1515   %add11 = f32[] add(%ar12, %const1)
1516   %add12 = f32[] add(%ar11, %add11)
1517   %crs1 = f32[] all-reduce(%add12), replica_groups={{0,1}}, to_apply=%sum
1518   ROOT %tuple = (f32[]) tuple(%crs1)
1519 }
1520 )";
1521 
1522   TF_ASSERT_OK_AND_ASSIGN(
1523       std::unique_ptr<HloModule> module,
1524       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1525   auto crs_before =
1526       module->entry_computation()->root_instruction()->operands()[0];
1527   auto replica_groups_before = crs_before->replica_groups();
1528   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1529                          /*spmd_partition=*/true);
1530   auto changed = combiner.Run(module.get()).ValueOrDie();
1531   EXPECT_TRUE(changed);
1532   EXPECT_THAT(module->entry_computation()->root_instruction(),
1533               op::Tuple(op::AllReduce(op::Add(
1534                   op::Parameter(),
1535                   op::Add(op::Parameter(),
1536                           op::Divide(op::Constant(), op::Constant()))))));
1537 
1538   auto crs_after =
1539       module->entry_computation()->root_instruction()->operands()[0];
1540   auto replica_groups_after = crs_after->replica_groups();
1541   CompareReplicaGroups(replica_groups_before, replica_groups_after);
1542 }
1543 
TEST_F(ArCrsCombinerTest,OneReplicaDontRewrite)1544 TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
1545   const char* module_str = R"(
1546 HloModule foobar
1547 
1548 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
1549   %a = bf16[] parameter(0)
1550   %b = bf16[] parameter(1)
1551   ROOT %add = bf16[] add(%a, %b)
1552 }
1553 
1554 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1555   %x = f32[] parameter(0)
1556   %y = f32[] parameter(1)
1557   ROOT %add = f32[] add(%x, %y)
1558 }
1559 
1560 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
1561   %p = bf16[] parameter(0)
1562   %constant.bf16 = bf16[] constant(1)
1563 
1564   %all-reduce.ar.1 = bf16[]
1565       all-reduce(%p),
1566       replica_groups={{0}},
1567       channel_id=1,
1568       to_apply=%sum.bf16,
1569       sharding={maximal device=0}
1570   %convert.1 = f32[]
1571       convert(%all-reduce.ar.1),
1572       sharding={maximal device=0}
1573   %all-reduce.1 = f32[]
1574       all-reduce(%convert.1),
1575       replica_groups={{0}},
1576       to_apply=%sum.f32,
1577       sharding={maximal device=0}
1578 
1579   %all-reduce.ar.2 = bf16[]
1580       all-reduce(%constant.bf16),
1581       replica_groups={{0}},
1582       channel_id=1,
1583       to_apply=%sum.bf16,
1584       sharding={maximal device=1}
1585   %convert.2 = f32[]
1586       convert(%all-reduce.ar.2),
1587       sharding={maximal device=1}
1588   %all-reduce.2 = f32[]
1589       all-reduce(%convert.2),
1590       replica_groups={{0}},
1591       to_apply=%sum.f32,
1592       sharding={maximal device=1}
1593 
1594   ROOT %tuple = (f32[], f32[])
1595       tuple(%all-reduce.1, %all-reduce.2),
1596       sharding={{maximal device=0}, {maximal device=1}}
1597 }
1598 )";
1599 
1600   TF_ASSERT_OK_AND_ASSIGN(
1601       std::unique_ptr<HloModule> module,
1602       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
1603   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1604                          /*spmd_partition=*/false);
1605   auto changed = combiner.Run(module.get()).ValueOrDie();
1606   EXPECT_FALSE(changed);
1607 }
1608 
TEST_F(ArCrsCombinerTest,OneReplicaDontRewriteSPMD)1609 TEST_F(ArCrsCombinerTest, OneReplicaDontRewriteSPMD) {
1610   const char* module_str = R"(
1611 HloModule foobar
1612 
1613 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
1614   %a = bf16[] parameter(0)
1615   %b = bf16[] parameter(1)
1616   ROOT %add = bf16[] add(%a, %b)
1617 }
1618 
1619 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1620   %x = f32[] parameter(0)
1621   %y = f32[] parameter(1)
1622   ROOT %add = f32[] add(%x, %y)
1623 }
1624 
1625 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
1626   %p = bf16[] parameter(0)
1627   %constant.bf16 = bf16[] constant(1)
1628 
1629   %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}},
1630       channel_id=1, to_apply=%sum.bf16
1631   %convert.1 = f32[] convert(%all-reduce.ar.1)
1632   %all-reduce.1 = f32[] all-reduce(%convert.1),
1633       replica_groups={{0}}, to_apply=%sum.f32
1634   ROOT %tuple = (f32[]) tuple(%all-reduce.1)
1635 }
1636 )";
1637 
1638   TF_ASSERT_OK_AND_ASSIGN(
1639       std::unique_ptr<HloModule> module,
1640       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1));
1641   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1642                          /*spmd_partition=*/true);
1643   auto changed = combiner.Run(module.get()).ValueOrDie();
1644   EXPECT_FALSE(changed);
1645 }
1646 
TEST_F(ArCrsCombinerTest,SameValueTestConditional)1647 TEST_F(ArCrsCombinerTest, SameValueTestConditional) {
1648   const char* module_str = R"(
1649 HloModule foobar
1650 
1651 branch_true {
1652   pt = (f32[2,4], f32[2,4]) parameter(0)
1653   gte.0 = f32[2,4] get-tuple-element(pt), index=0
1654   gte.1 = f32[2,4] get-tuple-element(pt), index=1
1655   ROOT tuple.t = (f32[2,4], f32[2,4]) tuple(gte.1, gte.0)
1656 }
1657 
1658 branch_false {
1659   pf = (f32[2,4], f32[2,4]) parameter(0)
1660   gte.0 = f32[2,4] get-tuple-element(pf), index=0
1661   gte.1 = f32[2,4] get-tuple-element(pf), index=1
1662   add = f32[2,4] add(gte.1, gte.1)
1663   ROOT tuple.f = (f32[2,4], f32[2,4]) tuple(gte.0, add)
1664 }
1665 
1666 ENTRY Parameters1.v4 {
1667   constant = pred[] constant(true)
1668   p = f32[2,4] parameter(0)
1669   tuple = (f32[2,4], f32[2,4]) tuple(p, p)
1670   ROOT conditional = (f32[2,4], f32[2,4]) conditional(constant, tuple, tuple), true_computation=branch_true, false_computation=branch_false
1671 }
1672 )";
1673 
1674   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1675                           ParseAndReturnVerifiedModule(module_str));
1676   auto cond = module->entry_computation()->root_instruction();
1677 
1678   auto branch_true = cond->branch_computation(0)->root_instruction();
1679   auto t0 = branch_true->mutable_operand(0);
1680   auto t1 = branch_true->mutable_operand(1);
1681   EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(t0, t1));
1682 
1683   auto branch_false = cond->branch_computation(1)->root_instruction();
1684   auto f0 = branch_false->mutable_operand(0);
1685   auto f1 = branch_false->mutable_operand(1);
1686   EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(f0, f1));
1687 }
1688 
TEST_F(ArCrsCombinerTest,AllReduceWithReplicas)1689 TEST_F(ArCrsCombinerTest, AllReduceWithReplicas) {
1690   const char* module_str = R"(
1691 HloModule foobar
1692 
1693 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1694   %x = f32[] parameter(0)
1695   %y = f32[] parameter(1)
1696   ROOT %add = f32[] add(%x, %y)
1697 }
1698 
1699 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
1700   %p = bf16[] parameter(0)
1701   %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
1702     to_apply=%sum.f32, sharding={maximal device=0}
1703   %all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
1704     to_apply=%sum.f32, sharding={maximal device=1}
1705   %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
1706     to_apply=%sum.f32, sharding={maximal device=0}
1707   %all-reduce.3 = f32[] all-reduce(%all-reduce.1), replica_groups={{0,1}},
1708     to_apply=%sum.f32, sharding={maximal device=1}
1709   ROOT %tuple = (f32[], f32[]) tuple(%all-reduce.2, %all-reduce.3),
1710       sharding={{maximal device=0}, {maximal device=1}}
1711 }
1712 )";
1713 
1714   TF_ASSERT_OK_AND_ASSIGN(
1715       std::unique_ptr<HloModule> module,
1716       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1717   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1718                          /*spmd_partition=*/false);
1719   auto changed = combiner.Run(module.get()).ValueOrDie();
1720   EXPECT_FALSE(changed);
1721 }
1722 
TEST_F(ArCrsCombinerTest,AllReduceWithReplicasSPMD)1723 TEST_F(ArCrsCombinerTest, AllReduceWithReplicasSPMD) {
1724   const char* module_str = R"(
1725 HloModule foobar
1726 
1727 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1728   %x = f32[] parameter(0)
1729   %y = f32[] parameter(1)
1730   ROOT %add = f32[] add(%x, %y)
1731 }
1732 
1733 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
1734   %p = bf16[] parameter(0)
1735   %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0},{1}},
1736     to_apply=%sum.f32
1737   %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0},{1}},
1738     to_apply=%sum.f32
1739   ROOT %tuple = (f32[]) tuple(%all-reduce.2)
1740 }
1741 )";
1742 
1743   TF_ASSERT_OK_AND_ASSIGN(
1744       std::unique_ptr<HloModule> module,
1745       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1746   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1747                          /*spmd_partition=*/true);
1748   auto changed = combiner.Run(module.get()).ValueOrDie();
1749   EXPECT_FALSE(changed);
1750 }
1751 
TEST_F(ArCrsCombinerTest,ReplaceReplicatedAllReduceSPMD)1752 TEST_F(ArCrsCombinerTest, ReplaceReplicatedAllReduceSPMD) {
1753   const char* module_str = R"(
1754 HloModule foobar
1755 
1756 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1757   %x = f32[] parameter(0)
1758   %y = f32[] parameter(1)
1759   ROOT %add = f32[] add(%x, %y)
1760 }
1761 
1762 ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] {
1763   %p = f32[2,4] parameter(0), sharding={replicated}
1764   ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32,
1765     replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}}
1766 }
1767 )";
1768 
1769   // Replacing replicated all-reduce is only triggered when there are enough
1770   // replicas (currently > num_partitions * 8).
1771   TF_ASSERT_OK_AND_ASSIGN(
1772       std::unique_ptr<HloModule> module,
1773       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32));
1774   ArCrsCombiner combiner(/*num_spatial_partitions=*/2,
1775                          /*spmd_partition=*/true);
1776   auto changed = combiner.Run(module.get()).ValueOrDie();
1777   EXPECT_TRUE(changed);
1778 
1779   auto root = module->entry_computation()->root_instruction();
1780   EXPECT_THAT(root, op::Divide(op::AllReduce(op::Parameter()),
1781                                op::Broadcast(op::Constant())));
1782 
1783   auto ar = root->operand(0);
1784   auto divisor = root->operand(1)->operand(0);
1785   EXPECT_TRUE(ar->channel_id());
1786   EXPECT_TRUE(divisor->literal().IsAllFloat(2));
1787 }
1788 
TEST_F(ArCrsCombinerTest,AllReduceWithGlobalIdReplicaGroups)1789 TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) {
1790   const char* module_str = R"(
1791 HloModule foobar
1792 
1793 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1794   %x = f32[] parameter(0)
1795   %y = f32[] parameter(1)
1796   ROOT %add = f32[] add(%x, %y)
1797 }
1798 
1799 ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
1800   %p = bf16[] parameter(0)
1801   %all-reduce.0 = f32[] all-reduce(%p), channel_id=1,
1802     replica_groups={{0,1,2,3},{4,5,6,7}}, use_global_device_ids=true,
1803     to_apply=%sum.f32
1804   %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
1805     to_apply=%sum.f32
1806   ROOT %tuple = (f32[]) tuple(%all-reduce.2)
1807 }
1808 )";
1809 
1810   TF_ASSERT_OK_AND_ASSIGN(
1811       std::unique_ptr<HloModule> module,
1812       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2,
1813                                    /*num_partitions=*/4));
1814   ArCrsCombiner combiner(/*num_spatial_partitions=*/4,
1815                          /*spmd_partition=*/true);
1816   auto changed = combiner.Run(module.get()).ValueOrDie();
1817   EXPECT_TRUE(changed);
1818 }
1819 
1820 }  // namespace
1821 }  // namespace xla
1822