xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
23 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
24 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 
30 namespace xla {
31 namespace spmd {
32 namespace {
33 
34 using ::testing::_;
35 using ::testing::AllOf;
36 namespace op = xla::testing::opcode_matchers;
37 
38 class SpmdPartitioningTest : public HloTestBase {
39  public:
PartitionComputation(absl::string_view hlo_module,int64_t num_devices,bool conv_halo_exchange_always_on_lhs=true,bool choose_faster_windowed_einsum=false,bool unroll_windowed_einsum=false,bool bidirectional_windowed_einsum=false,int64_t threshold_for_windowed_einsum_mib=-1)40   StatusOr<std::unique_ptr<HloModule>> PartitionComputation(
41       absl::string_view hlo_module, int64_t num_devices,
42       bool conv_halo_exchange_always_on_lhs = true,
43       bool choose_faster_windowed_einsum = false,
44       bool unroll_windowed_einsum = false,
45       bool bidirectional_windowed_einsum = false,
46       int64_t threshold_for_windowed_einsum_mib = -1) {
47     // Some tests (BackpropFilter convs) set this flag false to test two
48     // different paths of the implementation.
49     SpmdPartitionerOptions options;
50     options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
51     options.allow_module_signature_change = true;
52     options.choose_faster_windowed_einsum_over_mem =
53         choose_faster_windowed_einsum;
54     options.unroll_windowed_einsum = unroll_windowed_einsum;
55     options.bidirectional_windowed_einsum = bidirectional_windowed_einsum;
56     if (threshold_for_windowed_einsum_mib >= 0) {
57       options.threshold_for_windowed_einsum_mib =
58           threshold_for_windowed_einsum_mib;
59     }
60     auto collective_ops_creator =
61         GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
62     // Do not use all-gather for pattern-matching purpose, as the partitioner
63     // might create reshape/transposes around it.
64     collective_ops_creator.create_cross_partition_all_gather = nullptr;
65 
66     TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(
67                                          hlo_module, GetModuleConfigForTest()));
68     HloPassPipeline pass("spmd-partitioning");
69     pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
70                               /*allow_mixed_precision=*/false);
71     pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options,
72                                   collective_ops_creator);
73     pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
74                               /*allow_mixed_precision=*/false);
75     TF_RETURN_IF_ERROR(pass.Run(module.get()).status());
76     return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
77   }
78 };
79 
TEST_F(SpmdPartitioningTest,InvalidSharding)80 TEST_F(SpmdPartitioningTest, InvalidSharding) {
81   absl::string_view hlo_string = R"(
82 HloModule module
83 
84 ENTRY entry {
85   token0 = token[] after-all(), sharding={maximal device=0}
86   infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
87     sharding={{devices=[2,1]0,1}, {maximal device=0}}
88   ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
89     sharding={maximal device=0}
90 })";
91   auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
92   EXPECT_FALSE(module_status.status().ok());
93   EXPECT_THAT(module_status.status().ToString(),
94               ::testing::HasSubstr(
95                   "only supports tile sharding that includes all partitions"));
96 }
97 
TEST_F(SpmdPartitioningTest,SingleDeviceToReplicated)98 TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) {
99   absl::string_view hlo_string = R"(
100 HloModule module
101 
102 ENTRY entry {
103   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
104     sharding={maximal device=0}
105   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
106 })";
107   TF_ASSERT_OK_AND_ASSIGN(auto module,
108                           PartitionComputation(hlo_string, /*num_devices=*/2));
109   VLOG(1) << module->ToString();
110   HloInstruction* root = module->entry_computation()->root_instruction();
111   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
112                               op::Select(op::Broadcast(op::Compare()),
113                                          op::Constant(), op::Broadcast()))),
114                           op::Shape("s32[2,3]")));
115 }
116 
TEST_F(SpmdPartitioningTest,SingleDeviceCustomCall)117 TEST_F(SpmdPartitioningTest, SingleDeviceCustomCall) {
118   absl::string_view hlo_string = R"(
119 HloModule module
120 
121 ENTRY entry {
122   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
123     sharding={maximal device=0}
124   %cc = s32[2,3] custom-call(%constant), custom_call_target="SomeCustomCall",
125     sharding={maximal device=0}
126   ROOT %copy = s32[2,3]{1,0} copy(%cc), sharding={replicated}
127 })";
128   TF_ASSERT_OK_AND_ASSIGN(auto module,
129                           PartitionComputation(hlo_string, /*num_devices=*/2));
130   VLOG(1) << module->ToString();
131   HloInstruction* custom_call = FindInstruction(module.get(), "cc.1");
132   EXPECT_NE(custom_call, nullptr);
133   EXPECT_NE(custom_call->parent(), module->entry_computation());
134   HloInstruction* root = module->entry_computation()->root_instruction();
135   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
136                               op::Select(op::Broadcast(op::Compare()),
137                                          op::Conditional(), op::Broadcast()))),
138                           op::Shape("s32[2,3]")));
139 }
140 
TEST_F(SpmdPartitioningTest,SingleDeviceToSingleDevice)141 TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) {
142   absl::string_view hlo_string = R"(
143 HloModule module
144 
145 ENTRY entry {
146   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
147     sharding={maximal device=0}
148   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1}
149 })";
150   TF_ASSERT_OK_AND_ASSIGN(auto module,
151                           PartitionComputation(hlo_string, /*num_devices=*/2));
152   HloInstruction* root = module->entry_computation()->root_instruction();
153   VLOG(1) << module->ToString();
154   EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select(
155                                        op::Broadcast(op::Compare()),
156                                        op::Constant(), op::Broadcast()))),
157                                    op::Shape("s32[2,3]"))));
158 }
159 
TEST_F(SpmdPartitioningTest,SingleDeviceToTiled)160 TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) {
161   absl::string_view hlo_string = R"(
162 HloModule module
163 
164 ENTRY entry {
165   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
166     sharding={maximal device=0}
167   ROOT %copy = s32[2,3]{1,0} copy(%constant),
168     sharding={devices=[2,1]1,0}
169 })";
170   TF_ASSERT_OK_AND_ASSIGN(auto module,
171                           PartitionComputation(hlo_string, /*num_devices=*/2));
172   VLOG(1) << module->ToString();
173   HloInstruction* root = module->entry_computation()->root_instruction();
174   EXPECT_THAT(
175       root,
176       AllOf(
177           op::Copy(op::DynamicSlice(
178               op::AllReduce(op::Select(
179                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
180                   op::Constant(), op::Broadcast())),
181               op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
182               op::Constant())),
183           op::Shape("s32[1,3]")));
184 }
185 
TEST_F(SpmdPartitioningTest,TiledToReplicated)186 TEST_F(SpmdPartitioningTest, TiledToReplicated) {
187   absl::string_view hlo_string = R"(
188 HloModule module
189 
190 ENTRY entry {
191   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
192     sharding={devices=[2,1]0,1}
193   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
194 })";
195   TF_ASSERT_OK_AND_ASSIGN(auto module,
196                           PartitionComputation(hlo_string, /*num_devices=*/2));
197   HloInstruction* root = module->entry_computation()->root_instruction();
198   EXPECT_THAT(
199       root,
200       op::Copy(op::AllReduce(AllOf(
201           op::DynamicUpdateSlice(
202               op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
203               op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
204               op::Constant()),
205           op::Shape("s32[2,3]")))));
206 }
207 
TEST_F(SpmdPartitioningTest,TiledToSingleDevice)208 TEST_F(SpmdPartitioningTest, TiledToSingleDevice) {
209   absl::string_view hlo_string = R"(
210 HloModule module
211 
212 ENTRY entry {
213   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
214     sharding={devices=[2,1]0,1}
215   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0}
216 })";
217   TF_ASSERT_OK_AND_ASSIGN(auto module,
218                           PartitionComputation(hlo_string, /*num_devices=*/2));
219   HloInstruction* root = module->entry_computation()->root_instruction();
220   EXPECT_THAT(
221       root,
222       op::Copy(op::Copy(op::AllReduce(AllOf(
223           op::DynamicUpdateSlice(
224               op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
225               op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
226               op::Constant()),
227           op::Shape("s32[2,3]"))))));
228 }
229 
TEST_F(SpmdPartitioningTest,TiledToTiledEven)230 TEST_F(SpmdPartitioningTest, TiledToTiledEven) {
231   absl::string_view hlo_string = R"(
232 HloModule module
233 
234 ENTRY entry {
235   %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1}
236   ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1}
237 })";
238   TF_ASSERT_OK_AND_ASSIGN(auto module,
239                           PartitionComputation(hlo_string, /*num_devices=*/2));
240   VLOG(1) << module->ToString();
241 
242   HloInstruction* root = module->entry_computation()->root_instruction();
243   EXPECT_THAT(
244       root,
245       AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf(
246                 op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))),
247             op::Shape("s32[8,1]")));
248 }
249 
TEST_F(SpmdPartitioningTest,TiledToTiledUneven)250 TEST_F(SpmdPartitioningTest, TiledToTiledUneven) {
251   absl::string_view hlo_string = R"(
252 HloModule module
253 
254 ENTRY entry {
255   %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1}
256   ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1}
257 })";
258   TF_ASSERT_OK_AND_ASSIGN(auto module,
259                           PartitionComputation(hlo_string, /*num_devices=*/2));
260   VLOG(1) << module->ToString();
261 
262   HloInstruction* root = module->entry_computation()->root_instruction();
263   EXPECT_THAT(
264       root,
265       AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll(
266           op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]")))))))))));
267 }
268 
TEST_F(SpmdPartitioningTest,GetTupleElementSwapDevice)269 TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) {
270   absl::string_view hlo_string = R"(
271 HloModule module
272 
273 ENTRY entry {
274   %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0),
275     sharding={{maximal device=1}, {maximal device=1}}
276   %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0,
277     sharding={maximal device=0}
278   %gte.1 = u32[] get-tuple-element(%param.0), index=1,
279     sharding={maximal device=0}
280   ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1),
281     sharding={{maximal device=0},{maximal device=0}}
282 })";
283   TF_ASSERT_OK_AND_ASSIGN(auto module,
284                           PartitionComputation(hlo_string, /*num_devices=*/2));
285   VLOG(1) << module->ToString();
286   HloInstruction* root = module->entry_computation()->root_instruction();
287   ASSERT_THAT(root, op::Tuple());
288 
289   EXPECT_THAT(root->operand(0),
290               op::Copy(op::AllReduce(op::Select(
291                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
292                   op::GetTupleElement(op::Parameter()), op::Broadcast()))));
293   EXPECT_THAT(root->operand(1),
294               op::Copy(op::AllReduce(op::Select(
295                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
296                   op::GetTupleElement(op::Parameter()), op::Broadcast()))));
297 }
298 
TEST_F(SpmdPartitioningTest,GetTupleElementTiled)299 TEST_F(SpmdPartitioningTest, GetTupleElementTiled) {
300   absl::string_view hlo_string = R"(
301 HloModule module
302 
303 ENTRY entry {
304   param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0),
305     sharding={{replicated}, {replicated}}
306   gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0,
307     sharding={devices=[2,1]0,1}
308   gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1,
309     sharding={devices=[2,1]0,1}
310   ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1),
311     sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}}
312 })";
313   TF_ASSERT_OK_AND_ASSIGN(auto module,
314                           PartitionComputation(hlo_string, /*num_devices=*/2));
315   VLOG(1) << module->ToString();
316   HloInstruction* root = module->entry_computation()->root_instruction();
317   ASSERT_THAT(root, op::Tuple());
318 
319   auto offset =
320       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
321 
322   EXPECT_THAT(root->operand(0),
323               op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
324                                op::Constant()));
325   EXPECT_THAT(root->operand(1),
326               op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
327                                op::Constant()));
328 }
329 
TEST_F(SpmdPartitioningTest,TiledInfeed)330 TEST_F(SpmdPartitioningTest, TiledInfeed) {
331   absl::string_view hlo_string = R"(
332 HloModule module
333 
334 ENTRY entry {
335   token0 = token[] after-all(), sharding={maximal device=0}
336   infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
337     sharding={{devices=[2,1]0,1}, {maximal device=0}}
338   ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
339     sharding={maximal device=0}
340 })";
341   TF_ASSERT_OK_AND_ASSIGN(auto module,
342                           PartitionComputation(hlo_string, /*num_devices=*/2));
343   HloInstruction* root = module->entry_computation()->root_instruction();
344   EXPECT_THAT(
345       root,
346       op::Copy(op::AllReduce(op::DynamicUpdateSlice(
347           op::Broadcast(),
348           op::GetTupleElement(
349               AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
350           op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
351           op::Constant()))));
352 }
353 
TEST_F(SpmdPartitioningTest,UnevenTiledInfeed)354 TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) {
355   absl::string_view hlo_string = R"(
356 HloModule module
357 
358 ENTRY entry {
359   token0 = token[] after-all(), sharding={maximal device=0}
360   infeed = (f32[9,2]{1,0}, token[]) infeed(token0),
361     sharding={{devices=[2,1]0,1}, {maximal device=0}}
362   ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0,
363     sharding={devices=[2,1]0,1}
364 })";
365   TF_ASSERT_OK_AND_ASSIGN(auto module,
366                           PartitionComputation(hlo_string, /*num_devices=*/2));
367   VLOG(1) << module->ToString();
368   HloInstruction* root = module->entry_computation()->root_instruction();
369   EXPECT_THAT(
370       root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional(
371                                              op::Convert(op::PartitionId()),
372                                              op::AfterAll(), op::AfterAll()))));
373   EXPECT_THAT(
374       root->operand(0)->called_computations()[0]->root_instruction(),
375       AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter())));
376   auto second_infeed =
377       AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter()));
378   EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
379               AllOf(op::Shape("(f32[5,2], token[])"),
380                     op::Tuple(op::Pad(op::GetTupleElement(second_infeed),
381                                       op::Constant()),
382                               op::GetTupleElement(second_infeed))));
383 }
384 
TEST_F(SpmdPartitioningTest,UnevenTiledTupleInfeed)385 TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) {
386   absl::string_view hlo_string = R"(
387 HloModule module
388 
389 ENTRY entry {
390   token0 = token[] after-all(), sharding={maximal device=0}
391   infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
392     sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}}
393   ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
394     index=0, sharding={{devices=[2,1]0,1}, {replicated}}
395 })";
396   TF_ASSERT_OK_AND_ASSIGN(auto module,
397                           PartitionComputation(hlo_string, /*num_devices=*/2));
398   VLOG(1) << module->ToString();
399   HloInstruction* root = module->entry_computation()->root_instruction();
400   EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"),
401                           op::GetTupleElement(op::Conditional(
402                               op::Convert(op::PartitionId()), op::AfterAll(),
403                               op::AfterAll()))));
404   EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
405               AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
406                     op::Infeed(op::Parameter())));
407   auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"),
408                              op::Infeed(op::Parameter()));
409   EXPECT_THAT(
410       root->operand(0)->called_computations()[1]->root_instruction(),
411       AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
412             op::Tuple(op::Tuple(op::Pad(op::GetTupleElement(
413                                             op::GetTupleElement(second_infeed)),
414                                         op::Constant()),
415                                 op::GetTupleElement(
416                                     op::GetTupleElement(second_infeed))),
417                       op::GetTupleElement(second_infeed))));
418 }
419 
TEST_F(SpmdPartitioningTest,MixedTupleInfeed)420 TEST_F(SpmdPartitioningTest, MixedTupleInfeed) {
421   absl::string_view hlo_string = R"(
422 HloModule module
423 
424 ENTRY entry {
425   token0 = token[] after-all(), sharding={maximal device=0}
426   infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
427     sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
428   ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
429     index=0, sharding={{maximal device=0}, {maximal device=1}}
430 })";
431   TF_ASSERT_OK_AND_ASSIGN(auto module,
432                           PartitionComputation(hlo_string, /*num_devices=*/2));
433   VLOG(1) << module->ToString();
434   HloInstruction* root = module->entry_computation()->root_instruction();
435   EXPECT_THAT(root, AllOf(op::Shape("(f32[9,2], f32[2])"),
436                           op::GetTupleElement(op::Conditional(
437                               op::Convert(op::PartitionId()), op::AfterAll(),
438                               op::AfterAll()))));
439   auto first_infeed = AllOf(op::Shape("((f32[9,2], ()), token[])"),
440                             op::Infeed(op::Parameter()));
441   EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
442               AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
443                     op::Tuple(op::Tuple(op::GetTupleElement(
444                                             op::GetTupleElement(first_infeed)),
445                                         op::Broadcast(op::Constant())),
446                               op::GetTupleElement(first_infeed))));
447   auto second_infeed =
448       AllOf(op::Shape("(((), f32[2]), token[])"), op::Infeed(op::Parameter()));
449   EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
450               AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
451                     op::Tuple(op::Tuple(op::Broadcast(op::Constant()),
452                                         op::GetTupleElement(op::GetTupleElement(
453                                             second_infeed))),
454                               op::GetTupleElement(second_infeed))));
455 }
456 
TEST_F(SpmdPartitioningTest,TiledToReplicatedReduce)457 TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) {
458   absl::string_view hlo_string = R"(
459 HloModule module
460 
461 sum {
462   a = f32[] parameter(0)
463   b = f32[] parameter(1)
464   ROOT add = f32[] add(a, b)
465 }
466 
467 ENTRY entry {
468   constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
469     sharding={devices=[2,1]0,1}
470   constant.1 = f32[] constant(0), sharding={replicated}
471   ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1},
472     to_apply=sum, sharding={replicated}
473 })";
474   TF_ASSERT_OK_AND_ASSIGN(auto module,
475                           PartitionComputation(hlo_string, /*num_devices=*/2));
476   VLOG(1) << module->ToString();
477   HloInstruction* root = module->entry_computation()->root_instruction();
478   EXPECT_THAT(
479       root,
480       op::AllReduce(op::Reduce(
481           op::Select(
482               op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())),
483                           op::Broadcast(op::Constant())),
484               AllOf(op::Shape("f32[2,3]{1,0}"),
485                     op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
486                                      op::Reshape(), op::Constant())),
487               op::Broadcast(op::Constant())),
488           op::Constant())));
489 }
490 
TEST_F(SpmdPartitioningTest,TiledElementwise)491 TEST_F(SpmdPartitioningTest, TiledElementwise) {
492   absl::string_view hlo_string = R"(
493 HloModule module
494 
495 ENTRY entry {
496   constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
497     sharding={devices=[2,1]0,1}
498   constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}),
499     sharding={replicated}
500   multiply = f32[3,3]{1,0} multiply(constant, constant.1),
501     sharding={devices=[2,1]0,1}
502   ROOT add = f32[3,3]{1,0} add(multiply, constant.1),
503     sharding={devices=[2,1]0,1}
504 })";
505   TF_ASSERT_OK_AND_ASSIGN(auto module,
506                           PartitionComputation(hlo_string, /*num_devices=*/2));
507   VLOG(1) << module->ToString();
508   HloInstruction* root = module->entry_computation()->root_instruction();
509   EXPECT_THAT(
510       root,
511       AllOf(
512           op::Shape("f32[2,3]{1,0}"),
513           op::Add(op::Multiply(
514                       op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
515                                        op::Reshape(), op::Constant()),
516                       op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
517                                        op::Reshape(), op::Constant())),
518                   op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
519                                    op::Reshape(), op::Constant()))));
520 }
521 
TEST_F(SpmdPartitioningTest,TiledAllReduce)522 TEST_F(SpmdPartitioningTest, TiledAllReduce) {
523   absl::string_view hlo_string = R"(
524 HloModule module
525 
526 sum {
527   a = f32[] parameter(0)
528   b = f32[] parameter(1)
529   ROOT add = f32[] add(a, b)
530 }
531 
532 ENTRY entry {
533   parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1}
534   ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum,
535     replica_groups={}, sharding={devices=[2,1]0,1}
536 })";
537   TF_ASSERT_OK_AND_ASSIGN(auto module,
538                           PartitionComputation(hlo_string, /*num_devices=*/2));
539   VLOG(1) << module->ToString();
540   HloInstruction* root = module->entry_computation()->root_instruction();
541   EXPECT_THAT(
542       root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0))));
543 }
544 
TEST_F(SpmdPartitioningTest,BroadcastOnlyNewDimsSharded)545 TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) {
546   absl::string_view hlo_string = R"(
547 HloModule module
548 
549 ENTRY entry {
550   constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
551     sharding={replicated}
552   ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
553     sharding={devices=[2,1,1]0,1}
554 })";
555   TF_ASSERT_OK_AND_ASSIGN(auto module,
556                           PartitionComputation(hlo_string, /*num_devices=*/2));
557   VLOG(1) << module->ToString();
558   HloInstruction* root = module->entry_computation()->root_instruction();
559   EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"),
560                           op::Broadcast(op::Constant())));
561 }
562 
TEST_F(SpmdPartitioningTest,BroadcastOnlyOldDimsSharded)563 TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) {
564   absl::string_view hlo_string = R"(
565 HloModule module
566 
567 ENTRY entry {
568   constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
569     sharding={replicated}
570   ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
571     sharding={devices=[1,2,1]0,1}
572 })";
573   TF_ASSERT_OK_AND_ASSIGN(auto module,
574                           PartitionComputation(hlo_string, /*num_devices=*/2));
575   VLOG(1) << module->ToString();
576   HloInstruction* root = module->entry_computation()->root_instruction();
577   EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
578                           op::Broadcast(op::DynamicSlice(
579                               op::Constant(), op::Reshape(), op::Constant()))));
580 }
581 
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsSharded)582 TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) {
583   absl::string_view hlo_string = R"(
584 HloModule module
585 
586 ENTRY entry {
587   constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
588     sharding={replicated}
589   ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
590     sharding={devices=[2,2,1]0,1,2,3}
591 })";
592   TF_ASSERT_OK_AND_ASSIGN(auto module,
593                           PartitionComputation(hlo_string, /*num_devices=*/4));
594   VLOG(1) << module->ToString();
595   HloInstruction* root = module->entry_computation()->root_instruction();
596   EXPECT_THAT(
597       root,
598       AllOf(op::Shape("f32[2,2,3]{2,1,0}"),
599             op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"),
600                                 op::DynamicSlice(op::Constant(), op::Reshape(),
601                                                  op::Constant())))));
602 }
603 
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsShardedPartiallySharded)604 TEST_F(SpmdPartitioningTest,
605        BroadcastBothOldAndNewDimsShardedPartiallySharded) {
606   absl::string_view hlo_string = R"(
607 HloModule module
608 
609 ENTRY entry {
610   param = f32[4,3] parameter(0),
611     sharding={devices=[1,2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
612   ROOT broadcast = f32[4,4,3] broadcast(param), dimensions={1,2},
613     sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
614 })";
615   TF_ASSERT_OK_AND_ASSIGN(auto module,
616                           PartitionComputation(hlo_string, /*num_devices=*/8));
617   VLOG(1) << module->ToString();
618   HloInstruction* root = module->entry_computation()->root_instruction();
619   EXPECT_THAT(
620       root,
621       AllOf(op::Shape("f32[2,4,2]"),
622             op::Broadcast(AllOf(op::Shape("f32[4,2]"), op::Parameter(0)))));
623 }
624 
TEST_F(SpmdPartitioningTest,ConvWithParallelDimAndNonParallelSpatialDimPartitioned)625 TEST_F(SpmdPartitioningTest,
626        ConvWithParallelDimAndNonParallelSpatialDimPartitioned) {
627   absl::string_view hlo_string = R"(
628 HloModule module
629 
630 ENTRY entry {
631   %lhs = f32[32,12,12,24,32] parameter(0)
632   %lhs.copy = f32[32,12,12,24,32] copy(%lhs),
633     sharding={devices=[2,2,1,1,1]0,1,2,3}
634   %rhs = f32[32,6,6,16,32] parameter(1)
635   %rhs.copy = f32[32,6,6,16,32] copy(%rhs),
636     sharding={devices=[2,2,1,1,1]0,1,2,3}
637   ROOT %conv = f32[32,7,7,24,16] convolution(%lhs.copy, %rhs.copy),
638     dim_labels=012bf_012oi->012bf,
639     window={size=32x6x6 stride=31x1x1 lhs_dilate=32x1x1},
640     sharding={devices=[2,2,1,1,1]0,1,2,3}
641 })";
642 
643   TF_ASSERT_OK_AND_ASSIGN(auto module,
644                           PartitionComputation(hlo_string, /*num_devices=*/4));
645   VLOG(1) << module->ToString();
646   const auto root = module->entry_computation()->root_instruction();
647   const auto lhs = AllOf(op::Copy(op::DynamicSlice(
648                              op::Parameter(), op::Reshape(), op::Reshape(),
649                              op::Constant(), op::Constant(), op::Constant())),
650                          op::Shape("f32[16,6,12,24,32]"));
651   const auto rhs = AllOf(op::Copy(op::DynamicSlice(
652                              op::Parameter(), op::Reshape(), op::Reshape(),
653                              op::Constant(), op::Constant(), op::Constant())),
654                          op::Shape("f32[16,3,6,16,32]"));
655   auto resharded_rhs =
656       AllOf(op::Shape("f32[16,6,6,16,32]"),
657             op::AllReduce(op::DynamicUpdateSlice(
658                 op::Broadcast(), rhs, op::Constant(), op::Reshape(),
659                 op::Constant(), op::Constant(), op::Constant())));
660 
661   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
662                          op::Shape("f32[16,2,12,24,32]"));
663   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
664                           op::Shape("f32[16,3,12,24,32]"));
665   EXPECT_THAT(
666       root,
667       AllOf(op::Convolution(
668                 op::Select(op::Compare(),
669                            op::DynamicSlice(
670                                op::Concatenate(left_halo, lhs, right_halo),
671                                op::Constant(), op::Add(), op::Constant(),
672                                op::Constant(), op::Constant()),
673                            op::Broadcast()),
674                 resharded_rhs),
675             op::Shape("f32[16,4,7,24,16]")));
676 }
677 
TEST_F(SpmdPartitioningTest,BroadcastPropagateTiledSharding)678 TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) {
679   absl::string_view hlo_string = R"(
680 HloModule module
681 
682 ENTRY entry {
683   constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}),
684     sharding={devices=[2,1]0,1}
685   ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
686     sharding={devices=[1,2,1]0,1}
687 })";
688   TF_ASSERT_OK_AND_ASSIGN(auto module,
689                           PartitionComputation(hlo_string, /*num_devices=*/2));
690   VLOG(1) << module->ToString();
691   HloInstruction* root = module->entry_computation()->root_instruction();
692   EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
693                           op::Broadcast(op::DynamicSlice(
694                               op::Constant(), op::Reshape(), op::Constant()))));
695 }
696 
TEST_F(SpmdPartitioningTest,OutfeedSingleDevice)697 TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) {
698   absl::string_view hlo_string = R"(
699 HloModule module
700 
701 ENTRY entry {
702   token.0 = token[] after-all()
703   data = f32[1024]{0} parameter(0), sharding={maximal device=0}
704   outfeed = token[] outfeed(data, token.0), sharding={maximal device=0}
705 })";
706   TF_ASSERT_OK_AND_ASSIGN(auto module,
707                           PartitionComputation(hlo_string, /*num_devices=*/2));
708   VLOG(1) << module->ToString();
709   HloInstruction* root = module->entry_computation()->root_instruction();
710   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
711                           op::Conditional(
712                               op::Compare(op::PartitionId(), op::Constant()),
713                               op::Tuple(op::Parameter(0), op::AfterAll()),
714                               op::Tuple(op::Parameter(0), op::AfterAll()))));
715 
716   HloInstruction* root_b0 = root->branch_computation(0)->root_instruction();
717   EXPECT_THAT(root_b0,
718               AllOf(op::Shape("token[]"),
719                     op::Outfeed(op::GetTupleElement(op::Parameter(), 0),
720                                 op::GetTupleElement(op::Parameter(), 1))));
721 
722   HloInstruction* root_b1 = root->branch_computation(1)->root_instruction();
723   EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll()));
724 }
725 
TEST_F(SpmdPartitioningTest,OutfeedEvenlyTiled)726 TEST_F(SpmdPartitioningTest, OutfeedEvenlyTiled) {
727   absl::string_view hlo_string = R"(
728 HloModule module
729 
730 ENTRY entry {
731   token.0 = token[] after-all()
732   data = f32[1024]{0} parameter(0), sharding={devices=[2]0,1}
733   ROOT outfeed = token[] outfeed(data, token.0), sharding={devices=[2]0,1}
734 })";
735   TF_ASSERT_OK_AND_ASSIGN(auto module,
736                           PartitionComputation(hlo_string, /*num_devices=*/2));
737   VLOG(1) << module->ToString();
738   HloInstruction* root = module->entry_computation()->root_instruction();
739   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
740                           op::Outfeed(op::Parameter(), op::AfterAll())));
741 }
742 
TEST_F(SpmdPartitioningTest,OutfeedTupleEvenlyTiled)743 TEST_F(SpmdPartitioningTest, OutfeedTupleEvenlyTiled) {
744   absl::string_view hlo_string = R"(
745 HloModule module
746 
747 ENTRY entry {
748   token.0 = token[] after-all()
749   data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
750     {devices=[2]0,1}}
751   ROOT outfeed = token[] outfeed(data, token.0),
752     outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
753     {devices=[2]0,1}}
754 })";
755   TF_ASSERT_OK_AND_ASSIGN(auto module,
756                           PartitionComputation(hlo_string, /*num_devices=*/2));
757   VLOG(1) << module->ToString();
758   HloInstruction* root = module->entry_computation()->root_instruction();
759   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
760                           op::Outfeed(op::Parameter(), op::AfterAll())));
761   auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
762   auto expected_layout1 = LayoutUtil::MakeLayout({0});
763   EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
764                                 expected_layout0));
765   EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
766                                 expected_layout1));
767 }
768 
TEST_F(SpmdPartitioningTest,OutfeedReplicated)769 TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
770   absl::string_view hlo_string = R"(
771 HloModule module
772 
773 ENTRY entry {
774   token.0 = token[] after-all()
775   data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
776     {replicated}}
777   ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
778     {replicated}}
779 })";
780   TF_ASSERT_OK_AND_ASSIGN(auto module,
781                           PartitionComputation(hlo_string, /*num_devices=*/2));
782   VLOG(1) << module->ToString();
783   HloInstruction* root = module->entry_computation()->root_instruction();
784   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
785                           op::Outfeed(op::Parameter(), op::AfterAll())));
786 }
787 
TEST_F(SpmdPartitioningTest,OutfeedUnevenlyTiled)788 TEST_F(SpmdPartitioningTest, OutfeedUnevenlyTiled) {
789   absl::string_view hlo_string = R"(
790 HloModule module
791 
792 ENTRY entry {
793   token.0 = token[] after-all()
794   data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
795     {devices=[2]0,1}}
796   outfeed = token[] outfeed(data, token.0),
797     outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
798     {devices=[2]0,1}}
799 })";
800   TF_ASSERT_OK_AND_ASSIGN(auto module,
801                           PartitionComputation(hlo_string, /*num_devices=*/2));
802   VLOG(1) << module->ToString();
803 
804   HloInstruction* root = module->entry_computation()->root_instruction();
805   EXPECT_THAT(
806       root, AllOf(op::Shape("token[]"),
807                   op::Conditional(op::Convert(),
808                                   op::Tuple(op::Parameter(), op::AfterAll()),
809                                   op::Tuple(op::Parameter(), op::AfterAll()))));
810 
811   auto first_outfeed =
812       AllOf(op::Shape("(f32[512,2], f32[2])"), op::GetTupleElement());
813   EXPECT_THAT(root->called_computations()[0]->root_instruction(),
814               AllOf(op::Shape("token[]"),
815                     op::Outfeed(first_outfeed, op::GetTupleElement())));
816 
817   auto second_outfeed = AllOf(op::Shape("(f32[511,2], f32[1])"), op::Tuple());
818   EXPECT_THAT(root->called_computations()[1]->root_instruction(),
819               AllOf(op::Shape("token[]"),
820                     op::Outfeed(second_outfeed, op::GetTupleElement())));
821 
822   auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
823   auto expected_layout1 = LayoutUtil::MakeLayout({0});
824   auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
825   auto second_outfeed_instr =
826       root->called_computations()[1]->root_instruction();
827   EXPECT_TRUE(LayoutUtil::Equal(
828       first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
829       expected_layout0));
830   EXPECT_TRUE(LayoutUtil::Equal(
831       first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
832       expected_layout1));
833   EXPECT_TRUE(LayoutUtil::Equal(
834       second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
835       expected_layout0));
836   EXPECT_TRUE(LayoutUtil::Equal(
837       second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
838       expected_layout1));
839 }
840 
TEST_F(SpmdPartitioningTest,ReduceWindowReplicatedInput)841 TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
842   absl::string_view hlo_string = R"(
843 HloModule module
844 
845 sum {
846   a = f32[] parameter(0)
847   b = f32[] parameter(1)
848   ROOT add = f32[] add(a, b)
849 }
850 
851 ENTRY entry {
852   constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
853     sharding={replicated}
854   constant.1 = f32[] constant(0), sharding={replicated}
855   ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1),
856     window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum,
857     sharding={devices=[2,1]0,1}
858 })";
859   TF_ASSERT_OK_AND_ASSIGN(auto module,
860                           PartitionComputation(hlo_string, /*num_devices=*/2));
861   VLOG(1) << module->ToString();
862   HloInstruction* root = module->entry_computation()->root_instruction();
863   EXPECT_THAT(
864       root,
865       AllOf(op::Shape("f32[2,2]{1,0}"),
866             op::ReduceWindow(
867                 op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"),
868                                        op::Pad(op::Constant(), op::Constant())),
869                                  op::Multiply(op::Reshape(), op::Constant()),
870                                  op::Constant()),
871                 op::Constant())));
872 }
873 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledNegativeLeftHalo)874 TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) {
875   absl::string_view hlo_string = R"(
876 HloModule module
877 
878 sum {
879   a = f32[] parameter(0)
880   b = f32[] parameter(1)
881   ROOT add = f32[] add(a, b)
882 }
883 
884 ENTRY entry {
885   constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
886     sharding={devices=[2,1]0,1}
887   constant.1 = f32[] constant(0), sharding={replicated}
888   ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1),
889     window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum,
890     sharding={devices=[2,1]0,1}
891 })";
892   TF_ASSERT_OK_AND_ASSIGN(auto module,
893                           PartitionComputation(hlo_string, /*num_devices=*/2));
894   VLOG(1) << module->ToString();
895   HloInstruction* root = module->entry_computation()->root_instruction();
896 
897   auto sharded_input =
898       op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
899   auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
900                           op::CollectivePermute(op::Slice(sharded_input)));
901   auto pre_masking = op::DynamicSlice(
902       AllOf(
903           op::Shape("f32[6,2]{1,0}"),
904           op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
905       op::Reshape(), op::Constant());
906   auto index_in_padded = op::Add(
907       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
908   auto masked =
909       op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
910                  pre_masking, op::Broadcast(op::Constant()));
911   EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
912                           op::ReduceWindow(masked, op::Constant())));
913 }
914 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideHaloBeyondNeighbor)915 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
916   absl::string_view hlo_string = R"(
917 HloModule module
918 
919 sum {
920   a = f32[] parameter(0)
921   b = f32[] parameter(1)
922   ROOT add = f32[] add(a, b)
923 }
924 
925 ENTRY entry {
926   param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
927   constant.1 = f32[] constant(0), sharding={replicated}
928   ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
929     window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
930     sharding={devices=[5,1]0,1,2,3,4}
931 })";
932   TF_ASSERT_OK_AND_ASSIGN(auto module,
933                           PartitionComputation(hlo_string, /*num_devices=*/5));
934   VLOG(1) << module->ToString();
935   auto halo0 = AllOf(op::Shape("f32[1,2]"),
936                      op::CollectivePermute(op::Slice(op::Parameter(0))));
937   auto halo1 =
938       AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
939   auto pre_mask =
940       AllOf(op::Shape("f32[4,2]"),
941             op::Concatenate(halo0, halo1, op::Slice(op::Parameter(0))));
942   auto masked =
943       op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
944                              op::Broadcast(op::Constant())),
945                  pre_mask, op::Broadcast(op::Constant()));
946   HloInstruction* root = module->entry_computation()->root_instruction();
947   EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
948                           op::ReduceWindow(masked, op::Constant())));
949 }
950 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideUnequalHalo)951 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
952   absl::string_view hlo_string = R"(
953 HloModule module
954 
955 sum {
956   a = f32[] parameter(0)
957   b = f32[] parameter(1)
958   ROOT add = f32[] add(a, b)
959 }
960 
961 ENTRY entry {
962   constant = f32[9,2]{1,0} constant(
963     {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}),
964     sharding={devices=[3,1]0,1,2}
965   constant.1 = f32[] constant(0), sharding={replicated}
966   ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1),
967     window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum,
968     sharding={devices=[3,1]0,1,2}
969 })";
970   TF_ASSERT_OK_AND_ASSIGN(auto module,
971                           PartitionComputation(hlo_string, /*num_devices=*/3));
972   VLOG(1) << module->ToString();
973   HloInstruction* root = module->entry_computation()->root_instruction();
974 
975   auto sharded_input =
976       op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
977   auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
978                           op::CollectivePermute(op::Slice(sharded_input)));
979   auto pre_masking = op::DynamicSlice(
980       AllOf(
981           op::Shape("f32[7,2]{1,0}"),
982           op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
983       op::Reshape(), op::Constant());
984   auto index_in_padded = op::Add(
985       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
986   auto masked = op::Select(
987       op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
988               op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
989       pre_masking, op::Broadcast(op::Constant()));
990   EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
991                           op::ReduceWindow(masked, op::Constant())));
992 }
993 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledTwoSideHalo)994 TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) {
995   absl::string_view hlo_string = R"(
996 HloModule module
997 
998 sum {
999   a = f32[] parameter(0)
1000   b = f32[] parameter(1)
1001   ROOT add = f32[] add(a, b)
1002 }
1003 
1004 ENTRY entry {
1005   constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}),
1006     sharding={devices=[2,1]0,1}
1007   constant.1 = f32[] constant(0), sharding={replicated}
1008   ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1),
1009     window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum,
1010     sharding={devices=[2,1]0,1}
1011 })";
1012   TF_ASSERT_OK_AND_ASSIGN(auto module,
1013                           PartitionComputation(hlo_string, /*num_devices=*/2));
1014   VLOG(1) << module->ToString();
1015   HloInstruction* root = module->entry_computation()->root_instruction();
1016 
1017   auto sharded_input =
1018       op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
1019   auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
1020                          op::CollectivePermute(op::Slice(sharded_input)));
1021   auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
1022                           op::CollectivePermute(op::Slice(sharded_input)));
1023   auto pre_masking = AllOf(
1024       op::Shape("f32[5,2]{1,0}"),
1025       op::DynamicSlice(
1026           AllOf(op::Shape("f32[6,2]{1,0}"),
1027                 op::Pad(op::Concatenate(left_halo, sharded_input, right_halo),
1028                         op::Constant())),
1029           op::Reshape(), op::Constant()));
1030   auto index_in_padded = op::Add(
1031       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1032   auto masked = op::Select(
1033       op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
1034               op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
1035       pre_masking, op::Broadcast(op::Constant()));
1036   EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
1037                           op::ReduceWindow(masked, op::Constant())));
1038 }
1039 
TEST_F(SpmdPartitioningTest,ReduceWindowTiled2D)1040 TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) {
1041   absl::string_view hlo_string = R"(
1042 HloModule module
1043 
1044 sum {
1045   a = f32[] parameter(0)
1046   b = f32[] parameter(1)
1047   ROOT add = f32[] add(a, b)
1048 }
1049 
1050 ENTRY entry {
1051   token0 = token[] after-all(), sharding={maximal device=0}
1052   infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0),
1053     sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}}
1054   infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0,
1055     sharding={devices=[2,2,1,1]0,1,2,3}
1056   constant = f32[] constant(0), sharding={replicated}
1057   ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant),
1058     window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum,
1059     sharding={devices=[2,2,1,1]0,1,2,3}
1060 })";
1061   TF_ASSERT_OK_AND_ASSIGN(auto module,
1062                           PartitionComputation(hlo_string, /*num_devices=*/4));
1063   VLOG(1) << module->ToString();
1064   HloInstruction* root = module->entry_computation()->root_instruction();
1065 
1066   auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"),
1067                              op::GetTupleElement(op::Infeed()));
1068   auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1069                               op::CollectivePermute(op::Slice(sharded_input)));
1070   auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1071                                op::CollectivePermute(op::Slice(sharded_input)));
1072   auto dim0_pre_masking = op::DynamicSlice(
1073       AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"),
1074             op::Pad(
1075                 op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo),
1076                 op::Constant())),
1077       op::Reshape(), op::Constant(), op::Constant(), op::Constant());
1078   auto dim0_index_in_padded = op::Add(
1079       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1080   auto dim0_masked = op::Select(
1081       op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())),
1082               op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))),
1083       dim0_pre_masking, op::Broadcast(op::Constant()));
1084   auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked);
1085   auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1086                               op::CollectivePermute(op::Slice(dim0_resharded)));
1087   auto dim1_right_halo =
1088       AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1089             op::CollectivePermute(op::Slice(dim0_resharded)));
1090   auto dim1_pre_masking = op::DynamicSlice(
1091       AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"),
1092             op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded,
1093                                     dim1_right_halo),
1094                     op::Constant())),
1095       op::Constant(), op::Reshape(), op::Constant(), op::Constant());
1096   auto dim1_index_in_padded = op::Add(
1097       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1098   auto dim1_masked = op::Select(
1099       op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())),
1100               op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))),
1101       dim1_pre_masking, op::Broadcast(op::Constant()));
1102   auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked);
1103   EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"),
1104                           op::ReduceWindow(dim1_resharded, op::Constant())));
1105 }
1106 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicated)1107 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) {
1108   absl::string_view hlo_string = R"(
1109 HloModule module
1110 
1111 ENTRY entry {
1112   %lhs = f32[128,224,224,3] parameter(0)
1113   %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1114     sharding={devices=[1,2,1,1]0,1}
1115   %rhs = f32[7,7,3,64] parameter(1)
1116   %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1117     sharding={replicated}
1118   ROOT %conv = f32[128,112,112,64] convolution(
1119     f32[128,224,224,3] %lhs.copy,
1120     f32[7,7,3,64] %rhs.copy),
1121     window={size=7x7 stride=2x2 pad=3_3x3_3},
1122     dim_labels=b01f_01io->b01f,
1123     sharding={devices=[1,2,1,1]0,1}
1124 })";
1125 
1126   TF_ASSERT_OK_AND_ASSIGN(auto module,
1127                           PartitionComputation(hlo_string, /*num_devices=*/2));
1128   VLOG(1) << module->ToString();
1129 
1130   const auto root = module->entry_computation()->root_instruction();
1131   const auto lhs = AllOf(
1132       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1133                                 op::Constant(), op::Constant())),
1134       op::Shape("f32[128,112,224,3]"));
1135   const auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1136 
1137   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1138                          op::Shape("f32[128,3,224,3]"));
1139   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1140                           op::Shape("f32[128,2,224,3]"));
1141   EXPECT_THAT(root,
1142               AllOf(op::Convolution(
1143                         op::Select(op::And(),
1144                                    op::Concatenate(left_halo, lhs, right_halo),
1145                                    op::Broadcast()),
1146                         rhs),
1147                     op::Shape("f32[128,56,112,64]")));
1148 }
1149 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedNeedReshard)1150 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) {
1151   absl::string_view hlo_string = R"(
1152 HloModule module
1153 
1154 ENTRY entry {
1155   %lhs = f32[128,224,224,3] parameter(0)
1156   %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1157     sharding={devices=[2,1,1,1]0,1}
1158   %rhs = f32[7,7,3,64] parameter(1)
1159   %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1160     sharding={replicated}
1161   ROOT %conv = f32[128,112,112,64] convolution(
1162     f32[128,224,224,3] %lhs.copy,
1163     f32[7,7,3,64] %rhs.copy),
1164     window={size=7x7 stride=2x2 pad=3_3x3_3},
1165     dim_labels=b01f_01io->b01f,
1166     sharding={devices=[1,2,1,1]0,1}
1167 })";
1168 
1169   TF_ASSERT_OK_AND_ASSIGN(auto module,
1170                           PartitionComputation(hlo_string, /*num_devices=*/2));
1171   VLOG(1) << module->ToString();
1172 
1173   const auto root = module->entry_computation()->root_instruction();
1174   const auto lhs = AllOf(
1175       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1176                                 op::Constant(), op::Constant())),
1177       op::Shape("f32[64,224,224,3]"));
1178   auto all_to_all =
1179       AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]"));
1180   auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)),
1181                            op::Shape("f32[128,112,224,3]"));
1182 
1183   const auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1184 
1185   auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1186                          op::Shape("f32[128,3,224,3]"));
1187   auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1188                           op::Shape("f32[128,2,224,3]"));
1189   EXPECT_THAT(
1190       root,
1191       AllOf(op::Convolution(
1192                 op::Select(op::And(),
1193                            op::Concatenate(left_halo, reshard_lhs, right_halo),
1194                            op::Broadcast()),
1195                 rhs),
1196             op::Shape("f32[128,56,112,64]")));
1197 }
1198 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedReordered)1199 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) {
1200   absl::string_view hlo_string = R"(
1201 HloModule module
1202 
1203 ENTRY entry {
1204   %lhs = f32[224,224,3,128] parameter(0)
1205   %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1}
1206   %rhs = f32[7,7,3,64] parameter(1)
1207   %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated}
1208   ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy),
1209     window={size=7x7 stride=2x2 pad=3_3x3_3},
1210     dim_labels=01fb_01io->b01f,
1211     sharding={devices=[1,2,1,1]0,1}
1212 })";
1213 
1214   TF_ASSERT_OK_AND_ASSIGN(auto module,
1215                           PartitionComputation(hlo_string, /*num_devices=*/2));
1216   VLOG(1) << module->ToString();
1217 
1218   const auto root = module->entry_computation()->root_instruction();
1219   const auto lhs = AllOf(
1220       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1221                                 op::Constant(), op::Constant())),
1222       op::Shape("f32[112,224,3,128]"));
1223   const auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1224 
1225   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1226                          op::Shape("f32[3,224,3,128]"));
1227   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1228                           op::Shape("f32[2,224,3,128]"));
1229   EXPECT_THAT(root,
1230               AllOf(op::Convolution(
1231                         op::Select(op::And(),
1232                                    op::Concatenate(left_halo, lhs, right_halo),
1233                                    op::Broadcast()),
1234                         rhs),
1235                     op::Shape("f32[128,56,112,64]")));
1236 }
1237 
1238 // (stride * per_shard_window_count) % dilation == 0
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated)1239 TEST_F(SpmdPartitioningTest,
1240        ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) {
1241   absl::string_view hlo_string = R"(
1242 HloModule module
1243 
1244 ENTRY entry {
1245   %lhs = f32[128,7,7,512] parameter(0)
1246   %lhs.copy = f32[128,7,7,512] copy(%lhs),
1247     sharding={devices=[1,2,1,1]0,1}
1248   %rhs = f32[3,3,512,512] parameter(1)
1249   %rhs.copy = f32[3,3,512,512] copy(%rhs),
1250     sharding={replicated}
1251   ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy),
1252     window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1},
1253     dim_labels=b01f_01io->b01f,
1254     sharding={devices=[1,2,1,1]0,1}
1255 })";
1256 
1257   TF_ASSERT_OK_AND_ASSIGN(auto module,
1258                           PartitionComputation(hlo_string, /*num_devices=*/2));
1259   VLOG(1) << module->ToString();
1260 
1261   const auto root = module->entry_computation()->root_instruction();
1262   // There is no halo exchange, and because the last element in the shard is not
1263   // needed (stride == 4), the LHS will be just a slice.
1264   auto sliced_lhs =
1265       AllOf(op::Slice(op::Copy(op::DynamicSlice(
1266                 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1267                 op::Reshape(), op::Constant(), op::Constant()))),
1268             op::Shape("f32[128,3,7,512]"));
1269   const auto rhs =
1270       AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1271   EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs),
1272                           op::Shape("f32[128,2,4,512]")));
1273   EXPECT_EQ(root->window().dimensions(0).padding_low(), 1);
1274   EXPECT_EQ(root->window().dimensions(0).padding_high(), 1);
1275 }
1276 
1277 // (stride * per_shard_window_count) % dilation != 0 but stride == 1
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationStride1LhsTiledRhsReplicated)1278 TEST_F(SpmdPartitioningTest,
1279        ConvolutionBaseDilationStride1LhsTiledRhsReplicated) {
1280   absl::string_view hlo_string = R"(
1281 HloModule module
1282 
1283 ENTRY entry {
1284   %lhs = f32[128,7,7,512] parameter(0)
1285   %lhs.copy = f32[128,7,7,512] copy(%lhs),
1286     sharding={devices=[1,2,1,1]0,1}
1287   %rhs = f32[3,3,512,512] parameter(1)
1288   %rhs.copy = f32[3,3,512,512] copy(%rhs),
1289     sharding={replicated}
1290   ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy),
1291     window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1},
1292     dim_labels=b01f_01io->b01f,
1293     sharding={devices=[1,2,1,1]0,1}
1294 })";
1295 
1296   TF_ASSERT_OK_AND_ASSIGN(auto module,
1297                           PartitionComputation(hlo_string, /*num_devices=*/2));
1298   VLOG(1) << module->ToString();
1299 
1300   const auto root = module->entry_computation()->root_instruction();
1301   const auto lhs =
1302       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1303                                       op::Constant(), op::Reshape(),
1304                                       op::Constant(), op::Constant())),
1305             op::Shape("f32[128,4,7,512]"));
1306   const auto rhs =
1307       AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1308 
1309   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1310                          op::Shape("f32[128,1,7,512]"));
1311   auto start_window = op::Multiply(op::Reshape(), op::Constant());
1312   auto start_input_element = op::Divide(start_window, op::Constant());
1313   auto dynamic_offset_for_padded_concat = op::Subtract(
1314       op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1315                                    start_input_element));
1316   auto pre_masking =
1317       AllOf(op::Shape("f32[128,5,7,512]"),
1318             op::DynamicSlice(
1319                 AllOf(op::Shape("f32[128,6,7,512]"),
1320                       op::Pad(op::Concatenate(left_halo, lhs), op::Constant())),
1321                 op::Constant(), dynamic_offset_for_padded_concat,
1322                 op::Constant(), op::Constant()));
1323   auto masked = op::Select(
1324       op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)),
1325                   op::Broadcast(op::Constant())),
1326       pre_masking, op::Broadcast(op::Constant()));
1327   auto dynamic_offset_on_output = op::Subtract(
1328       start_window, op::Multiply(start_input_element, op::Constant()));
1329   EXPECT_THAT(root,
1330               AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs),
1331                                            op::Shape("f32[128,8,14,512]")),
1332                                      op::Constant(), dynamic_offset_on_output,
1333                                      op::Constant(), op::Constant()),
1334                     op::Shape("f32[128,7,14,512]")));
1335   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1);
1336   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1337 }
1338 
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlap)1339 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) {
1340   absl::string_view hlo_string = R"(
1341 HloModule module
1342 
1343 ge {
1344   a = f32[] parameter(0)
1345   b = f32[] parameter(1)
1346   ROOT compare = pred[] compare(a, b), direction=GE
1347 }
1348 
1349 sum {
1350   c = f32[] parameter(0)
1351   d = f32[] parameter(1)
1352   ROOT add = f32[] add(c, d)
1353 }
1354 
1355 ENTRY entry {
1356   %param = f32[11,4]{1,0} parameter(0)
1357   %param.copy = f32[11,4] copy(%param),
1358     sharding={devices=[4,1]0,1,2,3}
1359   constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1360     sharding={devices=[4,1]0,1,2,3}
1361   constant.1 = f32[] constant(0), sharding={replicated}
1362   ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1363     constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1364     select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1365 })";
1366   TF_ASSERT_OK_AND_ASSIGN(auto module,
1367                           PartitionComputation(hlo_string, /*num_devices=*/4));
1368   VLOG(1) << module->ToString();
1369   const auto root = module->entry_computation()->root_instruction();
1370   auto source =
1371       AllOf(op::Shape("f32[1,2]{1,0}"),
1372             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1373   auto masked_data = AllOf(
1374       op::Shape("f32[3,4]{1,0}"),
1375       op::Select(
1376           op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1377                                               op::Reshape(), op::Constant()))),
1378                       op::Broadcast(op::Constant())),
1379           op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1380                                     op::Reshape(), op::Constant())),
1381           op::Broadcast(op::Constant())));
1382 
1383   EXPECT_THAT(root,
1384               AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1385                     op::Shape("f32[3,4]{1,0}")));
1386   EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1387   EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1388 }
1389 
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlapReshard)1390 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) {
1391   absl::string_view hlo_string = R"(
1392 HloModule module
1393 
1394 ge {
1395   a = f32[] parameter(0)
1396   b = f32[] parameter(1)
1397   ROOT compare = pred[] compare(a, b), direction=GE
1398 }
1399 
1400 sum {
1401   c = f32[] parameter(0)
1402   d = f32[] parameter(1)
1403   ROOT add = f32[] add(c, d)
1404 }
1405 
1406 ENTRY entry {
1407   %param = f32[11,4]{1,0} parameter(0)
1408   %param.copy = f32[11,4] copy(%param),
1409     sharding={devices=[1,4]0,1,2,3}
1410   constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1411     sharding={devices=[4,1]0,1,2,3}
1412   constant.1 = f32[] constant(0), sharding={replicated}
1413   ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1414     constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1415     select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1416 })";
1417   TF_ASSERT_OK_AND_ASSIGN(auto module,
1418                           PartitionComputation(hlo_string, /*num_devices=*/4));
1419   VLOG(1) << module->ToString();
1420   const auto root = module->entry_computation()->root_instruction();
1421   auto source =
1422       AllOf(op::Shape("f32[1,2]{1,0}"),
1423             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1424   auto operand = AllOf(op::Copy(op::DynamicSlice(
1425                            op::Parameter(0), op::Constant(), op::Reshape())),
1426                        op::Shape("f32[11,1]"));
1427   auto reshard_operand = op::Reshape(op::Transpose(
1428       op::AllToAll(op::Reshape(op::Pad(operand, op::Constant())))));
1429   auto masked_data = AllOf(
1430       op::Shape("f32[3,4]{1,0}"),
1431       op::Select(
1432           op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1433                                               op::Reshape(), op::Constant()))),
1434                       op::Broadcast(op::Constant())),
1435           reshard_operand, op::Broadcast(op::Constant())));
1436 
1437   EXPECT_THAT(root,
1438               AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1439                     op::Shape("f32[3,4]{1,0}")));
1440   EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1441   EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1442 }
1443 
TEST_F(SpmdPartitioningTest,SelectAndScatterWithOverlap)1444 TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) {
1445   absl::string_view hlo_string = R"(
1446 HloModule module
1447 
1448 ge {
1449   a = f32[] parameter(0)
1450   b = f32[] parameter(1)
1451   ROOT compare = pred[] compare(a, b), direction=GE
1452 }
1453 
1454 sum {
1455   c = f32[] parameter(0)
1456   d = f32[] parameter(1)
1457   ROOT add = f32[] add(c, d)
1458 }
1459 
1460 ENTRY entry {
1461   %param = f32[11,4]{1,0} parameter(0)
1462   %param.copy = f32[11,4] copy(%param),
1463     sharding={devices=[4,1]0,1,2,3}
1464   constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}),
1465     sharding={devices=[4,1]0,1,2,3}
1466   constant.1 = f32[] constant(0), sharding={replicated}
1467   ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1468     constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0},
1469     select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1470 })";
1471   TF_ASSERT_OK_AND_ASSIGN(auto module,
1472                           PartitionComputation(hlo_string, /*num_devices=*/4));
1473   VLOG(1) << module->ToString();
1474   const auto root = module->entry_computation()->root_instruction();
1475 
1476   auto source_shard =
1477       AllOf(op::Shape("f32[2,2]{1,0}"),
1478             op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant()));
1479   // Max halo size is the same as the shard size, so slice is not needed.
1480   auto source_left_halo = op::CollectivePermute(source_shard);
1481   auto required_source_shard_start =
1482       op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
1483   auto source_with_halo = op::DynamicSlice(
1484       AllOf(op::Shape("f32[5,2]{1,0}"),
1485             op::Pad(op::Concatenate(source_left_halo, source_shard),
1486                     op::Constant())),
1487       op::Subtract(op::Constant(),
1488                    op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1489                                 required_source_shard_start)),
1490       op::Constant());
1491   auto masked_source_with_halo = AllOf(
1492       AllOf(op::Shape("f32[3,2]{1,0}")),
1493       op::Select(
1494           op::Compare(
1495               op::Add(op::Iota(), op::Broadcast(required_source_shard_start)),
1496               op::Broadcast(op::Constant())),
1497           source_with_halo, op::Broadcast(op::Constant())));
1498 
1499   auto data_shard =
1500       AllOf(op::Shape("f32[3,4]{1,0}"),
1501             op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1502                                       op::Reshape(), op::Constant())));
1503   auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1504                               op::CollectivePermute(op::Slice(data_shard)));
1505   auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1506                                op::CollectivePermute(op::Slice(data_shard)));
1507   auto required_data_start_on_padded =
1508       op::Multiply(required_source_shard_start, op::Constant());
1509   auto left_halo_size = op::Subtract(
1510       op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()),
1511       required_data_start_on_padded);
1512   auto data_with_halo =
1513       AllOf(op::Shape("f32[7,4]{1,0}"),
1514             op::DynamicSlice(
1515                 AllOf(op::Shape("f32[8,4]{1,0}"),
1516                       op::Pad(op::Concatenate(data_left_halo, data_shard,
1517                                               data_right_halo),
1518                               op::Constant())),
1519                 op::Subtract(op::Constant(), left_halo_size), op::Constant()));
1520   auto index_on_padded =
1521       op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded));
1522   auto masked_data_with_halo = op::Select(
1523       op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())),
1524               op::Compare(index_on_padded, op::Broadcast(op::Constant()))),
1525       data_with_halo, op::Broadcast(op::Constant()));
1526 
1527   EXPECT_THAT(
1528       root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo,
1529                                                         masked_source_with_halo,
1530                                                         op::Constant()),
1531                                    left_halo_size, op::Constant()),
1532                   op::Shape("f32[3,4]{1,0}")));
1533   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0);
1534   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1535 }
1536 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled)1537 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) {
1538   absl::string_view hlo_string = R"(
1539 HloModule module
1540 
1541 ENTRY entry {
1542   %lhs = f32[128,56,56,64] parameter(0)
1543   %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1544   %rhs = f32[128,56,56,256] parameter(1)
1545   %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1546   ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1547     window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1548 })";
1549 
1550   TF_ASSERT_OK_AND_ASSIGN(auto module,
1551                           PartitionComputation(hlo_string, /*num_devices=*/2));
1552   VLOG(1) << module->ToString();
1553 
1554   const auto root = module->entry_computation()->root_instruction();
1555   const auto lhs = AllOf(
1556       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1557                                 op::Constant(), op::Constant())),
1558       op::Shape("f32[128,28,56,64]"));
1559   const auto rhs = AllOf(
1560       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1561                                 op::Constant(), op::Constant())),
1562       op::Shape("f32[128,28,56,256]"));
1563 
1564   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1565                           op::Shape("f32[1,1,64,256]")));
1566 }
1567 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowReversal)1568 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) {
1569   absl::string_view hlo_string = R"(
1570 HloModule module
1571 
1572 ENTRY entry {
1573   %lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1}
1574   %rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0}
1575   ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs),
1576     window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf,
1577     sharding={replicated}
1578 })";
1579 
1580   TF_ASSERT_OK_AND_ASSIGN(auto module,
1581                           PartitionComputation(hlo_string, /*num_devices=*/2));
1582   VLOG(1) << module->ToString();
1583 
1584   const auto lhs_masked =
1585       AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _));
1586   const auto rhs_left_padded =
1587       op::Concatenate(op::CollectivePermute(op::Slice(op::Parameter(1))),
1588                       op::Slice(op::Parameter(1)));
1589   const auto rhs_masked =
1590       AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _));
1591 
1592   const auto root = module->entry_computation()->root_instruction();
1593   EXPECT_THAT(root,
1594               AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)),
1595                     op::Shape("f32[1,64,256]")));
1596 }
1597 
TEST_F(SpmdPartitioningTest,DotLhsTiledRhsTiledWithReshard)1598 TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) {
1599   absl::string_view hlo_string = R"(
1600 HloModule module
1601 
1602 ENTRY entry {
1603   %lhs = f32[128,56,56,64] parameter(0)
1604   %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1605   %rhs = f32[128,56,56,256] parameter(1)
1606   %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1607   ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1608     window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1609 })";
1610 
1611   TF_ASSERT_OK_AND_ASSIGN(auto module,
1612                           PartitionComputation(hlo_string, /*num_devices=*/2));
1613   VLOG(1) << module->ToString();
1614 
1615   const auto root = module->entry_computation()->root_instruction();
1616   const auto lhs = AllOf(
1617       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1618                                 op::Constant(), op::Constant())),
1619       op::Shape("f32[128,28,56,64]"));
1620   const auto rhs = AllOf(
1621       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1622                                 op::Constant(), op::Constant())),
1623       op::Shape("f32[64,56,56,256]"));
1624   auto all_to_all =
1625       AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]"));
1626   auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all)));
1627 
1628   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)),
1629                           op::Shape("f32[1,1,64,256]")));
1630 }
1631 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithReshard)1632 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) {
1633   absl::string_view hlo_string = R"(
1634 HloModule module
1635 
1636 ENTRY entry {
1637   %lhs = f32[128,56,56,512] parameter(0)
1638   %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1639   %rhs = f32[128,28,28,64] parameter(1)
1640   %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1641   ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy),
1642     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2},
1643     dim_labels=f01b_i01o->01bf, sharding={replicated}
1644 })";
1645 
1646   TF_ASSERT_OK_AND_ASSIGN(auto module,
1647                           PartitionComputation(hlo_string, /*num_devices=*/2));
1648   VLOG(1) << module->ToString();
1649 
1650   const auto root = module->entry_computation()->root_instruction();
1651   const auto lhs = AllOf(
1652       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1653                                 op::Constant(), op::Constant())),
1654       op::Shape("f32[128,28,56,512]"));
1655   const auto rhs = AllOf(
1656       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1657                                 op::Constant(), op::Constant())),
1658       op::Shape("f32[64,28,28,64]"));
1659   auto all_to_all =
1660       AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]"));
1661   auto reshard = op::Reshape(op::Transpose(all_to_all));
1662 
1663   EXPECT_THAT(root,
1664               AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)),
1665                     op::Shape("f32[1,1,512,64]")));
1666 }
1667 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned)1668 TEST_F(SpmdPartitioningTest,
1669        ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned) {
1670   absl::string_view hlo_string = R"(
1671 HloModule module
1672 
1673 ENTRY entry {
1674   %lhs = f32[8,28,28,8] parameter(0)
1675   %lhs.copy = f32[8,28,28,8] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
1676   %rhs = f32[8,14,14,64] parameter(1)
1677   %rhs.copy = f32[8,14,14,64] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
1678   ROOT %conv = f32[1,1,8,64] convolution(%lhs.copy, %rhs.copy),
1679     window={size=14x14 pad=0_-1x0_-1 rhs_dilate=2x2},
1680     dim_labels=f01b_i01o->01bf, sharding={replicated}
1681 })";
1682 
1683   TF_ASSERT_OK_AND_ASSIGN(auto module,
1684                           PartitionComputation(hlo_string, /*num_devices=*/4));
1685   VLOG(1) << module->ToString();
1686 
1687   const auto root = module->entry_computation()->root_instruction();
1688   const auto lhs = AllOf(
1689       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1690                                 op::Constant(), op::Constant())),
1691       op::Shape("f32[8,7,28,8]"));
1692   const auto rhs = AllOf(op::Pad(op::Parameter(), op::Constant()),
1693                          op::Shape("f32[8,16,14,64]"));
1694   auto selected_rhs = AllOf(
1695       op::Select(op::Compare(),
1696                  op::Copy(op::DynamicSlice(rhs, op::Constant(), op::Reshape(),
1697                                            op::Constant(), op::Constant())),
1698                  op::Broadcast()),
1699       op::Shape("f32[8,4,14,64]"));
1700   auto right_halo =
1701       AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,2,28,8]"));
1702   auto selected_lhs =
1703       AllOf(op::DynamicSlice(
1704                 op::Pad(op::Concatenate(lhs, right_halo), op::Constant()),
1705                 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
1706             op::Shape("f32[8,7,28,8]"));
1707   EXPECT_THAT(root,
1708               AllOf(op::AllReduce(op::Convolution(selected_lhs, selected_rhs)),
1709                     op::Shape("f32[1,1,8,64]")));
1710 }
1711 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding)1712 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) {
1713   absl::string_view hlo_string = R"(
1714 HloModule module
1715 
1716 ENTRY entry {
1717   %lhs = f32[32,28,28,128] parameter(0)
1718   %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1719   %rhs = f32[32,28,28,64] parameter(1)
1720   %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1721   ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1722     window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1723 })";
1724 
1725   TF_ASSERT_OK_AND_ASSIGN(
1726       auto module,
1727       PartitionComputation(hlo_string, /*num_devices=*/2,
1728                            /*conv_halo_exchange_always_on_lhs=*/false));
1729   VLOG(1) << module->ToString();
1730 
1731   const auto root = module->entry_computation()->root_instruction();
1732   const auto lhs = AllOf(
1733       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1734                                 op::Constant(), op::Constant())),
1735       op::Shape("f32[32,14,28,128]"));
1736   const auto rhs = AllOf(
1737       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1738                                 op::Constant(), op::Constant())),
1739       op::Shape("f32[32,14,28,64]"));
1740 
1741   auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1742                          op::Shape("f32[32,1,28,64]"));
1743   auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1744                           op::Shape("f32[32,1,28,64]"));
1745   EXPECT_THAT(root,
1746               AllOf(op::AllReduce(op::Convolution(
1747                         lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1748                                    op::Shape("f32[32,16,28,64]")))),
1749                     op::Shape("f32[3,3,128,64]")));
1750 }
1751 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate)1752 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) {
1753   absl::string_view hlo_string = R"(
1754 HloModule module
1755 
1756 ENTRY entry {
1757   %lhs = f32[128,224,224,3] parameter(0)
1758   %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1759   %rhs = f32[128,112,112,64] parameter(1)
1760   %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1761   ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1762     window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1763 })";
1764 
1765   TF_ASSERT_OK_AND_ASSIGN(
1766       auto module,
1767       PartitionComputation(hlo_string, /*num_devices=*/2,
1768                            /*conv_halo_exchange_always_on_lhs=*/false));
1769   VLOG(1) << module->ToString();
1770 
1771   const auto root = module->entry_computation()->root_instruction();
1772   const auto lhs = AllOf(
1773       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1774                                 op::Constant(), op::Constant())),
1775       op::Shape("f32[128,112,224,3]"));
1776   const auto rhs = AllOf(
1777       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1778                                 op::Constant(), op::Constant())),
1779       op::Shape("f32[128,56,112,64]"));
1780 
1781   auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1782                          op::Shape("f32[128,2,112,64]"));
1783   auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1784                           op::Shape("f32[128,2,112,64]"));
1785   EXPECT_THAT(root,
1786               AllOf(op::AllReduce(op::Convolution(
1787                         lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1788                                    op::Shape("f32[128,60,112,64]")))),
1789                     op::Shape("f32[7,7,3,64]")));
1790 }
1791 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding)1792 TEST_F(SpmdPartitioningTest,
1793        ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) {
1794   absl::string_view hlo_string = R"(
1795 HloModule module
1796 
1797 ENTRY entry {
1798   %lhs = f32[128,56,56,256] parameter(0)
1799   %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1800   %rhs = f32[128,28,28,512] parameter(1)
1801   %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1802   ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1803     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1804 })";
1805 
1806   TF_ASSERT_OK_AND_ASSIGN(
1807       auto module,
1808       PartitionComputation(hlo_string, /*num_devices=*/2,
1809                            /*conv_halo_exchange_always_on_lhs=*/false));
1810   VLOG(1) << module->ToString();
1811 
1812   const auto root = module->entry_computation()->root_instruction();
1813   const auto lhs = AllOf(
1814       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1815                                 op::Constant(), op::Constant())),
1816       op::Shape("f32[128,28,56,256]"));
1817   const auto rhs = AllOf(
1818       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1819                                 op::Constant(), op::Constant())),
1820       op::Shape("f32[128,14,28,512]"));
1821 
1822   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1823                           op::Shape("f32[1,1,256,512]")));
1824 }
1825 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven)1826 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) {
1827   absl::string_view hlo_string = R"(
1828 HloModule module
1829 
1830 ENTRY entry {
1831   %lhs = f32[128,14,14,512] parameter(0)
1832   %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1833   %rhs = f32[128,7,7,512] parameter(1)
1834   %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1835   ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1836     window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1837 })";
1838 
1839   TF_ASSERT_OK_AND_ASSIGN(
1840       auto module,
1841       PartitionComputation(hlo_string, /*num_devices=*/2,
1842                            /*conv_halo_exchange_always_on_lhs=*/false));
1843   VLOG(1) << module->ToString();
1844 
1845   const auto root = module->entry_computation()->root_instruction();
1846   const auto lhs = AllOf(
1847       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1848                                 op::Constant(), op::Constant())),
1849       op::Shape("f32[128,7,14,512]"));
1850   const auto rhs = AllOf(
1851       op::Select(op::Compare(),
1852                  op::Copy(op::DynamicSlice(
1853                      op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1854                      op::Reshape(), op::Constant(), op::Constant())),
1855                  op::Broadcast()),
1856       op::Shape("f32[128,4,7,512]"));
1857 
1858   auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1859                          op::Shape("f32[128,1,7,512]"));
1860   EXPECT_THAT(root,
1861               AllOf(op::AllReduce(op::Convolution(
1862                         AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()),
1863                                                op::Constant(), op::Subtract(),
1864                                                op::Constant(), op::Constant()),
1865                               op::Shape("f32[128,10,14,512]")),
1866                         AllOf(op::Concatenate(left_halo, rhs),
1867                               op::Shape("f32[128,5,7,512]")))),
1868                     op::Shape("f32[3,3,512,512]")));
1869 }
1870 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs)1871 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) {
1872   absl::string_view hlo_string = R"(
1873 HloModule module
1874 
1875 ENTRY entry {
1876   %lhs = f32[32,28,28,128] parameter(0)
1877   %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1878   %rhs = f32[32,28,28,64] parameter(1)
1879   %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1880   ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1881     window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1882 })";
1883 
1884   TF_ASSERT_OK_AND_ASSIGN(auto module,
1885                           PartitionComputation(hlo_string, /*num_devices=*/2));
1886   VLOG(1) << module->ToString();
1887 
1888   const auto root = module->entry_computation()->root_instruction();
1889   const auto lhs = AllOf(
1890       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1891                                 op::Constant(), op::Constant())),
1892       op::Shape("f32[32,14,28,128]"));
1893   const auto rhs = AllOf(
1894       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1895                                 op::Constant(), op::Constant())),
1896       op::Shape("f32[32,14,28,64]"));
1897 
1898   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1899                          op::Shape("f32[32,1,28,128]"));
1900   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1901                           op::Shape("f32[32,1,28,128]"));
1902   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1903                               AllOf(op::Concatenate(left_halo, lhs, right_halo),
1904                                     op::Shape("f32[32,16,28,128]")),
1905                               rhs)),
1906                           op::Shape("f32[3,3,128,64]")));
1907 }
1908 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs)1909 TEST_F(SpmdPartitioningTest,
1910        ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) {
1911   absl::string_view hlo_string = R"(
1912 HloModule module
1913 
1914 ENTRY entry {
1915   %lhs = f32[128,224,224,3] parameter(0)
1916   %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1917   %rhs = f32[128,112,112,64] parameter(1)
1918   %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1919   ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1920     window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1921 })";
1922 
1923   TF_ASSERT_OK_AND_ASSIGN(auto module,
1924                           PartitionComputation(hlo_string, /*num_devices=*/2));
1925   VLOG(1) << module->ToString();
1926 
1927   const auto root = module->entry_computation()->root_instruction();
1928   const auto lhs = AllOf(
1929       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1930                                 op::Constant(), op::Constant())),
1931       op::Shape("f32[128,112,224,3]"));
1932   const auto rhs = AllOf(
1933       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1934                                 op::Constant(), op::Constant())),
1935       op::Shape("f32[128,56,112,64]"));
1936 
1937   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1938                          op::Shape("f32[128,3,224,3]"));
1939   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1940                           op::Shape("f32[128,2,224,3]"));
1941   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1942                               AllOf(op::Concatenate(left_halo, lhs, right_halo),
1943                                     op::Shape("f32[128,117,224,3]")),
1944                               rhs)),
1945                           op::Shape("f32[7,7,3,64]")));
1946 }
1947 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs)1948 TEST_F(SpmdPartitioningTest,
1949        ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) {
1950   absl::string_view hlo_string = R"(
1951 HloModule module
1952 
1953 ENTRY entry {
1954   %lhs = f32[128,56,56,256] parameter(0)
1955   %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1956   %rhs = f32[128,28,28,512] parameter(1)
1957   %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1958   ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1959     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1960 })";
1961 
1962   TF_ASSERT_OK_AND_ASSIGN(auto module,
1963                           PartitionComputation(hlo_string, /*num_devices=*/2));
1964   VLOG(1) << module->ToString();
1965 
1966   const auto root = module->entry_computation()->root_instruction();
1967   const auto lhs = AllOf(
1968       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1969                                 op::Constant(), op::Constant())),
1970       op::Shape("f32[128,28,56,256]"));
1971   const auto rhs = AllOf(
1972       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1973                                 op::Constant(), op::Constant())),
1974       op::Shape("f32[128,14,28,512]"));
1975 
1976   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)),
1977                           op::Shape("f32[1,1,256,512]")));
1978 }
1979 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs)1980 TEST_F(SpmdPartitioningTest,
1981        ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) {
1982   absl::string_view hlo_string = R"(
1983 HloModule module
1984 
1985 ENTRY entry {
1986   %lhs = f32[128,14,14,512] parameter(0)
1987   %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1988   %rhs = f32[128,7,7,512] parameter(1)
1989   %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1990   ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1991     window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1992 })";
1993 
1994   TF_ASSERT_OK_AND_ASSIGN(auto module,
1995                           PartitionComputation(hlo_string, /*num_devices=*/2));
1996   VLOG(1) << module->ToString();
1997 
1998   const auto root = module->entry_computation()->root_instruction();
1999   const auto lhs = AllOf(
2000       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
2001                                 op::Constant(), op::Constant())),
2002       op::Shape("f32[128,7,14,512]"));
2003   const auto rhs = AllOf(
2004       op::Select(op::Compare(),
2005                  op::Copy(op::DynamicSlice(
2006                      op::Pad(op::Parameter(), op::Constant()), op::Constant(),
2007                      op::Reshape(), op::Constant(), op::Constant())),
2008                  op::Broadcast()),
2009       op::Shape("f32[128,4,7,512]"));
2010 
2011   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
2012                           op::Shape("f32[128,1,14,512]"));
2013   EXPECT_THAT(
2014       root, AllOf(op::AllReduce(op::Convolution(
2015                       AllOf(op::DynamicSlice(
2016                                 AllOf(op::Pad(op::Concatenate(lhs, right_halo),
2017                                               op::Constant()),
2018                                       op::Shape("f32[128,10,14,512]")),
2019                                 op::Constant(), op::Reshape(), op::Constant(),
2020                                 op::Constant()),
2021                             op::Shape("f32[128,9,14,512]")),
2022                       rhs)),
2023                   op::Shape("f32[3,3,512,512]")));
2024 }
2025 
TEST_F(SpmdPartitioningTest,ConcatenateAlongNonPartitionedDimension)2026 TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) {
2027   absl::string_view hlo_string = R"(
2028 HloModule module
2029 
2030 ENTRY entry {
2031   %param0 = f32[14,257] parameter(0)
2032   %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1}
2033   %param1 = f32[14,116] parameter(1)
2034   %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1}
2035   ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
2036     dimensions={1}, sharding={devices=[2,1]0,1}
2037 })";
2038 
2039   TF_ASSERT_OK_AND_ASSIGN(auto module,
2040                           PartitionComputation(hlo_string, /*num_devices=*/2));
2041   VLOG(1) << module->ToString();
2042 
2043   const auto root = module->entry_computation()->root_instruction();
2044   auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2045                                                 op::Constant())),
2046                       op::Shape("f32[7,257]"));
2047   auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2048                                                 op::Constant())),
2049                       op::Shape("f32[7,116]"));
2050   EXPECT_THAT(root,
2051               AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]")));
2052 }
2053 
TEST_F(SpmdPartitioningTest,ConcatenateAlongPartitionedDimension)2054 TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) {
2055   absl::string_view hlo_string = R"(
2056 HloModule module
2057 
2058 ENTRY entry {
2059   %param0 = f32[14,257] parameter(0)
2060   %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1}
2061   %param1 = f32[14,116] parameter(1)
2062   %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1}
2063   ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
2064     dimensions={1}, sharding={devices=[1,2]0,1}
2065 })";
2066 
2067   TF_ASSERT_OK_AND_ASSIGN(auto module,
2068                           PartitionComputation(hlo_string, /*num_devices=*/2));
2069   VLOG(1) << module->ToString();
2070 
2071   const auto root = module->entry_computation()->root_instruction();
2072   auto param0 =
2073       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2074                                       op::Constant(), op::Reshape())),
2075             op::Shape("f32[14,129]"));
2076   auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
2077                                                 op::Reshape())),
2078                       op::Shape("f32[14,58]"));
2079   EXPECT_THAT(root, AllOf(op::DynamicSlice(
2080                               AllOf(op::AllReduce(op::DynamicUpdateSlice(
2081                                         op::DynamicUpdateSlice(
2082                                             op::Broadcast(), param0,
2083                                             op::Constant(), op::Multiply()),
2084                                         param1, op::Constant(), op::Add())),
2085                                     op::Shape("f32[14,374]")),
2086                               op::Constant(), op::Multiply()),
2087                           op::Shape("f32[14,187]")));
2088 }
2089 
TEST_F(SpmdPartitioningTest,ConcatenateAlongBothDimensions)2090 TEST_F(SpmdPartitioningTest, ConcatenateAlongBothDimensions) {
2091   const char* const hlo_string = R"(
2092 HloModule module
2093 
2094 ENTRY entry {
2095   %param0 = f32[14,257] parameter(0), sharding={devices=[2,2]0,1,2,3}
2096   %param1 = f32[14,116] parameter(1), sharding={devices=[2,2]0,1,2,3}
2097   ROOT %concatenate = f32[14,373] concatenate(%param0, %param1),
2098     dimensions={1}, sharding={devices=[2,2]0,1,2,3}
2099 })";
2100 
2101   TF_ASSERT_OK_AND_ASSIGN(auto module,
2102                           PartitionComputation(hlo_string, /*num_devices=*/4));
2103   VLOG(1) << module->ToString();
2104 
2105   const auto root = module->entry_computation()->root_instruction();
2106   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[7,129]"));
2107   auto param1 = AllOf(op::Parameter(1), op::Shape("f32[7,58]"));
2108   EXPECT_THAT(root, AllOf(op::DynamicSlice(
2109                               AllOf(op::AllReduce(op::DynamicUpdateSlice(
2110                                         op::DynamicUpdateSlice(
2111                                             op::Broadcast(), param0,
2112                                             op::Constant(), op::Multiply()),
2113                                         param1, op::Constant(), op::Add())),
2114                                     op::Shape("f32[7,374]")),
2115                               op::Constant(), op::Multiply()),
2116                           op::Shape("f32[7,187]")));
2117 }
2118 
TEST_F(SpmdPartitioningTest,PadAlongNonPartitionedDimension)2119 TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) {
2120   absl::string_view hlo_string = R"(
2121 HloModule module
2122 
2123 ENTRY entry {
2124   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2125   %const = f32[] constant(0)
2126   ROOT %pad = f32[128,17,257] pad(%param0, %const), padding=0_0x1_2x0_0,
2127     sharding={devices=[1,1,2]0,1}
2128 })";
2129 
2130   TF_ASSERT_OK_AND_ASSIGN(auto module,
2131                           PartitionComputation(hlo_string, /*num_devices=*/2));
2132   VLOG(1) << module->ToString();
2133 
2134   const auto root = module->entry_computation()->root_instruction();
2135   auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2136   EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()),
2137                           op::Shape("f32[128,17,129]")));
2138 }
2139 
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimension)2140 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimension) {
2141   absl::string_view hlo_string = R"(
2142 HloModule module
2143 
2144 ENTRY entry {
2145   %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1}
2146   %const = f32[] constant(0)
2147   ROOT %pad = f32[14,259] pad(%param0, %const), padding=0_0x0_2,
2148     sharding={devices=[1,2]0,1}
2149 })";
2150 
2151   TF_ASSERT_OK_AND_ASSIGN(auto module,
2152                           PartitionComputation(hlo_string, /*num_devices=*/2));
2153   VLOG(1) << module->ToString();
2154 
2155   const auto root = module->entry_computation()->root_instruction();
2156   auto param0 = AllOf(op::Parameter(), op::Shape("f32[14,129]"));
2157   auto after_halo_exchange =
2158       AllOf(op::Shape("f32[14,130]"),
2159             op::Concatenate(param0, op::CollectivePermute(op::Slice(param0))));
2160   auto pad = AllOf(op::Shape("f32[14,131]"),
2161                    op::Pad(after_halo_exchange, op::Constant()));
2162   EXPECT_THAT(root, op::Select(_, op::DynamicSlice(pad, op::Constant(), _), _));
2163 }
2164 
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimensionWithInteriorPadding)2165 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) {
2166   absl::string_view hlo_string = R"(
2167 HloModule module
2168 
2169 ENTRY entry {
2170   %param0 = f32[7] parameter(0), sharding={devices=[2]0,1}
2171   %param1 = f32[] parameter(1), sharding={replicated}
2172   ROOT %pad = f32[22] pad(%param0, %param1), padding=2_1_2,
2173     sharding={devices=[2]0,1}
2174 })";
2175 
2176   TF_ASSERT_OK_AND_ASSIGN(auto module,
2177                           PartitionComputation(hlo_string, /*num_devices=*/2));
2178   VLOG(1) << module->ToString();
2179   const auto root = module->entry_computation()->root_instruction();
2180 
2181   auto param0 = AllOf(op::Parameter(), op::Shape("f32[4]"));
2182   auto after_halo_exchange = AllOf(
2183       op::Shape("f32[4]"),
2184       op::DynamicSlice(
2185           AllOf(op::Shape("f32[5]"),
2186                 op::Pad(AllOf(op::Shape("f32[4]"),
2187                               op::Concatenate(
2188                                   op::CollectivePermute(op::Slice(param0)),
2189                                   op::Slice(param0))),
2190                         op::Parameter(1))),
2191           _));
2192   auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2193   EXPECT_THAT(root, op::DynamicSlice(pad, _));
2194 }
2195 
TEST_F(SpmdPartitioningTest,PartialReplicatePad)2196 TEST_F(SpmdPartitioningTest, PartialReplicatePad) {
2197   absl::string_view hlo_string = R"(
2198 HloModule module
2199 
2200 ENTRY entry {
2201   %param0 = f32[11,7] parameter(0),
2202     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2203   %param1 = f32[] parameter(1), sharding={replicated}
2204   ROOT %pad = f32[27,22] pad(%param0, %param1), padding=2_4_1x2_1_2,
2205     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2206 })";
2207 
2208   TF_ASSERT_OK_AND_ASSIGN(auto module,
2209                           PartitionComputation(hlo_string, /*num_devices=*/4));
2210   VLOG(1) << module->ToString();
2211   const auto root = module->entry_computation()->root_instruction();
2212 
2213   auto param0 = AllOf(op::Parameter(), op::Shape("f32[11,4]"));
2214   auto after_halo_exchange = AllOf(
2215       op::Shape("f32[11,4]"),
2216       op::DynamicSlice(
2217           AllOf(op::Shape("f32[11,5]"),
2218                 op::Pad(AllOf(op::Shape("f32[11,4]"),
2219                               op::Concatenate(
2220                                   op::CollectivePermute(op::Slice(param0)),
2221                                   op::Slice(param0))),
2222                         op::Parameter(1))),
2223           op::Constant(), _));
2224   auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2225   EXPECT_THAT(root, AllOf(op::DynamicSlice(pad, op::Constant(), _),
2226                           op::Shape("f32[27,11]")));
2227 }
2228 
TEST_F(SpmdPartitioningTest,SliceAlongNonPartitionedDimension)2229 TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) {
2230   absl::string_view hlo_string = R"(
2231 HloModule module
2232 
2233 ENTRY entry {
2234   %param0 = f32[128,14,257] parameter(0)
2235   %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1}
2236   ROOT %slice = f32[128,11,257] slice(%param0.copy),
2237     slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1}
2238 })";
2239 
2240   TF_ASSERT_OK_AND_ASSIGN(auto module,
2241                           PartitionComputation(hlo_string, /*num_devices=*/2));
2242   VLOG(1) << module->ToString();
2243 
2244   const auto root = module->entry_computation()->root_instruction();
2245   auto param0 = AllOf(
2246       op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2247                                 op::Constant(), op::Constant(), op::Reshape())),
2248       op::Shape("f32[128,14,129]"));
2249   EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2250 }
2251 
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension)2252 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) {
2253   absl::string_view hlo_string = R"(
2254 HloModule module
2255 
2256 ENTRY entry {
2257   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2258   ROOT %slice = f32[63,14,251] slice(%param0),
2259     slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1}
2260 })";
2261 
2262   TF_ASSERT_OK_AND_ASSIGN(auto module,
2263                           PartitionComputation(hlo_string, /*num_devices=*/2));
2264   VLOG(1) << module->ToString();
2265 
2266   const auto root = module->entry_computation()->root_instruction();
2267   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[128,14,129]"));
2268   EXPECT_THAT(
2269       root,
2270       AllOf(op::Slice(AllOf(
2271                 op::DynamicSlice(
2272                     AllOf(op::Concatenate(
2273                               op::Slice(param0),
2274                               AllOf(op::CollectivePermute(op::Slice(param0)),
2275                                     op::Shape("f32[128,14,2]"))),
2276                           op::Shape("f32[128,14,129]")),
2277                     op::Constant(), op::Constant(), op::Add()),
2278                 op::Shape("f32[128,14,126]"))),
2279             op::Shape("f32[63,14,126]")));
2280 }
2281 
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension2)2282 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension2) {
2283   absl::string_view hlo_string = R"(
2284 HloModule module
2285 
2286 ENTRY entry {
2287   %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2288   ROOT %slice = f32[1] slice(%param0),
2289     slice={[3:4]}, sharding={devices=[4]0,1,2,3}
2290 })";
2291 
2292   TF_ASSERT_OK_AND_ASSIGN(auto module,
2293                           PartitionComputation(hlo_string, /*num_devices=*/4));
2294   VLOG(1) << module->ToString();
2295 
2296   const auto root = module->entry_computation()->root_instruction();
2297   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2298   EXPECT_THAT(root, AllOf(op::Copy(op::CollectivePermute(param0)),
2299                           op::Shape("f32[1]")));
2300 }
2301 
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRight)2302 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRight) {
2303   absl::string_view hlo_string = R"(
2304 HloModule module
2305 
2306 ENTRY entry {
2307   %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2308   %init = f32[] constant(2.0)
2309   %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2310   %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2311   %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2312   ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2313 })";
2314 
2315   TF_ASSERT_OK_AND_ASSIGN(auto module,
2316                           PartitionComputation(hlo_string, /*num_devices=*/4));
2317   VLOG(1) << module->ToString();
2318 
2319   const auto root = module->entry_computation()->root_instruction();
2320   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2321   EXPECT_THAT(root, AllOf(op::Select(_, op::CollectivePermute(param0), _),
2322                           op::Shape("f32[1]")));
2323 }
2324 
2325 // Same as above except that it uses zero padding, so there is no need for
2326 // masking.
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRightNoMasking)2327 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRightNoMasking) {
2328   absl::string_view hlo_string = R"(
2329 HloModule module
2330 
2331 ENTRY entry {
2332   %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2333   %init = f32[] constant(0)
2334   %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2335   %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2336   %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2337   ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2338 })";
2339 
2340   TF_ASSERT_OK_AND_ASSIGN(auto module,
2341                           PartitionComputation(hlo_string, /*num_devices=*/4));
2342   VLOG(1) << module->ToString();
2343 
2344   const auto root = module->entry_computation()->root_instruction();
2345   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2346   EXPECT_THAT(root, AllOf(op::CollectivePermute(param0), op::Shape("f32[1]")));
2347 }
2348 
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRight)2349 TEST_F(SpmdPartitioningTest, MergedSliceThenConcatRotateRight) {
2350   absl::string_view hlo_string = R"(
2351 HloModule module
2352 
2353 ENTRY entry {
2354   %param0 = f32[12] parameter(0), sharding={devices=[4]0,1,2,3}
2355   %slice0 = f32[2] slice(%param0), slice={[10:12]}, sharding={devices=[4]0,1,2,3}
2356   %slice1 = f32[10] slice(%param0), slice={[0:10]}, sharding={devices=[4]0,1,2,3}
2357   ROOT %concat = f32[12] concatenate(%slice0, %slice1), dimensions={0},
2358     sharding={devices=[4]0,1,2,3}
2359 })";
2360 
2361   TF_ASSERT_OK_AND_ASSIGN(auto module,
2362                           PartitionComputation(hlo_string, /*num_devices=*/4));
2363   VLOG(1) << module->ToString();
2364 
2365   const auto root = module->entry_computation()->root_instruction();
2366   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2367   auto rotate = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2368                                 op::Slice(param0));
2369   EXPECT_THAT(root, AllOf(rotate, op::Shape("f32[3]")));
2370 }
2371 
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithAlignedPadding)2372 TEST_F(SpmdPartitioningTest,
2373        MergedSliceThenConcatRotateRightWithAlignedPadding) {
2374   absl::string_view hlo_string = R"(
2375 HloModule module
2376 
2377 ENTRY entry {
2378   %param0 = f32[6] parameter(0), sharding={devices=[4]0,1,2,3}
2379   %slice0 = f32[2] slice(%param0), slice={[4:6]}, sharding={devices=[4]0,1,2,3}
2380   %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2381   ROOT %concat = f32[6] concatenate(%slice0, %slice1), dimensions={0},
2382     sharding={devices=[4]0,1,2,3}
2383 })";
2384 
2385   TF_ASSERT_OK_AND_ASSIGN(auto module,
2386                           PartitionComputation(hlo_string, /*num_devices=*/4));
2387   VLOG(1) << module->ToString();
2388 
2389   const auto root = module->entry_computation()->root_instruction();
2390   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[2]"));
2391   EXPECT_THAT(root, op::CollectivePermute(param0));
2392 }
2393 
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithUnalignedPadding)2394 TEST_F(SpmdPartitioningTest,
2395        MergedSliceThenConcatRotateRightWithUnalignedPadding) {
2396   absl::string_view hlo_string = R"(
2397 HloModule module
2398 
2399 ENTRY entry {
2400   %param0 = f32[10] parameter(0), sharding={devices=[4]0,1,2,3}
2401   %slice0 = f32[6] slice(%param0), slice={[4:10]}, sharding={devices=[4]0,1,2,3}
2402   %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2403   ROOT %concat = f32[10] concatenate(%slice0, %slice1), dimensions={0},
2404     sharding={devices=[4]0,1,2,3}
2405 })";
2406 
2407   TF_ASSERT_OK_AND_ASSIGN(auto module,
2408                           PartitionComputation(hlo_string, /*num_devices=*/4));
2409   VLOG(1) << module->ToString();
2410 
2411   const auto root = module->entry_computation()->root_instruction();
2412   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2413   auto rotate0 = op::CollectivePermute(param0);
2414   auto rotate1 = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2415                                  op::CollectivePermute(op::Slice(param0)));
2416   EXPECT_THAT(root,
2417               AllOf(op::Select(_, rotate1, rotate0), op::Shape("f32[3]")));
2418 }
2419 
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongNonPartitionedDimension)2420 TEST_F(SpmdPartitioningTest,
2421        PartialReplicateSliceAlongNonPartitionedDimension) {
2422   absl::string_view hlo_string = R"(
2423 HloModule module
2424 
2425 ENTRY entry {
2426   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2427   ROOT %slice = f32[128,11,257] slice(%param0),
2428     slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2429 })";
2430 
2431   TF_ASSERT_OK_AND_ASSIGN(auto module,
2432                           PartitionComputation(hlo_string, /*num_devices=*/4));
2433   VLOG(1) << module->ToString();
2434 
2435   const auto root = module->entry_computation()->root_instruction();
2436   auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2437   EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2438 }
2439 
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongPartitionedDimension)2440 TEST_F(SpmdPartitioningTest, PartialReplicateSliceAlongPartitionedDimension) {
2441   absl::string_view hlo_string = R"(
2442 HloModule module
2443 
2444 ENTRY entry {
2445   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2446   ROOT %slice = f32[63,14,251] slice(%param0),
2447     slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2448 })";
2449 
2450   TF_ASSERT_OK_AND_ASSIGN(auto module,
2451                           PartitionComputation(hlo_string, /*num_devices=*/4));
2452   VLOG(1) << module->ToString();
2453 
2454   const auto root = module->entry_computation()->root_instruction();
2455   auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2456   EXPECT_THAT(
2457       root,
2458       AllOf(
2459           op::Slice(AllOf(
2460               op::DynamicSlice(
2461                   AllOf(op::Concatenate(
2462                             op::Slice(param0),
2463                             AllOf(op::CollectivePermute(op::Slice(param0)),
2464                                   op::Shape("f32[128,14,2]"))),
2465                         op::Shape("f32[128,14,129]")),
2466                   op::Constant(), op::Constant(),
2467                   op::Add(op::Multiply(op::Reshape(op::DynamicSlice(
2468                                            op::Constant(), op::PartitionId())),
2469                                        op::Constant()),
2470                           op::Constant())),
2471               op::Shape("f32[128,14,126]"))),
2472           op::Shape("f32[63,14,126]")));
2473 }
2474 
TEST_F(SpmdPartitioningTest,SortAlongNonPartitionedDimension)2475 TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) {
2476   absl::string_view hlo_string = R"(
2477 HloModule module
2478 
2479 ge {
2480   p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated}
2481   bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2482   constant = s32[]{:T(256)} constant(0), sharding={replicated}
2483   compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated}
2484   constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated}
2485   bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2486   subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated}
2487   bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated}
2488   select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated}
2489   p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated}
2490   bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2491   compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated}
2492   bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2493   subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated}
2494   bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated}
2495   select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated}
2496   compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated}
2497   compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated}
2498   compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated}
2499   p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated}
2500   p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated}
2501   compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated}
2502   ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated}
2503 }
2504 
2505 ENTRY entry {
2506   %param0 = f32[128,14,257] parameter(0)
2507   %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1}
2508   %param1 = s32[128,14,257] parameter(1)
2509   %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1}
2510   ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)})
2511     sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true,
2512     to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}}
2513 })";
2514 
2515   TF_ASSERT_OK_AND_ASSIGN(auto module,
2516                           PartitionComputation(hlo_string, /*num_devices=*/2));
2517   VLOG(1) << module->ToString();
2518 
2519   const auto root = module->entry_computation()->root_instruction();
2520   auto param0 =
2521       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
2522                                       op::Reshape(), op::Constant())),
2523             op::Shape("f32[128,7,257]"));
2524   auto param1 =
2525       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
2526                                       op::Reshape(), op::Constant())),
2527             op::Shape("s32[128,7,257]"));
2528   EXPECT_THAT(root, AllOf(op::Sort(param0, param1),
2529                           op::Shape("(f32[128,7,257], s32[128,7,257])")));
2530 }
2531 
TEST_F(SpmdPartitioningTest,PartitionCustomCall)2532 TEST_F(SpmdPartitioningTest, PartitionCustomCall) {
2533   absl::string_view hlo_string = R"(
2534 HloModule cluster_2013453984438090939__.47
2535 
2536 ENTRY %cluster_2013453984438090939__.47
2537   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2538   %arg_tuple.1 = bf16[2,209664] parameter(0)
2539   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2540   %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2541     custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK"
2542   %get-tuple-element = bf16[2,2000]{1,0}
2543     get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call),
2544     index=0, sharding={replicated}
2545   %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0},
2546     s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated}
2547   ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2548     tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0}
2549     %get-tuple-element.1), sharding={{replicated}, {replicated}},
2550     metadata={op_name="XLA_Retvals"}
2551 })";
2552 
2553   TF_ASSERT_OK_AND_ASSIGN(auto module,
2554                           PartitionComputation(hlo_string, /*num_devices=*/2));
2555   VLOG(1) << module->ToString();
2556   auto custom_call = FindInstruction(module.get(), "custom-call.1");
2557   EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832);
2558   auto sort = FindInstruction(module.get(), "sort");
2559   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000);
2560   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
2561 }
2562 
TEST_F(SpmdPartitioningTest,PartitionCustomCall_TwoPartitionedDims)2563 TEST_F(SpmdPartitioningTest, PartitionCustomCall_TwoPartitionedDims) {
2564   absl::string_view hlo_string = R"(
2565 HloModule module
2566 
2567 ENTRY entry {
2568   %param0 = f32[8,32128] parameter(0)
2569   %copy.0 = f32[8,32128] copy(%param0),
2570     sharding={devices=[4,2]0,1,2,3,4,5,6,7}
2571   %custom-call = (f32[8,2]{1,0}, s32[8,2]{1,0})
2572     custom-call(%copy.0), custom_call_target="TopK"
2573   %get-tuple-element = f32[8,2]{1,0}
2574     get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=0,
2575     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2576   %get-tuple-element.1 = s32[8,2]{1,0}
2577     get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=1,
2578     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2579   ROOT %tuple = (f32[8,2]{1,0}, s32[8,2]{1,0})
2580     tuple(%get-tuple-element, %get-tuple-element.1),
2581     sharding={{replicated}, {replicated}}
2582 })";
2583 
2584   TF_ASSERT_OK_AND_ASSIGN(auto module,
2585                           PartitionComputation(hlo_string, /*num_devices=*/8));
2586   VLOG(1) << module->ToString();
2587   auto custom_call = FindInstruction(module.get(), "custom-call.1");
2588   EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 16064);
2589   auto sort = FindInstruction(module.get(), "sort");
2590   EXPECT_EQ(sort->operand(0)->shape().dimensions(0), 2);
2591   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4);
2592   EXPECT_EQ(sort->operand(1)->shape().dimensions(0), 2);
2593   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4);
2594 }
2595 
TEST_F(SpmdPartitioningTest,PartitionSortInTopK)2596 TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
2597   absl::string_view hlo_string = R"(
2598 HloModule module
2599 
2600 %compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
2601    s32[], p.1.rhs.12: s32[]) -> pred[] {
2602   %p.1.lhs.11 = s32[] parameter(2)
2603   %p.1.rhs.12 = s32[] parameter(3)
2604   %p.0.lhs.9 = bf16[] parameter(0)
2605   %convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
2606   %bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
2607   %constant.20 = s32[] constant(0)
2608   %compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
2609     direction=LT
2610   %constant.15 = u32[] constant(2147483647)
2611   %bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
2612   %subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
2613   %bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
2614   %select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
2615     %bitcast-convert.16)
2616   %p.0.rhs.10 = bf16[] parameter(1)
2617   %convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
2618   %bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
2619   %constant.28 = s32[] constant(0)
2620   %compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
2621     direction=LT
2622   %constant.23 = u32[] constant(2147483647)
2623   %bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
2624   %subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
2625   %bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
2626   %select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
2627     %bitcast-convert.24)
2628   ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
2629     direction=GT
2630 }
2631 
2632 ENTRY entry
2633   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2634   %arg_tuple.1 = bf16[2,209664] parameter(0)
2635   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2636   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2637     metadata={op_type="TopKV2" op_name="TopKV2"}
2638   %sort.32 = (bf16[2,209664], s32[2,209664])
2639     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2640     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2641     metadata={op_type="TopKV2" op_name="TopKV2"}
2642   %get-tuple-element.33 = bf16[2,209664]
2643     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2644     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2645   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2646     %get-tuple-element.33), slice={[0:2], [0:2000]},
2647     metadata={op_type="TopKV2" op_name="TopKV2"}
2648   %get-tuple-element.35 = s32[2,209664]
2649     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2650     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2651   %slice.36 = s32[2,2000] slice(s32[2,209664]
2652     %get-tuple-element.35), slice={[0:2], [0:2000]},
2653     metadata={op_type="TopKV2" op_name="TopKV2"}
2654   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2655     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2656     %slice.36), sharding={{replicated}, {replicated}},
2657     metadata={op_name="XLA_Retvals"}
2658 })";
2659 
2660   TF_ASSERT_OK_AND_ASSIGN(auto module,
2661                           PartitionComputation(hlo_string, /*num_devices=*/2));
2662   VLOG(1) << module->ToString();
2663   auto sort = FindInstruction(module.get(), "sort.0");
2664   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2665   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2666   auto final_sort = FindInstruction(module.get(), "sort.1");
2667   EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2668   EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2669 }
2670 
TEST_F(SpmdPartitioningTest,PartitionSortInTopKWhenComparisonWithSelect)2671 TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
2672   absl::string_view hlo_string = R"(
2673 HloModule module
2674 
2675 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2676   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2677   %p.0.lhs.2566 = bf16[] parameter(0)
2678   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2679   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2680   %constant.285 = s32[] constant(0)
2681   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2682     direction=LT
2683   %constant.286 = u32[] constant(2147483647)
2684   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2685   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2686   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2687   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2688     s32[] %bitcast-convert.48)
2689   %p.0.rhs.2567 = bf16[] parameter(1)
2690   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2691   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2692   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2693     direction=LT
2694   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2695   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2696   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2697   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2698     s32[] %bitcast-convert.51)
2699   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2700   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2701   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2702     direction=EQ
2703   %p.1.lhs.2586 = s32[] parameter(2)
2704   %p.1.rhs.2587 = s32[] parameter(3)
2705   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2706     direction=LT
2707   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2708     pred[] %compare.86)
2709 }
2710 
2711 ENTRY entry
2712   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2713   %arg_tuple.1 = bf16[2,209664] parameter(0)
2714   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2715   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2716     metadata={op_type="TopKV2" op_name="TopKV2"}
2717   %sort.32 = (bf16[2,209664], s32[2,209664])
2718     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2719     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2720     metadata={op_type="TopKV2" op_name="TopKV2"}
2721   %get-tuple-element.33 = bf16[2,209664]
2722     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2723     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2724   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2725     %get-tuple-element.33), slice={[0:2], [0:2000]},
2726     metadata={op_type="TopKV2" op_name="TopKV2"}
2727   %get-tuple-element.35 = s32[2,209664]
2728     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2729     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2730   %slice.36 = s32[2,2000] slice(s32[2,209664]
2731     %get-tuple-element.35), slice={[0:2], [0:2000]},
2732     metadata={op_type="TopKV2" op_name="TopKV2"}
2733   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2734     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2735     %slice.36), sharding={{replicated}, {replicated}},
2736     metadata={op_name="XLA_Retvals"}
2737 })";
2738 
2739   TF_ASSERT_OK_AND_ASSIGN(auto module,
2740                           PartitionComputation(hlo_string, /*num_devices=*/2));
2741   VLOG(1) << module->ToString();
2742   auto sort = FindInstruction(module.get(), "sort.0");
2743   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2744   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2745   auto final_sort = FindInstruction(module.get(), "sort.1");
2746   EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2747   EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2748 }
2749 
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSecondOperandIsNotIota)2750 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
2751   absl::string_view hlo_string = R"(
2752 HloModule module
2753 
2754 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2755   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2756   %p.0.lhs.2566 = bf16[] parameter(0)
2757   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2758   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2759   %constant.285 = s32[] constant(0)
2760   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2761     direction=LT
2762   %constant.286 = u32[] constant(2147483647)
2763   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2764   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2765   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2766   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2767     s32[] %bitcast-convert.48)
2768   %p.0.rhs.2567 = bf16[] parameter(1)
2769   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2770   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2771   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2772     direction=LT
2773   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2774   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2775   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2776   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2777     s32[] %bitcast-convert.51)
2778   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2779   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2780   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2781     direction=EQ
2782   %p.1.lhs.2586 = s32[] parameter(2)
2783   %p.1.rhs.2587 = s32[] parameter(3)
2784   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2785     direction=LT
2786   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2787     pred[] %compare.86)
2788 }
2789 
2790 ENTRY entry {
2791   %arg_tuple.1 = bf16[2,209664] parameter(0)
2792   %arg_tuple.2 = s32[2,209664] parameter(1)
2793   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2794   %sort.32 = (bf16[2,209664], s32[2,209664])
2795     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
2796     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2797     metadata={op_type="TopKV2" op_name="TopKV2"}
2798   %get-tuple-element.33 = bf16[2,209664]
2799     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2800     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2801   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2802     %get-tuple-element.33), slice={[0:2], [0:2000]},
2803     metadata={op_type="TopKV2" op_name="TopKV2"}
2804   %get-tuple-element.35 = s32[2,209664]
2805     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2806     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2807   %slice.36 = s32[2,2000] slice(s32[2,209664]
2808     %get-tuple-element.35), slice={[0:2], [0:2000]},
2809     metadata={op_type="TopKV2" op_name="TopKV2"}
2810   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2811     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2812     %slice.36), sharding={{replicated}, {replicated}},
2813     metadata={op_name="XLA_Retvals"}
2814 })";
2815 
2816   TF_ASSERT_OK_AND_ASSIGN(auto module,
2817                           PartitionComputation(hlo_string, /*num_devices=*/2));
2818   VLOG(1) << module->ToString();
2819   auto sort = FindInstruction(module.get(), "sort.0");
2820   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2821   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2822 }
2823 
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenNoPartitionInSortDim)2824 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
2825   absl::string_view hlo_string = R"(
2826 HloModule module
2827 
2828 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2829   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2830   %p.0.lhs.2566 = bf16[] parameter(0)
2831   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2832   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2833   %constant.285 = s32[] constant(0)
2834   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2835     direction=LT
2836   %constant.286 = u32[] constant(2147483647)
2837   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2838   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2839   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2840   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2841     s32[] %bitcast-convert.48)
2842   %p.0.rhs.2567 = bf16[] parameter(1)
2843   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2844   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2845   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2846     direction=LT
2847   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2848   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2849   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2850   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2851     s32[] %bitcast-convert.51)
2852   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2853   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2854   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2855     direction=EQ
2856   %p.1.lhs.2586 = s32[] parameter(2)
2857   %p.1.rhs.2587 = s32[] parameter(3)
2858   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2859     direction=LT
2860   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2861     pred[] %compare.86)
2862 }
2863 
2864 ENTRY entry
2865   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2866   %arg_tuple.1 = bf16[2,209664] parameter(0)
2867   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
2868   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2869     metadata={op_type="TopKV2" op_name="TopKV2"}
2870   %sort.32 = (bf16[2,209664], s32[2,209664])
2871     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2872     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2873     metadata={op_type="TopKV2" op_name="TopKV2"}
2874   %get-tuple-element.33 = bf16[2,209664]
2875     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2876     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2877   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2878     %get-tuple-element.33), slice={[0:2], [0:2000]},
2879     metadata={op_type="TopKV2" op_name="TopKV2"}
2880   %get-tuple-element.35 = s32[2,209664]
2881     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2882     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2883   %slice.36 = s32[2,2000] slice(s32[2,209664]
2884     %get-tuple-element.35), slice={[0:2], [0:2000]},
2885     metadata={op_type="TopKV2" op_name="TopKV2"}
2886   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2887     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2888     %slice.36), sharding={{replicated}, {replicated}},
2889     metadata={op_name="XLA_Retvals"}
2890 })";
2891 
2892   TF_ASSERT_OK_AND_ASSIGN(auto module,
2893                           PartitionComputation(hlo_string, /*num_devices=*/2));
2894   VLOG(1) << module->ToString();
2895   auto sort = FindInstruction(module.get(), "sort.0");
2896   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2897   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2898 }
2899 
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSliceInOtherDim)2900 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
2901   absl::string_view hlo_string = R"(
2902 HloModule module
2903 
2904 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2905   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2906   %p.0.lhs.2566 = bf16[] parameter(0)
2907   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2908   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2909   %constant.285 = s32[] constant(0)
2910   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2911     direction=LT
2912   %constant.286 = u32[] constant(2147483647)
2913   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2914   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2915   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2916   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2917     s32[] %bitcast-convert.48)
2918   %p.0.rhs.2567 = bf16[] parameter(1)
2919   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2920   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2921   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2922     direction=LT
2923   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2924   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2925   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2926   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2927     s32[] %bitcast-convert.51)
2928   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2929   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2930   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2931     direction=EQ
2932   %p.1.lhs.2586 = s32[] parameter(2)
2933   %p.1.rhs.2587 = s32[] parameter(3)
2934   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2935     direction=LT
2936   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2937     pred[] %compare.86)
2938 }
2939 
2940 ENTRY entry {
2941   %arg_tuple.1 = bf16[2,209664] parameter(0)
2942   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2943   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2944     metadata={op_type="TopKV2" op_name="TopKV2"}
2945   %sort.32 = (bf16[2,209664], s32[2,209664])
2946     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2947     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2948     metadata={op_type="TopKV2" op_name="TopKV2"}
2949   %get-tuple-element.33 = bf16[2,209664]
2950     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2951     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2952   %slice.34 = bf16[1,209664] slice(bf16[2,209664]
2953     %get-tuple-element.33), slice={[0:1], [0:209664]},
2954     metadata={op_type="TopKV2" op_name="TopKV2"}
2955   %get-tuple-element.35 = s32[2,209664]
2956     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2957     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2958   %slice.36 = s32[1,209664] slice(s32[2,209664]
2959     %get-tuple-element.35), slice={[0:1], [0:209664]},
2960     metadata={op_type="TopKV2" op_name="TopKV2"}
2961   ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
2962     tuple(bf16[1,209664] %slice.34, s32[1,209664]
2963     %slice.36), sharding={{replicated}, {replicated}},
2964     metadata={op_name="XLA_Retvals"}
2965 })";
2966 
2967   TF_ASSERT_OK_AND_ASSIGN(auto module,
2968                           PartitionComputation(hlo_string, /*num_devices=*/2));
2969   VLOG(1) << module->ToString();
2970   auto sort = FindInstruction(module.get(), "sort.0");
2971   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2972   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2973 }
2974 
TEST_F(SpmdPartitioningTest,ShardableTranspose)2975 TEST_F(SpmdPartitioningTest, ShardableTranspose) {
2976   absl::string_view hlo_string = R"(
2977 HloModule module
2978 
2979 ENTRY entry {
2980   %param0 = f32[16,38,38,4] parameter(0)
2981   %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
2982   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
2983     dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1}
2984 })";
2985 
2986   TF_ASSERT_OK_AND_ASSIGN(auto module,
2987                           PartitionComputation(hlo_string, /*num_devices=*/2));
2988   VLOG(1) << module->ToString();
2989 
2990   const auto root = module->entry_computation()->root_instruction();
2991   auto param0 = AllOf(
2992       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
2993                                 op::Constant(), op::Constant())),
2994       op::Shape("f32[16,19,38,4]"));
2995   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
2996 }
2997 
TEST_F(SpmdPartitioningTest,MultiDimensionShardedTranspose)2998 TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) {
2999   absl::string_view hlo_string = R"(
3000 HloModule module
3001 
3002 ENTRY entry {
3003   %param0 = f32[16,38,38,4] parameter(0)
3004   %param0.copy = f32[16,38,38,4] copy(%param0),
3005     sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
3006   ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
3007     dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7}
3008 })";
3009 
3010   TF_ASSERT_OK_AND_ASSIGN(auto module,
3011                           PartitionComputation(hlo_string, /*num_devices=*/8));
3012   VLOG(1) << module->ToString();
3013 
3014   const auto root = module->entry_computation()->root_instruction();
3015   auto param0 = AllOf(
3016       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
3017                                 op::Constant(), op::Constant())),
3018       op::Shape("f32[4,19,38,4]"));
3019   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]")));
3020 }
3021 
TEST_F(SpmdPartitioningTest,NonShardableTranspose)3022 TEST_F(SpmdPartitioningTest, NonShardableTranspose) {
3023   absl::string_view hlo_string = R"(
3024 HloModule module
3025 
3026 ENTRY entry {
3027   %param0 = f32[16,38,38,4] parameter(0)
3028   %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
3029   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3030     dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1}
3031 })";
3032 
3033   TF_ASSERT_OK_AND_ASSIGN(auto module,
3034                           PartitionComputation(hlo_string, /*num_devices=*/2));
3035   VLOG(1) << module->ToString();
3036 
3037   const auto root = module->entry_computation()->root_instruction();
3038   auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3039                        op::Shape("f32[16,38,38,2]"));
3040   EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3041 }
3042 
TEST_F(SpmdPartitioningTest,PartialReplicateShardableTranspose)3043 TEST_F(SpmdPartitioningTest, PartialReplicateShardableTranspose) {
3044   absl::string_view hlo_string = R"(
3045 HloModule module
3046 
3047 ENTRY entry {
3048   %param0 = f32[16,38,38,4] parameter(0)
3049   %param0.copy = f32[16,38,38,4] copy(%param0),
3050     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3051   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3052     dimensions={0,3,1,2},
3053     sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3054 })";
3055 
3056   TF_ASSERT_OK_AND_ASSIGN(auto module,
3057                           PartitionComputation(hlo_string, /*num_devices=*/4));
3058   VLOG(1) << module->ToString();
3059 
3060   const auto root = module->entry_computation()->root_instruction();
3061   auto param0 = AllOf(
3062       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3063                                 op::Constant(), op::Constant())),
3064       op::Shape("f32[16,19,38,4]"));
3065   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
3066 }
3067 
TEST_F(SpmdPartitioningTest,PartialReplicateNonShardableTranspose)3068 TEST_F(SpmdPartitioningTest, PartialReplicateNonShardableTranspose) {
3069   absl::string_view hlo_string = R"(
3070 HloModule module
3071 
3072 ENTRY entry {
3073   %param0 = f32[16,38,38,4] parameter(0)
3074   %param0.copy = f32[16,38,38,4] copy(%param0),
3075     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3076   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3077     dimensions={0,3,1,2},
3078     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3079 })";
3080 
3081   TF_ASSERT_OK_AND_ASSIGN(auto module,
3082                           PartitionComputation(hlo_string, /*num_devices=*/4));
3083   VLOG(1) << module->ToString();
3084 
3085   const auto root = module->entry_computation()->root_instruction();
3086   auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3087                        op::Shape("f32[16,38,38,2]"));
3088   EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3089 }
3090 
TEST_F(SpmdPartitioningTest,PartialReplicateMultiDimensionShardedTranspose)3091 TEST_F(SpmdPartitioningTest, PartialReplicateMultiDimensionShardedTranspose) {
3092   absl::string_view hlo_string = R"(
3093 HloModule module
3094 
3095 ENTRY entry {
3096   %param0 = f32[16,38,38,4] parameter(0)
3097   %param0.copy = f32[16,38,38,4] copy(%param0),
3098     sharding={devices=[2,2,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3099   ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
3100     dimensions={1,3,0,2},
3101     sharding={devices=[2,1,2,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3102 })";
3103 
3104   TF_ASSERT_OK_AND_ASSIGN(auto module,
3105                           PartitionComputation(hlo_string, /*num_devices=*/8));
3106   VLOG(1) << module->ToString();
3107 
3108   const auto root = module->entry_computation()->root_instruction();
3109   auto param0 = AllOf(
3110       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
3111                                 op::Constant(), op::Constant())),
3112       op::Shape("f32[8,19,38,4]"));
3113   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,8,38]")));
3114 }
3115 
TEST_F(SpmdPartitioningTest,ShardableReshape)3116 TEST_F(SpmdPartitioningTest, ShardableReshape) {
3117   absl::string_view hlo_string = R"(
3118 HloModule module
3119 
3120 ENTRY entry {
3121   %param0 = f32[38,38,324] parameter(0)
3122   %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1}
3123   ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3124     sharding={devices=[2,1,1,1]0,1}
3125 })";
3126 
3127   TF_ASSERT_OK_AND_ASSIGN(auto module,
3128                           PartitionComputation(hlo_string, /*num_devices=*/2));
3129   VLOG(1) << module->ToString();
3130 
3131   const auto root = module->entry_computation()->root_instruction();
3132   auto param0 =
3133       AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3134                                       op::Constant(), op::Constant())),
3135             op::Shape("f32[19,38,324]"));
3136   EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3137 }
3138 
TEST_F(SpmdPartitioningTest,ReshapeWithReshard)3139 TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
3140   absl::string_view hlo_string = R"(
3141 HloModule module
3142 
3143 ENTRY entry {
3144   %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3145   ROOT %reshape = f32[38,38,4,81] reshape(%param0),
3146     sharding={devices=[1,2,1,1]0,1}
3147 })";
3148 
3149   TF_ASSERT_OK_AND_ASSIGN(auto module,
3150                           PartitionComputation(hlo_string, /*num_devices=*/2));
3151   VLOG(1) << module->ToString();
3152 
3153   const auto root = module->entry_computation()->root_instruction();
3154   auto input_reshard =
3155       op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
3156   EXPECT_THAT(root,
3157               AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
3158 }
3159 
TEST_F(SpmdPartitioningTest,ReshapeWithReshard2)3160 TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
3161   absl::string_view hlo_string = R"(
3162 HloModule module
3163 
3164 ENTRY entry {
3165   %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3166   ROOT %reshape = f32[38,38,2,162] reshape(%param0),
3167     sharding={devices=[1,1,1,2]0,1}
3168 })";
3169 
3170   TF_ASSERT_OK_AND_ASSIGN(auto module,
3171                           PartitionComputation(hlo_string, /*num_devices=*/2));
3172   VLOG(1) << module->ToString();
3173 
3174   const auto root = module->entry_computation()->root_instruction();
3175   auto local_reshape =
3176       AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
3177   EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
3178                           op::Reshape(op::Transpose(
3179                               op::AllToAll(op::Reshape(local_reshape))))));
3180 }
3181 
TEST_F(SpmdPartitioningTest,PartialReplicateShardableReshape)3182 TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
3183   absl::string_view hlo_string = R"(
3184 HloModule module
3185 
3186 ENTRY entry {
3187   %param0 = f32[38,38,324] parameter(0)
3188   %param0.copy = f32[38,38,324] copy(%param0),
3189     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3190   ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3191     sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
3192 })";
3193 
3194   TF_ASSERT_OK_AND_ASSIGN(auto module,
3195                           PartitionComputation(hlo_string, /*num_devices=*/4));
3196   VLOG(1) << module->ToString();
3197 
3198   const auto root = module->entry_computation()->root_instruction();
3199   auto param0 =
3200       AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3201                                       op::Constant(), op::Constant())),
3202             op::Shape("f32[19,38,324]"));
3203   EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3204 }
3205 
TEST_F(SpmdPartitioningTest,ReshapeMergeDimsWithHaloExchange)3206 TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
3207   absl::string_view hlo_string = R"(
3208 HloModule module
3209 
3210 ENTRY entry {
3211   %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1}
3212   ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3213     sharding={devices=[1,1,1,2,1]0,1}
3214 })";
3215 
3216   TF_ASSERT_OK_AND_ASSIGN(auto module,
3217                           PartitionComputation(hlo_string, /*num_devices=*/2));
3218   VLOG(1) << module->ToString();
3219 
3220   auto reshape =
3221       AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3222   auto halo = op::CollectivePermute(op::Slice(reshape));
3223   auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3224                                     _, _, _, _, _);
3225   const auto root = module->entry_computation()->root_instruction();
3226   EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3227 }
3228 
TEST_F(SpmdPartitioningTest,PartialReplicateReshapeMergeDimsWithHaloExchange)3229 TEST_F(SpmdPartitioningTest, PartialReplicateReshapeMergeDimsWithHaloExchange) {
3230   absl::string_view hlo_string = R"(
3231 HloModule module
3232 
3233 ENTRY entry {
3234   %input = s32[2,3,7,10] parameter(0),
3235     sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3236   ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3237     sharding={devices=[1,1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3238 })";
3239 
3240   TF_ASSERT_OK_AND_ASSIGN(auto module,
3241                           PartitionComputation(hlo_string, /*num_devices=*/4));
3242   VLOG(1) << module->ToString();
3243 
3244   auto reshape =
3245       AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3246   auto halo = op::CollectivePermute(op::Slice(reshape));
3247   auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3248                                     _, _, _, _, _);
3249   const auto root = module->entry_computation()->root_instruction();
3250   EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3251 }
3252 
3253 // Produces an invalid module after transformation.
TEST_F(SpmdPartitioningTest,InceptionV3_4_way_ReduceWindowDilated)3254 TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) {
3255   absl::string_view hlo_string = R"(
3256 HloModule module
3257 
3258 sum {
3259   a = f32[] parameter(0)
3260   b = f32[] parameter(1)
3261   ROOT add = f32[] add(a, b)
3262 }
3263 
3264 ENTRY entry {
3265   %param0 = f32[128,5,5,768] parameter(0)
3266   %param0.copy = f32[128,5,5,768] copy(%param0),
3267     sharding={devices=[1,4,1,1]0,1,2,3}
3268   %constant.1 = f32[] constant(0), sharding={replicated}
3269   ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1),
3270     window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1},
3271     to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3}
3272 })";
3273 
3274   TF_ASSERT_OK_AND_ASSIGN(auto module,
3275                           PartitionComputation(hlo_string, /*num_devices=*/4));
3276   VLOG(1) << module->ToString();
3277 
3278   auto input_shard = op::Copy(op::DynamicSlice(
3279       op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(),
3280       op::Constant(), op::Constant()));
3281   auto id_mul4_add1 =
3282       op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
3283   auto id_mul5 = op::Multiply(op::Reshape(), op::Constant());
3284   auto id_mul5_add1_div3 =
3285       op::Divide(op::Add(id_mul5, op::Constant()), op::Constant());
3286   auto before_masking = AllOf(
3287       op::Shape("f32[128,3,5,768]"),
3288       op::DynamicSlice(
3289           AllOf(
3290               op::Shape("f32[128,4,5,768]"),
3291               op::Concatenate(op::CollectivePermute(input_shard), input_shard)),
3292           op::Constant(),
3293           op::Subtract(op::Constant(),
3294                        op::Subtract(id_mul4_add1, id_mul5_add1_div3)),
3295           op::Constant(), op::Constant()));
3296   auto masked = op::Select(
3297       op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3298                           op::Broadcast(op::Constant())),
3299               op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3300                           op::Broadcast(op::Constant()))),
3301       before_masking, op::Broadcast(op::Constant()));
3302   auto rw = AllOf(op::Shape("f32[128,7,17,768]"),
3303                   op::ReduceWindow(masked, op::Constant()));
3304   auto final_slice_index = op::Subtract(
3305       id_mul5,
3306       op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant()));
3307   const auto root = module->entry_computation()->root_instruction();
3308   EXPECT_THAT(root,
3309               AllOf(op::Shape("f32[128,5,17,768]"),
3310                     op::DynamicSlice(rw, op::Constant(), final_slice_index,
3311                                      op::Constant(), op::Constant())));
3312 }
3313 
TEST_F(SpmdPartitioningTest,TiledToTiledReduce)3314 TEST_F(SpmdPartitioningTest, TiledToTiledReduce) {
3315   absl::string_view hlo_string = R"(
3316 HloModule module
3317 
3318 sum {
3319   a = f32[] parameter(0)
3320   b = f32[] parameter(1)
3321   ROOT add = f32[] add(a, b)
3322 }
3323 
3324 ENTRY entry {
3325   %param0 = f32[4,32,32,128] parameter(0)
3326   %param0.copy = f32[4,32,32,128] copy(%param0),
3327     sharding={devices=[1,1,1,2]0,1}
3328   %constant.1 = f32[] constant(0), sharding={replicated}
3329   %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3330     to_apply=%sum, sharding={devices=[2]0,1}
3331 })";
3332 
3333   TF_ASSERT_OK_AND_ASSIGN(auto module,
3334                           PartitionComputation(hlo_string, /*num_devices=*/2));
3335   VLOG(1) << module->ToString();
3336 
3337   const auto root = module->entry_computation()->root_instruction();
3338   auto param0 = AllOf(
3339       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
3340                                 op::Constant(), op::Reshape())),
3341       op::Shape("f32[4,32,32,64]"));
3342 
3343   EXPECT_THAT(root,
3344               AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]")));
3345 }
3346 
TEST_F(SpmdPartitioningTest,PartialTiledToPartialTiledReduce)3347 TEST_F(SpmdPartitioningTest, PartialTiledToPartialTiledReduce) {
3348   absl::string_view hlo_string = R"(
3349 HloModule module
3350 
3351 sum {
3352   a = f32[] parameter(0)
3353   b = f32[] parameter(1)
3354   ROOT add = f32[] add(a, b)
3355 }
3356 
3357 ENTRY entry {
3358   %param0 = f32[4,4] parameter(0),
3359     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3360   %constant.1 = f32[] constant(0), sharding={replicated}
3361   ROOT %reduce = f32[4] reduce(%param0, %constant.1), dimensions={0},
3362     to_apply=%sum,
3363     sharding={devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3364 })";
3365 
3366   TF_ASSERT_OK_AND_ASSIGN(auto module,
3367                           PartitionComputation(hlo_string, /*num_devices=*/8));
3368   VLOG(1) << module->ToString();
3369 
3370   const auto root = module->entry_computation()->root_instruction();
3371   EXPECT_THAT(root,
3372               AllOf(op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())),
3373                     op::Shape("f32[2]")));
3374 }
3375 
TEST_F(SpmdPartitioningTest,TiledToTiledTupleReduce)3376 TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) {
3377   absl::string_view hlo_string = R"(
3378 HloModule module
3379 
3380 %minmax_func {
3381   %lhs_value = f32[] parameter(0)
3382   %rhs_value = f32[] parameter(2)
3383   %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3384   %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3385   %lhs_index = s32[] parameter(1)
3386   %rhs_index = s32[] parameter(3)
3387   %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3388   ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3389 }
3390 
3391 ENTRY %main {
3392   %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1}
3393   %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1}
3394   %init0 = f32[] parameter(2)
3395   %init1 = s32[] parameter(3)
3396   ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3397     dimensions={1}, to_apply=%minmax_func,
3398     sharding={{devices=[2]0,1}, {devices=[2]0,1}}
3399 })";
3400 
3401   TF_ASSERT_OK_AND_ASSIGN(auto module,
3402                           PartitionComputation(hlo_string, /*num_devices=*/2));
3403   VLOG(1) << module->ToString();
3404 
3405   const auto root = module->entry_computation()->root_instruction();
3406   EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1),
3407                                      op::Parameter(2), op::Parameter(3)),
3408                           op::Shape("(f32[14], s32[14])")));
3409 }
3410 
TEST_F(SpmdPartitioningTest,TiledToPartiallyTiledTupleReduce)3411 TEST_F(SpmdPartitioningTest, TiledToPartiallyTiledTupleReduce) {
3412   absl::string_view hlo_string = R"(
3413 HloModule module
3414 
3415 %minmax_func {
3416   %lhs_value = f32[] parameter(0)
3417   %rhs_value = f32[] parameter(2)
3418   %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3419   %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3420   %lhs_index = s32[] parameter(1)
3421   %rhs_index = s32[] parameter(3)
3422   %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3423   ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3424 }
3425 
3426 ENTRY %main {
3427   %param0 = f32[28,12] parameter(0), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3428   %param1 = s32[28,12] parameter(1), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3429   %init0 = f32[] parameter(2)
3430   %init1 = s32[] parameter(3)
3431   ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3432     dimensions={1}, to_apply=%minmax_func,
3433     sharding={{devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate},
3434               {devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}}
3435 })";
3436 
3437   TF_ASSERT_OK_AND_ASSIGN(auto module,
3438                           PartitionComputation(hlo_string, /*num_devices=*/8));
3439   VLOG(1) << module->ToString();
3440 
3441   const auto lhs = AllOf(op::Shape("f32[14,3]"), op::Parameter(0));
3442   const auto rhs = AllOf(op::Shape("s32[14,3]"), op::Parameter(1));
3443   auto local_reduce =
3444       AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)),
3445             op::Shape("(f32[14], s32[14])"));
3446   auto reshape_l = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3447                          op::Shape("f32[14,1]"));
3448   auto reshape_r = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3449                          op::Shape("s32[14,1]"));
3450   auto broadcast_l =
3451       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_l, _, _)),
3452             op::Shape("f32[14,4]"));
3453   auto broadcast_r =
3454       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_r, _, _)),
3455             op::Shape("s32[14,4]"));
3456   const auto root = module->entry_computation()->root_instruction();
3457   EXPECT_THAT(root, AllOf(op::Reduce(broadcast_l, broadcast_r, op::Parameter(2),
3458                                      op::Parameter(3)),
3459                           op::Shape("(f32[14], s32[14])")));
3460 }
3461 
TEST_F(SpmdPartitioningTest,TupleReduceSubgroupManual)3462 TEST_F(SpmdPartitioningTest, TupleReduceSubgroupManual) {
3463   absl::string_view hlo_string = R"(
3464 HloModule module
3465 
3466 %minmax_func {
3467   %lhs_value = f32[] parameter(0)
3468   %rhs_value = f32[] parameter(2)
3469   %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3470   %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3471   %lhs_index = s32[] parameter(1)
3472   %rhs_index = s32[] parameter(3)
3473   %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3474   ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3475 }
3476 
3477 ENTRY %main {
3478   %param0 = f32[28,12] parameter(0),
3479     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
3480   %param1 = s32[28,12] parameter(1),
3481     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
3482   %init0 = f32[] parameter(2),
3483     sharding={devices=[2,2]0,1,2,3 last_tile_dims={replicated,manual}}
3484   %init1 = s32[] parameter(3),
3485     sharding={devices=[2,2]0,1,2,3 last_tile_dims={replicated,manual}}
3486   ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3487     dimensions={1}, to_apply=%minmax_func,
3488     sharding={{devices=[1,2,2]0,1,2,3 last_tile_dims={replicated,manual}},
3489               {devices=[1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}}
3490 })";
3491 
3492   TF_ASSERT_OK_AND_ASSIGN(auto module,
3493                           PartitionComputation(hlo_string, /*num_devices=*/4));
3494   VLOG(1) << module->ToString();
3495 
3496   const auto lhs = AllOf(op::Shape("f32[28,6]"), op::Parameter(0));
3497   const auto rhs = AllOf(op::Shape("s32[28,6]"), op::Parameter(1));
3498   auto local_reduce =
3499       AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)),
3500             op::Shape("(f32[28], s32[28])"));
3501   auto reshape_l = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3502                          op::Shape("f32[28,1]"));
3503   auto reshape_r = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3504                          op::Shape("s32[28,1]"));
3505   auto broadcast_l =
3506       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_l, _, _)),
3507             op::Shape("f32[28,2]"));
3508   auto broadcast_r =
3509       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_r, _, _)),
3510             op::Shape("s32[28,2]"));
3511   const auto root = module->entry_computation()->root_instruction();
3512   EXPECT_THAT(root, AllOf(op::Reduce(broadcast_l, broadcast_r, op::Parameter(2),
3513                                      op::Parameter(3)),
3514                           op::Shape("(f32[28], s32[28])")));
3515 }
3516 
TEST_F(SpmdPartitioningTest,TiledToTiledReduceOutputReshard)3517 TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) {
3518   absl::string_view hlo_string = R"(
3519 HloModule module
3520 
3521 sum {
3522   a = f32[] parameter(0)
3523   b = f32[] parameter(1)
3524   ROOT add = f32[] add(a, b)
3525 }
3526 
3527 ENTRY entry {
3528   %param0 = f32[4,32,32,128] parameter(0)
3529   %param0.copy = f32[4,32,32,128] copy(%param0),
3530     sharding={devices=[1,2,1,1]0,1}
3531   %constant.1 = f32[] constant(0), sharding={replicated}
3532   %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3533     to_apply=%sum, sharding={devices=[2]0,1}
3534 })";
3535 
3536   TF_ASSERT_OK_AND_ASSIGN(auto module,
3537                           PartitionComputation(hlo_string, /*num_devices=*/2));
3538   VLOG(1) << module->ToString();
3539 
3540   const auto root = module->entry_computation()->root_instruction();
3541   auto param0 = AllOf(
3542       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3543                                 op::Constant(), op::Constant())),
3544       op::Shape("f32[4,16,32,128]"));
3545 
3546   EXPECT_THAT(root,
3547               AllOf(op::DynamicSlice(
3548                         AllOf(op::AllReduce(op::Reduce(param0, op::Constant())),
3549                               op::Shape("f32[128]")),
3550                         op::Reshape()),
3551                     op::Shape("f32[64]")));
3552 }
3553 
TEST_F(SpmdPartitioningTest,IotaAlongNonTileDimension)3554 TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) {
3555   absl::string_view hlo_string = R"(
3556 HloModule module
3557 
3558 ENTRY entry {
3559   ROOT %iota = s32[16,80,91] iota(), iota_dimension=1,
3560     sharding={devices=[1,1,2]0,1}
3561 })";
3562 
3563   TF_ASSERT_OK_AND_ASSIGN(auto module,
3564                           PartitionComputation(hlo_string, /*num_devices=*/2));
3565   VLOG(1) << module->ToString();
3566 
3567   const auto root = module->entry_computation()->root_instruction();
3568   EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]")));
3569 }
3570 
TEST_F(SpmdPartitioningTest,IotaAlongTileDimension)3571 TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) {
3572   absl::string_view hlo_string = R"(
3573 HloModule module
3574 
3575 ENTRY entry {
3576   ROOT %iota = s32[16,80,91] iota(), iota_dimension=2,
3577     sharding={devices=[1,1,2]0,1}
3578 })";
3579 
3580   TF_ASSERT_OK_AND_ASSIGN(auto module,
3581                           PartitionComputation(hlo_string, /*num_devices=*/2));
3582   VLOG(1) << module->ToString();
3583 
3584   const auto root = module->entry_computation()->root_instruction();
3585   EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3586                           op::Shape("s32[16,80,46]")));
3587 }
3588 
TEST_F(SpmdPartitioningTest,U32IotaAlongTileDimension)3589 TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) {
3590   absl::string_view hlo_string = R"(
3591 HloModule module
3592 
3593 ENTRY entry {
3594   ROOT %iota = u32[16,80,91] iota(), iota_dimension=2,
3595     sharding={devices=[1,1,2]0,1}
3596 })";
3597 
3598   TF_ASSERT_OK_AND_ASSIGN(auto module,
3599                           PartitionComputation(hlo_string, /*num_devices=*/2));
3600   VLOG(1) << module->ToString();
3601 
3602   const auto root = module->entry_computation()->root_instruction();
3603   EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3604                           op::Shape("u32[16,80,46]")));
3605 }
3606 
TEST_F(SpmdPartitioningTest,Conditional)3607 TEST_F(SpmdPartitioningTest, Conditional) {
3608   absl::string_view hlo_string = R"(
3609 HloModule module
3610 
3611 Negate {
3612   x = f32[4,5] parameter(0), sharding={replicated}
3613   ROOT negate = f32[4,5] negate(x), sharding={replicated}
3614 }
3615 
3616 Identity {
3617   y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1}
3618   ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1}
3619 }
3620 
3621 ENTRY entry {
3622   %param.0 = pred[] parameter(0)
3623   %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0}
3624   %param.1 = f32[4,5] parameter(1)
3625   %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated}
3626   %param.2 = f32[4,5] parameter(2)
3627   %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1}
3628   ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy),
3629     true_computation=Negate, false_computation=Identity,
3630     sharding={devices=[2,1]0,1}
3631 })";
3632 
3633   TF_ASSERT_OK_AND_ASSIGN(auto module,
3634                           PartitionComputation(hlo_string, /*num_devices=*/2));
3635   VLOG(1) << module->ToString();
3636 
3637   auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]")));
3638   auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]"));
3639   auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3640                                                 op::Constant())),
3641                       op::Shape("f32[2,5]"));
3642 
3643   const auto root = module->entry_computation()->root_instruction();
3644   EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2),
3645                           op::Shape("f32[2,5]")));
3646 
3647   auto then_branch_root = root->branch_computation(0)->root_instruction();
3648   EXPECT_THAT(then_branch_root,
3649               AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(),
3650                                      op::Constant()),
3651                     op::Shape("f32[2,5]")));
3652 
3653   auto else_branch_root = root->branch_computation(1)->root_instruction();
3654   EXPECT_THAT(else_branch_root,
3655               AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]")));
3656 }
3657 
TEST_F(SpmdPartitioningTest,ConditionalManual)3658 TEST_F(SpmdPartitioningTest, ConditionalManual) {
3659   absl::string_view hlo_string = R"(
3660 HloModule module
3661 
3662 Negate {
3663   x = f32[4,5] parameter(0), sharding={manual}
3664   ROOT negate = f32[4,5] negate(x), sharding={manual}
3665 }
3666 
3667 Identity {
3668   y = f32[4,5] parameter(0), sharding={manual}
3669   ROOT copy = f32[4,5] copy(y), sharding={manual}
3670 }
3671 
3672 ENTRY entry {
3673   %param.0 = pred[] parameter(0), sharding={manual}
3674   %param.1 = f32[4,5] parameter(1), sharding={manual}
3675   %param.2 = f32[4,5] parameter(2), sharding={manual}
3676   ROOT cond = f32[4,5] conditional(%param.0, %param.1, %param.2),
3677     true_computation=Negate, false_computation=Identity, sharding={manual}
3678 })";
3679 
3680   TF_ASSERT_OK_AND_ASSIGN(auto module,
3681                           PartitionComputation(hlo_string, /*num_devices=*/2));
3682   VLOG(1) << module->ToString();
3683 
3684   auto param0 = AllOf(op::Parameter(0), op::Shape("pred[]"));
3685   auto param1 = AllOf(op::Parameter(1), op::Shape("f32[4,5]"));
3686   auto param2 = AllOf(op::Parameter(2), op::Shape("f32[4,5]"));
3687 
3688   const auto root = module->entry_computation()->root_instruction();
3689   EXPECT_THAT(root, AllOf(op::Conditional(param0, param1, param2),
3690                           op::Shape("f32[4,5]")));
3691 }
3692 
TEST_F(SpmdPartitioningTest,WhileManual)3693 TEST_F(SpmdPartitioningTest, WhileManual) {
3694   absl::string_view hlo_string = R"(
3695 HloModule module
3696 
3697 LoopCond {
3698   x = s32[] parameter(0), sharding={manual}
3699   const = s32[] constant(5), sharding={manual}
3700   ROOT lt = pred[] compare(x, const), direction=LT, sharding={manual}
3701 }
3702 
3703 Inc {
3704   x = s32[] parameter(0), sharding={manual}
3705   const = s32[] constant(1), sharding={manual}
3706   ROOT add = s32[] add(x, const), sharding={manual}
3707 }
3708 
3709 ENTRY entry {
3710   zero = s32[] parameter(0), sharding={manual}
3711   ROOT while = s32[] while(zero), body=Inc, condition=LoopCond,
3712     sharding={manual}
3713 })";
3714 
3715   TF_ASSERT_OK_AND_ASSIGN(auto module,
3716                           PartitionComputation(hlo_string, /*num_devices=*/2));
3717   VLOG(1) << module->ToString();
3718 
3719   auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
3720   const auto root = module->entry_computation()->root_instruction();
3721   EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
3722 }
3723 
TEST_F(SpmdPartitioningTest,SelectAndScatter_RetinaNet)3724 TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) {
3725   absl::string_view hlo_string = R"(
3726 HloModule module
3727 
3728 ge {
3729   a = f32[] parameter(0)
3730   b = f32[] parameter(1)
3731   ROOT compare = pred[] compare(a, b), direction=GE
3732 }
3733 
3734 sum {
3735   c = f32[] parameter(0)
3736   d = f32[] parameter(1)
3737   ROOT add = f32[] add(c, d)
3738 }
3739 
3740 ENTRY entry {
3741   %param.0 = f32[32,128,384,64] parameter(0)
3742   %param.0.copy = f32[32,128,384,64] copy(%param.0),
3743     sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3744   %param.1 = f32[32,64,192,64] parameter(1)
3745   %param.1.copy = f32[32,64,192,64] copy(%param.1),
3746     sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3747   constant.1 = f32[] constant(0), sharding={replicated}
3748   ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy,
3749     %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1},
3750     select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3751 })";
3752   TF_ASSERT_OK_AND_ASSIGN(auto module,
3753                           PartitionComputation(hlo_string, /*num_devices=*/8));
3754   VLOG(1) << module->ToString();
3755 
3756   const auto root = module->entry_computation()->root_instruction();
3757   auto source = AllOf(
3758       op::Shape("f32[32,8,192,64]"),
3759       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
3760                                 op::Constant(), op::Constant())));
3761   auto data = AllOf(
3762       op::Shape("f32[32,16,384,64]"),
3763       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
3764                                 op::Constant(), op::Constant())));
3765 
3766   EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant()));
3767   EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
3768   EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
3769 }
3770 
TEST_F(SpmdPartitioningTest,TiledDot)3771 TEST_F(SpmdPartitioningTest, TiledDot) {
3772   absl::string_view hlo_string = R"(
3773 HloModule module
3774 
3775 ENTRY entry {
3776   %lhs = f32[128,64] parameter(0)
3777   %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3778   %rhs = f32[64,256] parameter(1)
3779   %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3780   ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3781     dim_labels=bf_io->bf, sharding={replicated}
3782 })";
3783 
3784   TF_ASSERT_OK_AND_ASSIGN(
3785       auto module,
3786       PartitionComputation(hlo_string, /*num_devices=*/2,
3787                            /*conv_halo_exchange_always_on_lhs=*/false));
3788   VLOG(1) << module->ToString();
3789 
3790   const auto root = module->entry_computation()->root_instruction();
3791   const auto lhs = AllOf(op::Copy(op::DynamicSlice(
3792                              op::Parameter(), op::Constant(), op::Reshape())),
3793                          op::Shape("f32[128,32]"));
3794   const auto rhs = AllOf(op::Copy(op::DynamicSlice(
3795                              op::Parameter(), op::Reshape(), op::Constant())),
3796                          op::Shape("f32[32,256]"));
3797   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3798                           op::Shape("f32[128,256]")));
3799 }
3800 
TEST_F(SpmdPartitioningTest,TiledDotOutputTiled)3801 TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) {
3802   absl::string_view hlo_string = R"(
3803 HloModule module
3804 
3805 ENTRY entry {
3806   %lhs = f32[128,64] parameter(0)
3807   %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3808   %rhs = f32[64,256] parameter(1)
3809   %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3810   ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3811     dim_labels=bf_io->bf, sharding={devices=[1,2]0,1}
3812 })";
3813 
3814   TF_ASSERT_OK_AND_ASSIGN(auto module,
3815                           PartitionComputation(hlo_string, /*num_devices=*/2));
3816   VLOG(1) << module->ToString();
3817 
3818   const auto root = module->entry_computation()->root_instruction();
3819   const auto lhs = AllOf(op::Copy(op::DynamicSlice(
3820                              op::Parameter(), op::Constant(), op::Reshape())),
3821                          op::Shape("f32[128,32]"));
3822   const auto rhs = AllOf(op::Copy(op::DynamicSlice(
3823                              op::Parameter(), op::Reshape(), op::Constant())),
3824                          op::Shape("f32[32,256]"));
3825   EXPECT_THAT(root, AllOf(op::DynamicSlice(
3826                               AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3827                                     op::Shape("f32[128,256]")),
3828                               op::Constant(), op::Reshape()),
3829                           op::Shape("f32[128,128]")));
3830 }
3831 
TEST_F(SpmdPartitioningTest,BatchPartitionedConvolution)3832 TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) {
3833   absl::string_view hlo_string = R"(
3834 HloModule module
3835 
3836 ENTRY entry {
3837   %lhs = f32[128,256,256] parameter(0)
3838   %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1}
3839   %rhs = f32[256,8,1] parameter(1)
3840   %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated}
3841   ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy),
3842     window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1}
3843 })";
3844 
3845   TF_ASSERT_OK_AND_ASSIGN(auto module,
3846                           PartitionComputation(hlo_string, /*num_devices=*/2));
3847   VLOG(1) << module->ToString();
3848 
3849   const auto root = module->entry_computation()->root_instruction();
3850   const auto lhs =
3851       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3852                                       op::Reshape(), op::Constant())),
3853             op::Shape("f32[128,128,256]"));
3854   const auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]"));
3855   EXPECT_THAT(root,
3856               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]")));
3857 }
3858 
TEST_F(SpmdPartitioningTest,DotOutputFeaturePartitioned)3859 TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) {
3860   absl::string_view hlo_string = R"(
3861 HloModule module
3862 
3863 ENTRY entry {
3864   %lhs = f32[24,64] parameter(0)
3865   %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated}
3866   %rhs = f32[39296,64] parameter(1)
3867   %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1}
3868   ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy),
3869     lhs_batch_dims={}, rhs_batch_dims={},
3870     lhs_contracting_dims={1}, rhs_contracting_dims={1},
3871     sharding={devices=[1,2]0,1}
3872 })";
3873 
3874   TF_ASSERT_OK_AND_ASSIGN(auto module,
3875                           PartitionComputation(hlo_string, /*num_devices=*/2));
3876   VLOG(1) << module->ToString();
3877 
3878   const auto root = module->entry_computation()->root_instruction();
3879   const auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]"));
3880   const auto rhs = AllOf(op::Copy(op::DynamicSlice(
3881                              op::Parameter(1), op::Reshape(), op::Constant())),
3882                          op::Shape("f32[19648,64]"));
3883   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]")));
3884 }
3885 
TEST_F(SpmdPartitioningTest,WindowedEinsumTwoContractingDimsLhsReshard)3886 TEST_F(SpmdPartitioningTest, WindowedEinsumTwoContractingDimsLhsReshard) {
3887   absl::string_view hlo_string = R"(
3888 HloModule module
3889 
3890 ENTRY entry {
3891   %p0 = f32[2048,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1}
3892   %p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1}
3893   ROOT %dot.224 = f32[2048,2176]{1,0} dot(f32[2048,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1}
3894 })";
3895 
3896   TF_ASSERT_OK_AND_ASSIGN(
3897       auto module,
3898       PartitionComputation(hlo_string, /*num_devices=*/2,
3899                            /*conv_halo_exchange_always_on_lhs=*/true,
3900                            /*choose_faster_windowed_einsum=*/false,
3901                            /*unroll_windowed_einsum=*/false,
3902                            /*bidirectional_windowed_einsum=*/false,
3903                            /*threshold_for_windowed_einsum_mib=*/0));
3904   VLOG(1) << module->ToString();
3905 
3906   // Check while op.
3907   const auto arg0 = AllOf(
3908       op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0))))),
3909       op::Shape("f32[2048,1,3264]"));
3910   const auto arg1 = AllOf(op::Parameter(1), op::Shape("f32[1,3264,2176]"));
3911 
3912   const auto while_op =
3913       AllOf(op::While(op::Tuple(arg0, arg1, op::Broadcast(), op::Broadcast(),
3914                                 op::Constant())),
3915             op::Shape("(f32[2048,1,3264]{2,1,0}, f32[1,3264,2176]{2,1,0},"
3916                       " f32[2048,1088]{1,0}, f32[2048,1088]{1,0}, u32[])"));
3917   const auto root = module->entry_computation()->root_instruction();
3918   EXPECT_THAT(
3919       root, AllOf(op::GetTupleElement(while_op), op::Shape("f32[2048,1088]")));
3920 
3921   // Check while op body.
3922   const auto while_loop = root->operand(0);
3923   const auto next_i =
3924       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
3925   auto lhs = AllOf(op::GetTupleElement(op::Parameter(0)),
3926                    op::Shape("f32[2048,1,3264]"));
3927   auto rhs = AllOf(op::DynamicSlice(), op::Shape("f32[1,3264,1088]"));
3928   auto dot_op = op::Dot(lhs, rhs);
3929   auto add_op = op::Add(op::GetTupleElement(op::Parameter(0)), dot_op);
3930   auto cond_op =
3931       op::Conditional(op::Compare(next_i, op::Constant()), add_op, add_op);
3932   EXPECT_THAT(while_loop->while_body()->root_instruction(),
3933               op::Tuple(op::GetTupleElement(op::Parameter(0)),
3934                         op::GetTupleElement(op::Parameter(0)), cond_op,
3935                         op::GetTupleElement(op::Parameter(0)), next_i));
3936 }
3937 
TEST_F(SpmdPartitioningTest,WindowedEinsumTwoContractingDimsRhsReshard)3938 TEST_F(SpmdPartitioningTest, WindowedEinsumTwoContractingDimsRhsReshard) {
3939   absl::string_view hlo_string = R"(
3940 HloModule module
3941 
3942 ENTRY entry {
3943   %p0 = f32[4096,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1}
3944   %p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1}
3945   ROOT %dot.224 = f32[4096,2176]{1,0} dot(f32[4096,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1}
3946 })";
3947 
3948   TF_ASSERT_OK_AND_ASSIGN(
3949       auto module,
3950       PartitionComputation(hlo_string, /*num_devices=*/2,
3951                            /*conv_halo_exchange_always_on_lhs=*/true,
3952                            /*choose_faster_windowed_einsum=*/false,
3953                            /*unroll_windowed_einsum=*/false,
3954                            /*bidirectional_windowed_einsum=*/false,
3955                            /*threshold_for_windowed_einsum_mib=*/0));
3956   VLOG(1) << module->ToString();
3957 
3958   // Check while op.
3959   const auto arg0 = AllOf(op::Parameter(0), op::Shape("f32[4096,2,1632]"));
3960   const auto arg1 = AllOf(
3961       op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(1))))),
3962       op::Shape("f32[2,1632,2176]"));
3963 
3964   const auto while_op =
3965       AllOf(op::While(op::Tuple(arg0, arg1, op::Broadcast(), op::Broadcast(),
3966                                 op::Constant())),
3967             op::Shape("(f32[4096,2,1632]{2,1,0}, f32[2,1632,2176]{2,1,0},"
3968                       " f32[4096,1088]{1,0}, f32[4096,1088]{1,0}, u32[])"));
3969   const auto root = module->entry_computation()->root_instruction();
3970   EXPECT_THAT(
3971       root, AllOf(op::GetTupleElement(while_op), op::Shape("f32[4096,1088]")));
3972 
3973   // Check while op body.
3974   const auto while_loop = root->operand(0);
3975   const auto next_i =
3976       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
3977   auto lhs = AllOf(op::GetTupleElement(op::Parameter(0)),
3978                    op::Shape("f32[4096,2,1632]"));
3979   auto rhs = AllOf(op::DynamicSlice(), op::Shape("f32[2,1632,1088]"));
3980   auto dot_op = op::Dot(lhs, rhs);
3981   auto add_op = op::Add(op::GetTupleElement(op::Parameter(0)), dot_op);
3982   auto cond_op =
3983       op::Conditional(op::Compare(next_i, op::Constant()), add_op, add_op);
3984   EXPECT_THAT(while_loop->while_body()->root_instruction(),
3985               op::Tuple(op::GetTupleElement(op::Parameter(0)),
3986                         op::GetTupleElement(op::Parameter(0)), cond_op,
3987                         op::GetTupleElement(op::Parameter(0)), next_i));
3988 }
3989 
TEST_F(SpmdPartitioningTest,DotPartialDeviceOrder)3990 TEST_F(SpmdPartitioningTest, DotPartialDeviceOrder) {
3991   absl::string_view hlo_string = R"(
3992 HloModule module
3993 
3994 ENTRY entry {
3995   %lhs = f32[16,256,4096] parameter(0), sharding={devices=[1,1,2,2]1,3,0,2 last_tile_dim_replicate}
3996   %rhs = f32[4096,2048] parameter(1), sharding={devices=[2,2]3,1,2,0}
3997   ROOT %dot = f32[16,256,2048] dot(%lhs, %rhs),
3998     lhs_batch_dims={}, rhs_batch_dims={},
3999     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4000     sharding={devices=[1,1,2,2]2,3,0,1 last_tile_dim_replicate}
4001 })";
4002 
4003   TF_ASSERT_OK_AND_ASSIGN(auto module,
4004                           PartitionComputation(hlo_string, /*num_devices=*/4));
4005   VLOG(1) << module->ToString();
4006 
4007   const auto root = module->entry_computation()->root_instruction();
4008   const auto lhs = AllOf(op::Parameter(0), op::Shape("f32[16,256,2048]"));
4009   const auto rhs = AllOf(op::Parameter(1), op::Shape("f32[2048,1024]"));
4010   EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)),
4011                           op::Shape("f32[16,256,1024]")));
4012 }
4013 
TEST_F(SpmdPartitioningTest,EinsumBatchPartitioned)4014 TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) {
4015   absl::string_view hlo_string = R"(
4016 HloModule module
4017 
4018 ENTRY entry {
4019   %lhs = f32[32,24,64] parameter(0)
4020   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
4021   %rhs = f32[32,39296,64] parameter(1)
4022   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
4023   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4024     lhs_batch_dims={0}, rhs_batch_dims={0},
4025     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4026     sharding={devices=[2,1,1]0,1}
4027 })";
4028 
4029   TF_ASSERT_OK_AND_ASSIGN(auto module,
4030                           PartitionComputation(hlo_string, /*num_devices=*/2));
4031   VLOG(1) << module->ToString();
4032 
4033   const auto root = module->entry_computation()->root_instruction();
4034   const auto lhs =
4035       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4036                                       op::Constant(), op::Constant())),
4037             op::Shape("f32[16,24,64]"));
4038   const auto rhs =
4039       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4040                                       op::Constant(), op::Constant())),
4041             op::Shape("f32[16,39296,64]"));
4042   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]")));
4043 }
4044 
TEST_F(SpmdPartitioningTest,EinsumLHSandOutputBatchPartitioned)4045 TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) {
4046   absl::string_view hlo_string = R"(
4047 HloModule module
4048 
4049 ENTRY entry {
4050   %lhs = f32[32,24,64] parameter(0)
4051   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
4052   %rhs = f32[32,39296,64] parameter(1)
4053   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
4054   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4055     lhs_batch_dims={0}, rhs_batch_dims={0},
4056     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4057     sharding={devices=[2,1,1]0,1}
4058 })";
4059 
4060   TF_ASSERT_OK_AND_ASSIGN(auto module,
4061                           PartitionComputation(hlo_string, /*num_devices=*/2));
4062   VLOG(1) << module->ToString();
4063 
4064   const auto root = module->entry_computation()->root_instruction();
4065   const auto lhs =
4066       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4067                                       op::Constant(), op::Constant())),
4068             op::Shape("f32[16,24,64]"));
4069   const auto rhs =
4070       AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
4071   EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(),
4072                                                         op::Constant(),
4073                                                         op::Constant())),
4074                           op::Shape("f32[16,24,39296]")));
4075 }
4076 
TEST_F(SpmdPartitioningTest,EinsumRHSandOutputBatchPartitioned)4077 TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) {
4078   absl::string_view hlo_string = R"(
4079 HloModule module
4080 
4081 ENTRY entry {
4082   %lhs = f32[32,24,64] parameter(0)
4083   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1}
4084   %rhs = f32[32,39296,64] parameter(1)
4085   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
4086   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4087     lhs_batch_dims={0}, rhs_batch_dims={0},
4088     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4089     sharding={devices=[2,1,1]0,1}
4090 })";
4091 
4092   TF_ASSERT_OK_AND_ASSIGN(auto module,
4093                           PartitionComputation(hlo_string, /*num_devices=*/2));
4094   VLOG(1) << module->ToString();
4095 
4096   const auto root = module->entry_computation()->root_instruction();
4097   const auto lhs =
4098       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4099                                       op::Reshape(), op::Constant())),
4100             op::Shape("f32[32,12,64]"));
4101   const auto rhs =
4102       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4103                                       op::Constant(), op::Constant())),
4104             op::Shape("f32[16,39296,64]"));
4105   const auto lhs_reshard =
4106       op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))));
4107   EXPECT_THAT(root,
4108               AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]")));
4109 }
4110 
TEST_F(SpmdPartitioningTest,EinsumOutputBatchPartitioned)4111 TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) {
4112   absl::string_view hlo_string = R"(
4113 HloModule module
4114 
4115 ENTRY entry {
4116   %lhs = f32[32,24,64] parameter(0)
4117   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
4118   %rhs = f32[32,39296,64] parameter(1)
4119   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
4120   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4121     lhs_batch_dims={0}, rhs_batch_dims={0},
4122     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4123     sharding={devices=[2,1,1]0,1}
4124 })";
4125 
4126   TF_ASSERT_OK_AND_ASSIGN(auto module,
4127                           PartitionComputation(hlo_string, /*num_devices=*/2));
4128   VLOG(1) << module->ToString();
4129 
4130   const auto root = module->entry_computation()->root_instruction();
4131   const auto lhs_slice =
4132       AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(),
4133                              op::Constant(), op::Constant()),
4134             op::Shape("f32[16,24,64]"));
4135   const auto rhs_slice =
4136       AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(),
4137                              op::Constant(), op::Constant()),
4138             op::Shape("f32[16,39296,64]"));
4139   EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice),
4140                           op::Shape("f32[16,24,39296]")));
4141 }
4142 
TEST_F(SpmdPartitioningTest,EinsumContractingDimsPartitioned)4143 TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) {
4144   absl::string_view hlo_string = R"(
4145 HloModule module
4146 
4147 ENTRY entry {
4148   %lhs = f32[32,24,64,128] parameter(0)
4149   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3}
4150   %rhs = f32[32,39296,64,128] parameter(1)
4151   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3}
4152   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4153     lhs_batch_dims={0}, rhs_batch_dims={0},
4154     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4155     sharding={replicated}
4156 })";
4157 
4158   TF_ASSERT_OK_AND_ASSIGN(auto module,
4159                           PartitionComputation(hlo_string, /*num_devices=*/4));
4160   VLOG(1) << module->ToString();
4161 
4162   const auto root = module->entry_computation()->root_instruction();
4163   const auto lhs = AllOf(
4164       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4165                                 op::Constant(), op::Reshape(), op::Reshape())),
4166       op::Shape("f32[32,24,32,64]"));
4167   const auto rhs = AllOf(
4168       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4169                                 op::Constant(), op::Reshape(), op::Reshape())),
4170       op::Shape("f32[32,39296,32,64]"));
4171   EXPECT_THAT(root, AllOf(op::AllReduce(op::AllReduce(op::Dot(lhs, rhs))),
4172                           op::Shape("f32[32,24,39296]")));
4173 }
4174 
TEST_F(SpmdPartitioningTest,EinsumContractingDimsPartitionedResultPartiallySliced)4175 TEST_F(SpmdPartitioningTest,
4176        EinsumContractingDimsPartitionedResultPartiallySliced) {
4177   absl::string_view hlo_string = R"(
4178 HloModule module
4179 
4180 ENTRY entry {
4181   %lhs = f32[32,64] parameter(0), sharding={devices=[1,4]0,1,2,3}
4182   %rhs = f32[64,128] parameter(1), sharding={devices=[4,1]0,1,2,3}
4183   ROOT %dot = f32[32,128] dot(%lhs, %rhs),
4184     lhs_contracting_dims={1}, rhs_contracting_dims={0},
4185     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
4186 })";
4187 
4188   TF_ASSERT_OK_AND_ASSIGN(auto module,
4189                           PartitionComputation(hlo_string, /*num_devices=*/4));
4190   VLOG(1) << module->ToString();
4191 
4192   const auto root = module->entry_computation()->root_instruction();
4193   const auto lhs = AllOf(op::Parameter(0), op::Shape("f32[32,16]"));
4194   const auto rhs = AllOf(op::Parameter(1), op::Shape("f32[16,128]"));
4195   EXPECT_THAT(root, AllOf(op::AllReduce(op::DynamicSlice(
4196                               op::AllReduce(op::Dot(lhs, rhs)), _, _)),
4197                           op::Shape("f32[16,128]")));
4198 }
4199 
TEST_F(SpmdPartitioningTest,EinsumLHSNonContractingDimsPartitioned)4200 TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) {
4201   absl::string_view hlo_string = R"(
4202 HloModule module
4203 
4204 ENTRY entry {
4205   %lhs = f32[32,24,64,128] parameter(0)
4206   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3}
4207   %rhs = f32[32,39296,64] parameter(1)
4208   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
4209   ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy),
4210     lhs_batch_dims={0}, rhs_batch_dims={0},
4211     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4212     sharding={devices=[1,2,2,1]0,1,2,3}
4213 })";
4214 
4215   TF_ASSERT_OK_AND_ASSIGN(auto module,
4216                           PartitionComputation(hlo_string, /*num_devices=*/4));
4217   VLOG(1) << module->ToString();
4218 
4219   const auto root = module->entry_computation()->root_instruction();
4220   const auto lhs = AllOf(
4221       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
4222                                 op::Constant(), op::Reshape())),
4223       op::Shape("f32[32,12,64,64]"));
4224   const auto rhs =
4225       AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
4226   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]")));
4227 }
4228 
TEST_F(SpmdPartitioningTest,EinsumRHSNonContractingDimsPartitioned)4229 TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) {
4230   absl::string_view hlo_string = R"(
4231 HloModule module
4232 
4233 ENTRY entry {
4234   %lhs = f32[32,24,64] parameter(0)
4235   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
4236   %rhs = f32[32,39296,64,128] parameter(1)
4237   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3}
4238   ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy),
4239     lhs_batch_dims={0}, rhs_batch_dims={0},
4240     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4241     sharding={devices=[1,1,2,2]0,1,2,3}
4242 })";
4243 
4244   TF_ASSERT_OK_AND_ASSIGN(auto module,
4245                           PartitionComputation(hlo_string, /*num_devices=*/4));
4246   VLOG(1) << module->ToString();
4247 
4248   const auto root = module->entry_computation()->root_instruction();
4249   const auto lhs =
4250       AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]"));
4251   const auto rhs = AllOf(
4252       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
4253                                 op::Constant(), op::Reshape())),
4254       op::Shape("f32[32,19648,64,64]"));
4255   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]")));
4256 }
4257 
TEST_F(SpmdPartitioningTest,EinsumOutputLHSNonContractingDimPartitioned)4258 TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) {
4259   absl::string_view hlo_string = R"(
4260 HloModule module
4261 
4262 ENTRY entry {
4263   %lhs = f32[32,24,64,128] parameter(0)
4264   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4265   %rhs = f32[32,39296,64,128] parameter(1)
4266   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4267   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4268     lhs_batch_dims={0}, rhs_batch_dims={0},
4269     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4270     sharding={devices=[1,2,1]0,1}
4271 })";
4272 
4273   TF_ASSERT_OK_AND_ASSIGN(auto module,
4274                           PartitionComputation(hlo_string, /*num_devices=*/2));
4275   VLOG(1) << module->ToString();
4276 
4277   const auto root = module->entry_computation()->root_instruction();
4278   const auto lhs =
4279       AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4280   const auto rhs =
4281       AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4282   EXPECT_THAT(
4283       root,
4284       AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(),
4285                                            op::Constant(), op::Constant()),
4286                           op::Shape("f32[32,12,64,128]")),
4287                     rhs),
4288             op::Shape("f32[32,12,39296]")));
4289 }
4290 
TEST_F(SpmdPartitioningTest,EinsumOutputRHSNonContractingDimPartitioned)4291 TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) {
4292   absl::string_view hlo_string = R"(
4293 HloModule module
4294 
4295 ENTRY entry {
4296   %lhs = f32[32,24,64,128] parameter(0)
4297   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4298   %rhs = f32[32,39296,64,128] parameter(1)
4299   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4300   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4301     lhs_batch_dims={0}, rhs_batch_dims={0},
4302     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4303     sharding={devices=[1,1,2]0,1}
4304 })";
4305 
4306   TF_ASSERT_OK_AND_ASSIGN(auto module,
4307                           PartitionComputation(hlo_string, /*num_devices=*/2));
4308   VLOG(1) << module->ToString();
4309 
4310   const auto root = module->entry_computation()->root_instruction();
4311   const auto lhs =
4312       AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4313   const auto rhs =
4314       AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4315   EXPECT_THAT(root,
4316               AllOf(op::Dot(lhs, AllOf(op::DynamicSlice(
4317                                            rhs, op::Constant(), op::Reshape(),
4318                                            op::Constant(), op::Constant()),
4319                                        op::Shape("f32[32,19648,64,128]"))),
4320                     op::Shape("f32[32,24,19648]")));
4321 }
4322 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingPartitioned)4323 TEST_F(SpmdPartitioningTest,
4324        EinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4325   absl::string_view hlo_string = R"(
4326 HloModule module
4327 
4328 ENTRY entry {
4329   %lhs = f32[320,25,64,128] parameter(0)
4330   %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4331   %rhs = f32[320,39296,64,128] parameter(1)
4332   %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4333     sharding={devices=[1,1,4,1]0,1,2,3}
4334   ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4335     lhs_batch_dims={0}, rhs_batch_dims={0},
4336     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4337     sharding={devices=[1,4,1]0,1,2,3}
4338 })";
4339 
4340   TF_ASSERT_OK_AND_ASSIGN(auto module,
4341                           PartitionComputation(hlo_string, /*num_devices=*/4));
4342   VLOG(1) << module->ToString();
4343 
4344   const auto root = module->entry_computation()->root_instruction();
4345   const auto lhs = AllOf(
4346       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4347                                 op::Constant(), op::Reshape(), op::Constant())),
4348       op::Shape("f32[320,25,16,128]"));
4349   const auto rhs = AllOf(
4350       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4351                                 op::Constant(), op::Reshape(), op::Constant())),
4352       op::Shape("f32[320,39296,16,128]"));
4353   EXPECT_THAT(
4354       root,
4355       AllOf(op::GetTupleElement(op::While(op::Tuple(
4356                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4357             op::Shape("f32[320,7,39296]")));
4358 
4359   const auto while_loop = root->operand(0);
4360   // Check loop condition.
4361   EXPECT_THAT(
4362       while_loop->while_condition()->root_instruction(),
4363       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4364 
4365   // Check loop body.
4366   const auto next_i =
4367       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4368   auto ds =
4369       AllOf(op::DynamicSlice(
4370                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4371                 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
4372             op::Shape("f32[320,7,16,128]"));
4373   auto partial_output =
4374       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4375                     op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
4376             op::Shape("f32[320,7,39296]"));
4377   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4378                                 partial_output, partial_output);
4379   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4380               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4381                         op::GetTupleElement(op::Parameter(0)), window,
4382                         op::GetTupleElement(op::Parameter(0)), next_i));
4383 
4384   // Check the conditional that contains the collective permute.
4385   auto cp_conditional =
4386       while_loop->while_body()->root_instruction()->operand(2);
4387   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4388               op::CollectivePermute(op::Parameter(0)));
4389   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4390               op::Parameter(0));
4391 }
4392 
TEST_F(SpmdPartitioningTest,UnrolledEinsumRHSWindowedInContractingOutNonContractingPartitioned)4393 TEST_F(SpmdPartitioningTest,
4394        UnrolledEinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4395   absl::string_view hlo_string = R"(
4396 HloModule module
4397 
4398 ENTRY entry {
4399   %lhs = f32[320,25,64,128] parameter(0)
4400   %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4401   %rhs = f32[320,39296,64,128] parameter(1)
4402   %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4403     sharding={devices=[1,1,4,1]0,1,2,3}
4404   ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4405     lhs_batch_dims={0}, rhs_batch_dims={0},
4406     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4407     sharding={devices=[1,4,1]0,1,2,3}
4408 })";
4409 
4410   TF_ASSERT_OK_AND_ASSIGN(
4411       auto module,
4412       PartitionComputation(hlo_string, /*num_devices=*/4,
4413                            /*conv_halo_exchange_always_on_lhs =*/true,
4414                            /*choose_faster_windowed_einsum =*/false,
4415                            /*unroll_windowed_einsum =*/true));
4416   VLOG(1) << module->ToString();
4417 
4418   const auto lhs = AllOf(
4419       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4420                                 op::Constant(), op::Reshape(), op::Constant())),
4421       op::Shape("f32[320,25,16,128]"));
4422   const auto rhs = AllOf(
4423       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4424                                 op::Constant(), op::Reshape(), op::Constant())),
4425       op::Shape("f32[320,39296,16,128]"));
4426   const auto while_op = AllOf(
4427       op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4428                           op::Constant())),
4429       op::Shape("(f32[320,25,16,128], f32[320,39296,16,128], f32[320,7,39296],"
4430                 " f32[320,7,39296], u32[])"));
4431   const auto root = module->entry_computation()->root_instruction();
4432   EXPECT_THAT(
4433       root, AllOf(op::Add(op::CollectivePermute(op::GetTupleElement(while_op)),
4434                           op::GetTupleElement(while_op)),
4435                   op::Shape("f32[320,7,39296]")));
4436 
4437   const auto while_loop = root->operand(1)->operand(0);
4438   // Check loop condition.
4439   EXPECT_THAT(
4440       while_loop->while_condition()->root_instruction(),
4441       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4442 
4443   // Check loop body.
4444   const auto next_i =
4445       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4446   auto ds =
4447       AllOf(op::DynamicSlice(
4448                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4449                 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
4450             op::Shape("f32[320,7,16,128]"));
4451   auto partial_output = AllOf(
4452       op::Add(op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4453               op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
4454       op::Shape("f32[320,7,39296]"));
4455   auto partial_output2 =
4456       AllOf(op::CollectivePermute(
4457                 op::Add(op::GetTupleElement(op::Parameter(0)),
4458                         op::Dot(ds, op::GetTupleElement(op::Parameter(0))))),
4459             op::Shape("f32[320,7,39296]"));
4460 
4461   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4462               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4463                         op::GetTupleElement(op::Parameter(0)), partial_output,
4464                         partial_output2, next_i));
4465 }
4466 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedInContractingOutNonContractingPartitioned)4467 TEST_F(
4468     SpmdPartitioningTest,
4469     BidirectionalEinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4470   absl::string_view hlo_string = R"(
4471 HloModule module
4472 
4473 ENTRY entry {
4474   %lhs = f32[320,25,64,128] parameter(0)
4475   %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4476   %rhs = f32[320,39296,64,128] parameter(1)
4477   %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4478     sharding={devices=[1,1,4,1]0,1,2,3}
4479   ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4480     lhs_batch_dims={0}, rhs_batch_dims={0},
4481     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4482     sharding={devices=[1,4,1]0,1,2,3}
4483 })";
4484 
4485   TF_ASSERT_OK_AND_ASSIGN(
4486       auto module,
4487       PartitionComputation(hlo_string, /*num_devices=*/4,
4488                            /*conv_halo_exchange_always_on_lhs =*/true,
4489                            /*choose_faster_windowed_einsum =*/false,
4490                            /*unroll_windowed_einsum =*/false,
4491                            /*bidirectional_windowed_einsum =*/true));
4492   VLOG(1) << module->ToString();
4493   const auto lhs = AllOf(
4494       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4495                                 op::Constant(), op::Reshape(), op::Constant())),
4496       op::Shape("f32[320,25,16,128]"));
4497   const auto rhs = AllOf(
4498       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4499                                 op::Constant(), op::Reshape(), op::Constant())),
4500       op::Shape("f32[320,39296,16,128]"));
4501   const auto while_op = AllOf(
4502       op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4503                           op::Constant())),
4504       op::Shape("(f32[320,25,16,128], f32[320,39296,16,128], f32[320,7,39296],"
4505                 " f32[320,7,39296], u32[])"));
4506   const auto root = module->entry_computation()->root_instruction();
4507   EXPECT_THAT(
4508       root, AllOf(op::Add(op::GetTupleElement(while_op),
4509                           op::CollectivePermute(op::GetTupleElement(while_op))),
4510                   op::Shape("f32[320,7,39296]")));
4511 
4512   const auto while_loop = root->operand(0)->operand(0);
4513   // Check loop condition.
4514   EXPECT_THAT(
4515       while_loop->while_condition()->root_instruction(),
4516       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4517 
4518   // Check loop body.
4519   const auto next_i =
4520       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4521               op::Constant());
4522   const auto partial_dot_pattern =
4523       AllOf(op::Reshape(op::Slice(
4524                 op::Dot(op::Maximum(), op::GetTupleElement(op::Parameter(0))))),
4525             op::Shape("f32[320,7,39296]"));
4526   const auto partial_output_pattern = AllOf(
4527       op::Add(op::CollectivePermute(op::Add(
4528                   op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4529                   partial_dot_pattern)),
4530               partial_dot_pattern),
4531       op::Shape("f32[320,7,39296]"));
4532 
4533   EXPECT_THAT(
4534       while_loop->while_body()->root_instruction(),
4535       op::Tuple(op::GetTupleElement(op::Parameter(0)),
4536                 op::GetTupleElement(op::Parameter(0)), partial_output_pattern,
4537                 partial_output_pattern, next_i));
4538 }
4539 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingFromBroadcast)4540 TEST_F(SpmdPartitioningTest,
4541        EinsumRHSWindowedInContractingOutNonContractingFromBroadcast) {
4542   absl::string_view hlo_string = R"(
4543 HloModule module
4544 
4545 ENTRY entry {
4546   %constant.1 = f32[] constant(2)
4547   %broadcast = f32[32,25,64,128] broadcast(%constant.1), dimensions={},
4548     sharding={devices=[1,1,4,1]0,1,2,3}
4549   %add = f32[32,25,64,128] add(%broadcast, %broadcast),
4550     sharding={devices=[1,1,4,1]0,1,2,3}
4551   %rhs = f32[32,39296,64,128] parameter(0)
4552   %rhs.copy = f32[32,39296,64,128] copy(%rhs),
4553     sharding={devices=[1,1,4,1]0,1,2,3}
4554   ROOT %dot = f32[32,25,39296] dot(%add, %rhs.copy),
4555     lhs_batch_dims={0}, rhs_batch_dims={0},
4556     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4557     sharding={devices=[1,4,1]0,1,2,3}
4558 })";
4559 
4560   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4561                                                             /*num_devices=*/4));
4562   VLOG(1) << module->ToString();
4563   // Involves loop code motion, skips pattern matching.
4564 }
4565 
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned)4566 TEST_F(SpmdPartitioningTest,
4567        EinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4568   absl::string_view hlo_string = R"(
4569 HloModule module
4570 
4571 ENTRY entry {
4572   %lhs = f32[16,1024,16384] parameter(0)
4573   %lhs.copy = f32[16,1024,16384] copy(%lhs),
4574     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4575   %rhs = f32[16384,67,128] parameter(1)
4576   %rhs.copy = f32[16384,67,128] copy(%rhs),
4577     sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4578   ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4579     lhs_batch_dims={}, rhs_batch_dims={},
4580     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4581     sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4582 })";
4583 
4584   TF_ASSERT_OK_AND_ASSIGN(auto module,
4585                           PartitionComputation(hlo_string, /*num_devices=*/8));
4586   VLOG(1) << module->ToString();
4587 
4588   const auto root = module->entry_computation()->root_instruction();
4589   const auto lhs =
4590       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4591                                       op::Constant(), op::Reshape())),
4592             op::Shape("f32[8,1024,4096]"));
4593   const auto rhs =
4594       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4595                                       op::Constant(), op::Constant())),
4596             op::Shape("f32[4096,67,128]"));
4597   EXPECT_THAT(
4598       root,
4599       AllOf(op::GetTupleElement(op::While(op::Tuple(
4600                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4601             op::Shape("f32[8,1024,17,128]")));
4602 
4603   const auto while_loop = root->operand(0);
4604   // Check loop condition.
4605   EXPECT_THAT(
4606       while_loop->while_condition()->root_instruction(),
4607       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4608 
4609   // Check loop body.
4610   const auto next_i =
4611       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4612   auto ds =
4613       AllOf(op::DynamicSlice(
4614                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4615                 op::Constant(), op::Reshape(), op::Constant()),
4616             op::Shape("f32[4096,17,128]"));
4617   auto partial_output =
4618       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4619                     op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4620             op::Shape("f32[8,1024,17,128]"));
4621   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4622                                 partial_output, partial_output);
4623   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4624               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4625                         op::GetTupleElement(op::Parameter(0)), window,
4626                         op::GetTupleElement(op::Parameter(0)), next_i));
4627 
4628   // Check the conditional that contains the collective permute.
4629   auto cp_conditional =
4630       while_loop->while_body()->root_instruction()->operand(2);
4631   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4632               op::CollectivePermute(op::Parameter(0)));
4633   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4634               op::Parameter(0));
4635 }
4636 
TEST_F(SpmdPartitioningTest,UnrollEinsumLHSWindowedInContractingOutNonContractingPartitioned)4637 TEST_F(SpmdPartitioningTest,
4638        UnrollEinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4639   absl::string_view hlo_string = R"(
4640 HloModule module
4641 
4642 ENTRY entry {
4643   %lhs = f32[16,1024,16384] parameter(0)
4644   %lhs.copy = f32[16,1024,16384] copy(%lhs),
4645     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4646   %rhs = f32[16384,67,128] parameter(1)
4647   %rhs.copy = f32[16384,67,128] copy(%rhs),
4648     sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4649   ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4650     lhs_batch_dims={}, rhs_batch_dims={},
4651     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4652     sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4653 })";
4654 
4655   TF_ASSERT_OK_AND_ASSIGN(
4656       auto module,
4657       PartitionComputation(hlo_string, /*num_devices=*/8,
4658                            /*conv_halo_exchange_always_on_lhs =*/true,
4659                            /*choose_faster_windowed_einsum =*/false,
4660                            /*unroll_windowed_einsum =*/true));
4661   VLOG(1) << module->ToString();
4662 
4663   const auto lhs =
4664       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4665                                       op::Constant(), op::Reshape())),
4666             op::Shape("f32[8,1024,4096]"));
4667   const auto rhs =
4668       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4669                                       op::Constant(), op::Constant())),
4670             op::Shape("f32[4096,67,128]"));
4671   const auto while_op =
4672       AllOf(op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4673                                 op::Constant())),
4674             op::Shape("(f32[8,1024,4096], f32[4096,67,128], f32[8,1024,17,128],"
4675                       " f32[8,1024,17,128], u32[])"));
4676   const auto root = module->entry_computation()->root_instruction();
4677   EXPECT_THAT(
4678       root, AllOf(op::Add(op::CollectivePermute(op::GetTupleElement(while_op)),
4679                           op::GetTupleElement(while_op)),
4680                   op::Shape("f32[8,1024,17,128]")));
4681 
4682   const auto while_loop = root->operand(1)->operand(0);
4683   // Check loop condition.
4684   EXPECT_THAT(
4685       while_loop->while_condition()->root_instruction(),
4686       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4687 
4688   // Check loop body.
4689   const auto next_i =
4690       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4691   auto ds =
4692       AllOf(op::DynamicSlice(
4693                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4694                 op::Constant(), op::Reshape(), op::Constant()),
4695             op::Shape("f32[4096,17,128]"));
4696   auto partial_output = AllOf(
4697       op::Add(op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4698               op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4699       op::Shape("f32[8,1024,17,128]"));
4700   auto partial_output2 =
4701       AllOf(op::CollectivePermute(
4702                 op::Add(op::GetTupleElement(op::Parameter(0)),
4703                         op::Dot(op::GetTupleElement(op::Parameter(0)), ds))),
4704             op::Shape("f32[8,1024,17,128]"));
4705 
4706   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4707               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4708                         op::GetTupleElement(op::Parameter(0)), partial_output,
4709                         partial_output2, next_i));
4710 }
4711 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumLHSWindowedInContractingOutNonContractingPartitioned)4712 TEST_F(
4713     SpmdPartitioningTest,
4714     BidirectionalEinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4715   absl::string_view hlo_string = R"(
4716 HloModule module
4717 
4718 ENTRY entry {
4719   %lhs = f32[16,1024,16384] parameter(0)
4720   %lhs.copy = f32[16,1024,16384] copy(%lhs),
4721     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4722   %rhs = f32[16384,67,128] parameter(1)
4723   %rhs.copy = f32[16384,67,128] copy(%rhs),
4724     sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4725   ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4726     lhs_batch_dims={}, rhs_batch_dims={},
4727     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4728     sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4729 })";
4730 
4731   TF_ASSERT_OK_AND_ASSIGN(
4732       auto module,
4733       PartitionComputation(hlo_string, /*num_devices=*/8,
4734                            /*conv_halo_exchange_always_on_lhs =*/true,
4735                            /*choose_faster_windowed_einsum =*/false,
4736                            /*unroll_windowed_einsum =*/false,
4737                            /*bidirectional_windowed_einsum =*/true));
4738   VLOG(1) << module->ToString();
4739 
4740   const auto lhs =
4741       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4742                                       op::Constant(), op::Reshape())),
4743             op::Shape("f32[8,1024,4096]"));
4744   const auto rhs =
4745       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4746                                       op::Constant(), op::Constant())),
4747             op::Shape("f32[4096,67,128]"));
4748   const auto while_op =
4749       AllOf(op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4750                                 op::Constant())),
4751             op::Shape("(f32[8,1024,4096], f32[4096,67,128], f32[8,1024,17,128],"
4752                       " f32[8,1024,17,128], u32[])"));
4753   const auto root = module->entry_computation()->root_instruction();
4754   EXPECT_THAT(
4755       root, AllOf(op::Add(op::GetTupleElement(while_op),
4756                           op::CollectivePermute(op::GetTupleElement(while_op))),
4757                   op::Shape("f32[8,1024,17,128]")));
4758 
4759   const auto while_loop = root->operand(0)->operand(0);
4760   // Check loop condition.
4761   EXPECT_THAT(
4762       while_loop->while_condition()->root_instruction(),
4763       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4764 
4765   // Check loop body.
4766   const auto next_i =
4767       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4768               op::Constant());
4769   const auto partial_dot_pattern =
4770       AllOf(op::Reshape(op::Slice(
4771                 op::Dot(op::GetTupleElement(op::Parameter(0)), op::Maximum()))),
4772             op::Shape("f32[8,1024,17,128]"));
4773   const auto partial_output_pattern = AllOf(
4774       op::Add(op::CollectivePermute(op::Add(
4775                   op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4776                   partial_dot_pattern)),
4777               partial_dot_pattern),
4778       op::Shape("f32[8,1024,17,128]"));
4779 
4780   EXPECT_THAT(
4781       while_loop->while_body()->root_instruction(),
4782       op::Tuple(op::GetTupleElement(op::Parameter(0)),
4783                 op::GetTupleElement(op::Parameter(0)), partial_output_pattern,
4784                 partial_output_pattern, next_i));
4785 }
4786 
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned2)4787 TEST_F(SpmdPartitioningTest,
4788        EinsumLHSWindowedInContractingOutNonContractingPartitioned2) {
4789   absl::string_view hlo_string = R"(
4790 HloModule module
4791 
4792 ENTRY entry {
4793   %lhs = f32[16,1024,16384] parameter(0)
4794   %lhs.copy = f32[16,1024,16384] copy(%lhs),
4795     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4796   %rhs = f32[16384,2,33,128] parameter(1)
4797   %rhs.copy = f32[16384,2,33,128] copy(%rhs),
4798     sharding={devices=[4,1,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4799   ROOT %dot = f32[16,1024,2,33,128] dot(%lhs.copy, %rhs.copy),
4800     lhs_batch_dims={}, rhs_batch_dims={},
4801     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4802     sharding={devices=[2,1,2,2,1]0,1,2,3,4,5,6,7}
4803 })";
4804 
4805   TF_ASSERT_OK_AND_ASSIGN(auto module,
4806                           PartitionComputation(hlo_string, /*num_devices=*/8));
4807   VLOG(1) << module->ToString();
4808 
4809   const auto root = module->entry_computation()->root_instruction();
4810   const auto lhs =
4811       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4812                                       op::Constant(), op::Reshape())),
4813             op::Shape("f32[8,1024,4096]"));
4814   const auto rhs = AllOf(
4815       op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
4816                                 op::Constant(), op::Constant())),
4817       op::Shape("f32[4096,2,33,128]"));
4818   EXPECT_THAT(
4819       root,
4820       AllOf(op::GetTupleElement(op::While(op::Tuple(
4821                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4822             op::Shape("f32[8,1024,1,17,128]")));
4823 
4824   const auto while_loop = root->operand(0);
4825   // Check loop condition.
4826   EXPECT_THAT(
4827       while_loop->while_condition()->root_instruction(),
4828       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4829 
4830   // Check loop body.
4831   const auto next_i =
4832       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4833   auto ds =
4834       AllOf(op::DynamicSlice(
4835                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4836                 op::Constant(), op::Reshape(), op::Reshape(), op::Constant()),
4837             op::Shape("f32[4096,1,17,128]"));
4838   auto partial_output =
4839       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4840                     op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4841             op::Shape("f32[8,1024,1,17,128]"));
4842   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4843                                 partial_output, partial_output);
4844   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4845               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4846                         op::GetTupleElement(op::Parameter(0)), window,
4847                         op::GetTupleElement(op::Parameter(0)), next_i));
4848 
4849   // Check the conditional that contains the collective permute.
4850   auto cp_conditional =
4851       while_loop->while_body()->root_instruction()->operand(2);
4852   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4853               op::CollectivePermute(op::Parameter(0)));
4854   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4855               op::Parameter(0));
4856 }
4857 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingNoDoubleAG)4858 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingNoDoubleAG) {
4859   absl::string_view hlo_string = R"(
4860 HloModule module
4861 
4862 ENTRY entry {
4863   %lhs = f32[32,24,64,128] parameter(0)
4864   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4865   %lhs2 = f32[32,24,64,128] parameter(2)
4866   %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,2,1,1]0,1}
4867   %rhs = f32[32,39295,64,128] parameter(1)
4868   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4869   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4870     lhs_batch_dims={0}, rhs_batch_dims={0},
4871     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4872     sharding={devices=[1,2,1]0,1}
4873   %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
4874     lhs_batch_dims={0}, rhs_batch_dims={0},
4875     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4876     sharding={devices=[1,2,1]0,1}
4877   ROOT %t = tuple(%dot, %dot2)
4878 })";
4879 
4880   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4881                                                             /*num_devices=*/2));
4882   VLOG(1) << module->ToString();
4883   const auto root = module->entry_computation()->root_instruction();
4884   EXPECT_THAT(root, op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
4885                                   _, op::Dot(_, op::Slice(_)), _, _, _)),
4886                               op::AllReduce(op::DynamicUpdateSlice(
4887                                   _, op::Dot(_, op::Slice(_)), _, _, _))));
4888 }
4889 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingNoSharedSharding)4890 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingNoSharedSharding) {
4891   absl::string_view hlo_string = R"(
4892 HloModule module
4893 
4894 ENTRY entry {
4895   %lhs = f32[32,24,64,128] parameter(0)
4896   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4897   %lhs2 = f32[32,24,64,128] parameter(2)
4898   %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,1,2,1]0,1}
4899   %rhs = f32[32,39295,64,128] parameter(1)
4900   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4901   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4902     lhs_batch_dims={0}, rhs_batch_dims={0},
4903     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4904     sharding={devices=[1,2,1]0,1}
4905   %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
4906     lhs_batch_dims={0}, rhs_batch_dims={0},
4907     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4908     sharding={devices=[2,1,1]0,1}
4909   ROOT %t = tuple(%dot, %dot2)
4910 })";
4911 
4912   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4913                                                             /*num_devices=*/2));
4914   VLOG(1) << module->ToString();
4915   const auto root = module->entry_computation()->root_instruction();
4916   EXPECT_THAT(
4917       root,
4918       op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
4919                     _, op::Slice(op::GetTupleElement(op::While(_))), _, _, _)),
4920                 op::AllReduce(op::DynamicUpdateSlice(
4921                     _, op::Dot(_, op::Slice(_)), _, _, _))));
4922 }
4923 
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContractingNoSharedSharding)4924 TEST_F(SpmdPartitioningTest,
4925        UnrollEinsumRHSWindowedNonContractingNoSharedSharding) {
4926   absl::string_view hlo_string = R"(
4927 HloModule module
4928 
4929 ENTRY entry {
4930   %lhs = f32[32,24,64,128] parameter(0)
4931   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4932   %lhs2 = f32[32,24,64,128] parameter(2)
4933   %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,1,2,1]0,1}
4934   %rhs = f32[32,39295,64,128] parameter(1)
4935   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4936   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4937     lhs_batch_dims={0}, rhs_batch_dims={0},
4938     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4939     sharding={devices=[1,2,1]0,1}
4940   %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
4941     lhs_batch_dims={0}, rhs_batch_dims={0},
4942     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4943     sharding={devices=[2,1,1]0,1}
4944   ROOT %t = tuple(%dot, %dot2)
4945 })";
4946 
4947   TF_ASSERT_OK_AND_ASSIGN(
4948       auto module,
4949       PartitionComputation(hlo_string, /*num_devices=*/2,
4950                            /*conv_halo_exchange_always_on_lhs =*/true,
4951                            /*choose_faster_windowed_einsum =*/false,
4952                            /*unroll_windowed_einsum =*/true));
4953   VLOG(1) << module->ToString();
4954   const auto root = module->entry_computation()->root_instruction();
4955   EXPECT_THAT(
4956       root,
4957       op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
4958                     _, op::Slice(op::GetTupleElement(op::While(_))), _, _, _)),
4959                 op::AllReduce(op::DynamicUpdateSlice(
4960                     _, op::Dot(_, op::Slice(_)), _, _, _))));
4961 
4962   // Tuple<-AllReduce<-DynamicUpdateSlice<-Slice<-GetTupleElement<-While
4963   const auto while_loop =
4964       root->operand(0)->operand(0)->operand(1)->operand(0)->operand(0);
4965   // Check loop condition.
4966   EXPECT_THAT(
4967       while_loop->while_condition()->root_instruction(),
4968       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4969 
4970   // Check loop body.
4971   const auto next_i =
4972       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4973               op::Constant());
4974   auto intermediate_output = AllOf(
4975       op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
4976                              op::Dot(op::GetTupleElement(op::Parameter(0)),
4977                                      op::GetTupleElement(op::Parameter(0))),
4978                              op::Constant(), op::Constant(), op::Reshape()),
4979       op::Shape("f32[32,12,39296]"));
4980   auto output = AllOf(
4981       op::DynamicUpdateSlice(
4982           intermediate_output,
4983           op::Dot(op::GetTupleElement(op::Parameter(0)),
4984                   op::CollectivePermute(op::GetTupleElement(op::Parameter(0)))),
4985           op::Constant(), op::Constant(), op::Reshape()),
4986       op::Shape("f32[32,12,39296]"));
4987 
4988   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4989               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4990                         op::CollectivePermute(op::CollectivePermute(
4991                             op::GetTupleElement(op::Parameter(0)))),
4992                         output, op::GetTupleElement(op::Parameter(0)), next_i));
4993 }
4994 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContractingNoSharedSharding)4995 TEST_F(SpmdPartitioningTest,
4996        BidirectionalEinsumRHSWindowedNonContractingNoSharedSharding) {
4997   absl::string_view hlo_string = R"(
4998 HloModule module
4999 
5000 ENTRY entry {
5001   %lhs = f32[32,24,64,128] parameter(0)
5002   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5003   %lhs2 = f32[32,24,64,128] parameter(2)
5004   %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,1,4,1]0,1,2,3}
5005   %rhs = f32[32,39295,64,128] parameter(1)
5006   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
5007   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5008     lhs_batch_dims={0}, rhs_batch_dims={0},
5009     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5010     sharding={devices=[1,4,1]0,1,2,3}
5011   %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
5012     lhs_batch_dims={0}, rhs_batch_dims={0},
5013     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5014     sharding={devices=[4,1,1]0,1,2,3}
5015   ROOT %t = tuple(%dot, %dot2)
5016 })";
5017 
5018   TF_ASSERT_OK_AND_ASSIGN(
5019       auto module,
5020       PartitionComputation(hlo_string, /*num_devices=*/4,
5021                            /*conv_halo_exchange_always_on_lhs =*/true,
5022                            /*choose_faster_windowed_einsum =*/false,
5023                            /*unroll_windowed_einsum =*/false,
5024                            /*bidirectional_windowed_einsum =*/true));
5025   VLOG(1) << module->ToString();
5026   const auto root = module->entry_computation()->root_instruction();
5027   EXPECT_THAT(
5028       root,
5029       op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
5030                     _, op::Slice(op::GetTupleElement(op::While(_))), _, _, _)),
5031                 op::AllReduce(op::DynamicUpdateSlice(
5032                     _, op::Dot(_, op::Slice(_)), _, _, _))));
5033 
5034   // Tuple<-AllReduce<-DynamicUpdateSlice<-Slice<-GetTupleElement<-While
5035   const auto while_loop =
5036       root->operand(0)->operand(0)->operand(1)->operand(0)->operand(0);
5037   // Check loop condition.
5038   EXPECT_THAT(
5039       while_loop->while_condition()->root_instruction(),
5040       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5041 
5042   // Check loop body.
5043   const auto next_i =
5044       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5045               op::Constant());
5046   const auto partial_dot_pattern =
5047       AllOf(op::Reshape(op::Slice(op::Dot(op::GetTupleElement(op::Parameter(0)),
5048                                           op::Concatenate()))),
5049             op::Shape("f32[32,6,9824]"));
5050   auto intermediate_output1 =
5051       AllOf(op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5052                                    partial_dot_pattern, op::Constant(),
5053                                    op::Constant(), op::Reshape()),
5054             op::Shape("f32[32,6,39296]"));
5055   auto intermediate_output2 = AllOf(
5056       op::DynamicUpdateSlice(intermediate_output1, partial_dot_pattern,
5057                              op::Constant(), op::Constant(), op::Reshape()),
5058       op::Shape("f32[32,6,39296]"));
5059   auto intermediate_output3 = AllOf(
5060       op::DynamicUpdateSlice(intermediate_output2, partial_dot_pattern,
5061                              op::Constant(), op::Constant(), op::Reshape()),
5062       op::Shape("f32[32,6,39296]"));
5063   auto partial_output = AllOf(
5064       op::DynamicUpdateSlice(intermediate_output3, partial_dot_pattern,
5065                              op::Constant(), op::Constant(), op::Reshape()),
5066       op::Shape("f32[32,6,39296]"));
5067 
5068   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5069               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5070                         op::CollectivePermute(op::CollectivePermute(
5071                             op::GetTupleElement(op::Parameter(0)))),
5072                         partial_output,
5073                         op::CollectivePermute(op::CollectivePermute(
5074                             op::GetTupleElement(op::Parameter(0)))),
5075                         next_i));
5076 }
5077 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContracting)5078 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) {
5079   absl::string_view hlo_string = R"(
5080 HloModule module
5081 
5082 ENTRY entry {
5083   %lhs = f32[32,24,64,128] parameter(0)
5084   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5085   %rhs = f32[32,39295,64,128] parameter(1)
5086   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5087   ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5088     lhs_batch_dims={0}, rhs_batch_dims={0},
5089     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5090     sharding={devices=[1,2,1]0,1}
5091 })";
5092 
5093   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
5094                                                             /*num_devices=*/2));
5095   VLOG(1) << module->ToString();
5096   const auto root = module->entry_computation()->root_instruction();
5097   const auto lhs = AllOf(
5098       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5099                                 op::Constant(), op::Constant())),
5100       op::Shape("f32[32,12,64,128]"));
5101   const auto rhs =
5102       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5103                                       op::Constant(), op::Reshape(),
5104                                       op::Constant(), op::Constant())),
5105             op::Shape("f32[32,19648,64,128]"));
5106   EXPECT_THAT(root,
5107               AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
5108                                         lhs, rhs, op::Broadcast(),
5109                                         op::Broadcast(), op::Constant()))),
5110                                     op::Shape("f32[32,12,39296]"))),
5111                     op::Shape("f32[32,12,39295]")));
5112   const auto while_loop = root->operand(0)->operand(0);
5113   // Check loop condition.
5114   EXPECT_THAT(
5115       while_loop->while_condition()->root_instruction(),
5116       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5117 
5118   // Check loop body.
5119   const auto next_i =
5120       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5121   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
5122                                 op::GetTupleElement(op::Parameter(0)),
5123                                 op::GetTupleElement(op::Parameter(0)));
5124   auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)),
5125                                 op::GetTupleElement(op::Parameter(0)));
5126   EXPECT_THAT(
5127       while_loop->while_body()->root_instruction(),
5128       op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
5129                 op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5130                                        partial_output, op::Constant(),
5131                                        op::Constant(), op::Reshape()),
5132                 op::GetTupleElement(op::Parameter(0)), next_i));
5133 
5134   // Check the conditional that contains the collective permute.
5135   auto cp_conditional =
5136       while_loop->while_body()->root_instruction()->operand(1);
5137   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
5138               op::CollectivePermute(op::Parameter(0)));
5139   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
5140               op::Parameter(0));
5141 }
5142 
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContracting)5143 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedNonContracting) {
5144   absl::string_view hlo_string = R"(
5145 HloModule module
5146 
5147 ENTRY entry {
5148   %lhs = f32[32,24,64,128] parameter(0)
5149   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5150   %rhs = f32[32,39295,64,128] parameter(1)
5151   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5152   ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5153     lhs_batch_dims={0}, rhs_batch_dims={0},
5154     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5155     sharding={devices=[1,2,1]0,1}
5156 })";
5157 
5158   TF_ASSERT_OK_AND_ASSIGN(
5159       auto module,
5160       PartitionComputation(hlo_string, /*num_devices=*/2,
5161                            /*conv_halo_exchange_always_on_lhs =*/true,
5162                            /*choose_faster_windowed_einsum =*/false,
5163                            /*unroll_windowed_einsum =*/true));
5164   VLOG(1) << module->ToString();
5165 
5166   const auto root = module->entry_computation()->root_instruction();
5167   const auto lhs = AllOf(
5168       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5169                                 op::Constant(), op::Constant())),
5170       op::Shape("f32[32,12,64,128]"));
5171   const auto rhs =
5172       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5173                                       op::Constant(), op::Reshape(),
5174                                       op::Constant(), op::Constant())),
5175             op::Shape("f32[32,19648,64,128]"));
5176   EXPECT_THAT(root,
5177               AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
5178                                         lhs, rhs, op::Broadcast(),
5179                                         op::Broadcast(), op::Constant()))),
5180                                     op::Shape("f32[32,12,39296]"))),
5181                     op::Shape("f32[32,12,39295]")));
5182 
5183   const auto while_loop = root->operand(0)->operand(0);
5184   // Check loop condition.
5185   EXPECT_THAT(
5186       while_loop->while_condition()->root_instruction(),
5187       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5188 
5189   // Check loop body.
5190   const auto next_i =
5191       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5192               op::Constant());
5193   auto intermediate_output = AllOf(
5194       op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5195                              op::Dot(op::GetTupleElement(op::Parameter(0)),
5196                                      op::GetTupleElement(op::Parameter(0))),
5197                              op::Constant(), op::Constant(), op::Reshape()),
5198       op::Shape("f32[32,12,39296]"));
5199   auto output = AllOf(
5200       op::DynamicUpdateSlice(
5201           intermediate_output,
5202           op::Dot(op::GetTupleElement(op::Parameter(0)),
5203                   op::CollectivePermute(op::GetTupleElement(op::Parameter(0)))),
5204           op::Constant(), op::Constant(), op::Reshape()),
5205       op::Shape("f32[32,12,39296]"));
5206 
5207   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5208               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5209                         op::CollectivePermute(op::CollectivePermute(
5210                             op::GetTupleElement(op::Parameter(0)))),
5211                         output, op::GetTupleElement(op::Parameter(0)), next_i));
5212 }
5213 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContracting)5214 TEST_F(SpmdPartitioningTest, BidirectionalEinsumRHSWindowedNonContracting) {
5215   absl::string_view hlo_string = R"(
5216 HloModule module
5217 
5218 ENTRY entry {
5219   %lhs = f32[32,24,64,128] parameter(0)
5220   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5221   %rhs = f32[32,39295,64,128] parameter(1)
5222   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
5223   ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5224     lhs_batch_dims={0}, rhs_batch_dims={0},
5225     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5226     sharding={devices=[1,4,1]0,1,2,3}
5227 })";
5228 
5229   TF_ASSERT_OK_AND_ASSIGN(
5230       auto module,
5231       PartitionComputation(hlo_string, /*num_devices=*/4,
5232                            /*conv_halo_exchange_always_on_lhs =*/true,
5233                            /*choose_faster_windowed_einsum =*/false,
5234                            /*unroll_windowed_einsum =*/false,
5235                            /*bidirectional_windowed_einsum =*/true));
5236   VLOG(1) << module->ToString();
5237 
5238   const auto root = module->entry_computation()->root_instruction();
5239   const auto lhs = AllOf(
5240       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5241                                 op::Constant(), op::Constant())),
5242       op::Shape("f32[32,6,64,128]"));
5243   const auto rhs =
5244       AllOf(op::Reshape(op::Copy(op::DynamicSlice(
5245                 op::Pad(op::Parameter(1), op::Constant()), op::Constant(),
5246                 op::Reshape(), op::Constant(), op::Constant()))),
5247             op::Shape("f32[32,1,9824,64,128]"));
5248   EXPECT_THAT(
5249       root,
5250       AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
5251                                 lhs, rhs, op::Broadcast(),
5252                                 op::CollectivePermute(rhs), op::Constant()))),
5253                             op::Shape("f32[32,6,39296]"))),
5254             op::Shape("f32[32,6,39295]")));
5255 
5256   const auto while_loop = root->operand(0)->operand(0);
5257   // Check loop condition.
5258   EXPECT_THAT(
5259       while_loop->while_condition()->root_instruction(),
5260       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5261 
5262   // Check loop body.
5263   const auto next_i =
5264       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5265               op::Constant());
5266   const auto partial_dot_pattern =
5267       AllOf(op::Reshape(op::Slice(op::Dot(op::GetTupleElement(op::Parameter(0)),
5268                                           op::Concatenate()))),
5269             op::Shape("f32[32,6,9824]"));
5270   auto intermediate_output1 =
5271       AllOf(op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5272                                    partial_dot_pattern, op::Constant(),
5273                                    op::Constant(), op::Reshape()),
5274             op::Shape("f32[32,6,39296]"));
5275   auto intermediate_output2 = AllOf(
5276       op::DynamicUpdateSlice(intermediate_output1, partial_dot_pattern,
5277                              op::Constant(), op::Constant(), op::Reshape()),
5278       op::Shape("f32[32,6,39296]"));
5279   auto intermediate_output3 = AllOf(
5280       op::DynamicUpdateSlice(intermediate_output2, partial_dot_pattern,
5281                              op::Constant(), op::Constant(), op::Reshape()),
5282       op::Shape("f32[32,6,39296]"));
5283   auto partial_output = AllOf(
5284       op::DynamicUpdateSlice(intermediate_output3, partial_dot_pattern,
5285                              op::Constant(), op::Constant(), op::Reshape()),
5286       op::Shape("f32[32,6,39296]"));
5287 
5288   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5289               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5290                         op::CollectivePermute(op::CollectivePermute(
5291                             op::GetTupleElement(op::Parameter(0)))),
5292                         partial_output,
5293                         op::CollectivePermute(op::CollectivePermute(
5294                             op::GetTupleElement(op::Parameter(0)))),
5295                         next_i));
5296 }
5297 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContracting)5298 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) {
5299   absl::string_view hlo_string = R"(
5300 HloModule module
5301 
5302 ENTRY entry {
5303   %lhs = f32[32,24,63,128] parameter(0)
5304   %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5305   %rhs = f32[32,39296,63,128] parameter(1)
5306   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
5307   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
5308     lhs_batch_dims={0}, rhs_batch_dims={0},
5309     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5310     sharding={devices=[1,2,1]0,1}
5311 })";
5312 
5313   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
5314                                                             /*num_devices=*/2));
5315   VLOG(1) << module->ToString();
5316   const auto root = module->entry_computation()->root_instruction();
5317   const auto lhs = AllOf(
5318       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5319                                 op::Constant(), op::Constant())),
5320       op::Shape("f32[32,12,63,128]"));
5321   const auto rhs =
5322       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5323                                       op::Constant(), op::Constant(),
5324                                       op::Reshape(), op::Constant())),
5325             op::Shape("f32[32,39296,32,128]"));
5326   auto masked_rhs =
5327       op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
5328   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
5329                               op::Tuple(lhs, masked_rhs, op::Broadcast(),
5330                                         op::Broadcast(), op::Constant()))),
5331                           op::Shape("f32[32,12,39296]")));
5332   const auto while_loop = root->operand(0);
5333   // Check loop condition.
5334   EXPECT_THAT(
5335       while_loop->while_condition()->root_instruction(),
5336       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5337 
5338   // Check loop body.
5339   const auto next_i =
5340       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5341   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
5342                                 op::GetTupleElement(op::Parameter(0)),
5343                                 op::GetTupleElement(op::Parameter(0)));
5344   auto partial_output = op::Dot(
5345       op::DynamicSlice(
5346           op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5347           op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
5348       op::GetTupleElement(op::Parameter(0)));
5349   EXPECT_THAT(
5350       while_loop->while_body()->root_instruction(),
5351       op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
5352                 op::Add(op::GetTupleElement(op::Parameter(0)), partial_output),
5353                 op::GetTupleElement(op::Parameter(0)), next_i));
5354 
5355   // Check the conditional that contains the collective permute.
5356   auto cp_conditional =
5357       while_loop->while_body()->root_instruction()->operand(1);
5358   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
5359               op::CollectivePermute(op::Parameter(0)));
5360   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
5361               op::Parameter(0));
5362 }
5363 
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedContracting)5364 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedContracting) {
5365   absl::string_view hlo_string = R"(
5366 HloModule module
5367 
5368 ENTRY entry {
5369   %lhs = f32[32,24,63,128] parameter(0)
5370   %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5371   %rhs = f32[32,39296,63,128] parameter(1)
5372   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
5373   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
5374     lhs_batch_dims={0}, rhs_batch_dims={0},
5375     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5376     sharding={devices=[1,2,1]0,1}
5377 })";
5378 
5379   TF_ASSERT_OK_AND_ASSIGN(
5380       auto module,
5381       PartitionComputation(hlo_string, /*num_devices=*/2,
5382                            /*conv_halo_exchange_always_on_lhs =*/true,
5383                            /*choose_faster_windowed_einsum =*/false,
5384                            /*unroll_windowed_einsum =*/true));
5385   VLOG(1) << module->ToString();
5386 
5387   const auto root = module->entry_computation()->root_instruction();
5388   const auto lhs = AllOf(
5389       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5390                                 op::Constant(), op::Constant())),
5391       op::Shape("f32[32,12,63,128]"));
5392   const auto rhs =
5393       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5394                                       op::Constant(), op::Constant(),
5395                                       op::Reshape(), op::Constant())),
5396             op::Shape("f32[32,39296,32,128]"));
5397   auto masked_rhs =
5398       op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
5399   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
5400                               op::Tuple(lhs, masked_rhs, op::Broadcast(),
5401                                         op::Broadcast(), op::Constant()))),
5402                           op::Shape("f32[32,12,39296]")));
5403   const auto while_loop = root->operand(0);
5404   // Check loop condition.
5405   EXPECT_THAT(
5406       while_loop->while_condition()->root_instruction(),
5407       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5408 
5409   // Check loop body.
5410   const auto next_i =
5411       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5412               op::Constant());
5413   auto intermediate_output =
5414       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
5415                     op::Dot(op::DynamicSlice(
5416                                 op::Pad(op::GetTupleElement(op::Parameter(0)),
5417                                         op::Constant()),
5418                                 op::Constant(), op::Constant(), op::Reshape(),
5419                                 op::Constant()),
5420                             op::GetTupleElement(op::Parameter(0)))),
5421             op::Shape("f32[32,12,39296]"));
5422   auto output = AllOf(
5423       op::Add(
5424           intermediate_output,
5425           op::Dot(
5426               op::DynamicSlice(op::Pad(op::GetTupleElement(op::Parameter(0)),
5427                                        op::Constant()),
5428                                op::Constant(), op::Constant(), op::Reshape(),
5429                                op::Constant()),
5430               op::CollectivePermute(op::GetTupleElement(op::Parameter(0))))),
5431       op::Shape("f32[32,12,39296]"));
5432 
5433   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5434               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5435                         op::CollectivePermute(op::CollectivePermute(
5436                             op::GetTupleElement(op::Parameter(0)))),
5437                         output, op::GetTupleElement(op::Parameter(0)), next_i));
5438 }
5439 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedContracting)5440 TEST_F(SpmdPartitioningTest, BidirectionalEinsumRHSWindowedContracting) {
5441   absl::string_view hlo_string = R"(
5442 HloModule module
5443 
5444 ENTRY entry {
5445   %lhs = f32[32,24,63,128] parameter(0)
5446   %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5447   %rhs = f32[32,39296,63,128] parameter(1)
5448   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,4,1]0,1,2,3}
5449   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
5450     lhs_batch_dims={0}, rhs_batch_dims={0},
5451     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5452     sharding={devices=[1,4,1]0,1,2,3}
5453 })";
5454 
5455   TF_ASSERT_OK_AND_ASSIGN(
5456       auto module,
5457       PartitionComputation(hlo_string, /*num_devices=*/4,
5458                            /*conv_halo_exchange_always_on_lhs =*/true,
5459                            /*choose_faster_windowed_einsum =*/false,
5460                            /*unroll_windowed_einsum =*/false,
5461                            /*bidirectional_windowed_einsum =*/true));
5462   VLOG(1) << module->ToString();
5463 
5464   const auto root = module->entry_computation()->root_instruction();
5465   const auto lhs = AllOf(
5466       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5467                                 op::Constant(), op::Constant())),
5468       op::Shape("f32[32,6,63,128]"));
5469   const auto rhs =
5470       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5471                                       op::Constant(), op::Constant(),
5472                                       op::Reshape(), op::Constant())),
5473             op::Shape("f32[32,39296,16,128]"));
5474   auto masked_rhs = op::Reshape(
5475       op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())));
5476   EXPECT_THAT(root,
5477               AllOf(op::GetTupleElement(op::While(op::Tuple(
5478                         lhs, masked_rhs, op::Broadcast(),
5479                         op::CollectivePermute(masked_rhs), op::Constant()))),
5480                     op::Shape("f32[32,6,39296]")));
5481   const auto while_loop = root->operand(0);
5482   // Check loop condition.
5483   EXPECT_THAT(
5484       while_loop->while_condition()->root_instruction(),
5485       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5486 
5487   // Check loop body.
5488   const auto next_i =
5489       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5490               op::Constant());
5491   auto partial_output =
5492       AllOf(op::Add(op::Add(op::GetTupleElement(op::Parameter(0)),
5493                             op::Dot(op::Maximum(), op::Concatenate())),
5494                     op::Dot(op::Maximum(), op::Concatenate())),
5495             op::Shape("f32[32,6,39296]"));
5496 
5497   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5498               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5499                         op::CollectivePermute(op::CollectivePermute(
5500                             op::GetTupleElement(op::Parameter(0)))),
5501                         partial_output,
5502                         op::CollectivePermute(op::CollectivePermute(
5503                             op::GetTupleElement(op::Parameter(0)))),
5504                         next_i));
5505 }
5506 
TEST_F(SpmdPartitioningTest,EinsumWindowedNonContractingDimensionsNoCodeMotionWithDependentNodes)5507 TEST_F(SpmdPartitioningTest,
5508        EinsumWindowedNonContractingDimensionsNoCodeMotionWithDependentNodes) {
5509   absl::string_view hlo_string = R"(
5510 HloModule module
5511 
5512 sum {
5513   a = f32[] parameter(0)
5514   b = f32[] parameter(1)
5515   ROOT add = f32[] add(a, b)
5516 }
5517 
5518 ENTRY entry {
5519   %lhs = f32[32,24,64,128] parameter(0)
5520   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5521   %rhs = f32[32,39295,64,128] parameter(1)
5522   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5523   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5524     lhs_batch_dims={0}, rhs_batch_dims={0},
5525     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5526     sharding={devices=[1,2,1]0,1}
5527   %constant = f32[] constant(0)
5528   %constant.1 = f32[] constant(2)
5529   %constant.2 = f32[] constant(4)
5530   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5531     sharding={devices=[1,2,1]0,1}
5532   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5533     sharding={devices=[1,2,1]0,1}
5534   %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5535     to_apply=sum, sharding={devices=[1,2]0,1}
5536   %all-reduce = f32[32,24] all-reduce(%reduce),
5537     to_apply=sum, sharding={devices=[1,2]0,1}
5538   %broadcast.1 = f32[32,24,39295] broadcast(%all-reduce), dimensions={0,1},
5539     sharding={devices=[1,2,1]0,1}
5540   %subtract = f32[32,24,39295] subtract(%multiply, %broadcast.1),
5541     sharding={devices=[1,2,1]0,1}
5542   ROOT %reduce.1 = f32[32,24] reduce(%subtract, %constant.2), dimensions={2},
5543     to_apply=sum, sharding={devices=[1,2]0,1}
5544 })";
5545 
5546   TF_ASSERT_OK_AND_ASSIGN(auto module,
5547                           PartitionComputation(hlo_string, /*num_devices=*/2));
5548   VLOG(1) << module->ToString();
5549 
5550   const auto root = module->entry_computation()->root_instruction();
5551   const auto lhs = AllOf(
5552       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5553                                 op::Constant(), op::Constant())),
5554       op::Shape("f32[32,12,64,128]"));
5555   const auto rhs =
5556       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5557                                       op::Constant(), op::Reshape(),
5558                                       op::Constant(), op::Constant())),
5559             op::Shape("f32[32,19648,64,128]"));
5560   const auto while_output =
5561       AllOf(op::Slice(op::GetTupleElement(op::While(op::Tuple(
5562                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant())))),
5563             op::Shape("f32[32,12,39295]"));
5564   // All the multiples, subtracts and reduces should remain in the spmd entry
5565   // computation.
5566   const auto multiply =
5567       AllOf(op::Multiply(while_output, op::Broadcast(op::Constant())),
5568             op::Shape("f32[32,12,39295]"));
5569   EXPECT_THAT(
5570       root,
5571       AllOf(op::Reduce(
5572                 op::Subtract(multiply, op::Broadcast(op::AllReduce(op::Reduce(
5573                                            multiply, op::Constant())))),
5574                 op::Constant()),
5575             op::Shape("f32[32,12]")));
5576 
5577   const auto while_loop =
5578       root->operand(0)->operand(0)->operand(0)->operand(0)->operand(0);
5579   // Check loop condition.
5580   EXPECT_THAT(
5581       while_loop->while_condition()->root_instruction(),
5582       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5583 
5584   // Check loop body. There is not be any multple, subtract, reduce, etc.
5585   // that has been moved into the loop body.
5586   const auto next_i =
5587       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5588   auto output = op::DynamicUpdateSlice(
5589       op::GetTupleElement(op::Parameter(0)),
5590       op::Dot(op::GetTupleElement(op::Parameter(0)),
5591               op::GetTupleElement(op::Parameter(0))),
5592       op::Constant(), op::Constant(), op::Reshape(op::DynamicSlice()));
5593 
5594   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5595               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5596                         op::Conditional(op::Compare(next_i, op::Constant()),
5597                                         op::GetTupleElement(op::Parameter(0)),
5598                                         op::GetTupleElement(op::Parameter(0))),
5599                         output, op::GetTupleElement(op::Parameter(0)), next_i));
5600 }
5601 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce1)5602 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) {
5603   absl::string_view hlo_string = R"(
5604 HloModule module
5605 
5606 sum {
5607   a = f32[] parameter(0)
5608   b = f32[] parameter(1)
5609   ROOT add = f32[] add(a, b)
5610 }
5611 
5612 ENTRY entry {
5613   %lhs = f32[32,24,64,128] parameter(0)
5614   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5615   %rhs = f32[32,39295,64,128] parameter(1)
5616   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5617   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5618     lhs_batch_dims={0}, rhs_batch_dims={0},
5619     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5620     sharding={devices=[1,2,1]0,1}
5621   %constant = f32[] constant(0)
5622   %constant.1 = f32[] constant(2)
5623   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5624     sharding={devices=[1,2,1]0,1}
5625   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5626     sharding={devices=[1,2,1]0,1}
5627   ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5628     to_apply=sum, sharding={devices=[1,2]0,1}
5629 })";
5630 
5631   TF_ASSERT_OK_AND_ASSIGN(auto module,
5632                           PartitionComputation(hlo_string, /*num_devices=*/2));
5633   VLOG(1) << module->ToString();
5634 
5635   const auto root = module->entry_computation()->root_instruction();
5636   const auto lhs = AllOf(
5637       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5638                                 op::Constant(), op::Constant())),
5639       op::Shape("f32[32,12,64,128]"));
5640   const auto rhs =
5641       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5642                                       op::Constant(), op::Reshape(),
5643                                       op::Constant(), op::Constant())),
5644             op::Shape("f32[32,19648,64,128]"));
5645   auto input_subtuple =
5646       op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5647   EXPECT_THAT(
5648       root,
5649       AllOf(op::GetTupleElement(op::GetTupleElement(op::While(op::Tuple(
5650                 lhs, rhs, input_subtuple, op::Broadcast(), op::Constant())))),
5651             op::Shape("f32[32,12]")));
5652 
5653   const auto while_loop = root->operand(0)->operand(0);
5654   // Check loop condition.
5655   EXPECT_THAT(
5656       while_loop->while_condition()->root_instruction(),
5657       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5658 
5659   // Check loop body.
5660   const auto next_i =
5661       op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5662   auto output_tuple = op::Tuple(
5663       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5664       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5665       op::Add(op::Reduce(
5666                   op::Select(op::Compare(),
5667                              op::Multiply(
5668                                  op::Dot(op::GetTupleElement(op::Parameter(0)),
5669                                          op::GetTupleElement(op::Parameter(0))),
5670                                  op::DynamicSlice()),
5671                              op::Broadcast()),
5672                   op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5673               op::DynamicSlice(
5674                   op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5675                   op::Constant(), op::Constant())));
5676 
5677   EXPECT_THAT(
5678       while_loop->while_body()->root_instruction(),
5679       op::Tuple(op::GetTupleElement(op::Parameter(0)),
5680                 op::Conditional(op::Compare(next_i, op::Constant()),
5681                                 op::GetTupleElement(op::Parameter(0)),
5682                                 op::GetTupleElement(op::Parameter(0))),
5683                 output_tuple, op::GetTupleElement(op::Parameter(0)), next_i));
5684 }
5685 
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContractingReduce1)5686 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedNonContractingReduce1) {
5687   absl::string_view hlo_string = R"(
5688 HloModule module
5689 
5690 sum {
5691   a = f32[] parameter(0)
5692   b = f32[] parameter(1)
5693   ROOT add = f32[] add(a, b)
5694 }
5695 
5696 ENTRY entry {
5697   %lhs = f32[32,24,64,128] parameter(0)
5698   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5699   %rhs = f32[32,39295,64,128] parameter(1)
5700   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5701   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5702     lhs_batch_dims={0}, rhs_batch_dims={0},
5703     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5704     sharding={devices=[1,2,1]0,1}
5705   %constant = f32[] constant(0)
5706   %constant.1 = f32[] constant(2)
5707   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5708     sharding={devices=[1,2,1]0,1}
5709   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5710   sharding={devices=[1,2,1]0,1}
5711   ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5712     to_apply=sum, sharding={devices=[1,2]0,1}
5713 })";
5714 
5715   TF_ASSERT_OK_AND_ASSIGN(
5716       auto module,
5717       PartitionComputation(hlo_string, /*num_devices=*/2,
5718                            /*conv_halo_exchange_always_on_lhs =*/true,
5719                            /*choose_faster_windowed_einsum =*/false,
5720                            /*unroll_windowed_einsum =*/true));
5721   VLOG(1) << module->ToString();
5722 
5723   const auto root = module->entry_computation()->root_instruction();
5724   const auto lhs = AllOf(
5725       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5726                                 op::Constant(), op::Constant())),
5727       op::Shape("f32[32,12,64,128]"));
5728   const auto rhs =
5729       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5730                                       op::Constant(), op::Reshape(),
5731                                       op::Constant(), op::Constant())),
5732             op::Shape("f32[32,19648,64,128]"));
5733   auto input_subtuple =
5734       op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5735   EXPECT_THAT(
5736       root,
5737       AllOf(op::GetTupleElement(op::GetTupleElement(op::While(op::Tuple(
5738                 lhs, rhs, input_subtuple, op::Broadcast(), op::Constant())))),
5739             op::Shape("f32[32,12]")));
5740 
5741   const auto while_loop = root->operand(0)->operand(0);
5742   // Check loop condition.
5743   EXPECT_THAT(
5744       while_loop->while_condition()->root_instruction(),
5745       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5746 
5747   // Check loop body.
5748   const auto next_i =
5749       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5750               op::Constant());
5751   auto intermediate_output = AllOf(
5752       op::Add(
5753           op::Reduce(
5754               op::Select(op::Compare(),
5755                          op::Multiply(
5756                              op::Dot(op::GetTupleElement(op::Parameter(0)),
5757                                      op::CollectivePermute(op::GetTupleElement(
5758                                          op::Parameter(0)))),
5759                              op::DynamicSlice()),
5760                          op::Broadcast()),
5761               op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5762           op::DynamicSlice(
5763               op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5764               op::Constant(), op::Constant())),
5765       op::Shape("f32[32,12]"));
5766   auto output_tuple = op::Tuple(
5767       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5768       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5769       op::Add(op::Reduce(
5770                   op::Select(op::Compare(),
5771                              op::Multiply(
5772                                  op::Dot(op::GetTupleElement(op::Parameter(0)),
5773                                          op::GetTupleElement(op::Parameter(0))),
5774                                  op::DynamicSlice()),
5775                              op::Broadcast()),
5776                   op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5777               op::DynamicSlice(intermediate_output, op::Constant(),
5778                                op::Constant())));
5779 
5780   EXPECT_THAT(
5781       while_loop->while_body()->root_instruction(),
5782       op::Tuple(op::GetTupleElement(op::Parameter(0)),
5783                 op::CollectivePermute(op::CollectivePermute(
5784                     op::GetTupleElement(op::Parameter(0)))),
5785                 output_tuple, op::GetTupleElement(op::Parameter(0)), next_i));
5786 }
5787 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContractingReduce1)5788 TEST_F(SpmdPartitioningTest,
5789        BidirectionalEinsumRHSWindowedNonContractingReduce1) {
5790   absl::string_view hlo_string = R"(
5791 HloModule module
5792 
5793 sum {
5794   a = f32[] parameter(0)
5795   b = f32[] parameter(1)
5796   ROOT add = f32[] add(a, b)
5797 }
5798 
5799 ENTRY entry {
5800   %lhs = f32[32,24,64,128] parameter(0)
5801   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5802   %rhs = f32[32,39295,64,128] parameter(1)
5803   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
5804   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5805     lhs_batch_dims={0}, rhs_batch_dims={0},
5806     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5807     sharding={devices=[1,4,1]0,1,2,3}
5808   %constant = f32[] constant(0)
5809   %constant.1 = f32[] constant(2)
5810   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5811     sharding={devices=[1,4,1]0,1,2,3}
5812   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5813   sharding={devices=[1,4,1]0,1,2,3}
5814   ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5815     to_apply=sum, sharding={devices=[1,4]0,1,2,3}
5816 })";
5817 
5818   TF_ASSERT_OK_AND_ASSIGN(
5819       auto module,
5820       PartitionComputation(hlo_string, /*num_devices=*/4,
5821                            /*conv_halo_exchange_always_on_lhs =*/true,
5822                            /*choose_faster_windowed_einsum =*/false,
5823                            /*unroll_windowed_einsum =*/false,
5824                            /*bidirectional_windowed_einsum =*/true));
5825   VLOG(1) << module->ToString();
5826 
5827   const auto root = module->entry_computation()->root_instruction();
5828   const auto lhs = AllOf(
5829       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5830                                 op::Constant(), op::Constant())),
5831       op::Shape("f32[32,6,64,128]"));
5832   const auto rhs =
5833       AllOf(op::Reshape(op::Copy(op::DynamicSlice(
5834                 op::Pad(op::Parameter(1), op::Constant()), op::Constant(),
5835                 op::Reshape(), op::Constant(), op::Constant()))),
5836             op::Shape("f32[32,1,9824,64,128]"));
5837   auto input_subtuple =
5838       op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5839   EXPECT_THAT(root,
5840               AllOf(op::GetTupleElement(op::GetTupleElement(op::While(
5841                         op::Tuple(lhs, rhs, input_subtuple,
5842                                   op::CollectivePermute(), op::Constant())))),
5843                     op::Shape("f32[32,6]")));
5844 
5845   const auto while_loop = root->operand(0)->operand(0);
5846   // Check loop condition.
5847   EXPECT_THAT(
5848       while_loop->while_condition()->root_instruction(),
5849       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5850 
5851   // Check loop body.
5852   const auto next_i =
5853       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5854               op::Constant());
5855   auto partial_reduce_pattern = AllOf(
5856       op::Reduce(
5857           op::Select(op::Compare(),
5858                      op::Multiply(op::Reshape(op::Slice(op::Dot(
5859                                       op::GetTupleElement(op::Parameter(0)),
5860                                       op::Concatenate()))),
5861                                   op::DynamicSlice()),
5862                      op::Broadcast()),
5863           op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5864       op::Shape("f32[32,6]"));
5865   auto intermediate_output1 = AllOf(
5866       op::Add(partial_reduce_pattern,
5867               op::DynamicSlice(
5868                   op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5869                   op::Constant(), op::Constant())),
5870       op::Shape("f32[32,6]"));
5871   auto intermediate_output2 =
5872       AllOf(op::Add(partial_reduce_pattern,
5873                     op::DynamicSlice(intermediate_output1, op::Constant(),
5874                                      op::Constant())),
5875             op::Shape("f32[32,6]"));
5876   auto intermediate_output3 =
5877       AllOf(op::Add(partial_reduce_pattern,
5878                     op::DynamicSlice(intermediate_output2, op::Constant(),
5879                                      op::Constant())),
5880             op::Shape("f32[32,6]"));
5881   auto output_tuple =
5882       op::Tuple(op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5883                 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5884                 op::Add(partial_reduce_pattern,
5885                         op::DynamicSlice(intermediate_output3, op::Constant(),
5886                                          op::Constant())));
5887 
5888   EXPECT_THAT(while_loop->while_body()->root_instruction(),
5889               op::Tuple(op::GetTupleElement(op::Parameter(0)),
5890                         op::CollectivePermute(op::CollectivePermute(
5891                             op::GetTupleElement(op::Parameter(0)))),
5892                         output_tuple,
5893                         op::CollectivePermute(op::CollectivePermute(
5894                             op::GetTupleElement(op::Parameter(0)))),
5895                         next_i));
5896 }
5897 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce2)5898 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) {
5899   absl::string_view hlo_string = R"(
5900 HloModule module
5901 
5902 sum {
5903   a = f32[] parameter(0)
5904   b = f32[] parameter(1)
5905   ROOT add = f32[] add(a, b)
5906 }
5907 
5908 ENTRY entry {
5909   %lhs = f32[32,24,64,128] parameter(0)
5910   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5911   %rhs = f32[32,39295,64,128] parameter(1)
5912   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5913   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5914     lhs_batch_dims={0}, rhs_batch_dims={0},
5915     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5916     sharding={devices=[1,2,1]0,1}
5917   %constant = f32[] constant(0)
5918   %constant.1 = f32[] constant(2)
5919   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5920     sharding={devices=[1,2,1]0,1}
5921   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5922     sharding={devices=[1,2,1]0,1}
5923   ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
5924     to_apply=sum, sharding={replicated}
5925 })";
5926 
5927   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
5928                                                             /*num_devices=*/2));
5929   VLOG(1) << module->ToString();
5930   // Involves loop code motion, skips pattern matching.
5931 }
5932 
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContractingReduce2)5933 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedNonContractingReduce2) {
5934   absl::string_view hlo_string = R"(
5935 HloModule module
5936 
5937 sum {
5938   a = f32[] parameter(0)
5939   b = f32[] parameter(1)
5940   ROOT add = f32[] add(a, b)
5941 }
5942 
5943 ENTRY entry {
5944   %lhs = f32[32,24,64,128] parameter(0)
5945   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5946   %rhs = f32[32,39295,64,128] parameter(1)
5947   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5948   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5949     lhs_batch_dims={0}, rhs_batch_dims={0},
5950     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5951     sharding={devices=[1,2,1]0,1}
5952   %constant = f32[] constant(0)
5953   %constant.1 = f32[] constant(2)
5954   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5955     sharding={devices=[1,2,1]0,1}
5956   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5957     sharding={devices=[1,2,1]0,1}
5958   ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
5959     to_apply=sum, sharding={replicated}
5960 })";
5961 
5962   TF_ASSERT_OK_AND_ASSIGN(
5963       auto module,
5964       PartitionComputation(hlo_string, /*num_devices=*/2,
5965                            /*conv_halo_exchange_always_on_lhs =*/true,
5966                            /*choose_faster_windowed_einsum =*/false,
5967                            /*unroll_windowed_einsum =*/true));
5968   VLOG(1) << module->ToString();
5969 
5970   const auto root = module->entry_computation()->root_instruction();
5971   const auto lhs = AllOf(
5972       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5973                                 op::Constant(), op::Constant())),
5974       op::Shape("f32[32,12,64,128]"));
5975   const auto rhs =
5976       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5977                                       op::Constant(), op::Reshape(),
5978                                       op::Constant(), op::Constant())),
5979             op::Shape("f32[32,19648,64,128]"));
5980   auto input_subtuple =
5981       op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5982   EXPECT_THAT(
5983       root,
5984       AllOf(op::AllReduce(op::Slice(op::GetTupleElement(op::GetTupleElement(
5985                 op::While(op::Tuple(lhs, rhs, input_subtuple, op::Broadcast(),
5986                                     op::Constant())))))),
5987             op::Shape("f32[32,39295]")));
5988 
5989   // AllReduce<-Slice<-GetTupleElement<-GetTupleElement<-While
5990   const auto while_loop = root->operand(0)->operand(0)->operand(0)->operand(0);
5991   // Check loop condition.
5992   EXPECT_THAT(
5993       while_loop->while_condition()->root_instruction(),
5994       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5995 
5996   // Check loop body.
5997   const auto next_i =
5998       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5999               op::Constant());
6000   auto intermediate_output = AllOf(
6001       op::DynamicUpdateSlice(
6002           op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6003           op::Reduce(
6004               op::Multiply(op::Dot(op::GetTupleElement(op::Parameter(0)),
6005                                    op::CollectivePermute(
6006                                        op::GetTupleElement(op::Parameter(0)))),
6007                            op::DynamicSlice()),
6008               op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6009           op::Constant(), op::Reshape()),
6010       op::Shape("f32[32,39296]"));
6011   auto output_tuple = op::Tuple(
6012       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6013       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6014       op::DynamicUpdateSlice(
6015           intermediate_output,
6016           op::Reduce(
6017               op::Multiply(op::Dot(op::GetTupleElement(op::Parameter(0)),
6018                                    op::GetTupleElement(op::Parameter(0))),
6019                            op::DynamicSlice()),
6020               op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6021           op::Constant(), op::Reshape()));
6022 
6023   EXPECT_THAT(
6024       while_loop->while_body()->root_instruction(),
6025       op::Tuple(op::GetTupleElement(op::Parameter(0)),
6026                 op::CollectivePermute(op::CollectivePermute(
6027                     op::GetTupleElement(op::Parameter(0)))),
6028                 output_tuple, op::GetTupleElement(op::Parameter(0)), next_i));
6029 }
6030 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContractingReduce2)6031 TEST_F(SpmdPartitioningTest,
6032        BidirectionalEinsumRHSWindowedNonContractingReduce2) {
6033   absl::string_view hlo_string = R"(
6034 HloModule module
6035 
6036 sum {
6037   a = f32[] parameter(0)
6038   b = f32[] parameter(1)
6039   ROOT add = f32[] add(a, b)
6040 }
6041 
6042 ENTRY entry {
6043   %lhs = f32[32,24,64,128] parameter(0)
6044   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
6045   %rhs = f32[32,39295,64,128] parameter(1)
6046   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
6047   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
6048     lhs_batch_dims={0}, rhs_batch_dims={0},
6049     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6050     sharding={devices=[1,4,1]0,1,2,3}
6051   %constant = f32[] constant(0)
6052   %constant.1 = f32[] constant(2)
6053   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
6054     sharding={devices=[1,4,1]0,1,2,3}
6055   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
6056     sharding={devices=[1,4,1]0,1,2,3}
6057   ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
6058     to_apply=sum, sharding={replicated}
6059 })";
6060 
6061   TF_ASSERT_OK_AND_ASSIGN(
6062       auto module,
6063       PartitionComputation(hlo_string, /*num_devices=*/4,
6064                            /*conv_halo_exchange_always_on_lhs =*/true,
6065                            /*choose_faster_windowed_einsum =*/false,
6066                            /*unroll_windowed_einsum =*/false,
6067                            /*bidirectional_windowed_einsum =*/true));
6068   VLOG(1) << module->ToString();
6069 
6070   const auto root = module->entry_computation()->root_instruction();
6071   const auto lhs = AllOf(
6072       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
6073                                 op::Constant(), op::Constant())),
6074       op::Shape("f32[32,6,64,128]"));
6075   const auto rhs =
6076       AllOf(op::Reshape(op::Copy(op::DynamicSlice(
6077                 op::Pad(op::Parameter(1), op::Constant()), op::Constant(),
6078                 op::Reshape(), op::Constant(), op::Constant()))),
6079             op::Shape("f32[32,1,9824,64,128]"));
6080   auto input_subtuple =
6081       op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
6082   EXPECT_THAT(
6083       root, AllOf(op::AllReduce(op::Slice(op::GetTupleElement(
6084                       op::GetTupleElement(op::While(op::Tuple(
6085                           lhs, rhs, input_subtuple, op::CollectivePermute(rhs),
6086                           op::Constant())))))),
6087                   op::Shape("f32[32,39295]")));
6088 
6089   // AllReduce<-Slice<-GetTupleElement<-GetTupleElement<-While
6090   const auto while_loop = root->operand(0)->operand(0)->operand(0)->operand(0);
6091   // Check loop condition.
6092   EXPECT_THAT(
6093       while_loop->while_condition()->root_instruction(),
6094       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
6095 
6096   // Check loop body.
6097   const auto next_i =
6098       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
6099               op::Constant());
6100   auto partial_reduce_pattern = AllOf(
6101       op::Reduce(op::Multiply(op::Reshape(op::Slice(
6102                                   op::Dot(op::GetTupleElement(op::Parameter(0)),
6103                                           op::Concatenate()))),
6104                               op::DynamicSlice(op::Broadcast(), op::Constant(),
6105                                                op::Constant(), op::Reshape())),
6106                  op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6107       op::Shape("f32[32,9824]"));
6108   auto intermediate_output1 =
6109       AllOf(op::DynamicUpdateSlice(
6110                 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6111                 partial_reduce_pattern, op::Constant(), op::Reshape()),
6112             op::Shape("f32[32,39296]"));
6113   auto intermediate_output2 =
6114       AllOf(op::DynamicUpdateSlice(intermediate_output1, partial_reduce_pattern,
6115                                    op::Constant(), op::Reshape()),
6116             op::Shape("f32[32,39296]"));
6117   auto intermediate_output3 =
6118       AllOf(op::DynamicUpdateSlice(intermediate_output2, partial_reduce_pattern,
6119                                    op::Constant(), op::Reshape()),
6120             op::Shape("f32[32,39296]"));
6121   auto output_tuple = op::Tuple(
6122       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6123       op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6124       op::DynamicUpdateSlice(intermediate_output3, partial_reduce_pattern,
6125                              op::Constant(), op::Reshape()));
6126 
6127   EXPECT_THAT(while_loop->while_body()->root_instruction(),
6128               op::Tuple(op::GetTupleElement(op::Parameter(0)),
6129                         op::CollectivePermute(op::CollectivePermute(
6130                             op::GetTupleElement(op::Parameter(0)))),
6131                         output_tuple,
6132                         op::CollectivePermute(op::CollectivePermute(
6133                             op::GetTupleElement(op::Parameter(0)))),
6134                         next_i));
6135 }
6136 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContractingFromBroadcast)6137 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) {
6138   absl::string_view hlo_string = R"(
6139 HloModule module
6140 
6141 ENTRY entry {
6142   %rhs = f32[32,39296,63,128] parameter(0)
6143   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
6144   %constant.1 = f32[] constant(2)
6145   %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
6146     sharding={devices=[1,2,1,1]0,1}
6147   %add = f32[32,24,63,128] add(%broadcast, %broadcast),
6148     sharding={devices=[1,2,1,1]0,1}
6149   ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
6150     lhs_batch_dims={0}, rhs_batch_dims={0},
6151     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6152     sharding={devices=[1,2,1]0,1}
6153 })";
6154 
6155   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
6156                                                             /*num_devices=*/2));
6157   VLOG(1) << module->ToString();
6158   // Involves loop code motion, skips pattern matching.
6159 }
6160 
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedContractingFromBroadcast)6161 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedContractingFromBroadcast) {
6162   absl::string_view hlo_string = R"(
6163 HloModule module
6164 
6165 ENTRY entry {
6166   %rhs = f32[32,39296,63,128] parameter(0)
6167   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
6168   %constant.1 = f32[] constant(2)
6169   %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
6170     sharding={devices=[1,2,1,1]0,1}
6171   %add = f32[32,24,63,128] add(%broadcast, %broadcast),
6172     sharding={devices=[1,2,1,1]0,1}
6173   ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
6174     lhs_batch_dims={0}, rhs_batch_dims={0},
6175     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6176     sharding={devices=[1,2,1]0,1}
6177 })";
6178 
6179   TF_ASSERT_OK_AND_ASSIGN(
6180       auto module,
6181       PartitionComputation(hlo_string, /*num_devices=*/2,
6182                            /*conv_halo_exchange_always_on_lhs =*/true,
6183                            /*choose_faster_windowed_einsum =*/false,
6184                            /*unroll_windowed_einsum =*/true));
6185   VLOG(1) << module->ToString();
6186 
6187   const auto root = module->entry_computation()->root_instruction();
6188   const auto lhs = op::Tuple(op::Constant());
6189   const auto rhs =
6190       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
6191                                       op::Constant(), op::Constant(),
6192                                       op::Reshape(), op::Constant())),
6193             op::Shape("f32[32,39296,32,128]"));
6194   auto masked_rhs =
6195       op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
6196   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
6197                               op::Tuple(lhs, masked_rhs, op::Broadcast(),
6198                                         op::Broadcast(), op::Constant()))),
6199                           op::Shape("f32[32,12,39296]")));
6200 
6201   const auto while_loop = root->operand(0);
6202   // Check loop condition.
6203   EXPECT_THAT(
6204       while_loop->while_condition()->root_instruction(),
6205       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
6206 
6207   // Check loop body.
6208   const auto next_i =
6209       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
6210               op::Constant());
6211   auto padded_broadcast_sum = op::Pad(
6212       op::Add(op::Broadcast(
6213                   op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6214               op::Broadcast(
6215                   op::GetTupleElement(op::GetTupleElement(op::Parameter(0))))),
6216       op::Constant());
6217   auto intermediate_output =
6218       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
6219                     op::Dot(op::DynamicSlice(padded_broadcast_sum,
6220                                              op::Constant(), op::Constant(),
6221                                              op::Reshape(), op::Constant()),
6222                             op::GetTupleElement(op::Parameter(0)))),
6223             op::Shape("f32[32,12,39296]"));
6224   auto output = AllOf(
6225       op::Add(
6226           intermediate_output,
6227           op::Dot(
6228               op::DynamicSlice(padded_broadcast_sum, op::Constant(),
6229                                op::Constant(), op::Reshape(), op::Constant()),
6230               op::CollectivePermute(op::GetTupleElement(op::Parameter(0))))),
6231       op::Shape("f32[32,12,39296]"));
6232 
6233   EXPECT_THAT(while_loop->while_body()->root_instruction(),
6234               op::Tuple(op::GetTupleElement(op::Parameter(0)),
6235                         op::CollectivePermute(op::CollectivePermute(
6236                             op::GetTupleElement(op::Parameter(0)))),
6237                         output, op::GetTupleElement(op::Parameter(0)), next_i));
6238 }
6239 
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedContractingFromBroadcast)6240 TEST_F(SpmdPartitioningTest,
6241        BidirectionalEinsumRHSWindowedContractingFromBroadcast) {
6242   absl::string_view hlo_string = R"(
6243 HloModule module
6244 
6245 ENTRY entry {
6246   %rhs = f32[32,39296,63,128] parameter(0)
6247   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,4,1]0,1,2,3}
6248   %constant.1 = f32[] constant(2)
6249   %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
6250     sharding={devices=[1,4,1,1]0,1,2,3}
6251   %add = f32[32,24,63,128] add(%broadcast, %broadcast),
6252     sharding={devices=[1,4,1,1]0,1,2,3}
6253   ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
6254     lhs_batch_dims={0}, rhs_batch_dims={0},
6255     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6256     sharding={devices=[1,4,1]0,1,2,3}
6257 })";
6258 
6259   TF_ASSERT_OK_AND_ASSIGN(
6260       auto module,
6261       PartitionComputation(hlo_string, /*num_devices=*/4,
6262                            /*conv_halo_exchange_always_on_lhs =*/true,
6263                            /*choose_faster_windowed_einsum =*/false,
6264                            /*unroll_windowed_einsum =*/false,
6265                            /*bidirectional_windowed_einsum =*/true));
6266   VLOG(1) << module->ToString();
6267 
6268   auto input_subtuple =
6269       op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
6270 
6271   const auto root = module->entry_computation()->root_instruction();
6272   const auto lhs = op::Tuple(op::Constant());
6273   const auto rhs =
6274       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
6275                                       op::Constant(), op::Constant(),
6276                                       op::Reshape(), op::Constant())),
6277             op::Shape("f32[32,39296,16,128]"));
6278   auto masked_rhs = op::Reshape(
6279       op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())));
6280   EXPECT_THAT(root,
6281               AllOf(op::GetTupleElement(op::While(op::Tuple(
6282                         lhs, masked_rhs, op::Broadcast(),
6283                         op::CollectivePermute(masked_rhs), op::Constant()))),
6284                     op::Shape("f32[32,6,39296]")));
6285 
6286   const auto while_loop = root->operand(0);
6287   // Check loop condition.
6288   EXPECT_THAT(
6289       while_loop->while_condition()->root_instruction(),
6290       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
6291 
6292   // Check loop body.
6293   const auto next_i =
6294       op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
6295               op::Constant());
6296   auto output =
6297       AllOf(op::Add(op::Add(op::GetTupleElement(op::Parameter(0)),
6298                             op::Dot(op::Maximum(), op::Concatenate())),
6299                     op::Dot(op::Maximum(), op::Concatenate())),
6300             op::Shape("f32[32,6,39296]"));
6301 
6302   EXPECT_THAT(while_loop->while_body()->root_instruction(),
6303               op::Tuple(op::GetTupleElement(op::Parameter(0)),
6304                         op::CollectivePermute(op::CollectivePermute(
6305                             op::GetTupleElement(op::Parameter(0)))),
6306                         output,
6307                         op::CollectivePermute(op::CollectivePermute(
6308                             op::GetTupleElement(op::Parameter(0)))),
6309                         next_i));
6310 }
6311 
TEST_F(SpmdPartitioningTest,EinsumNonContractingDimPartitionOnTwoDims)6312 TEST_F(SpmdPartitioningTest, EinsumNonContractingDimPartitionOnTwoDims) {
6313   absl::string_view hlo_string = R"(
6314 HloModule module
6315 
6316 ENTRY entry {
6317   %lhs = bf16[8,1024,2,1536] parameter(0)
6318   %lhs.copy = bf16[8,1024,2,1536] copy(lhs),
6319     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6320   %rhs = bf16[2,1536,512,1] parameter(1)
6321   %rhs.copy = bf16[2,1536,512,1] copy(rhs),
6322     sharding={devices=[2,1,2,1,2]0,4,2,6,1,5,3,7 last_tile_dim_replicate}
6323   ROOT %convolution = bf16[8,1024,512,1] convolution(lhs.copy, rhs.copy),
6324     window={size=1x2}, dim_labels=0b1f_1io0->0bf1,
6325     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6326 })";
6327 
6328   TF_ASSERT_OK_AND_ASSIGN(auto module,
6329                           PartitionComputation(hlo_string, /*num_devices=*/8));
6330   VLOG(1) << module->ToString();
6331   const auto root = module->entry_computation()->root_instruction();
6332   const auto lhs = AllOf(
6333       op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant(),
6334                                 op::Reshape(), op::Constant())),
6335       op::Shape("bf16[2,1024,1,1536]"));
6336   const auto rhs = AllOf(
6337       op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
6338                                 op::Reshape(), op::Constant())),
6339       op::Shape("bf16[1,1536,256,1]"));
6340 
6341   const auto partial_replicate_rhs =
6342       AllOf(op::AllReduce(op::DynamicUpdateSlice(
6343                 op::Broadcast(), rhs, op::Constant(), op::Constant(),
6344                 op::Reshape(), op::Constant())),
6345             op::Shape("bf16[1,1536,512,1]"));
6346   EXPECT_THAT(
6347       root,
6348       AllOf(op::DynamicSlice(
6349                 op::AllReduce(op::Convolution(lhs, partial_replicate_rhs)),
6350                 op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
6351             op::Shape("bf16[2,1024,256,1]")));
6352 }
6353 
TEST_F(SpmdPartitioningTest,EinsumNonContractingDimPartitionOnTwoDims2)6354 TEST_F(SpmdPartitioningTest, EinsumNonContractingDimPartitionOnTwoDims2) {
6355   absl::string_view hlo_string = R"(
6356 HloModule module
6357 
6358 ENTRY entry {
6359   %lhs = bf16[8,1024,2,1536] parameter(0)
6360   %lhs.copy = bf16[8,1024,2,1536] copy(lhs),
6361     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6362   %rhs = bf16[2,1536,512,1] parameter(1)
6363   %rhs.copy = bf16[2,1536,512,1] copy(rhs),
6364     sharding={devices=[2,1,2,1,2]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
6365   ROOT %convolution = bf16[8,1024,512,1] convolution(lhs.copy, rhs.copy),
6366     window={size=1x2}, dim_labels=0b1f_1io0->0bf1,
6367     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6368 })";
6369 
6370   TF_ASSERT_OK_AND_ASSIGN(auto module,
6371                           PartitionComputation(hlo_string, /*num_devices=*/8));
6372   VLOG(1) << module->ToString();
6373   const auto root = module->entry_computation()->root_instruction();
6374   const auto lhs = AllOf(
6375       op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant(),
6376                                 op::Reshape(), op::Constant())),
6377       op::Shape("bf16[2,1024,1,1536]"));
6378   const auto rhs = AllOf(
6379       op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
6380                                 op::Reshape(), op::Constant())),
6381       op::Shape("bf16[1,1536,256,1]"));
6382 
6383   const auto partial_replicate_rhs =
6384       AllOf(op::AllReduce(op::DynamicUpdateSlice(
6385                 op::Broadcast(), rhs, op::Constant(), op::Constant(),
6386                 op::Reshape(), op::Constant())),
6387             op::Shape("bf16[1,1536,512,1]"));
6388   EXPECT_THAT(
6389       root,
6390       AllOf(op::DynamicSlice(
6391                 op::AllReduce(op::Convolution(lhs, partial_replicate_rhs)),
6392                 op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
6393             op::Shape("bf16[2,1024,256,1]")));
6394 }
6395 
TEST_F(SpmdPartitioningTest,ReplicatedRng)6396 TEST_F(SpmdPartitioningTest, ReplicatedRng) {
6397   absl::string_view hlo_string = R"(
6398 HloModule module
6399 
6400 ENTRY entry {
6401   %lhs = s32[] parameter(0)
6402   %lhs.copy = s32[] copy(%lhs), sharding={replicated}
6403   %rhs = s32[] parameter(1)
6404   %rhs.copy = s32[] copy(%rhs), sharding={replicated}
6405   ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
6406       distribution=rng_uniform, sharding={replicated}
6407 })";
6408 
6409   TF_ASSERT_OK_AND_ASSIGN(auto module,
6410                           PartitionComputation(hlo_string, /*num_devices=*/2));
6411   VLOG(1) << module->ToString();
6412 
6413   const auto root = module->entry_computation()->root_instruction();
6414   const auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
6415   const auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]"));
6416   EXPECT_THAT(
6417       root,
6418       AllOf(op::AllReduce(op::Select(
6419                 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
6420                 op::Rng(), op::Broadcast(op::Constant()))),
6421             op::Shape("s32[4]")));
6422 }
6423 
TEST_F(SpmdPartitioningTest,ManualRng)6424 TEST_F(SpmdPartitioningTest, ManualRng) {
6425   absl::string_view hlo_string = R"(
6426 HloModule module
6427 
6428 ENTRY entry {
6429   %lhs = s32[] parameter(0), sharding={manual}
6430   %rhs = s32[] parameter(1), sharding={manual}
6431   ROOT %rng = s32[4]{0} rng(%lhs, %rhs),
6432       distribution=rng_uniform, sharding={manual}
6433 })";
6434 
6435   TF_ASSERT_OK_AND_ASSIGN(auto module,
6436                           PartitionComputation(hlo_string, /*num_devices=*/2));
6437   VLOG(1) << module->ToString();
6438 
6439   const auto root = module->entry_computation()->root_instruction();
6440   EXPECT_THAT(root, AllOf(op::Rng(op::Parameter(0), op::Parameter(1)),
6441                           op::Shape("s32[4]")));
6442 }
6443 
TEST_F(SpmdPartitioningTest,PartitionedRng)6444 TEST_F(SpmdPartitioningTest, PartitionedRng) {
6445   absl::string_view hlo_string = R"(
6446 HloModule module
6447 
6448 ENTRY entry {
6449   %lhs = s32[] parameter(0)
6450   %lhs.copy = s32[] copy(%lhs), sharding={replicated}
6451   %rhs = s32[] parameter(1)
6452   %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1}
6453   ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
6454       distribution=rng_uniform, sharding={devices=[2]0,1}
6455 })";
6456 
6457   TF_ASSERT_OK_AND_ASSIGN(auto module,
6458                           PartitionComputation(hlo_string, /*num_devices=*/2));
6459   VLOG(1) << module->ToString();
6460 
6461   const auto root = module->entry_computation()->root_instruction();
6462   const auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
6463   const auto rhs =
6464       AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]"));
6465   EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select(
6466                                            op::Broadcast(op::Compare()), rhs,
6467                                            op::Broadcast(op::Constant())))),
6468                           op::Shape("s32[2]")));
6469 }
6470 
TEST_F(SpmdPartitioningTest,PartialReplicatedRng)6471 TEST_F(SpmdPartitioningTest, PartialReplicatedRng) {
6472   absl::string_view hlo_string = R"(
6473 HloModule module
6474 
6475 ENTRY entry {
6476   %lhs = s32[] parameter(0), sharding={replicated}
6477   %rhs = s32[] parameter(1), sharding={replicated}
6478   ROOT %rng = s32[8]{0} rng(%lhs, %rhs),
6479       distribution=rng_uniform,
6480       sharding={devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6481 })";
6482 
6483   TF_ASSERT_OK_AND_ASSIGN(auto module,
6484                           PartitionComputation(hlo_string, /*num_devices=*/8));
6485   VLOG(1) << module->ToString();
6486 
6487   const auto root = module->entry_computation()->root_instruction();
6488   const auto lhs = AllOf(op::Parameter(0), op::Shape("s32[]"));
6489   const auto rhs = AllOf(op::Parameter(1), op::Shape("s32[]"));
6490   auto partition_id =
6491       AllOf(op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
6492             op::Shape("u32[]"));
6493   EXPECT_THAT(
6494       root, AllOf(op::AllReduce(op::Select(
6495                       op::Broadcast(op::Compare(partition_id, op::Constant())),
6496                       op::Rng(lhs, rhs), op::Broadcast(op::Constant()))),
6497                   op::Shape("s32[4]")));
6498 }
6499 
TEST_F(SpmdPartitioningTest,DynamicSliceAlongNonPartitionedDimension)6500 TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) {
6501   absl::string_view hlo_string = R"(
6502 HloModule module
6503 
6504 ENTRY entry {
6505   %input = s32[128,64] parameter(0), sharding={devices=[2,1]0,1}
6506   %index = s32[] parameter(1)
6507   %trivial_index = s32[] parameter(2)
6508   ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input, %trivial_index, %index),
6509     dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1}
6510 })";
6511 
6512   TF_ASSERT_OK_AND_ASSIGN(auto module,
6513                           PartitionComputation(hlo_string, /*num_devices=*/2));
6514   VLOG(1) << module->ToString();
6515 
6516   const auto root = module->entry_computation()->root_instruction();
6517   auto input = AllOf(op::Parameter(0), op::Shape("s32[64,64]"));
6518   EXPECT_THAT(root,
6519               AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)),
6520                     op::Shape("s32[64,2]")));
6521 }
6522 
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongNonPartitionedDimension)6523 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) {
6524   absl::string_view hlo_string = R"(
6525 HloModule module
6526 
6527 ENTRY entry {
6528   %input = s32[128,64] parameter(0), sharding={devices=[2,1]0,1}
6529   %index = s32[] parameter(1)
6530   %update = s32[128,2] parameter(2)
6531   %trivial_index = s32[] parameter(3)
6532   %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1}
6533   ROOT %dynamic-update-slice = s32[128,64]
6534     dynamic-update-slice(%input, %update.copy, %trivial_index, %index),
6535     sharding={devices=[2,1]0,1}
6536 })";
6537 
6538   TF_ASSERT_OK_AND_ASSIGN(auto module,
6539                           PartitionComputation(hlo_string, /*num_devices=*/2));
6540   VLOG(1) << module->ToString();
6541 
6542   const auto root = module->entry_computation()->root_instruction();
6543   auto input = AllOf(op::Parameter(0), op::Shape("s32[64,64]"));
6544   auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(),
6545                                                 op::Constant())),
6546                       op::Shape("s32[64,2]"));
6547   EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(),
6548                                                  op::Parameter(1)),
6549                           op::Shape("s32[64,64]")));
6550 }
6551 
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension)6552 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension) {
6553   absl::string_view hlo_string = R"(
6554 HloModule module
6555 
6556 ENTRY entry {
6557   %input = s32[128,64] parameter(0), sharding={devices=[1,2]0,1}
6558   %index = s32[] parameter(1)
6559   %constant = s32[] constant(60)
6560   %update = s32[128,2] parameter(2), sharding={devices=[1,2]0,1}
6561   ROOT %dynamic-update-slice = s32[128,64]
6562     dynamic-update-slice(%input, %update, %index, %constant),
6563     sharding={devices=[1,2]0,1}
6564 })";
6565 
6566   TF_ASSERT_OK_AND_ASSIGN(auto module,
6567                           PartitionComputation(hlo_string, /*num_devices=*/2));
6568   VLOG(1) << module->ToString();
6569 
6570   const auto root = module->entry_computation()->root_instruction();
6571   auto input = AllOf(op::Parameter(0), op::Shape("s32[128,32]"));
6572   auto update = AllOf(
6573       op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), op::Parameter(2),
6574                                            op::Constant(), op::Reshape())),
6575       op::Shape("s32[128,2]"));
6576 
6577   EXPECT_THAT(root,
6578               AllOf(op::Select(op::Broadcast(),
6579                                op::DynamicUpdateSlice(
6580                                    input, update, op::Constant(), op::Select()),
6581                                input),
6582                     op::Shape("s32[128,32]")));
6583 }
6584 
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension2)6585 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension2) {
6586   absl::string_view hlo_string = R"(
6587 HloModule module
6588 
6589 ENTRY entry {
6590   %input = s32[8,790,2] parameter(0),
6591     sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
6592   %index = s32[] parameter(1)
6593   %constant = s32[] constant(0)
6594   %update = s32[1,790,2] parameter(2),
6595     sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
6596   ROOT %dynamic-update-slice = s32[8,790,2]
6597     dynamic-update-slice(%input, %update, %index, %constant, %constant),
6598     sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
6599 })";
6600 
6601   TF_ASSERT_OK_AND_ASSIGN(auto module,
6602                           PartitionComputation(hlo_string, /*num_devices=*/8));
6603   VLOG(1) << module->ToString();
6604 
6605   const auto root = module->entry_computation()->root_instruction();
6606   auto input = AllOf(op::Parameter(0), op::Shape("s32[1,790,2]"));
6607   auto update = AllOf(op::AllReduce(op::Select(
6608                           op::Broadcast(), op::Parameter(2), op::Broadcast())),
6609                       op::Shape("s32[1,790,2]"));
6610   EXPECT_THAT(
6611       root,
6612       AllOf(op::Select(op::Broadcast(),
6613                        op::DynamicUpdateSlice(input, update, op::Select(),
6614                                               op::Constant(), op::Constant()),
6615                        input),
6616             op::Shape("s32[1,790,2]")));
6617 }
6618 
TEST_F(SpmdPartitioningTest,DynamicUpdateSlicePartitionSliceAndNonSliceDims)6619 TEST_F(SpmdPartitioningTest, DynamicUpdateSlicePartitionSliceAndNonSliceDims) {
6620   absl::string_view hlo_string = R"(
6621 HloModule module
6622 
6623 ENTRY entry {
6624   %input = s32[128,64] parameter(0)
6625   %input.copy = s32[128,64] copy(%input), sharding={devices=[2,2]0,1,2,3}
6626   %constant.0 = s32[] constant(0)
6627   %constant.1 = s32[] constant(60)
6628   %update = s32[128,2] parameter(1)
6629   %update.copy = s32[128,2] copy(%update), sharding={devices=[2,2]0,1,2,3}
6630   ROOT %dynamic-update-slice = s32[128,64]
6631     dynamic-update-slice(%input.copy, %update.copy, %constant.0, %constant.1),
6632     sharding={devices=[2,2]0,1,2,3}
6633 })";
6634 
6635   TF_ASSERT_OK_AND_ASSIGN(auto module,
6636                           PartitionComputation(hlo_string, /*num_devices=*/4));
6637   VLOG(1) << module->ToString();
6638 
6639   const auto root = module->entry_computation()->root_instruction();
6640   auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6641                                                op::Reshape())),
6642                      op::Shape("s32[64,32]"));
6643   auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
6644                           op::Broadcast(),
6645                           op::Copy(op::DynamicSlice(
6646                               op::Parameter(1), op::Reshape(), op::Reshape())),
6647                           op::Constant(), op::Reshape())),
6648                       op::Shape("s32[64,2]"));
6649 
6650   EXPECT_THAT(root,
6651               AllOf(op::Select(op::Broadcast(),
6652                                op::DynamicUpdateSlice(
6653                                    input, update, op::Constant(), op::Select()),
6654                                input),
6655                     op::Shape("s32[64,32]")));
6656 }
6657 
TEST_F(SpmdPartitioningTest,PassthroughGather)6658 TEST_F(SpmdPartitioningTest, PassthroughGather) {
6659   absl::string_view hlo_string = R"(
6660 HloModule module
6661 
6662 ENTRY entry {
6663   %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
6664   %indices = s32[3] parameter(1), sharding={replicated}
6665   ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
6666     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
6667     slice_sizes={1,9}, sharding={devices=[1,2]0,1}
6668 })";
6669   TF_ASSERT_OK_AND_ASSIGN(auto module,
6670                           PartitionComputation(hlo_string, /*num_devices=*/2));
6671   VLOG(1) << module->ToString();
6672   HloInstruction* root = module->entry_computation()->root_instruction();
6673   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6674                           op::Shape("f32[3,5]")));
6675 }
6676 
TEST_F(SpmdPartitioningTest,PassthroughGather_PartialReplicate)6677 TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) {
6678   absl::string_view hlo_string = R"(
6679 HloModule module
6680 
6681 ENTRY entry {
6682   %input = f32[2,9] parameter(0),
6683     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6684   %indices = s32[3] parameter(1), sharding={replicated}
6685   ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
6686     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
6687     slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6688 })";
6689   TF_ASSERT_OK_AND_ASSIGN(auto module,
6690                           PartitionComputation(hlo_string, /*num_devices=*/4));
6691   VLOG(1) << module->ToString();
6692   HloInstruction* root = module->entry_computation()->root_instruction();
6693   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6694                           op::Shape("f32[3,5]")));
6695 }
6696 
TEST_F(SpmdPartitioningTest,IndexPassthroughGather)6697 TEST_F(SpmdPartitioningTest, IndexPassthroughGather) {
6698   absl::string_view hlo_string = R"(
6699 HloModule module
6700 
6701 ENTRY entry {
6702   %input = f32[2,9,8] parameter(0), sharding={replicated}
6703   %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
6704   ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
6705     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
6706     slice_sizes={1,1,8}, sharding={devices=[1,2,2]0,1,2,3}
6707 })";
6708   TF_ASSERT_OK_AND_ASSIGN(auto module,
6709                           PartitionComputation(hlo_string, /*num_devices=*/4));
6710   VLOG(1) << module->ToString();
6711   HloInstruction* root = module->entry_computation()->root_instruction();
6712   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6713                           op::Shape("f32[8,2,2]")));
6714 }
6715 
TEST_F(SpmdPartitioningTest,IndexPassthroughGather_PartialReplicate)6716 TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) {
6717   absl::string_view hlo_string = R"(
6718 HloModule module
6719 
6720 ENTRY entry {
6721   %input = f32[2,9,8] parameter(0), sharding={replicated}
6722   %indices = s32[4,2,4] parameter(1),
6723     sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6724   ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
6725     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
6726     slice_sizes={1,1,8},
6727     sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6728 })";
6729   TF_ASSERT_OK_AND_ASSIGN(auto module,
6730                           PartitionComputation(hlo_string, /*num_devices=*/8));
6731   VLOG(1) << module->ToString();
6732   HloInstruction* root = module->entry_computation()->root_instruction();
6733   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6734                           op::Shape("f32[8,2,2]")));
6735 }
6736 
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims)6737 TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) {
6738   absl::string_view hlo_string = R"(
6739 HloModule module
6740 
6741 ENTRY entry {
6742   %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
6743   %indices = s32[2,3] parameter(1), sharding={replicated}
6744   ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
6745     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
6746     slice_sizes={1,9}, sharding={replicated}
6747 })";
6748   TF_ASSERT_OK_AND_ASSIGN(auto module,
6749                           PartitionComputation(hlo_string, /*num_devices=*/2));
6750   VLOG(1) << module->ToString();
6751   auto offset =
6752       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
6753   auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
6754   auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
6755                    op::Shape("s32[2,3]"));
6756   auto clamp = op::Clamp(min, op::Parameter(1), max);
6757   auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
6758   auto mask =
6759       op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
6760   auto masked =
6761       op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
6762   HloInstruction* root = module->entry_computation()->root_instruction();
6763   EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
6764 }
6765 
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims_PartialReplicate)6766 TEST_F(SpmdPartitioningTest,
6767        GatherPartitionedOnTrivialSliceDims_PartialReplicate) {
6768   absl::string_view hlo_string = R"(
6769 HloModule module
6770 
6771 ENTRY entry {
6772   %input = f32[17,9] parameter(0),
6773     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6774   %indices = s32[2,3] parameter(1), sharding={replicated}
6775   ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
6776     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
6777     slice_sizes={1,9}, sharding={replicated}
6778 })";
6779   TF_ASSERT_OK_AND_ASSIGN(auto module,
6780                           PartitionComputation(hlo_string, /*num_devices=*/4));
6781   VLOG(1) << module->ToString();
6782   auto offset =
6783       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
6784   auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
6785   auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
6786                    op::Shape("s32[2,3]"));
6787   auto clamp = op::Clamp(min, op::Parameter(1), max);
6788   auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
6789   auto mask =
6790       op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
6791   auto masked =
6792       op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
6793   HloInstruction* root = module->entry_computation()->root_instruction();
6794   EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
6795 }
6796 
TEST_F(SpmdPartitioningTest,PassthroughScatter)6797 TEST_F(SpmdPartitioningTest, PassthroughScatter) {
6798   absl::string_view hlo_string = R"(
6799 HloModule module
6800 
6801 add (lhs: f32[], rhs: f32[]) -> f32[] {
6802   lhs = f32[] parameter(0)
6803   rhs = f32[] parameter(1)
6804   ROOT sum = f32[] add(lhs, rhs)
6805 }
6806 
6807 ENTRY entry {
6808   %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
6809   %indices = s32[3] parameter(1), sharding={replicated}
6810   %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
6811   ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
6812       to_apply=add,
6813       update_window_dims={1},
6814       inserted_window_dims={0},
6815       scatter_dims_to_operand_dims={0},
6816       index_vector_dim=1, sharding={devices=[1,2]0,1}
6817 })";
6818   TF_ASSERT_OK_AND_ASSIGN(auto module,
6819                           PartitionComputation(hlo_string, /*num_devices=*/2));
6820   VLOG(1) << module->ToString();
6821   HloInstruction* root = module->entry_computation()->root_instruction();
6822   EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6823                                       op::Parameter(2)),
6824                           op::Shape("f32[2,5]")));
6825 }
6826 
TEST_F(SpmdPartitioningTest,PassthroughScatterVariadic)6827 TEST_F(SpmdPartitioningTest, PassthroughScatterVariadic) {
6828   absl::string_view hlo_string = R"(
6829 HloModule module
6830 
6831 add_min_max {
6832   lhs0 = f32[] parameter(0)
6833   lhs1 = f32[] parameter(1)
6834   rhs0 = f32[] parameter(2)
6835   rhs1 = f32[] parameter(3)
6836   min = minimum(rhs0, rhs1)
6837   max = maximum(rhs0, rhs1)
6838   min_sum = add(lhs0, min)
6839   max_sum = add(lhs1, max)
6840   ROOT tuple = tuple(min_sum, max_sum)
6841 }
6842 
6843 ENTRY entry {
6844   %input0 = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
6845   %input1 = f32[2,9] parameter(1), sharding={devices=[1,2]0,1}
6846   %indices = s32[3] parameter(2), sharding={replicated}
6847   %updates0 = f32[3,9] parameter(3), sharding={devices=[1,2]0,1}
6848   %updates1 = f32[3,9] parameter(4), sharding={devices=[1,2]0,1}
6849   ROOT %scatter = (f32[2,9], f32[2,9])
6850     scatter(%input0, %input1, %indices, %updates0, %updates1),
6851       to_apply=add_min_max, update_window_dims={1}, inserted_window_dims={0},
6852       scatter_dims_to_operand_dims={0}, index_vector_dim=1,
6853       sharding={devices=[1,2]0,1}
6854 })";
6855   TF_ASSERT_OK_AND_ASSIGN(auto module,
6856                           PartitionComputation(hlo_string, /*num_devices=*/2));
6857   HloInstruction* root = module->entry_computation()->root_instruction();
6858   EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6859                                       op::Parameter(2), op::Parameter(3),
6860                                       op::Parameter(4)),
6861                           op::Shape("(f32[2,5], f32[2,5])")));
6862 }
6863 
TEST_F(SpmdPartitioningTest,PassthroughScatter_PartialReplicate)6864 TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
6865   absl::string_view hlo_string = R"(
6866 HloModule module
6867 
6868 add (lhs: f32[], rhs: f32[]) -> f32[] {
6869   lhs = f32[] parameter(0)
6870   rhs = f32[] parameter(1)
6871   ROOT sum = f32[] add(lhs, rhs)
6872 }
6873 
6874 ENTRY entry {
6875   %input = f32[2,9] parameter(0),
6876     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6877   %indices = s32[3] parameter(1), sharding={replicated}
6878   %updates = f32[3,9] parameter(2),
6879     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6880   ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
6881       to_apply=add,
6882       update_window_dims={1},
6883       inserted_window_dims={0},
6884       scatter_dims_to_operand_dims={0},
6885       index_vector_dim=1,
6886       sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6887 })";
6888   TF_ASSERT_OK_AND_ASSIGN(auto module,
6889                           PartitionComputation(hlo_string, /*num_devices=*/4));
6890   VLOG(1) << module->ToString();
6891   HloInstruction* root = module->entry_computation()->root_instruction();
6892   EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6893                                       op::Parameter(2)),
6894                           op::Shape("f32[2,5]")));
6895 }
6896 
TEST_F(SpmdPartitioningTest,PassthroughScatterVariadic_PartialReplicate)6897 TEST_F(SpmdPartitioningTest, PassthroughScatterVariadic_PartialReplicate) {
6898   absl::string_view hlo_string = R"(
6899 HloModule module
6900 
6901 add_min_max {
6902   lhs0 = f32[] parameter(0)
6903   lhs1 = f32[] parameter(1)
6904   rhs0 = f32[] parameter(2)
6905   rhs1 = f32[] parameter(3)
6906   min = minimum(rhs0, rhs1)
6907   max = maximum(rhs0, rhs1)
6908   min_sum = add(lhs0, min)
6909   max_sum = add(lhs1, max)
6910   ROOT tuple = tuple(min_sum, max_sum)
6911 }
6912 
6913 ENTRY entry {
6914   %input0 = f32[2,9] parameter(0),
6915     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6916   %input1 = f32[2,9] parameter(1),
6917     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6918   %indices = s32[3] parameter(2), sharding={replicated}
6919   %updates0 = f32[3,9] parameter(3),
6920     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6921   %updates1 = f32[3,9] parameter(4),
6922     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6923   ROOT %scatter = (f32[2,9], f32[2,9])
6924     scatter(%input0, %input1, %indices, %updates0, %updates1),
6925       to_apply=add_min_max, update_window_dims={1}, inserted_window_dims={0},
6926       scatter_dims_to_operand_dims={0}, index_vector_dim=1,
6927       sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6928 })";
6929   TF_ASSERT_OK_AND_ASSIGN(auto module,
6930                           PartitionComputation(hlo_string, /*num_devices=*/4));
6931   HloInstruction* root = module->entry_computation()->root_instruction();
6932   EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6933                                       op::Parameter(2), op::Parameter(3),
6934                                       op::Parameter(4)),
6935                           op::Shape("(f32[2,5], f32[2,5])")));
6936 }
6937 
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter)6938 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
6939   absl::string_view hlo_string = R"(
6940 HloModule module
6941 
6942 add (lhs: f32[], rhs: f32[]) -> f32[] {
6943   lhs = f32[] parameter(0)
6944   rhs = f32[] parameter(1)
6945   ROOT sum = f32[] add(lhs, rhs)
6946 }
6947 
6948 ENTRY entry {
6949   %input = f32[2,9,8] parameter(0), sharding={replicated}
6950   %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
6951   %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
6952   ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
6953       to_apply=add,
6954       update_window_dims={2},
6955       inserted_window_dims={0,1},
6956       scatter_dims_to_operand_dims={0,1},
6957       index_vector_dim=1, sharding={replicated}
6958 })";
6959   TF_ASSERT_OK_AND_ASSIGN(auto module,
6960                           PartitionComputation(hlo_string, /*num_devices=*/4));
6961   VLOG(1) << module->ToString();
6962   HloInstruction* root = module->entry_computation()->root_instruction();
6963   EXPECT_THAT(
6964       root,
6965       AllOf(op::AllReduce(op::AllReduce(op::Scatter(
6966                 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
6967                            op::Broadcast(op::Constant()), op::Parameter(0)),
6968                 op::Parameter(1), op::Parameter(2)))),
6969             op::Shape("f32[2,9,8]")));
6970 }
6971 
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_PartialReplicate)6972 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
6973   absl::string_view hlo_string = R"(
6974 HloModule module
6975 
6976 add (lhs: f32[], rhs: f32[]) -> f32[] {
6977   lhs = f32[] parameter(0)
6978   rhs = f32[] parameter(1)
6979   ROOT sum = f32[] add(lhs, rhs)
6980 }
6981 
6982 ENTRY entry {
6983   %input = f32[2,9,8] parameter(0), sharding={replicated}
6984   %indices = s32[4,2,4] parameter(1),
6985     sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6986   %updates = f32[4,4,8] parameter(2),
6987     sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6988   ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
6989       to_apply=add,
6990       update_window_dims={2},
6991       inserted_window_dims={0,1},
6992       scatter_dims_to_operand_dims={0,1},
6993       index_vector_dim=1, sharding={replicated}
6994 })";
6995   TF_ASSERT_OK_AND_ASSIGN(auto module,
6996                           PartitionComputation(hlo_string, /*num_devices=*/8));
6997   VLOG(1) << module->ToString();
6998   HloInstruction* root = module->entry_computation()->root_instruction();
6999   EXPECT_THAT(
7000       root,
7001       AllOf(op::AllReduce(op::AllReduce(op::Scatter(
7002                 op::Select(op::Broadcast(op::Convert(op::Reshape())),
7003                            op::Broadcast(op::Constant()), op::Parameter(0)),
7004                 op::Parameter(1), op::Parameter(2)))),
7005             op::Shape("f32[2,9,8]")));
7006 }
7007 
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_Min)7008 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
7009   absl::string_view hlo_string = R"(
7010 HloModule module
7011 
7012 min (lhs: f32[], rhs: f32[]) -> f32[] {
7013   lhs = f32[] parameter(0)
7014   rhs = f32[] parameter(1)
7015   ROOT min = f32[] minimum(lhs, rhs)
7016 }
7017 
7018 ENTRY entry {
7019   %input = f32[2,9,8] parameter(0), sharding={replicated}
7020   %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
7021   %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
7022   ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
7023       to_apply=min,
7024       update_window_dims={2},
7025       inserted_window_dims={0,1},
7026       scatter_dims_to_operand_dims={0,1},
7027       index_vector_dim=1, sharding={replicated}
7028 })";
7029   TF_ASSERT_OK_AND_ASSIGN(auto module,
7030                           PartitionComputation(hlo_string, /*num_devices=*/4));
7031   VLOG(1) << module->ToString();
7032   HloInstruction* root = module->entry_computation()->root_instruction();
7033   EXPECT_THAT(
7034       root,
7035       AllOf(op::AllReduce(op::AllReduce(op::Scatter(
7036                 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
7037                            op::Broadcast(op::Constant()), op::Parameter(0)),
7038                 op::Parameter(1), op::Parameter(2)))),
7039             op::Shape("f32[2,9,8]")));
7040 }
7041 
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims)7042 TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) {
7043   absl::string_view hlo_string = R"(
7044 HloModule module
7045 
7046 add (lhs: f32[], rhs: f32[]) -> f32[] {
7047   lhs = f32[] parameter(0)
7048   rhs = f32[] parameter(1)
7049   ROOT sum = f32[] add(lhs, rhs)
7050 }
7051 
7052 ENTRY entry {
7053   %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
7054   %indices = s32[2,3] parameter(1), sharding={replicated}
7055   %updates = f32[2,3,9] parameter(2), sharding={replicated}
7056   ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
7057       to_apply=add,
7058       update_window_dims={2},
7059       inserted_window_dims={0},
7060       scatter_dims_to_operand_dims={0},
7061       index_vector_dim=2, sharding={devices=[2,1]0,1}
7062 })";
7063   TF_ASSERT_OK_AND_ASSIGN(auto module,
7064                           PartitionComputation(hlo_string, /*num_devices=*/2));
7065   VLOG(1) << module->ToString();
7066   auto offset =
7067       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7068   auto indices = op::Subtract(
7069       op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7070   HloInstruction* root = module->entry_computation()->root_instruction();
7071   EXPECT_THAT(root,
7072               AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
7073                     op::Shape("f32[9,9]")));
7074 }
7075 
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDimsVariadic)7076 TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDimsVariadic) {
7077   absl::string_view hlo_string = R"(
7078 HloModule module
7079 
7080 add_min_max {
7081   lhs0 = f32[] parameter(0)
7082   lhs1 = f32[] parameter(1)
7083   rhs0 = f32[] parameter(2)
7084   rhs1 = f32[] parameter(3)
7085   min = minimum(rhs0, rhs1)
7086   max = maximum(rhs0, rhs1)
7087   min_sum = add(lhs0, min)
7088   max_sum = add(lhs1, max)
7089   ROOT tuple = tuple(min_sum, max_sum)
7090 }
7091 
7092 ENTRY entry {
7093   %input0 = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
7094   %input1 = f32[17,9] parameter(1), sharding={devices=[2,1]0,1}
7095   %indices = s32[2,3] parameter(2), sharding={replicated}
7096   %updates0 = f32[2,3,9] parameter(3), sharding={replicated}
7097   %updates1 = f32[2,3,9] parameter(4), sharding={replicated}
7098   ROOT %scatter = (f32[17,9], f32[17,9])
7099     scatter(%input0, %input1, %indices, %updates0, %updates1),
7100       to_apply=add_min_max, update_window_dims={2}, inserted_window_dims={0},
7101       scatter_dims_to_operand_dims={0}, index_vector_dim=2,
7102       sharding={devices=[2,1]0,1}
7103 })";
7104   TF_ASSERT_OK_AND_ASSIGN(auto module,
7105                           PartitionComputation(hlo_string, /*num_devices=*/2));
7106   auto offset =
7107       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7108   auto indices = op::Subtract(
7109       op::Parameter(2), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7110   HloInstruction* root = module->entry_computation()->root_instruction();
7111   EXPECT_THAT(root,
7112               AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), indices,
7113                                 op::Parameter(3), op::Parameter(4)),
7114                     op::Shape("(f32[9,9], f32[9,9])")));
7115 }
7116 
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims_PartialReplicate)7117 TEST_F(SpmdPartitioningTest,
7118        ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
7119   absl::string_view hlo_string = R"(
7120 HloModule module
7121 
7122 add (lhs: f32[], rhs: f32[]) -> f32[] {
7123   lhs = f32[] parameter(0)
7124   rhs = f32[] parameter(1)
7125   ROOT sum = f32[] add(lhs, rhs)
7126 }
7127 
7128 ENTRY entry {
7129   %input = f32[17,9] parameter(0),
7130     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7131   %indices = s32[2,3] parameter(1), sharding={replicated}
7132   %updates = f32[2,3,9] parameter(2), sharding={replicated}
7133   ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
7134       to_apply=add,
7135       update_window_dims={2},
7136       inserted_window_dims={0},
7137       scatter_dims_to_operand_dims={0},
7138       index_vector_dim=2,
7139       sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7140 })";
7141   TF_ASSERT_OK_AND_ASSIGN(auto module,
7142                           PartitionComputation(hlo_string, /*num_devices=*/4));
7143   VLOG(1) << module->ToString();
7144   auto offset =
7145       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7146   auto indices = op::Subtract(
7147       op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7148   HloInstruction* root = module->entry_computation()->root_instruction();
7149   EXPECT_THAT(root,
7150               AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
7151                     op::Shape("f32[9,9]")));
7152 }
7153 
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDimsVariadic_PartialReplicate)7154 TEST_F(SpmdPartitioningTest,
7155        ScatterPartitionedOnTrivialSliceDimsVariadic_PartialReplicate) {
7156   absl::string_view hlo_string = R"(
7157 HloModule module
7158 
7159 add_min_max {
7160   lhs0 = f32[] parameter(0)
7161   lhs1 = f32[] parameter(1)
7162   rhs0 = f32[] parameter(2)
7163   rhs1 = f32[] parameter(3)
7164   min = minimum(rhs0, rhs1)
7165   max = maximum(rhs0, rhs1)
7166   min_sum = add(lhs0, min)
7167   max_sum = add(lhs1, max)
7168   ROOT tuple = tuple(min_sum, max_sum)
7169 }
7170 
7171 ENTRY entry {
7172   %input0 = f32[17,9] parameter(0),
7173     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7174   %input1 = f32[17,9] parameter(1),
7175     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7176   %indices = s32[2,3] parameter(2), sharding={replicated}
7177   %updates0 = f32[2,3,9] parameter(3), sharding={replicated}
7178   %updates1 = f32[2,3,9] parameter(4), sharding={replicated}
7179   ROOT %scatter = (f32[17,9], f32[17,9])
7180     scatter(%input0, %input1, %indices, %updates0, %updates1),
7181       to_apply=add_min_max, update_window_dims={2}, inserted_window_dims={0},
7182       scatter_dims_to_operand_dims={0}, index_vector_dim=2,
7183       sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7184 })";
7185   TF_ASSERT_OK_AND_ASSIGN(auto module,
7186                           PartitionComputation(hlo_string, /*num_devices=*/4));
7187   VLOG(1) << module->ToString();
7188   auto offset =
7189       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7190   auto indices = op::Subtract(
7191       op::Parameter(2), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7192   HloInstruction* root = module->entry_computation()->root_instruction();
7193   EXPECT_THAT(root,
7194               AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), indices,
7195                                 op::Parameter(3), op::Parameter(4)),
7196                     op::Shape("(f32[9,9], f32[9,9])")));
7197 }
7198 
TEST_F(SpmdPartitioningTest,TiledReversePassthrough)7199 TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
7200   absl::string_view hlo_string = R"(
7201 HloModule module
7202 
7203 ENTRY entry {
7204   constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
7205     sharding={devices=[2,1]0,1}
7206   ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1},
7207     sharding={devices=[2,1]0,1}
7208 })";
7209   TF_ASSERT_OK_AND_ASSIGN(auto module,
7210                           PartitionComputation(hlo_string, /*num_devices=*/2));
7211   VLOG(1) << module->ToString();
7212   HloInstruction* root = module->entry_computation()->root_instruction();
7213   EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"),
7214                           op::Reverse(op::DynamicSlice(
7215                               op::Pad(op::Constant(), op::Constant()),
7216                               op::Reshape(), op::Constant()))));
7217 }
7218 
TEST_F(SpmdPartitioningTest,TiledReversePassthroughViaReversedSharding)7219 TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
7220   absl::string_view hlo_string = R"(
7221 HloModule module
7222 
7223 ENTRY entry {
7224   param = f32[4] parameter(0), sharding={devices=[2]0,1}
7225   ROOT reverse = f32[4] reverse(param), dimensions={0},
7226     sharding={devices=[2]1,0}
7227 })";
7228   TF_ASSERT_OK_AND_ASSIGN(auto module,
7229                           PartitionComputation(hlo_string, /*num_devices=*/2));
7230   VLOG(1) << module->ToString();
7231   HloInstruction* root = module->entry_computation()->root_instruction();
7232   EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
7233 }
7234 
TEST_F(SpmdPartitioningTest,TiledReverseSwapShards)7235 TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
7236   absl::string_view hlo_string = R"(
7237 HloModule module
7238 
7239 ENTRY entry {
7240   param = f32[4] parameter(0), sharding={devices=[2]0,1}
7241   ROOT reverse = f32[4] reverse(param), dimensions={0},
7242     sharding={devices=[2]0,1}
7243 })";
7244   TF_ASSERT_OK_AND_ASSIGN(auto module,
7245                           PartitionComputation(hlo_string, /*num_devices=*/2));
7246   VLOG(1) << module->ToString();
7247   HloInstruction* root = module->entry_computation()->root_instruction();
7248   EXPECT_THAT(root,
7249               AllOf(op::Shape("f32[2]"),
7250                     op::Reverse(op::CollectivePermute(op::Parameter(0)))));
7251 }
7252 
TEST_F(SpmdPartitioningTest,TiledReverseHaloExchange)7253 TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
7254   absl::string_view hlo_string = R"(
7255 HloModule module
7256 
7257 ENTRY entry {
7258   param = f32[3] parameter(0), sharding={devices=[2]0,1}
7259   ROOT reverse = f32[3] reverse(param), dimensions={0},
7260     sharding={devices=[2]1,0}
7261 })";
7262   TF_ASSERT_OK_AND_ASSIGN(auto module,
7263                           PartitionComputation(hlo_string, /*num_devices=*/2));
7264   VLOG(1) << module->ToString();
7265   HloInstruction* root = module->entry_computation()->root_instruction();
7266   auto halo_exchange_concat =
7267       op::Concatenate(AllOf(op::Shape("f32[1]"),
7268                             op::CollectivePermute(op::Slice(op::Parameter(0)))),
7269                       op::Slice(op::Parameter(0)));
7270   EXPECT_THAT(root,
7271               AllOf(op::Shape("f32[2]"), op::Reverse(halo_exchange_concat)));
7272 }
7273 
TEST_F(SpmdPartitioningTest,MixWithManualPartitioning)7274 TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
7275   absl::string_view hlo_string = R"(
7276 HloModule module
7277 
7278 ENTRY entry {
7279   param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}}
7280   param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1}
7281   param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual}
7282   to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
7283   add = f32[4,2] add(to_shard, param1), sharding={manual}
7284   to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
7285   mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
7286   to_shard2 = f32[4,2] custom-call(mul), custom_call_target="SPMDFullToShardShape", sharding={manual}
7287   ROOT tuple = (f32[4,2]) tuple(to_shard2), sharding={{manual}}
7288 })";
7289   TF_ASSERT_OK_AND_ASSIGN(auto module,
7290                           PartitionComputation(hlo_string, /*num_devices=*/2));
7291   VLOG(1) << module->ToString();
7292   HloInstruction* root = module->entry_computation()->root_instruction();
7293   auto p0 = op::GetTupleElement(op::Parameter(0));
7294   auto to_shard = op::Copy(p0);
7295   auto p1 = op::GetTupleElement(op::Parameter(0));
7296   auto mul = AllOf(op::Shape("f32[4,2]"),
7297                    op::Multiply(op::Copy(op::Add(to_shard, p1)), p0));
7298   EXPECT_THAT(root, op::Tuple(op::Copy(mul)));
7299 }
7300 
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard)7301 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
7302   absl::string_view hlo_string = R"(
7303 HloModule module
7304 
7305 ENTRY entry {
7306   %param0 = f32[8,8,8,8] parameter(0),
7307     sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
7308   ROOT %copy = f32[8,8,8,8] copy(%param0),
7309     sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
7310 })";
7311 
7312   TF_ASSERT_OK_AND_ASSIGN(auto module,
7313                           PartitionComputation(hlo_string, /*num_devices=*/8));
7314   VLOG(1) << module->ToString();
7315 
7316   const auto root = module->entry_computation()->root_instruction();
7317   auto reshape =
7318       AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
7319   auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
7320   auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
7321   EXPECT_THAT(root,
7322               op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
7323   EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
7324             4);
7325 }
7326 
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard2)7327 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) {
7328   absl::string_view hlo_string = R"(
7329 HloModule module
7330 
7331 ENTRY entry {
7332   %param0 = f32[8,8] parameter(0),
7333     sharding={devices=[2,4]0,1,2,3,4,5,6,7}
7334   ROOT %copy = f32[8,8] copy(%param0),
7335     sharding={devices=[4,2]0,1,4,5,2,3,6,7}
7336 })";
7337 
7338   TF_ASSERT_OK_AND_ASSIGN(auto module,
7339                           PartitionComputation(hlo_string, /*num_devices=*/8));
7340   VLOG(1) << module->ToString();
7341 
7342   const auto root = module->entry_computation()->root_instruction();
7343   auto all_to_all = op::AllToAll(
7344       AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0))));
7345   auto reshape =
7346       AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all)));
7347   EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape)));
7348 }
7349 
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard3)7350 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) {
7351   absl::string_view hlo_string = R"(
7352 HloModule module
7353 
7354 ENTRY entry {
7355   %param0 = f32[8,8,8] parameter(0),
7356     sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
7357   ROOT %copy = f32[8,8,8] copy(%param0),
7358     sharding={devices=[1,2,4]0,1,4,5,2,3,6,7}
7359 })";
7360 
7361   TF_ASSERT_OK_AND_ASSIGN(auto module,
7362                           PartitionComputation(hlo_string, /*num_devices=*/8));
7363   VLOG(1) << module->ToString();
7364 
7365   const auto root = module->entry_computation()->root_instruction();
7366   auto all_to_all = op::AllToAll(
7367       AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0))));
7368   auto reshape =
7369       AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all)));
7370   auto all_to_all2 =
7371       op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape)));
7372   auto reshape2 =
7373       AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2)));
7374   EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
7375 }
7376 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting0)7377 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) {
7378   absl::string_view hlo_string = R"(
7379 HloModule module
7380 
7381 ENTRY entry {
7382   %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3}
7383   %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,2,1,3}
7384   ROOT %dot = f32[48,32] dot(%lhs, %rhs),
7385     lhs_batch_dims={}, rhs_batch_dims={},
7386     lhs_contracting_dims={1}, rhs_contracting_dims={1},
7387     sharding={devices=[2,2]0,1,2,3}
7388 })";
7389 
7390   TF_ASSERT_OK_AND_ASSIGN(auto module,
7391                           PartitionComputation(hlo_string, /*num_devices=*/4));
7392   VLOG(1) << module->ToString();
7393 
7394   const auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0));
7395   auto partial_replicated_lhs =
7396       AllOf(op::Shape("f32[24,12]"),
7397             op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _)));
7398   const auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1));
7399   auto partial_replicated_rhs =
7400       AllOf(op::Shape("f32[16,12]"),
7401             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
7402   const auto root = module->entry_computation()->root_instruction();
7403   EXPECT_THAT(root,
7404               AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs),
7405                     op::Shape("f32[24,16]")));
7406 }
7407 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting1)7408 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) {
7409   absl::string_view hlo_string = R"(
7410 HloModule module
7411 
7412 ENTRY entry {
7413   %lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3}
7414   %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
7415   ROOT %dot = f32[48,32] dot(%lhs, %rhs),
7416     lhs_batch_dims={}, rhs_batch_dims={},
7417     lhs_contracting_dims={1}, rhs_contracting_dims={1},
7418     sharding={devices=[2,2]0,1,2,3}
7419 })";
7420 
7421   TF_ASSERT_OK_AND_ASSIGN(auto module,
7422                           PartitionComputation(hlo_string, /*num_devices=*/4));
7423   VLOG(1) << module->ToString();
7424 
7425   const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7426   const auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
7427   auto partial_replicated_rhs =
7428       AllOf(op::Shape("f32[32,50]"),
7429             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
7430   const auto root = module->entry_computation()->root_instruction();
7431   EXPECT_THAT(
7432       root, AllOf(op::Shape("f32[24,16]"),
7433                   op::DynamicSlice(
7434                       op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs),
7435                                           op::Shape("f32[24,32]"))),
7436                       _, _)));
7437 }
7438 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting2)7439 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting2) {
7440   absl::string_view hlo_string = R"(
7441 HloModule module
7442 
7443 ENTRY entry {
7444   %lhs = f32[48,100] parameter(0), sharding={replicated}
7445   %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
7446   ROOT %dot = f32[48,32] dot(%lhs, %rhs),
7447     lhs_batch_dims={}, rhs_batch_dims={},
7448     lhs_contracting_dims={1}, rhs_contracting_dims={1},
7449     sharding={devices=[2,2]0,1,2,3}
7450 })";
7451 
7452   TF_ASSERT_OK_AND_ASSIGN(auto module,
7453                           PartitionComputation(hlo_string, /*num_devices=*/4));
7454   VLOG(1) << module->ToString();
7455 
7456   const auto lhs = AllOf(op::Shape("f32[48,100]"), op::Parameter(0));
7457   const auto lhs_slice =
7458       AllOf(op::Shape("f32[24,100]"), op::DynamicSlice(lhs, _, _));
7459   const auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
7460   auto partial_replicated_rhs = AllOf(
7461       op::Shape("f32[16,100]"), op::AllReduce(op::DynamicUpdateSlice(
7462                                     _, op::CollectivePermute(rhs), _, _)));
7463   const auto root = module->entry_computation()->root_instruction();
7464   EXPECT_THAT(root, AllOf(op::Shape("f32[24,16]"),
7465                           op::Dot(lhs_slice, partial_replicated_rhs)));
7466 }
7467 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNoncontractingAndContracting3)7468 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) {
7469   absl::string_view hlo_string = R"(
7470 HloModule module
7471 
7472 ENTRY entry {
7473   %lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7474   %rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3}
7475   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7476     lhs_contracting_dims={0}, rhs_contracting_dims={0},
7477     sharding={devices=[2,2]1,0,3,2}
7478 })";
7479 
7480   TF_ASSERT_OK_AND_ASSIGN(auto module,
7481                           PartitionComputation(hlo_string, /*num_devices=*/4));
7482   VLOG(1) << module->ToString();
7483 
7484   const auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0));
7485   auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant()));
7486   const auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1));
7487   auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant()));
7488   const auto root = module->entry_computation()->root_instruction();
7489   EXPECT_THAT(root,
7490               AllOf(op::Shape("f32[12,16]"),
7491                     op::DynamicSlice(
7492                         AllOf(op::Shape("f32[24,16]"),
7493                               op::AllReduce(op::Dot(masked_lhs, masked_rhs))),
7494                         _, _)));
7495 }
7496 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndNonContracting)7497 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
7498   absl::string_view hlo_string = R"(
7499 HloModule module
7500 
7501 ENTRY entry {
7502   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
7503   %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
7504   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7505     lhs_batch_dims={0}, rhs_batch_dims={0},
7506     lhs_contracting_dims={2}, rhs_contracting_dims={2},
7507     sharding={devices=[2,2,1]0,1,2,3}
7508 })";
7509 
7510   TF_ASSERT_OK_AND_ASSIGN(auto module,
7511                           PartitionComputation(hlo_string, /*num_devices=*/4));
7512   VLOG(1) << module->ToString();
7513 
7514   const auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0));
7515   const auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
7516   auto partial_replicated_rhs =
7517       AllOf(op::Shape("f32[2,32,100]"),
7518             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _)));
7519   const auto root = module->entry_computation()->root_instruction();
7520   EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
7521                           op::Dot(lhs, partial_replicated_rhs)));
7522 }
7523 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting)7524 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting) {
7525   absl::string_view hlo_string = R"(
7526 HloModule module
7527 
7528 ENTRY entry {
7529   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7530   %rhs = f32[4,32,100] parameter(1), sharding={devices=[1,2,2]0,1,2,3}
7531   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7532     lhs_batch_dims={0}, rhs_batch_dims={0},
7533     lhs_contracting_dims={2}, rhs_contracting_dims={2},
7534     sharding={devices=[2,2,1]0,1,2,3}
7535 })";
7536 
7537   TF_ASSERT_OK_AND_ASSIGN(auto module,
7538                           PartitionComputation(hlo_string, /*num_devices=*/4));
7539   VLOG(1) << module->ToString();
7540 
7541   const auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
7542   const auto rhs = AllOf(op::Shape("f32[4,16,50]"), op::Parameter(1));
7543   auto resharded_rhs =
7544       AllOf(op::Shape("f32[2,32,50]"),
7545             op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs)))));
7546   const auto root = module->entry_computation()->root_instruction();
7547   EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
7548                           op::DynamicSlice(
7549                               AllOf(op::Shape("f32[2,24,32]"),
7550                                     op::AllReduce(op::Dot(lhs, resharded_rhs))),
7551                               _, _, _)));
7552 }
7553 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting2)7554 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting2) {
7555   absl::string_view hlo_string = R"(
7556 HloModule module
7557 
7558 ENTRY entry {
7559   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7560   %rhs = f32[4,32,100] parameter(1), sharding={replicated}
7561   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7562     lhs_batch_dims={0}, rhs_batch_dims={0},
7563     lhs_contracting_dims={2}, rhs_contracting_dims={2},
7564     sharding={devices=[2,2,1]0,1,2,3}
7565 })";
7566 
7567   TF_ASSERT_OK_AND_ASSIGN(auto module,
7568                           PartitionComputation(hlo_string, /*num_devices=*/4));
7569   VLOG(1) << module->ToString();
7570 
7571   const auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
7572   auto resharded_lhs =
7573       AllOf(op::Shape("f32[2,12,100]"),
7574             op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))));
7575   const auto rhs = AllOf(op::Shape("f32[4,32,100]"), op::Parameter(1));
7576   const auto rhs_slice =
7577       AllOf(op::Shape("f32[2,32,100]"), op::DynamicSlice(rhs, _, _, _));
7578   const auto root = module->entry_computation()->root_instruction();
7579   EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
7580                           op::Dot(resharded_lhs, rhs_slice)));
7581 }
7582 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchNonContractingAndContracting)7583 TEST_F(SpmdPartitioningTest,
7584        Dot2DPartitionedBatchNonContractingAndContracting) {
7585   absl::string_view hlo_string = R"(
7586 HloModule module
7587 
7588 ENTRY entry {
7589   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7590   %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
7591   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7592     lhs_batch_dims={0}, rhs_batch_dims={0},
7593     lhs_contracting_dims={2}, rhs_contracting_dims={2},
7594     sharding={devices=[2,1,2]0,1,2,3}
7595 })";
7596 
7597   TF_ASSERT_OK_AND_ASSIGN(auto module,
7598                           PartitionComputation(hlo_string, /*num_devices=*/4));
7599   VLOG(1) << module->ToString();
7600 
7601   const auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
7602   const auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
7603   auto partial_replicated_lhs =
7604       AllOf(op::Shape("f32[2,24,100]"),
7605             op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _)));
7606   const auto root = module->entry_computation()->root_instruction();
7607   EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"),
7608                           op::Dot(partial_replicated_lhs, rhs)));
7609 }
7610 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndReshard)7611 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) {
7612   absl::string_view hlo_string = R"(
7613 HloModule module
7614 
7615 ENTRY entry {
7616   %lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3}
7617   %rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3}
7618   ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs),
7619     lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
7620     lhs_contracting_dims={3}, rhs_contracting_dims={3},
7621     sharding={devices=[1,2,2,1]0,1,2,3}
7622 })";
7623 
7624   TF_ASSERT_OK_AND_ASSIGN(auto module,
7625                           PartitionComputation(hlo_string, /*num_devices=*/4));
7626   VLOG(1) << module->ToString();
7627 
7628   const auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0));
7629   const auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1));
7630   auto partial_replicated_rhs =
7631       AllOf(op::Shape("f32[2,8,32,100]"),
7632             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _)));
7633   auto dot =
7634       AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs));
7635   auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot));
7636   auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape));
7637   auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all));
7638   const auto root = module->entry_computation()->root_instruction();
7639   EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose)));
7640 }
7641 
TEST_F(SpmdPartitioningTest,SimpleDotPartial)7642 TEST_F(SpmdPartitioningTest, SimpleDotPartial) {
7643   absl::string_view hlo_string = R"(
7644 HloModule module
7645 
7646 ENTRY entry {
7647   %lhs = f32[2,24,100] parameter(0),
7648     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7649   %rhs = f32[2,32,100] parameter(1),
7650     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7651   ROOT %dot = f32[2,24,32] dot(%lhs, %rhs),
7652     lhs_batch_dims={0}, rhs_batch_dims={0},
7653     lhs_contracting_dims={2}, rhs_contracting_dims={2},
7654     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7655 })";
7656 
7657   TF_ASSERT_OK_AND_ASSIGN(auto module,
7658                           PartitionComputation(hlo_string, /*num_devices=*/4));
7659   VLOG(1) << module->ToString();
7660 
7661   const auto lhs = AllOf(op::Shape("f32[1,24,100]"), op::Parameter(0));
7662   const auto rhs = AllOf(op::Shape("f32[1,32,100]"), op::Parameter(1));
7663   auto dot = AllOf(op::Shape("f32[1,24,32]"), op::Dot(lhs, rhs));
7664   const auto root = module->entry_computation()->root_instruction();
7665   EXPECT_THAT(root, dot);
7666 }
7667 
TEST_F(SpmdPartitioningTest,DotPartialContracting)7668 TEST_F(SpmdPartitioningTest, DotPartialContracting) {
7669   absl::string_view hlo_string = R"(
7670 HloModule module
7671 
7672 ENTRY entry {
7673   %lhs = f32[24,100] parameter(0),
7674     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7675   %rhs = f32[32,100] parameter(1),
7676     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7677   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7678     lhs_batch_dims={}, rhs_batch_dims={},
7679     lhs_contracting_dims={1}, rhs_contracting_dims={1},
7680     sharding={replicated}
7681 })";
7682 
7683   TF_ASSERT_OK_AND_ASSIGN(auto module,
7684                           PartitionComputation(hlo_string, /*num_devices=*/4));
7685   VLOG(1) << module->ToString();
7686 
7687   const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7688   const auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
7689   auto dot = AllOf(op::Shape("f32[24,32]"), op::Dot(lhs, rhs));
7690   const auto root = module->entry_computation()->root_instruction();
7691   EXPECT_THAT(root, op::AllReduce(dot));
7692 }
7693 
TEST_F(SpmdPartitioningTest,DotPartialContracting2)7694 TEST_F(SpmdPartitioningTest, DotPartialContracting2) {
7695   absl::string_view hlo_string = R"(
7696 HloModule module
7697 
7698 ENTRY entry {
7699   %lhs = f32[24,100] parameter(0),
7700     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7701   %rhs = f32[32,100] parameter(1),
7702     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7703   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7704     lhs_batch_dims={}, rhs_batch_dims={},
7705     lhs_contracting_dims={1}, rhs_contracting_dims={1},
7706     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
7707 })";
7708 
7709   TF_ASSERT_OK_AND_ASSIGN(auto module,
7710                           PartitionComputation(hlo_string, /*num_devices=*/4));
7711   VLOG(1) << module->ToString();
7712 
7713   const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7714   const auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
7715   auto dot =
7716       AllOf(op::Shape("f32[12,32]"),
7717             op::Dot(AllOf(op::Shape("f32[12,50]"), op::DynamicSlice(lhs, _, _)),
7718                     rhs));
7719   const auto root = module->entry_computation()->root_instruction();
7720   EXPECT_THAT(root, op::AllReduce(dot));
7721 }
7722 
TEST_F(SpmdPartitioningTest,DotPartialContracting3)7723 TEST_F(SpmdPartitioningTest, DotPartialContracting3) {
7724   absl::string_view hlo_string = R"(
7725 HloModule module
7726 
7727 ENTRY entry {
7728   %lhs = f32[24,100] parameter(0),
7729     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7730   %rhs = f32[32,100] parameter(1),
7731     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7732   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7733     lhs_batch_dims={}, rhs_batch_dims={},
7734     lhs_contracting_dims={1}, rhs_contracting_dims={1},
7735     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7736 })";
7737 
7738   TF_ASSERT_OK_AND_ASSIGN(auto module,
7739                           PartitionComputation(hlo_string, /*num_devices=*/8));
7740   VLOG(1) << module->ToString();
7741 
7742   const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7743   const auto rhs =
7744       AllOf(op::Shape("f32[16,50]"), op::DynamicSlice(op::Parameter(1), _, _));
7745   auto dot = AllOf(op::Shape("f32[24,16]"), op::Dot(lhs, rhs));
7746   const auto root = module->entry_computation()->root_instruction();
7747   EXPECT_THAT(root, op::CollectivePermute(op::AllReduce(dot)));
7748 }
7749 
TEST_F(SpmdPartitioningTest,DotBatchAndPartialContracting)7750 TEST_F(SpmdPartitioningTest, DotBatchAndPartialContracting) {
7751   absl::string_view hlo_string = R"(
7752 HloModule module
7753 
7754 ENTRY entry {
7755   %lhs = f32[4,24,100] parameter(0),
7756     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
7757   %rhs = f32[4,32,100] parameter(1),
7758     sharding={devices=[2,1,2,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}
7759   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7760     lhs_batch_dims={0}, rhs_batch_dims={0},
7761     lhs_contracting_dims={2}, rhs_contracting_dims={2},
7762     sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7763 })";
7764 
7765   TF_ASSERT_OK_AND_ASSIGN(auto module,
7766                           PartitionComputation(hlo_string, /*num_devices=*/8));
7767   VLOG(1) << module->ToString();
7768 
7769   const auto lhs = AllOf(op::Shape("f32[2,12,50]"), op::Parameter(0));
7770   const auto rhs = AllOf(op::Shape("f32[2,32,50]"), op::Parameter(1));
7771   auto dot = AllOf(op::Shape("f32[2,12,32]"), op::Dot(lhs, rhs));
7772   const auto root = module->entry_computation()->root_instruction();
7773   EXPECT_THAT(root, op::AllReduce(dot));
7774 }
7775 
TEST_F(SpmdPartitioningTest,DotPartialNonContracting)7776 TEST_F(SpmdPartitioningTest, DotPartialNonContracting) {
7777   absl::string_view hlo_string = R"(
7778 HloModule module
7779 
7780 ENTRY entry {
7781   %lhs = f32[24,8,100] parameter(0),
7782     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7783   %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,2,1,3}
7784   ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
7785     lhs_batch_dims={}, rhs_batch_dims={},
7786     lhs_contracting_dims={2}, rhs_contracting_dims={1},
7787     sharding={devices=[2,1,2]0,1,2,3}
7788 })";
7789 
7790   TF_ASSERT_OK_AND_ASSIGN(auto module,
7791                           PartitionComputation(hlo_string, /*num_devices=*/4));
7792   VLOG(1) << module->ToString();
7793 
7794   const auto lhs = AllOf(op::Shape("f32[12,8,100]"), op::Parameter(0));
7795   const auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
7796   auto partially_replicated_rhs =
7797       AllOf(op::Shape("f32[16,100]"),
7798             op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), rhs, _, _)));
7799   auto dot =
7800       AllOf(op::Shape("f32[12,8,16]"), op::Dot(lhs, partially_replicated_rhs));
7801   const auto root = module->entry_computation()->root_instruction();
7802   EXPECT_THAT(root, dot);
7803 }
7804 
TEST_F(SpmdPartitioningTest,DotPartialNonContractingPartialMatch)7805 TEST_F(SpmdPartitioningTest, DotPartialNonContractingPartialMatch) {
7806   absl::string_view hlo_string = R"(
7807 HloModule module
7808 
7809 ENTRY entry {
7810   %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
7811   %rhs = f32[32,100] parameter(1),
7812     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
7813   ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
7814     lhs_batch_dims={}, rhs_batch_dims={},
7815     lhs_contracting_dims={2}, rhs_contracting_dims={1},
7816     sharding={devices=[2,1,2]0,1,2,3}
7817 })";
7818 
7819   TF_ASSERT_OK_AND_ASSIGN(auto module,
7820                           PartitionComputation(hlo_string, /*num_devices=*/4));
7821   VLOG(1) << module->ToString();
7822 
7823   const auto lhs = AllOf(op::Shape("f32[12,4,100]"), op::Parameter(0));
7824   const auto rhs = AllOf(op::Shape("f32[16,100]"), op::Parameter(1));
7825   auto partially_replicated_lhs = AllOf(
7826       op::Shape("f32[12,8,100]"),
7827       op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), lhs, _, _, _)));
7828   auto dot =
7829       AllOf(op::Shape("f32[12,8,16]"), op::Dot(partially_replicated_lhs, rhs));
7830   const auto root = module->entry_computation()->root_instruction();
7831   EXPECT_THAT(root, dot);
7832 }
7833 
TEST_F(SpmdPartitioningTest,DotPartialContractingPartialMatch)7834 TEST_F(SpmdPartitioningTest, DotPartialContractingPartialMatch) {
7835   absl::string_view hlo_string = R"(
7836 HloModule module
7837 
7838 ENTRY entry {
7839   %lhs = f32[24,8,100] parameter(0), sharding={devices=[1,2,2]0,1,2,3}
7840   %rhs = f32[32,8,100] parameter(1),
7841     sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}
7842   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7843     lhs_batch_dims={}, rhs_batch_dims={},
7844     lhs_contracting_dims={1,2}, rhs_contracting_dims={1,2},
7845     sharding={replicated}
7846 })";
7847 
7848   TF_ASSERT_OK_AND_ASSIGN(auto module,
7849                           PartitionComputation(hlo_string, /*num_devices=*/4));
7850   VLOG(1) << module->ToString();
7851 
7852   const auto lhs = AllOf(op::Shape("f32[24,4,50]"), op::Parameter(0));
7853   const auto rhs = AllOf(op::Shape("f32[32,8,50]"), op::Parameter(1));
7854   auto dot = AllOf(op::Shape("f32[24,32]"),
7855                    op::Dot(lhs, AllOf(op::Shape("f32[32,4,50]"),
7856                                       op::DynamicSlice(rhs, _, _, _))));
7857   const auto root = module->entry_computation()->root_instruction();
7858   EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot)));
7859 }
7860 
TEST_F(SpmdPartitioningTest,DotNonContractingPartialMatchContractingMatch)7861 TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) {
7862   absl::string_view hlo_string = R"(
7863 HloModule module
7864 
7865 ENTRY entry {
7866   %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7867   %rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3}
7868   ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
7869     lhs_batch_dims={}, rhs_batch_dims={},
7870     lhs_contracting_dims={2}, rhs_contracting_dims={0},
7871     sharding={devices=[2,2,1]0,1,2,3}
7872 })";
7873 
7874   TF_ASSERT_OK_AND_ASSIGN(auto module,
7875                           PartitionComputation(hlo_string, /*num_devices=*/4));
7876   VLOG(1) << module->ToString();
7877 
7878   const auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0));
7879   const auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1));
7880   auto dot = AllOf(
7881       op::Shape("f32[12,8,50]"),
7882       op::Dot(lhs, AllOf(op::Shape("f32[50,50]"),
7883                          op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
7884   const auto root = module->entry_computation()->root_instruction();
7885   EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"),
7886                           op::DynamicSlice(op::AllReduce(dot), _, _, _)))
7887       << module->ToString();
7888 }
7889 
TEST_F(SpmdPartitioningTest,DotLHSMutiNonContractingRHSNotMatch)7890 TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) {
7891   absl::string_view hlo_string = R"(
7892 HloModule module
7893 
7894 ENTRY entry {
7895   %lhs = f32[24,8,10] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
7896   %rhs = f32[10,50] parameter(1),
7897     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
7898   ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
7899     lhs_batch_dims={}, rhs_batch_dims={},
7900     lhs_contracting_dims={2}, rhs_contracting_dims={0},
7901     sharding={devices=[2,2,1]0,1,2,3}
7902 })";
7903 
7904   TF_ASSERT_OK_AND_ASSIGN(auto module,
7905                           PartitionComputation(hlo_string, /*num_devices=*/4));
7906   VLOG(1) << module->ToString();
7907 
7908   const auto lhs = AllOf(op::Shape("f32[12,4,10]"), op::Parameter(0));
7909   const auto rhs = AllOf(op::Shape("f32[5,50]"), op::Parameter(1));
7910   auto dot = AllOf(
7911       op::Shape("f32[12,4,50]"),
7912       op::Dot(lhs, AllOf(op::Shape("f32[10,50]"),
7913                          op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
7914   const auto root = module->entry_computation()->root_instruction();
7915   EXPECT_THAT(root, dot) << module->ToString();
7916 }
7917 
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_TileToReplicate)7918 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) {
7919   absl::string_view hlo_string = R"(
7920 HloModule module
7921 
7922 ENTRY entry {
7923   constant = f32[6,3]{1,0}
7924     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
7925     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7926   constant.1 = f32[6,3]{1,0}
7927     constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
7928     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7929    multiply = f32[6,3]{1,0} multiply(constant, constant.1),
7930     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7931    ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
7932     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated, manual}}
7933 }
7934 )";
7935 
7936   TF_ASSERT_OK_AND_ASSIGN(auto module,
7937                           PartitionComputation(hlo_string, /*num_devices=*/4));
7938   VLOG(1) << module->ToString();
7939 
7940   auto multiply_lhs =
7941       AllOf(op::Shape("f32[6,2]"),
7942             op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
7943                              op::Constant(), op::Reshape()));
7944   auto multiply_rhs =
7945       AllOf(op::Shape("f32[6,2]"),
7946             op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
7947                              op::Constant(), op::Reshape()));
7948   auto multiply =
7949       AllOf(op::Shape("f32[6,2]"), op::Multiply(multiply_lhs, multiply_rhs));
7950   auto replicated_lhs =
7951       AllOf(op::Shape("f32[6,3]"),
7952             op::Slice(op::AllReduce(op::DynamicUpdateSlice(
7953                 op::Broadcast(), multiply, op::Constant(), op::Reshape()))));
7954   const auto root = module->entry_computation()->root_instruction();
7955   EXPECT_THAT(root, AllOf(op::Shape("f32[6,3]"),
7956                           op::Add(replicated_lhs, op::Constant())));
7957 }
7958 
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_ReplicateToTile)7959 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_ReplicateToTile) {
7960   absl::string_view hlo_string = R"(
7961 HloModule module
7962 
7963 ENTRY entry {
7964   constant = f32[6,3]{1,0}
7965     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
7966     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
7967   constant.1 = f32[6,3]{1,0}
7968     constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
7969     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
7970    multiply = f32[6,3]{1,0} multiply(constant, constant.1),
7971     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
7972    ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
7973     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7974 }
7975 )";
7976 
7977   TF_ASSERT_OK_AND_ASSIGN(auto module,
7978                           PartitionComputation(hlo_string, /*num_devices=*/4));
7979   VLOG(1) << module->ToString();
7980 
7981   auto multiply = AllOf(op::Shape("f32[6,3]"),
7982                         op::Multiply(op::Constant(), op::Constant()));
7983   auto add_lhs = AllOf(op::Shape("f32[6,2]"),
7984                        op::DynamicSlice(op::Pad(multiply, op::Constant()),
7985                                         op::Constant(), op::Reshape()));
7986   auto add_rhs = AllOf(op::Shape("f32[6,2]"),
7987                        op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
7988                                         op::Constant(), op::Reshape()));
7989   const auto root = module->entry_computation()->root_instruction();
7990   EXPECT_THAT(root, AllOf(op::Shape("f32[6,2]"), op::Add(add_lhs, add_rhs)));
7991 }
7992 
TEST_F(SpmdPartitioningTest,ElementwiseTest_PartialReplicateToTiledHaloExchange)7993 TEST_F(SpmdPartitioningTest,
7994        ElementwiseTest_PartialReplicateToTiledHaloExchange) {
7995   absl::string_view hlo_string = R"(
7996 HloModule module
7997 
7998 ENTRY entry {
7999   constant = f32[6,3]{1,0}
8000     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
8001     sharding={replicated}
8002   constant.1 = f32[6,3]{1,0}
8003     constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
8004     sharding={replicated}
8005   multiply = f32[6,3]{1,0} multiply(constant, constant.1),
8006     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8007   ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
8008     sharding={devices=[4,1]0,1,2,3}
8009 }
8010 )";
8011 
8012   TF_ASSERT_OK_AND_ASSIGN(auto module,
8013                           PartitionComputation(hlo_string, /*num_devices=*/4));
8014   VLOG(1) << module->ToString();
8015   auto partial_replicate_lhs =
8016       AllOf(op::Shape("f32[3,3]"),
8017             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
8018   auto partial_replicate_rhs =
8019       AllOf(op::Shape("f32[3,3]"),
8020             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
8021   auto multiply =
8022       AllOf(op::Shape("f32[3,3]"),
8023             op::Multiply(partial_replicate_lhs, partial_replicate_rhs));
8024   auto right_halo =
8025       AllOf(op::Shape("f32[1,3]"), op::CollectivePermute(op::Slice(multiply)));
8026   auto add_lhs = AllOf(
8027       op::Shape("f32[2,3]"),
8028       op::DynamicSlice(
8029           op::DynamicSlice(
8030               op::Pad(op::Concatenate(multiply, right_halo), op::Constant()),
8031               op::Reshape(), op::Constant()),
8032           op::Subtract(), op::Subtract()));
8033   auto add_rhs = AllOf(op::Shape("f32[2,3]"),
8034                        op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
8035                                         op::Reshape(), op::Constant()));
8036   const auto root = module->entry_computation()->root_instruction();
8037   EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs)));
8038 }
8039 
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshard)8040 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) {
8041   absl::string_view hlo_string = R"(
8042 HloModule module
8043 
8044 ENTRY entry {
8045   %param0 = f32[8,8] parameter(0)
8046   %copy = f32[8,8] copy(%param0),
8047     sharding={devices=[2,2]0,1,2,3}
8048   ROOT %copy0 = f32[8,8] copy(%copy),
8049     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8050 })";
8051 
8052   TF_ASSERT_OK_AND_ASSIGN(auto module,
8053                           PartitionComputation(hlo_string, /*num_devices=*/4));
8054   VLOG(1) << module->ToString();
8055   auto tiled = AllOf(op::Shape("f32[4,4]"),
8056                      op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8057                                                op::Reshape())));
8058   auto partially_replicated = AllOf(
8059       op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice(
8060                                  op::Broadcast(_), tiled, _, _))));
8061   const auto root = module->entry_computation()->root_instruction();
8062   EXPECT_THAT(root, partially_replicated);
8063 }
8064 
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshardUnevenPartition)8065 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
8066   absl::string_view hlo_string = R"(
8067 HloModule module
8068 
8069 ENTRY entry {
8070   %param0 = f32[8,8] parameter(0),
8071     sharding={devices=[2,3]0,1,2,3,4,5}
8072   ROOT %copy0 = f32[8,8] copy(%param0),
8073     sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
8074 })";
8075 
8076   TF_ASSERT_OK_AND_ASSIGN(auto module,
8077                           PartitionComputation(hlo_string, /*num_devices=*/6));
8078   VLOG(1) << module->ToString();
8079   auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
8080   auto partially_replicated = AllOf(
8081       op::Shape("f32[8,4]"),
8082       op::Copy(op::Reshape(
8083           op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
8084               op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
8085   const auto root = module->entry_computation()->root_instruction();
8086   EXPECT_THAT(root, partially_replicated);
8087 }
8088 
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshardUnevenPartition)8089 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
8090   absl::string_view hlo_string = R"(
8091 HloModule module
8092 
8093 ENTRY entry {
8094   %param0 = f32[8,8] parameter(0),
8095     sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
8096   ROOT %copy0 = f32[8,8] copy(%param0),
8097     sharding={devices=[2,3]0,1,2,3,4,5}
8098 })";
8099 
8100   TF_ASSERT_OK_AND_ASSIGN(auto module,
8101                           PartitionComputation(hlo_string, /*num_devices=*/6));
8102   VLOG(1) << module->ToString();
8103   auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
8104   auto tiled = AllOf(
8105       op::Shape("f32[4,3]"),
8106       op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
8107                                             op::Reshape(partial_replicated)))),
8108                                         _),
8109                                 _, _)));
8110   const auto root = module->entry_computation()->root_instruction();
8111   EXPECT_THAT(root, tiled);
8112 }
8113 
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshard)8114 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
8115   absl::string_view hlo_string = R"(
8116 HloModule module
8117 
8118 ENTRY entry {
8119   %param0 = f32[8,8] parameter(0)
8120   %copy = f32[8,8] copy(%param0),
8121     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8122   ROOT %copy0 = f32[8,8] copy(%copy),
8123     sharding={devices=[2,2]0,1,2,3}
8124 })";
8125 
8126   TF_ASSERT_OK_AND_ASSIGN(auto module,
8127                           PartitionComputation(hlo_string, /*num_devices=*/4));
8128   VLOG(1) << module->ToString();
8129   auto partially_replicated =
8130       AllOf(op::Shape("f32[4,8]"),
8131             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8132                                       op::Constant())));
8133   auto tiled =
8134       AllOf(op::Shape("f32[4,4]"),
8135             op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
8136                                       op::Subtract())));
8137   const auto root = module->entry_computation()->root_instruction();
8138   EXPECT_THAT(root, tiled);
8139 }
8140 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_AllReduce)8141 TEST_F(SpmdPartitioningTest,
8142        PartialReplicateToPartialReplicateReshard_AllReduce) {
8143   absl::string_view hlo_string = R"(
8144 HloModule module
8145 
8146 ENTRY entry {
8147   %param0 = f32[8,8] parameter(0)
8148   %copy = f32[8,8] copy(param0),
8149     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8150   ROOT %copy0 = f32[8,8] copy(%copy),
8151     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8152 })";
8153 
8154   TF_ASSERT_OK_AND_ASSIGN(auto module,
8155                           PartitionComputation(hlo_string, /*num_devices=*/8));
8156 
8157   VLOG(1) << module->ToString();
8158   auto partially_replicated_init =
8159       AllOf(op::Shape("f32[4,4]"),
8160             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8161                                       op::Reshape())));
8162   auto partially_replicated =
8163       AllOf(op::Shape("f32[4,8]"),
8164             op::Copy(op::AllReduce(op::DynamicUpdateSlice(
8165                 op::Broadcast(_), partially_replicated_init, _, _))));
8166   const auto root = module->entry_computation()->root_instruction();
8167   EXPECT_THAT(root, partially_replicated);
8168 }
8169 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_DynamicSlice)8170 TEST_F(SpmdPartitioningTest,
8171        PartialReplicateToPartialReplicateReshard_DynamicSlice) {
8172   absl::string_view hlo_string = R"(
8173 HloModule module
8174 
8175 ENTRY entry {
8176   %param0 = f32[8,8] parameter(0)
8177   %copy = f32[8,8] copy(%param0),
8178     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8179   ROOT %copy0 = f32[8,8] copy(%copy),
8180     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8181 })";
8182 
8183   TF_ASSERT_OK_AND_ASSIGN(auto module,
8184                           PartitionComputation(hlo_string, /*num_devices=*/8));
8185   VLOG(1) << module->ToString();
8186   auto partially_replicated =
8187       AllOf(op::Shape("f32[4,8]"),
8188             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8189                                       op::Constant())));
8190   auto tiled =
8191       AllOf(op::Shape("f32[4,4]"),
8192             op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
8193                                       op::Subtract())));
8194   const auto root = module->entry_computation()->root_instruction();
8195   EXPECT_THAT(root, tiled);
8196 }
8197 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardWithCollectivePermute)8198 TEST_F(SpmdPartitioningTest,
8199        PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
8200   absl::string_view hlo_string = R"(
8201 HloModule module
8202 
8203 ENTRY entry {
8204   %param0 = f32[8,8] parameter(0)
8205   %copy = f32[8,8] copy(param0),
8206     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8207   ROOT %copy0 = f32[8,8] copy(%copy),
8208     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8209 })";
8210 
8211   TF_ASSERT_OK_AND_ASSIGN(auto module,
8212                           PartitionComputation(hlo_string, /*num_devices=*/8));
8213 
8214   VLOG(1) << module->ToString();
8215   auto partially_replicated_init =
8216       AllOf(op::Shape("f32[4,4]"),
8217             op::CollectivePermute(op::Copy(op::DynamicSlice(
8218                 op::Parameter(0), op::Reshape(), op::Reshape()))));
8219   auto partially_replicated =
8220       AllOf(op::Shape("f32[8,4]"),
8221             op::Copy(op::AllReduce(op::DynamicUpdateSlice(
8222                 op::Broadcast(_), partially_replicated_init, _, _))));
8223   const auto root = module->entry_computation()->root_instruction();
8224   EXPECT_THAT(root, partially_replicated);
8225 }
8226 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardCollectivePermute1)8227 TEST_F(SpmdPartitioningTest,
8228        PartialReplicateToPartialReplicateReshardCollectivePermute1) {
8229   absl::string_view hlo_string = R"(
8230 HloModule module
8231 
8232 ENTRY entry {
8233   %param0 = f32[8,8] parameter(0)
8234   %copy = f32[8,8] copy(%param0),
8235     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8236   ROOT %copy0 = f32[8,8] copy(%copy),
8237     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8238 })";
8239 
8240   TF_ASSERT_OK_AND_ASSIGN(auto module,
8241                           PartitionComputation(hlo_string, /*num_devices=*/8));
8242   VLOG(1) << module->ToString();
8243   auto partially_replicated =
8244       AllOf(op::Shape("f32[8,4]"),
8245             op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
8246                                       op::Reshape())));
8247   auto tiled =
8248       AllOf(op::Shape("f32[4,4]"),
8249             op::Copy(op::CollectivePermute(op::DynamicSlice(
8250                 partially_replicated, op::Subtract(), op::Subtract()))));
8251   const auto root = module->entry_computation()->root_instruction();
8252   EXPECT_THAT(root, tiled);
8253 }
8254 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange)8255 TEST_F(SpmdPartitioningTest,
8256        PartialReplicateToPartialReplicateReshardHaloExchange) {
8257   absl::string_view hlo_string = R"(
8258 HloModule module
8259 
8260 ENTRY entry {
8261   %param0 = f32[6,3] parameter(0)
8262   %copy = f32[6,3] copy(param0),
8263     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8264   ROOT %copy0 = f32[6,3] copy(%copy),
8265     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8266 })";
8267 
8268   TF_ASSERT_OK_AND_ASSIGN(auto module,
8269                           PartitionComputation(hlo_string, /*num_devices=*/8));
8270 
8271   VLOG(1) << module->ToString();
8272   auto partially_replicated_init =
8273       AllOf(op::Shape("f32[2,3]"),
8274             op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
8275                                       op::Reshape(), op::Constant())));
8276   auto slice =
8277       AllOf(op::Shape("f32[2,3]"),
8278             op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice(
8279                                                  partially_replicated_init)),
8280                                              partially_replicated_init),
8281                              _, _));
8282   auto partially_replicated =
8283       AllOf(op::Shape("f32[3,3]"),
8284             op::Copy(op::Slice(op::AllReduce(
8285                 op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _)))));
8286   const auto root = module->entry_computation()->root_instruction();
8287   EXPECT_THAT(root, partially_replicated);
8288 }
8289 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange1)8290 TEST_F(SpmdPartitioningTest,
8291        PartialReplicateToPartialReplicateReshardHaloExchange1) {
8292   absl::string_view hlo_string = R"(
8293 HloModule module
8294 
8295 ENTRY entry {
8296   %param0 = f32[6,3] parameter(0)
8297   %copy = f32[6,3] copy(param0),
8298     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8299   ROOT %copy0 = f32[6,3] copy(%copy),
8300     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8301 })";
8302 
8303   TF_ASSERT_OK_AND_ASSIGN(auto module,
8304                           PartitionComputation(hlo_string, /*num_devices=*/8));
8305 
8306   VLOG(1) << module->ToString();
8307   auto partially_replicated_init =
8308       AllOf(op::Shape("f32[3,3]"),
8309             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8310                                       op::Constant())));
8311   auto slice = AllOf(
8312       op::Shape("f32[4,3]"),
8313       op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init,
8314                                                op::CollectivePermute(op::Slice(
8315                                                    partially_replicated_init))),
8316                                op::Constant()),
8317                        _, _));
8318   auto partially_replicated =
8319       AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _)));
8320   const auto root = module->entry_computation()->root_instruction();
8321   EXPECT_THAT(root, partially_replicated);
8322 }
8323 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCount)8324 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCount) {
8325   absl::string_view hlo_string = R"(
8326 HloModule module
8327 
8328 ENTRY entry {
8329   %lhs = f32[16,801,1,1024] parameter(0)
8330   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8331     sharding={devices=[1,1,1,2]0,1}
8332   %rhs = f32[16,801,1,1024] parameter(1)
8333   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8334     sharding={devices=[1,1,1,2]0,1}
8335   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8336     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8337     window={size=801x1 pad=2_2x0_0},
8338     sharding={devices=[1,1,1,2]0,1}
8339 })";
8340 
8341   TF_ASSERT_OK_AND_ASSIGN(auto module,
8342                           PartitionComputation(hlo_string, /*num_devices=*/2));
8343 
8344   VLOG(1) << module->ToString();
8345   const auto root = module->entry_computation()->root_instruction();
8346   const auto lhs = AllOf(
8347       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8348                                 op::Constant(), op::Reshape())),
8349       op::Shape("f32[16,801,1,512]"));
8350   const auto rhs = AllOf(
8351       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8352                                 op::Constant(), op::Reshape())),
8353       op::Shape("f32[16,801,1,512]"));
8354   EXPECT_THAT(root,
8355               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]")));
8356 }
8357 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountRHSAlignWithLHS)8358 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountRHSAlignWithLHS) {
8359   absl::string_view hlo_string = R"(
8360 HloModule module
8361 
8362 ENTRY entry {
8363   %lhs = f32[16,801,1,1024] parameter(0)
8364   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8365     sharding={devices=[1,1,1,2]0,1}
8366   %rhs = f32[16,801,1,1024] parameter(1)
8367   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8368     sharding={devices=[1,2,1,1]0,1}
8369   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8370     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8371     window={size=801x1 pad=2_2x0_0},
8372     sharding={devices=[1,1,1,2]0,1}
8373 })";
8374 
8375   TF_ASSERT_OK_AND_ASSIGN(auto module,
8376                           PartitionComputation(hlo_string, /*num_devices=*/2));
8377   VLOG(1) << module->ToString();
8378   const auto root = module->entry_computation()->root_instruction();
8379   const auto lhs = AllOf(
8380       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8381                                 op::Constant(), op::Reshape())),
8382       op::Shape("f32[16,801,1,512]"));
8383   const auto rhs =
8384       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8385                                       op::Constant(), op::Reshape(),
8386                                       op::Constant(), op::Constant())),
8387             op::Shape("f32[16,401,1,1024]"));
8388   auto resharded_rhs = AllOf(
8389       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
8390       op::Shape("f32[16,801,1,512]"));
8391   EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
8392                           op::Shape("f32[5,1,1,512]")));
8393 }
8394 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountLHSAlignWithRHS)8395 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountLHSAlignWithRHS) {
8396   absl::string_view hlo_string = R"(
8397 HloModule module
8398 
8399 ENTRY entry {
8400   %lhs = f32[16,801,1,1024] parameter(0)
8401   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8402     sharding={devices=[1,2,1,1]0,1}
8403   %rhs = f32[16,801,1,1024] parameter(1)
8404   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8405     sharding={devices=[1,1,1,2]0,1}
8406   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8407     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8408     window={size=801x1 pad=2_2x0_0},
8409     sharding={devices=[1,1,1,2]0,1}
8410 })";
8411 
8412   TF_ASSERT_OK_AND_ASSIGN(auto module,
8413                           PartitionComputation(hlo_string, /*num_devices=*/2));
8414   VLOG(1) << module->ToString();
8415   const auto root = module->entry_computation()->root_instruction();
8416   const auto rhs = AllOf(
8417       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8418                                 op::Constant(), op::Reshape())),
8419       op::Shape("f32[16,801,1,512]"));
8420   const auto lhs =
8421       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8422                                       op::Constant(), op::Reshape(),
8423                                       op::Constant(), op::Constant())),
8424             op::Shape("f32[16,401,1,1024]"));
8425   auto resharded_lhs = AllOf(
8426       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8427       op::Shape("f32[16,801,1,512]"));
8428   EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
8429                           op::Shape("f32[5,1,1,512]")));
8430 }
8431 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithLHS)8432 TEST_F(SpmdPartitioningTest,
8433        PartitionConvWithBathGroupCountOutputAlignWithLHS) {
8434   absl::string_view hlo_string = R"(
8435 HloModule module
8436 
8437 ENTRY entry {
8438   %lhs = f32[16,801,1,1024] parameter(0)
8439   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8440     sharding={devices=[1,1,1,2]0,1}
8441   %rhs = f32[16,801,1,1024] parameter(1)
8442   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8443     sharding={devices=[1,1,1,2]0,1}
8444   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8445     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8446     window={size=801x1 pad=2_2x0_0},
8447     sharding={devices=[2,1,1,1]0,1}
8448 })";
8449 
8450   TF_ASSERT_OK_AND_ASSIGN(auto module,
8451                           PartitionComputation(hlo_string, /*num_devices=*/2));
8452   VLOG(1) << module->ToString();
8453   const auto root = module->entry_computation()->root_instruction();
8454   const auto lhs = AllOf(
8455       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8456                                 op::Constant(), op::Reshape())),
8457       op::Shape("f32[16,801,1,512]"));
8458   const auto rhs = AllOf(
8459       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8460                                 op::Constant(), op::Reshape())),
8461       op::Shape("f32[16,801,1,512]"));
8462   auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]"));
8463   EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
8464                               op::Reshape(op::Pad(conv, op::Constant()))))),
8465                           op::Shape("f32[3,1,1,1024]")));
8466 }
8467 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithRHS)8468 TEST_F(SpmdPartitioningTest,
8469        PartitionConvWithBathGroupCountOutputAlignWithRHS) {
8470   absl::string_view hlo_string = R"(
8471 HloModule module
8472 
8473 ENTRY entry {
8474   %lhs = f32[16,801,1,1024] parameter(0)
8475   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8476     sharding={devices=[1,2,1,1]0,1}
8477   %rhs = f32[16,801,1,1024] parameter(1)
8478   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8479     sharding={devices=[1,1,1,2]0,1}
8480   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8481     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8482     window={size=801x1 pad=2_2x0_0},
8483     sharding={devices=[2,1,1,1]0,1}
8484 })";
8485 
8486   TF_ASSERT_OK_AND_ASSIGN(auto module,
8487                           PartitionComputation(hlo_string, /*num_devices=*/2));
8488   VLOG(1) << module->ToString();
8489   const auto root = module->entry_computation()->root_instruction();
8490   const auto rhs = AllOf(
8491       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8492                                 op::Constant(), op::Reshape())),
8493       op::Shape("f32[16,801,1,512]"));
8494   const auto lhs =
8495       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8496                                       op::Constant(), op::Reshape(),
8497                                       op::Constant(), op::Constant())),
8498             op::Shape("f32[16,401,1,1024]"));
8499   auto resharded_lhs = AllOf(
8500       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8501       op::Shape("f32[16,801,1,512]"));
8502   auto conv =
8503       AllOf(op::Convolution(resharded_lhs, rhs), op::Shape("f32[5,1,1,512]"));
8504   EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
8505                               op::Reshape(op::Pad(conv, op::Constant()))))),
8506                           op::Shape("f32[3,1,1,1024]")));
8507 }
8508 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount)8509 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount) {
8510   absl::string_view hlo_string = R"(
8511 HloModule module
8512 
8513 ENTRY entry {
8514   %lhs = f32[16,801,1,1024] parameter(0)
8515   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8516     sharding={devices=[1,1,1,2]0,1}
8517   %rhs = f32[5,1,1,2048] parameter(1)
8518   %rhs.copy = f32[5,1,1,2048] copy(%rhs),
8519     sharding={devices=[1,1,1,2]0,1}
8520   ROOT %conv = f32[16,801,1,2048] convolution(%lhs.copy, %rhs.copy),
8521     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8522     window={size=5x1 pad=2_2x0_0},
8523     sharding={devices=[1,1,1,2]0,1}
8524 })";
8525 
8526   TF_ASSERT_OK_AND_ASSIGN(auto module,
8527                           PartitionComputation(hlo_string, /*num_devices=*/2));
8528   VLOG(1) << module->ToString();
8529   const auto root = module->entry_computation()->root_instruction();
8530   const auto lhs = AllOf(
8531       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8532                                 op::Constant(), op::Reshape())),
8533       op::Shape("f32[16,801,1,512]"));
8534   const auto rhs = AllOf(
8535       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8536                                 op::Constant(), op::Reshape())),
8537       op::Shape("f32[5,1,1,1024]"));
8538   EXPECT_THAT(
8539       root, AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,1024]")));
8540 }
8541 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount2)8542 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount2) {
8543   absl::string_view hlo_string = R"(
8544 HloModule module
8545 
8546 ENTRY entry {
8547   %lhs = f32[64,3,1,3072] parameter(0)
8548   %lhs.copy = f32[64,3,1,3072] copy(%lhs),
8549     sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
8550     ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
8551   %rhs = f32[3,1,1,3072] parameter(1)
8552   %rhs.copy = f32[3,1,1,3072] copy(%rhs),
8553     sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
8554     ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
8555   ROOT %conv = f32[64,1,1,3072] convolution(%lhs.copy, %rhs.copy),
8556     dim_labels=b01f_01io->b01f,feature_group_count=3072,
8557     window={size=3x1},
8558     sharding={devices=[8,1,1,4]0,16,24,8,2,18,26,10,4,20,28,12,6,22,30,14,7,23,
8559     31,15,5,21,29,13,3,19,27,11,1,17,25,9}
8560 })";
8561 
8562   TF_ASSERT_OK_AND_ASSIGN(auto module,
8563                           PartitionComputation(hlo_string, /*num_devices=*/32));
8564   VLOG(1) << module->ToString();
8565   const auto root = module->entry_computation()->root_instruction();
8566   const auto lhs =
8567       AllOf(op::DynamicSlice(
8568                 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
8569                                           op::Constant(), op::Constant(),
8570                                           op::Reshape())),
8571                 op::Reshape(), op::Constant(), op::Constant(), op::Constant()),
8572             op::Shape("f32[8,3,1,768]"));
8573   const auto rhs = AllOf(
8574       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8575                                 op::Constant(), op::Reshape())),
8576       op::Shape("f32[3,1,1,768]"));
8577   EXPECT_THAT(root,
8578               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[8,1,1,768]")));
8579 }
8580 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountRHSAlignWithLHS)8581 TEST_F(SpmdPartitioningTest,
8582        PartitionConvWithFeatureGroupCountRHSAlignWithLHS) {
8583   absl::string_view hlo_string = R"(
8584 HloModule module
8585 
8586 ENTRY entry {
8587   %lhs = f32[16,801,1,1024] parameter(0)
8588   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8589     sharding={devices=[1,1,1,2]0,1}
8590   %rhs = f32[5,1,1,1024] parameter(1)
8591   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8592     sharding={devices=[2,1,1,1]0,1}
8593   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8594     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8595     window={size=5x1 pad=2_2x0_0},
8596     sharding={devices=[1,1,1,2]0,1}
8597 })";
8598 
8599   TF_ASSERT_OK_AND_ASSIGN(auto module,
8600                           PartitionComputation(hlo_string, /*num_devices=*/2));
8601   VLOG(1) << module->ToString();
8602   const auto root = module->entry_computation()->root_instruction();
8603   const auto lhs = AllOf(
8604       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8605                                 op::Constant(), op::Reshape())),
8606       op::Shape("f32[16,801,1,512]"));
8607   const auto rhs =
8608       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8609                                       op::Reshape(), op::Constant(),
8610                                       op::Constant(), op::Constant())),
8611             op::Shape("f32[3,1,1,1024]"));
8612   auto resharded_rhs = AllOf(
8613       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
8614       op::Shape("f32[5,1,1,512]"));
8615   EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
8616                           op::Shape("f32[16,801,1,512]")));
8617 }
8618 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountLHSAlignWithRHS)8619 TEST_F(SpmdPartitioningTest,
8620        PartitionConvWithFeatureGroupCountLHSAlignWithRHS) {
8621   absl::string_view hlo_string = R"(
8622 HloModule module
8623 
8624 ENTRY entry {
8625   %lhs = f32[16,801,1,1024] parameter(0)
8626   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8627     sharding={devices=[1,2,1,1]0,1}
8628   %rhs = f32[5,1,1,1024] parameter(1)
8629   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8630     sharding={devices=[1,1,1,2]0,1}
8631   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8632     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8633     window={size=5x1 pad=2_2x0_0},
8634     sharding={devices=[1,1,1,2]0,1}
8635 })";
8636 
8637   TF_ASSERT_OK_AND_ASSIGN(auto module,
8638                           PartitionComputation(hlo_string, /*num_devices=*/2));
8639   VLOG(1) << module->ToString();
8640   const auto root = module->entry_computation()->root_instruction();
8641   const auto lhs =
8642       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8643                                       op::Constant(), op::Reshape(),
8644                                       op::Constant(), op::Constant())),
8645             op::Shape("f32[16,401,1,1024]"));
8646   const auto rhs = AllOf(
8647       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8648                                 op::Constant(), op::Reshape())),
8649       op::Shape("f32[5,1,1,512]"));
8650   auto resharded_lhs = AllOf(
8651       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8652       op::Shape("f32[16,801,1,512]"));
8653   EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
8654                           op::Shape("f32[16,801,1,512]")));
8655 }
8656 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithLHS)8657 TEST_F(SpmdPartitioningTest,
8658        PartitionConvWithFeatureGroupCountAlignOuputWithLHS) {
8659   absl::string_view hlo_string = R"(
8660 HloModule module
8661 
8662 ENTRY entry {
8663   %lhs = f32[16,801,1,1024] parameter(0)
8664   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8665     sharding={devices=[1,1,1,2]0,1}
8666   %rhs = f32[5,1,1,1024] parameter(1)
8667   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8668     sharding={devices=[1,1,1,2]0,1}
8669   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8670     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8671     window={size=5x1 pad=2_2x0_0},
8672     sharding={devices=[2,1,1,1]0,1}
8673 })";
8674 
8675   TF_ASSERT_OK_AND_ASSIGN(auto module,
8676                           PartitionComputation(hlo_string, /*num_devices=*/2));
8677   VLOG(1) << module->ToString();
8678   const auto root = module->entry_computation()->root_instruction();
8679   const auto lhs = AllOf(
8680       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8681                                 op::Constant(), op::Reshape())),
8682       op::Shape("f32[16,801,1,512]"));
8683   const auto rhs = AllOf(
8684       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8685                                 op::Constant(), op::Reshape())),
8686       op::Shape("f32[5,1,1,512]"));
8687   auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"));
8688   EXPECT_THAT(root,
8689               AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
8690                     op::Shape("f32[8,801,1,1024]")));
8691 }
8692 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate)8693 TEST_F(SpmdPartitioningTest,
8694        PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate) {
8695   absl::string_view hlo_string = R"(
8696 HloModule module
8697 
8698 ENTRY entry {
8699   %lhs = f32[16,801,1,1024] parameter(0)
8700   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8701     sharding={devices=[1,2,1,2]0,1,2,3}
8702   %rhs = f32[5,1,1,1024] parameter(1)
8703   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8704     sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
8705   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8706     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8707     window={size=5x1 pad=2_2x0_0},
8708     sharding={devices=[1,2,1,2]0,1,2,3}
8709 })";
8710 
8711   TF_ASSERT_OK_AND_ASSIGN(auto module,
8712                           PartitionComputation(hlo_string, /*num_devices=*/4));
8713   VLOG(1) << module->ToString();
8714   const auto root = module->entry_computation()->root_instruction();
8715   const auto lhs =
8716       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8717                                       op::Constant(), op::Reshape(),
8718                                       op::Constant(), op::Reshape())),
8719             op::Shape("f32[16,401,1,512]"));
8720   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8721                          op::CollectivePermute(op::Slice(lhs)));
8722   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8723                           op::CollectivePermute(op::Slice(lhs)));
8724   const auto rhs = AllOf(
8725       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8726                                 op::Constant(), op::Reshape())),
8727       op::Shape("f32[5,1,1,512]"));
8728   EXPECT_THAT(
8729       root,
8730       AllOf(op::Convolution(
8731                 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
8732                 rhs),
8733             op::Shape("f32[16, 401, 1, 512]")));
8734 }
8735 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput)8736 TEST_F(SpmdPartitioningTest,
8737        PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput) {
8738   absl::string_view hlo_string = R"(
8739 HloModule module
8740 
8741 ENTRY entry {
8742   %lhs = f32[16,801,1,1024] parameter(0)
8743   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8744     sharding={devices=[1,2,1,2]0,1,2,3}
8745   %rhs = f32[5,1,1,1024] parameter(1), sharding={replicated}
8746   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs),
8747     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8748     window={size=5x1 pad=2_2x0_0},
8749     sharding={devices=[1,2,1,2]0,1,2,3}
8750 })";
8751   TF_ASSERT_OK_AND_ASSIGN(auto module,
8752                           PartitionComputation(hlo_string, /*num_devices=*/4));
8753   VLOG(1) << module->ToString();
8754   const auto root = module->entry_computation()->root_instruction();
8755   const auto lhs =
8756       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8757                                       op::Constant(), op::Reshape(),
8758                                       op::Constant(), op::Reshape())),
8759             op::Shape("f32[16,401,1,512]"));
8760   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8761                          op::CollectivePermute(op::Slice(lhs)));
8762   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8763                           op::CollectivePermute(op::Slice(lhs)));
8764   const auto rhs =
8765       AllOf(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8766                              op::Constant(), op::Reshape()),
8767             op::Shape("f32[5,1,1,512]"));
8768   EXPECT_THAT(
8769       root,
8770       AllOf(op::Convolution(
8771                 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
8772                 rhs),
8773             op::Shape("f32[16, 401, 1, 512]")));
8774 }
8775 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput)8776 TEST_F(SpmdPartitioningTest,
8777        PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput) {
8778   absl::string_view hlo_string = R"(
8779 HloModule module
8780 
8781 ENTRY entry {
8782   %lhs = f32[16,801,1,1024] parameter(0)
8783   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8784     sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
8785   %rhs = f32[5,1,1,1024] parameter(1)
8786   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8787     sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
8788   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8789     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8790     window={size=5x1 pad=2_2x0_0},
8791     sharding={devices=[1,2,1,2]0,1,2,3}
8792 })";
8793   TF_ASSERT_OK_AND_ASSIGN(auto module,
8794                           PartitionComputation(hlo_string, /*num_devices=*/4));
8795   VLOG(1) << module->ToString();
8796   const auto root = module->entry_computation()->root_instruction();
8797   const auto lhs = AllOf(
8798       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
8799                                 op::Constant(), op::Constant())),
8800       op::Shape("f32[8,801,1,1024]"));
8801   auto resharded_lhs =
8802       AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(
8803                 op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(),
8804                                          op::Subtract(), op::Subtract()),
8805                         op::Constant()))))),
8806             op::Shape("f32[16,401,1,512]"));
8807   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8808                          op::CollectivePermute(op::Slice(resharded_lhs)));
8809   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8810                           op::CollectivePermute(op::Slice(resharded_lhs)));
8811   const auto rhs = AllOf(
8812       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8813                                 op::Constant(), op::Reshape())),
8814       op::Shape("f32[5,1,1,512]"));
8815   EXPECT_THAT(
8816       root,
8817       AllOf(
8818           op::Convolution(
8819               op::Select(
8820                   _, op::Concatenate(left_halo, resharded_lhs, right_halo), _),
8821               rhs),
8822           op::Shape("f32[16, 401, 1, 512]")));
8823 }
8824 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnBatchGroupCount)8825 TEST_F(SpmdPartitioningTest, PartitionConvGroupOnBatchGroupCount) {
8826   absl::string_view hlo_string = R"(
8827 HloModule module
8828 
8829 ENTRY entry {
8830   %lhs = f32[16,801,1,1024] parameter(0)
8831   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8832     sharding={devices=[1,2,1,2]0,1,2,3}
8833   %rhs = f32[16,801,1,1024] parameter(1)
8834   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8835     sharding={devices=[1,2,1,2]0,1,2,3}
8836   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8837     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8838     window={size=801x1 pad=2_2x0_0},
8839     sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
8840 })";
8841 
8842   TF_ASSERT_OK_AND_ASSIGN(auto module,
8843                           PartitionComputation(hlo_string, /*num_devices=*/4));
8844   VLOG(1) << module->ToString();
8845   const auto root = module->entry_computation()->root_instruction();
8846   const auto lhs = AllOf(
8847       op::Select(_,
8848                  op::Copy(op::DynamicSlice(
8849                      op::Pad(op::Parameter(), op::Constant()), op::Constant(),
8850                      op::Reshape(), op::Constant(), op::Reshape())),
8851                  _),
8852       op::Shape("f32[16,401,1,512]"));
8853   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8854                          op::CollectivePermute(op::Slice(lhs)));
8855   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8856                           op::CollectivePermute(op::Slice(lhs)));
8857   const auto rhs =
8858       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8859                                       op::Constant(), op::Reshape(),
8860                                       op::Constant(), op::Reshape())),
8861             op::Shape("f32[16,401,1,512]"));
8862   auto conv = AllOf(op::Convolution(op::Concatenate(left_halo, lhs, right_halo),
8863                                     op::Select(_, rhs, _)),
8864                     op::Shape("f32[5,1,1,512]"));
8865   EXPECT_THAT(root, AllOf(op::CollectivePermute(op::AllReduce(conv)),
8866                           op::Shape("f32[5,1,1,512]")));
8867 }
8868 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithRHS)8869 TEST_F(SpmdPartitioningTest,
8870        PartitionConvWithFeatureGroupCountAlignOuputWithRHS) {
8871   absl::string_view hlo_string = R"(
8872 HloModule module
8873 
8874 ENTRY entry {
8875   %lhs = f32[16,801,1,1024] parameter(0)
8876   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8877     sharding={devices=[1,2,1,1]0,1}
8878   %rhs = f32[5,1,1,1024] parameter(1)
8879   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8880     sharding={devices=[1,1,1,2]0,1}
8881   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8882     dim_labels=b01f_01io->b01f,feature_group_count=1024,
8883     window={size=5x1 pad=2_2x0_0},
8884     sharding={devices=[2,1,1,1]0,1}
8885 })";
8886 
8887   TF_ASSERT_OK_AND_ASSIGN(auto module,
8888                           PartitionComputation(hlo_string, /*num_devices=*/2));
8889   VLOG(1) << module->ToString();
8890   const auto root = module->entry_computation()->root_instruction();
8891   const auto lhs =
8892       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8893                                       op::Constant(), op::Reshape(),
8894                                       op::Constant(), op::Constant())),
8895             op::Shape("f32[16,401,1,1024]"));
8896   const auto rhs = AllOf(
8897       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8898                                 op::Constant(), op::Reshape())),
8899       op::Shape("f32[5,1,1,512]"));
8900   auto resharded_lhs = AllOf(
8901       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8902       op::Shape("f32[16,801,1,512]"));
8903   auto conv = AllOf(op::Convolution(resharded_lhs, rhs),
8904                     op::Shape("f32[16,801,1,512]"));
8905   EXPECT_THAT(root,
8906               AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
8907                     op::Shape("f32[8,801,1,1024]")));
8908 }
8909 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountBackProp)8910 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountBackProp) {
8911   absl::string_view hlo_string = R"(
8912 HloModule module
8913 
8914 ENTRY entry {
8915   %lhs = f32[16,801,1,1024] parameter(0)
8916   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8917     sharding={devices=[1,1,1,2]0,1}
8918   %rhs = f32[5,1,1024,1] parameter(1)
8919   %rhs.copy = f32[5,1,1024,1] copy(%rhs),
8920     sharding={devices=[1,1,2,1]0,1}
8921   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8922     dim_labels=b01f_01oi->b01f,feature_group_count=1024,
8923     window={size=5x1 pad=2_2x0_0 rhs_reversal=1x1},
8924     sharding={devices=[1,1,1,2]0,1}
8925 })";
8926 
8927   TF_ASSERT_OK_AND_ASSIGN(auto module,
8928                           PartitionComputation(hlo_string, /*num_devices=*/2));
8929   VLOG(1) << module->ToString();
8930   const auto root = module->entry_computation()->root_instruction();
8931   const auto lhs = AllOf(
8932       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8933                                 op::Constant(), op::Reshape())),
8934       op::Shape("f32[16,801,1,512]"));
8935   const auto rhs = AllOf(
8936       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8937                                 op::Reshape(), op::Constant())),
8938       op::Shape("f32[5,1,512,1]"));
8939   EXPECT_THAT(root,
8940               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")));
8941 }
8942 
TEST_F(SpmdPartitioningTest,NoReshardOnBroadcastDims)8943 TEST_F(SpmdPartitioningTest, NoReshardOnBroadcastDims) {
8944   absl::string_view hlo_string = R"(
8945 HloModule module
8946 
8947 ENTRY entry {
8948   %param0 = f32[2,3] parameter(0)
8949   %param1 = f32[2,3,20] parameter(1)
8950   %br0 = f32[20,2,20,3,20] broadcast(%param0), dimensions={1,3}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8951   %br1 = f32[20,2,20,3,20] broadcast(%param1), dimensions={1,3,4}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8952   %add = f32[20,2,20,3,20] add(%br0, %br1), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8953   %reshape = f32[10,4,10,6,20] reshape(%br0), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8954   %transpose = f32[2,3,20,20,20] transpose(%br0), dimensions={1,3,0,2,4}, sharding={devices=[1,1,2,2,2]0,1,2,3,4,5,6,7}
8955   %copy_add0 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]6,7,2,3,4,5,0,1}
8956   %copy_add1 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
8957   %copy_reshape = f32[10,4,10,6,20] copy(%reshape), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
8958   %copy_transpose = f32[2,3,20,20,20] copy(%transpose), sharding={devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}
8959   ROOT %tuple = (f32[20,2,20,3,20], f32[20,2,20,3,20], f32[10,4,10,6,20], f32[2,3,20,20,20])
8960     tuple(%copy_add0, %copy_add1, %copy_reshape, %copy_transpose),
8961     sharding={{devices=[2,1,2,1,2]6,7,2,3,4,5,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}}
8962 })";
8963 
8964   TF_ASSERT_OK_AND_ASSIGN(auto module,
8965                           PartitionComputation(hlo_string, /*num_devices=*/8));
8966   VLOG(1) << module->ToString();
8967   const auto root = module->entry_computation()->root_instruction();
8968   // Reshard on copy_add0 only happens on broadcast dims, can be skipped.
8969   auto copy_add0 =
8970       op::Copy(op::Copy(op::Add(op::Broadcast(_), op::Broadcast(_))));
8971   // Reshard on copy_add1 also happens on non-broadcast dims.
8972   auto copy_add1 = op::Copy(
8973       op::CollectivePermute(op::Add(op::Broadcast(_), op::Broadcast(_))));
8974   // Reshard on copy_reshape only happens on broadcast dims, can be skipped.
8975   auto copy_reshape = op::Copy(op::Copy(op::Reshape(op::Broadcast(_))));
8976   // Reshard on copy_transpose only happens on broadcast dims, can be skipped.
8977   auto copy_transpose = op::Copy(op::Copy(op::Transpose(op::Broadcast(_))));
8978   EXPECT_THAT(root,
8979               op::Tuple(copy_add0, copy_add1, copy_reshape, copy_transpose));
8980 }
8981 
TEST_F(SpmdPartitioningTest,ConvolutionFilterIFOFPartitionedInputPartialReplicate)8982 TEST_F(SpmdPartitioningTest,
8983        ConvolutionFilterIFOFPartitionedInputPartialReplicate) {
8984   absl::string_view hlo_string = R"(
8985 HloModule module
8986 
8987 ENTRY entry {
8988   %lhs = f32[128,112,112,12] parameter(0)
8989   %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs),
8990     sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
8991   %rhs = f32[7,7,12,64] parameter(1)
8992   %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs),
8993     sharding={devices=[1,1,2,2]0,1,2,3}
8994   ROOT %conv = f32[128,56,56,64] convolution(
8995     f32[128,112,112,12] %lhs.copy,
8996     f32[7,7,12,64] %rhs.copy),
8997     window={size=7x7 stride=2x2 pad=3_3x3_3},
8998     dim_labels=b01f_01io->b01f,
8999     sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
9000 })";
9001 
9002   TF_ASSERT_OK_AND_ASSIGN(auto module,
9003                           PartitionComputation(hlo_string, /*num_devices=*/4));
9004   VLOG(1) << module->ToString();
9005   const auto root = module->entry_computation()->root_instruction();
9006   const auto lhs = AllOf(
9007       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9008                                 op::Constant(), op::Reshape())),
9009       op::Shape("f32[128,112,112,6]"));
9010   const auto rhs = AllOf(
9011       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9012                                 op::Reshape(), op::Reshape())),
9013       op::Shape("f32[7,7,6,32]"));
9014 
9015   EXPECT_THAT(
9016       root,
9017       AllOf(op::CollectivePermute(op::AllReduce(op::Convolution(lhs, rhs))),
9018             op::Shape("f32[128,56,56,32]")));
9019 }
9020 
TEST_F(SpmdPartitioningTest,ConvolutionInputKernelNonContractingDimPartialReplicate)9021 TEST_F(SpmdPartitioningTest,
9022        ConvolutionInputKernelNonContractingDimPartialReplicate) {
9023   absl::string_view hlo_string = R"(
9024 HloModule module
9025 
9026 ENTRY entry {
9027   %lhs = f32[128,56,56,256] parameter(0)
9028   %lhs.copy = f32[128,56,56,256] copy(%lhs),
9029   sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
9030   %rhs = f32[128,28,28,512] parameter(1)
9031   %rhs.copy = f32[128,28,28,512] copy(%rhs),
9032   sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
9033   ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
9034     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf,
9035     sharding={devices=[1,1,2,2]0,1,2,3}
9036 })";
9037 
9038   TF_ASSERT_OK_AND_ASSIGN(auto module,
9039                           PartitionComputation(hlo_string, /*num_devices=*/4));
9040   VLOG(1) << module->ToString();
9041   const auto root = module->entry_computation()->root_instruction();
9042   const auto lhs = AllOf(
9043       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9044                                 op::Constant(), op::Reshape())),
9045       op::Shape("f32[128,56,56,128]"));
9046   const auto rhs = AllOf(
9047       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9048                                 op::Constant(), op::Reshape())),
9049       op::Shape("f32[128,28,28,256]"));
9050 
9051   EXPECT_THAT(root, AllOf(op::Convolution(lhs, op::CollectivePermute(rhs)),
9052                           op::Shape("f32[1,1,128,256]")));
9053 }
9054 
TEST_F(SpmdPartitioningTest,ConvolutionInputSpatialDimAndFeatureDimParttiioned)9055 TEST_F(SpmdPartitioningTest,
9056        ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
9057   absl::string_view hlo_string = R"(
9058 HloModule module
9059 
9060 ENTRY entry {
9061   %lhs = f32[8,210,210,12] parameter(0)
9062   %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
9063     sharding={devices=[1,2,1,2]0,1,2,3}
9064   %rhs = f32[3,3,12,32] parameter(1)
9065   %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
9066     sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
9067   ROOT %conv = f32[8,210,210,32] convolution(
9068     f32[8,210,210,12] %lhs.copy,
9069     f32[3,3,12,32] %rhs.copy),
9070     window={size=3x3 pad=1_1x1_1},
9071     dim_labels=b01f_01io->b01f,
9072     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
9073 })";
9074   TF_ASSERT_OK_AND_ASSIGN(auto module,
9075                           PartitionComputation(hlo_string, /*num_devices=*/4));
9076   VLOG(1) << module->ToString();
9077   const auto root = module->entry_computation()->root_instruction();
9078   const auto lhs = AllOf(
9079       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
9080                                 op::Constant(), op::Reshape())),
9081       op::Shape("f32[8,105,210,6]"));
9082   auto left_halo =
9083       AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
9084   auto right_halo =
9085       AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
9086   auto exchanged_lhs = AllOf(
9087       op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
9088                  op::Broadcast(_)),
9089       op::Shape("f32[8,107,210,6]"));
9090   const auto rhs = AllOf(
9091       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9092                                 op::Reshape(), op::Constant())),
9093       op::Shape("f32[3,3,6,32]"));
9094   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
9095                               exchanged_lhs, op::CollectivePermute(rhs))),
9096                           op::Shape("f32[8,105,210,32]")));
9097 }
9098 
TEST_F(SpmdPartitioningTest,Fft3D)9099 TEST_F(SpmdPartitioningTest, Fft3D) {
9100   absl::string_view hlo_string = R"(
9101 HloModule module
9102 
9103 ENTRY entry {
9104   constant = c64[1,1,6]
9105     constant({{{(0,0),(1,1),(2,2),(3,3),(4,4),(5,5)}}}),
9106     sharding={devices=[1,1,2]0,1}
9107   ROOT fft = c64[1,1,6] fft(c64[1,1,6] constant), fft_type=FFT, fft_length={6},
9108     sharding={devices=[1,1,2]0,1}
9109 }
9110 )";
9111 
9112   TF_ASSERT_OK_AND_ASSIGN(auto module,
9113                           PartitionComputation(hlo_string, /*num_devices=*/2));
9114   VLOG(1) << module->ToString();
9115   const auto root = module->entry_computation()->root_instruction();
9116   auto input = AllOf(op::DynamicSlice(op::Constant(), op::Constant(),
9117                                       op::Constant(), op::Reshape()),
9118                      op::Shape("c64[1,1,3]"));
9119   auto padded_input =
9120       AllOf(op::DynamicSlice(
9121                 op::Concatenate(input, op::CollectivePermute(op::Slice())),
9122                 op::Constant(), op::Constant(), op::Reshape()),
9123             op::Shape("c64[1,1,4]"));
9124 
9125   auto shuffled_input =
9126       AllOf(op::Slice(op::AllToAll(op::Dot(padded_input, op::Convert()))),
9127             op::Shape("c64[1,1,3]"));
9128 
9129   auto local_fft = AllOf(op::Fft(shuffled_input), op::Shape("c64[1,1,3]"));
9130 
9131   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple(
9132                               _, op::Multiply(local_fft, op::Exp()), _, _, _))),
9133                           op::Shape("c64[1,1,3]")));
9134 }
9135 
TEST_F(SpmdPartitioningTest,DotInputsAreIdentical)9136 TEST_F(SpmdPartitioningTest, DotInputsAreIdentical) {
9137   absl::string_view hlo_string = R"(
9138 HloModule module
9139 
9140 ENTRY entry {
9141   %parameter.1 = f32[4000,4000]{1,0} parameter(0),
9142     sharding={devices=[2,4]0,1,2,3,4,5,6,7}
9143   ROOT %convolution = f32[4000,4000]{1,0} convolution(
9144     f32[4000,4000]{1,0} %parameter.1, f32[4000,4000]{1,0} %parameter.1),
9145     dim_labels=bf_io->bf, sharding={devices=[2,4]0,1,2,3,4,5,6,7}
9146 }
9147 
9148 )";
9149 
9150   TF_ASSERT_OK_AND_ASSIGN(auto module,
9151                           PartitionComputation(hlo_string, /*num_devices=*/8));
9152   VLOG(1) << module->ToString();
9153   const auto root = module->entry_computation()->root_instruction();
9154   auto param = AllOf(op::Parameter(), op::Shape("f32[2000, 1000]"));
9155   auto resharded_lhs =
9156       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _, _)),
9157             op::Shape("f32[2000, 4000]"));
9158   auto resharded_rhs =
9159       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, op::Copy(param), _, _)),
9160             op::Shape("f32[4000, 1000]"));
9161   EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, resharded_rhs),
9162                           op::Shape("f32[2000, 1000]")));
9163 }
9164 
TEST_F(SpmdPartitioningTest,ConstantSliceReshard)9165 TEST_F(SpmdPartitioningTest, ConstantSliceReshard) {
9166   absl::string_view hlo_string = R"(
9167 HloModule module
9168 
9169 ENTRY entry {
9170   %constant.785 = f32[1,8] constant({{0,1,2,3,4,5,6,7}}),
9171     sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9172   %slice.62 = f32[1,1] slice(%constant.785), slice={[0:1], [0:1]},
9173     sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9174   ROOT %reshape.779 = f32[] reshape(%slice.62), sharding={replicated}
9175 })";
9176   TF_ASSERT_OK_AND_ASSIGN(auto module,
9177                           PartitionComputation(hlo_string, /*num_devices=*/8));
9178   const auto root = module->entry_computation()->root_instruction();
9179   VLOG(1) << module->ToString();
9180   auto slice = AllOf(op::Shape("f32[1,1]"),
9181                      op::Copy(op::DynamicSlice(op::Constant(), _, _)));
9182   EXPECT_THAT(root, op::Reshape(op::AllReduce(op::Select(_, slice, _))));
9183 }
9184 
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionOperand)9185 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionOperand) {
9186   absl::string_view hlo_string = R"(
9187 HloModule module
9188 
9189 ENTRY %module {
9190   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9191     sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
9192   %constant = s32[4] constant({0, 1, 2, 3}), sharding={replicated}
9193   %iota = s32[1,8,4]{2,1,0} broadcast(%constant), dimensions={2},
9194     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9195   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9196     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9197   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9198     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9199     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9200   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9201     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9202     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9203     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9204     slice_sizes={1,1,2,2}, sharding={replicated}
9205 })";
9206   TF_ASSERT_OK_AND_ASSIGN(auto module,
9207                           PartitionComputation(hlo_string, /*num_devices=*/8));
9208   const auto root = module->entry_computation()->root_instruction();
9209   VLOG(1) << module->ToString();
9210   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Reshape());
9211   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9212   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9213   EXPECT_THAT(root,
9214               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9215 }
9216 
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionIndices)9217 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionIndices) {
9218   absl::string_view hlo_string = R"(
9219 HloModule module
9220 
9221 ENTRY %module {
9222   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9223     sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
9224   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9225     sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
9226   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9227     sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
9228   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9229     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9230     sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
9231   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(s32[8,4,2,2]{3,2,1,0} %parameter.0,
9232     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9233     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9234     slice_sizes={1,1,2,2}, sharding={replicated}
9235 })";
9236   TF_ASSERT_OK_AND_ASSIGN(auto module,
9237                           PartitionComputation(hlo_string, /*num_devices=*/8));
9238   const auto root = module->entry_computation()->root_instruction();
9239   VLOG(1) << module->ToString();
9240   auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice());
9241   auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract());
9242   auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices));
9243   EXPECT_THAT(root, op::AllReduce(op::AllReduce(
9244                         op::DynamicUpdateSlice(_, gather, _, _, _, _))));
9245 }
9246 
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedIndices)9247 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedIndices) {
9248   absl::string_view hlo_string = R"(
9249 HloModule module
9250 
9251 ENTRY %module {
9252   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9253     sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
9254   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9255     sharding={replicated}
9256   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9257     sharding={replicated}
9258   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9259     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9260     sharding={replicated}
9261   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9262     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9263     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9264     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9265     slice_sizes={1,1,2,2}, sharding={replicated}
9266 })";
9267   TF_ASSERT_OK_AND_ASSIGN(auto module,
9268                           PartitionComputation(hlo_string, /*num_devices=*/8));
9269   const auto root = module->entry_computation()->root_instruction();
9270   VLOG(1) << module->ToString();
9271   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
9272   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9273   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9274   EXPECT_THAT(root,
9275               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9276 }
9277 
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedOperand)9278 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedOperand) {
9279   absl::string_view hlo_string = R"(
9280 HloModule module
9281 
9282 ENTRY %module {
9283   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={replicated}
9284   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9285     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9286   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9287     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9288   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9289     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9290     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9291   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9292     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9293     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9294     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9295     slice_sizes={1,1,2,2}, sharding={replicated}
9296 })";
9297   TF_ASSERT_OK_AND_ASSIGN(auto module,
9298                           PartitionComputation(hlo_string, /*num_devices=*/8));
9299   const auto root = module->entry_computation()->root_instruction();
9300   VLOG(1) << module->ToString();
9301   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
9302   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9303   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9304   EXPECT_THAT(root,
9305               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9306 }
9307 
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedIndices)9308 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedIndices) {
9309   absl::string_view hlo_string = R"(
9310 HloModule module
9311 
9312 ENTRY %module {
9313   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9314     sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
9315   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9316     sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9317   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9318     sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9319   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9320     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9321     sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9322   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9323     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9324     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9325     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9326     slice_sizes={1,1,2,2}, sharding={replicated}
9327 })";
9328   TF_ASSERT_OK_AND_ASSIGN(auto module,
9329                           PartitionComputation(hlo_string, /*num_devices=*/8));
9330   const auto root = module->entry_computation()->root_instruction();
9331   VLOG(1) << module->ToString();
9332   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
9333   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9334   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9335   EXPECT_THAT(root,
9336               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9337 }
9338 
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedOperand)9339 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedOperand) {
9340   absl::string_view hlo_string = R"(
9341 HloModule module
9342 
9343 ENTRY %module {
9344   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
9345     devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9346   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9347     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9348   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9349     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9350   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9351     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9352     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9353   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9354     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9355     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9356     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9357     slice_sizes={1,1,2,2}, sharding={replicated}
9358 })";
9359   TF_ASSERT_OK_AND_ASSIGN(auto module,
9360                           PartitionComputation(hlo_string, /*num_devices=*/8));
9361   const auto root = module->entry_computation()->root_instruction();
9362   VLOG(1) << module->ToString();
9363   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
9364   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9365   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9366   EXPECT_THAT(root,
9367               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9368 }
9369 
TEST_F(SpmdPartitioningTest,GatherParallelDimSwappedDimensions)9370 TEST_F(SpmdPartitioningTest, GatherParallelDimSwappedDimensions) {
9371   absl::string_view hlo_string = R"(
9372 HloModule module
9373 
9374 ENTRY %module {
9375   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
9376     devices=[4,2,1,1]0,1,2,3,4,5,6,7}
9377   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9378     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
9379   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9380     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
9381   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9382     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9383     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
9384   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9385     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9386     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9387     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9388     slice_sizes={1,1,2,2}, sharding={replicated}
9389 })";
9390   TF_ASSERT_OK_AND_ASSIGN(auto module,
9391                           PartitionComputation(hlo_string, /*num_devices=*/8));
9392   const auto root = module->entry_computation()->root_instruction();
9393   VLOG(1) << module->ToString();
9394   auto operand = AllOf(op::Shape("s32[4,1,2,2]"), op::CollectivePermute());
9395   auto indices = AllOf(op::Shape("s32[2,4,1]"), op::Subtract());
9396   auto gather = AllOf(op::Shape("s32[4,1,2,2]"), op::Gather(operand, indices));
9397   EXPECT_THAT(root, op::AllReduce(op::AllReduce(
9398                         op::DynamicUpdateSlice(_, gather, _, _, _, _))));
9399 }
9400 
TEST_F(SpmdPartitioningTest,GatherParallelDimAndNonParallelDimPartitioned)9401 TEST_F(SpmdPartitioningTest, GatherParallelDimAndNonParallelDimPartitioned) {
9402   absl::string_view hlo_string = R"(
9403 HloModule module
9404 
9405 ENTRY %module {
9406   %arg.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
9407   %arg.1 = s32[1,8,4]{2,1,0} parameter(1)
9408   %operand = s32[8,4,2,2]{3,2,1,0} copy(s32[8,4,2,2]{3,2,1,0} %arg.0),
9409     sharding={devices=[2,2,1,1]0,1,2,3}
9410   %indices = s32[1,8,4]{2,1,0} copy(s32[1,8,4]{2,1,0} %arg.1),
9411     sharding={devices=[1,2,2]0,1,2,3}
9412   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9413     sharding={devices=[1,2,2]0,1,2,3}
9414   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9415     s32[1,8,4]{2,1,0} %indices), dimensions={0},
9416     sharding={devices=[1,2,2]0,1,2,3}
9417   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9418     s32[8,4,2,2]{3,2,1,0} %operand,
9419     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9420     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9421     slice_sizes={1,1,2,2}, sharding={replicated}
9422 })";
9423 
9424   TF_ASSERT_OK_AND_ASSIGN(auto module,
9425                           PartitionComputation(hlo_string, /*num_devices=*/4));
9426   const auto root = module->entry_computation()->root_instruction();
9427   VLOG(1) << module->ToString();
9428   auto operand = AllOf(op::Shape("s32[4,4,2,2]"), op::AllReduce());
9429   auto indices = AllOf(op::Shape("s32[2,4,2]"), op::Subtract());
9430   auto gather = AllOf(op::Shape("s32[4,2,2,2]"), op::Gather(operand, indices));
9431   EXPECT_THAT(
9432       root, op::AllReduce(op::DynamicUpdateSlice(
9433                 _, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)),
9434                 _, _, _, _)));
9435 }
9436 
TEST_F(SpmdPartitioningTest,GatherMergedParalleIndexPassthrough)9437 TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {
9438   absl::string_view hlo_string = R"(
9439 HloModule module
9440 
9441 ENTRY %module {
9442   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9443     sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
9444   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9445     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9446   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9447     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9448   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9449     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9450     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9451   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9452     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9453     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9454     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9455     slice_sizes={1,1,2,2}, sharding={replicated}
9456 })";
9457   TF_ASSERT_OK_AND_ASSIGN(auto module,
9458                           PartitionComputation(hlo_string, /*num_devices=*/8));
9459   VLOG(1) << module->ToString();
9460   const auto root = module->entry_computation()->root_instruction();
9461   auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Reshape());
9462   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9463   auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
9464   EXPECT_THAT(root, op::AllReduce(op::AllReduce(
9465                         op::DynamicUpdateSlice(_, gather, _, _, _, _))));
9466 }
9467 
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperand)9468 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperand) {
9469   absl::string_view hlo_string = R"(
9470 HloModule module
9471 
9472 ENTRY %module {
9473   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9474     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9475   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9476     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9477   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9478     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9479   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9480     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9481     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9482   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9483     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9484     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9485     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9486     slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9487 })";
9488   TF_ASSERT_OK_AND_ASSIGN(auto module,
9489                           PartitionComputation(hlo_string, /*num_devices=*/8));
9490   VLOG(1) << module->ToString();
9491   const auto root = module->entry_computation()->root_instruction();
9492   auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
9493   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9494   auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
9495   EXPECT_THAT(root, gather);
9496 }
9497 
TEST_F(SpmdPartitioningTest,GatherReshardParalleIndexAndOperand)9498 TEST_F(SpmdPartitioningTest, GatherReshardParalleIndexAndOperand) {
9499   absl::string_view hlo_string = R"(
9500 HloModule module
9501 
9502 ENTRY %module {
9503   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9504     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9505   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9506     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9507   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9508     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9509   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9510     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9511     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9512   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9513     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9514     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9515     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9516     slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]1,0,3,2,4,5,6,7}
9517 })";
9518   TF_ASSERT_OK_AND_ASSIGN(auto module,
9519                           PartitionComputation(hlo_string, /*num_devices=*/8));
9520   VLOG(1) << module->ToString();
9521   const auto root = module->entry_computation()->root_instruction();
9522   auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
9523   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9524   auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
9525   EXPECT_THAT(root, op::CollectivePermute(gather));
9526 }
9527 
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperandReshard)9528 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperandReshard) {
9529   absl::string_view hlo_string = R"(
9530 HloModule module
9531 
9532 ENTRY %module {
9533   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9534     sharding={devices=[4,1,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9535   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9536     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9537   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9538     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9539   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9540     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9541     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9542   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9543     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9544     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9545     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9546     slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9547 })";
9548   TF_ASSERT_OK_AND_ASSIGN(auto module,
9549                           PartitionComputation(hlo_string, /*num_devices=*/8));
9550   VLOG(1) << module->ToString();
9551   const auto root = module->entry_computation()->root_instruction();
9552   auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Parameter(0));
9553   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9554   auto gather = AllOf(op::Shape("s32[2,4,2,2]"), op::Gather(operand, indices));
9555   EXPECT_THAT(root, op::DynamicSlice(gather, _, _, _, _));
9556 }
9557 
TEST_F(SpmdPartitioningTest,GatherMergedParallelIndexTrivialSlice)9558 TEST_F(SpmdPartitioningTest, GatherMergedParallelIndexTrivialSlice) {
9559   absl::string_view hlo_string = R"(
9560 HloModule module
9561 
9562 ENTRY %module {
9563   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9564     sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
9565   %parameter.1 = s32[1,8,1]{2,1,0} parameter(1),
9566     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9567   %iota = s32[1,8,1]{2,1,0} iota(), iota_dimension=1,
9568     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9569   %concatenate.19 = s32[2,8,1]{2,1,0} concatenate(
9570     s32[1,8,1]{2,1,0} %parameter.1, s32[1,8,1]{2,1,0} %iota), dimensions={0},
9571     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9572   ROOT %gather.20 = s32[8,1,2,2]{3,2,1,0} gather(
9573     s32[8,4,2,2]{3,2,1,0} %parameter.0,
9574     s32[2,8,1]{2,1,0} %concatenate.19), offset_dims={2,3},
9575     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9576     slice_sizes={1,1,2,2}, sharding={replicated}
9577 })";
9578   TF_ASSERT_OK_AND_ASSIGN(auto module,
9579                           PartitionComputation(hlo_string, /*num_devices=*/8));
9580   const auto root = module->entry_computation()->root_instruction();
9581   auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Parameter());
9582   auto indices = AllOf(op::Shape("s32[2,2,1]"), op::Subtract());
9583   auto gather = AllOf(op::Shape("s32[2,1,2,2]"), op::Gather(operand, indices));
9584   VLOG(1) << module->ToString();
9585   EXPECT_THAT(root,
9586               op::AllReduce(op::DynamicUpdateSlice(
9587                   _, op::AllReduce(op::Select(_, _, gather)), _, _, _, _)));
9588 }
9589 
TEST_F(SpmdPartitioningTest,SortTopKNonSortDimension)9590 TEST_F(SpmdPartitioningTest, SortTopKNonSortDimension) {
9591   absl::string_view hlo_string = R"(
9592 HloModule module
9593 
9594 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
9595   p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
9596   %p.0.lhs.42078 = f32[] parameter(0)
9597   %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
9598   %constant.45054 = s32[] constant(0)
9599   %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
9600     s32[] %constant.45054), direction=LT
9601   %constant.45278 = u32[] constant(2147483647)
9602   %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
9603   %subtract.337 = u32[] subtract(u32[] %constant.45278,
9604     u32[] %bitcast-convert.136)
9605   %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
9606   %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
9607     s32[] %bitcast-convert.135)
9608   %p.0.rhs.42079 = f32[] parameter(1)
9609   %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
9610   %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
9611     s32[] %constant.45054), direction=LT
9612   %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
9613   %subtract.338 = u32[] subtract(u32[] %constant.45278,
9614     u32[] %bitcast-convert.139)
9615   %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
9616   %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
9617     s32[] %bitcast-convert.138)
9618   %compare.135 = pred[] compare(s32[] %select.282,
9619     s32[] %select.283), direction=GT
9620   %compare.428 = pred[] compare(s32[] %select.283,
9621     s32[] %select.282), direction=GT
9622   %compare.429 = pred[] compare(pred[] %compare.135,
9623     pred[] %compare.428), direction=EQ
9624   %p.1.lhs.42080 = s32[] parameter(2)
9625   %p.1.rhs.42081 = s32[] parameter(3)
9626   %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
9627     s32[] %p.1.rhs.42081), direction=LT
9628   ROOT %select.579 = pred[] select(pred[] %compare.429,
9629     pred[] %compare.430, pred[] %compare.135)
9630 }
9631 
9632 ENTRY %module {
9633   %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
9634      sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9635   %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
9636     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9637   %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
9638     f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
9639     dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
9640     sharding={{devices=[2,1,4]0,1,2,3,4,5,6,7},
9641     {devices=[2,1,4]0,1,2,3,4,5,6,7}}
9642   output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
9643     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9644   %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
9645     slice={[0:2], [0:64], [0:2]},
9646     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9647   output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
9648     sharding={replicated}
9649   %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
9650     slice={[0:2], [0:64], [0:2]},
9651     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9652   ROOT output.t = (f32[2,64,2]{2,1,0},
9653     s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
9654     sharding={{replicated}, {replicated}}
9655 })";
9656   TF_ASSERT_OK_AND_ASSIGN(auto module,
9657                           PartitionComputation(hlo_string, /*num_devices=*/8));
9658 
9659   const HloInstruction* sort = FindInstruction(module.get(), "sort.0");
9660   EXPECT_NE(sort, nullptr);
9661   auto sort_match =
9662       AllOf(op::Shape("(f32[2,64,32128], s32[2,64,32128])"), op::Sort(_, _));
9663   EXPECT_THAT(sort, sort_match);
9664 }
9665 
TEST_F(SpmdPartitioningTest,SortTopKPropagateBaseShape)9666 TEST_F(SpmdPartitioningTest, SortTopKPropagateBaseShape) {
9667   absl::string_view hlo_string = R"(
9668 HloModule module
9669 
9670 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
9671   p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
9672   %p.0.lhs.42078 = f32[] parameter(0)
9673   %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
9674   %constant.45054 = s32[] constant(0)
9675   %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
9676     s32[] %constant.45054), direction=LT
9677   %constant.45278 = u32[] constant(2147483647)
9678   %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
9679   %subtract.337 = u32[] subtract(u32[] %constant.45278,
9680     u32[] %bitcast-convert.136)
9681   %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
9682   %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
9683     s32[] %bitcast-convert.135)
9684   %p.0.rhs.42079 = f32[] parameter(1)
9685   %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
9686   %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
9687     s32[] %constant.45054), direction=LT
9688   %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
9689   %subtract.338 = u32[] subtract(u32[] %constant.45278,
9690     u32[] %bitcast-convert.139)
9691   %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
9692   %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
9693     s32[] %bitcast-convert.138)
9694   %compare.135 = pred[] compare(s32[] %select.282,
9695     s32[] %select.283), direction=GT
9696   %compare.428 = pred[] compare(s32[] %select.283,
9697     s32[] %select.282), direction=GT
9698   %compare.429 = pred[] compare(pred[] %compare.135,
9699     pred[] %compare.428), direction=EQ
9700   %p.1.lhs.42080 = s32[] parameter(2)
9701   %p.1.rhs.42081 = s32[] parameter(3)
9702   %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
9703     s32[] %p.1.rhs.42081), direction=LT
9704   ROOT %select.579 = pred[] select(pred[] %compare.429,
9705     pred[] %compare.430, pred[] %compare.135)
9706 }
9707 
9708 ENTRY %module {
9709   %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
9710      sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9711   %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
9712     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9713   %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
9714     f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
9715     dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
9716     sharding={{devices=[1,1,8]0,1,2,3,4,5,6,7},
9717     {devices=[1,1,8]0,1,2,3,4,5,6,7}}
9718   output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
9719     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9720   %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
9721     slice={[0:2], [0:64], [0:2]},
9722     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9723   output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
9724     sharding={replicated}
9725   %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
9726     slice={[0:2], [0:64], [0:2]},
9727     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9728   ROOT output.t = (f32[2,64,2]{2,1,0},
9729     s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
9730     sharding={{replicated}, {replicated}}
9731 })";
9732   TF_ASSERT_OK_AND_ASSIGN(auto module,
9733                           PartitionComputation(hlo_string, /*num_devices=*/8));
9734   VLOG(1) << module->ToString();
9735   const HloInstruction* root = module->entry_computation()->root_instruction();
9736   auto all_reduce_val =
9737       AllOf(op::Shape("f32[2,64,2]"),
9738             op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _)));
9739   auto all_reduce_idx =
9740       AllOf(op::Shape("s32[2,64,2]"),
9741             op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _)));
9742   auto tuple = AllOf(op::Shape("(f32[2,64,2], s32[2,64,2])"),
9743                      op::Tuple(all_reduce_val, all_reduce_idx));
9744   EXPECT_THAT(root, tuple);
9745 }
9746 
TEST_F(SpmdPartitioningTest,GatherIndexOnlyCorrectReplacement)9747 TEST_F(SpmdPartitioningTest, GatherIndexOnlyCorrectReplacement) {
9748   absl::string_view hlo_string = R"(
9749 HloModule module
9750 
9751 ENTRY %module {
9752   %parameter.0 = bf16[1,8,6,6]{3,2,1,0} parameter(0),
9753     sharding={replicated}
9754   %parameter.1 = s32[2,4]{1,0} parameter(1),
9755      sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9756   %gather.100 = bf16[2,1,8,1,6]{4,3,2,1,0} gather(
9757     bf16[1,8,6,6]{3,2,1,0} %parameter.0, s32[2,4]{1,0} %parameter.1),
9758     offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3},
9759     index_vector_dim=1, slice_sizes={1,8,1,6},
9760     sharding={devices=[2,1,4,1,1]0,1,2,3,4,5,6,7}
9761   %constant.45590 = s32[] constant(0), sharding={replicated}
9762   %broadcast.54515 = s32[2,64,1,1]{3,2,1,0} broadcast(s32[] %constant.45590),
9763     dimensions={},
9764     sharding={devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9765   ROOT %reshape.4243 = bf16[2,8,6]{2,1,0} reshape(
9766     bf16[2,1,8,1,6]{4,3,2,1,0} %gather.100),
9767     sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
9768 })";
9769   TF_ASSERT_OK_AND_ASSIGN(auto module,
9770                           PartitionComputation(hlo_string, /*num_devices=*/8));
9771 
9772   const HloInstruction* root = module->entry_computation()->root_instruction();
9773   auto param0 = AllOf(op::Shape("bf16[1,8,6,6]"), op::Parameter());
9774   auto param1 = AllOf(op::Shape("s32[1,4]"), op::Parameter());
9775   auto reshape = AllOf(
9776       op::Shape("bf16[1,2,6]"),
9777       op::Reshape(op::DynamicSlice(op::Gather(param0, param1), _, _, _, _, _)));
9778   EXPECT_THAT(root, reshape);
9779 }
9780 
TEST_F(SpmdPartitioningTest,GatherRegressionTest1)9781 TEST_F(SpmdPartitioningTest, GatherRegressionTest1) {
9782   absl::string_view hlo_string = R"(
9783 HloModule module
9784 
9785 ENTRY %module {
9786   %parameter.0 = s32[1,4] parameter(0), sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9787   %iota.10 = s32[4]{0} iota(), iota_dimension=0, sharding={devices=[8]0,1,2,3,4,5,6,7}
9788   ROOT %gather.44 = s32[1,4]{1,0} gather(%parameter.0, %iota.10),
9789     offset_dims={0}, collapsed_slice_dims={1}, start_index_map={1}, index_vector_dim=1,
9790     slice_sizes={1,1}, sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9791 })";
9792   TF_ASSERT_OK_AND_ASSIGN(auto module,
9793                           PartitionComputation(hlo_string, /*num_devices=*/8));
9794 
9795   const HloInstruction* root = module->entry_computation()->root_instruction();
9796   auto param0 = AllOf(op::Shape("s32[1,1]"), op::Parameter());
9797   EXPECT_THAT(root, op::Gather(param0, _));
9798 }
9799 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint)9800 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint) {
9801   absl::string_view hlo_string = R"(
9802 HloModule module
9803 
9804 ENTRY %module {
9805   %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
9806     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9807   %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
9808     sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
9809   %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
9810     convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
9811     bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
9812     window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
9813     rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
9814     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9815   ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
9816     reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
9817     sharding={replicated}
9818 })";
9819   TF_ASSERT_OK_AND_ASSIGN(
9820       auto module,
9821       PartitionComputation(hlo_string, /*num_devices=*/8,
9822                            /*conv_halo_exchange_always_on_lhs =*/true,
9823                            /*choose_faster_windowed_einsum =*/false));
9824   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9825   EXPECT_NE(while_inst, nullptr);
9826   const HloComputation* cond_comp = while_inst->while_condition();
9827   const HloInstruction* root = cond_comp->root_instruction();
9828   EXPECT_THAT(root, op::Compare(_, op::Constant()));
9829   const HloConstantInstruction* iterations =
9830       Cast<HloConstantInstruction>(root->operand(1));
9831   EXPECT_TRUE(iterations->literal().GetFirstInteger());
9832   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
9833 }
9834 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations)9835 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations) {
9836   absl::string_view hlo_string = R"(
9837 HloModule module
9838 
9839 ENTRY %module {
9840   %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
9841     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9842   %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
9843     sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
9844   %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
9845     convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
9846     bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
9847     window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
9848     rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
9849     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9850   ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
9851     reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
9852     sharding={replicated}
9853 })";
9854   TF_ASSERT_OK_AND_ASSIGN(
9855       auto module,
9856       PartitionComputation(hlo_string, /*num_devices=*/8,
9857                            /*conv_halo_exchange_always_on_lhs =*/true,
9858                            /*choose_faster_windowed_einsum =*/true));
9859   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9860   EXPECT_NE(while_inst, nullptr);
9861   const HloComputation* cond_comp = while_inst->while_condition();
9862   const HloInstruction* root = cond_comp->root_instruction();
9863   EXPECT_THAT(root, op::Compare(_, op::Constant()));
9864   const HloConstantInstruction* iterations =
9865       Cast<HloConstantInstruction>(root->operand(1));
9866   EXPECT_TRUE(iterations->literal().GetFirstInteger());
9867   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 2);
9868 }
9869 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations2)9870 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations2) {
9871   const char* const hlo_string = R"(
9872 HloModule module
9873 
9874 ENTRY entry {
9875   %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
9876   %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
9877   sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
9878             18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9879   %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
9880   %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
9881     sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
9882               17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9883   %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
9884     bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
9885       devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9886         20,21,22,23,24,25,26,27,28,29,30,31}
9887   %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
9888     reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
9889     devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9890              20,21,22,23,24,25,26,27,28,29,30,31}
9891   %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9892     convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
9893     bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
9894     window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
9895     dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
9896     12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
9897     27,31}
9898   ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9899    copy(%convolution.10), sharding={replicated}
9900 })";
9901   TF_ASSERT_OK_AND_ASSIGN(
9902       auto module,
9903       PartitionComputation(hlo_string, /*num_devices=*/32,
9904                            /*conv_halo_exchange_always_on_lhs =*/true,
9905                            /*choose_faster_windowed_einsum =*/true));
9906   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9907   EXPECT_NE(while_inst, nullptr);
9908   const HloComputation* cond_comp = while_inst->while_condition();
9909   const HloInstruction* root = cond_comp->root_instruction();
9910   EXPECT_THAT(root, op::Compare(_, op::Constant()));
9911   const HloConstantInstruction* iterations =
9912       Cast<HloConstantInstruction>(root->operand(1));
9913   EXPECT_TRUE(iterations->literal().GetFirstInteger());
9914   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
9915 }
9916 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint2)9917 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint2) {
9918   const char* const hlo_string = R"(
9919 HloModule module
9920 
9921 ENTRY entry {
9922   %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
9923   %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
9924   sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
9925             18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9926   %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
9927   %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
9928     sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
9929               17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9930   %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
9931     bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
9932       devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9933         20,21,22,23,24,25,26,27,28,29,30,31}
9934   %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
9935     reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
9936     devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9937              20,21,22,23,24,25,26,27,28,29,30,31}
9938   %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9939     convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
9940     bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
9941     window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
9942     dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
9943     12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
9944     27,31}
9945   ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9946    copy(%convolution.10), sharding={replicated}
9947 })";
9948   TF_ASSERT_OK_AND_ASSIGN(
9949       auto module,
9950       PartitionComputation(hlo_string, /*num_devices=*/32,
9951                            /*conv_halo_exchange_always_on_lhs =*/true,
9952                            /*choose_faster_windowed_einsum =*/false));
9953   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9954   EXPECT_NE(while_inst, nullptr);
9955   const HloComputation* cond_comp = while_inst->while_condition();
9956   const HloInstruction* root = cond_comp->root_instruction();
9957   EXPECT_THAT(root, op::Compare(_, op::Constant()));
9958   const HloConstantInstruction* iterations =
9959       Cast<HloConstantInstruction>(root->operand(1));
9960   EXPECT_TRUE(iterations->literal().GetFirstInteger());
9961   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 8);
9962 }
9963 
TEST_F(SpmdPartitioningTest,ContractingPartitionDotOperandsSlicedWrong)9964 TEST_F(SpmdPartitioningTest, ContractingPartitionDotOperandsSlicedWrong) {
9965   const char* const hlo_string = R"(
9966 HloModule module
9967 
9968 ENTRY entry {
9969   %lhs = f32[8,2,15,4] parameter(0)
9970   %lhs.copy = f32[8,2,15,4] copy(%lhs),
9971     sharding={devices=[1,2,4,1]0,1,2,3,4,5,6,7}
9972   %rhs = f32[2,15,4] parameter(1)
9973   %rhs.copy = f32[2,15,4] copy(%rhs),
9974     sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
9975   %dot = f32[8,2,2] dot(%lhs.copy, %rhs.copy),
9976     lhs_batch_dims={}, rhs_batch_dims={},
9977     lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2},
9978     operand_precision={HIGH,HIGH},
9979     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
9980   ROOT %output = f32[8,2,2] copy(%dot), sharding={replicated}
9981 })";
9982   TF_ASSERT_OK_AND_ASSIGN(
9983       auto module,
9984       PartitionComputation(hlo_string, /*num_devices=*/8,
9985                            /*conv_halo_exchange_always_on_lhs =*/true,
9986                            /*choose_faster_windowed_einsum =*/true));
9987 
9988   const HloInstruction* dot_op = FindInstruction(module.get(), HloOpcode::kDot);
9989   auto op1 = op::Shape("f32[4,2,4,4]");
9990   auto op2 = op::Shape("f32[2,4,4]");
9991   EXPECT_THAT(dot_op, op::Dot(op1, op2));
9992 }
9993 
TEST_F(SpmdPartitioningTest,PartitionDotGroupOnBatchContractingReshard)9994 TEST_F(SpmdPartitioningTest, PartitionDotGroupOnBatchContractingReshard) {
9995   absl::string_view hlo_string = R"(
9996 HloModule module
9997 
9998 ENTRY entry {
9999   %lhs = f32[32,32,24,4096] parameter(0),
10000     sharding={devices=[2,1,1,2]0,1,2,3}
10001   %rhs = f32[32,4096,1024] parameter(1),
10002     sharding={devices=[2,2,1]0,1,2,3}
10003   ROOT %dot = f32[32,32,24,1024] dot(%lhs, %rhs),
10004     lhs_batch_dims={0}, rhs_batch_dims={0},
10005     lhs_contracting_dims={3}, rhs_contracting_dims={1},
10006     sharding={devices=[1,2,1,2]0,1,2,3}
10007 })";
10008 
10009   TF_ASSERT_OK_AND_ASSIGN(
10010       auto module,
10011       PartitionComputation(hlo_string, /*num_devices=*/4,
10012                            /*conv_halo_exchange_always_on_lhs=*/true,
10013                            /*choose_faster_windowed_einsum=*/true));
10014   VLOG(1) << module->ToString();
10015   const auto root = module->entry_computation()->root_instruction();
10016   auto dot = AllOf(op::Shape("f32[16,32,24,1024]"),
10017                    op::Dot(op::Parameter(0), op::Parameter(1)));
10018   auto reduce_scatter = AllOf(op::Shape("f32[16,32,24,512]"),
10019                               op::DynamicSlice(op::AllReduce(dot), _, _, _, _));
10020   EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(
10021                               op::AllToAll(op::Reshape(reduce_scatter)))),
10022                           op::Shape("f32[32,16,24,512]")));
10023 }
10024 
TEST_F(SpmdPartitioningTest,PartitionPassthroughScatterCorrectOutputSharding)10025 TEST_F(SpmdPartitioningTest, PartitionPassthroughScatterCorrectOutputSharding) {
10026   absl::string_view hlo_string = R"(
10027 HloModule module
10028 
10029 %scatter_add (parameter.0: bf16[], parameter.1: bf16[]) -> bf16[] {
10030   %parameter.0 = bf16[] parameter(0)
10031   %parameter.1 = bf16[] parameter(1)
10032   ROOT %add = bf16[] add(bf16[] %parameter.0, bf16[] %parameter.1)
10033 }
10034 
10035 ENTRY entry {
10036   %operand = bf16[2,1024]{1,0} parameter(0),
10037     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
10038   %indices = s32[8,512,1]{2,1,0} parameter(1),
10039     sharding={devices=[2,1,1,2]0,2,1,3 last_tile_dim_replicate}
10040   %updates = bf16[8,512,1024]{2,1,0} parameter(2),
10041     sharding={devices=[2,1,2]0,2,1,3}
10042   ROOT %scatter = bf16[2,1024]{1,0} scatter(bf16[2,1024]{1,0} %operand,
10043     s32[8,512,1]{2,1,0} %indices,
10044     bf16[8,512,1024]{2,1,0} %updates), update_window_dims={2},
10045     inserted_window_dims={0}, scatter_dims_to_operand_dims={0},
10046     index_vector_dim=2, to_apply=%scatter_add,
10047     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
10048 })";
10049 
10050   TF_ASSERT_OK_AND_ASSIGN(auto module,
10051                           PartitionComputation(hlo_string, /*num_devices=*/4));
10052   VLOG(1) << module->ToString();
10053   const auto root = module->entry_computation()->root_instruction();
10054   auto scatter = AllOf(op::Shape("bf16[2,512]"), op::Scatter(_, _, _));
10055   EXPECT_THAT(root, scatter);
10056 }
10057 
IsTrivialCollectivePermute(HloInstruction * hlo)10058 bool IsTrivialCollectivePermute(HloInstruction* hlo) {
10059   if (hlo->opcode() != HloOpcode::kCollectivePermute) {
10060     return false;
10061   }
10062   if (hlo->source_target_pairs().empty()) {
10063     return true;
10064   }
10065   return absl::c_all_of(hlo->source_target_pairs(),
10066                         [](const std::pair<int64_t, int64_t>& pair) {
10067                           return pair.first == pair.second;
10068                         });
10069 }
10070 
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyIdentity)10071 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyIdentity) {
10072   absl::string_view hlo_string = R"(
10073 HloModule test
10074 
10075 ENTRY entry {
10076   %parameter.7 = f32[3,16] parameter(0), sharding={devices=[1,2]0,1}
10077   %constant.7 = f32[] constant(0)
10078   %pad.3 = f32[3,18] pad(f32[3,16] %parameter.7, f32[] %constant.7), padding=0_0x1_1, sharding={devices=[1,2]0,1}
10079   // Shift right by 16.
10080   %slice.8 = f32[3,16] slice(f32[3,18] %pad.3), slice={[0:3], [2:18]}, sharding={devices=[1,2]0,1}
10081   %slice.9 = f32[3,2] slice(f32[3,18] %pad.3), slice={[0:3], [0:2]}, sharding={devices=[1,2]0,1}
10082   ROOT %concatenate.6 = f32[3,18] concatenate(f32[3,16] %slice.8, f32[3,2] %slice.9), dimensions={1}, sharding={devices=[1,2]0,1}
10083 }
10084 )";
10085 
10086   TF_ASSERT_OK_AND_ASSIGN(auto module,
10087                           PartitionComputation(hlo_string, /*num_devices=*/2));
10088   VLOG(1) << module->ToString();
10089 
10090   // Check that the partitioned code does not have a "trivial" collective
10091   // permute (which would degenerate to a copy).
10092   for (HloComputation* computation : module->computations()) {
10093     for (HloInstruction* hlo : computation->instructions()) {
10094       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10095     }
10096   }
10097 }
10098 
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyZero)10099 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyZero) {
10100   absl::string_view hlo_string = R"(
10101 HloModule test
10102 
10103 ENTRY entry {
10104   %parameter = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
10105   %slice = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
10106   %c0 = f32[] constant(0)
10107   ROOT %pad = f32[3,18,16,16,16,132]{5,4,3,2,1,0} pad(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice, f32[] %c0), padding=0_0x0_17x0_0x0_0x0_0x0_0, sharding={devices=[1,2,1,1,1,1]0,1}
10108 }
10109 )";
10110 
10111   TF_ASSERT_OK_AND_ASSIGN(auto module,
10112                           PartitionComputation(hlo_string, /*num_devices=*/2));
10113   VLOG(1) << module->ToString();
10114 
10115   // Check that the partitioned code does not have a collective permute with an
10116   // empty source_target_pair list.
10117   for (HloComputation* computation : module->computations()) {
10118     for (HloInstruction* hlo : computation->instructions()) {
10119       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10120     }
10121   }
10122 }
10123 
TEST_F(SpmdPartitioningTest,PadWithWrapPattern)10124 TEST_F(SpmdPartitioningTest, PadWithWrapPattern) {
10125   absl::string_view hlo_string = R"(
10126 HloModule xla_computation_apply_fn__4.61
10127 
10128 ENTRY %xla_computation_apply_fn__4.61 (parameter.7: f32[3,16,16,16,16,132]) -> f32[3,18,16,16,16,132] {
10129   %parameter.7 = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
10130   %slice.2 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
10131   %slice.3 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [0:1], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
10132   ROOT %concatenate.3 = f32[3,18,16,16,16,132]{5,4,3,2,1,0} concatenate(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.2, f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7, f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.3), dimensions={1}, sharding={devices=[1,2,1,1,1,1]0,1}
10133 }
10134 )";
10135 
10136   TF_ASSERT_OK_AND_ASSIGN(auto module,
10137                           PartitionComputation(hlo_string, /*num_devices=*/2));
10138   VLOG(1) << module->ToString();
10139 
10140   // Check that the partitioned code does not have all-reduce and two
10141   // non-trivial collective permute instructions.
10142   for (HloComputation* computation : module->computations()) {
10143     for (HloInstruction* hlo : computation->instructions()) {
10144       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10145       EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
10146     }
10147   }
10148 }
10149 
TEST_F(SpmdPartitioningTest,PadWrapWithNegatePattern)10150 TEST_F(SpmdPartitioningTest, PadWrapWithNegatePattern) {
10151   absl::string_view hlo_string = R"(
10152 HloModule module
10153 
10154 ENTRY entry {
10155   %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
10156   %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
10157   %negate.2 = f32[1,2] negate(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
10158   %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
10159   %negate.3 = f32[1,2] negate(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
10160   ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %negate.2, f32[1,18] %parameter.1, f32[1,2] %negate.3), dimensions={1}, sharding={devices=[1,2]0,1}
10161 }
10162 )";
10163   TF_ASSERT_OK_AND_ASSIGN(auto module,
10164                           PartitionComputation(hlo_string, /*num_devices=*/2));
10165   VLOG(1) << module->ToString();
10166 
10167   // Check that the partitioned code does not have all-reduce or trivial
10168   // collective permute
10169   for (HloComputation* computation : module->computations()) {
10170     for (HloInstruction* hlo : computation->instructions()) {
10171       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10172       EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
10173     }
10174   }
10175 }
10176 
TEST_F(SpmdPartitioningTest,PadWrapWithMultipleModifiersPattern)10177 TEST_F(SpmdPartitioningTest, PadWrapWithMultipleModifiersPattern) {
10178   absl::string_view hlo_string = R"(
10179 HloModule module
10180 
10181 ENTRY entry {
10182   %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
10183   %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
10184   %mod0.16 = f32[1,2] rsqrt(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
10185   %mod1.16 = f32[1,2] sine(f32[1,2] %mod0.16), sharding={devices=[1,2]0,1}
10186   %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
10187   %mod0.17 = f16[1,2] convert(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
10188   %mod1.17 = f16[1,2] cosine(f16[1,2] %mod0.17), sharding={devices=[1,2]0,1}
10189   %mod2.17 = f32[1,2] convert(f16[1,2] %mod1.17), sharding={devices=[1,2]0,1}
10190   ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %mod1.16, f32[1,18] %parameter.1, f32[1,2] %mod2.17), dimensions={1}, sharding={devices=[1,2]0,1}
10191 }
10192 )";
10193   TF_ASSERT_OK_AND_ASSIGN(auto module,
10194                           PartitionComputation(hlo_string, /*num_devices=*/2));
10195   VLOG(1) << module->ToString();
10196 
10197   // Check that the partitioned code does not have all-reduce or trivial
10198   // collective permute. Also make sure modifiers have the right dependencies.
10199   for (HloComputation* computation : module->computations()) {
10200     for (HloInstruction* hlo : computation->instructions()) {
10201       const HloOpcode op = hlo->opcode();
10202       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10203       EXPECT_NE(op, HloOpcode::kAllReduce) << hlo->ToString();
10204       if (hlo->operand_count() != 1) {
10205         continue;
10206       }
10207       const PrimitiveType type = hlo->shape().element_type();
10208       const HloOpcode child_op = hlo->operand(0)->opcode();
10209       const PrimitiveType child_type = hlo->operand(0)->shape().element_type();
10210 
10211       if (op == HloOpcode::kSin) {
10212         EXPECT_EQ(child_op, HloOpcode::kRsqrt);
10213       } else if (op == HloOpcode::kConvert && type == F32) {
10214         EXPECT_EQ(child_op, HloOpcode::kCos);
10215         EXPECT_EQ(child_type, F16);
10216       } else if (op == HloOpcode::kCos) {
10217         EXPECT_EQ(child_op, HloOpcode::kConvert);
10218         EXPECT_EQ(child_type, F16);
10219       }
10220     }
10221   }
10222 }
10223 
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate)10224 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate) {
10225   absl::string_view hlo_string = R"(
10226 HloModule module
10227 
10228 ENTRY entry {
10229   %param0 = f32[1,1] parameter(0), sharding={devices=[2,2]0,1,2,3}
10230   ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
10231 })";
10232 
10233   TF_ASSERT_OK_AND_ASSIGN(auto module,
10234                           PartitionComputation(hlo_string, /*num_devices=*/4));
10235   VLOG(1) << module->ToString();
10236 
10237   const auto root = module->entry_computation()->root_instruction();
10238   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
10239   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
10240                           op::Shape("f32[1,1]")));
10241 }
10242 
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate2)10243 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate2) {
10244   absl::string_view hlo_string = R"(
10245 HloModule module
10246 
10247 ENTRY entry {
10248   %param0 = f32[1,2] parameter(0), sharding={devices=[2,2]0,1,2,3}
10249   ROOT %copy = f32[1,2] copy(%param0), sharding={replicated}
10250 })";
10251 
10252   TF_ASSERT_OK_AND_ASSIGN(auto module,
10253                           PartitionComputation(hlo_string, /*num_devices=*/4));
10254   VLOG(1) << module->ToString();
10255 
10256   const auto root = module->entry_computation()->root_instruction();
10257   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
10258   auto broadcast =
10259       AllOf(op::AllReduce(op::Select(_, param0, _)), op::Shape("f32[1,1]"));
10260   EXPECT_THAT(
10261       root,
10262       AllOf(op::Copy(op::AllReduce(op::DynamicUpdateSlice(_, broadcast, _, _))),
10263             op::Shape("f32[1,2]")));
10264 }
10265 
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate3)10266 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate3) {
10267   absl::string_view hlo_string = R"(
10268 HloModule module
10269 
10270 ENTRY entry {
10271   %param0 = f32[1,1] parameter(0),
10272     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
10273   ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
10274 })";
10275 
10276   TF_ASSERT_OK_AND_ASSIGN(auto module,
10277                           PartitionComputation(hlo_string, /*num_devices=*/4));
10278   VLOG(1) << module->ToString();
10279 
10280   const auto root = module->entry_computation()->root_instruction();
10281   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
10282   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
10283                           op::Shape("f32[1,1]")));
10284 }
10285 
TEST_F(SpmdPartitioningTest,TupleWithSubgroupManual)10286 TEST_F(SpmdPartitioningTest, TupleWithSubgroupManual) {
10287   absl::string_view hlo_string = R"(
10288 HloModule module
10289 
10290 ENTRY entry {
10291   constant = f32[6,3]{1,0}
10292     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
10293     sharding={replicated}
10294   param = (f32[6,3]{1,0}, f32[]) parameter(0),
10295     sharding={{devices=[2,1,2]0,1,2,3 last_tile_dims={manual}},{replicated}}
10296   gte = f32[6,3]{1,0} get-tuple-element(param), index=0,
10297     sharding={devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}
10298   ROOT tuple = (f32[6,3]{1,0}, f32[6,3]{1,0}) tuple(constant, gte),
10299     sharding={{replicated},{devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}}
10300 }
10301 )";
10302 
10303   TF_ASSERT_OK_AND_ASSIGN(auto module,
10304                           PartitionComputation(hlo_string, /*num_devices=*/4));
10305   VLOG(1) << module->ToString();
10306   const auto root = module->entry_computation()->root_instruction();
10307   EXPECT_THAT(root,
10308               op::Tuple(op::Constant(), op::GetTupleElement(op::Parameter(0))));
10309 }
10310 
TEST_F(SpmdPartitioningTest,SubgroupManualSharedOperand)10311 TEST_F(SpmdPartitioningTest, SubgroupManualSharedOperand) {
10312   absl::string_view hlo_string = R"(
10313 HloModule module
10314 
10315 ENTRY entry {
10316   constant = f32[] constant(1), sharding={replicated}
10317   broadcast = f32[2,2] broadcast(constant), dimensions={},
10318     sharding={devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}
10319   ROOT add = f32[2,2] add(broadcast, broadcast),
10320     sharding={devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}
10321 }
10322 )";
10323 
10324   TF_ASSERT_OK_AND_ASSIGN(auto module,
10325                           PartitionComputation(hlo_string, /*num_devices=*/4));
10326   VLOG(1) << module->ToString();
10327   const auto root = module->entry_computation()->root_instruction();
10328   EXPECT_THAT(root, op::Add(op::Broadcast(op::Constant()),
10329                             op::Broadcast(op::Constant())));
10330 }
10331 
TEST_F(SpmdPartitioningTest,SubgroupManualAllReduce)10332 TEST_F(SpmdPartitioningTest, SubgroupManualAllReduce) {
10333   absl::string_view hlo_string = R"(
10334 HloModule module
10335 
10336 sum {
10337   a = f32[] parameter(0)
10338   b = f32[] parameter(1)
10339   ROOT add = f32[] add(a, b)
10340 }
10341 
10342 ENTRY entry {
10343   param = f32[2,2] parameter(0),
10344     sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10345   ROOT all-reduce = f32[2,2]{1,0} all-reduce(param), to_apply=sum,
10346     replica_groups={{2,0},{1,3}}, use_global_device_ids=true, channel_id=1,
10347     sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10348 }
10349 )";
10350 
10351   TF_ASSERT_OK_AND_ASSIGN(auto module,
10352                           PartitionComputation(hlo_string, /*num_devices=*/4));
10353   VLOG(1) << module->ToString();
10354   const auto root = module->entry_computation()->root_instruction();
10355   EXPECT_THAT(root,
10356               AllOf(op::AllReduce(op::Parameter(0)), op::Shape("f32[1,2]")));
10357   EXPECT_EQ(root->replica_groups().size(), 2);
10358 }
10359 
TEST_F(SpmdPartitioningTest,SubgroupIllegalManualAllReduce)10360 TEST_F(SpmdPartitioningTest, SubgroupIllegalManualAllReduce) {
10361   absl::string_view hlo_string = R"(
10362 HloModule module
10363 
10364 sum {
10365   a = f32[] parameter(0)
10366   b = f32[] parameter(1)
10367   ROOT add = f32[] add(a, b)
10368 }
10369 
10370 ENTRY entry {
10371   param = f32[2,2] parameter(0),
10372     sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10373   ROOT all-reduce = f32[2,2]{1,0} all-reduce(param), to_apply=sum,
10374     replica_groups={{1,0},{2,3}}, use_global_device_ids=true, channel_id=1,
10375     sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10376 }
10377 )";
10378 
10379   auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
10380   EXPECT_FALSE(module_status.status().ok());
10381   EXPECT_THAT(module_status.status().ToString(),
10382               ::testing::HasSubstr("Manual all-reduce across devices that "
10383                                    "belong to different manual subgroups"));
10384 }
10385 
TEST_F(SpmdPartitioningTest,SubgroupManualReduce)10386 TEST_F(SpmdPartitioningTest, SubgroupManualReduce) {
10387   absl::string_view hlo_string = R"(
10388 HloModule module
10389 
10390 sum {
10391   a = f32[] parameter(0)
10392   b = f32[] parameter(1)
10393   ROOT add = f32[] add(a, b)
10394 }
10395 
10396 ENTRY entry {
10397   constant = f32[] constant(0),
10398     sharding={devices=[2,2]0,1,2,3 last_tile_dims={manual,replicated}}
10399   param = f32[2,2] parameter(0),
10400     sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10401   ROOT reduce = f32[2] reduce(param, constant), dimensions={0}, to_apply=sum,
10402     sharding={devices=[2,2]0,1,2,3 last_tile_dims={manual,replicated}}
10403 }
10404 )";
10405 
10406   TF_ASSERT_OK_AND_ASSIGN(auto module,
10407                           PartitionComputation(hlo_string, /*num_devices=*/4));
10408   VLOG(1) << module->ToString();
10409   const auto root = module->entry_computation()->root_instruction();
10410   EXPECT_THAT(root,
10411               op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())));
10412   EXPECT_EQ(root->replica_groups().size(), 2);
10413 }
10414 
TEST_F(SpmdPartitioningTest,ScatterPreferUpdateIndexIfSmaller)10415 TEST_F(SpmdPartitioningTest, ScatterPreferUpdateIndexIfSmaller) {
10416   absl::string_view hlo_string = R"(
10417 HloModule module
10418 
10419 %scatter_add_reducer__33.191857 (parameter.191858: bf16[], parameter.191859: bf16[]) -> bf16[] {
10420   %parameter.191858 = bf16[] parameter(0)
10421   %parameter.191859 = bf16[] parameter(1)
10422   ROOT %add.4425 = bf16[] add(bf16[] %parameter.191858, bf16[] %parameter.191859)
10423 }
10424 
10425 ENTRY entry {
10426   p1 = s32[2048,1024,1]{2,1,0} parameter(0)
10427   p2 = bf16[2048,1024,2040]{2,1,0} parameter(1)
10428   %constant.8635 = bf16[] constant(0)
10429   %broadcast.21781 = bf16[50048,2040]{1,0} broadcast(bf16[] %constant.8635), dimensions={},
10430     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10431   %select.1954 = s32[2048,1024,1]{2,1,0} copy(%p1), sharding={devices=[4,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10432   %slice.1274 = bf16[2048,1024,2040]{2,1,0} copy(%p2),
10433   sharding={devices=[4,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10434   %scatter.34 = bf16[50048,2040]{1,0} scatter(bf16[50048,2040]{1,0} %broadcast.21781,
10435     s32[2048,1024,1]{2,1,0} %select.1954, bf16[2048,1024,2040]{2,1,0} %slice.1274),
10436     update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0},
10437     index_vector_dim=2, to_apply=%scatter_add_reducer__33.191857,
10438     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10439   ROOT c = bf16[50048,2040]{1,0} copy(scatter.34),
10440     sharding={replicated}
10441 }
10442 )";
10443 
10444   TF_ASSERT_OK_AND_ASSIGN(auto module,
10445                           PartitionComputation(hlo_string, /*num_devices=*/8));
10446   VLOG(1) << module->ToString();
10447   const auto root = module->entry_computation()->root_instruction();
10448   EXPECT_THAT(root,
10449               op::Copy(op::AllReduce(op::Scatter(
10450                   op::Shape("bf16[50048,2040]"), op::Shape("s32[512,1024,1]"),
10451                   op::Shape("bf16[512,1024,2040]")))));
10452 }
10453 
TEST_F(SpmdPartitioningTest,ScatterPreferTrivialIfSmallerThanIndices)10454 TEST_F(SpmdPartitioningTest, ScatterPreferTrivialIfSmallerThanIndices) {
10455   absl::string_view hlo_string = R"(
10456 HloModule module
10457 
10458 %scatter_add_reducer__33.191857 (parameter.191858: bf16[], parameter.191859: bf16[]) -> bf16[] {
10459   %parameter.191858 = bf16[] parameter(0)
10460   %parameter.191859 = bf16[] parameter(1)
10461   ROOT %add.4425 = bf16[] add(bf16[] %parameter.191858, bf16[] %parameter.191859)
10462 }
10463 
10464 ENTRY entry {
10465   p1 = s32[32,512,3]{2,1,0} parameter(0)
10466   p2 = bf16[32,512]{1,0} parameter(1)
10467   %constant.8635 = bf16[] constant(0)
10468   %broadcast.21781 = bf16[32,512,50001]{2,1,0} broadcast(bf16[] %constant.8635), dimensions={},
10469     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10470   %select.1954 = s32[32,512,3]{2,1,0} copy(%p1), sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10471   %slice.1274 = bf16[32,512]{1,0} copy(%p2),
10472   sharding={devices=[1,4,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10473   %scatter.34 = bf16[32,512,50001]{2,1,0} scatter(bf16[32,512,50001]{2,1,0} %broadcast.21781,
10474     s32[32,512,3]{2,1,0} %select.1954, bf16[32,512]{1,0} %slice.1274),
10475     update_window_dims={}, inserted_window_dims={0,1,2}, scatter_dims_to_operand_dims={0,1,2},
10476     index_vector_dim=2, to_apply=%scatter_add_reducer__33.191857,
10477     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10478   ROOT c = bf16[32,512,50001]{2,1,0} copy(scatter.34),
10479     sharding={replicated}
10480 }
10481 )";
10482 
10483   TF_ASSERT_OK_AND_ASSIGN(auto module,
10484                           PartitionComputation(hlo_string, /*num_devices=*/8));
10485   VLOG(1) << module->ToString();
10486   const auto root = module->entry_computation()->root_instruction();
10487   EXPECT_THAT(
10488       root,
10489       op::Copy(op::AllReduce(op::DynamicUpdateSlice(
10490           _,
10491           op::Scatter(op::Shape("bf16[32,128,50001]"),
10492                       op::Shape("s32[32,512,3]"), op::Shape("bf16[32,512]")),
10493           _, _, _))));
10494 }
10495 
TEST_F(SpmdPartitioningTest,GatherOperandPassthroughIndexPassthrough)10496 TEST_F(SpmdPartitioningTest, GatherOperandPassthroughIndexPassthrough) {
10497   const char* const hlo_string = R"(
10498 HloModule module
10499 
10500 ENTRY entry {
10501   %input = f32[2,9] parameter(0), sharding={replicated}
10502   %indices = s32[7] parameter(1), sharding={replicated}
10503   %input.copy = f32[2,9] copy(%input), sharding={devices=[1,2,2]1,0,3,2 last_tile_dim_replicate}
10504   %indices.copy = s32[7] copy(%indices), sharding={devices=[2,2]1,2,3,0 last_tile_dim_replicate}
10505   %gather = f32[7,9] gather(%input.copy, %indices.copy), offset_dims={1},
10506     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
10507     slice_sizes={1,9}, sharding={devices=[2,2]0,1,2,3}
10508   ROOT %copy = f32[7,9] copy(%gather), sharding={replicated}
10509 })";
10510 
10511   TF_ASSERT_OK_AND_ASSIGN(auto module,
10512                           PartitionComputation(hlo_string, /*num_devices=*/4));
10513   VLOG(1) << module->ToString();
10514   const HloInstruction* gather = FindInstruction(module.get(), "gather.1");
10515   EXPECT_NE(gather, nullptr);
10516   EXPECT_THAT(gather,
10517               AllOf(op::Shape("f32[4,5]"),
10518                     op::Gather(op::Shape("f32[2,5]"), op::Shape("s32[4]"))));
10519 }
10520 
TEST_F(SpmdPartitioningTest,GatherIndexPassthroughTrivialSlice)10521 TEST_F(SpmdPartitioningTest, GatherIndexPassthroughTrivialSlice) {
10522   const char* const hlo_string = R"(
10523 HloModule module
10524 
10525 ENTRY entry {
10526   %input = f32[17,9] parameter(0)
10527   %indices = s32[2,3] parameter(1)
10528   %input.copy = f32[17,9] copy(%input), sharding={devices=[2,1,2]3,2,1,0 last_tile_dim_replicate}
10529   %indices.copy = s32[2,3] copy(%indices), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
10530   %gather = f32[2,3,9] gather(%input.copy, %indices.copy), offset_dims={2},
10531     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
10532     slice_sizes={1,9}, sharding={devices=[2,1,1,2]1,0,3,2 last_tile_dim_replicate}
10533   ROOT %copy = f32[2,3,9] copy(%gather), sharding={replicated}
10534 })";
10535 
10536   TF_ASSERT_OK_AND_ASSIGN(auto module,
10537                           PartitionComputation(hlo_string, /*num_devices=*/4));
10538   VLOG(1) << module->ToString();
10539   const HloInstruction* gather = FindInstruction(module.get(), "gather.1");
10540   EXPECT_NE(gather, nullptr);
10541   EXPECT_THAT(gather,
10542               AllOf(op::Shape("f32[1,3,9]"),
10543                     op::Gather(op::Shape("f32[9,9]"), op::Shape("s32[1,3]"))));
10544 }
10545 
TEST_F(SpmdPartitioningTest,GatherReplicatedCorrectOutput)10546 TEST_F(SpmdPartitioningTest, GatherReplicatedCorrectOutput) {
10547   const char* const hlo_string = R"(
10548 HloModule module
10549 
10550 ENTRY entry {
10551   %input = f32[64,2,250112] parameter(0), sharding={
10552     devices=[16,1,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,
10553                     23,24,25,26,27,28,29,30,31}
10554   %indices = s32[10,1] parameter(1), sharding={replicated}
10555   %input.copy = f32[64,2,250112] copy(%input), sharding={
10556     devices=[16,1,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,
10557                     23,24,25,26,27,28,29,30,31}
10558   %indices.copy = s32[10,1] copy(%indices), sharding={replicated}
10559   %gather = f32[64,2,10] gather(f32[64,2,250112] %input,
10560     s32[10,1]{1,0} %indices.copy), offset_dims={0,1}, collapsed_slice_dims={2},
10561     start_index_map={2}, index_vector_dim=1, slice_sizes={64,2,1},
10562     sharding={devices=[16,1,1,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,
10563                                 19,20,21,22,23,24,25,26,27,28,29,30,
10564                                 31 last_tile_dim_replicate}
10565   ROOT %copy = (f32[64,2,10]) tuple(gather), sharding={{devices=[16,1,1,2]0,1,2,
10566     3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,
10567     30,31 last_tile_dim_replicate}}
10568 })";
10569 
10570   TF_ASSERT_OK_AND_ASSIGN(auto module,
10571                           PartitionComputation(hlo_string, /*num_devices=*/32));
10572   VLOG(1) << module->ToString();
10573   EXPECT_THAT(module->entry_computation()->root_instruction(),
10574               op::Shape("(f32[4,2,10])"));
10575 }
10576 
TEST_F(SpmdPartitioningTest,GatherTrivialRestoreSharding)10577 TEST_F(SpmdPartitioningTest, GatherTrivialRestoreSharding) {
10578   const char* const hlo_string = R"(
10579 HloModule module
10580 
10581 ENTRY entry {
10582   %input = bf16[250112,4096] parameter(0), sharding={replicated}
10583   %cpy.input = bf16[250112,4096] copy(%input), sharding={
10584     devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,
10585                     23,24,25,26,27,28,29,30,31}
10586   %indices = s32[64,1,1] parameter(1), sharding={replicated}
10587   %cpy.indices = s32[64,1,1] copy(%indices), sharding={replicated}
10588   %gather = bf16[64,1,4096] gather(bf16[250112,4096] %cpy.input, s32[64,1,1] %cpy.indices),
10589     offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0},
10590     index_vector_dim=2, slice_sizes={1,4096}, sharding={replicated}
10591   ROOT %copy = bf16[64,1,4096] copy(gather), sharding={replicated}
10592 })";
10593 
10594   TF_ASSERT_OK_AND_ASSIGN(auto module,
10595                           PartitionComputation(hlo_string, /*num_devices=*/32));
10596   VLOG(1) << module->ToString();
10597   EXPECT_THAT(module->entry_computation()->root_instruction(),
10598               op::Shape("bf16[64,1,4096]"));
10599   EXPECT_THAT(module->entry_computation()->root_instruction(),
10600               op::Copy(op::AllReduce(op::Select(
10601                   _, _, op::Gather(op::Shape("bf16[7816,4096]"), _)))));
10602 }
10603 
TEST_F(SpmdPartitioningTest,SliceTo1)10604 TEST_F(SpmdPartitioningTest, SliceTo1) {
10605   const char* const hlo_string = R"(
10606 HloModule module
10607 
10608 ENTRY entry {
10609   %input = f32[512] parameter(0), sharding={devices=[4]0,1,2,3}
10610   ROOT slice.134 = f32[1] slice(input), slice={[0:1]},
10611     sharding={devices=[4]0,1,2,3}
10612 })";
10613 
10614   TF_ASSERT_OK_AND_ASSIGN(auto module,
10615                           PartitionComputation(hlo_string, /*num_devices=*/4));
10616   VLOG(1) << module->ToString();
10617   EXPECT_THAT(module->entry_computation()->root_instruction(),
10618               AllOf(op::Slice(op::Parameter()), op::Shape("f32[1]")));
10619 }
10620 
TEST_F(SpmdPartitioningTest,SliceTo1_8Shards)10621 TEST_F(SpmdPartitioningTest, SliceTo1_8Shards) {
10622   const char* const hlo_string = R"(
10623 HloModule module
10624 
10625 ENTRY entry {
10626   %input = f32[4,4] parameter(0), sharding={devices=[4,2]0,1,2,3,4,5,6,7}
10627   ROOT %slice = f32[1,4] slice(%input), slice={[0:1], [0:4]},
10628     sharding={devices=[4,2]0,1,2,3,4,5,6,7}
10629 })";
10630 
10631   TF_ASSERT_OK_AND_ASSIGN(auto module,
10632                           PartitionComputation(hlo_string, /*num_devices=*/8));
10633   VLOG(1) << module->ToString();
10634   EXPECT_THAT(module->entry_computation()->root_instruction(),
10635               AllOf(op::Copy(op::Parameter()), op::Shape("f32[1,2]")));
10636 }
10637 
TEST_F(SpmdPartitioningTest,SliceTo1PartialReplicate)10638 TEST_F(SpmdPartitioningTest, SliceTo1PartialReplicate) {
10639   const char* const hlo_string = R"(
10640 HloModule module
10641 
10642 ENTRY entry {
10643   %input = f32[16] parameter(0),
10644     sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
10645   ROOT slice.134 = f32[1] slice(input), slice={[0:1]},
10646     sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
10647 })";
10648 
10649   TF_ASSERT_OK_AND_ASSIGN(auto module,
10650                           PartitionComputation(hlo_string, /*num_devices=*/4));
10651   VLOG(1) << module->ToString();
10652   EXPECT_THAT(module->entry_computation()->root_instruction(),
10653               AllOf(op::Slice(op::Parameter()), op::Shape("f32[1]")));
10654 }
10655 
TEST_F(SpmdPartitioningTest,SliceTo2)10656 TEST_F(SpmdPartitioningTest, SliceTo2) {
10657   const char* const hlo_string = R"(
10658 HloModule module
10659 
10660 ENTRY entry {
10661   %input = f32[512] parameter(0), sharding={devices=[4]0,1,2,3}
10662   ROOT slice.134 = f32[2] slice(input), slice={[0:2]},
10663     sharding={devices=[4]0,1,2,3}
10664 })";
10665 
10666   TF_ASSERT_OK_AND_ASSIGN(auto module,
10667                           PartitionComputation(hlo_string, /*num_devices=*/4));
10668   VLOG(1) << module->ToString();
10669   auto slice1 = AllOf(op::Slice(op::Parameter()), op::Shape("f32[2]"));
10670   auto halo =
10671       op::CollectivePermute(AllOf(op::Slice(slice1), op::Shape("f32[1]")));
10672   auto slice_self = AllOf(op::Slice(slice1), op::Shape("f32[1]"));
10673   EXPECT_THAT(
10674       module->entry_computation()->root_instruction(),
10675       op::Copy(AllOf(op::DynamicSlice(op::Concatenate(halo, slice_self), _),
10676                      op::Shape("f32[1]"))));
10677 }
10678 
TEST_F(SpmdPartitioningTest,SliceToMiddle2)10679 TEST_F(SpmdPartitioningTest, SliceToMiddle2) {
10680   const char* const hlo_string = R"(
10681 HloModule module
10682 
10683 ENTRY entry {
10684   %input = f32[512] parameter(0), sharding={devices=[8]0,1,2,3,4,5,6,7}
10685   ROOT %slice = f32[2] slice(input), slice={[300:302]},
10686     sharding={devices=[8]0,1,2,3,4,5,6,7}
10687 })";
10688 
10689   TF_ASSERT_OK_AND_ASSIGN(auto module,
10690                           PartitionComputation(hlo_string, /*num_devices=*/8));
10691   auto slice = AllOf(op::Slice(op::Parameter()), op::Shape("f32[2]"));
10692   auto halo1 = AllOf(op::CollectivePermute(slice), op::Shape("f32[2]"));
10693   auto halo2 =
10694       AllOf(op::CollectivePermute(op::Slice(slice)), op::Shape("f32[1]"));
10695   VLOG(1) << module->ToString();
10696   EXPECT_THAT(module->entry_computation()->root_instruction(),
10697               op::Copy(AllOf(op::DynamicSlice(op::Concatenate(halo1, halo2), _),
10698                              op::Shape("f32[1]"))));
10699 }
10700 
TEST_F(SpmdPartitioningTest,SliceToMiddle2PartiallyReplicated)10701 TEST_F(SpmdPartitioningTest, SliceToMiddle2PartiallyReplicated) {
10702   const char* const hlo_string = R"(
10703 HloModule module
10704 
10705 ENTRY entry {
10706   %input = f32[512] parameter(0),
10707     sharding={devices=[8,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
10708   ROOT %slice = f32[2] slice(input), slice={[300:302]},
10709     sharding={devices=[8,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
10710 })";
10711 
10712   TF_ASSERT_OK_AND_ASSIGN(auto module,
10713                           PartitionComputation(hlo_string, /*num_devices=*/16));
10714   auto slice = AllOf(op::Slice(op::Parameter()), op::Shape("f32[2]"));
10715   auto halo1 = AllOf(op::CollectivePermute(slice), op::Shape("f32[2]"));
10716   auto halo2 =
10717       AllOf(op::CollectivePermute(op::Slice(slice)), op::Shape("f32[1]"));
10718   VLOG(1) << module->ToString();
10719   EXPECT_THAT(module->entry_computation()->root_instruction(),
10720               op::Copy(AllOf(op::DynamicSlice(op::Concatenate(halo1, halo2), _),
10721                              op::Shape("f32[1]"))));
10722 }
10723 
TEST_F(SpmdPartitioningTest,PartialDusReplicate)10724 TEST_F(SpmdPartitioningTest, PartialDusReplicate) {
10725   const char* const hlo_string = R"(
10726 HloModule module
10727 
10728 ENTRY entry {
10729   %input = f32[3,2] parameter(0),
10730     sharding={devices=[8,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
10731   ROOT %copy = f32[3,2] copy(input), sharding={replicated}
10732 })";
10733 
10734   TF_ASSERT_OK_AND_ASSIGN(auto module,
10735                           PartitionComputation(hlo_string, /*num_devices=*/16));
10736   VLOG(1) << module->ToString();
10737   auto dus =
10738       AllOf(op::Shape("f32[3,2]"),
10739             op::DynamicUpdateSlice(op::Broadcast(),
10740                                    op::Select(_, op::Parameter(0), _), _, _));
10741   EXPECT_THAT(module->entry_computation()->root_instruction(),
10742               op::Copy(AllOf(op::AllReduce(op::AllReduce(dus)))));
10743 }
10744 
TEST_F(SpmdPartitioningTest,GatherPassthrough)10745 TEST_F(SpmdPartitioningTest, GatherPassthrough) {
10746   const char* const hlo_string = R"(
10747 HloModule module
10748 
10749 ENTRY entry {
10750   p = f32[16,64,768,768]{3,2,1,0} parameter(0), sharding={replicated}
10751   c = f32[16,64,768,768]{3,2,1,0} copy(p), sharding={devices=[1,4,1,1]0,1,2,3}
10752   constant.1669 = s32[] constant(0)
10753   iota.1012 = s32[6]{0} iota(), iota_dimension=0, sharding={replicated}
10754   constant.1748 = s32[] constant(128), sharding={replicated}
10755   broadcast.2642 = s32[6]{0} broadcast(constant.1748), dimensions={}, sharding={replicated}
10756   multiply.92 = s32[6]{0} multiply(iota.1012, broadcast.2642), sharding={replicated}
10757   broadcast.2643 = s32[2,6]{1,0} broadcast(multiply.92), dimensions={1}, sharding={replicated}
10758   transpose.542 = s32[6,2]{0,1} transpose(broadcast.2643), dimensions={1,0}, sharding={replicated}
10759   pad.19 = s32[6,4]{1,0} pad(transpose.542, constant.1669), padding=0_0x2_0, sharding={replicated}
10760   ROOT gather.1 = f32[16,64,6,128,128]{4,3,2,1,0} gather(c, pad.19), offset_dims={0,1,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3}, index_vector_dim=1, slice_sizes={16,64,128,128}, sharding={devices=[1,4,1,1,1]0,1,2,3}
10761 })";
10762 
10763   TF_ASSERT_OK_AND_ASSIGN(auto module,
10764                           PartitionComputation(hlo_string, /*num_devices=*/4));
10765   VLOG(1) << module->ToString();
10766 
10767   auto root = module->entry_computation()->root_instruction();
10768   EXPECT_THAT(root, AllOf(op::Gather(), op::Shape("f32[16,16,6,128,128]")));
10769 }
10770 
TEST_F(SpmdPartitioningTest,ComplexReshardFromPartialReplicate)10771 TEST_F(SpmdPartitioningTest, ComplexReshardFromPartialReplicate) {
10772   const char* const hlo_string = R"(
10773 HloModule module
10774 
10775 ENTRY entry {
10776   %p = f32[4,15,4,16] parameter(0)
10777   %p.copy = f32[4,15,4,16] copy(p),
10778     sharding={devices=[1,1,1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
10779   %a = f32[4,15,4,16] add(p.copy, p.copy),
10780     sharding={devices=[1,1,1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
10781   ROOT %c2 = f32[4,15,4,16] copy(a), sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
10782 })";
10783 
10784   TF_ASSERT_OK_AND_ASSIGN(auto module,
10785                           PartitionComputation(hlo_string, /*num_devices=*/8));
10786   VLOG(1) << module->ToString();
10787 
10788   EXPECT_THAT(
10789       module->entry_computation()->root_instruction(),
10790       op::Copy(op::Reshape(op::Reshape(op::Transpose(op::AllToAll(_))))));
10791 }
10792 
TEST_F(SpmdPartitioningTest,ComplexReshardToPartialReplicate)10793 TEST_F(SpmdPartitioningTest, ComplexReshardToPartialReplicate) {
10794   const char* const hlo_string = R"(
10795 HloModule module
10796 
10797 ENTRY entry {
10798   %p = f32[4,15,4,16] parameter(0)
10799   %p.copy = f32[4,15,4,16] copy(p),
10800     sharding={devices=[1,4,2,1]0,1,2,3,4,5,6,7}
10801   %a = f32[4,15,4,16] add(p.copy, p.copy),
10802     sharding={devices=[1,4,2,1]0,1,2,3,4,5,6,7}
10803   ROOT %c2 = f32[4,15,4,16] copy(a), sharding={devices=[1,1,1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
10804 })";
10805 
10806   TF_ASSERT_OK_AND_ASSIGN(auto module,
10807                           PartitionComputation(hlo_string, /*num_devices=*/8));
10808   VLOG(1) << module->ToString();
10809 
10810   EXPECT_THAT(module->entry_computation()->root_instruction(),
10811               op::Copy(op::Reshape(op::Transpose(op::AllToAll(_)))));
10812 }
10813 
TEST_F(SpmdPartitioningTest,ComplexReshardMoveMergeDimensionRight)10814 TEST_F(SpmdPartitioningTest, ComplexReshardMoveMergeDimensionRight) {
10815   const char* const hlo_string = R"(
10816 HloModule module
10817 
10818 ENTRY entry {
10819   %p = f32[4,15,4,15] parameter(0)
10820   %p.copy = f32[4,15,4,15] copy(p),
10821     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10822   %a = f32[4,15,4,15] add(p.copy, p.copy),
10823     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10824   ROOT %c2 = f32[4,15,4,15] copy(a), sharding={devices=[1,1,1,8]0,2,4,6,1,3,5,7}
10825 })";
10826 
10827   TF_ASSERT_OK_AND_ASSIGN(auto module,
10828                           PartitionComputation(hlo_string, /*num_devices=*/8));
10829   VLOG(1) << module->ToString();
10830 
10831   EXPECT_THAT(module->entry_computation()->root_instruction(),
10832               op::Copy(op::Reshape(
10833                   op::Slice(op::Reshape(op::Transpose(op::AllToAll(_)))))));
10834 }
10835 
TEST_F(SpmdPartitioningTest,ComplexReshardMoveMergeDimensionLeft)10836 TEST_F(SpmdPartitioningTest, ComplexReshardMoveMergeDimensionLeft) {
10837   const char* const hlo_string = R"(
10838 HloModule module
10839 
10840 ENTRY entry {
10841   %p = f32[2,15,1,2] parameter(0)
10842   %p.copy = f32[2,15,1,2] copy(p),
10843     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10844   %a = f32[2,15,1,2] add(p.copy, p.copy),
10845     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10846   ROOT %c2 = f32[2,15,1,2] copy(a), sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
10847 })";
10848 
10849   TF_ASSERT_OK_AND_ASSIGN(auto module,
10850                           PartitionComputation(hlo_string, /*num_devices=*/8));
10851   VLOG(1) << module->ToString();
10852 
10853   EXPECT_THAT(
10854       module->entry_computation()->root_instruction(),
10855       op::Copy(op::Reshape(op::Reshape(op::Transpose(op::AllToAll(_))))));
10856 }
10857 
TEST_F(SpmdPartitioningTest,ComplexReshardMoveMergeDimensionLeftReorder)10858 TEST_F(SpmdPartitioningTest, ComplexReshardMoveMergeDimensionLeftReorder) {
10859   const char* const hlo_string = R"(
10860 HloModule module
10861 
10862 ENTRY entry {
10863   %p = f32[4,15,4,16] parameter(0)
10864   %p.copy = f32[4,15,4,16] copy(p),
10865     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10866   %a = f32[4,15,4,16] add(p.copy, p.copy),
10867     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10868   ROOT %c2 = f32[4,15,4,16] copy(a), sharding={devices=[1,8,1,1]0,2,4,6,1,3,5,7}
10869 })";
10870 
10871   TF_ASSERT_OK_AND_ASSIGN(auto module,
10872                           PartitionComputation(hlo_string, /*num_devices=*/8));
10873   VLOG(1) << module->ToString();
10874 
10875   EXPECT_THAT(module->entry_computation()->root_instruction(),
10876               op::Copy(op::CollectivePermute(
10877                   op::Reshape(op::Reshape(op::Transpose(op::AllToAll(_)))))));
10878 }
10879 
TEST_F(SpmdPartitioningTest,PaddedConvReshard)10880 TEST_F(SpmdPartitioningTest, PaddedConvReshard) {
10881   const char* const hlo_string = R"(
10882 HloModule module
10883 
10884 ENTRY entry {
10885   %p = bf16[16,256,256,384]{3,2,1,0} parameter(0)
10886   %p2 = bf16[3,3,384,384]{3,2,1,0} parameter(1)
10887   %p.copy = bf16[16,256,256,384]{3,2,1,0} copy(%p), sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
10888   %p2.copy = bf16[3,3,384,384]{3,2,1,0} copy(%p2), sharding={replicated}
10889   ROOT %convolution.10115 = bf16[16,256,256,384]{3,2,1,0} convolution(%p.copy, %p2.copy), window={size=3x3 pad=128_128x128_128 rhs_dilate=128x128}, dim_labels=b01f_01io->b01f, sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
10890 })";
10891 
10892   TF_ASSERT_OK_AND_ASSIGN(auto module,
10893                           PartitionComputation(hlo_string, /*num_devices=*/8));
10894   EXPECT_THAT(module->entry_computation()->root_instruction(),
10895               op::Convolution(
10896                   op::DynamicSlice(op::Pad(_, op::Constant()), _, _, _, _), _));
10897 }
10898 
TEST_F(SpmdPartitioningTest,KeepPartitionedNonSlicedDimension)10899 TEST_F(SpmdPartitioningTest, KeepPartitionedNonSlicedDimension) {
10900   const char* const hlo_string = R"(
10901 HloModule module
10902 
10903 ENTRY entry {
10904   %p = bf16[16,128,128,384]{3,2,1,0} parameter(0), sharding={replicated}
10905   %constant.1165 = s32[] constant(0), sharding={replicated}
10906   constant.1151 = s32[] constant(192), sharding={replicated}
10907   broadcast.1152 = s32[2]{0} broadcast(constant.1151), dimensions={}, sharding={replicated}
10908   slice.1576 = s32[1]{0} slice(broadcast.1152), slice={[0:1]}, sharding={replicated}
10909   reshape.1888 = s32[] reshape(slice.1576), sharding={replicated}
10910   slice.1546 = s32[1]{0} slice(broadcast.1152), slice={[1:2]}, sharding={replicated}
10911   reshape.1890 = s32[] reshape(slice.1546), sharding={replicated}
10912   constant.861 = bf16[] constant(0), sharding={replicated}
10913   broadcast.862 = bf16[16,512,512,384]{3,2,1,0} broadcast(constant.861), dimensions={}, sharding={devices=[2,2,1,1]0,1,2,3}
10914   %c = bf16[16,128,128,384]{3,2,1,0} copy(p), sharding={devices=[2,2,1,1]0,1,2,3}
10915   add.228 = bf16[16,128,128,384]{3,2,1,0} add(c, c), sharding={devices=[2,2,1,1]0,1,2,3}
10916   ROOT dynamic-update-slice.111 = bf16[16,512,512,384]{3,2,1,0} dynamic-update-slice(broadcast.862, add.228, constant.1165, reshape.1888, reshape.1890, /*index=5*/constant.1165), sharding={devices=[2,2,1,1]0,1,2,3}
10917 })";
10918 
10919   TF_ASSERT_OK_AND_ASSIGN(auto module,
10920                           PartitionComputation(hlo_string, /*num_devices=*/4));
10921 
10922   XLA_VLOG_LINES(1, module->ToString());
10923   EXPECT_THAT(module->entry_computation()->root_instruction(),
10924               op::DynamicSlice(AllOf(op::DynamicUpdateSlice(),
10925                                      op::Shape("bf16[8,512,512,384]")),
10926                                _, _, _, _));
10927 }
10928 
TEST_F(SpmdPartitioningTest,KeepPartitionedNonSlicedDimensionWithConstantIndices)10929 TEST_F(SpmdPartitioningTest,
10930        KeepPartitionedNonSlicedDimensionWithConstantIndices) {
10931   const char* const hlo_string = R"(
10932 HloModule module
10933 
10934 ENTRY entry {
10935   p1 = bf16[16,192,192,384]{3,2,1,0} parameter(0), sharding={replicated}
10936   p2 = bf16[16,128,128,384]{3,2,1,0} parameter(1), sharding={replicated}
10937   c1 = bf16[16,192,192,384]{3,2,1,0} copy(p1), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10938   c2 = bf16[16,128,128,384]{3,2,1,0} copy(p2), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10939   constant.1163 = bf16[] constant(0), sharding={replicated}
10940   constant.1165 = s32[] constant(0), sharding={replicated}
10941   pad.179 = bf16[16,224,224,384]{3,2,1,0} pad(c1, constant.1163), padding=0_0x16_16x16_16x0_0, sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10942   add.439 = bf16[16,128,128,384]{3,2,1,0} add(c2, c2), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10943   constant.1070 = s32[] constant(48), sharding={replicated}
10944   dynamic-update-slice.128 = bf16[16,224,224,384]{3,2,1,0} dynamic-update-slice(pad.179, add.439, constant.1165, constant.1070, constant.1070, /*index=5*/constant.1165), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10945   ROOT c = bf16[16,224,224,384]{3,2,1,0} copy(dynamic-update-slice.128), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10946 })";
10947 
10948   TF_ASSERT_OK_AND_ASSIGN(auto module,
10949                           PartitionComputation(hlo_string, /*num_devices=*/8));
10950 
10951   XLA_VLOG_LINES(1, module->ToString());
10952   EXPECT_THAT(
10953       module->entry_computation()->root_instruction(),
10954       op::Copy(op::DynamicSlice(
10955           AllOf(op::DynamicUpdateSlice(), op::Shape("bf16[8,224, 224,384]")), _,
10956           _, _, _)));
10957 }
10958 
TEST_F(SpmdPartitioningTest,CustomCallManualSharding)10959 TEST_F(SpmdPartitioningTest, CustomCallManualSharding) {
10960   const char* const hlo_string = R"(
10961 HloModule pjit_xmap_dummy.5
10962 
10963 ENTRY %main.21 (Arg_0.1: f32[4,4,8], Arg_1.2: f32[4,8]) -> (f32[4,4,8], f32[4]) {
10964   %Arg_0.1 = f32[4,4,8]{2,1,0} parameter(0), sharding={devices=[4,1,1]0,1,2,3}
10965   %copy.3 = f32[4,4,8]{2,1,0} copy(f32[4,4,8]{2,1,0} %Arg_0.1), sharding={devices=[4,1,1]0,1,2,3}
10966   %custom-call.4 = f32[1,4,8]{2,1,0} custom-call(f32[4,4,8]{2,1,0} %copy.3), custom_call_target="SPMDFullToShardShape", sharding={manual}
10967   %reshape.7 = f32[4,8]{1,0} reshape(f32[1,4,8]{2,1,0} %custom-call.4), sharding={manual}
10968   %Arg_1.2 = f32[4,8]{1,0} parameter(1), sharding={replicated}
10969   %copy.2 = f32[4,8]{1,0} copy(f32[4,8]{1,0} %Arg_1.2), sharding={replicated}
10970   %custom-call.6 = f32[4,8]{1,0} custom-call(f32[4,8]{1,0} %copy.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
10971   %custom-call.8 = (f32[4,8]{1,0}, f32[1]{0}) custom-call(f32[4,8]{1,0} %reshape.7, f32[4,8]{1,0} %custom-call.6), custom_call_target="dummy", operand_layout_constraints={f32[4,8]{1,0}, f32[4,8]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, sharding={{manual}, {manual}}
10972   %get-tuple-element.9 = f32[4,8]{1,0} get-tuple-element((f32[4,8]{1,0}, f32[1]{0}) %custom-call.8), index=0, sharding={manual}
10973   %reshape.11 = f32[1,4,8]{2,1,0} reshape(f32[4,8]{1,0} %get-tuple-element.9), sharding={manual}
10974   %copy.1 = f32[1,4,8]{2,1,0} copy(f32[1,4,8]{2,1,0} %reshape.11), sharding={manual}
10975   %custom-call.14 = f32[4,4,8]{2,1,0} custom-call(f32[1,4,8]{2,1,0} %copy.1), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1]0,1,2,3}
10976   %reshape.18 = f32[4,4,8]{2,1,0} reshape(f32[4,4,8]{2,1,0} %custom-call.14), sharding={devices=[4,1,1]0,1,2,3}
10977   %get-tuple-element.10 = f32[1]{0} get-tuple-element((f32[4,8]{1,0}, f32[1]{0}) %custom-call.8), index=1, sharding={manual}
10978   %reshape.12 = f32[1,1]{1,0} reshape(f32[1]{0} %get-tuple-element.10), sharding={manual}
10979   %copy = f32[1,1]{1,0} copy(f32[1,1]{1,0} %reshape.12), sharding={manual}
10980   %custom-call.16 = f32[4,1]{1,0} custom-call(f32[1,1]{1,0} %copy), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1]0,1,2,3}
10981   %reshape.17 = f32[4]{0} reshape(f32[4,1]{1,0} %custom-call.16), sharding={devices=[4]0,1,2,3}
10982   %reshape.19 = f32[4]{0} reshape(f32[4]{0} %reshape.17), sharding={devices=[4]0,1,2,3}
10983   ROOT %tuple.20 = (f32[4,4,8]{2,1,0}, f32[4]{0}) tuple(f32[4,4,8]{2,1,0} %reshape.18, f32[4]{0} %reshape.19), sharding={{replicated}, {replicated}}
10984 }
10985 )";
10986 
10987   TF_ASSERT_OK_AND_ASSIGN(auto module,
10988                           PartitionComputation(hlo_string, /*num_devices=*/4));
10989 
10990   XLA_VLOG_LINES(1, module->ToString());
10991   EXPECT_THAT(module->entry_computation()->root_instruction(),
10992               op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
10993                             _, op::Shape("f32[1,4,8]"), _, _, _)),
10994                         op::AllReduce(op::DynamicUpdateSlice(
10995                             _, op::Shape("f32[1]"), _))));
10996 }
10997 
10998 }  // namespace
10999 }  // namespace spmd
11000 }  // namespace xla
11001