1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
23 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
24 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29
30 namespace xla {
31 namespace spmd {
32 namespace {
33
34 using ::testing::_;
35 using ::testing::AllOf;
36 namespace op = xla::testing::opcode_matchers;
37
38 class SpmdPartitioningTest : public HloTestBase {
39 public:
PartitionComputation(absl::string_view hlo_module,int64_t num_devices,bool conv_halo_exchange_always_on_lhs=true,bool choose_faster_windowed_einsum=false,bool unroll_windowed_einsum=false,bool bidirectional_windowed_einsum=false,int64_t threshold_for_windowed_einsum_mib=-1)40 StatusOr<std::unique_ptr<HloModule>> PartitionComputation(
41 absl::string_view hlo_module, int64_t num_devices,
42 bool conv_halo_exchange_always_on_lhs = true,
43 bool choose_faster_windowed_einsum = false,
44 bool unroll_windowed_einsum = false,
45 bool bidirectional_windowed_einsum = false,
46 int64_t threshold_for_windowed_einsum_mib = -1) {
47 // Some tests (BackpropFilter convs) set this flag false to test two
48 // different paths of the implementation.
49 SpmdPartitionerOptions options;
50 options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
51 options.allow_module_signature_change = true;
52 options.choose_faster_windowed_einsum_over_mem =
53 choose_faster_windowed_einsum;
54 options.unroll_windowed_einsum = unroll_windowed_einsum;
55 options.bidirectional_windowed_einsum = bidirectional_windowed_einsum;
56 if (threshold_for_windowed_einsum_mib >= 0) {
57 options.threshold_for_windowed_einsum_mib =
58 threshold_for_windowed_einsum_mib;
59 }
60 auto collective_ops_creator =
61 GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
62 // Do not use all-gather for pattern-matching purpose, as the partitioner
63 // might create reshape/transposes around it.
64 collective_ops_creator.create_cross_partition_all_gather = nullptr;
65
66 TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(
67 hlo_module, GetModuleConfigForTest()));
68 HloPassPipeline pass("spmd-partitioning");
69 pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
70 /*allow_mixed_precision=*/false);
71 pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options,
72 collective_ops_creator);
73 pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
74 /*allow_mixed_precision=*/false);
75 TF_RETURN_IF_ERROR(pass.Run(module.get()).status());
76 return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
77 }
78 };
79
TEST_F(SpmdPartitioningTest,InvalidSharding)80 TEST_F(SpmdPartitioningTest, InvalidSharding) {
81 absl::string_view hlo_string = R"(
82 HloModule module
83
84 ENTRY entry {
85 token0 = token[] after-all(), sharding={maximal device=0}
86 infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
87 sharding={{devices=[2,1]0,1}, {maximal device=0}}
88 ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
89 sharding={maximal device=0}
90 })";
91 auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
92 EXPECT_FALSE(module_status.status().ok());
93 EXPECT_THAT(module_status.status().ToString(),
94 ::testing::HasSubstr(
95 "only supports tile sharding that includes all partitions"));
96 }
97
TEST_F(SpmdPartitioningTest,SingleDeviceToReplicated)98 TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) {
99 absl::string_view hlo_string = R"(
100 HloModule module
101
102 ENTRY entry {
103 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
104 sharding={maximal device=0}
105 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
106 })";
107 TF_ASSERT_OK_AND_ASSIGN(auto module,
108 PartitionComputation(hlo_string, /*num_devices=*/2));
109 VLOG(1) << module->ToString();
110 HloInstruction* root = module->entry_computation()->root_instruction();
111 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
112 op::Select(op::Broadcast(op::Compare()),
113 op::Constant(), op::Broadcast()))),
114 op::Shape("s32[2,3]")));
115 }
116
TEST_F(SpmdPartitioningTest,SingleDeviceCustomCall)117 TEST_F(SpmdPartitioningTest, SingleDeviceCustomCall) {
118 absl::string_view hlo_string = R"(
119 HloModule module
120
121 ENTRY entry {
122 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
123 sharding={maximal device=0}
124 %cc = s32[2,3] custom-call(%constant), custom_call_target="SomeCustomCall",
125 sharding={maximal device=0}
126 ROOT %copy = s32[2,3]{1,0} copy(%cc), sharding={replicated}
127 })";
128 TF_ASSERT_OK_AND_ASSIGN(auto module,
129 PartitionComputation(hlo_string, /*num_devices=*/2));
130 VLOG(1) << module->ToString();
131 HloInstruction* custom_call = FindInstruction(module.get(), "cc.1");
132 EXPECT_NE(custom_call, nullptr);
133 EXPECT_NE(custom_call->parent(), module->entry_computation());
134 HloInstruction* root = module->entry_computation()->root_instruction();
135 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
136 op::Select(op::Broadcast(op::Compare()),
137 op::Conditional(), op::Broadcast()))),
138 op::Shape("s32[2,3]")));
139 }
140
TEST_F(SpmdPartitioningTest,SingleDeviceToSingleDevice)141 TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) {
142 absl::string_view hlo_string = R"(
143 HloModule module
144
145 ENTRY entry {
146 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
147 sharding={maximal device=0}
148 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1}
149 })";
150 TF_ASSERT_OK_AND_ASSIGN(auto module,
151 PartitionComputation(hlo_string, /*num_devices=*/2));
152 HloInstruction* root = module->entry_computation()->root_instruction();
153 VLOG(1) << module->ToString();
154 EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select(
155 op::Broadcast(op::Compare()),
156 op::Constant(), op::Broadcast()))),
157 op::Shape("s32[2,3]"))));
158 }
159
TEST_F(SpmdPartitioningTest,SingleDeviceToTiled)160 TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) {
161 absl::string_view hlo_string = R"(
162 HloModule module
163
164 ENTRY entry {
165 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
166 sharding={maximal device=0}
167 ROOT %copy = s32[2,3]{1,0} copy(%constant),
168 sharding={devices=[2,1]1,0}
169 })";
170 TF_ASSERT_OK_AND_ASSIGN(auto module,
171 PartitionComputation(hlo_string, /*num_devices=*/2));
172 VLOG(1) << module->ToString();
173 HloInstruction* root = module->entry_computation()->root_instruction();
174 EXPECT_THAT(
175 root,
176 AllOf(
177 op::Copy(op::DynamicSlice(
178 op::AllReduce(op::Select(
179 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
180 op::Constant(), op::Broadcast())),
181 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
182 op::Constant())),
183 op::Shape("s32[1,3]")));
184 }
185
TEST_F(SpmdPartitioningTest,TiledToReplicated)186 TEST_F(SpmdPartitioningTest, TiledToReplicated) {
187 absl::string_view hlo_string = R"(
188 HloModule module
189
190 ENTRY entry {
191 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
192 sharding={devices=[2,1]0,1}
193 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
194 })";
195 TF_ASSERT_OK_AND_ASSIGN(auto module,
196 PartitionComputation(hlo_string, /*num_devices=*/2));
197 HloInstruction* root = module->entry_computation()->root_instruction();
198 EXPECT_THAT(
199 root,
200 op::Copy(op::AllReduce(AllOf(
201 op::DynamicUpdateSlice(
202 op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
203 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
204 op::Constant()),
205 op::Shape("s32[2,3]")))));
206 }
207
TEST_F(SpmdPartitioningTest,TiledToSingleDevice)208 TEST_F(SpmdPartitioningTest, TiledToSingleDevice) {
209 absl::string_view hlo_string = R"(
210 HloModule module
211
212 ENTRY entry {
213 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
214 sharding={devices=[2,1]0,1}
215 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0}
216 })";
217 TF_ASSERT_OK_AND_ASSIGN(auto module,
218 PartitionComputation(hlo_string, /*num_devices=*/2));
219 HloInstruction* root = module->entry_computation()->root_instruction();
220 EXPECT_THAT(
221 root,
222 op::Copy(op::Copy(op::AllReduce(AllOf(
223 op::DynamicUpdateSlice(
224 op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
225 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
226 op::Constant()),
227 op::Shape("s32[2,3]"))))));
228 }
229
TEST_F(SpmdPartitioningTest,TiledToTiledEven)230 TEST_F(SpmdPartitioningTest, TiledToTiledEven) {
231 absl::string_view hlo_string = R"(
232 HloModule module
233
234 ENTRY entry {
235 %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1}
236 ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1}
237 })";
238 TF_ASSERT_OK_AND_ASSIGN(auto module,
239 PartitionComputation(hlo_string, /*num_devices=*/2));
240 VLOG(1) << module->ToString();
241
242 HloInstruction* root = module->entry_computation()->root_instruction();
243 EXPECT_THAT(
244 root,
245 AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf(
246 op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))),
247 op::Shape("s32[8,1]")));
248 }
249
TEST_F(SpmdPartitioningTest,TiledToTiledUneven)250 TEST_F(SpmdPartitioningTest, TiledToTiledUneven) {
251 absl::string_view hlo_string = R"(
252 HloModule module
253
254 ENTRY entry {
255 %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1}
256 ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1}
257 })";
258 TF_ASSERT_OK_AND_ASSIGN(auto module,
259 PartitionComputation(hlo_string, /*num_devices=*/2));
260 VLOG(1) << module->ToString();
261
262 HloInstruction* root = module->entry_computation()->root_instruction();
263 EXPECT_THAT(
264 root,
265 AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll(
266 op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]")))))))))));
267 }
268
TEST_F(SpmdPartitioningTest,GetTupleElementSwapDevice)269 TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) {
270 absl::string_view hlo_string = R"(
271 HloModule module
272
273 ENTRY entry {
274 %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0),
275 sharding={{maximal device=1}, {maximal device=1}}
276 %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0,
277 sharding={maximal device=0}
278 %gte.1 = u32[] get-tuple-element(%param.0), index=1,
279 sharding={maximal device=0}
280 ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1),
281 sharding={{maximal device=0},{maximal device=0}}
282 })";
283 TF_ASSERT_OK_AND_ASSIGN(auto module,
284 PartitionComputation(hlo_string, /*num_devices=*/2));
285 VLOG(1) << module->ToString();
286 HloInstruction* root = module->entry_computation()->root_instruction();
287 ASSERT_THAT(root, op::Tuple());
288
289 EXPECT_THAT(root->operand(0),
290 op::Copy(op::AllReduce(op::Select(
291 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
292 op::GetTupleElement(op::Parameter()), op::Broadcast()))));
293 EXPECT_THAT(root->operand(1),
294 op::Copy(op::AllReduce(op::Select(
295 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
296 op::GetTupleElement(op::Parameter()), op::Broadcast()))));
297 }
298
TEST_F(SpmdPartitioningTest,GetTupleElementTiled)299 TEST_F(SpmdPartitioningTest, GetTupleElementTiled) {
300 absl::string_view hlo_string = R"(
301 HloModule module
302
303 ENTRY entry {
304 param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0),
305 sharding={{replicated}, {replicated}}
306 gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0,
307 sharding={devices=[2,1]0,1}
308 gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1,
309 sharding={devices=[2,1]0,1}
310 ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1),
311 sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}}
312 })";
313 TF_ASSERT_OK_AND_ASSIGN(auto module,
314 PartitionComputation(hlo_string, /*num_devices=*/2));
315 VLOG(1) << module->ToString();
316 HloInstruction* root = module->entry_computation()->root_instruction();
317 ASSERT_THAT(root, op::Tuple());
318
319 auto offset =
320 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
321
322 EXPECT_THAT(root->operand(0),
323 op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
324 op::Constant()));
325 EXPECT_THAT(root->operand(1),
326 op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
327 op::Constant()));
328 }
329
TEST_F(SpmdPartitioningTest,TiledInfeed)330 TEST_F(SpmdPartitioningTest, TiledInfeed) {
331 absl::string_view hlo_string = R"(
332 HloModule module
333
334 ENTRY entry {
335 token0 = token[] after-all(), sharding={maximal device=0}
336 infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
337 sharding={{devices=[2,1]0,1}, {maximal device=0}}
338 ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
339 sharding={maximal device=0}
340 })";
341 TF_ASSERT_OK_AND_ASSIGN(auto module,
342 PartitionComputation(hlo_string, /*num_devices=*/2));
343 HloInstruction* root = module->entry_computation()->root_instruction();
344 EXPECT_THAT(
345 root,
346 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
347 op::Broadcast(),
348 op::GetTupleElement(
349 AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
350 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
351 op::Constant()))));
352 }
353
TEST_F(SpmdPartitioningTest,UnevenTiledInfeed)354 TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) {
355 absl::string_view hlo_string = R"(
356 HloModule module
357
358 ENTRY entry {
359 token0 = token[] after-all(), sharding={maximal device=0}
360 infeed = (f32[9,2]{1,0}, token[]) infeed(token0),
361 sharding={{devices=[2,1]0,1}, {maximal device=0}}
362 ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0,
363 sharding={devices=[2,1]0,1}
364 })";
365 TF_ASSERT_OK_AND_ASSIGN(auto module,
366 PartitionComputation(hlo_string, /*num_devices=*/2));
367 VLOG(1) << module->ToString();
368 HloInstruction* root = module->entry_computation()->root_instruction();
369 EXPECT_THAT(
370 root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional(
371 op::Convert(op::PartitionId()),
372 op::AfterAll(), op::AfterAll()))));
373 EXPECT_THAT(
374 root->operand(0)->called_computations()[0]->root_instruction(),
375 AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter())));
376 auto second_infeed =
377 AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter()));
378 EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
379 AllOf(op::Shape("(f32[5,2], token[])"),
380 op::Tuple(op::Pad(op::GetTupleElement(second_infeed),
381 op::Constant()),
382 op::GetTupleElement(second_infeed))));
383 }
384
TEST_F(SpmdPartitioningTest,UnevenTiledTupleInfeed)385 TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) {
386 absl::string_view hlo_string = R"(
387 HloModule module
388
389 ENTRY entry {
390 token0 = token[] after-all(), sharding={maximal device=0}
391 infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
392 sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}}
393 ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
394 index=0, sharding={{devices=[2,1]0,1}, {replicated}}
395 })";
396 TF_ASSERT_OK_AND_ASSIGN(auto module,
397 PartitionComputation(hlo_string, /*num_devices=*/2));
398 VLOG(1) << module->ToString();
399 HloInstruction* root = module->entry_computation()->root_instruction();
400 EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"),
401 op::GetTupleElement(op::Conditional(
402 op::Convert(op::PartitionId()), op::AfterAll(),
403 op::AfterAll()))));
404 EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
405 AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
406 op::Infeed(op::Parameter())));
407 auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"),
408 op::Infeed(op::Parameter()));
409 EXPECT_THAT(
410 root->operand(0)->called_computations()[1]->root_instruction(),
411 AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
412 op::Tuple(op::Tuple(op::Pad(op::GetTupleElement(
413 op::GetTupleElement(second_infeed)),
414 op::Constant()),
415 op::GetTupleElement(
416 op::GetTupleElement(second_infeed))),
417 op::GetTupleElement(second_infeed))));
418 }
419
TEST_F(SpmdPartitioningTest,MixedTupleInfeed)420 TEST_F(SpmdPartitioningTest, MixedTupleInfeed) {
421 absl::string_view hlo_string = R"(
422 HloModule module
423
424 ENTRY entry {
425 token0 = token[] after-all(), sharding={maximal device=0}
426 infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
427 sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
428 ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
429 index=0, sharding={{maximal device=0}, {maximal device=1}}
430 })";
431 TF_ASSERT_OK_AND_ASSIGN(auto module,
432 PartitionComputation(hlo_string, /*num_devices=*/2));
433 VLOG(1) << module->ToString();
434 HloInstruction* root = module->entry_computation()->root_instruction();
435 EXPECT_THAT(root, AllOf(op::Shape("(f32[9,2], f32[2])"),
436 op::GetTupleElement(op::Conditional(
437 op::Convert(op::PartitionId()), op::AfterAll(),
438 op::AfterAll()))));
439 auto first_infeed = AllOf(op::Shape("((f32[9,2], ()), token[])"),
440 op::Infeed(op::Parameter()));
441 EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
442 AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
443 op::Tuple(op::Tuple(op::GetTupleElement(
444 op::GetTupleElement(first_infeed)),
445 op::Broadcast(op::Constant())),
446 op::GetTupleElement(first_infeed))));
447 auto second_infeed =
448 AllOf(op::Shape("(((), f32[2]), token[])"), op::Infeed(op::Parameter()));
449 EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
450 AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
451 op::Tuple(op::Tuple(op::Broadcast(op::Constant()),
452 op::GetTupleElement(op::GetTupleElement(
453 second_infeed))),
454 op::GetTupleElement(second_infeed))));
455 }
456
TEST_F(SpmdPartitioningTest,TiledToReplicatedReduce)457 TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) {
458 absl::string_view hlo_string = R"(
459 HloModule module
460
461 sum {
462 a = f32[] parameter(0)
463 b = f32[] parameter(1)
464 ROOT add = f32[] add(a, b)
465 }
466
467 ENTRY entry {
468 constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
469 sharding={devices=[2,1]0,1}
470 constant.1 = f32[] constant(0), sharding={replicated}
471 ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1},
472 to_apply=sum, sharding={replicated}
473 })";
474 TF_ASSERT_OK_AND_ASSIGN(auto module,
475 PartitionComputation(hlo_string, /*num_devices=*/2));
476 VLOG(1) << module->ToString();
477 HloInstruction* root = module->entry_computation()->root_instruction();
478 EXPECT_THAT(
479 root,
480 op::AllReduce(op::Reduce(
481 op::Select(
482 op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())),
483 op::Broadcast(op::Constant())),
484 AllOf(op::Shape("f32[2,3]{1,0}"),
485 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
486 op::Reshape(), op::Constant())),
487 op::Broadcast(op::Constant())),
488 op::Constant())));
489 }
490
TEST_F(SpmdPartitioningTest,TiledElementwise)491 TEST_F(SpmdPartitioningTest, TiledElementwise) {
492 absl::string_view hlo_string = R"(
493 HloModule module
494
495 ENTRY entry {
496 constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
497 sharding={devices=[2,1]0,1}
498 constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}),
499 sharding={replicated}
500 multiply = f32[3,3]{1,0} multiply(constant, constant.1),
501 sharding={devices=[2,1]0,1}
502 ROOT add = f32[3,3]{1,0} add(multiply, constant.1),
503 sharding={devices=[2,1]0,1}
504 })";
505 TF_ASSERT_OK_AND_ASSIGN(auto module,
506 PartitionComputation(hlo_string, /*num_devices=*/2));
507 VLOG(1) << module->ToString();
508 HloInstruction* root = module->entry_computation()->root_instruction();
509 EXPECT_THAT(
510 root,
511 AllOf(
512 op::Shape("f32[2,3]{1,0}"),
513 op::Add(op::Multiply(
514 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
515 op::Reshape(), op::Constant()),
516 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
517 op::Reshape(), op::Constant())),
518 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
519 op::Reshape(), op::Constant()))));
520 }
521
TEST_F(SpmdPartitioningTest,TiledAllReduce)522 TEST_F(SpmdPartitioningTest, TiledAllReduce) {
523 absl::string_view hlo_string = R"(
524 HloModule module
525
526 sum {
527 a = f32[] parameter(0)
528 b = f32[] parameter(1)
529 ROOT add = f32[] add(a, b)
530 }
531
532 ENTRY entry {
533 parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1}
534 ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum,
535 replica_groups={}, sharding={devices=[2,1]0,1}
536 })";
537 TF_ASSERT_OK_AND_ASSIGN(auto module,
538 PartitionComputation(hlo_string, /*num_devices=*/2));
539 VLOG(1) << module->ToString();
540 HloInstruction* root = module->entry_computation()->root_instruction();
541 EXPECT_THAT(
542 root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0))));
543 }
544
TEST_F(SpmdPartitioningTest,BroadcastOnlyNewDimsSharded)545 TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) {
546 absl::string_view hlo_string = R"(
547 HloModule module
548
549 ENTRY entry {
550 constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
551 sharding={replicated}
552 ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
553 sharding={devices=[2,1,1]0,1}
554 })";
555 TF_ASSERT_OK_AND_ASSIGN(auto module,
556 PartitionComputation(hlo_string, /*num_devices=*/2));
557 VLOG(1) << module->ToString();
558 HloInstruction* root = module->entry_computation()->root_instruction();
559 EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"),
560 op::Broadcast(op::Constant())));
561 }
562
TEST_F(SpmdPartitioningTest,BroadcastOnlyOldDimsSharded)563 TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) {
564 absl::string_view hlo_string = R"(
565 HloModule module
566
567 ENTRY entry {
568 constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
569 sharding={replicated}
570 ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
571 sharding={devices=[1,2,1]0,1}
572 })";
573 TF_ASSERT_OK_AND_ASSIGN(auto module,
574 PartitionComputation(hlo_string, /*num_devices=*/2));
575 VLOG(1) << module->ToString();
576 HloInstruction* root = module->entry_computation()->root_instruction();
577 EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
578 op::Broadcast(op::DynamicSlice(
579 op::Constant(), op::Reshape(), op::Constant()))));
580 }
581
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsSharded)582 TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) {
583 absl::string_view hlo_string = R"(
584 HloModule module
585
586 ENTRY entry {
587 constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
588 sharding={replicated}
589 ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
590 sharding={devices=[2,2,1]0,1,2,3}
591 })";
592 TF_ASSERT_OK_AND_ASSIGN(auto module,
593 PartitionComputation(hlo_string, /*num_devices=*/4));
594 VLOG(1) << module->ToString();
595 HloInstruction* root = module->entry_computation()->root_instruction();
596 EXPECT_THAT(
597 root,
598 AllOf(op::Shape("f32[2,2,3]{2,1,0}"),
599 op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"),
600 op::DynamicSlice(op::Constant(), op::Reshape(),
601 op::Constant())))));
602 }
603
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsShardedPartiallySharded)604 TEST_F(SpmdPartitioningTest,
605 BroadcastBothOldAndNewDimsShardedPartiallySharded) {
606 absl::string_view hlo_string = R"(
607 HloModule module
608
609 ENTRY entry {
610 param = f32[4,3] parameter(0),
611 sharding={devices=[1,2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
612 ROOT broadcast = f32[4,4,3] broadcast(param), dimensions={1,2},
613 sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
614 })";
615 TF_ASSERT_OK_AND_ASSIGN(auto module,
616 PartitionComputation(hlo_string, /*num_devices=*/8));
617 VLOG(1) << module->ToString();
618 HloInstruction* root = module->entry_computation()->root_instruction();
619 EXPECT_THAT(
620 root,
621 AllOf(op::Shape("f32[2,4,2]"),
622 op::Broadcast(AllOf(op::Shape("f32[4,2]"), op::Parameter(0)))));
623 }
624
TEST_F(SpmdPartitioningTest,ConvWithParallelDimAndNonParallelSpatialDimPartitioned)625 TEST_F(SpmdPartitioningTest,
626 ConvWithParallelDimAndNonParallelSpatialDimPartitioned) {
627 absl::string_view hlo_string = R"(
628 HloModule module
629
630 ENTRY entry {
631 %lhs = f32[32,12,12,24,32] parameter(0)
632 %lhs.copy = f32[32,12,12,24,32] copy(%lhs),
633 sharding={devices=[2,2,1,1,1]0,1,2,3}
634 %rhs = f32[32,6,6,16,32] parameter(1)
635 %rhs.copy = f32[32,6,6,16,32] copy(%rhs),
636 sharding={devices=[2,2,1,1,1]0,1,2,3}
637 ROOT %conv = f32[32,7,7,24,16] convolution(%lhs.copy, %rhs.copy),
638 dim_labels=012bf_012oi->012bf,
639 window={size=32x6x6 stride=31x1x1 lhs_dilate=32x1x1},
640 sharding={devices=[2,2,1,1,1]0,1,2,3}
641 })";
642
643 TF_ASSERT_OK_AND_ASSIGN(auto module,
644 PartitionComputation(hlo_string, /*num_devices=*/4));
645 VLOG(1) << module->ToString();
646 const auto root = module->entry_computation()->root_instruction();
647 const auto lhs = AllOf(op::Copy(op::DynamicSlice(
648 op::Parameter(), op::Reshape(), op::Reshape(),
649 op::Constant(), op::Constant(), op::Constant())),
650 op::Shape("f32[16,6,12,24,32]"));
651 const auto rhs = AllOf(op::Copy(op::DynamicSlice(
652 op::Parameter(), op::Reshape(), op::Reshape(),
653 op::Constant(), op::Constant(), op::Constant())),
654 op::Shape("f32[16,3,6,16,32]"));
655 auto resharded_rhs =
656 AllOf(op::Shape("f32[16,6,6,16,32]"),
657 op::AllReduce(op::DynamicUpdateSlice(
658 op::Broadcast(), rhs, op::Constant(), op::Reshape(),
659 op::Constant(), op::Constant(), op::Constant())));
660
661 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
662 op::Shape("f32[16,2,12,24,32]"));
663 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
664 op::Shape("f32[16,3,12,24,32]"));
665 EXPECT_THAT(
666 root,
667 AllOf(op::Convolution(
668 op::Select(op::Compare(),
669 op::DynamicSlice(
670 op::Concatenate(left_halo, lhs, right_halo),
671 op::Constant(), op::Add(), op::Constant(),
672 op::Constant(), op::Constant()),
673 op::Broadcast()),
674 resharded_rhs),
675 op::Shape("f32[16,4,7,24,16]")));
676 }
677
TEST_F(SpmdPartitioningTest,BroadcastPropagateTiledSharding)678 TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) {
679 absl::string_view hlo_string = R"(
680 HloModule module
681
682 ENTRY entry {
683 constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}),
684 sharding={devices=[2,1]0,1}
685 ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
686 sharding={devices=[1,2,1]0,1}
687 })";
688 TF_ASSERT_OK_AND_ASSIGN(auto module,
689 PartitionComputation(hlo_string, /*num_devices=*/2));
690 VLOG(1) << module->ToString();
691 HloInstruction* root = module->entry_computation()->root_instruction();
692 EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
693 op::Broadcast(op::DynamicSlice(
694 op::Constant(), op::Reshape(), op::Constant()))));
695 }
696
TEST_F(SpmdPartitioningTest,OutfeedSingleDevice)697 TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) {
698 absl::string_view hlo_string = R"(
699 HloModule module
700
701 ENTRY entry {
702 token.0 = token[] after-all()
703 data = f32[1024]{0} parameter(0), sharding={maximal device=0}
704 outfeed = token[] outfeed(data, token.0), sharding={maximal device=0}
705 })";
706 TF_ASSERT_OK_AND_ASSIGN(auto module,
707 PartitionComputation(hlo_string, /*num_devices=*/2));
708 VLOG(1) << module->ToString();
709 HloInstruction* root = module->entry_computation()->root_instruction();
710 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
711 op::Conditional(
712 op::Compare(op::PartitionId(), op::Constant()),
713 op::Tuple(op::Parameter(0), op::AfterAll()),
714 op::Tuple(op::Parameter(0), op::AfterAll()))));
715
716 HloInstruction* root_b0 = root->branch_computation(0)->root_instruction();
717 EXPECT_THAT(root_b0,
718 AllOf(op::Shape("token[]"),
719 op::Outfeed(op::GetTupleElement(op::Parameter(), 0),
720 op::GetTupleElement(op::Parameter(), 1))));
721
722 HloInstruction* root_b1 = root->branch_computation(1)->root_instruction();
723 EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll()));
724 }
725
TEST_F(SpmdPartitioningTest,OutfeedEvenlyTiled)726 TEST_F(SpmdPartitioningTest, OutfeedEvenlyTiled) {
727 absl::string_view hlo_string = R"(
728 HloModule module
729
730 ENTRY entry {
731 token.0 = token[] after-all()
732 data = f32[1024]{0} parameter(0), sharding={devices=[2]0,1}
733 ROOT outfeed = token[] outfeed(data, token.0), sharding={devices=[2]0,1}
734 })";
735 TF_ASSERT_OK_AND_ASSIGN(auto module,
736 PartitionComputation(hlo_string, /*num_devices=*/2));
737 VLOG(1) << module->ToString();
738 HloInstruction* root = module->entry_computation()->root_instruction();
739 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
740 op::Outfeed(op::Parameter(), op::AfterAll())));
741 }
742
TEST_F(SpmdPartitioningTest,OutfeedTupleEvenlyTiled)743 TEST_F(SpmdPartitioningTest, OutfeedTupleEvenlyTiled) {
744 absl::string_view hlo_string = R"(
745 HloModule module
746
747 ENTRY entry {
748 token.0 = token[] after-all()
749 data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
750 {devices=[2]0,1}}
751 ROOT outfeed = token[] outfeed(data, token.0),
752 outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
753 {devices=[2]0,1}}
754 })";
755 TF_ASSERT_OK_AND_ASSIGN(auto module,
756 PartitionComputation(hlo_string, /*num_devices=*/2));
757 VLOG(1) << module->ToString();
758 HloInstruction* root = module->entry_computation()->root_instruction();
759 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
760 op::Outfeed(op::Parameter(), op::AfterAll())));
761 auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
762 auto expected_layout1 = LayoutUtil::MakeLayout({0});
763 EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
764 expected_layout0));
765 EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
766 expected_layout1));
767 }
768
TEST_F(SpmdPartitioningTest,OutfeedReplicated)769 TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
770 absl::string_view hlo_string = R"(
771 HloModule module
772
773 ENTRY entry {
774 token.0 = token[] after-all()
775 data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
776 {replicated}}
777 ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
778 {replicated}}
779 })";
780 TF_ASSERT_OK_AND_ASSIGN(auto module,
781 PartitionComputation(hlo_string, /*num_devices=*/2));
782 VLOG(1) << module->ToString();
783 HloInstruction* root = module->entry_computation()->root_instruction();
784 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
785 op::Outfeed(op::Parameter(), op::AfterAll())));
786 }
787
TEST_F(SpmdPartitioningTest,OutfeedUnevenlyTiled)788 TEST_F(SpmdPartitioningTest, OutfeedUnevenlyTiled) {
789 absl::string_view hlo_string = R"(
790 HloModule module
791
792 ENTRY entry {
793 token.0 = token[] after-all()
794 data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
795 {devices=[2]0,1}}
796 outfeed = token[] outfeed(data, token.0),
797 outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
798 {devices=[2]0,1}}
799 })";
800 TF_ASSERT_OK_AND_ASSIGN(auto module,
801 PartitionComputation(hlo_string, /*num_devices=*/2));
802 VLOG(1) << module->ToString();
803
804 HloInstruction* root = module->entry_computation()->root_instruction();
805 EXPECT_THAT(
806 root, AllOf(op::Shape("token[]"),
807 op::Conditional(op::Convert(),
808 op::Tuple(op::Parameter(), op::AfterAll()),
809 op::Tuple(op::Parameter(), op::AfterAll()))));
810
811 auto first_outfeed =
812 AllOf(op::Shape("(f32[512,2], f32[2])"), op::GetTupleElement());
813 EXPECT_THAT(root->called_computations()[0]->root_instruction(),
814 AllOf(op::Shape("token[]"),
815 op::Outfeed(first_outfeed, op::GetTupleElement())));
816
817 auto second_outfeed = AllOf(op::Shape("(f32[511,2], f32[1])"), op::Tuple());
818 EXPECT_THAT(root->called_computations()[1]->root_instruction(),
819 AllOf(op::Shape("token[]"),
820 op::Outfeed(second_outfeed, op::GetTupleElement())));
821
822 auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
823 auto expected_layout1 = LayoutUtil::MakeLayout({0});
824 auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
825 auto second_outfeed_instr =
826 root->called_computations()[1]->root_instruction();
827 EXPECT_TRUE(LayoutUtil::Equal(
828 first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
829 expected_layout0));
830 EXPECT_TRUE(LayoutUtil::Equal(
831 first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
832 expected_layout1));
833 EXPECT_TRUE(LayoutUtil::Equal(
834 second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
835 expected_layout0));
836 EXPECT_TRUE(LayoutUtil::Equal(
837 second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
838 expected_layout1));
839 }
840
TEST_F(SpmdPartitioningTest,ReduceWindowReplicatedInput)841 TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
842 absl::string_view hlo_string = R"(
843 HloModule module
844
845 sum {
846 a = f32[] parameter(0)
847 b = f32[] parameter(1)
848 ROOT add = f32[] add(a, b)
849 }
850
851 ENTRY entry {
852 constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
853 sharding={replicated}
854 constant.1 = f32[] constant(0), sharding={replicated}
855 ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1),
856 window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum,
857 sharding={devices=[2,1]0,1}
858 })";
859 TF_ASSERT_OK_AND_ASSIGN(auto module,
860 PartitionComputation(hlo_string, /*num_devices=*/2));
861 VLOG(1) << module->ToString();
862 HloInstruction* root = module->entry_computation()->root_instruction();
863 EXPECT_THAT(
864 root,
865 AllOf(op::Shape("f32[2,2]{1,0}"),
866 op::ReduceWindow(
867 op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"),
868 op::Pad(op::Constant(), op::Constant())),
869 op::Multiply(op::Reshape(), op::Constant()),
870 op::Constant()),
871 op::Constant())));
872 }
873
TEST_F(SpmdPartitioningTest,ReduceWindowTiledNegativeLeftHalo)874 TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) {
875 absl::string_view hlo_string = R"(
876 HloModule module
877
878 sum {
879 a = f32[] parameter(0)
880 b = f32[] parameter(1)
881 ROOT add = f32[] add(a, b)
882 }
883
884 ENTRY entry {
885 constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
886 sharding={devices=[2,1]0,1}
887 constant.1 = f32[] constant(0), sharding={replicated}
888 ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1),
889 window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum,
890 sharding={devices=[2,1]0,1}
891 })";
892 TF_ASSERT_OK_AND_ASSIGN(auto module,
893 PartitionComputation(hlo_string, /*num_devices=*/2));
894 VLOG(1) << module->ToString();
895 HloInstruction* root = module->entry_computation()->root_instruction();
896
897 auto sharded_input =
898 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
899 auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
900 op::CollectivePermute(op::Slice(sharded_input)));
901 auto pre_masking = op::DynamicSlice(
902 AllOf(
903 op::Shape("f32[6,2]{1,0}"),
904 op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
905 op::Reshape(), op::Constant());
906 auto index_in_padded = op::Add(
907 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
908 auto masked =
909 op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
910 pre_masking, op::Broadcast(op::Constant()));
911 EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
912 op::ReduceWindow(masked, op::Constant())));
913 }
914
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideHaloBeyondNeighbor)915 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
916 absl::string_view hlo_string = R"(
917 HloModule module
918
919 sum {
920 a = f32[] parameter(0)
921 b = f32[] parameter(1)
922 ROOT add = f32[] add(a, b)
923 }
924
925 ENTRY entry {
926 param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
927 constant.1 = f32[] constant(0), sharding={replicated}
928 ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
929 window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
930 sharding={devices=[5,1]0,1,2,3,4}
931 })";
932 TF_ASSERT_OK_AND_ASSIGN(auto module,
933 PartitionComputation(hlo_string, /*num_devices=*/5));
934 VLOG(1) << module->ToString();
935 auto halo0 = AllOf(op::Shape("f32[1,2]"),
936 op::CollectivePermute(op::Slice(op::Parameter(0))));
937 auto halo1 =
938 AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
939 auto pre_mask =
940 AllOf(op::Shape("f32[4,2]"),
941 op::Concatenate(halo0, halo1, op::Slice(op::Parameter(0))));
942 auto masked =
943 op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
944 op::Broadcast(op::Constant())),
945 pre_mask, op::Broadcast(op::Constant()));
946 HloInstruction* root = module->entry_computation()->root_instruction();
947 EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
948 op::ReduceWindow(masked, op::Constant())));
949 }
950
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideUnequalHalo)951 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
952 absl::string_view hlo_string = R"(
953 HloModule module
954
955 sum {
956 a = f32[] parameter(0)
957 b = f32[] parameter(1)
958 ROOT add = f32[] add(a, b)
959 }
960
961 ENTRY entry {
962 constant = f32[9,2]{1,0} constant(
963 {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}),
964 sharding={devices=[3,1]0,1,2}
965 constant.1 = f32[] constant(0), sharding={replicated}
966 ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1),
967 window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum,
968 sharding={devices=[3,1]0,1,2}
969 })";
970 TF_ASSERT_OK_AND_ASSIGN(auto module,
971 PartitionComputation(hlo_string, /*num_devices=*/3));
972 VLOG(1) << module->ToString();
973 HloInstruction* root = module->entry_computation()->root_instruction();
974
975 auto sharded_input =
976 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
977 auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
978 op::CollectivePermute(op::Slice(sharded_input)));
979 auto pre_masking = op::DynamicSlice(
980 AllOf(
981 op::Shape("f32[7,2]{1,0}"),
982 op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
983 op::Reshape(), op::Constant());
984 auto index_in_padded = op::Add(
985 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
986 auto masked = op::Select(
987 op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
988 op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
989 pre_masking, op::Broadcast(op::Constant()));
990 EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
991 op::ReduceWindow(masked, op::Constant())));
992 }
993
TEST_F(SpmdPartitioningTest,ReduceWindowTiledTwoSideHalo)994 TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) {
995 absl::string_view hlo_string = R"(
996 HloModule module
997
998 sum {
999 a = f32[] parameter(0)
1000 b = f32[] parameter(1)
1001 ROOT add = f32[] add(a, b)
1002 }
1003
1004 ENTRY entry {
1005 constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}),
1006 sharding={devices=[2,1]0,1}
1007 constant.1 = f32[] constant(0), sharding={replicated}
1008 ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1),
1009 window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum,
1010 sharding={devices=[2,1]0,1}
1011 })";
1012 TF_ASSERT_OK_AND_ASSIGN(auto module,
1013 PartitionComputation(hlo_string, /*num_devices=*/2));
1014 VLOG(1) << module->ToString();
1015 HloInstruction* root = module->entry_computation()->root_instruction();
1016
1017 auto sharded_input =
1018 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
1019 auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
1020 op::CollectivePermute(op::Slice(sharded_input)));
1021 auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
1022 op::CollectivePermute(op::Slice(sharded_input)));
1023 auto pre_masking = AllOf(
1024 op::Shape("f32[5,2]{1,0}"),
1025 op::DynamicSlice(
1026 AllOf(op::Shape("f32[6,2]{1,0}"),
1027 op::Pad(op::Concatenate(left_halo, sharded_input, right_halo),
1028 op::Constant())),
1029 op::Reshape(), op::Constant()));
1030 auto index_in_padded = op::Add(
1031 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1032 auto masked = op::Select(
1033 op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
1034 op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
1035 pre_masking, op::Broadcast(op::Constant()));
1036 EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
1037 op::ReduceWindow(masked, op::Constant())));
1038 }
1039
TEST_F(SpmdPartitioningTest,ReduceWindowTiled2D)1040 TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) {
1041 absl::string_view hlo_string = R"(
1042 HloModule module
1043
1044 sum {
1045 a = f32[] parameter(0)
1046 b = f32[] parameter(1)
1047 ROOT add = f32[] add(a, b)
1048 }
1049
1050 ENTRY entry {
1051 token0 = token[] after-all(), sharding={maximal device=0}
1052 infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0),
1053 sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}}
1054 infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0,
1055 sharding={devices=[2,2,1,1]0,1,2,3}
1056 constant = f32[] constant(0), sharding={replicated}
1057 ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant),
1058 window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum,
1059 sharding={devices=[2,2,1,1]0,1,2,3}
1060 })";
1061 TF_ASSERT_OK_AND_ASSIGN(auto module,
1062 PartitionComputation(hlo_string, /*num_devices=*/4));
1063 VLOG(1) << module->ToString();
1064 HloInstruction* root = module->entry_computation()->root_instruction();
1065
1066 auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"),
1067 op::GetTupleElement(op::Infeed()));
1068 auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1069 op::CollectivePermute(op::Slice(sharded_input)));
1070 auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1071 op::CollectivePermute(op::Slice(sharded_input)));
1072 auto dim0_pre_masking = op::DynamicSlice(
1073 AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"),
1074 op::Pad(
1075 op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo),
1076 op::Constant())),
1077 op::Reshape(), op::Constant(), op::Constant(), op::Constant());
1078 auto dim0_index_in_padded = op::Add(
1079 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1080 auto dim0_masked = op::Select(
1081 op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())),
1082 op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))),
1083 dim0_pre_masking, op::Broadcast(op::Constant()));
1084 auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked);
1085 auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1086 op::CollectivePermute(op::Slice(dim0_resharded)));
1087 auto dim1_right_halo =
1088 AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1089 op::CollectivePermute(op::Slice(dim0_resharded)));
1090 auto dim1_pre_masking = op::DynamicSlice(
1091 AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"),
1092 op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded,
1093 dim1_right_halo),
1094 op::Constant())),
1095 op::Constant(), op::Reshape(), op::Constant(), op::Constant());
1096 auto dim1_index_in_padded = op::Add(
1097 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1098 auto dim1_masked = op::Select(
1099 op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())),
1100 op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))),
1101 dim1_pre_masking, op::Broadcast(op::Constant()));
1102 auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked);
1103 EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"),
1104 op::ReduceWindow(dim1_resharded, op::Constant())));
1105 }
1106
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicated)1107 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) {
1108 absl::string_view hlo_string = R"(
1109 HloModule module
1110
1111 ENTRY entry {
1112 %lhs = f32[128,224,224,3] parameter(0)
1113 %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1114 sharding={devices=[1,2,1,1]0,1}
1115 %rhs = f32[7,7,3,64] parameter(1)
1116 %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1117 sharding={replicated}
1118 ROOT %conv = f32[128,112,112,64] convolution(
1119 f32[128,224,224,3] %lhs.copy,
1120 f32[7,7,3,64] %rhs.copy),
1121 window={size=7x7 stride=2x2 pad=3_3x3_3},
1122 dim_labels=b01f_01io->b01f,
1123 sharding={devices=[1,2,1,1]0,1}
1124 })";
1125
1126 TF_ASSERT_OK_AND_ASSIGN(auto module,
1127 PartitionComputation(hlo_string, /*num_devices=*/2));
1128 VLOG(1) << module->ToString();
1129
1130 const auto root = module->entry_computation()->root_instruction();
1131 const auto lhs = AllOf(
1132 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1133 op::Constant(), op::Constant())),
1134 op::Shape("f32[128,112,224,3]"));
1135 const auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1136
1137 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1138 op::Shape("f32[128,3,224,3]"));
1139 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1140 op::Shape("f32[128,2,224,3]"));
1141 EXPECT_THAT(root,
1142 AllOf(op::Convolution(
1143 op::Select(op::And(),
1144 op::Concatenate(left_halo, lhs, right_halo),
1145 op::Broadcast()),
1146 rhs),
1147 op::Shape("f32[128,56,112,64]")));
1148 }
1149
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedNeedReshard)1150 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) {
1151 absl::string_view hlo_string = R"(
1152 HloModule module
1153
1154 ENTRY entry {
1155 %lhs = f32[128,224,224,3] parameter(0)
1156 %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1157 sharding={devices=[2,1,1,1]0,1}
1158 %rhs = f32[7,7,3,64] parameter(1)
1159 %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1160 sharding={replicated}
1161 ROOT %conv = f32[128,112,112,64] convolution(
1162 f32[128,224,224,3] %lhs.copy,
1163 f32[7,7,3,64] %rhs.copy),
1164 window={size=7x7 stride=2x2 pad=3_3x3_3},
1165 dim_labels=b01f_01io->b01f,
1166 sharding={devices=[1,2,1,1]0,1}
1167 })";
1168
1169 TF_ASSERT_OK_AND_ASSIGN(auto module,
1170 PartitionComputation(hlo_string, /*num_devices=*/2));
1171 VLOG(1) << module->ToString();
1172
1173 const auto root = module->entry_computation()->root_instruction();
1174 const auto lhs = AllOf(
1175 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1176 op::Constant(), op::Constant())),
1177 op::Shape("f32[64,224,224,3]"));
1178 auto all_to_all =
1179 AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]"));
1180 auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)),
1181 op::Shape("f32[128,112,224,3]"));
1182
1183 const auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1184
1185 auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1186 op::Shape("f32[128,3,224,3]"));
1187 auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1188 op::Shape("f32[128,2,224,3]"));
1189 EXPECT_THAT(
1190 root,
1191 AllOf(op::Convolution(
1192 op::Select(op::And(),
1193 op::Concatenate(left_halo, reshard_lhs, right_halo),
1194 op::Broadcast()),
1195 rhs),
1196 op::Shape("f32[128,56,112,64]")));
1197 }
1198
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedReordered)1199 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) {
1200 absl::string_view hlo_string = R"(
1201 HloModule module
1202
1203 ENTRY entry {
1204 %lhs = f32[224,224,3,128] parameter(0)
1205 %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1}
1206 %rhs = f32[7,7,3,64] parameter(1)
1207 %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated}
1208 ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy),
1209 window={size=7x7 stride=2x2 pad=3_3x3_3},
1210 dim_labels=01fb_01io->b01f,
1211 sharding={devices=[1,2,1,1]0,1}
1212 })";
1213
1214 TF_ASSERT_OK_AND_ASSIGN(auto module,
1215 PartitionComputation(hlo_string, /*num_devices=*/2));
1216 VLOG(1) << module->ToString();
1217
1218 const auto root = module->entry_computation()->root_instruction();
1219 const auto lhs = AllOf(
1220 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1221 op::Constant(), op::Constant())),
1222 op::Shape("f32[112,224,3,128]"));
1223 const auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1224
1225 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1226 op::Shape("f32[3,224,3,128]"));
1227 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1228 op::Shape("f32[2,224,3,128]"));
1229 EXPECT_THAT(root,
1230 AllOf(op::Convolution(
1231 op::Select(op::And(),
1232 op::Concatenate(left_halo, lhs, right_halo),
1233 op::Broadcast()),
1234 rhs),
1235 op::Shape("f32[128,56,112,64]")));
1236 }
1237
1238 // (stride * per_shard_window_count) % dilation == 0
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated)1239 TEST_F(SpmdPartitioningTest,
1240 ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) {
1241 absl::string_view hlo_string = R"(
1242 HloModule module
1243
1244 ENTRY entry {
1245 %lhs = f32[128,7,7,512] parameter(0)
1246 %lhs.copy = f32[128,7,7,512] copy(%lhs),
1247 sharding={devices=[1,2,1,1]0,1}
1248 %rhs = f32[3,3,512,512] parameter(1)
1249 %rhs.copy = f32[3,3,512,512] copy(%rhs),
1250 sharding={replicated}
1251 ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy),
1252 window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1},
1253 dim_labels=b01f_01io->b01f,
1254 sharding={devices=[1,2,1,1]0,1}
1255 })";
1256
1257 TF_ASSERT_OK_AND_ASSIGN(auto module,
1258 PartitionComputation(hlo_string, /*num_devices=*/2));
1259 VLOG(1) << module->ToString();
1260
1261 const auto root = module->entry_computation()->root_instruction();
1262 // There is no halo exchange, and because the last element in the shard is not
1263 // needed (stride == 4), the LHS will be just a slice.
1264 auto sliced_lhs =
1265 AllOf(op::Slice(op::Copy(op::DynamicSlice(
1266 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1267 op::Reshape(), op::Constant(), op::Constant()))),
1268 op::Shape("f32[128,3,7,512]"));
1269 const auto rhs =
1270 AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1271 EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs),
1272 op::Shape("f32[128,2,4,512]")));
1273 EXPECT_EQ(root->window().dimensions(0).padding_low(), 1);
1274 EXPECT_EQ(root->window().dimensions(0).padding_high(), 1);
1275 }
1276
1277 // (stride * per_shard_window_count) % dilation != 0 but stride == 1
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationStride1LhsTiledRhsReplicated)1278 TEST_F(SpmdPartitioningTest,
1279 ConvolutionBaseDilationStride1LhsTiledRhsReplicated) {
1280 absl::string_view hlo_string = R"(
1281 HloModule module
1282
1283 ENTRY entry {
1284 %lhs = f32[128,7,7,512] parameter(0)
1285 %lhs.copy = f32[128,7,7,512] copy(%lhs),
1286 sharding={devices=[1,2,1,1]0,1}
1287 %rhs = f32[3,3,512,512] parameter(1)
1288 %rhs.copy = f32[3,3,512,512] copy(%rhs),
1289 sharding={replicated}
1290 ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy),
1291 window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1},
1292 dim_labels=b01f_01io->b01f,
1293 sharding={devices=[1,2,1,1]0,1}
1294 })";
1295
1296 TF_ASSERT_OK_AND_ASSIGN(auto module,
1297 PartitionComputation(hlo_string, /*num_devices=*/2));
1298 VLOG(1) << module->ToString();
1299
1300 const auto root = module->entry_computation()->root_instruction();
1301 const auto lhs =
1302 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1303 op::Constant(), op::Reshape(),
1304 op::Constant(), op::Constant())),
1305 op::Shape("f32[128,4,7,512]"));
1306 const auto rhs =
1307 AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1308
1309 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1310 op::Shape("f32[128,1,7,512]"));
1311 auto start_window = op::Multiply(op::Reshape(), op::Constant());
1312 auto start_input_element = op::Divide(start_window, op::Constant());
1313 auto dynamic_offset_for_padded_concat = op::Subtract(
1314 op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1315 start_input_element));
1316 auto pre_masking =
1317 AllOf(op::Shape("f32[128,5,7,512]"),
1318 op::DynamicSlice(
1319 AllOf(op::Shape("f32[128,6,7,512]"),
1320 op::Pad(op::Concatenate(left_halo, lhs), op::Constant())),
1321 op::Constant(), dynamic_offset_for_padded_concat,
1322 op::Constant(), op::Constant()));
1323 auto masked = op::Select(
1324 op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)),
1325 op::Broadcast(op::Constant())),
1326 pre_masking, op::Broadcast(op::Constant()));
1327 auto dynamic_offset_on_output = op::Subtract(
1328 start_window, op::Multiply(start_input_element, op::Constant()));
1329 EXPECT_THAT(root,
1330 AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs),
1331 op::Shape("f32[128,8,14,512]")),
1332 op::Constant(), dynamic_offset_on_output,
1333 op::Constant(), op::Constant()),
1334 op::Shape("f32[128,7,14,512]")));
1335 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1);
1336 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1337 }
1338
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlap)1339 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) {
1340 absl::string_view hlo_string = R"(
1341 HloModule module
1342
1343 ge {
1344 a = f32[] parameter(0)
1345 b = f32[] parameter(1)
1346 ROOT compare = pred[] compare(a, b), direction=GE
1347 }
1348
1349 sum {
1350 c = f32[] parameter(0)
1351 d = f32[] parameter(1)
1352 ROOT add = f32[] add(c, d)
1353 }
1354
1355 ENTRY entry {
1356 %param = f32[11,4]{1,0} parameter(0)
1357 %param.copy = f32[11,4] copy(%param),
1358 sharding={devices=[4,1]0,1,2,3}
1359 constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1360 sharding={devices=[4,1]0,1,2,3}
1361 constant.1 = f32[] constant(0), sharding={replicated}
1362 ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1363 constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1364 select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1365 })";
1366 TF_ASSERT_OK_AND_ASSIGN(auto module,
1367 PartitionComputation(hlo_string, /*num_devices=*/4));
1368 VLOG(1) << module->ToString();
1369 const auto root = module->entry_computation()->root_instruction();
1370 auto source =
1371 AllOf(op::Shape("f32[1,2]{1,0}"),
1372 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1373 auto masked_data = AllOf(
1374 op::Shape("f32[3,4]{1,0}"),
1375 op::Select(
1376 op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1377 op::Reshape(), op::Constant()))),
1378 op::Broadcast(op::Constant())),
1379 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1380 op::Reshape(), op::Constant())),
1381 op::Broadcast(op::Constant())));
1382
1383 EXPECT_THAT(root,
1384 AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1385 op::Shape("f32[3,4]{1,0}")));
1386 EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1387 EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1388 }
1389
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlapReshard)1390 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) {
1391 absl::string_view hlo_string = R"(
1392 HloModule module
1393
1394 ge {
1395 a = f32[] parameter(0)
1396 b = f32[] parameter(1)
1397 ROOT compare = pred[] compare(a, b), direction=GE
1398 }
1399
1400 sum {
1401 c = f32[] parameter(0)
1402 d = f32[] parameter(1)
1403 ROOT add = f32[] add(c, d)
1404 }
1405
1406 ENTRY entry {
1407 %param = f32[11,4]{1,0} parameter(0)
1408 %param.copy = f32[11,4] copy(%param),
1409 sharding={devices=[1,4]0,1,2,3}
1410 constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1411 sharding={devices=[4,1]0,1,2,3}
1412 constant.1 = f32[] constant(0), sharding={replicated}
1413 ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1414 constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1415 select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1416 })";
1417 TF_ASSERT_OK_AND_ASSIGN(auto module,
1418 PartitionComputation(hlo_string, /*num_devices=*/4));
1419 VLOG(1) << module->ToString();
1420 const auto root = module->entry_computation()->root_instruction();
1421 auto source =
1422 AllOf(op::Shape("f32[1,2]{1,0}"),
1423 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1424 auto operand = AllOf(op::Copy(op::DynamicSlice(
1425 op::Parameter(0), op::Constant(), op::Reshape())),
1426 op::Shape("f32[11,1]"));
1427 auto reshard_operand = op::Reshape(op::Transpose(
1428 op::AllToAll(op::Reshape(op::Pad(operand, op::Constant())))));
1429 auto masked_data = AllOf(
1430 op::Shape("f32[3,4]{1,0}"),
1431 op::Select(
1432 op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1433 op::Reshape(), op::Constant()))),
1434 op::Broadcast(op::Constant())),
1435 reshard_operand, op::Broadcast(op::Constant())));
1436
1437 EXPECT_THAT(root,
1438 AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1439 op::Shape("f32[3,4]{1,0}")));
1440 EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1441 EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1442 }
1443
TEST_F(SpmdPartitioningTest,SelectAndScatterWithOverlap)1444 TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) {
1445 absl::string_view hlo_string = R"(
1446 HloModule module
1447
1448 ge {
1449 a = f32[] parameter(0)
1450 b = f32[] parameter(1)
1451 ROOT compare = pred[] compare(a, b), direction=GE
1452 }
1453
1454 sum {
1455 c = f32[] parameter(0)
1456 d = f32[] parameter(1)
1457 ROOT add = f32[] add(c, d)
1458 }
1459
1460 ENTRY entry {
1461 %param = f32[11,4]{1,0} parameter(0)
1462 %param.copy = f32[11,4] copy(%param),
1463 sharding={devices=[4,1]0,1,2,3}
1464 constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}),
1465 sharding={devices=[4,1]0,1,2,3}
1466 constant.1 = f32[] constant(0), sharding={replicated}
1467 ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1468 constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0},
1469 select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1470 })";
1471 TF_ASSERT_OK_AND_ASSIGN(auto module,
1472 PartitionComputation(hlo_string, /*num_devices=*/4));
1473 VLOG(1) << module->ToString();
1474 const auto root = module->entry_computation()->root_instruction();
1475
1476 auto source_shard =
1477 AllOf(op::Shape("f32[2,2]{1,0}"),
1478 op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant()));
1479 // Max halo size is the same as the shard size, so slice is not needed.
1480 auto source_left_halo = op::CollectivePermute(source_shard);
1481 auto required_source_shard_start =
1482 op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
1483 auto source_with_halo = op::DynamicSlice(
1484 AllOf(op::Shape("f32[5,2]{1,0}"),
1485 op::Pad(op::Concatenate(source_left_halo, source_shard),
1486 op::Constant())),
1487 op::Subtract(op::Constant(),
1488 op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1489 required_source_shard_start)),
1490 op::Constant());
1491 auto masked_source_with_halo = AllOf(
1492 AllOf(op::Shape("f32[3,2]{1,0}")),
1493 op::Select(
1494 op::Compare(
1495 op::Add(op::Iota(), op::Broadcast(required_source_shard_start)),
1496 op::Broadcast(op::Constant())),
1497 source_with_halo, op::Broadcast(op::Constant())));
1498
1499 auto data_shard =
1500 AllOf(op::Shape("f32[3,4]{1,0}"),
1501 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1502 op::Reshape(), op::Constant())));
1503 auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1504 op::CollectivePermute(op::Slice(data_shard)));
1505 auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1506 op::CollectivePermute(op::Slice(data_shard)));
1507 auto required_data_start_on_padded =
1508 op::Multiply(required_source_shard_start, op::Constant());
1509 auto left_halo_size = op::Subtract(
1510 op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()),
1511 required_data_start_on_padded);
1512 auto data_with_halo =
1513 AllOf(op::Shape("f32[7,4]{1,0}"),
1514 op::DynamicSlice(
1515 AllOf(op::Shape("f32[8,4]{1,0}"),
1516 op::Pad(op::Concatenate(data_left_halo, data_shard,
1517 data_right_halo),
1518 op::Constant())),
1519 op::Subtract(op::Constant(), left_halo_size), op::Constant()));
1520 auto index_on_padded =
1521 op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded));
1522 auto masked_data_with_halo = op::Select(
1523 op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())),
1524 op::Compare(index_on_padded, op::Broadcast(op::Constant()))),
1525 data_with_halo, op::Broadcast(op::Constant()));
1526
1527 EXPECT_THAT(
1528 root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo,
1529 masked_source_with_halo,
1530 op::Constant()),
1531 left_halo_size, op::Constant()),
1532 op::Shape("f32[3,4]{1,0}")));
1533 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0);
1534 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1535 }
1536
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled)1537 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) {
1538 absl::string_view hlo_string = R"(
1539 HloModule module
1540
1541 ENTRY entry {
1542 %lhs = f32[128,56,56,64] parameter(0)
1543 %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1544 %rhs = f32[128,56,56,256] parameter(1)
1545 %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1546 ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1547 window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1548 })";
1549
1550 TF_ASSERT_OK_AND_ASSIGN(auto module,
1551 PartitionComputation(hlo_string, /*num_devices=*/2));
1552 VLOG(1) << module->ToString();
1553
1554 const auto root = module->entry_computation()->root_instruction();
1555 const auto lhs = AllOf(
1556 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1557 op::Constant(), op::Constant())),
1558 op::Shape("f32[128,28,56,64]"));
1559 const auto rhs = AllOf(
1560 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1561 op::Constant(), op::Constant())),
1562 op::Shape("f32[128,28,56,256]"));
1563
1564 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1565 op::Shape("f32[1,1,64,256]")));
1566 }
1567
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowReversal)1568 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) {
1569 absl::string_view hlo_string = R"(
1570 HloModule module
1571
1572 ENTRY entry {
1573 %lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1}
1574 %rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0}
1575 ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs),
1576 window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf,
1577 sharding={replicated}
1578 })";
1579
1580 TF_ASSERT_OK_AND_ASSIGN(auto module,
1581 PartitionComputation(hlo_string, /*num_devices=*/2));
1582 VLOG(1) << module->ToString();
1583
1584 const auto lhs_masked =
1585 AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _));
1586 const auto rhs_left_padded =
1587 op::Concatenate(op::CollectivePermute(op::Slice(op::Parameter(1))),
1588 op::Slice(op::Parameter(1)));
1589 const auto rhs_masked =
1590 AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _));
1591
1592 const auto root = module->entry_computation()->root_instruction();
1593 EXPECT_THAT(root,
1594 AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)),
1595 op::Shape("f32[1,64,256]")));
1596 }
1597
TEST_F(SpmdPartitioningTest,DotLhsTiledRhsTiledWithReshard)1598 TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) {
1599 absl::string_view hlo_string = R"(
1600 HloModule module
1601
1602 ENTRY entry {
1603 %lhs = f32[128,56,56,64] parameter(0)
1604 %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1605 %rhs = f32[128,56,56,256] parameter(1)
1606 %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1607 ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1608 window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1609 })";
1610
1611 TF_ASSERT_OK_AND_ASSIGN(auto module,
1612 PartitionComputation(hlo_string, /*num_devices=*/2));
1613 VLOG(1) << module->ToString();
1614
1615 const auto root = module->entry_computation()->root_instruction();
1616 const auto lhs = AllOf(
1617 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1618 op::Constant(), op::Constant())),
1619 op::Shape("f32[128,28,56,64]"));
1620 const auto rhs = AllOf(
1621 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1622 op::Constant(), op::Constant())),
1623 op::Shape("f32[64,56,56,256]"));
1624 auto all_to_all =
1625 AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]"));
1626 auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all)));
1627
1628 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)),
1629 op::Shape("f32[1,1,64,256]")));
1630 }
1631
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithReshard)1632 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) {
1633 absl::string_view hlo_string = R"(
1634 HloModule module
1635
1636 ENTRY entry {
1637 %lhs = f32[128,56,56,512] parameter(0)
1638 %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1639 %rhs = f32[128,28,28,64] parameter(1)
1640 %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1641 ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy),
1642 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2},
1643 dim_labels=f01b_i01o->01bf, sharding={replicated}
1644 })";
1645
1646 TF_ASSERT_OK_AND_ASSIGN(auto module,
1647 PartitionComputation(hlo_string, /*num_devices=*/2));
1648 VLOG(1) << module->ToString();
1649
1650 const auto root = module->entry_computation()->root_instruction();
1651 const auto lhs = AllOf(
1652 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1653 op::Constant(), op::Constant())),
1654 op::Shape("f32[128,28,56,512]"));
1655 const auto rhs = AllOf(
1656 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1657 op::Constant(), op::Constant())),
1658 op::Shape("f32[64,28,28,64]"));
1659 auto all_to_all =
1660 AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]"));
1661 auto reshard = op::Reshape(op::Transpose(all_to_all));
1662
1663 EXPECT_THAT(root,
1664 AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)),
1665 op::Shape("f32[1,1,512,64]")));
1666 }
1667
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned)1668 TEST_F(SpmdPartitioningTest,
1669 ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned) {
1670 absl::string_view hlo_string = R"(
1671 HloModule module
1672
1673 ENTRY entry {
1674 %lhs = f32[8,28,28,8] parameter(0)
1675 %lhs.copy = f32[8,28,28,8] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
1676 %rhs = f32[8,14,14,64] parameter(1)
1677 %rhs.copy = f32[8,14,14,64] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
1678 ROOT %conv = f32[1,1,8,64] convolution(%lhs.copy, %rhs.copy),
1679 window={size=14x14 pad=0_-1x0_-1 rhs_dilate=2x2},
1680 dim_labels=f01b_i01o->01bf, sharding={replicated}
1681 })";
1682
1683 TF_ASSERT_OK_AND_ASSIGN(auto module,
1684 PartitionComputation(hlo_string, /*num_devices=*/4));
1685 VLOG(1) << module->ToString();
1686
1687 const auto root = module->entry_computation()->root_instruction();
1688 const auto lhs = AllOf(
1689 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1690 op::Constant(), op::Constant())),
1691 op::Shape("f32[8,7,28,8]"));
1692 const auto rhs = AllOf(op::Pad(op::Parameter(), op::Constant()),
1693 op::Shape("f32[8,16,14,64]"));
1694 auto selected_rhs = AllOf(
1695 op::Select(op::Compare(),
1696 op::Copy(op::DynamicSlice(rhs, op::Constant(), op::Reshape(),
1697 op::Constant(), op::Constant())),
1698 op::Broadcast()),
1699 op::Shape("f32[8,4,14,64]"));
1700 auto right_halo =
1701 AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,2,28,8]"));
1702 auto selected_lhs =
1703 AllOf(op::DynamicSlice(
1704 op::Pad(op::Concatenate(lhs, right_halo), op::Constant()),
1705 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
1706 op::Shape("f32[8,7,28,8]"));
1707 EXPECT_THAT(root,
1708 AllOf(op::AllReduce(op::Convolution(selected_lhs, selected_rhs)),
1709 op::Shape("f32[1,1,8,64]")));
1710 }
1711
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding)1712 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) {
1713 absl::string_view hlo_string = R"(
1714 HloModule module
1715
1716 ENTRY entry {
1717 %lhs = f32[32,28,28,128] parameter(0)
1718 %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1719 %rhs = f32[32,28,28,64] parameter(1)
1720 %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1721 ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1722 window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1723 })";
1724
1725 TF_ASSERT_OK_AND_ASSIGN(
1726 auto module,
1727 PartitionComputation(hlo_string, /*num_devices=*/2,
1728 /*conv_halo_exchange_always_on_lhs=*/false));
1729 VLOG(1) << module->ToString();
1730
1731 const auto root = module->entry_computation()->root_instruction();
1732 const auto lhs = AllOf(
1733 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1734 op::Constant(), op::Constant())),
1735 op::Shape("f32[32,14,28,128]"));
1736 const auto rhs = AllOf(
1737 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1738 op::Constant(), op::Constant())),
1739 op::Shape("f32[32,14,28,64]"));
1740
1741 auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1742 op::Shape("f32[32,1,28,64]"));
1743 auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1744 op::Shape("f32[32,1,28,64]"));
1745 EXPECT_THAT(root,
1746 AllOf(op::AllReduce(op::Convolution(
1747 lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1748 op::Shape("f32[32,16,28,64]")))),
1749 op::Shape("f32[3,3,128,64]")));
1750 }
1751
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate)1752 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) {
1753 absl::string_view hlo_string = R"(
1754 HloModule module
1755
1756 ENTRY entry {
1757 %lhs = f32[128,224,224,3] parameter(0)
1758 %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1759 %rhs = f32[128,112,112,64] parameter(1)
1760 %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1761 ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1762 window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1763 })";
1764
1765 TF_ASSERT_OK_AND_ASSIGN(
1766 auto module,
1767 PartitionComputation(hlo_string, /*num_devices=*/2,
1768 /*conv_halo_exchange_always_on_lhs=*/false));
1769 VLOG(1) << module->ToString();
1770
1771 const auto root = module->entry_computation()->root_instruction();
1772 const auto lhs = AllOf(
1773 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1774 op::Constant(), op::Constant())),
1775 op::Shape("f32[128,112,224,3]"));
1776 const auto rhs = AllOf(
1777 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1778 op::Constant(), op::Constant())),
1779 op::Shape("f32[128,56,112,64]"));
1780
1781 auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1782 op::Shape("f32[128,2,112,64]"));
1783 auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1784 op::Shape("f32[128,2,112,64]"));
1785 EXPECT_THAT(root,
1786 AllOf(op::AllReduce(op::Convolution(
1787 lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1788 op::Shape("f32[128,60,112,64]")))),
1789 op::Shape("f32[7,7,3,64]")));
1790 }
1791
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding)1792 TEST_F(SpmdPartitioningTest,
1793 ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) {
1794 absl::string_view hlo_string = R"(
1795 HloModule module
1796
1797 ENTRY entry {
1798 %lhs = f32[128,56,56,256] parameter(0)
1799 %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1800 %rhs = f32[128,28,28,512] parameter(1)
1801 %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1802 ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1803 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1804 })";
1805
1806 TF_ASSERT_OK_AND_ASSIGN(
1807 auto module,
1808 PartitionComputation(hlo_string, /*num_devices=*/2,
1809 /*conv_halo_exchange_always_on_lhs=*/false));
1810 VLOG(1) << module->ToString();
1811
1812 const auto root = module->entry_computation()->root_instruction();
1813 const auto lhs = AllOf(
1814 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1815 op::Constant(), op::Constant())),
1816 op::Shape("f32[128,28,56,256]"));
1817 const auto rhs = AllOf(
1818 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1819 op::Constant(), op::Constant())),
1820 op::Shape("f32[128,14,28,512]"));
1821
1822 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1823 op::Shape("f32[1,1,256,512]")));
1824 }
1825
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven)1826 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) {
1827 absl::string_view hlo_string = R"(
1828 HloModule module
1829
1830 ENTRY entry {
1831 %lhs = f32[128,14,14,512] parameter(0)
1832 %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1833 %rhs = f32[128,7,7,512] parameter(1)
1834 %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1835 ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1836 window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1837 })";
1838
1839 TF_ASSERT_OK_AND_ASSIGN(
1840 auto module,
1841 PartitionComputation(hlo_string, /*num_devices=*/2,
1842 /*conv_halo_exchange_always_on_lhs=*/false));
1843 VLOG(1) << module->ToString();
1844
1845 const auto root = module->entry_computation()->root_instruction();
1846 const auto lhs = AllOf(
1847 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1848 op::Constant(), op::Constant())),
1849 op::Shape("f32[128,7,14,512]"));
1850 const auto rhs = AllOf(
1851 op::Select(op::Compare(),
1852 op::Copy(op::DynamicSlice(
1853 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1854 op::Reshape(), op::Constant(), op::Constant())),
1855 op::Broadcast()),
1856 op::Shape("f32[128,4,7,512]"));
1857
1858 auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1859 op::Shape("f32[128,1,7,512]"));
1860 EXPECT_THAT(root,
1861 AllOf(op::AllReduce(op::Convolution(
1862 AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()),
1863 op::Constant(), op::Subtract(),
1864 op::Constant(), op::Constant()),
1865 op::Shape("f32[128,10,14,512]")),
1866 AllOf(op::Concatenate(left_halo, rhs),
1867 op::Shape("f32[128,5,7,512]")))),
1868 op::Shape("f32[3,3,512,512]")));
1869 }
1870
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs)1871 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) {
1872 absl::string_view hlo_string = R"(
1873 HloModule module
1874
1875 ENTRY entry {
1876 %lhs = f32[32,28,28,128] parameter(0)
1877 %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1878 %rhs = f32[32,28,28,64] parameter(1)
1879 %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1880 ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1881 window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1882 })";
1883
1884 TF_ASSERT_OK_AND_ASSIGN(auto module,
1885 PartitionComputation(hlo_string, /*num_devices=*/2));
1886 VLOG(1) << module->ToString();
1887
1888 const auto root = module->entry_computation()->root_instruction();
1889 const auto lhs = AllOf(
1890 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1891 op::Constant(), op::Constant())),
1892 op::Shape("f32[32,14,28,128]"));
1893 const auto rhs = AllOf(
1894 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1895 op::Constant(), op::Constant())),
1896 op::Shape("f32[32,14,28,64]"));
1897
1898 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1899 op::Shape("f32[32,1,28,128]"));
1900 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1901 op::Shape("f32[32,1,28,128]"));
1902 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1903 AllOf(op::Concatenate(left_halo, lhs, right_halo),
1904 op::Shape("f32[32,16,28,128]")),
1905 rhs)),
1906 op::Shape("f32[3,3,128,64]")));
1907 }
1908
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs)1909 TEST_F(SpmdPartitioningTest,
1910 ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) {
1911 absl::string_view hlo_string = R"(
1912 HloModule module
1913
1914 ENTRY entry {
1915 %lhs = f32[128,224,224,3] parameter(0)
1916 %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1917 %rhs = f32[128,112,112,64] parameter(1)
1918 %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1919 ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1920 window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1921 })";
1922
1923 TF_ASSERT_OK_AND_ASSIGN(auto module,
1924 PartitionComputation(hlo_string, /*num_devices=*/2));
1925 VLOG(1) << module->ToString();
1926
1927 const auto root = module->entry_computation()->root_instruction();
1928 const auto lhs = AllOf(
1929 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1930 op::Constant(), op::Constant())),
1931 op::Shape("f32[128,112,224,3]"));
1932 const auto rhs = AllOf(
1933 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1934 op::Constant(), op::Constant())),
1935 op::Shape("f32[128,56,112,64]"));
1936
1937 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1938 op::Shape("f32[128,3,224,3]"));
1939 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1940 op::Shape("f32[128,2,224,3]"));
1941 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1942 AllOf(op::Concatenate(left_halo, lhs, right_halo),
1943 op::Shape("f32[128,117,224,3]")),
1944 rhs)),
1945 op::Shape("f32[7,7,3,64]")));
1946 }
1947
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs)1948 TEST_F(SpmdPartitioningTest,
1949 ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) {
1950 absl::string_view hlo_string = R"(
1951 HloModule module
1952
1953 ENTRY entry {
1954 %lhs = f32[128,56,56,256] parameter(0)
1955 %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1956 %rhs = f32[128,28,28,512] parameter(1)
1957 %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1958 ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1959 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1960 })";
1961
1962 TF_ASSERT_OK_AND_ASSIGN(auto module,
1963 PartitionComputation(hlo_string, /*num_devices=*/2));
1964 VLOG(1) << module->ToString();
1965
1966 const auto root = module->entry_computation()->root_instruction();
1967 const auto lhs = AllOf(
1968 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1969 op::Constant(), op::Constant())),
1970 op::Shape("f32[128,28,56,256]"));
1971 const auto rhs = AllOf(
1972 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1973 op::Constant(), op::Constant())),
1974 op::Shape("f32[128,14,28,512]"));
1975
1976 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)),
1977 op::Shape("f32[1,1,256,512]")));
1978 }
1979
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs)1980 TEST_F(SpmdPartitioningTest,
1981 ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) {
1982 absl::string_view hlo_string = R"(
1983 HloModule module
1984
1985 ENTRY entry {
1986 %lhs = f32[128,14,14,512] parameter(0)
1987 %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1988 %rhs = f32[128,7,7,512] parameter(1)
1989 %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1990 ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1991 window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1992 })";
1993
1994 TF_ASSERT_OK_AND_ASSIGN(auto module,
1995 PartitionComputation(hlo_string, /*num_devices=*/2));
1996 VLOG(1) << module->ToString();
1997
1998 const auto root = module->entry_computation()->root_instruction();
1999 const auto lhs = AllOf(
2000 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
2001 op::Constant(), op::Constant())),
2002 op::Shape("f32[128,7,14,512]"));
2003 const auto rhs = AllOf(
2004 op::Select(op::Compare(),
2005 op::Copy(op::DynamicSlice(
2006 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
2007 op::Reshape(), op::Constant(), op::Constant())),
2008 op::Broadcast()),
2009 op::Shape("f32[128,4,7,512]"));
2010
2011 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
2012 op::Shape("f32[128,1,14,512]"));
2013 EXPECT_THAT(
2014 root, AllOf(op::AllReduce(op::Convolution(
2015 AllOf(op::DynamicSlice(
2016 AllOf(op::Pad(op::Concatenate(lhs, right_halo),
2017 op::Constant()),
2018 op::Shape("f32[128,10,14,512]")),
2019 op::Constant(), op::Reshape(), op::Constant(),
2020 op::Constant()),
2021 op::Shape("f32[128,9,14,512]")),
2022 rhs)),
2023 op::Shape("f32[3,3,512,512]")));
2024 }
2025
TEST_F(SpmdPartitioningTest,ConcatenateAlongNonPartitionedDimension)2026 TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) {
2027 absl::string_view hlo_string = R"(
2028 HloModule module
2029
2030 ENTRY entry {
2031 %param0 = f32[14,257] parameter(0)
2032 %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1}
2033 %param1 = f32[14,116] parameter(1)
2034 %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1}
2035 ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
2036 dimensions={1}, sharding={devices=[2,1]0,1}
2037 })";
2038
2039 TF_ASSERT_OK_AND_ASSIGN(auto module,
2040 PartitionComputation(hlo_string, /*num_devices=*/2));
2041 VLOG(1) << module->ToString();
2042
2043 const auto root = module->entry_computation()->root_instruction();
2044 auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2045 op::Constant())),
2046 op::Shape("f32[7,257]"));
2047 auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2048 op::Constant())),
2049 op::Shape("f32[7,116]"));
2050 EXPECT_THAT(root,
2051 AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]")));
2052 }
2053
TEST_F(SpmdPartitioningTest,ConcatenateAlongPartitionedDimension)2054 TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) {
2055 absl::string_view hlo_string = R"(
2056 HloModule module
2057
2058 ENTRY entry {
2059 %param0 = f32[14,257] parameter(0)
2060 %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1}
2061 %param1 = f32[14,116] parameter(1)
2062 %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1}
2063 ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
2064 dimensions={1}, sharding={devices=[1,2]0,1}
2065 })";
2066
2067 TF_ASSERT_OK_AND_ASSIGN(auto module,
2068 PartitionComputation(hlo_string, /*num_devices=*/2));
2069 VLOG(1) << module->ToString();
2070
2071 const auto root = module->entry_computation()->root_instruction();
2072 auto param0 =
2073 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2074 op::Constant(), op::Reshape())),
2075 op::Shape("f32[14,129]"));
2076 auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
2077 op::Reshape())),
2078 op::Shape("f32[14,58]"));
2079 EXPECT_THAT(root, AllOf(op::DynamicSlice(
2080 AllOf(op::AllReduce(op::DynamicUpdateSlice(
2081 op::DynamicUpdateSlice(
2082 op::Broadcast(), param0,
2083 op::Constant(), op::Multiply()),
2084 param1, op::Constant(), op::Add())),
2085 op::Shape("f32[14,374]")),
2086 op::Constant(), op::Multiply()),
2087 op::Shape("f32[14,187]")));
2088 }
2089
TEST_F(SpmdPartitioningTest,ConcatenateAlongBothDimensions)2090 TEST_F(SpmdPartitioningTest, ConcatenateAlongBothDimensions) {
2091 const char* const hlo_string = R"(
2092 HloModule module
2093
2094 ENTRY entry {
2095 %param0 = f32[14,257] parameter(0), sharding={devices=[2,2]0,1,2,3}
2096 %param1 = f32[14,116] parameter(1), sharding={devices=[2,2]0,1,2,3}
2097 ROOT %concatenate = f32[14,373] concatenate(%param0, %param1),
2098 dimensions={1}, sharding={devices=[2,2]0,1,2,3}
2099 })";
2100
2101 TF_ASSERT_OK_AND_ASSIGN(auto module,
2102 PartitionComputation(hlo_string, /*num_devices=*/4));
2103 VLOG(1) << module->ToString();
2104
2105 const auto root = module->entry_computation()->root_instruction();
2106 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[7,129]"));
2107 auto param1 = AllOf(op::Parameter(1), op::Shape("f32[7,58]"));
2108 EXPECT_THAT(root, AllOf(op::DynamicSlice(
2109 AllOf(op::AllReduce(op::DynamicUpdateSlice(
2110 op::DynamicUpdateSlice(
2111 op::Broadcast(), param0,
2112 op::Constant(), op::Multiply()),
2113 param1, op::Constant(), op::Add())),
2114 op::Shape("f32[7,374]")),
2115 op::Constant(), op::Multiply()),
2116 op::Shape("f32[7,187]")));
2117 }
2118
TEST_F(SpmdPartitioningTest,PadAlongNonPartitionedDimension)2119 TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) {
2120 absl::string_view hlo_string = R"(
2121 HloModule module
2122
2123 ENTRY entry {
2124 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2125 %const = f32[] constant(0)
2126 ROOT %pad = f32[128,17,257] pad(%param0, %const), padding=0_0x1_2x0_0,
2127 sharding={devices=[1,1,2]0,1}
2128 })";
2129
2130 TF_ASSERT_OK_AND_ASSIGN(auto module,
2131 PartitionComputation(hlo_string, /*num_devices=*/2));
2132 VLOG(1) << module->ToString();
2133
2134 const auto root = module->entry_computation()->root_instruction();
2135 auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2136 EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()),
2137 op::Shape("f32[128,17,129]")));
2138 }
2139
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimension)2140 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimension) {
2141 absl::string_view hlo_string = R"(
2142 HloModule module
2143
2144 ENTRY entry {
2145 %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1}
2146 %const = f32[] constant(0)
2147 ROOT %pad = f32[14,259] pad(%param0, %const), padding=0_0x0_2,
2148 sharding={devices=[1,2]0,1}
2149 })";
2150
2151 TF_ASSERT_OK_AND_ASSIGN(auto module,
2152 PartitionComputation(hlo_string, /*num_devices=*/2));
2153 VLOG(1) << module->ToString();
2154
2155 const auto root = module->entry_computation()->root_instruction();
2156 auto param0 = AllOf(op::Parameter(), op::Shape("f32[14,129]"));
2157 auto after_halo_exchange =
2158 AllOf(op::Shape("f32[14,130]"),
2159 op::Concatenate(param0, op::CollectivePermute(op::Slice(param0))));
2160 auto pad = AllOf(op::Shape("f32[14,131]"),
2161 op::Pad(after_halo_exchange, op::Constant()));
2162 EXPECT_THAT(root, op::Select(_, op::DynamicSlice(pad, op::Constant(), _), _));
2163 }
2164
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimensionWithInteriorPadding)2165 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) {
2166 absl::string_view hlo_string = R"(
2167 HloModule module
2168
2169 ENTRY entry {
2170 %param0 = f32[7] parameter(0), sharding={devices=[2]0,1}
2171 %param1 = f32[] parameter(1), sharding={replicated}
2172 ROOT %pad = f32[22] pad(%param0, %param1), padding=2_1_2,
2173 sharding={devices=[2]0,1}
2174 })";
2175
2176 TF_ASSERT_OK_AND_ASSIGN(auto module,
2177 PartitionComputation(hlo_string, /*num_devices=*/2));
2178 VLOG(1) << module->ToString();
2179 const auto root = module->entry_computation()->root_instruction();
2180
2181 auto param0 = AllOf(op::Parameter(), op::Shape("f32[4]"));
2182 auto after_halo_exchange = AllOf(
2183 op::Shape("f32[4]"),
2184 op::DynamicSlice(
2185 AllOf(op::Shape("f32[5]"),
2186 op::Pad(AllOf(op::Shape("f32[4]"),
2187 op::Concatenate(
2188 op::CollectivePermute(op::Slice(param0)),
2189 op::Slice(param0))),
2190 op::Parameter(1))),
2191 _));
2192 auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2193 EXPECT_THAT(root, op::DynamicSlice(pad, _));
2194 }
2195
TEST_F(SpmdPartitioningTest,PartialReplicatePad)2196 TEST_F(SpmdPartitioningTest, PartialReplicatePad) {
2197 absl::string_view hlo_string = R"(
2198 HloModule module
2199
2200 ENTRY entry {
2201 %param0 = f32[11,7] parameter(0),
2202 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2203 %param1 = f32[] parameter(1), sharding={replicated}
2204 ROOT %pad = f32[27,22] pad(%param0, %param1), padding=2_4_1x2_1_2,
2205 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2206 })";
2207
2208 TF_ASSERT_OK_AND_ASSIGN(auto module,
2209 PartitionComputation(hlo_string, /*num_devices=*/4));
2210 VLOG(1) << module->ToString();
2211 const auto root = module->entry_computation()->root_instruction();
2212
2213 auto param0 = AllOf(op::Parameter(), op::Shape("f32[11,4]"));
2214 auto after_halo_exchange = AllOf(
2215 op::Shape("f32[11,4]"),
2216 op::DynamicSlice(
2217 AllOf(op::Shape("f32[11,5]"),
2218 op::Pad(AllOf(op::Shape("f32[11,4]"),
2219 op::Concatenate(
2220 op::CollectivePermute(op::Slice(param0)),
2221 op::Slice(param0))),
2222 op::Parameter(1))),
2223 op::Constant(), _));
2224 auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2225 EXPECT_THAT(root, AllOf(op::DynamicSlice(pad, op::Constant(), _),
2226 op::Shape("f32[27,11]")));
2227 }
2228
TEST_F(SpmdPartitioningTest,SliceAlongNonPartitionedDimension)2229 TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) {
2230 absl::string_view hlo_string = R"(
2231 HloModule module
2232
2233 ENTRY entry {
2234 %param0 = f32[128,14,257] parameter(0)
2235 %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1}
2236 ROOT %slice = f32[128,11,257] slice(%param0.copy),
2237 slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1}
2238 })";
2239
2240 TF_ASSERT_OK_AND_ASSIGN(auto module,
2241 PartitionComputation(hlo_string, /*num_devices=*/2));
2242 VLOG(1) << module->ToString();
2243
2244 const auto root = module->entry_computation()->root_instruction();
2245 auto param0 = AllOf(
2246 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2247 op::Constant(), op::Constant(), op::Reshape())),
2248 op::Shape("f32[128,14,129]"));
2249 EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2250 }
2251
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension)2252 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) {
2253 absl::string_view hlo_string = R"(
2254 HloModule module
2255
2256 ENTRY entry {
2257 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2258 ROOT %slice = f32[63,14,251] slice(%param0),
2259 slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1}
2260 })";
2261
2262 TF_ASSERT_OK_AND_ASSIGN(auto module,
2263 PartitionComputation(hlo_string, /*num_devices=*/2));
2264 VLOG(1) << module->ToString();
2265
2266 const auto root = module->entry_computation()->root_instruction();
2267 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[128,14,129]"));
2268 EXPECT_THAT(
2269 root,
2270 AllOf(op::Slice(AllOf(
2271 op::DynamicSlice(
2272 AllOf(op::Concatenate(
2273 op::Slice(param0),
2274 AllOf(op::CollectivePermute(op::Slice(param0)),
2275 op::Shape("f32[128,14,2]"))),
2276 op::Shape("f32[128,14,129]")),
2277 op::Constant(), op::Constant(), op::Add()),
2278 op::Shape("f32[128,14,126]"))),
2279 op::Shape("f32[63,14,126]")));
2280 }
2281
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension2)2282 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension2) {
2283 absl::string_view hlo_string = R"(
2284 HloModule module
2285
2286 ENTRY entry {
2287 %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2288 ROOT %slice = f32[1] slice(%param0),
2289 slice={[3:4]}, sharding={devices=[4]0,1,2,3}
2290 })";
2291
2292 TF_ASSERT_OK_AND_ASSIGN(auto module,
2293 PartitionComputation(hlo_string, /*num_devices=*/4));
2294 VLOG(1) << module->ToString();
2295
2296 const auto root = module->entry_computation()->root_instruction();
2297 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2298 EXPECT_THAT(root, AllOf(op::Copy(op::CollectivePermute(param0)),
2299 op::Shape("f32[1]")));
2300 }
2301
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRight)2302 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRight) {
2303 absl::string_view hlo_string = R"(
2304 HloModule module
2305
2306 ENTRY entry {
2307 %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2308 %init = f32[] constant(2.0)
2309 %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2310 %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2311 %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2312 ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2313 })";
2314
2315 TF_ASSERT_OK_AND_ASSIGN(auto module,
2316 PartitionComputation(hlo_string, /*num_devices=*/4));
2317 VLOG(1) << module->ToString();
2318
2319 const auto root = module->entry_computation()->root_instruction();
2320 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2321 EXPECT_THAT(root, AllOf(op::Select(_, op::CollectivePermute(param0), _),
2322 op::Shape("f32[1]")));
2323 }
2324
2325 // Same as above except that it uses zero padding, so there is no need for
2326 // masking.
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRightNoMasking)2327 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRightNoMasking) {
2328 absl::string_view hlo_string = R"(
2329 HloModule module
2330
2331 ENTRY entry {
2332 %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2333 %init = f32[] constant(0)
2334 %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2335 %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2336 %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2337 ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2338 })";
2339
2340 TF_ASSERT_OK_AND_ASSIGN(auto module,
2341 PartitionComputation(hlo_string, /*num_devices=*/4));
2342 VLOG(1) << module->ToString();
2343
2344 const auto root = module->entry_computation()->root_instruction();
2345 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2346 EXPECT_THAT(root, AllOf(op::CollectivePermute(param0), op::Shape("f32[1]")));
2347 }
2348
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRight)2349 TEST_F(SpmdPartitioningTest, MergedSliceThenConcatRotateRight) {
2350 absl::string_view hlo_string = R"(
2351 HloModule module
2352
2353 ENTRY entry {
2354 %param0 = f32[12] parameter(0), sharding={devices=[4]0,1,2,3}
2355 %slice0 = f32[2] slice(%param0), slice={[10:12]}, sharding={devices=[4]0,1,2,3}
2356 %slice1 = f32[10] slice(%param0), slice={[0:10]}, sharding={devices=[4]0,1,2,3}
2357 ROOT %concat = f32[12] concatenate(%slice0, %slice1), dimensions={0},
2358 sharding={devices=[4]0,1,2,3}
2359 })";
2360
2361 TF_ASSERT_OK_AND_ASSIGN(auto module,
2362 PartitionComputation(hlo_string, /*num_devices=*/4));
2363 VLOG(1) << module->ToString();
2364
2365 const auto root = module->entry_computation()->root_instruction();
2366 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2367 auto rotate = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2368 op::Slice(param0));
2369 EXPECT_THAT(root, AllOf(rotate, op::Shape("f32[3]")));
2370 }
2371
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithAlignedPadding)2372 TEST_F(SpmdPartitioningTest,
2373 MergedSliceThenConcatRotateRightWithAlignedPadding) {
2374 absl::string_view hlo_string = R"(
2375 HloModule module
2376
2377 ENTRY entry {
2378 %param0 = f32[6] parameter(0), sharding={devices=[4]0,1,2,3}
2379 %slice0 = f32[2] slice(%param0), slice={[4:6]}, sharding={devices=[4]0,1,2,3}
2380 %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2381 ROOT %concat = f32[6] concatenate(%slice0, %slice1), dimensions={0},
2382 sharding={devices=[4]0,1,2,3}
2383 })";
2384
2385 TF_ASSERT_OK_AND_ASSIGN(auto module,
2386 PartitionComputation(hlo_string, /*num_devices=*/4));
2387 VLOG(1) << module->ToString();
2388
2389 const auto root = module->entry_computation()->root_instruction();
2390 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[2]"));
2391 EXPECT_THAT(root, op::CollectivePermute(param0));
2392 }
2393
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithUnalignedPadding)2394 TEST_F(SpmdPartitioningTest,
2395 MergedSliceThenConcatRotateRightWithUnalignedPadding) {
2396 absl::string_view hlo_string = R"(
2397 HloModule module
2398
2399 ENTRY entry {
2400 %param0 = f32[10] parameter(0), sharding={devices=[4]0,1,2,3}
2401 %slice0 = f32[6] slice(%param0), slice={[4:10]}, sharding={devices=[4]0,1,2,3}
2402 %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2403 ROOT %concat = f32[10] concatenate(%slice0, %slice1), dimensions={0},
2404 sharding={devices=[4]0,1,2,3}
2405 })";
2406
2407 TF_ASSERT_OK_AND_ASSIGN(auto module,
2408 PartitionComputation(hlo_string, /*num_devices=*/4));
2409 VLOG(1) << module->ToString();
2410
2411 const auto root = module->entry_computation()->root_instruction();
2412 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2413 auto rotate0 = op::CollectivePermute(param0);
2414 auto rotate1 = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2415 op::CollectivePermute(op::Slice(param0)));
2416 EXPECT_THAT(root,
2417 AllOf(op::Select(_, rotate1, rotate0), op::Shape("f32[3]")));
2418 }
2419
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongNonPartitionedDimension)2420 TEST_F(SpmdPartitioningTest,
2421 PartialReplicateSliceAlongNonPartitionedDimension) {
2422 absl::string_view hlo_string = R"(
2423 HloModule module
2424
2425 ENTRY entry {
2426 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2427 ROOT %slice = f32[128,11,257] slice(%param0),
2428 slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2429 })";
2430
2431 TF_ASSERT_OK_AND_ASSIGN(auto module,
2432 PartitionComputation(hlo_string, /*num_devices=*/4));
2433 VLOG(1) << module->ToString();
2434
2435 const auto root = module->entry_computation()->root_instruction();
2436 auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2437 EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2438 }
2439
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongPartitionedDimension)2440 TEST_F(SpmdPartitioningTest, PartialReplicateSliceAlongPartitionedDimension) {
2441 absl::string_view hlo_string = R"(
2442 HloModule module
2443
2444 ENTRY entry {
2445 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2446 ROOT %slice = f32[63,14,251] slice(%param0),
2447 slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2448 })";
2449
2450 TF_ASSERT_OK_AND_ASSIGN(auto module,
2451 PartitionComputation(hlo_string, /*num_devices=*/4));
2452 VLOG(1) << module->ToString();
2453
2454 const auto root = module->entry_computation()->root_instruction();
2455 auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2456 EXPECT_THAT(
2457 root,
2458 AllOf(
2459 op::Slice(AllOf(
2460 op::DynamicSlice(
2461 AllOf(op::Concatenate(
2462 op::Slice(param0),
2463 AllOf(op::CollectivePermute(op::Slice(param0)),
2464 op::Shape("f32[128,14,2]"))),
2465 op::Shape("f32[128,14,129]")),
2466 op::Constant(), op::Constant(),
2467 op::Add(op::Multiply(op::Reshape(op::DynamicSlice(
2468 op::Constant(), op::PartitionId())),
2469 op::Constant()),
2470 op::Constant())),
2471 op::Shape("f32[128,14,126]"))),
2472 op::Shape("f32[63,14,126]")));
2473 }
2474
TEST_F(SpmdPartitioningTest,SortAlongNonPartitionedDimension)2475 TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) {
2476 absl::string_view hlo_string = R"(
2477 HloModule module
2478
2479 ge {
2480 p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated}
2481 bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2482 constant = s32[]{:T(256)} constant(0), sharding={replicated}
2483 compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated}
2484 constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated}
2485 bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2486 subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated}
2487 bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated}
2488 select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated}
2489 p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated}
2490 bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2491 compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated}
2492 bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2493 subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated}
2494 bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated}
2495 select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated}
2496 compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated}
2497 compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated}
2498 compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated}
2499 p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated}
2500 p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated}
2501 compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated}
2502 ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated}
2503 }
2504
2505 ENTRY entry {
2506 %param0 = f32[128,14,257] parameter(0)
2507 %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1}
2508 %param1 = s32[128,14,257] parameter(1)
2509 %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1}
2510 ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)})
2511 sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true,
2512 to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}}
2513 })";
2514
2515 TF_ASSERT_OK_AND_ASSIGN(auto module,
2516 PartitionComputation(hlo_string, /*num_devices=*/2));
2517 VLOG(1) << module->ToString();
2518
2519 const auto root = module->entry_computation()->root_instruction();
2520 auto param0 =
2521 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
2522 op::Reshape(), op::Constant())),
2523 op::Shape("f32[128,7,257]"));
2524 auto param1 =
2525 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
2526 op::Reshape(), op::Constant())),
2527 op::Shape("s32[128,7,257]"));
2528 EXPECT_THAT(root, AllOf(op::Sort(param0, param1),
2529 op::Shape("(f32[128,7,257], s32[128,7,257])")));
2530 }
2531
TEST_F(SpmdPartitioningTest,PartitionCustomCall)2532 TEST_F(SpmdPartitioningTest, PartitionCustomCall) {
2533 absl::string_view hlo_string = R"(
2534 HloModule cluster_2013453984438090939__.47
2535
2536 ENTRY %cluster_2013453984438090939__.47
2537 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2538 %arg_tuple.1 = bf16[2,209664] parameter(0)
2539 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2540 %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2541 custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK"
2542 %get-tuple-element = bf16[2,2000]{1,0}
2543 get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call),
2544 index=0, sharding={replicated}
2545 %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0},
2546 s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated}
2547 ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2548 tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0}
2549 %get-tuple-element.1), sharding={{replicated}, {replicated}},
2550 metadata={op_name="XLA_Retvals"}
2551 })";
2552
2553 TF_ASSERT_OK_AND_ASSIGN(auto module,
2554 PartitionComputation(hlo_string, /*num_devices=*/2));
2555 VLOG(1) << module->ToString();
2556 auto custom_call = FindInstruction(module.get(), "custom-call.1");
2557 EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832);
2558 auto sort = FindInstruction(module.get(), "sort");
2559 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000);
2560 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
2561 }
2562
TEST_F(SpmdPartitioningTest,PartitionCustomCall_TwoPartitionedDims)2563 TEST_F(SpmdPartitioningTest, PartitionCustomCall_TwoPartitionedDims) {
2564 absl::string_view hlo_string = R"(
2565 HloModule module
2566
2567 ENTRY entry {
2568 %param0 = f32[8,32128] parameter(0)
2569 %copy.0 = f32[8,32128] copy(%param0),
2570 sharding={devices=[4,2]0,1,2,3,4,5,6,7}
2571 %custom-call = (f32[8,2]{1,0}, s32[8,2]{1,0})
2572 custom-call(%copy.0), custom_call_target="TopK"
2573 %get-tuple-element = f32[8,2]{1,0}
2574 get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=0,
2575 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2576 %get-tuple-element.1 = s32[8,2]{1,0}
2577 get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=1,
2578 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2579 ROOT %tuple = (f32[8,2]{1,0}, s32[8,2]{1,0})
2580 tuple(%get-tuple-element, %get-tuple-element.1),
2581 sharding={{replicated}, {replicated}}
2582 })";
2583
2584 TF_ASSERT_OK_AND_ASSIGN(auto module,
2585 PartitionComputation(hlo_string, /*num_devices=*/8));
2586 VLOG(1) << module->ToString();
2587 auto custom_call = FindInstruction(module.get(), "custom-call.1");
2588 EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 16064);
2589 auto sort = FindInstruction(module.get(), "sort");
2590 EXPECT_EQ(sort->operand(0)->shape().dimensions(0), 2);
2591 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4);
2592 EXPECT_EQ(sort->operand(1)->shape().dimensions(0), 2);
2593 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4);
2594 }
2595
TEST_F(SpmdPartitioningTest,PartitionSortInTopK)2596 TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
2597 absl::string_view hlo_string = R"(
2598 HloModule module
2599
2600 %compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
2601 s32[], p.1.rhs.12: s32[]) -> pred[] {
2602 %p.1.lhs.11 = s32[] parameter(2)
2603 %p.1.rhs.12 = s32[] parameter(3)
2604 %p.0.lhs.9 = bf16[] parameter(0)
2605 %convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
2606 %bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
2607 %constant.20 = s32[] constant(0)
2608 %compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
2609 direction=LT
2610 %constant.15 = u32[] constant(2147483647)
2611 %bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
2612 %subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
2613 %bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
2614 %select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
2615 %bitcast-convert.16)
2616 %p.0.rhs.10 = bf16[] parameter(1)
2617 %convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
2618 %bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
2619 %constant.28 = s32[] constant(0)
2620 %compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
2621 direction=LT
2622 %constant.23 = u32[] constant(2147483647)
2623 %bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
2624 %subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
2625 %bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
2626 %select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
2627 %bitcast-convert.24)
2628 ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
2629 direction=GT
2630 }
2631
2632 ENTRY entry
2633 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2634 %arg_tuple.1 = bf16[2,209664] parameter(0)
2635 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2636 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2637 metadata={op_type="TopKV2" op_name="TopKV2"}
2638 %sort.32 = (bf16[2,209664], s32[2,209664])
2639 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2640 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2641 metadata={op_type="TopKV2" op_name="TopKV2"}
2642 %get-tuple-element.33 = bf16[2,209664]
2643 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2644 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2645 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2646 %get-tuple-element.33), slice={[0:2], [0:2000]},
2647 metadata={op_type="TopKV2" op_name="TopKV2"}
2648 %get-tuple-element.35 = s32[2,209664]
2649 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2650 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2651 %slice.36 = s32[2,2000] slice(s32[2,209664]
2652 %get-tuple-element.35), slice={[0:2], [0:2000]},
2653 metadata={op_type="TopKV2" op_name="TopKV2"}
2654 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2655 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2656 %slice.36), sharding={{replicated}, {replicated}},
2657 metadata={op_name="XLA_Retvals"}
2658 })";
2659
2660 TF_ASSERT_OK_AND_ASSIGN(auto module,
2661 PartitionComputation(hlo_string, /*num_devices=*/2));
2662 VLOG(1) << module->ToString();
2663 auto sort = FindInstruction(module.get(), "sort.0");
2664 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2665 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2666 auto final_sort = FindInstruction(module.get(), "sort.1");
2667 EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2668 EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2669 }
2670
TEST_F(SpmdPartitioningTest,PartitionSortInTopKWhenComparisonWithSelect)2671 TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
2672 absl::string_view hlo_string = R"(
2673 HloModule module
2674
2675 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2676 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2677 %p.0.lhs.2566 = bf16[] parameter(0)
2678 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2679 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2680 %constant.285 = s32[] constant(0)
2681 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2682 direction=LT
2683 %constant.286 = u32[] constant(2147483647)
2684 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2685 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2686 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2687 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2688 s32[] %bitcast-convert.48)
2689 %p.0.rhs.2567 = bf16[] parameter(1)
2690 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2691 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2692 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2693 direction=LT
2694 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2695 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2696 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2697 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2698 s32[] %bitcast-convert.51)
2699 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2700 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2701 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2702 direction=EQ
2703 %p.1.lhs.2586 = s32[] parameter(2)
2704 %p.1.rhs.2587 = s32[] parameter(3)
2705 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2706 direction=LT
2707 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2708 pred[] %compare.86)
2709 }
2710
2711 ENTRY entry
2712 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2713 %arg_tuple.1 = bf16[2,209664] parameter(0)
2714 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2715 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2716 metadata={op_type="TopKV2" op_name="TopKV2"}
2717 %sort.32 = (bf16[2,209664], s32[2,209664])
2718 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2719 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2720 metadata={op_type="TopKV2" op_name="TopKV2"}
2721 %get-tuple-element.33 = bf16[2,209664]
2722 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2723 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2724 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2725 %get-tuple-element.33), slice={[0:2], [0:2000]},
2726 metadata={op_type="TopKV2" op_name="TopKV2"}
2727 %get-tuple-element.35 = s32[2,209664]
2728 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2729 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2730 %slice.36 = s32[2,2000] slice(s32[2,209664]
2731 %get-tuple-element.35), slice={[0:2], [0:2000]},
2732 metadata={op_type="TopKV2" op_name="TopKV2"}
2733 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2734 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2735 %slice.36), sharding={{replicated}, {replicated}},
2736 metadata={op_name="XLA_Retvals"}
2737 })";
2738
2739 TF_ASSERT_OK_AND_ASSIGN(auto module,
2740 PartitionComputation(hlo_string, /*num_devices=*/2));
2741 VLOG(1) << module->ToString();
2742 auto sort = FindInstruction(module.get(), "sort.0");
2743 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2744 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2745 auto final_sort = FindInstruction(module.get(), "sort.1");
2746 EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2747 EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2748 }
2749
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSecondOperandIsNotIota)2750 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
2751 absl::string_view hlo_string = R"(
2752 HloModule module
2753
2754 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2755 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2756 %p.0.lhs.2566 = bf16[] parameter(0)
2757 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2758 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2759 %constant.285 = s32[] constant(0)
2760 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2761 direction=LT
2762 %constant.286 = u32[] constant(2147483647)
2763 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2764 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2765 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2766 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2767 s32[] %bitcast-convert.48)
2768 %p.0.rhs.2567 = bf16[] parameter(1)
2769 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2770 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2771 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2772 direction=LT
2773 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2774 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2775 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2776 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2777 s32[] %bitcast-convert.51)
2778 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2779 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2780 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2781 direction=EQ
2782 %p.1.lhs.2586 = s32[] parameter(2)
2783 %p.1.rhs.2587 = s32[] parameter(3)
2784 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2785 direction=LT
2786 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2787 pred[] %compare.86)
2788 }
2789
2790 ENTRY entry {
2791 %arg_tuple.1 = bf16[2,209664] parameter(0)
2792 %arg_tuple.2 = s32[2,209664] parameter(1)
2793 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2794 %sort.32 = (bf16[2,209664], s32[2,209664])
2795 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
2796 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2797 metadata={op_type="TopKV2" op_name="TopKV2"}
2798 %get-tuple-element.33 = bf16[2,209664]
2799 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2800 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2801 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2802 %get-tuple-element.33), slice={[0:2], [0:2000]},
2803 metadata={op_type="TopKV2" op_name="TopKV2"}
2804 %get-tuple-element.35 = s32[2,209664]
2805 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2806 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2807 %slice.36 = s32[2,2000] slice(s32[2,209664]
2808 %get-tuple-element.35), slice={[0:2], [0:2000]},
2809 metadata={op_type="TopKV2" op_name="TopKV2"}
2810 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2811 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2812 %slice.36), sharding={{replicated}, {replicated}},
2813 metadata={op_name="XLA_Retvals"}
2814 })";
2815
2816 TF_ASSERT_OK_AND_ASSIGN(auto module,
2817 PartitionComputation(hlo_string, /*num_devices=*/2));
2818 VLOG(1) << module->ToString();
2819 auto sort = FindInstruction(module.get(), "sort.0");
2820 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2821 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2822 }
2823
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenNoPartitionInSortDim)2824 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
2825 absl::string_view hlo_string = R"(
2826 HloModule module
2827
2828 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2829 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2830 %p.0.lhs.2566 = bf16[] parameter(0)
2831 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2832 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2833 %constant.285 = s32[] constant(0)
2834 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2835 direction=LT
2836 %constant.286 = u32[] constant(2147483647)
2837 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2838 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2839 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2840 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2841 s32[] %bitcast-convert.48)
2842 %p.0.rhs.2567 = bf16[] parameter(1)
2843 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2844 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2845 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2846 direction=LT
2847 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2848 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2849 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2850 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2851 s32[] %bitcast-convert.51)
2852 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2853 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2854 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2855 direction=EQ
2856 %p.1.lhs.2586 = s32[] parameter(2)
2857 %p.1.rhs.2587 = s32[] parameter(3)
2858 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2859 direction=LT
2860 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2861 pred[] %compare.86)
2862 }
2863
2864 ENTRY entry
2865 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2866 %arg_tuple.1 = bf16[2,209664] parameter(0)
2867 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
2868 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2869 metadata={op_type="TopKV2" op_name="TopKV2"}
2870 %sort.32 = (bf16[2,209664], s32[2,209664])
2871 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2872 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2873 metadata={op_type="TopKV2" op_name="TopKV2"}
2874 %get-tuple-element.33 = bf16[2,209664]
2875 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2876 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2877 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2878 %get-tuple-element.33), slice={[0:2], [0:2000]},
2879 metadata={op_type="TopKV2" op_name="TopKV2"}
2880 %get-tuple-element.35 = s32[2,209664]
2881 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2882 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2883 %slice.36 = s32[2,2000] slice(s32[2,209664]
2884 %get-tuple-element.35), slice={[0:2], [0:2000]},
2885 metadata={op_type="TopKV2" op_name="TopKV2"}
2886 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2887 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2888 %slice.36), sharding={{replicated}, {replicated}},
2889 metadata={op_name="XLA_Retvals"}
2890 })";
2891
2892 TF_ASSERT_OK_AND_ASSIGN(auto module,
2893 PartitionComputation(hlo_string, /*num_devices=*/2));
2894 VLOG(1) << module->ToString();
2895 auto sort = FindInstruction(module.get(), "sort.0");
2896 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2897 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2898 }
2899
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSliceInOtherDim)2900 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
2901 absl::string_view hlo_string = R"(
2902 HloModule module
2903
2904 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2905 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2906 %p.0.lhs.2566 = bf16[] parameter(0)
2907 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2908 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2909 %constant.285 = s32[] constant(0)
2910 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2911 direction=LT
2912 %constant.286 = u32[] constant(2147483647)
2913 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2914 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2915 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2916 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2917 s32[] %bitcast-convert.48)
2918 %p.0.rhs.2567 = bf16[] parameter(1)
2919 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2920 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2921 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2922 direction=LT
2923 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2924 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2925 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2926 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2927 s32[] %bitcast-convert.51)
2928 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2929 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2930 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2931 direction=EQ
2932 %p.1.lhs.2586 = s32[] parameter(2)
2933 %p.1.rhs.2587 = s32[] parameter(3)
2934 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2935 direction=LT
2936 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2937 pred[] %compare.86)
2938 }
2939
2940 ENTRY entry {
2941 %arg_tuple.1 = bf16[2,209664] parameter(0)
2942 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2943 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2944 metadata={op_type="TopKV2" op_name="TopKV2"}
2945 %sort.32 = (bf16[2,209664], s32[2,209664])
2946 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2947 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2948 metadata={op_type="TopKV2" op_name="TopKV2"}
2949 %get-tuple-element.33 = bf16[2,209664]
2950 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2951 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2952 %slice.34 = bf16[1,209664] slice(bf16[2,209664]
2953 %get-tuple-element.33), slice={[0:1], [0:209664]},
2954 metadata={op_type="TopKV2" op_name="TopKV2"}
2955 %get-tuple-element.35 = s32[2,209664]
2956 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2957 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2958 %slice.36 = s32[1,209664] slice(s32[2,209664]
2959 %get-tuple-element.35), slice={[0:1], [0:209664]},
2960 metadata={op_type="TopKV2" op_name="TopKV2"}
2961 ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
2962 tuple(bf16[1,209664] %slice.34, s32[1,209664]
2963 %slice.36), sharding={{replicated}, {replicated}},
2964 metadata={op_name="XLA_Retvals"}
2965 })";
2966
2967 TF_ASSERT_OK_AND_ASSIGN(auto module,
2968 PartitionComputation(hlo_string, /*num_devices=*/2));
2969 VLOG(1) << module->ToString();
2970 auto sort = FindInstruction(module.get(), "sort.0");
2971 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2972 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2973 }
2974
TEST_F(SpmdPartitioningTest,ShardableTranspose)2975 TEST_F(SpmdPartitioningTest, ShardableTranspose) {
2976 absl::string_view hlo_string = R"(
2977 HloModule module
2978
2979 ENTRY entry {
2980 %param0 = f32[16,38,38,4] parameter(0)
2981 %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
2982 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
2983 dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1}
2984 })";
2985
2986 TF_ASSERT_OK_AND_ASSIGN(auto module,
2987 PartitionComputation(hlo_string, /*num_devices=*/2));
2988 VLOG(1) << module->ToString();
2989
2990 const auto root = module->entry_computation()->root_instruction();
2991 auto param0 = AllOf(
2992 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
2993 op::Constant(), op::Constant())),
2994 op::Shape("f32[16,19,38,4]"));
2995 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
2996 }
2997
TEST_F(SpmdPartitioningTest,MultiDimensionShardedTranspose)2998 TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) {
2999 absl::string_view hlo_string = R"(
3000 HloModule module
3001
3002 ENTRY entry {
3003 %param0 = f32[16,38,38,4] parameter(0)
3004 %param0.copy = f32[16,38,38,4] copy(%param0),
3005 sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
3006 ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
3007 dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7}
3008 })";
3009
3010 TF_ASSERT_OK_AND_ASSIGN(auto module,
3011 PartitionComputation(hlo_string, /*num_devices=*/8));
3012 VLOG(1) << module->ToString();
3013
3014 const auto root = module->entry_computation()->root_instruction();
3015 auto param0 = AllOf(
3016 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
3017 op::Constant(), op::Constant())),
3018 op::Shape("f32[4,19,38,4]"));
3019 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]")));
3020 }
3021
TEST_F(SpmdPartitioningTest,NonShardableTranspose)3022 TEST_F(SpmdPartitioningTest, NonShardableTranspose) {
3023 absl::string_view hlo_string = R"(
3024 HloModule module
3025
3026 ENTRY entry {
3027 %param0 = f32[16,38,38,4] parameter(0)
3028 %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
3029 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3030 dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1}
3031 })";
3032
3033 TF_ASSERT_OK_AND_ASSIGN(auto module,
3034 PartitionComputation(hlo_string, /*num_devices=*/2));
3035 VLOG(1) << module->ToString();
3036
3037 const auto root = module->entry_computation()->root_instruction();
3038 auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3039 op::Shape("f32[16,38,38,2]"));
3040 EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3041 }
3042
TEST_F(SpmdPartitioningTest,PartialReplicateShardableTranspose)3043 TEST_F(SpmdPartitioningTest, PartialReplicateShardableTranspose) {
3044 absl::string_view hlo_string = R"(
3045 HloModule module
3046
3047 ENTRY entry {
3048 %param0 = f32[16,38,38,4] parameter(0)
3049 %param0.copy = f32[16,38,38,4] copy(%param0),
3050 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3051 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3052 dimensions={0,3,1,2},
3053 sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3054 })";
3055
3056 TF_ASSERT_OK_AND_ASSIGN(auto module,
3057 PartitionComputation(hlo_string, /*num_devices=*/4));
3058 VLOG(1) << module->ToString();
3059
3060 const auto root = module->entry_computation()->root_instruction();
3061 auto param0 = AllOf(
3062 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3063 op::Constant(), op::Constant())),
3064 op::Shape("f32[16,19,38,4]"));
3065 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
3066 }
3067
TEST_F(SpmdPartitioningTest,PartialReplicateNonShardableTranspose)3068 TEST_F(SpmdPartitioningTest, PartialReplicateNonShardableTranspose) {
3069 absl::string_view hlo_string = R"(
3070 HloModule module
3071
3072 ENTRY entry {
3073 %param0 = f32[16,38,38,4] parameter(0)
3074 %param0.copy = f32[16,38,38,4] copy(%param0),
3075 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3076 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3077 dimensions={0,3,1,2},
3078 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3079 })";
3080
3081 TF_ASSERT_OK_AND_ASSIGN(auto module,
3082 PartitionComputation(hlo_string, /*num_devices=*/4));
3083 VLOG(1) << module->ToString();
3084
3085 const auto root = module->entry_computation()->root_instruction();
3086 auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3087 op::Shape("f32[16,38,38,2]"));
3088 EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3089 }
3090
TEST_F(SpmdPartitioningTest,PartialReplicateMultiDimensionShardedTranspose)3091 TEST_F(SpmdPartitioningTest, PartialReplicateMultiDimensionShardedTranspose) {
3092 absl::string_view hlo_string = R"(
3093 HloModule module
3094
3095 ENTRY entry {
3096 %param0 = f32[16,38,38,4] parameter(0)
3097 %param0.copy = f32[16,38,38,4] copy(%param0),
3098 sharding={devices=[2,2,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3099 ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
3100 dimensions={1,3,0,2},
3101 sharding={devices=[2,1,2,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3102 })";
3103
3104 TF_ASSERT_OK_AND_ASSIGN(auto module,
3105 PartitionComputation(hlo_string, /*num_devices=*/8));
3106 VLOG(1) << module->ToString();
3107
3108 const auto root = module->entry_computation()->root_instruction();
3109 auto param0 = AllOf(
3110 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
3111 op::Constant(), op::Constant())),
3112 op::Shape("f32[8,19,38,4]"));
3113 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,8,38]")));
3114 }
3115
TEST_F(SpmdPartitioningTest,ShardableReshape)3116 TEST_F(SpmdPartitioningTest, ShardableReshape) {
3117 absl::string_view hlo_string = R"(
3118 HloModule module
3119
3120 ENTRY entry {
3121 %param0 = f32[38,38,324] parameter(0)
3122 %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1}
3123 ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3124 sharding={devices=[2,1,1,1]0,1}
3125 })";
3126
3127 TF_ASSERT_OK_AND_ASSIGN(auto module,
3128 PartitionComputation(hlo_string, /*num_devices=*/2));
3129 VLOG(1) << module->ToString();
3130
3131 const auto root = module->entry_computation()->root_instruction();
3132 auto param0 =
3133 AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3134 op::Constant(), op::Constant())),
3135 op::Shape("f32[19,38,324]"));
3136 EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3137 }
3138
TEST_F(SpmdPartitioningTest,ReshapeWithReshard)3139 TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
3140 absl::string_view hlo_string = R"(
3141 HloModule module
3142
3143 ENTRY entry {
3144 %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3145 ROOT %reshape = f32[38,38,4,81] reshape(%param0),
3146 sharding={devices=[1,2,1,1]0,1}
3147 })";
3148
3149 TF_ASSERT_OK_AND_ASSIGN(auto module,
3150 PartitionComputation(hlo_string, /*num_devices=*/2));
3151 VLOG(1) << module->ToString();
3152
3153 const auto root = module->entry_computation()->root_instruction();
3154 auto input_reshard =
3155 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
3156 EXPECT_THAT(root,
3157 AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
3158 }
3159
TEST_F(SpmdPartitioningTest,ReshapeWithReshard2)3160 TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
3161 absl::string_view hlo_string = R"(
3162 HloModule module
3163
3164 ENTRY entry {
3165 %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3166 ROOT %reshape = f32[38,38,2,162] reshape(%param0),
3167 sharding={devices=[1,1,1,2]0,1}
3168 })";
3169
3170 TF_ASSERT_OK_AND_ASSIGN(auto module,
3171 PartitionComputation(hlo_string, /*num_devices=*/2));
3172 VLOG(1) << module->ToString();
3173
3174 const auto root = module->entry_computation()->root_instruction();
3175 auto local_reshape =
3176 AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
3177 EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
3178 op::Reshape(op::Transpose(
3179 op::AllToAll(op::Reshape(local_reshape))))));
3180 }
3181
TEST_F(SpmdPartitioningTest,PartialReplicateShardableReshape)3182 TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
3183 absl::string_view hlo_string = R"(
3184 HloModule module
3185
3186 ENTRY entry {
3187 %param0 = f32[38,38,324] parameter(0)
3188 %param0.copy = f32[38,38,324] copy(%param0),
3189 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3190 ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3191 sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
3192 })";
3193
3194 TF_ASSERT_OK_AND_ASSIGN(auto module,
3195 PartitionComputation(hlo_string, /*num_devices=*/4));
3196 VLOG(1) << module->ToString();
3197
3198 const auto root = module->entry_computation()->root_instruction();
3199 auto param0 =
3200 AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3201 op::Constant(), op::Constant())),
3202 op::Shape("f32[19,38,324]"));
3203 EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3204 }
3205
TEST_F(SpmdPartitioningTest,ReshapeMergeDimsWithHaloExchange)3206 TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
3207 absl::string_view hlo_string = R"(
3208 HloModule module
3209
3210 ENTRY entry {
3211 %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1}
3212 ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3213 sharding={devices=[1,1,1,2,1]0,1}
3214 })";
3215
3216 TF_ASSERT_OK_AND_ASSIGN(auto module,
3217 PartitionComputation(hlo_string, /*num_devices=*/2));
3218 VLOG(1) << module->ToString();
3219
3220 auto reshape =
3221 AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3222 auto halo = op::CollectivePermute(op::Slice(reshape));
3223 auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3224 _, _, _, _, _);
3225 const auto root = module->entry_computation()->root_instruction();
3226 EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3227 }
3228
TEST_F(SpmdPartitioningTest,PartialReplicateReshapeMergeDimsWithHaloExchange)3229 TEST_F(SpmdPartitioningTest, PartialReplicateReshapeMergeDimsWithHaloExchange) {
3230 absl::string_view hlo_string = R"(
3231 HloModule module
3232
3233 ENTRY entry {
3234 %input = s32[2,3,7,10] parameter(0),
3235 sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3236 ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3237 sharding={devices=[1,1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3238 })";
3239
3240 TF_ASSERT_OK_AND_ASSIGN(auto module,
3241 PartitionComputation(hlo_string, /*num_devices=*/4));
3242 VLOG(1) << module->ToString();
3243
3244 auto reshape =
3245 AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3246 auto halo = op::CollectivePermute(op::Slice(reshape));
3247 auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3248 _, _, _, _, _);
3249 const auto root = module->entry_computation()->root_instruction();
3250 EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3251 }
3252
3253 // Produces an invalid module after transformation.
TEST_F(SpmdPartitioningTest,InceptionV3_4_way_ReduceWindowDilated)3254 TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) {
3255 absl::string_view hlo_string = R"(
3256 HloModule module
3257
3258 sum {
3259 a = f32[] parameter(0)
3260 b = f32[] parameter(1)
3261 ROOT add = f32[] add(a, b)
3262 }
3263
3264 ENTRY entry {
3265 %param0 = f32[128,5,5,768] parameter(0)
3266 %param0.copy = f32[128,5,5,768] copy(%param0),
3267 sharding={devices=[1,4,1,1]0,1,2,3}
3268 %constant.1 = f32[] constant(0), sharding={replicated}
3269 ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1),
3270 window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1},
3271 to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3}
3272 })";
3273
3274 TF_ASSERT_OK_AND_ASSIGN(auto module,
3275 PartitionComputation(hlo_string, /*num_devices=*/4));
3276 VLOG(1) << module->ToString();
3277
3278 auto input_shard = op::Copy(op::DynamicSlice(
3279 op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(),
3280 op::Constant(), op::Constant()));
3281 auto id_mul4_add1 =
3282 op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
3283 auto id_mul5 = op::Multiply(op::Reshape(), op::Constant());
3284 auto id_mul5_add1_div3 =
3285 op::Divide(op::Add(id_mul5, op::Constant()), op::Constant());
3286 auto before_masking = AllOf(
3287 op::Shape("f32[128,3,5,768]"),
3288 op::DynamicSlice(
3289 AllOf(
3290 op::Shape("f32[128,4,5,768]"),
3291 op::Concatenate(op::CollectivePermute(input_shard), input_shard)),
3292 op::Constant(),
3293 op::Subtract(op::Constant(),
3294 op::Subtract(id_mul4_add1, id_mul5_add1_div3)),
3295 op::Constant(), op::Constant()));
3296 auto masked = op::Select(
3297 op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3298 op::Broadcast(op::Constant())),
3299 op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3300 op::Broadcast(op::Constant()))),
3301 before_masking, op::Broadcast(op::Constant()));
3302 auto rw = AllOf(op::Shape("f32[128,7,17,768]"),
3303 op::ReduceWindow(masked, op::Constant()));
3304 auto final_slice_index = op::Subtract(
3305 id_mul5,
3306 op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant()));
3307 const auto root = module->entry_computation()->root_instruction();
3308 EXPECT_THAT(root,
3309 AllOf(op::Shape("f32[128,5,17,768]"),
3310 op::DynamicSlice(rw, op::Constant(), final_slice_index,
3311 op::Constant(), op::Constant())));
3312 }
3313
TEST_F(SpmdPartitioningTest,TiledToTiledReduce)3314 TEST_F(SpmdPartitioningTest, TiledToTiledReduce) {
3315 absl::string_view hlo_string = R"(
3316 HloModule module
3317
3318 sum {
3319 a = f32[] parameter(0)
3320 b = f32[] parameter(1)
3321 ROOT add = f32[] add(a, b)
3322 }
3323
3324 ENTRY entry {
3325 %param0 = f32[4,32,32,128] parameter(0)
3326 %param0.copy = f32[4,32,32,128] copy(%param0),
3327 sharding={devices=[1,1,1,2]0,1}
3328 %constant.1 = f32[] constant(0), sharding={replicated}
3329 %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3330 to_apply=%sum, sharding={devices=[2]0,1}
3331 })";
3332
3333 TF_ASSERT_OK_AND_ASSIGN(auto module,
3334 PartitionComputation(hlo_string, /*num_devices=*/2));
3335 VLOG(1) << module->ToString();
3336
3337 const auto root = module->entry_computation()->root_instruction();
3338 auto param0 = AllOf(
3339 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
3340 op::Constant(), op::Reshape())),
3341 op::Shape("f32[4,32,32,64]"));
3342
3343 EXPECT_THAT(root,
3344 AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]")));
3345 }
3346
TEST_F(SpmdPartitioningTest,PartialTiledToPartialTiledReduce)3347 TEST_F(SpmdPartitioningTest, PartialTiledToPartialTiledReduce) {
3348 absl::string_view hlo_string = R"(
3349 HloModule module
3350
3351 sum {
3352 a = f32[] parameter(0)
3353 b = f32[] parameter(1)
3354 ROOT add = f32[] add(a, b)
3355 }
3356
3357 ENTRY entry {
3358 %param0 = f32[4,4] parameter(0),
3359 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3360 %constant.1 = f32[] constant(0), sharding={replicated}
3361 ROOT %reduce = f32[4] reduce(%param0, %constant.1), dimensions={0},
3362 to_apply=%sum,
3363 sharding={devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3364 })";
3365
3366 TF_ASSERT_OK_AND_ASSIGN(auto module,
3367 PartitionComputation(hlo_string, /*num_devices=*/8));
3368 VLOG(1) << module->ToString();
3369
3370 const auto root = module->entry_computation()->root_instruction();
3371 EXPECT_THAT(root,
3372 AllOf(op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())),
3373 op::Shape("f32[2]")));
3374 }
3375
TEST_F(SpmdPartitioningTest,TiledToTiledTupleReduce)3376 TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) {
3377 absl::string_view hlo_string = R"(
3378 HloModule module
3379
3380 %minmax_func {
3381 %lhs_value = f32[] parameter(0)
3382 %rhs_value = f32[] parameter(2)
3383 %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3384 %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3385 %lhs_index = s32[] parameter(1)
3386 %rhs_index = s32[] parameter(3)
3387 %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3388 ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3389 }
3390
3391 ENTRY %main {
3392 %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1}
3393 %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1}
3394 %init0 = f32[] parameter(2)
3395 %init1 = s32[] parameter(3)
3396 ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3397 dimensions={1}, to_apply=%minmax_func,
3398 sharding={{devices=[2]0,1}, {devices=[2]0,1}}
3399 })";
3400
3401 TF_ASSERT_OK_AND_ASSIGN(auto module,
3402 PartitionComputation(hlo_string, /*num_devices=*/2));
3403 VLOG(1) << module->ToString();
3404
3405 const auto root = module->entry_computation()->root_instruction();
3406 EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1),
3407 op::Parameter(2), op::Parameter(3)),
3408 op::Shape("(f32[14], s32[14])")));
3409 }
3410
TEST_F(SpmdPartitioningTest,TiledToPartiallyTiledTupleReduce)3411 TEST_F(SpmdPartitioningTest, TiledToPartiallyTiledTupleReduce) {
3412 absl::string_view hlo_string = R"(
3413 HloModule module
3414
3415 %minmax_func {
3416 %lhs_value = f32[] parameter(0)
3417 %rhs_value = f32[] parameter(2)
3418 %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3419 %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3420 %lhs_index = s32[] parameter(1)
3421 %rhs_index = s32[] parameter(3)
3422 %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3423 ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3424 }
3425
3426 ENTRY %main {
3427 %param0 = f32[28,12] parameter(0), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3428 %param1 = s32[28,12] parameter(1), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3429 %init0 = f32[] parameter(2)
3430 %init1 = s32[] parameter(3)
3431 ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3432 dimensions={1}, to_apply=%minmax_func,
3433 sharding={{devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate},
3434 {devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}}
3435 })";
3436
3437 TF_ASSERT_OK_AND_ASSIGN(auto module,
3438 PartitionComputation(hlo_string, /*num_devices=*/8));
3439 VLOG(1) << module->ToString();
3440
3441 const auto lhs = AllOf(op::Shape("f32[14,3]"), op::Parameter(0));
3442 const auto rhs = AllOf(op::Shape("s32[14,3]"), op::Parameter(1));
3443 auto local_reduce =
3444 AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)),
3445 op::Shape("(f32[14], s32[14])"));
3446 auto reshape_l = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3447 op::Shape("f32[14,1]"));
3448 auto reshape_r = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3449 op::Shape("s32[14,1]"));
3450 auto broadcast_l =
3451 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_l, _, _)),
3452 op::Shape("f32[14,4]"));
3453 auto broadcast_r =
3454 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_r, _, _)),
3455 op::Shape("s32[14,4]"));
3456 const auto root = module->entry_computation()->root_instruction();
3457 EXPECT_THAT(root, AllOf(op::Reduce(broadcast_l, broadcast_r, op::Parameter(2),
3458 op::Parameter(3)),
3459 op::Shape("(f32[14], s32[14])")));
3460 }
3461
TEST_F(SpmdPartitioningTest,TupleReduceSubgroupManual)3462 TEST_F(SpmdPartitioningTest, TupleReduceSubgroupManual) {
3463 absl::string_view hlo_string = R"(
3464 HloModule module
3465
3466 %minmax_func {
3467 %lhs_value = f32[] parameter(0)
3468 %rhs_value = f32[] parameter(2)
3469 %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3470 %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3471 %lhs_index = s32[] parameter(1)
3472 %rhs_index = s32[] parameter(3)
3473 %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3474 ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3475 }
3476
3477 ENTRY %main {
3478 %param0 = f32[28,12] parameter(0),
3479 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
3480 %param1 = s32[28,12] parameter(1),
3481 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
3482 %init0 = f32[] parameter(2),
3483 sharding={devices=[2,2]0,1,2,3 last_tile_dims={replicated,manual}}
3484 %init1 = s32[] parameter(3),
3485 sharding={devices=[2,2]0,1,2,3 last_tile_dims={replicated,manual}}
3486 ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3487 dimensions={1}, to_apply=%minmax_func,
3488 sharding={{devices=[1,2,2]0,1,2,3 last_tile_dims={replicated,manual}},
3489 {devices=[1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}}
3490 })";
3491
3492 TF_ASSERT_OK_AND_ASSIGN(auto module,
3493 PartitionComputation(hlo_string, /*num_devices=*/4));
3494 VLOG(1) << module->ToString();
3495
3496 const auto lhs = AllOf(op::Shape("f32[28,6]"), op::Parameter(0));
3497 const auto rhs = AllOf(op::Shape("s32[28,6]"), op::Parameter(1));
3498 auto local_reduce =
3499 AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)),
3500 op::Shape("(f32[28], s32[28])"));
3501 auto reshape_l = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3502 op::Shape("f32[28,1]"));
3503 auto reshape_r = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3504 op::Shape("s32[28,1]"));
3505 auto broadcast_l =
3506 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_l, _, _)),
3507 op::Shape("f32[28,2]"));
3508 auto broadcast_r =
3509 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_r, _, _)),
3510 op::Shape("s32[28,2]"));
3511 const auto root = module->entry_computation()->root_instruction();
3512 EXPECT_THAT(root, AllOf(op::Reduce(broadcast_l, broadcast_r, op::Parameter(2),
3513 op::Parameter(3)),
3514 op::Shape("(f32[28], s32[28])")));
3515 }
3516
TEST_F(SpmdPartitioningTest,TiledToTiledReduceOutputReshard)3517 TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) {
3518 absl::string_view hlo_string = R"(
3519 HloModule module
3520
3521 sum {
3522 a = f32[] parameter(0)
3523 b = f32[] parameter(1)
3524 ROOT add = f32[] add(a, b)
3525 }
3526
3527 ENTRY entry {
3528 %param0 = f32[4,32,32,128] parameter(0)
3529 %param0.copy = f32[4,32,32,128] copy(%param0),
3530 sharding={devices=[1,2,1,1]0,1}
3531 %constant.1 = f32[] constant(0), sharding={replicated}
3532 %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3533 to_apply=%sum, sharding={devices=[2]0,1}
3534 })";
3535
3536 TF_ASSERT_OK_AND_ASSIGN(auto module,
3537 PartitionComputation(hlo_string, /*num_devices=*/2));
3538 VLOG(1) << module->ToString();
3539
3540 const auto root = module->entry_computation()->root_instruction();
3541 auto param0 = AllOf(
3542 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3543 op::Constant(), op::Constant())),
3544 op::Shape("f32[4,16,32,128]"));
3545
3546 EXPECT_THAT(root,
3547 AllOf(op::DynamicSlice(
3548 AllOf(op::AllReduce(op::Reduce(param0, op::Constant())),
3549 op::Shape("f32[128]")),
3550 op::Reshape()),
3551 op::Shape("f32[64]")));
3552 }
3553
TEST_F(SpmdPartitioningTest,IotaAlongNonTileDimension)3554 TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) {
3555 absl::string_view hlo_string = R"(
3556 HloModule module
3557
3558 ENTRY entry {
3559 ROOT %iota = s32[16,80,91] iota(), iota_dimension=1,
3560 sharding={devices=[1,1,2]0,1}
3561 })";
3562
3563 TF_ASSERT_OK_AND_ASSIGN(auto module,
3564 PartitionComputation(hlo_string, /*num_devices=*/2));
3565 VLOG(1) << module->ToString();
3566
3567 const auto root = module->entry_computation()->root_instruction();
3568 EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]")));
3569 }
3570
TEST_F(SpmdPartitioningTest,IotaAlongTileDimension)3571 TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) {
3572 absl::string_view hlo_string = R"(
3573 HloModule module
3574
3575 ENTRY entry {
3576 ROOT %iota = s32[16,80,91] iota(), iota_dimension=2,
3577 sharding={devices=[1,1,2]0,1}
3578 })";
3579
3580 TF_ASSERT_OK_AND_ASSIGN(auto module,
3581 PartitionComputation(hlo_string, /*num_devices=*/2));
3582 VLOG(1) << module->ToString();
3583
3584 const auto root = module->entry_computation()->root_instruction();
3585 EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3586 op::Shape("s32[16,80,46]")));
3587 }
3588
TEST_F(SpmdPartitioningTest,U32IotaAlongTileDimension)3589 TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) {
3590 absl::string_view hlo_string = R"(
3591 HloModule module
3592
3593 ENTRY entry {
3594 ROOT %iota = u32[16,80,91] iota(), iota_dimension=2,
3595 sharding={devices=[1,1,2]0,1}
3596 })";
3597
3598 TF_ASSERT_OK_AND_ASSIGN(auto module,
3599 PartitionComputation(hlo_string, /*num_devices=*/2));
3600 VLOG(1) << module->ToString();
3601
3602 const auto root = module->entry_computation()->root_instruction();
3603 EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3604 op::Shape("u32[16,80,46]")));
3605 }
3606
TEST_F(SpmdPartitioningTest,Conditional)3607 TEST_F(SpmdPartitioningTest, Conditional) {
3608 absl::string_view hlo_string = R"(
3609 HloModule module
3610
3611 Negate {
3612 x = f32[4,5] parameter(0), sharding={replicated}
3613 ROOT negate = f32[4,5] negate(x), sharding={replicated}
3614 }
3615
3616 Identity {
3617 y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1}
3618 ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1}
3619 }
3620
3621 ENTRY entry {
3622 %param.0 = pred[] parameter(0)
3623 %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0}
3624 %param.1 = f32[4,5] parameter(1)
3625 %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated}
3626 %param.2 = f32[4,5] parameter(2)
3627 %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1}
3628 ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy),
3629 true_computation=Negate, false_computation=Identity,
3630 sharding={devices=[2,1]0,1}
3631 })";
3632
3633 TF_ASSERT_OK_AND_ASSIGN(auto module,
3634 PartitionComputation(hlo_string, /*num_devices=*/2));
3635 VLOG(1) << module->ToString();
3636
3637 auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]")));
3638 auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]"));
3639 auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3640 op::Constant())),
3641 op::Shape("f32[2,5]"));
3642
3643 const auto root = module->entry_computation()->root_instruction();
3644 EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2),
3645 op::Shape("f32[2,5]")));
3646
3647 auto then_branch_root = root->branch_computation(0)->root_instruction();
3648 EXPECT_THAT(then_branch_root,
3649 AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(),
3650 op::Constant()),
3651 op::Shape("f32[2,5]")));
3652
3653 auto else_branch_root = root->branch_computation(1)->root_instruction();
3654 EXPECT_THAT(else_branch_root,
3655 AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]")));
3656 }
3657
TEST_F(SpmdPartitioningTest,ConditionalManual)3658 TEST_F(SpmdPartitioningTest, ConditionalManual) {
3659 absl::string_view hlo_string = R"(
3660 HloModule module
3661
3662 Negate {
3663 x = f32[4,5] parameter(0), sharding={manual}
3664 ROOT negate = f32[4,5] negate(x), sharding={manual}
3665 }
3666
3667 Identity {
3668 y = f32[4,5] parameter(0), sharding={manual}
3669 ROOT copy = f32[4,5] copy(y), sharding={manual}
3670 }
3671
3672 ENTRY entry {
3673 %param.0 = pred[] parameter(0), sharding={manual}
3674 %param.1 = f32[4,5] parameter(1), sharding={manual}
3675 %param.2 = f32[4,5] parameter(2), sharding={manual}
3676 ROOT cond = f32[4,5] conditional(%param.0, %param.1, %param.2),
3677 true_computation=Negate, false_computation=Identity, sharding={manual}
3678 })";
3679
3680 TF_ASSERT_OK_AND_ASSIGN(auto module,
3681 PartitionComputation(hlo_string, /*num_devices=*/2));
3682 VLOG(1) << module->ToString();
3683
3684 auto param0 = AllOf(op::Parameter(0), op::Shape("pred[]"));
3685 auto param1 = AllOf(op::Parameter(1), op::Shape("f32[4,5]"));
3686 auto param2 = AllOf(op::Parameter(2), op::Shape("f32[4,5]"));
3687
3688 const auto root = module->entry_computation()->root_instruction();
3689 EXPECT_THAT(root, AllOf(op::Conditional(param0, param1, param2),
3690 op::Shape("f32[4,5]")));
3691 }
3692
TEST_F(SpmdPartitioningTest,WhileManual)3693 TEST_F(SpmdPartitioningTest, WhileManual) {
3694 absl::string_view hlo_string = R"(
3695 HloModule module
3696
3697 LoopCond {
3698 x = s32[] parameter(0), sharding={manual}
3699 const = s32[] constant(5), sharding={manual}
3700 ROOT lt = pred[] compare(x, const), direction=LT, sharding={manual}
3701 }
3702
3703 Inc {
3704 x = s32[] parameter(0), sharding={manual}
3705 const = s32[] constant(1), sharding={manual}
3706 ROOT add = s32[] add(x, const), sharding={manual}
3707 }
3708
3709 ENTRY entry {
3710 zero = s32[] parameter(0), sharding={manual}
3711 ROOT while = s32[] while(zero), body=Inc, condition=LoopCond,
3712 sharding={manual}
3713 })";
3714
3715 TF_ASSERT_OK_AND_ASSIGN(auto module,
3716 PartitionComputation(hlo_string, /*num_devices=*/2));
3717 VLOG(1) << module->ToString();
3718
3719 auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
3720 const auto root = module->entry_computation()->root_instruction();
3721 EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
3722 }
3723
TEST_F(SpmdPartitioningTest,SelectAndScatter_RetinaNet)3724 TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) {
3725 absl::string_view hlo_string = R"(
3726 HloModule module
3727
3728 ge {
3729 a = f32[] parameter(0)
3730 b = f32[] parameter(1)
3731 ROOT compare = pred[] compare(a, b), direction=GE
3732 }
3733
3734 sum {
3735 c = f32[] parameter(0)
3736 d = f32[] parameter(1)
3737 ROOT add = f32[] add(c, d)
3738 }
3739
3740 ENTRY entry {
3741 %param.0 = f32[32,128,384,64] parameter(0)
3742 %param.0.copy = f32[32,128,384,64] copy(%param.0),
3743 sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3744 %param.1 = f32[32,64,192,64] parameter(1)
3745 %param.1.copy = f32[32,64,192,64] copy(%param.1),
3746 sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3747 constant.1 = f32[] constant(0), sharding={replicated}
3748 ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy,
3749 %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1},
3750 select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3751 })";
3752 TF_ASSERT_OK_AND_ASSIGN(auto module,
3753 PartitionComputation(hlo_string, /*num_devices=*/8));
3754 VLOG(1) << module->ToString();
3755
3756 const auto root = module->entry_computation()->root_instruction();
3757 auto source = AllOf(
3758 op::Shape("f32[32,8,192,64]"),
3759 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
3760 op::Constant(), op::Constant())));
3761 auto data = AllOf(
3762 op::Shape("f32[32,16,384,64]"),
3763 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
3764 op::Constant(), op::Constant())));
3765
3766 EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant()));
3767 EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
3768 EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
3769 }
3770
TEST_F(SpmdPartitioningTest,TiledDot)3771 TEST_F(SpmdPartitioningTest, TiledDot) {
3772 absl::string_view hlo_string = R"(
3773 HloModule module
3774
3775 ENTRY entry {
3776 %lhs = f32[128,64] parameter(0)
3777 %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3778 %rhs = f32[64,256] parameter(1)
3779 %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3780 ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3781 dim_labels=bf_io->bf, sharding={replicated}
3782 })";
3783
3784 TF_ASSERT_OK_AND_ASSIGN(
3785 auto module,
3786 PartitionComputation(hlo_string, /*num_devices=*/2,
3787 /*conv_halo_exchange_always_on_lhs=*/false));
3788 VLOG(1) << module->ToString();
3789
3790 const auto root = module->entry_computation()->root_instruction();
3791 const auto lhs = AllOf(op::Copy(op::DynamicSlice(
3792 op::Parameter(), op::Constant(), op::Reshape())),
3793 op::Shape("f32[128,32]"));
3794 const auto rhs = AllOf(op::Copy(op::DynamicSlice(
3795 op::Parameter(), op::Reshape(), op::Constant())),
3796 op::Shape("f32[32,256]"));
3797 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3798 op::Shape("f32[128,256]")));
3799 }
3800
TEST_F(SpmdPartitioningTest,TiledDotOutputTiled)3801 TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) {
3802 absl::string_view hlo_string = R"(
3803 HloModule module
3804
3805 ENTRY entry {
3806 %lhs = f32[128,64] parameter(0)
3807 %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3808 %rhs = f32[64,256] parameter(1)
3809 %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3810 ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3811 dim_labels=bf_io->bf, sharding={devices=[1,2]0,1}
3812 })";
3813
3814 TF_ASSERT_OK_AND_ASSIGN(auto module,
3815 PartitionComputation(hlo_string, /*num_devices=*/2));
3816 VLOG(1) << module->ToString();
3817
3818 const auto root = module->entry_computation()->root_instruction();
3819 const auto lhs = AllOf(op::Copy(op::DynamicSlice(
3820 op::Parameter(), op::Constant(), op::Reshape())),
3821 op::Shape("f32[128,32]"));
3822 const auto rhs = AllOf(op::Copy(op::DynamicSlice(
3823 op::Parameter(), op::Reshape(), op::Constant())),
3824 op::Shape("f32[32,256]"));
3825 EXPECT_THAT(root, AllOf(op::DynamicSlice(
3826 AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3827 op::Shape("f32[128,256]")),
3828 op::Constant(), op::Reshape()),
3829 op::Shape("f32[128,128]")));
3830 }
3831
TEST_F(SpmdPartitioningTest,BatchPartitionedConvolution)3832 TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) {
3833 absl::string_view hlo_string = R"(
3834 HloModule module
3835
3836 ENTRY entry {
3837 %lhs = f32[128,256,256] parameter(0)
3838 %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1}
3839 %rhs = f32[256,8,1] parameter(1)
3840 %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated}
3841 ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy),
3842 window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1}
3843 })";
3844
3845 TF_ASSERT_OK_AND_ASSIGN(auto module,
3846 PartitionComputation(hlo_string, /*num_devices=*/2));
3847 VLOG(1) << module->ToString();
3848
3849 const auto root = module->entry_computation()->root_instruction();
3850 const auto lhs =
3851 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3852 op::Reshape(), op::Constant())),
3853 op::Shape("f32[128,128,256]"));
3854 const auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]"));
3855 EXPECT_THAT(root,
3856 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]")));
3857 }
3858
TEST_F(SpmdPartitioningTest,DotOutputFeaturePartitioned)3859 TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) {
3860 absl::string_view hlo_string = R"(
3861 HloModule module
3862
3863 ENTRY entry {
3864 %lhs = f32[24,64] parameter(0)
3865 %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated}
3866 %rhs = f32[39296,64] parameter(1)
3867 %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1}
3868 ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy),
3869 lhs_batch_dims={}, rhs_batch_dims={},
3870 lhs_contracting_dims={1}, rhs_contracting_dims={1},
3871 sharding={devices=[1,2]0,1}
3872 })";
3873
3874 TF_ASSERT_OK_AND_ASSIGN(auto module,
3875 PartitionComputation(hlo_string, /*num_devices=*/2));
3876 VLOG(1) << module->ToString();
3877
3878 const auto root = module->entry_computation()->root_instruction();
3879 const auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]"));
3880 const auto rhs = AllOf(op::Copy(op::DynamicSlice(
3881 op::Parameter(1), op::Reshape(), op::Constant())),
3882 op::Shape("f32[19648,64]"));
3883 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]")));
3884 }
3885
TEST_F(SpmdPartitioningTest,WindowedEinsumTwoContractingDimsLhsReshard)3886 TEST_F(SpmdPartitioningTest, WindowedEinsumTwoContractingDimsLhsReshard) {
3887 absl::string_view hlo_string = R"(
3888 HloModule module
3889
3890 ENTRY entry {
3891 %p0 = f32[2048,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1}
3892 %p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1}
3893 ROOT %dot.224 = f32[2048,2176]{1,0} dot(f32[2048,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1}
3894 })";
3895
3896 TF_ASSERT_OK_AND_ASSIGN(
3897 auto module,
3898 PartitionComputation(hlo_string, /*num_devices=*/2,
3899 /*conv_halo_exchange_always_on_lhs=*/true,
3900 /*choose_faster_windowed_einsum=*/false,
3901 /*unroll_windowed_einsum=*/false,
3902 /*bidirectional_windowed_einsum=*/false,
3903 /*threshold_for_windowed_einsum_mib=*/0));
3904 VLOG(1) << module->ToString();
3905
3906 // Check while op.
3907 const auto arg0 = AllOf(
3908 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0))))),
3909 op::Shape("f32[2048,1,3264]"));
3910 const auto arg1 = AllOf(op::Parameter(1), op::Shape("f32[1,3264,2176]"));
3911
3912 const auto while_op =
3913 AllOf(op::While(op::Tuple(arg0, arg1, op::Broadcast(), op::Broadcast(),
3914 op::Constant())),
3915 op::Shape("(f32[2048,1,3264]{2,1,0}, f32[1,3264,2176]{2,1,0},"
3916 " f32[2048,1088]{1,0}, f32[2048,1088]{1,0}, u32[])"));
3917 const auto root = module->entry_computation()->root_instruction();
3918 EXPECT_THAT(
3919 root, AllOf(op::GetTupleElement(while_op), op::Shape("f32[2048,1088]")));
3920
3921 // Check while op body.
3922 const auto while_loop = root->operand(0);
3923 const auto next_i =
3924 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
3925 auto lhs = AllOf(op::GetTupleElement(op::Parameter(0)),
3926 op::Shape("f32[2048,1,3264]"));
3927 auto rhs = AllOf(op::DynamicSlice(), op::Shape("f32[1,3264,1088]"));
3928 auto dot_op = op::Dot(lhs, rhs);
3929 auto add_op = op::Add(op::GetTupleElement(op::Parameter(0)), dot_op);
3930 auto cond_op =
3931 op::Conditional(op::Compare(next_i, op::Constant()), add_op, add_op);
3932 EXPECT_THAT(while_loop->while_body()->root_instruction(),
3933 op::Tuple(op::GetTupleElement(op::Parameter(0)),
3934 op::GetTupleElement(op::Parameter(0)), cond_op,
3935 op::GetTupleElement(op::Parameter(0)), next_i));
3936 }
3937
TEST_F(SpmdPartitioningTest,WindowedEinsumTwoContractingDimsRhsReshard)3938 TEST_F(SpmdPartitioningTest, WindowedEinsumTwoContractingDimsRhsReshard) {
3939 absl::string_view hlo_string = R"(
3940 HloModule module
3941
3942 ENTRY entry {
3943 %p0 = f32[4096,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1}
3944 %p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1}
3945 ROOT %dot.224 = f32[4096,2176]{1,0} dot(f32[4096,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1}
3946 })";
3947
3948 TF_ASSERT_OK_AND_ASSIGN(
3949 auto module,
3950 PartitionComputation(hlo_string, /*num_devices=*/2,
3951 /*conv_halo_exchange_always_on_lhs=*/true,
3952 /*choose_faster_windowed_einsum=*/false,
3953 /*unroll_windowed_einsum=*/false,
3954 /*bidirectional_windowed_einsum=*/false,
3955 /*threshold_for_windowed_einsum_mib=*/0));
3956 VLOG(1) << module->ToString();
3957
3958 // Check while op.
3959 const auto arg0 = AllOf(op::Parameter(0), op::Shape("f32[4096,2,1632]"));
3960 const auto arg1 = AllOf(
3961 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(1))))),
3962 op::Shape("f32[2,1632,2176]"));
3963
3964 const auto while_op =
3965 AllOf(op::While(op::Tuple(arg0, arg1, op::Broadcast(), op::Broadcast(),
3966 op::Constant())),
3967 op::Shape("(f32[4096,2,1632]{2,1,0}, f32[2,1632,2176]{2,1,0},"
3968 " f32[4096,1088]{1,0}, f32[4096,1088]{1,0}, u32[])"));
3969 const auto root = module->entry_computation()->root_instruction();
3970 EXPECT_THAT(
3971 root, AllOf(op::GetTupleElement(while_op), op::Shape("f32[4096,1088]")));
3972
3973 // Check while op body.
3974 const auto while_loop = root->operand(0);
3975 const auto next_i =
3976 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
3977 auto lhs = AllOf(op::GetTupleElement(op::Parameter(0)),
3978 op::Shape("f32[4096,2,1632]"));
3979 auto rhs = AllOf(op::DynamicSlice(), op::Shape("f32[2,1632,1088]"));
3980 auto dot_op = op::Dot(lhs, rhs);
3981 auto add_op = op::Add(op::GetTupleElement(op::Parameter(0)), dot_op);
3982 auto cond_op =
3983 op::Conditional(op::Compare(next_i, op::Constant()), add_op, add_op);
3984 EXPECT_THAT(while_loop->while_body()->root_instruction(),
3985 op::Tuple(op::GetTupleElement(op::Parameter(0)),
3986 op::GetTupleElement(op::Parameter(0)), cond_op,
3987 op::GetTupleElement(op::Parameter(0)), next_i));
3988 }
3989
TEST_F(SpmdPartitioningTest,DotPartialDeviceOrder)3990 TEST_F(SpmdPartitioningTest, DotPartialDeviceOrder) {
3991 absl::string_view hlo_string = R"(
3992 HloModule module
3993
3994 ENTRY entry {
3995 %lhs = f32[16,256,4096] parameter(0), sharding={devices=[1,1,2,2]1,3,0,2 last_tile_dim_replicate}
3996 %rhs = f32[4096,2048] parameter(1), sharding={devices=[2,2]3,1,2,0}
3997 ROOT %dot = f32[16,256,2048] dot(%lhs, %rhs),
3998 lhs_batch_dims={}, rhs_batch_dims={},
3999 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4000 sharding={devices=[1,1,2,2]2,3,0,1 last_tile_dim_replicate}
4001 })";
4002
4003 TF_ASSERT_OK_AND_ASSIGN(auto module,
4004 PartitionComputation(hlo_string, /*num_devices=*/4));
4005 VLOG(1) << module->ToString();
4006
4007 const auto root = module->entry_computation()->root_instruction();
4008 const auto lhs = AllOf(op::Parameter(0), op::Shape("f32[16,256,2048]"));
4009 const auto rhs = AllOf(op::Parameter(1), op::Shape("f32[2048,1024]"));
4010 EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)),
4011 op::Shape("f32[16,256,1024]")));
4012 }
4013
TEST_F(SpmdPartitioningTest,EinsumBatchPartitioned)4014 TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) {
4015 absl::string_view hlo_string = R"(
4016 HloModule module
4017
4018 ENTRY entry {
4019 %lhs = f32[32,24,64] parameter(0)
4020 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
4021 %rhs = f32[32,39296,64] parameter(1)
4022 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
4023 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4024 lhs_batch_dims={0}, rhs_batch_dims={0},
4025 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4026 sharding={devices=[2,1,1]0,1}
4027 })";
4028
4029 TF_ASSERT_OK_AND_ASSIGN(auto module,
4030 PartitionComputation(hlo_string, /*num_devices=*/2));
4031 VLOG(1) << module->ToString();
4032
4033 const auto root = module->entry_computation()->root_instruction();
4034 const auto lhs =
4035 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4036 op::Constant(), op::Constant())),
4037 op::Shape("f32[16,24,64]"));
4038 const auto rhs =
4039 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4040 op::Constant(), op::Constant())),
4041 op::Shape("f32[16,39296,64]"));
4042 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]")));
4043 }
4044
TEST_F(SpmdPartitioningTest,EinsumLHSandOutputBatchPartitioned)4045 TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) {
4046 absl::string_view hlo_string = R"(
4047 HloModule module
4048
4049 ENTRY entry {
4050 %lhs = f32[32,24,64] parameter(0)
4051 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
4052 %rhs = f32[32,39296,64] parameter(1)
4053 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
4054 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4055 lhs_batch_dims={0}, rhs_batch_dims={0},
4056 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4057 sharding={devices=[2,1,1]0,1}
4058 })";
4059
4060 TF_ASSERT_OK_AND_ASSIGN(auto module,
4061 PartitionComputation(hlo_string, /*num_devices=*/2));
4062 VLOG(1) << module->ToString();
4063
4064 const auto root = module->entry_computation()->root_instruction();
4065 const auto lhs =
4066 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4067 op::Constant(), op::Constant())),
4068 op::Shape("f32[16,24,64]"));
4069 const auto rhs =
4070 AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
4071 EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(),
4072 op::Constant(),
4073 op::Constant())),
4074 op::Shape("f32[16,24,39296]")));
4075 }
4076
TEST_F(SpmdPartitioningTest,EinsumRHSandOutputBatchPartitioned)4077 TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) {
4078 absl::string_view hlo_string = R"(
4079 HloModule module
4080
4081 ENTRY entry {
4082 %lhs = f32[32,24,64] parameter(0)
4083 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1}
4084 %rhs = f32[32,39296,64] parameter(1)
4085 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
4086 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4087 lhs_batch_dims={0}, rhs_batch_dims={0},
4088 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4089 sharding={devices=[2,1,1]0,1}
4090 })";
4091
4092 TF_ASSERT_OK_AND_ASSIGN(auto module,
4093 PartitionComputation(hlo_string, /*num_devices=*/2));
4094 VLOG(1) << module->ToString();
4095
4096 const auto root = module->entry_computation()->root_instruction();
4097 const auto lhs =
4098 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4099 op::Reshape(), op::Constant())),
4100 op::Shape("f32[32,12,64]"));
4101 const auto rhs =
4102 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4103 op::Constant(), op::Constant())),
4104 op::Shape("f32[16,39296,64]"));
4105 const auto lhs_reshard =
4106 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))));
4107 EXPECT_THAT(root,
4108 AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]")));
4109 }
4110
TEST_F(SpmdPartitioningTest,EinsumOutputBatchPartitioned)4111 TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) {
4112 absl::string_view hlo_string = R"(
4113 HloModule module
4114
4115 ENTRY entry {
4116 %lhs = f32[32,24,64] parameter(0)
4117 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
4118 %rhs = f32[32,39296,64] parameter(1)
4119 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
4120 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4121 lhs_batch_dims={0}, rhs_batch_dims={0},
4122 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4123 sharding={devices=[2,1,1]0,1}
4124 })";
4125
4126 TF_ASSERT_OK_AND_ASSIGN(auto module,
4127 PartitionComputation(hlo_string, /*num_devices=*/2));
4128 VLOG(1) << module->ToString();
4129
4130 const auto root = module->entry_computation()->root_instruction();
4131 const auto lhs_slice =
4132 AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(),
4133 op::Constant(), op::Constant()),
4134 op::Shape("f32[16,24,64]"));
4135 const auto rhs_slice =
4136 AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(),
4137 op::Constant(), op::Constant()),
4138 op::Shape("f32[16,39296,64]"));
4139 EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice),
4140 op::Shape("f32[16,24,39296]")));
4141 }
4142
TEST_F(SpmdPartitioningTest,EinsumContractingDimsPartitioned)4143 TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) {
4144 absl::string_view hlo_string = R"(
4145 HloModule module
4146
4147 ENTRY entry {
4148 %lhs = f32[32,24,64,128] parameter(0)
4149 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3}
4150 %rhs = f32[32,39296,64,128] parameter(1)
4151 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3}
4152 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4153 lhs_batch_dims={0}, rhs_batch_dims={0},
4154 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4155 sharding={replicated}
4156 })";
4157
4158 TF_ASSERT_OK_AND_ASSIGN(auto module,
4159 PartitionComputation(hlo_string, /*num_devices=*/4));
4160 VLOG(1) << module->ToString();
4161
4162 const auto root = module->entry_computation()->root_instruction();
4163 const auto lhs = AllOf(
4164 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4165 op::Constant(), op::Reshape(), op::Reshape())),
4166 op::Shape("f32[32,24,32,64]"));
4167 const auto rhs = AllOf(
4168 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4169 op::Constant(), op::Reshape(), op::Reshape())),
4170 op::Shape("f32[32,39296,32,64]"));
4171 EXPECT_THAT(root, AllOf(op::AllReduce(op::AllReduce(op::Dot(lhs, rhs))),
4172 op::Shape("f32[32,24,39296]")));
4173 }
4174
TEST_F(SpmdPartitioningTest,EinsumContractingDimsPartitionedResultPartiallySliced)4175 TEST_F(SpmdPartitioningTest,
4176 EinsumContractingDimsPartitionedResultPartiallySliced) {
4177 absl::string_view hlo_string = R"(
4178 HloModule module
4179
4180 ENTRY entry {
4181 %lhs = f32[32,64] parameter(0), sharding={devices=[1,4]0,1,2,3}
4182 %rhs = f32[64,128] parameter(1), sharding={devices=[4,1]0,1,2,3}
4183 ROOT %dot = f32[32,128] dot(%lhs, %rhs),
4184 lhs_contracting_dims={1}, rhs_contracting_dims={0},
4185 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
4186 })";
4187
4188 TF_ASSERT_OK_AND_ASSIGN(auto module,
4189 PartitionComputation(hlo_string, /*num_devices=*/4));
4190 VLOG(1) << module->ToString();
4191
4192 const auto root = module->entry_computation()->root_instruction();
4193 const auto lhs = AllOf(op::Parameter(0), op::Shape("f32[32,16]"));
4194 const auto rhs = AllOf(op::Parameter(1), op::Shape("f32[16,128]"));
4195 EXPECT_THAT(root, AllOf(op::AllReduce(op::DynamicSlice(
4196 op::AllReduce(op::Dot(lhs, rhs)), _, _)),
4197 op::Shape("f32[16,128]")));
4198 }
4199
TEST_F(SpmdPartitioningTest,EinsumLHSNonContractingDimsPartitioned)4200 TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) {
4201 absl::string_view hlo_string = R"(
4202 HloModule module
4203
4204 ENTRY entry {
4205 %lhs = f32[32,24,64,128] parameter(0)
4206 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3}
4207 %rhs = f32[32,39296,64] parameter(1)
4208 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
4209 ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy),
4210 lhs_batch_dims={0}, rhs_batch_dims={0},
4211 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4212 sharding={devices=[1,2,2,1]0,1,2,3}
4213 })";
4214
4215 TF_ASSERT_OK_AND_ASSIGN(auto module,
4216 PartitionComputation(hlo_string, /*num_devices=*/4));
4217 VLOG(1) << module->ToString();
4218
4219 const auto root = module->entry_computation()->root_instruction();
4220 const auto lhs = AllOf(
4221 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
4222 op::Constant(), op::Reshape())),
4223 op::Shape("f32[32,12,64,64]"));
4224 const auto rhs =
4225 AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
4226 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]")));
4227 }
4228
TEST_F(SpmdPartitioningTest,EinsumRHSNonContractingDimsPartitioned)4229 TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) {
4230 absl::string_view hlo_string = R"(
4231 HloModule module
4232
4233 ENTRY entry {
4234 %lhs = f32[32,24,64] parameter(0)
4235 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
4236 %rhs = f32[32,39296,64,128] parameter(1)
4237 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3}
4238 ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy),
4239 lhs_batch_dims={0}, rhs_batch_dims={0},
4240 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4241 sharding={devices=[1,1,2,2]0,1,2,3}
4242 })";
4243
4244 TF_ASSERT_OK_AND_ASSIGN(auto module,
4245 PartitionComputation(hlo_string, /*num_devices=*/4));
4246 VLOG(1) << module->ToString();
4247
4248 const auto root = module->entry_computation()->root_instruction();
4249 const auto lhs =
4250 AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]"));
4251 const auto rhs = AllOf(
4252 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
4253 op::Constant(), op::Reshape())),
4254 op::Shape("f32[32,19648,64,64]"));
4255 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]")));
4256 }
4257
TEST_F(SpmdPartitioningTest,EinsumOutputLHSNonContractingDimPartitioned)4258 TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) {
4259 absl::string_view hlo_string = R"(
4260 HloModule module
4261
4262 ENTRY entry {
4263 %lhs = f32[32,24,64,128] parameter(0)
4264 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4265 %rhs = f32[32,39296,64,128] parameter(1)
4266 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4267 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4268 lhs_batch_dims={0}, rhs_batch_dims={0},
4269 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4270 sharding={devices=[1,2,1]0,1}
4271 })";
4272
4273 TF_ASSERT_OK_AND_ASSIGN(auto module,
4274 PartitionComputation(hlo_string, /*num_devices=*/2));
4275 VLOG(1) << module->ToString();
4276
4277 const auto root = module->entry_computation()->root_instruction();
4278 const auto lhs =
4279 AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4280 const auto rhs =
4281 AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4282 EXPECT_THAT(
4283 root,
4284 AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(),
4285 op::Constant(), op::Constant()),
4286 op::Shape("f32[32,12,64,128]")),
4287 rhs),
4288 op::Shape("f32[32,12,39296]")));
4289 }
4290
TEST_F(SpmdPartitioningTest,EinsumOutputRHSNonContractingDimPartitioned)4291 TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) {
4292 absl::string_view hlo_string = R"(
4293 HloModule module
4294
4295 ENTRY entry {
4296 %lhs = f32[32,24,64,128] parameter(0)
4297 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4298 %rhs = f32[32,39296,64,128] parameter(1)
4299 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4300 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4301 lhs_batch_dims={0}, rhs_batch_dims={0},
4302 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4303 sharding={devices=[1,1,2]0,1}
4304 })";
4305
4306 TF_ASSERT_OK_AND_ASSIGN(auto module,
4307 PartitionComputation(hlo_string, /*num_devices=*/2));
4308 VLOG(1) << module->ToString();
4309
4310 const auto root = module->entry_computation()->root_instruction();
4311 const auto lhs =
4312 AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4313 const auto rhs =
4314 AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4315 EXPECT_THAT(root,
4316 AllOf(op::Dot(lhs, AllOf(op::DynamicSlice(
4317 rhs, op::Constant(), op::Reshape(),
4318 op::Constant(), op::Constant()),
4319 op::Shape("f32[32,19648,64,128]"))),
4320 op::Shape("f32[32,24,19648]")));
4321 }
4322
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingPartitioned)4323 TEST_F(SpmdPartitioningTest,
4324 EinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4325 absl::string_view hlo_string = R"(
4326 HloModule module
4327
4328 ENTRY entry {
4329 %lhs = f32[320,25,64,128] parameter(0)
4330 %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4331 %rhs = f32[320,39296,64,128] parameter(1)
4332 %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4333 sharding={devices=[1,1,4,1]0,1,2,3}
4334 ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4335 lhs_batch_dims={0}, rhs_batch_dims={0},
4336 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4337 sharding={devices=[1,4,1]0,1,2,3}
4338 })";
4339
4340 TF_ASSERT_OK_AND_ASSIGN(auto module,
4341 PartitionComputation(hlo_string, /*num_devices=*/4));
4342 VLOG(1) << module->ToString();
4343
4344 const auto root = module->entry_computation()->root_instruction();
4345 const auto lhs = AllOf(
4346 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4347 op::Constant(), op::Reshape(), op::Constant())),
4348 op::Shape("f32[320,25,16,128]"));
4349 const auto rhs = AllOf(
4350 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4351 op::Constant(), op::Reshape(), op::Constant())),
4352 op::Shape("f32[320,39296,16,128]"));
4353 EXPECT_THAT(
4354 root,
4355 AllOf(op::GetTupleElement(op::While(op::Tuple(
4356 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4357 op::Shape("f32[320,7,39296]")));
4358
4359 const auto while_loop = root->operand(0);
4360 // Check loop condition.
4361 EXPECT_THAT(
4362 while_loop->while_condition()->root_instruction(),
4363 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4364
4365 // Check loop body.
4366 const auto next_i =
4367 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4368 auto ds =
4369 AllOf(op::DynamicSlice(
4370 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4371 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
4372 op::Shape("f32[320,7,16,128]"));
4373 auto partial_output =
4374 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4375 op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
4376 op::Shape("f32[320,7,39296]"));
4377 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4378 partial_output, partial_output);
4379 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4380 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4381 op::GetTupleElement(op::Parameter(0)), window,
4382 op::GetTupleElement(op::Parameter(0)), next_i));
4383
4384 // Check the conditional that contains the collective permute.
4385 auto cp_conditional =
4386 while_loop->while_body()->root_instruction()->operand(2);
4387 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4388 op::CollectivePermute(op::Parameter(0)));
4389 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4390 op::Parameter(0));
4391 }
4392
TEST_F(SpmdPartitioningTest,UnrolledEinsumRHSWindowedInContractingOutNonContractingPartitioned)4393 TEST_F(SpmdPartitioningTest,
4394 UnrolledEinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4395 absl::string_view hlo_string = R"(
4396 HloModule module
4397
4398 ENTRY entry {
4399 %lhs = f32[320,25,64,128] parameter(0)
4400 %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4401 %rhs = f32[320,39296,64,128] parameter(1)
4402 %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4403 sharding={devices=[1,1,4,1]0,1,2,3}
4404 ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4405 lhs_batch_dims={0}, rhs_batch_dims={0},
4406 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4407 sharding={devices=[1,4,1]0,1,2,3}
4408 })";
4409
4410 TF_ASSERT_OK_AND_ASSIGN(
4411 auto module,
4412 PartitionComputation(hlo_string, /*num_devices=*/4,
4413 /*conv_halo_exchange_always_on_lhs =*/true,
4414 /*choose_faster_windowed_einsum =*/false,
4415 /*unroll_windowed_einsum =*/true));
4416 VLOG(1) << module->ToString();
4417
4418 const auto lhs = AllOf(
4419 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4420 op::Constant(), op::Reshape(), op::Constant())),
4421 op::Shape("f32[320,25,16,128]"));
4422 const auto rhs = AllOf(
4423 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4424 op::Constant(), op::Reshape(), op::Constant())),
4425 op::Shape("f32[320,39296,16,128]"));
4426 const auto while_op = AllOf(
4427 op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4428 op::Constant())),
4429 op::Shape("(f32[320,25,16,128], f32[320,39296,16,128], f32[320,7,39296],"
4430 " f32[320,7,39296], u32[])"));
4431 const auto root = module->entry_computation()->root_instruction();
4432 EXPECT_THAT(
4433 root, AllOf(op::Add(op::CollectivePermute(op::GetTupleElement(while_op)),
4434 op::GetTupleElement(while_op)),
4435 op::Shape("f32[320,7,39296]")));
4436
4437 const auto while_loop = root->operand(1)->operand(0);
4438 // Check loop condition.
4439 EXPECT_THAT(
4440 while_loop->while_condition()->root_instruction(),
4441 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4442
4443 // Check loop body.
4444 const auto next_i =
4445 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4446 auto ds =
4447 AllOf(op::DynamicSlice(
4448 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4449 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
4450 op::Shape("f32[320,7,16,128]"));
4451 auto partial_output = AllOf(
4452 op::Add(op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4453 op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
4454 op::Shape("f32[320,7,39296]"));
4455 auto partial_output2 =
4456 AllOf(op::CollectivePermute(
4457 op::Add(op::GetTupleElement(op::Parameter(0)),
4458 op::Dot(ds, op::GetTupleElement(op::Parameter(0))))),
4459 op::Shape("f32[320,7,39296]"));
4460
4461 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4462 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4463 op::GetTupleElement(op::Parameter(0)), partial_output,
4464 partial_output2, next_i));
4465 }
4466
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedInContractingOutNonContractingPartitioned)4467 TEST_F(
4468 SpmdPartitioningTest,
4469 BidirectionalEinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4470 absl::string_view hlo_string = R"(
4471 HloModule module
4472
4473 ENTRY entry {
4474 %lhs = f32[320,25,64,128] parameter(0)
4475 %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4476 %rhs = f32[320,39296,64,128] parameter(1)
4477 %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4478 sharding={devices=[1,1,4,1]0,1,2,3}
4479 ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4480 lhs_batch_dims={0}, rhs_batch_dims={0},
4481 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4482 sharding={devices=[1,4,1]0,1,2,3}
4483 })";
4484
4485 TF_ASSERT_OK_AND_ASSIGN(
4486 auto module,
4487 PartitionComputation(hlo_string, /*num_devices=*/4,
4488 /*conv_halo_exchange_always_on_lhs =*/true,
4489 /*choose_faster_windowed_einsum =*/false,
4490 /*unroll_windowed_einsum =*/false,
4491 /*bidirectional_windowed_einsum =*/true));
4492 VLOG(1) << module->ToString();
4493 const auto lhs = AllOf(
4494 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4495 op::Constant(), op::Reshape(), op::Constant())),
4496 op::Shape("f32[320,25,16,128]"));
4497 const auto rhs = AllOf(
4498 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4499 op::Constant(), op::Reshape(), op::Constant())),
4500 op::Shape("f32[320,39296,16,128]"));
4501 const auto while_op = AllOf(
4502 op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4503 op::Constant())),
4504 op::Shape("(f32[320,25,16,128], f32[320,39296,16,128], f32[320,7,39296],"
4505 " f32[320,7,39296], u32[])"));
4506 const auto root = module->entry_computation()->root_instruction();
4507 EXPECT_THAT(
4508 root, AllOf(op::Add(op::GetTupleElement(while_op),
4509 op::CollectivePermute(op::GetTupleElement(while_op))),
4510 op::Shape("f32[320,7,39296]")));
4511
4512 const auto while_loop = root->operand(0)->operand(0);
4513 // Check loop condition.
4514 EXPECT_THAT(
4515 while_loop->while_condition()->root_instruction(),
4516 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4517
4518 // Check loop body.
4519 const auto next_i =
4520 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4521 op::Constant());
4522 const auto partial_dot_pattern =
4523 AllOf(op::Reshape(op::Slice(
4524 op::Dot(op::Maximum(), op::GetTupleElement(op::Parameter(0))))),
4525 op::Shape("f32[320,7,39296]"));
4526 const auto partial_output_pattern = AllOf(
4527 op::Add(op::CollectivePermute(op::Add(
4528 op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4529 partial_dot_pattern)),
4530 partial_dot_pattern),
4531 op::Shape("f32[320,7,39296]"));
4532
4533 EXPECT_THAT(
4534 while_loop->while_body()->root_instruction(),
4535 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4536 op::GetTupleElement(op::Parameter(0)), partial_output_pattern,
4537 partial_output_pattern, next_i));
4538 }
4539
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingFromBroadcast)4540 TEST_F(SpmdPartitioningTest,
4541 EinsumRHSWindowedInContractingOutNonContractingFromBroadcast) {
4542 absl::string_view hlo_string = R"(
4543 HloModule module
4544
4545 ENTRY entry {
4546 %constant.1 = f32[] constant(2)
4547 %broadcast = f32[32,25,64,128] broadcast(%constant.1), dimensions={},
4548 sharding={devices=[1,1,4,1]0,1,2,3}
4549 %add = f32[32,25,64,128] add(%broadcast, %broadcast),
4550 sharding={devices=[1,1,4,1]0,1,2,3}
4551 %rhs = f32[32,39296,64,128] parameter(0)
4552 %rhs.copy = f32[32,39296,64,128] copy(%rhs),
4553 sharding={devices=[1,1,4,1]0,1,2,3}
4554 ROOT %dot = f32[32,25,39296] dot(%add, %rhs.copy),
4555 lhs_batch_dims={0}, rhs_batch_dims={0},
4556 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4557 sharding={devices=[1,4,1]0,1,2,3}
4558 })";
4559
4560 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4561 /*num_devices=*/4));
4562 VLOG(1) << module->ToString();
4563 // Involves loop code motion, skips pattern matching.
4564 }
4565
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned)4566 TEST_F(SpmdPartitioningTest,
4567 EinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4568 absl::string_view hlo_string = R"(
4569 HloModule module
4570
4571 ENTRY entry {
4572 %lhs = f32[16,1024,16384] parameter(0)
4573 %lhs.copy = f32[16,1024,16384] copy(%lhs),
4574 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4575 %rhs = f32[16384,67,128] parameter(1)
4576 %rhs.copy = f32[16384,67,128] copy(%rhs),
4577 sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4578 ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4579 lhs_batch_dims={}, rhs_batch_dims={},
4580 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4581 sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4582 })";
4583
4584 TF_ASSERT_OK_AND_ASSIGN(auto module,
4585 PartitionComputation(hlo_string, /*num_devices=*/8));
4586 VLOG(1) << module->ToString();
4587
4588 const auto root = module->entry_computation()->root_instruction();
4589 const auto lhs =
4590 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4591 op::Constant(), op::Reshape())),
4592 op::Shape("f32[8,1024,4096]"));
4593 const auto rhs =
4594 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4595 op::Constant(), op::Constant())),
4596 op::Shape("f32[4096,67,128]"));
4597 EXPECT_THAT(
4598 root,
4599 AllOf(op::GetTupleElement(op::While(op::Tuple(
4600 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4601 op::Shape("f32[8,1024,17,128]")));
4602
4603 const auto while_loop = root->operand(0);
4604 // Check loop condition.
4605 EXPECT_THAT(
4606 while_loop->while_condition()->root_instruction(),
4607 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4608
4609 // Check loop body.
4610 const auto next_i =
4611 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4612 auto ds =
4613 AllOf(op::DynamicSlice(
4614 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4615 op::Constant(), op::Reshape(), op::Constant()),
4616 op::Shape("f32[4096,17,128]"));
4617 auto partial_output =
4618 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4619 op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4620 op::Shape("f32[8,1024,17,128]"));
4621 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4622 partial_output, partial_output);
4623 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4624 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4625 op::GetTupleElement(op::Parameter(0)), window,
4626 op::GetTupleElement(op::Parameter(0)), next_i));
4627
4628 // Check the conditional that contains the collective permute.
4629 auto cp_conditional =
4630 while_loop->while_body()->root_instruction()->operand(2);
4631 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4632 op::CollectivePermute(op::Parameter(0)));
4633 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4634 op::Parameter(0));
4635 }
4636
TEST_F(SpmdPartitioningTest,UnrollEinsumLHSWindowedInContractingOutNonContractingPartitioned)4637 TEST_F(SpmdPartitioningTest,
4638 UnrollEinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4639 absl::string_view hlo_string = R"(
4640 HloModule module
4641
4642 ENTRY entry {
4643 %lhs = f32[16,1024,16384] parameter(0)
4644 %lhs.copy = f32[16,1024,16384] copy(%lhs),
4645 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4646 %rhs = f32[16384,67,128] parameter(1)
4647 %rhs.copy = f32[16384,67,128] copy(%rhs),
4648 sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4649 ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4650 lhs_batch_dims={}, rhs_batch_dims={},
4651 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4652 sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4653 })";
4654
4655 TF_ASSERT_OK_AND_ASSIGN(
4656 auto module,
4657 PartitionComputation(hlo_string, /*num_devices=*/8,
4658 /*conv_halo_exchange_always_on_lhs =*/true,
4659 /*choose_faster_windowed_einsum =*/false,
4660 /*unroll_windowed_einsum =*/true));
4661 VLOG(1) << module->ToString();
4662
4663 const auto lhs =
4664 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4665 op::Constant(), op::Reshape())),
4666 op::Shape("f32[8,1024,4096]"));
4667 const auto rhs =
4668 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4669 op::Constant(), op::Constant())),
4670 op::Shape("f32[4096,67,128]"));
4671 const auto while_op =
4672 AllOf(op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4673 op::Constant())),
4674 op::Shape("(f32[8,1024,4096], f32[4096,67,128], f32[8,1024,17,128],"
4675 " f32[8,1024,17,128], u32[])"));
4676 const auto root = module->entry_computation()->root_instruction();
4677 EXPECT_THAT(
4678 root, AllOf(op::Add(op::CollectivePermute(op::GetTupleElement(while_op)),
4679 op::GetTupleElement(while_op)),
4680 op::Shape("f32[8,1024,17,128]")));
4681
4682 const auto while_loop = root->operand(1)->operand(0);
4683 // Check loop condition.
4684 EXPECT_THAT(
4685 while_loop->while_condition()->root_instruction(),
4686 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4687
4688 // Check loop body.
4689 const auto next_i =
4690 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4691 auto ds =
4692 AllOf(op::DynamicSlice(
4693 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4694 op::Constant(), op::Reshape(), op::Constant()),
4695 op::Shape("f32[4096,17,128]"));
4696 auto partial_output = AllOf(
4697 op::Add(op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4698 op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4699 op::Shape("f32[8,1024,17,128]"));
4700 auto partial_output2 =
4701 AllOf(op::CollectivePermute(
4702 op::Add(op::GetTupleElement(op::Parameter(0)),
4703 op::Dot(op::GetTupleElement(op::Parameter(0)), ds))),
4704 op::Shape("f32[8,1024,17,128]"));
4705
4706 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4707 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4708 op::GetTupleElement(op::Parameter(0)), partial_output,
4709 partial_output2, next_i));
4710 }
4711
TEST_F(SpmdPartitioningTest,BidirectionalEinsumLHSWindowedInContractingOutNonContractingPartitioned)4712 TEST_F(
4713 SpmdPartitioningTest,
4714 BidirectionalEinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4715 absl::string_view hlo_string = R"(
4716 HloModule module
4717
4718 ENTRY entry {
4719 %lhs = f32[16,1024,16384] parameter(0)
4720 %lhs.copy = f32[16,1024,16384] copy(%lhs),
4721 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4722 %rhs = f32[16384,67,128] parameter(1)
4723 %rhs.copy = f32[16384,67,128] copy(%rhs),
4724 sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4725 ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4726 lhs_batch_dims={}, rhs_batch_dims={},
4727 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4728 sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4729 })";
4730
4731 TF_ASSERT_OK_AND_ASSIGN(
4732 auto module,
4733 PartitionComputation(hlo_string, /*num_devices=*/8,
4734 /*conv_halo_exchange_always_on_lhs =*/true,
4735 /*choose_faster_windowed_einsum =*/false,
4736 /*unroll_windowed_einsum =*/false,
4737 /*bidirectional_windowed_einsum =*/true));
4738 VLOG(1) << module->ToString();
4739
4740 const auto lhs =
4741 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4742 op::Constant(), op::Reshape())),
4743 op::Shape("f32[8,1024,4096]"));
4744 const auto rhs =
4745 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4746 op::Constant(), op::Constant())),
4747 op::Shape("f32[4096,67,128]"));
4748 const auto while_op =
4749 AllOf(op::While(op::Tuple(lhs, rhs, op::Broadcast(), op::Broadcast(),
4750 op::Constant())),
4751 op::Shape("(f32[8,1024,4096], f32[4096,67,128], f32[8,1024,17,128],"
4752 " f32[8,1024,17,128], u32[])"));
4753 const auto root = module->entry_computation()->root_instruction();
4754 EXPECT_THAT(
4755 root, AllOf(op::Add(op::GetTupleElement(while_op),
4756 op::CollectivePermute(op::GetTupleElement(while_op))),
4757 op::Shape("f32[8,1024,17,128]")));
4758
4759 const auto while_loop = root->operand(0)->operand(0);
4760 // Check loop condition.
4761 EXPECT_THAT(
4762 while_loop->while_condition()->root_instruction(),
4763 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4764
4765 // Check loop body.
4766 const auto next_i =
4767 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4768 op::Constant());
4769 const auto partial_dot_pattern =
4770 AllOf(op::Reshape(op::Slice(
4771 op::Dot(op::GetTupleElement(op::Parameter(0)), op::Maximum()))),
4772 op::Shape("f32[8,1024,17,128]"));
4773 const auto partial_output_pattern = AllOf(
4774 op::Add(op::CollectivePermute(op::Add(
4775 op::CollectivePermute(op::GetTupleElement(op::Parameter(0))),
4776 partial_dot_pattern)),
4777 partial_dot_pattern),
4778 op::Shape("f32[8,1024,17,128]"));
4779
4780 EXPECT_THAT(
4781 while_loop->while_body()->root_instruction(),
4782 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4783 op::GetTupleElement(op::Parameter(0)), partial_output_pattern,
4784 partial_output_pattern, next_i));
4785 }
4786
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned2)4787 TEST_F(SpmdPartitioningTest,
4788 EinsumLHSWindowedInContractingOutNonContractingPartitioned2) {
4789 absl::string_view hlo_string = R"(
4790 HloModule module
4791
4792 ENTRY entry {
4793 %lhs = f32[16,1024,16384] parameter(0)
4794 %lhs.copy = f32[16,1024,16384] copy(%lhs),
4795 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4796 %rhs = f32[16384,2,33,128] parameter(1)
4797 %rhs.copy = f32[16384,2,33,128] copy(%rhs),
4798 sharding={devices=[4,1,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4799 ROOT %dot = f32[16,1024,2,33,128] dot(%lhs.copy, %rhs.copy),
4800 lhs_batch_dims={}, rhs_batch_dims={},
4801 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4802 sharding={devices=[2,1,2,2,1]0,1,2,3,4,5,6,7}
4803 })";
4804
4805 TF_ASSERT_OK_AND_ASSIGN(auto module,
4806 PartitionComputation(hlo_string, /*num_devices=*/8));
4807 VLOG(1) << module->ToString();
4808
4809 const auto root = module->entry_computation()->root_instruction();
4810 const auto lhs =
4811 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4812 op::Constant(), op::Reshape())),
4813 op::Shape("f32[8,1024,4096]"));
4814 const auto rhs = AllOf(
4815 op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
4816 op::Constant(), op::Constant())),
4817 op::Shape("f32[4096,2,33,128]"));
4818 EXPECT_THAT(
4819 root,
4820 AllOf(op::GetTupleElement(op::While(op::Tuple(
4821 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4822 op::Shape("f32[8,1024,1,17,128]")));
4823
4824 const auto while_loop = root->operand(0);
4825 // Check loop condition.
4826 EXPECT_THAT(
4827 while_loop->while_condition()->root_instruction(),
4828 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4829
4830 // Check loop body.
4831 const auto next_i =
4832 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4833 auto ds =
4834 AllOf(op::DynamicSlice(
4835 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4836 op::Constant(), op::Reshape(), op::Reshape(), op::Constant()),
4837 op::Shape("f32[4096,1,17,128]"));
4838 auto partial_output =
4839 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4840 op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4841 op::Shape("f32[8,1024,1,17,128]"));
4842 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4843 partial_output, partial_output);
4844 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4845 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4846 op::GetTupleElement(op::Parameter(0)), window,
4847 op::GetTupleElement(op::Parameter(0)), next_i));
4848
4849 // Check the conditional that contains the collective permute.
4850 auto cp_conditional =
4851 while_loop->while_body()->root_instruction()->operand(2);
4852 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4853 op::CollectivePermute(op::Parameter(0)));
4854 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4855 op::Parameter(0));
4856 }
4857
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingNoDoubleAG)4858 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingNoDoubleAG) {
4859 absl::string_view hlo_string = R"(
4860 HloModule module
4861
4862 ENTRY entry {
4863 %lhs = f32[32,24,64,128] parameter(0)
4864 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4865 %lhs2 = f32[32,24,64,128] parameter(2)
4866 %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,2,1,1]0,1}
4867 %rhs = f32[32,39295,64,128] parameter(1)
4868 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4869 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4870 lhs_batch_dims={0}, rhs_batch_dims={0},
4871 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4872 sharding={devices=[1,2,1]0,1}
4873 %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
4874 lhs_batch_dims={0}, rhs_batch_dims={0},
4875 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4876 sharding={devices=[1,2,1]0,1}
4877 ROOT %t = tuple(%dot, %dot2)
4878 })";
4879
4880 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4881 /*num_devices=*/2));
4882 VLOG(1) << module->ToString();
4883 const auto root = module->entry_computation()->root_instruction();
4884 EXPECT_THAT(root, op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
4885 _, op::Dot(_, op::Slice(_)), _, _, _)),
4886 op::AllReduce(op::DynamicUpdateSlice(
4887 _, op::Dot(_, op::Slice(_)), _, _, _))));
4888 }
4889
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingNoSharedSharding)4890 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingNoSharedSharding) {
4891 absl::string_view hlo_string = R"(
4892 HloModule module
4893
4894 ENTRY entry {
4895 %lhs = f32[32,24,64,128] parameter(0)
4896 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4897 %lhs2 = f32[32,24,64,128] parameter(2)
4898 %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,1,2,1]0,1}
4899 %rhs = f32[32,39295,64,128] parameter(1)
4900 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4901 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4902 lhs_batch_dims={0}, rhs_batch_dims={0},
4903 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4904 sharding={devices=[1,2,1]0,1}
4905 %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
4906 lhs_batch_dims={0}, rhs_batch_dims={0},
4907 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4908 sharding={devices=[2,1,1]0,1}
4909 ROOT %t = tuple(%dot, %dot2)
4910 })";
4911
4912 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4913 /*num_devices=*/2));
4914 VLOG(1) << module->ToString();
4915 const auto root = module->entry_computation()->root_instruction();
4916 EXPECT_THAT(
4917 root,
4918 op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
4919 _, op::Slice(op::GetTupleElement(op::While(_))), _, _, _)),
4920 op::AllReduce(op::DynamicUpdateSlice(
4921 _, op::Dot(_, op::Slice(_)), _, _, _))));
4922 }
4923
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContractingNoSharedSharding)4924 TEST_F(SpmdPartitioningTest,
4925 UnrollEinsumRHSWindowedNonContractingNoSharedSharding) {
4926 absl::string_view hlo_string = R"(
4927 HloModule module
4928
4929 ENTRY entry {
4930 %lhs = f32[32,24,64,128] parameter(0)
4931 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4932 %lhs2 = f32[32,24,64,128] parameter(2)
4933 %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,1,2,1]0,1}
4934 %rhs = f32[32,39295,64,128] parameter(1)
4935 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4936 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4937 lhs_batch_dims={0}, rhs_batch_dims={0},
4938 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4939 sharding={devices=[1,2,1]0,1}
4940 %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
4941 lhs_batch_dims={0}, rhs_batch_dims={0},
4942 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4943 sharding={devices=[2,1,1]0,1}
4944 ROOT %t = tuple(%dot, %dot2)
4945 })";
4946
4947 TF_ASSERT_OK_AND_ASSIGN(
4948 auto module,
4949 PartitionComputation(hlo_string, /*num_devices=*/2,
4950 /*conv_halo_exchange_always_on_lhs =*/true,
4951 /*choose_faster_windowed_einsum =*/false,
4952 /*unroll_windowed_einsum =*/true));
4953 VLOG(1) << module->ToString();
4954 const auto root = module->entry_computation()->root_instruction();
4955 EXPECT_THAT(
4956 root,
4957 op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
4958 _, op::Slice(op::GetTupleElement(op::While(_))), _, _, _)),
4959 op::AllReduce(op::DynamicUpdateSlice(
4960 _, op::Dot(_, op::Slice(_)), _, _, _))));
4961
4962 // Tuple<-AllReduce<-DynamicUpdateSlice<-Slice<-GetTupleElement<-While
4963 const auto while_loop =
4964 root->operand(0)->operand(0)->operand(1)->operand(0)->operand(0);
4965 // Check loop condition.
4966 EXPECT_THAT(
4967 while_loop->while_condition()->root_instruction(),
4968 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4969
4970 // Check loop body.
4971 const auto next_i =
4972 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4973 op::Constant());
4974 auto intermediate_output = AllOf(
4975 op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
4976 op::Dot(op::GetTupleElement(op::Parameter(0)),
4977 op::GetTupleElement(op::Parameter(0))),
4978 op::Constant(), op::Constant(), op::Reshape()),
4979 op::Shape("f32[32,12,39296]"));
4980 auto output = AllOf(
4981 op::DynamicUpdateSlice(
4982 intermediate_output,
4983 op::Dot(op::GetTupleElement(op::Parameter(0)),
4984 op::CollectivePermute(op::GetTupleElement(op::Parameter(0)))),
4985 op::Constant(), op::Constant(), op::Reshape()),
4986 op::Shape("f32[32,12,39296]"));
4987
4988 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4989 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4990 op::CollectivePermute(op::CollectivePermute(
4991 op::GetTupleElement(op::Parameter(0)))),
4992 output, op::GetTupleElement(op::Parameter(0)), next_i));
4993 }
4994
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContractingNoSharedSharding)4995 TEST_F(SpmdPartitioningTest,
4996 BidirectionalEinsumRHSWindowedNonContractingNoSharedSharding) {
4997 absl::string_view hlo_string = R"(
4998 HloModule module
4999
5000 ENTRY entry {
5001 %lhs = f32[32,24,64,128] parameter(0)
5002 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5003 %lhs2 = f32[32,24,64,128] parameter(2)
5004 %lhs2.copy = f32[32,24,64,128] copy(%lhs2), sharding={devices=[1,1,4,1]0,1,2,3}
5005 %rhs = f32[32,39295,64,128] parameter(1)
5006 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
5007 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5008 lhs_batch_dims={0}, rhs_batch_dims={0},
5009 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5010 sharding={devices=[1,4,1]0,1,2,3}
5011 %dot2 = f32[32,24,39295] dot(%lhs2.copy, %rhs.copy),
5012 lhs_batch_dims={0}, rhs_batch_dims={0},
5013 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5014 sharding={devices=[4,1,1]0,1,2,3}
5015 ROOT %t = tuple(%dot, %dot2)
5016 })";
5017
5018 TF_ASSERT_OK_AND_ASSIGN(
5019 auto module,
5020 PartitionComputation(hlo_string, /*num_devices=*/4,
5021 /*conv_halo_exchange_always_on_lhs =*/true,
5022 /*choose_faster_windowed_einsum =*/false,
5023 /*unroll_windowed_einsum =*/false,
5024 /*bidirectional_windowed_einsum =*/true));
5025 VLOG(1) << module->ToString();
5026 const auto root = module->entry_computation()->root_instruction();
5027 EXPECT_THAT(
5028 root,
5029 op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
5030 _, op::Slice(op::GetTupleElement(op::While(_))), _, _, _)),
5031 op::AllReduce(op::DynamicUpdateSlice(
5032 _, op::Dot(_, op::Slice(_)), _, _, _))));
5033
5034 // Tuple<-AllReduce<-DynamicUpdateSlice<-Slice<-GetTupleElement<-While
5035 const auto while_loop =
5036 root->operand(0)->operand(0)->operand(1)->operand(0)->operand(0);
5037 // Check loop condition.
5038 EXPECT_THAT(
5039 while_loop->while_condition()->root_instruction(),
5040 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5041
5042 // Check loop body.
5043 const auto next_i =
5044 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5045 op::Constant());
5046 const auto partial_dot_pattern =
5047 AllOf(op::Reshape(op::Slice(op::Dot(op::GetTupleElement(op::Parameter(0)),
5048 op::Concatenate()))),
5049 op::Shape("f32[32,6,9824]"));
5050 auto intermediate_output1 =
5051 AllOf(op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5052 partial_dot_pattern, op::Constant(),
5053 op::Constant(), op::Reshape()),
5054 op::Shape("f32[32,6,39296]"));
5055 auto intermediate_output2 = AllOf(
5056 op::DynamicUpdateSlice(intermediate_output1, partial_dot_pattern,
5057 op::Constant(), op::Constant(), op::Reshape()),
5058 op::Shape("f32[32,6,39296]"));
5059 auto intermediate_output3 = AllOf(
5060 op::DynamicUpdateSlice(intermediate_output2, partial_dot_pattern,
5061 op::Constant(), op::Constant(), op::Reshape()),
5062 op::Shape("f32[32,6,39296]"));
5063 auto partial_output = AllOf(
5064 op::DynamicUpdateSlice(intermediate_output3, partial_dot_pattern,
5065 op::Constant(), op::Constant(), op::Reshape()),
5066 op::Shape("f32[32,6,39296]"));
5067
5068 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5069 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5070 op::CollectivePermute(op::CollectivePermute(
5071 op::GetTupleElement(op::Parameter(0)))),
5072 partial_output,
5073 op::CollectivePermute(op::CollectivePermute(
5074 op::GetTupleElement(op::Parameter(0)))),
5075 next_i));
5076 }
5077
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContracting)5078 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) {
5079 absl::string_view hlo_string = R"(
5080 HloModule module
5081
5082 ENTRY entry {
5083 %lhs = f32[32,24,64,128] parameter(0)
5084 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5085 %rhs = f32[32,39295,64,128] parameter(1)
5086 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5087 ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5088 lhs_batch_dims={0}, rhs_batch_dims={0},
5089 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5090 sharding={devices=[1,2,1]0,1}
5091 })";
5092
5093 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
5094 /*num_devices=*/2));
5095 VLOG(1) << module->ToString();
5096 const auto root = module->entry_computation()->root_instruction();
5097 const auto lhs = AllOf(
5098 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5099 op::Constant(), op::Constant())),
5100 op::Shape("f32[32,12,64,128]"));
5101 const auto rhs =
5102 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5103 op::Constant(), op::Reshape(),
5104 op::Constant(), op::Constant())),
5105 op::Shape("f32[32,19648,64,128]"));
5106 EXPECT_THAT(root,
5107 AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
5108 lhs, rhs, op::Broadcast(),
5109 op::Broadcast(), op::Constant()))),
5110 op::Shape("f32[32,12,39296]"))),
5111 op::Shape("f32[32,12,39295]")));
5112 const auto while_loop = root->operand(0)->operand(0);
5113 // Check loop condition.
5114 EXPECT_THAT(
5115 while_loop->while_condition()->root_instruction(),
5116 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5117
5118 // Check loop body.
5119 const auto next_i =
5120 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5121 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
5122 op::GetTupleElement(op::Parameter(0)),
5123 op::GetTupleElement(op::Parameter(0)));
5124 auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)),
5125 op::GetTupleElement(op::Parameter(0)));
5126 EXPECT_THAT(
5127 while_loop->while_body()->root_instruction(),
5128 op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
5129 op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5130 partial_output, op::Constant(),
5131 op::Constant(), op::Reshape()),
5132 op::GetTupleElement(op::Parameter(0)), next_i));
5133
5134 // Check the conditional that contains the collective permute.
5135 auto cp_conditional =
5136 while_loop->while_body()->root_instruction()->operand(1);
5137 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
5138 op::CollectivePermute(op::Parameter(0)));
5139 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
5140 op::Parameter(0));
5141 }
5142
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContracting)5143 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedNonContracting) {
5144 absl::string_view hlo_string = R"(
5145 HloModule module
5146
5147 ENTRY entry {
5148 %lhs = f32[32,24,64,128] parameter(0)
5149 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5150 %rhs = f32[32,39295,64,128] parameter(1)
5151 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5152 ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5153 lhs_batch_dims={0}, rhs_batch_dims={0},
5154 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5155 sharding={devices=[1,2,1]0,1}
5156 })";
5157
5158 TF_ASSERT_OK_AND_ASSIGN(
5159 auto module,
5160 PartitionComputation(hlo_string, /*num_devices=*/2,
5161 /*conv_halo_exchange_always_on_lhs =*/true,
5162 /*choose_faster_windowed_einsum =*/false,
5163 /*unroll_windowed_einsum =*/true));
5164 VLOG(1) << module->ToString();
5165
5166 const auto root = module->entry_computation()->root_instruction();
5167 const auto lhs = AllOf(
5168 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5169 op::Constant(), op::Constant())),
5170 op::Shape("f32[32,12,64,128]"));
5171 const auto rhs =
5172 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5173 op::Constant(), op::Reshape(),
5174 op::Constant(), op::Constant())),
5175 op::Shape("f32[32,19648,64,128]"));
5176 EXPECT_THAT(root,
5177 AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
5178 lhs, rhs, op::Broadcast(),
5179 op::Broadcast(), op::Constant()))),
5180 op::Shape("f32[32,12,39296]"))),
5181 op::Shape("f32[32,12,39295]")));
5182
5183 const auto while_loop = root->operand(0)->operand(0);
5184 // Check loop condition.
5185 EXPECT_THAT(
5186 while_loop->while_condition()->root_instruction(),
5187 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5188
5189 // Check loop body.
5190 const auto next_i =
5191 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5192 op::Constant());
5193 auto intermediate_output = AllOf(
5194 op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5195 op::Dot(op::GetTupleElement(op::Parameter(0)),
5196 op::GetTupleElement(op::Parameter(0))),
5197 op::Constant(), op::Constant(), op::Reshape()),
5198 op::Shape("f32[32,12,39296]"));
5199 auto output = AllOf(
5200 op::DynamicUpdateSlice(
5201 intermediate_output,
5202 op::Dot(op::GetTupleElement(op::Parameter(0)),
5203 op::CollectivePermute(op::GetTupleElement(op::Parameter(0)))),
5204 op::Constant(), op::Constant(), op::Reshape()),
5205 op::Shape("f32[32,12,39296]"));
5206
5207 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5208 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5209 op::CollectivePermute(op::CollectivePermute(
5210 op::GetTupleElement(op::Parameter(0)))),
5211 output, op::GetTupleElement(op::Parameter(0)), next_i));
5212 }
5213
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContracting)5214 TEST_F(SpmdPartitioningTest, BidirectionalEinsumRHSWindowedNonContracting) {
5215 absl::string_view hlo_string = R"(
5216 HloModule module
5217
5218 ENTRY entry {
5219 %lhs = f32[32,24,64,128] parameter(0)
5220 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5221 %rhs = f32[32,39295,64,128] parameter(1)
5222 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
5223 ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5224 lhs_batch_dims={0}, rhs_batch_dims={0},
5225 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5226 sharding={devices=[1,4,1]0,1,2,3}
5227 })";
5228
5229 TF_ASSERT_OK_AND_ASSIGN(
5230 auto module,
5231 PartitionComputation(hlo_string, /*num_devices=*/4,
5232 /*conv_halo_exchange_always_on_lhs =*/true,
5233 /*choose_faster_windowed_einsum =*/false,
5234 /*unroll_windowed_einsum =*/false,
5235 /*bidirectional_windowed_einsum =*/true));
5236 VLOG(1) << module->ToString();
5237
5238 const auto root = module->entry_computation()->root_instruction();
5239 const auto lhs = AllOf(
5240 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5241 op::Constant(), op::Constant())),
5242 op::Shape("f32[32,6,64,128]"));
5243 const auto rhs =
5244 AllOf(op::Reshape(op::Copy(op::DynamicSlice(
5245 op::Pad(op::Parameter(1), op::Constant()), op::Constant(),
5246 op::Reshape(), op::Constant(), op::Constant()))),
5247 op::Shape("f32[32,1,9824,64,128]"));
5248 EXPECT_THAT(
5249 root,
5250 AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
5251 lhs, rhs, op::Broadcast(),
5252 op::CollectivePermute(rhs), op::Constant()))),
5253 op::Shape("f32[32,6,39296]"))),
5254 op::Shape("f32[32,6,39295]")));
5255
5256 const auto while_loop = root->operand(0)->operand(0);
5257 // Check loop condition.
5258 EXPECT_THAT(
5259 while_loop->while_condition()->root_instruction(),
5260 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5261
5262 // Check loop body.
5263 const auto next_i =
5264 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5265 op::Constant());
5266 const auto partial_dot_pattern =
5267 AllOf(op::Reshape(op::Slice(op::Dot(op::GetTupleElement(op::Parameter(0)),
5268 op::Concatenate()))),
5269 op::Shape("f32[32,6,9824]"));
5270 auto intermediate_output1 =
5271 AllOf(op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
5272 partial_dot_pattern, op::Constant(),
5273 op::Constant(), op::Reshape()),
5274 op::Shape("f32[32,6,39296]"));
5275 auto intermediate_output2 = AllOf(
5276 op::DynamicUpdateSlice(intermediate_output1, partial_dot_pattern,
5277 op::Constant(), op::Constant(), op::Reshape()),
5278 op::Shape("f32[32,6,39296]"));
5279 auto intermediate_output3 = AllOf(
5280 op::DynamicUpdateSlice(intermediate_output2, partial_dot_pattern,
5281 op::Constant(), op::Constant(), op::Reshape()),
5282 op::Shape("f32[32,6,39296]"));
5283 auto partial_output = AllOf(
5284 op::DynamicUpdateSlice(intermediate_output3, partial_dot_pattern,
5285 op::Constant(), op::Constant(), op::Reshape()),
5286 op::Shape("f32[32,6,39296]"));
5287
5288 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5289 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5290 op::CollectivePermute(op::CollectivePermute(
5291 op::GetTupleElement(op::Parameter(0)))),
5292 partial_output,
5293 op::CollectivePermute(op::CollectivePermute(
5294 op::GetTupleElement(op::Parameter(0)))),
5295 next_i));
5296 }
5297
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContracting)5298 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) {
5299 absl::string_view hlo_string = R"(
5300 HloModule module
5301
5302 ENTRY entry {
5303 %lhs = f32[32,24,63,128] parameter(0)
5304 %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5305 %rhs = f32[32,39296,63,128] parameter(1)
5306 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
5307 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
5308 lhs_batch_dims={0}, rhs_batch_dims={0},
5309 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5310 sharding={devices=[1,2,1]0,1}
5311 })";
5312
5313 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
5314 /*num_devices=*/2));
5315 VLOG(1) << module->ToString();
5316 const auto root = module->entry_computation()->root_instruction();
5317 const auto lhs = AllOf(
5318 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5319 op::Constant(), op::Constant())),
5320 op::Shape("f32[32,12,63,128]"));
5321 const auto rhs =
5322 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5323 op::Constant(), op::Constant(),
5324 op::Reshape(), op::Constant())),
5325 op::Shape("f32[32,39296,32,128]"));
5326 auto masked_rhs =
5327 op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
5328 EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
5329 op::Tuple(lhs, masked_rhs, op::Broadcast(),
5330 op::Broadcast(), op::Constant()))),
5331 op::Shape("f32[32,12,39296]")));
5332 const auto while_loop = root->operand(0);
5333 // Check loop condition.
5334 EXPECT_THAT(
5335 while_loop->while_condition()->root_instruction(),
5336 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5337
5338 // Check loop body.
5339 const auto next_i =
5340 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5341 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
5342 op::GetTupleElement(op::Parameter(0)),
5343 op::GetTupleElement(op::Parameter(0)));
5344 auto partial_output = op::Dot(
5345 op::DynamicSlice(
5346 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5347 op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
5348 op::GetTupleElement(op::Parameter(0)));
5349 EXPECT_THAT(
5350 while_loop->while_body()->root_instruction(),
5351 op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
5352 op::Add(op::GetTupleElement(op::Parameter(0)), partial_output),
5353 op::GetTupleElement(op::Parameter(0)), next_i));
5354
5355 // Check the conditional that contains the collective permute.
5356 auto cp_conditional =
5357 while_loop->while_body()->root_instruction()->operand(1);
5358 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
5359 op::CollectivePermute(op::Parameter(0)));
5360 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
5361 op::Parameter(0));
5362 }
5363
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedContracting)5364 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedContracting) {
5365 absl::string_view hlo_string = R"(
5366 HloModule module
5367
5368 ENTRY entry {
5369 %lhs = f32[32,24,63,128] parameter(0)
5370 %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5371 %rhs = f32[32,39296,63,128] parameter(1)
5372 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
5373 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
5374 lhs_batch_dims={0}, rhs_batch_dims={0},
5375 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5376 sharding={devices=[1,2,1]0,1}
5377 })";
5378
5379 TF_ASSERT_OK_AND_ASSIGN(
5380 auto module,
5381 PartitionComputation(hlo_string, /*num_devices=*/2,
5382 /*conv_halo_exchange_always_on_lhs =*/true,
5383 /*choose_faster_windowed_einsum =*/false,
5384 /*unroll_windowed_einsum =*/true));
5385 VLOG(1) << module->ToString();
5386
5387 const auto root = module->entry_computation()->root_instruction();
5388 const auto lhs = AllOf(
5389 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5390 op::Constant(), op::Constant())),
5391 op::Shape("f32[32,12,63,128]"));
5392 const auto rhs =
5393 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5394 op::Constant(), op::Constant(),
5395 op::Reshape(), op::Constant())),
5396 op::Shape("f32[32,39296,32,128]"));
5397 auto masked_rhs =
5398 op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
5399 EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
5400 op::Tuple(lhs, masked_rhs, op::Broadcast(),
5401 op::Broadcast(), op::Constant()))),
5402 op::Shape("f32[32,12,39296]")));
5403 const auto while_loop = root->operand(0);
5404 // Check loop condition.
5405 EXPECT_THAT(
5406 while_loop->while_condition()->root_instruction(),
5407 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5408
5409 // Check loop body.
5410 const auto next_i =
5411 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5412 op::Constant());
5413 auto intermediate_output =
5414 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
5415 op::Dot(op::DynamicSlice(
5416 op::Pad(op::GetTupleElement(op::Parameter(0)),
5417 op::Constant()),
5418 op::Constant(), op::Constant(), op::Reshape(),
5419 op::Constant()),
5420 op::GetTupleElement(op::Parameter(0)))),
5421 op::Shape("f32[32,12,39296]"));
5422 auto output = AllOf(
5423 op::Add(
5424 intermediate_output,
5425 op::Dot(
5426 op::DynamicSlice(op::Pad(op::GetTupleElement(op::Parameter(0)),
5427 op::Constant()),
5428 op::Constant(), op::Constant(), op::Reshape(),
5429 op::Constant()),
5430 op::CollectivePermute(op::GetTupleElement(op::Parameter(0))))),
5431 op::Shape("f32[32,12,39296]"));
5432
5433 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5434 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5435 op::CollectivePermute(op::CollectivePermute(
5436 op::GetTupleElement(op::Parameter(0)))),
5437 output, op::GetTupleElement(op::Parameter(0)), next_i));
5438 }
5439
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedContracting)5440 TEST_F(SpmdPartitioningTest, BidirectionalEinsumRHSWindowedContracting) {
5441 absl::string_view hlo_string = R"(
5442 HloModule module
5443
5444 ENTRY entry {
5445 %lhs = f32[32,24,63,128] parameter(0)
5446 %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5447 %rhs = f32[32,39296,63,128] parameter(1)
5448 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,4,1]0,1,2,3}
5449 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
5450 lhs_batch_dims={0}, rhs_batch_dims={0},
5451 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5452 sharding={devices=[1,4,1]0,1,2,3}
5453 })";
5454
5455 TF_ASSERT_OK_AND_ASSIGN(
5456 auto module,
5457 PartitionComputation(hlo_string, /*num_devices=*/4,
5458 /*conv_halo_exchange_always_on_lhs =*/true,
5459 /*choose_faster_windowed_einsum =*/false,
5460 /*unroll_windowed_einsum =*/false,
5461 /*bidirectional_windowed_einsum =*/true));
5462 VLOG(1) << module->ToString();
5463
5464 const auto root = module->entry_computation()->root_instruction();
5465 const auto lhs = AllOf(
5466 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5467 op::Constant(), op::Constant())),
5468 op::Shape("f32[32,6,63,128]"));
5469 const auto rhs =
5470 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5471 op::Constant(), op::Constant(),
5472 op::Reshape(), op::Constant())),
5473 op::Shape("f32[32,39296,16,128]"));
5474 auto masked_rhs = op::Reshape(
5475 op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())));
5476 EXPECT_THAT(root,
5477 AllOf(op::GetTupleElement(op::While(op::Tuple(
5478 lhs, masked_rhs, op::Broadcast(),
5479 op::CollectivePermute(masked_rhs), op::Constant()))),
5480 op::Shape("f32[32,6,39296]")));
5481 const auto while_loop = root->operand(0);
5482 // Check loop condition.
5483 EXPECT_THAT(
5484 while_loop->while_condition()->root_instruction(),
5485 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5486
5487 // Check loop body.
5488 const auto next_i =
5489 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5490 op::Constant());
5491 auto partial_output =
5492 AllOf(op::Add(op::Add(op::GetTupleElement(op::Parameter(0)),
5493 op::Dot(op::Maximum(), op::Concatenate())),
5494 op::Dot(op::Maximum(), op::Concatenate())),
5495 op::Shape("f32[32,6,39296]"));
5496
5497 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5498 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5499 op::CollectivePermute(op::CollectivePermute(
5500 op::GetTupleElement(op::Parameter(0)))),
5501 partial_output,
5502 op::CollectivePermute(op::CollectivePermute(
5503 op::GetTupleElement(op::Parameter(0)))),
5504 next_i));
5505 }
5506
TEST_F(SpmdPartitioningTest,EinsumWindowedNonContractingDimensionsNoCodeMotionWithDependentNodes)5507 TEST_F(SpmdPartitioningTest,
5508 EinsumWindowedNonContractingDimensionsNoCodeMotionWithDependentNodes) {
5509 absl::string_view hlo_string = R"(
5510 HloModule module
5511
5512 sum {
5513 a = f32[] parameter(0)
5514 b = f32[] parameter(1)
5515 ROOT add = f32[] add(a, b)
5516 }
5517
5518 ENTRY entry {
5519 %lhs = f32[32,24,64,128] parameter(0)
5520 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5521 %rhs = f32[32,39295,64,128] parameter(1)
5522 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5523 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5524 lhs_batch_dims={0}, rhs_batch_dims={0},
5525 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5526 sharding={devices=[1,2,1]0,1}
5527 %constant = f32[] constant(0)
5528 %constant.1 = f32[] constant(2)
5529 %constant.2 = f32[] constant(4)
5530 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5531 sharding={devices=[1,2,1]0,1}
5532 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5533 sharding={devices=[1,2,1]0,1}
5534 %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5535 to_apply=sum, sharding={devices=[1,2]0,1}
5536 %all-reduce = f32[32,24] all-reduce(%reduce),
5537 to_apply=sum, sharding={devices=[1,2]0,1}
5538 %broadcast.1 = f32[32,24,39295] broadcast(%all-reduce), dimensions={0,1},
5539 sharding={devices=[1,2,1]0,1}
5540 %subtract = f32[32,24,39295] subtract(%multiply, %broadcast.1),
5541 sharding={devices=[1,2,1]0,1}
5542 ROOT %reduce.1 = f32[32,24] reduce(%subtract, %constant.2), dimensions={2},
5543 to_apply=sum, sharding={devices=[1,2]0,1}
5544 })";
5545
5546 TF_ASSERT_OK_AND_ASSIGN(auto module,
5547 PartitionComputation(hlo_string, /*num_devices=*/2));
5548 VLOG(1) << module->ToString();
5549
5550 const auto root = module->entry_computation()->root_instruction();
5551 const auto lhs = AllOf(
5552 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5553 op::Constant(), op::Constant())),
5554 op::Shape("f32[32,12,64,128]"));
5555 const auto rhs =
5556 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5557 op::Constant(), op::Reshape(),
5558 op::Constant(), op::Constant())),
5559 op::Shape("f32[32,19648,64,128]"));
5560 const auto while_output =
5561 AllOf(op::Slice(op::GetTupleElement(op::While(op::Tuple(
5562 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant())))),
5563 op::Shape("f32[32,12,39295]"));
5564 // All the multiples, subtracts and reduces should remain in the spmd entry
5565 // computation.
5566 const auto multiply =
5567 AllOf(op::Multiply(while_output, op::Broadcast(op::Constant())),
5568 op::Shape("f32[32,12,39295]"));
5569 EXPECT_THAT(
5570 root,
5571 AllOf(op::Reduce(
5572 op::Subtract(multiply, op::Broadcast(op::AllReduce(op::Reduce(
5573 multiply, op::Constant())))),
5574 op::Constant()),
5575 op::Shape("f32[32,12]")));
5576
5577 const auto while_loop =
5578 root->operand(0)->operand(0)->operand(0)->operand(0)->operand(0);
5579 // Check loop condition.
5580 EXPECT_THAT(
5581 while_loop->while_condition()->root_instruction(),
5582 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5583
5584 // Check loop body. There is not be any multple, subtract, reduce, etc.
5585 // that has been moved into the loop body.
5586 const auto next_i =
5587 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5588 auto output = op::DynamicUpdateSlice(
5589 op::GetTupleElement(op::Parameter(0)),
5590 op::Dot(op::GetTupleElement(op::Parameter(0)),
5591 op::GetTupleElement(op::Parameter(0))),
5592 op::Constant(), op::Constant(), op::Reshape(op::DynamicSlice()));
5593
5594 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5595 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5596 op::Conditional(op::Compare(next_i, op::Constant()),
5597 op::GetTupleElement(op::Parameter(0)),
5598 op::GetTupleElement(op::Parameter(0))),
5599 output, op::GetTupleElement(op::Parameter(0)), next_i));
5600 }
5601
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce1)5602 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) {
5603 absl::string_view hlo_string = R"(
5604 HloModule module
5605
5606 sum {
5607 a = f32[] parameter(0)
5608 b = f32[] parameter(1)
5609 ROOT add = f32[] add(a, b)
5610 }
5611
5612 ENTRY entry {
5613 %lhs = f32[32,24,64,128] parameter(0)
5614 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5615 %rhs = f32[32,39295,64,128] parameter(1)
5616 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5617 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5618 lhs_batch_dims={0}, rhs_batch_dims={0},
5619 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5620 sharding={devices=[1,2,1]0,1}
5621 %constant = f32[] constant(0)
5622 %constant.1 = f32[] constant(2)
5623 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5624 sharding={devices=[1,2,1]0,1}
5625 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5626 sharding={devices=[1,2,1]0,1}
5627 ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5628 to_apply=sum, sharding={devices=[1,2]0,1}
5629 })";
5630
5631 TF_ASSERT_OK_AND_ASSIGN(auto module,
5632 PartitionComputation(hlo_string, /*num_devices=*/2));
5633 VLOG(1) << module->ToString();
5634
5635 const auto root = module->entry_computation()->root_instruction();
5636 const auto lhs = AllOf(
5637 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5638 op::Constant(), op::Constant())),
5639 op::Shape("f32[32,12,64,128]"));
5640 const auto rhs =
5641 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5642 op::Constant(), op::Reshape(),
5643 op::Constant(), op::Constant())),
5644 op::Shape("f32[32,19648,64,128]"));
5645 auto input_subtuple =
5646 op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5647 EXPECT_THAT(
5648 root,
5649 AllOf(op::GetTupleElement(op::GetTupleElement(op::While(op::Tuple(
5650 lhs, rhs, input_subtuple, op::Broadcast(), op::Constant())))),
5651 op::Shape("f32[32,12]")));
5652
5653 const auto while_loop = root->operand(0)->operand(0);
5654 // Check loop condition.
5655 EXPECT_THAT(
5656 while_loop->while_condition()->root_instruction(),
5657 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5658
5659 // Check loop body.
5660 const auto next_i =
5661 op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
5662 auto output_tuple = op::Tuple(
5663 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5664 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5665 op::Add(op::Reduce(
5666 op::Select(op::Compare(),
5667 op::Multiply(
5668 op::Dot(op::GetTupleElement(op::Parameter(0)),
5669 op::GetTupleElement(op::Parameter(0))),
5670 op::DynamicSlice()),
5671 op::Broadcast()),
5672 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5673 op::DynamicSlice(
5674 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5675 op::Constant(), op::Constant())));
5676
5677 EXPECT_THAT(
5678 while_loop->while_body()->root_instruction(),
5679 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5680 op::Conditional(op::Compare(next_i, op::Constant()),
5681 op::GetTupleElement(op::Parameter(0)),
5682 op::GetTupleElement(op::Parameter(0))),
5683 output_tuple, op::GetTupleElement(op::Parameter(0)), next_i));
5684 }
5685
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContractingReduce1)5686 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedNonContractingReduce1) {
5687 absl::string_view hlo_string = R"(
5688 HloModule module
5689
5690 sum {
5691 a = f32[] parameter(0)
5692 b = f32[] parameter(1)
5693 ROOT add = f32[] add(a, b)
5694 }
5695
5696 ENTRY entry {
5697 %lhs = f32[32,24,64,128] parameter(0)
5698 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5699 %rhs = f32[32,39295,64,128] parameter(1)
5700 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5701 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5702 lhs_batch_dims={0}, rhs_batch_dims={0},
5703 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5704 sharding={devices=[1,2,1]0,1}
5705 %constant = f32[] constant(0)
5706 %constant.1 = f32[] constant(2)
5707 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5708 sharding={devices=[1,2,1]0,1}
5709 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5710 sharding={devices=[1,2,1]0,1}
5711 ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5712 to_apply=sum, sharding={devices=[1,2]0,1}
5713 })";
5714
5715 TF_ASSERT_OK_AND_ASSIGN(
5716 auto module,
5717 PartitionComputation(hlo_string, /*num_devices=*/2,
5718 /*conv_halo_exchange_always_on_lhs =*/true,
5719 /*choose_faster_windowed_einsum =*/false,
5720 /*unroll_windowed_einsum =*/true));
5721 VLOG(1) << module->ToString();
5722
5723 const auto root = module->entry_computation()->root_instruction();
5724 const auto lhs = AllOf(
5725 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5726 op::Constant(), op::Constant())),
5727 op::Shape("f32[32,12,64,128]"));
5728 const auto rhs =
5729 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5730 op::Constant(), op::Reshape(),
5731 op::Constant(), op::Constant())),
5732 op::Shape("f32[32,19648,64,128]"));
5733 auto input_subtuple =
5734 op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5735 EXPECT_THAT(
5736 root,
5737 AllOf(op::GetTupleElement(op::GetTupleElement(op::While(op::Tuple(
5738 lhs, rhs, input_subtuple, op::Broadcast(), op::Constant())))),
5739 op::Shape("f32[32,12]")));
5740
5741 const auto while_loop = root->operand(0)->operand(0);
5742 // Check loop condition.
5743 EXPECT_THAT(
5744 while_loop->while_condition()->root_instruction(),
5745 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5746
5747 // Check loop body.
5748 const auto next_i =
5749 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5750 op::Constant());
5751 auto intermediate_output = AllOf(
5752 op::Add(
5753 op::Reduce(
5754 op::Select(op::Compare(),
5755 op::Multiply(
5756 op::Dot(op::GetTupleElement(op::Parameter(0)),
5757 op::CollectivePermute(op::GetTupleElement(
5758 op::Parameter(0)))),
5759 op::DynamicSlice()),
5760 op::Broadcast()),
5761 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5762 op::DynamicSlice(
5763 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5764 op::Constant(), op::Constant())),
5765 op::Shape("f32[32,12]"));
5766 auto output_tuple = op::Tuple(
5767 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5768 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5769 op::Add(op::Reduce(
5770 op::Select(op::Compare(),
5771 op::Multiply(
5772 op::Dot(op::GetTupleElement(op::Parameter(0)),
5773 op::GetTupleElement(op::Parameter(0))),
5774 op::DynamicSlice()),
5775 op::Broadcast()),
5776 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5777 op::DynamicSlice(intermediate_output, op::Constant(),
5778 op::Constant())));
5779
5780 EXPECT_THAT(
5781 while_loop->while_body()->root_instruction(),
5782 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5783 op::CollectivePermute(op::CollectivePermute(
5784 op::GetTupleElement(op::Parameter(0)))),
5785 output_tuple, op::GetTupleElement(op::Parameter(0)), next_i));
5786 }
5787
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContractingReduce1)5788 TEST_F(SpmdPartitioningTest,
5789 BidirectionalEinsumRHSWindowedNonContractingReduce1) {
5790 absl::string_view hlo_string = R"(
5791 HloModule module
5792
5793 sum {
5794 a = f32[] parameter(0)
5795 b = f32[] parameter(1)
5796 ROOT add = f32[] add(a, b)
5797 }
5798
5799 ENTRY entry {
5800 %lhs = f32[32,24,64,128] parameter(0)
5801 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
5802 %rhs = f32[32,39295,64,128] parameter(1)
5803 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
5804 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5805 lhs_batch_dims={0}, rhs_batch_dims={0},
5806 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5807 sharding={devices=[1,4,1]0,1,2,3}
5808 %constant = f32[] constant(0)
5809 %constant.1 = f32[] constant(2)
5810 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5811 sharding={devices=[1,4,1]0,1,2,3}
5812 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5813 sharding={devices=[1,4,1]0,1,2,3}
5814 ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
5815 to_apply=sum, sharding={devices=[1,4]0,1,2,3}
5816 })";
5817
5818 TF_ASSERT_OK_AND_ASSIGN(
5819 auto module,
5820 PartitionComputation(hlo_string, /*num_devices=*/4,
5821 /*conv_halo_exchange_always_on_lhs =*/true,
5822 /*choose_faster_windowed_einsum =*/false,
5823 /*unroll_windowed_einsum =*/false,
5824 /*bidirectional_windowed_einsum =*/true));
5825 VLOG(1) << module->ToString();
5826
5827 const auto root = module->entry_computation()->root_instruction();
5828 const auto lhs = AllOf(
5829 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5830 op::Constant(), op::Constant())),
5831 op::Shape("f32[32,6,64,128]"));
5832 const auto rhs =
5833 AllOf(op::Reshape(op::Copy(op::DynamicSlice(
5834 op::Pad(op::Parameter(1), op::Constant()), op::Constant(),
5835 op::Reshape(), op::Constant(), op::Constant()))),
5836 op::Shape("f32[32,1,9824,64,128]"));
5837 auto input_subtuple =
5838 op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5839 EXPECT_THAT(root,
5840 AllOf(op::GetTupleElement(op::GetTupleElement(op::While(
5841 op::Tuple(lhs, rhs, input_subtuple,
5842 op::CollectivePermute(), op::Constant())))),
5843 op::Shape("f32[32,6]")));
5844
5845 const auto while_loop = root->operand(0)->operand(0);
5846 // Check loop condition.
5847 EXPECT_THAT(
5848 while_loop->while_condition()->root_instruction(),
5849 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5850
5851 // Check loop body.
5852 const auto next_i =
5853 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5854 op::Constant());
5855 auto partial_reduce_pattern = AllOf(
5856 op::Reduce(
5857 op::Select(op::Compare(),
5858 op::Multiply(op::Reshape(op::Slice(op::Dot(
5859 op::GetTupleElement(op::Parameter(0)),
5860 op::Concatenate()))),
5861 op::DynamicSlice()),
5862 op::Broadcast()),
5863 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
5864 op::Shape("f32[32,6]"));
5865 auto intermediate_output1 = AllOf(
5866 op::Add(partial_reduce_pattern,
5867 op::DynamicSlice(
5868 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5869 op::Constant(), op::Constant())),
5870 op::Shape("f32[32,6]"));
5871 auto intermediate_output2 =
5872 AllOf(op::Add(partial_reduce_pattern,
5873 op::DynamicSlice(intermediate_output1, op::Constant(),
5874 op::Constant())),
5875 op::Shape("f32[32,6]"));
5876 auto intermediate_output3 =
5877 AllOf(op::Add(partial_reduce_pattern,
5878 op::DynamicSlice(intermediate_output2, op::Constant(),
5879 op::Constant())),
5880 op::Shape("f32[32,6]"));
5881 auto output_tuple =
5882 op::Tuple(op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5883 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
5884 op::Add(partial_reduce_pattern,
5885 op::DynamicSlice(intermediate_output3, op::Constant(),
5886 op::Constant())));
5887
5888 EXPECT_THAT(while_loop->while_body()->root_instruction(),
5889 op::Tuple(op::GetTupleElement(op::Parameter(0)),
5890 op::CollectivePermute(op::CollectivePermute(
5891 op::GetTupleElement(op::Parameter(0)))),
5892 output_tuple,
5893 op::CollectivePermute(op::CollectivePermute(
5894 op::GetTupleElement(op::Parameter(0)))),
5895 next_i));
5896 }
5897
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce2)5898 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) {
5899 absl::string_view hlo_string = R"(
5900 HloModule module
5901
5902 sum {
5903 a = f32[] parameter(0)
5904 b = f32[] parameter(1)
5905 ROOT add = f32[] add(a, b)
5906 }
5907
5908 ENTRY entry {
5909 %lhs = f32[32,24,64,128] parameter(0)
5910 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5911 %rhs = f32[32,39295,64,128] parameter(1)
5912 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5913 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5914 lhs_batch_dims={0}, rhs_batch_dims={0},
5915 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5916 sharding={devices=[1,2,1]0,1}
5917 %constant = f32[] constant(0)
5918 %constant.1 = f32[] constant(2)
5919 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5920 sharding={devices=[1,2,1]0,1}
5921 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5922 sharding={devices=[1,2,1]0,1}
5923 ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
5924 to_apply=sum, sharding={replicated}
5925 })";
5926
5927 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
5928 /*num_devices=*/2));
5929 VLOG(1) << module->ToString();
5930 // Involves loop code motion, skips pattern matching.
5931 }
5932
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedNonContractingReduce2)5933 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedNonContractingReduce2) {
5934 absl::string_view hlo_string = R"(
5935 HloModule module
5936
5937 sum {
5938 a = f32[] parameter(0)
5939 b = f32[] parameter(1)
5940 ROOT add = f32[] add(a, b)
5941 }
5942
5943 ENTRY entry {
5944 %lhs = f32[32,24,64,128] parameter(0)
5945 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
5946 %rhs = f32[32,39295,64,128] parameter(1)
5947 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
5948 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
5949 lhs_batch_dims={0}, rhs_batch_dims={0},
5950 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
5951 sharding={devices=[1,2,1]0,1}
5952 %constant = f32[] constant(0)
5953 %constant.1 = f32[] constant(2)
5954 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
5955 sharding={devices=[1,2,1]0,1}
5956 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
5957 sharding={devices=[1,2,1]0,1}
5958 ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
5959 to_apply=sum, sharding={replicated}
5960 })";
5961
5962 TF_ASSERT_OK_AND_ASSIGN(
5963 auto module,
5964 PartitionComputation(hlo_string, /*num_devices=*/2,
5965 /*conv_halo_exchange_always_on_lhs =*/true,
5966 /*choose_faster_windowed_einsum =*/false,
5967 /*unroll_windowed_einsum =*/true));
5968 VLOG(1) << module->ToString();
5969
5970 const auto root = module->entry_computation()->root_instruction();
5971 const auto lhs = AllOf(
5972 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
5973 op::Constant(), op::Constant())),
5974 op::Shape("f32[32,12,64,128]"));
5975 const auto rhs =
5976 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
5977 op::Constant(), op::Reshape(),
5978 op::Constant(), op::Constant())),
5979 op::Shape("f32[32,19648,64,128]"));
5980 auto input_subtuple =
5981 op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
5982 EXPECT_THAT(
5983 root,
5984 AllOf(op::AllReduce(op::Slice(op::GetTupleElement(op::GetTupleElement(
5985 op::While(op::Tuple(lhs, rhs, input_subtuple, op::Broadcast(),
5986 op::Constant())))))),
5987 op::Shape("f32[32,39295]")));
5988
5989 // AllReduce<-Slice<-GetTupleElement<-GetTupleElement<-While
5990 const auto while_loop = root->operand(0)->operand(0)->operand(0)->operand(0);
5991 // Check loop condition.
5992 EXPECT_THAT(
5993 while_loop->while_condition()->root_instruction(),
5994 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
5995
5996 // Check loop body.
5997 const auto next_i =
5998 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
5999 op::Constant());
6000 auto intermediate_output = AllOf(
6001 op::DynamicUpdateSlice(
6002 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6003 op::Reduce(
6004 op::Multiply(op::Dot(op::GetTupleElement(op::Parameter(0)),
6005 op::CollectivePermute(
6006 op::GetTupleElement(op::Parameter(0)))),
6007 op::DynamicSlice()),
6008 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6009 op::Constant(), op::Reshape()),
6010 op::Shape("f32[32,39296]"));
6011 auto output_tuple = op::Tuple(
6012 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6013 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6014 op::DynamicUpdateSlice(
6015 intermediate_output,
6016 op::Reduce(
6017 op::Multiply(op::Dot(op::GetTupleElement(op::Parameter(0)),
6018 op::GetTupleElement(op::Parameter(0))),
6019 op::DynamicSlice()),
6020 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6021 op::Constant(), op::Reshape()));
6022
6023 EXPECT_THAT(
6024 while_loop->while_body()->root_instruction(),
6025 op::Tuple(op::GetTupleElement(op::Parameter(0)),
6026 op::CollectivePermute(op::CollectivePermute(
6027 op::GetTupleElement(op::Parameter(0)))),
6028 output_tuple, op::GetTupleElement(op::Parameter(0)), next_i));
6029 }
6030
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedNonContractingReduce2)6031 TEST_F(SpmdPartitioningTest,
6032 BidirectionalEinsumRHSWindowedNonContractingReduce2) {
6033 absl::string_view hlo_string = R"(
6034 HloModule module
6035
6036 sum {
6037 a = f32[] parameter(0)
6038 b = f32[] parameter(1)
6039 ROOT add = f32[] add(a, b)
6040 }
6041
6042 ENTRY entry {
6043 %lhs = f32[32,24,64,128] parameter(0)
6044 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
6045 %rhs = f32[32,39295,64,128] parameter(1)
6046 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
6047 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
6048 lhs_batch_dims={0}, rhs_batch_dims={0},
6049 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6050 sharding={devices=[1,4,1]0,1,2,3}
6051 %constant = f32[] constant(0)
6052 %constant.1 = f32[] constant(2)
6053 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
6054 sharding={devices=[1,4,1]0,1,2,3}
6055 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
6056 sharding={devices=[1,4,1]0,1,2,3}
6057 ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
6058 to_apply=sum, sharding={replicated}
6059 })";
6060
6061 TF_ASSERT_OK_AND_ASSIGN(
6062 auto module,
6063 PartitionComputation(hlo_string, /*num_devices=*/4,
6064 /*conv_halo_exchange_always_on_lhs =*/true,
6065 /*choose_faster_windowed_einsum =*/false,
6066 /*unroll_windowed_einsum =*/false,
6067 /*bidirectional_windowed_einsum =*/true));
6068 VLOG(1) << module->ToString();
6069
6070 const auto root = module->entry_computation()->root_instruction();
6071 const auto lhs = AllOf(
6072 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
6073 op::Constant(), op::Constant())),
6074 op::Shape("f32[32,6,64,128]"));
6075 const auto rhs =
6076 AllOf(op::Reshape(op::Copy(op::DynamicSlice(
6077 op::Pad(op::Parameter(1), op::Constant()), op::Constant(),
6078 op::Reshape(), op::Constant(), op::Constant()))),
6079 op::Shape("f32[32,1,9824,64,128]"));
6080 auto input_subtuple =
6081 op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
6082 EXPECT_THAT(
6083 root, AllOf(op::AllReduce(op::Slice(op::GetTupleElement(
6084 op::GetTupleElement(op::While(op::Tuple(
6085 lhs, rhs, input_subtuple, op::CollectivePermute(rhs),
6086 op::Constant())))))),
6087 op::Shape("f32[32,39295]")));
6088
6089 // AllReduce<-Slice<-GetTupleElement<-GetTupleElement<-While
6090 const auto while_loop = root->operand(0)->operand(0)->operand(0)->operand(0);
6091 // Check loop condition.
6092 EXPECT_THAT(
6093 while_loop->while_condition()->root_instruction(),
6094 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
6095
6096 // Check loop body.
6097 const auto next_i =
6098 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
6099 op::Constant());
6100 auto partial_reduce_pattern = AllOf(
6101 op::Reduce(op::Multiply(op::Reshape(op::Slice(
6102 op::Dot(op::GetTupleElement(op::Parameter(0)),
6103 op::Concatenate()))),
6104 op::DynamicSlice(op::Broadcast(), op::Constant(),
6105 op::Constant(), op::Reshape())),
6106 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6107 op::Shape("f32[32,9824]"));
6108 auto intermediate_output1 =
6109 AllOf(op::DynamicUpdateSlice(
6110 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6111 partial_reduce_pattern, op::Constant(), op::Reshape()),
6112 op::Shape("f32[32,39296]"));
6113 auto intermediate_output2 =
6114 AllOf(op::DynamicUpdateSlice(intermediate_output1, partial_reduce_pattern,
6115 op::Constant(), op::Reshape()),
6116 op::Shape("f32[32,39296]"));
6117 auto intermediate_output3 =
6118 AllOf(op::DynamicUpdateSlice(intermediate_output2, partial_reduce_pattern,
6119 op::Constant(), op::Reshape()),
6120 op::Shape("f32[32,39296]"));
6121 auto output_tuple = op::Tuple(
6122 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6123 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))),
6124 op::DynamicUpdateSlice(intermediate_output3, partial_reduce_pattern,
6125 op::Constant(), op::Reshape()));
6126
6127 EXPECT_THAT(while_loop->while_body()->root_instruction(),
6128 op::Tuple(op::GetTupleElement(op::Parameter(0)),
6129 op::CollectivePermute(op::CollectivePermute(
6130 op::GetTupleElement(op::Parameter(0)))),
6131 output_tuple,
6132 op::CollectivePermute(op::CollectivePermute(
6133 op::GetTupleElement(op::Parameter(0)))),
6134 next_i));
6135 }
6136
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContractingFromBroadcast)6137 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) {
6138 absl::string_view hlo_string = R"(
6139 HloModule module
6140
6141 ENTRY entry {
6142 %rhs = f32[32,39296,63,128] parameter(0)
6143 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
6144 %constant.1 = f32[] constant(2)
6145 %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
6146 sharding={devices=[1,2,1,1]0,1}
6147 %add = f32[32,24,63,128] add(%broadcast, %broadcast),
6148 sharding={devices=[1,2,1,1]0,1}
6149 ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
6150 lhs_batch_dims={0}, rhs_batch_dims={0},
6151 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6152 sharding={devices=[1,2,1]0,1}
6153 })";
6154
6155 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
6156 /*num_devices=*/2));
6157 VLOG(1) << module->ToString();
6158 // Involves loop code motion, skips pattern matching.
6159 }
6160
TEST_F(SpmdPartitioningTest,UnrollEinsumRHSWindowedContractingFromBroadcast)6161 TEST_F(SpmdPartitioningTest, UnrollEinsumRHSWindowedContractingFromBroadcast) {
6162 absl::string_view hlo_string = R"(
6163 HloModule module
6164
6165 ENTRY entry {
6166 %rhs = f32[32,39296,63,128] parameter(0)
6167 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
6168 %constant.1 = f32[] constant(2)
6169 %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
6170 sharding={devices=[1,2,1,1]0,1}
6171 %add = f32[32,24,63,128] add(%broadcast, %broadcast),
6172 sharding={devices=[1,2,1,1]0,1}
6173 ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
6174 lhs_batch_dims={0}, rhs_batch_dims={0},
6175 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6176 sharding={devices=[1,2,1]0,1}
6177 })";
6178
6179 TF_ASSERT_OK_AND_ASSIGN(
6180 auto module,
6181 PartitionComputation(hlo_string, /*num_devices=*/2,
6182 /*conv_halo_exchange_always_on_lhs =*/true,
6183 /*choose_faster_windowed_einsum =*/false,
6184 /*unroll_windowed_einsum =*/true));
6185 VLOG(1) << module->ToString();
6186
6187 const auto root = module->entry_computation()->root_instruction();
6188 const auto lhs = op::Tuple(op::Constant());
6189 const auto rhs =
6190 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
6191 op::Constant(), op::Constant(),
6192 op::Reshape(), op::Constant())),
6193 op::Shape("f32[32,39296,32,128]"));
6194 auto masked_rhs =
6195 op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
6196 EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
6197 op::Tuple(lhs, masked_rhs, op::Broadcast(),
6198 op::Broadcast(), op::Constant()))),
6199 op::Shape("f32[32,12,39296]")));
6200
6201 const auto while_loop = root->operand(0);
6202 // Check loop condition.
6203 EXPECT_THAT(
6204 while_loop->while_condition()->root_instruction(),
6205 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
6206
6207 // Check loop body.
6208 const auto next_i =
6209 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
6210 op::Constant());
6211 auto padded_broadcast_sum = op::Pad(
6212 op::Add(op::Broadcast(
6213 op::GetTupleElement(op::GetTupleElement(op::Parameter(0)))),
6214 op::Broadcast(
6215 op::GetTupleElement(op::GetTupleElement(op::Parameter(0))))),
6216 op::Constant());
6217 auto intermediate_output =
6218 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
6219 op::Dot(op::DynamicSlice(padded_broadcast_sum,
6220 op::Constant(), op::Constant(),
6221 op::Reshape(), op::Constant()),
6222 op::GetTupleElement(op::Parameter(0)))),
6223 op::Shape("f32[32,12,39296]"));
6224 auto output = AllOf(
6225 op::Add(
6226 intermediate_output,
6227 op::Dot(
6228 op::DynamicSlice(padded_broadcast_sum, op::Constant(),
6229 op::Constant(), op::Reshape(), op::Constant()),
6230 op::CollectivePermute(op::GetTupleElement(op::Parameter(0))))),
6231 op::Shape("f32[32,12,39296]"));
6232
6233 EXPECT_THAT(while_loop->while_body()->root_instruction(),
6234 op::Tuple(op::GetTupleElement(op::Parameter(0)),
6235 op::CollectivePermute(op::CollectivePermute(
6236 op::GetTupleElement(op::Parameter(0)))),
6237 output, op::GetTupleElement(op::Parameter(0)), next_i));
6238 }
6239
TEST_F(SpmdPartitioningTest,BidirectionalEinsumRHSWindowedContractingFromBroadcast)6240 TEST_F(SpmdPartitioningTest,
6241 BidirectionalEinsumRHSWindowedContractingFromBroadcast) {
6242 absl::string_view hlo_string = R"(
6243 HloModule module
6244
6245 ENTRY entry {
6246 %rhs = f32[32,39296,63,128] parameter(0)
6247 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,4,1]0,1,2,3}
6248 %constant.1 = f32[] constant(2)
6249 %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
6250 sharding={devices=[1,4,1,1]0,1,2,3}
6251 %add = f32[32,24,63,128] add(%broadcast, %broadcast),
6252 sharding={devices=[1,4,1,1]0,1,2,3}
6253 ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
6254 lhs_batch_dims={0}, rhs_batch_dims={0},
6255 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
6256 sharding={devices=[1,4,1]0,1,2,3}
6257 })";
6258
6259 TF_ASSERT_OK_AND_ASSIGN(
6260 auto module,
6261 PartitionComputation(hlo_string, /*num_devices=*/4,
6262 /*conv_halo_exchange_always_on_lhs =*/true,
6263 /*choose_faster_windowed_einsum =*/false,
6264 /*unroll_windowed_einsum =*/false,
6265 /*bidirectional_windowed_einsum =*/true));
6266 VLOG(1) << module->ToString();
6267
6268 auto input_subtuple =
6269 op::Tuple(op::Constant(), op::Constant(), op::Broadcast(op::Constant()));
6270
6271 const auto root = module->entry_computation()->root_instruction();
6272 const auto lhs = op::Tuple(op::Constant());
6273 const auto rhs =
6274 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
6275 op::Constant(), op::Constant(),
6276 op::Reshape(), op::Constant())),
6277 op::Shape("f32[32,39296,16,128]"));
6278 auto masked_rhs = op::Reshape(
6279 op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())));
6280 EXPECT_THAT(root,
6281 AllOf(op::GetTupleElement(op::While(op::Tuple(
6282 lhs, masked_rhs, op::Broadcast(),
6283 op::CollectivePermute(masked_rhs), op::Constant()))),
6284 op::Shape("f32[32,6,39296]")));
6285
6286 const auto while_loop = root->operand(0);
6287 // Check loop condition.
6288 EXPECT_THAT(
6289 while_loop->while_condition()->root_instruction(),
6290 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
6291
6292 // Check loop body.
6293 const auto next_i =
6294 op::Add(op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()),
6295 op::Constant());
6296 auto output =
6297 AllOf(op::Add(op::Add(op::GetTupleElement(op::Parameter(0)),
6298 op::Dot(op::Maximum(), op::Concatenate())),
6299 op::Dot(op::Maximum(), op::Concatenate())),
6300 op::Shape("f32[32,6,39296]"));
6301
6302 EXPECT_THAT(while_loop->while_body()->root_instruction(),
6303 op::Tuple(op::GetTupleElement(op::Parameter(0)),
6304 op::CollectivePermute(op::CollectivePermute(
6305 op::GetTupleElement(op::Parameter(0)))),
6306 output,
6307 op::CollectivePermute(op::CollectivePermute(
6308 op::GetTupleElement(op::Parameter(0)))),
6309 next_i));
6310 }
6311
TEST_F(SpmdPartitioningTest,EinsumNonContractingDimPartitionOnTwoDims)6312 TEST_F(SpmdPartitioningTest, EinsumNonContractingDimPartitionOnTwoDims) {
6313 absl::string_view hlo_string = R"(
6314 HloModule module
6315
6316 ENTRY entry {
6317 %lhs = bf16[8,1024,2,1536] parameter(0)
6318 %lhs.copy = bf16[8,1024,2,1536] copy(lhs),
6319 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6320 %rhs = bf16[2,1536,512,1] parameter(1)
6321 %rhs.copy = bf16[2,1536,512,1] copy(rhs),
6322 sharding={devices=[2,1,2,1,2]0,4,2,6,1,5,3,7 last_tile_dim_replicate}
6323 ROOT %convolution = bf16[8,1024,512,1] convolution(lhs.copy, rhs.copy),
6324 window={size=1x2}, dim_labels=0b1f_1io0->0bf1,
6325 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6326 })";
6327
6328 TF_ASSERT_OK_AND_ASSIGN(auto module,
6329 PartitionComputation(hlo_string, /*num_devices=*/8));
6330 VLOG(1) << module->ToString();
6331 const auto root = module->entry_computation()->root_instruction();
6332 const auto lhs = AllOf(
6333 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant(),
6334 op::Reshape(), op::Constant())),
6335 op::Shape("bf16[2,1024,1,1536]"));
6336 const auto rhs = AllOf(
6337 op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
6338 op::Reshape(), op::Constant())),
6339 op::Shape("bf16[1,1536,256,1]"));
6340
6341 const auto partial_replicate_rhs =
6342 AllOf(op::AllReduce(op::DynamicUpdateSlice(
6343 op::Broadcast(), rhs, op::Constant(), op::Constant(),
6344 op::Reshape(), op::Constant())),
6345 op::Shape("bf16[1,1536,512,1]"));
6346 EXPECT_THAT(
6347 root,
6348 AllOf(op::DynamicSlice(
6349 op::AllReduce(op::Convolution(lhs, partial_replicate_rhs)),
6350 op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
6351 op::Shape("bf16[2,1024,256,1]")));
6352 }
6353
TEST_F(SpmdPartitioningTest,EinsumNonContractingDimPartitionOnTwoDims2)6354 TEST_F(SpmdPartitioningTest, EinsumNonContractingDimPartitionOnTwoDims2) {
6355 absl::string_view hlo_string = R"(
6356 HloModule module
6357
6358 ENTRY entry {
6359 %lhs = bf16[8,1024,2,1536] parameter(0)
6360 %lhs.copy = bf16[8,1024,2,1536] copy(lhs),
6361 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6362 %rhs = bf16[2,1536,512,1] parameter(1)
6363 %rhs.copy = bf16[2,1536,512,1] copy(rhs),
6364 sharding={devices=[2,1,2,1,2]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
6365 ROOT %convolution = bf16[8,1024,512,1] convolution(lhs.copy, rhs.copy),
6366 window={size=1x2}, dim_labels=0b1f_1io0->0bf1,
6367 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
6368 })";
6369
6370 TF_ASSERT_OK_AND_ASSIGN(auto module,
6371 PartitionComputation(hlo_string, /*num_devices=*/8));
6372 VLOG(1) << module->ToString();
6373 const auto root = module->entry_computation()->root_instruction();
6374 const auto lhs = AllOf(
6375 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant(),
6376 op::Reshape(), op::Constant())),
6377 op::Shape("bf16[2,1024,1,1536]"));
6378 const auto rhs = AllOf(
6379 op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
6380 op::Reshape(), op::Constant())),
6381 op::Shape("bf16[1,1536,256,1]"));
6382
6383 const auto partial_replicate_rhs =
6384 AllOf(op::AllReduce(op::DynamicUpdateSlice(
6385 op::Broadcast(), rhs, op::Constant(), op::Constant(),
6386 op::Reshape(), op::Constant())),
6387 op::Shape("bf16[1,1536,512,1]"));
6388 EXPECT_THAT(
6389 root,
6390 AllOf(op::DynamicSlice(
6391 op::AllReduce(op::Convolution(lhs, partial_replicate_rhs)),
6392 op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
6393 op::Shape("bf16[2,1024,256,1]")));
6394 }
6395
TEST_F(SpmdPartitioningTest,ReplicatedRng)6396 TEST_F(SpmdPartitioningTest, ReplicatedRng) {
6397 absl::string_view hlo_string = R"(
6398 HloModule module
6399
6400 ENTRY entry {
6401 %lhs = s32[] parameter(0)
6402 %lhs.copy = s32[] copy(%lhs), sharding={replicated}
6403 %rhs = s32[] parameter(1)
6404 %rhs.copy = s32[] copy(%rhs), sharding={replicated}
6405 ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
6406 distribution=rng_uniform, sharding={replicated}
6407 })";
6408
6409 TF_ASSERT_OK_AND_ASSIGN(auto module,
6410 PartitionComputation(hlo_string, /*num_devices=*/2));
6411 VLOG(1) << module->ToString();
6412
6413 const auto root = module->entry_computation()->root_instruction();
6414 const auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
6415 const auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]"));
6416 EXPECT_THAT(
6417 root,
6418 AllOf(op::AllReduce(op::Select(
6419 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
6420 op::Rng(), op::Broadcast(op::Constant()))),
6421 op::Shape("s32[4]")));
6422 }
6423
TEST_F(SpmdPartitioningTest,ManualRng)6424 TEST_F(SpmdPartitioningTest, ManualRng) {
6425 absl::string_view hlo_string = R"(
6426 HloModule module
6427
6428 ENTRY entry {
6429 %lhs = s32[] parameter(0), sharding={manual}
6430 %rhs = s32[] parameter(1), sharding={manual}
6431 ROOT %rng = s32[4]{0} rng(%lhs, %rhs),
6432 distribution=rng_uniform, sharding={manual}
6433 })";
6434
6435 TF_ASSERT_OK_AND_ASSIGN(auto module,
6436 PartitionComputation(hlo_string, /*num_devices=*/2));
6437 VLOG(1) << module->ToString();
6438
6439 const auto root = module->entry_computation()->root_instruction();
6440 EXPECT_THAT(root, AllOf(op::Rng(op::Parameter(0), op::Parameter(1)),
6441 op::Shape("s32[4]")));
6442 }
6443
TEST_F(SpmdPartitioningTest,PartitionedRng)6444 TEST_F(SpmdPartitioningTest, PartitionedRng) {
6445 absl::string_view hlo_string = R"(
6446 HloModule module
6447
6448 ENTRY entry {
6449 %lhs = s32[] parameter(0)
6450 %lhs.copy = s32[] copy(%lhs), sharding={replicated}
6451 %rhs = s32[] parameter(1)
6452 %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1}
6453 ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
6454 distribution=rng_uniform, sharding={devices=[2]0,1}
6455 })";
6456
6457 TF_ASSERT_OK_AND_ASSIGN(auto module,
6458 PartitionComputation(hlo_string, /*num_devices=*/2));
6459 VLOG(1) << module->ToString();
6460
6461 const auto root = module->entry_computation()->root_instruction();
6462 const auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
6463 const auto rhs =
6464 AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]"));
6465 EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select(
6466 op::Broadcast(op::Compare()), rhs,
6467 op::Broadcast(op::Constant())))),
6468 op::Shape("s32[2]")));
6469 }
6470
TEST_F(SpmdPartitioningTest,PartialReplicatedRng)6471 TEST_F(SpmdPartitioningTest, PartialReplicatedRng) {
6472 absl::string_view hlo_string = R"(
6473 HloModule module
6474
6475 ENTRY entry {
6476 %lhs = s32[] parameter(0), sharding={replicated}
6477 %rhs = s32[] parameter(1), sharding={replicated}
6478 ROOT %rng = s32[8]{0} rng(%lhs, %rhs),
6479 distribution=rng_uniform,
6480 sharding={devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6481 })";
6482
6483 TF_ASSERT_OK_AND_ASSIGN(auto module,
6484 PartitionComputation(hlo_string, /*num_devices=*/8));
6485 VLOG(1) << module->ToString();
6486
6487 const auto root = module->entry_computation()->root_instruction();
6488 const auto lhs = AllOf(op::Parameter(0), op::Shape("s32[]"));
6489 const auto rhs = AllOf(op::Parameter(1), op::Shape("s32[]"));
6490 auto partition_id =
6491 AllOf(op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
6492 op::Shape("u32[]"));
6493 EXPECT_THAT(
6494 root, AllOf(op::AllReduce(op::Select(
6495 op::Broadcast(op::Compare(partition_id, op::Constant())),
6496 op::Rng(lhs, rhs), op::Broadcast(op::Constant()))),
6497 op::Shape("s32[4]")));
6498 }
6499
TEST_F(SpmdPartitioningTest,DynamicSliceAlongNonPartitionedDimension)6500 TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) {
6501 absl::string_view hlo_string = R"(
6502 HloModule module
6503
6504 ENTRY entry {
6505 %input = s32[128,64] parameter(0), sharding={devices=[2,1]0,1}
6506 %index = s32[] parameter(1)
6507 %trivial_index = s32[] parameter(2)
6508 ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input, %trivial_index, %index),
6509 dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1}
6510 })";
6511
6512 TF_ASSERT_OK_AND_ASSIGN(auto module,
6513 PartitionComputation(hlo_string, /*num_devices=*/2));
6514 VLOG(1) << module->ToString();
6515
6516 const auto root = module->entry_computation()->root_instruction();
6517 auto input = AllOf(op::Parameter(0), op::Shape("s32[64,64]"));
6518 EXPECT_THAT(root,
6519 AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)),
6520 op::Shape("s32[64,2]")));
6521 }
6522
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongNonPartitionedDimension)6523 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) {
6524 absl::string_view hlo_string = R"(
6525 HloModule module
6526
6527 ENTRY entry {
6528 %input = s32[128,64] parameter(0), sharding={devices=[2,1]0,1}
6529 %index = s32[] parameter(1)
6530 %update = s32[128,2] parameter(2)
6531 %trivial_index = s32[] parameter(3)
6532 %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1}
6533 ROOT %dynamic-update-slice = s32[128,64]
6534 dynamic-update-slice(%input, %update.copy, %trivial_index, %index),
6535 sharding={devices=[2,1]0,1}
6536 })";
6537
6538 TF_ASSERT_OK_AND_ASSIGN(auto module,
6539 PartitionComputation(hlo_string, /*num_devices=*/2));
6540 VLOG(1) << module->ToString();
6541
6542 const auto root = module->entry_computation()->root_instruction();
6543 auto input = AllOf(op::Parameter(0), op::Shape("s32[64,64]"));
6544 auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(),
6545 op::Constant())),
6546 op::Shape("s32[64,2]"));
6547 EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(),
6548 op::Parameter(1)),
6549 op::Shape("s32[64,64]")));
6550 }
6551
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension)6552 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension) {
6553 absl::string_view hlo_string = R"(
6554 HloModule module
6555
6556 ENTRY entry {
6557 %input = s32[128,64] parameter(0), sharding={devices=[1,2]0,1}
6558 %index = s32[] parameter(1)
6559 %constant = s32[] constant(60)
6560 %update = s32[128,2] parameter(2), sharding={devices=[1,2]0,1}
6561 ROOT %dynamic-update-slice = s32[128,64]
6562 dynamic-update-slice(%input, %update, %index, %constant),
6563 sharding={devices=[1,2]0,1}
6564 })";
6565
6566 TF_ASSERT_OK_AND_ASSIGN(auto module,
6567 PartitionComputation(hlo_string, /*num_devices=*/2));
6568 VLOG(1) << module->ToString();
6569
6570 const auto root = module->entry_computation()->root_instruction();
6571 auto input = AllOf(op::Parameter(0), op::Shape("s32[128,32]"));
6572 auto update = AllOf(
6573 op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), op::Parameter(2),
6574 op::Constant(), op::Reshape())),
6575 op::Shape("s32[128,2]"));
6576
6577 EXPECT_THAT(root,
6578 AllOf(op::Select(op::Broadcast(),
6579 op::DynamicUpdateSlice(
6580 input, update, op::Constant(), op::Select()),
6581 input),
6582 op::Shape("s32[128,32]")));
6583 }
6584
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension2)6585 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension2) {
6586 absl::string_view hlo_string = R"(
6587 HloModule module
6588
6589 ENTRY entry {
6590 %input = s32[8,790,2] parameter(0),
6591 sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
6592 %index = s32[] parameter(1)
6593 %constant = s32[] constant(0)
6594 %update = s32[1,790,2] parameter(2),
6595 sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
6596 ROOT %dynamic-update-slice = s32[8,790,2]
6597 dynamic-update-slice(%input, %update, %index, %constant, %constant),
6598 sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
6599 })";
6600
6601 TF_ASSERT_OK_AND_ASSIGN(auto module,
6602 PartitionComputation(hlo_string, /*num_devices=*/8));
6603 VLOG(1) << module->ToString();
6604
6605 const auto root = module->entry_computation()->root_instruction();
6606 auto input = AllOf(op::Parameter(0), op::Shape("s32[1,790,2]"));
6607 auto update = AllOf(op::AllReduce(op::Select(
6608 op::Broadcast(), op::Parameter(2), op::Broadcast())),
6609 op::Shape("s32[1,790,2]"));
6610 EXPECT_THAT(
6611 root,
6612 AllOf(op::Select(op::Broadcast(),
6613 op::DynamicUpdateSlice(input, update, op::Select(),
6614 op::Constant(), op::Constant()),
6615 input),
6616 op::Shape("s32[1,790,2]")));
6617 }
6618
TEST_F(SpmdPartitioningTest,DynamicUpdateSlicePartitionSliceAndNonSliceDims)6619 TEST_F(SpmdPartitioningTest, DynamicUpdateSlicePartitionSliceAndNonSliceDims) {
6620 absl::string_view hlo_string = R"(
6621 HloModule module
6622
6623 ENTRY entry {
6624 %input = s32[128,64] parameter(0)
6625 %input.copy = s32[128,64] copy(%input), sharding={devices=[2,2]0,1,2,3}
6626 %constant.0 = s32[] constant(0)
6627 %constant.1 = s32[] constant(60)
6628 %update = s32[128,2] parameter(1)
6629 %update.copy = s32[128,2] copy(%update), sharding={devices=[2,2]0,1,2,3}
6630 ROOT %dynamic-update-slice = s32[128,64]
6631 dynamic-update-slice(%input.copy, %update.copy, %constant.0, %constant.1),
6632 sharding={devices=[2,2]0,1,2,3}
6633 })";
6634
6635 TF_ASSERT_OK_AND_ASSIGN(auto module,
6636 PartitionComputation(hlo_string, /*num_devices=*/4));
6637 VLOG(1) << module->ToString();
6638
6639 const auto root = module->entry_computation()->root_instruction();
6640 auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6641 op::Reshape())),
6642 op::Shape("s32[64,32]"));
6643 auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
6644 op::Broadcast(),
6645 op::Copy(op::DynamicSlice(
6646 op::Parameter(1), op::Reshape(), op::Reshape())),
6647 op::Constant(), op::Reshape())),
6648 op::Shape("s32[64,2]"));
6649
6650 EXPECT_THAT(root,
6651 AllOf(op::Select(op::Broadcast(),
6652 op::DynamicUpdateSlice(
6653 input, update, op::Constant(), op::Select()),
6654 input),
6655 op::Shape("s32[64,32]")));
6656 }
6657
TEST_F(SpmdPartitioningTest,PassthroughGather)6658 TEST_F(SpmdPartitioningTest, PassthroughGather) {
6659 absl::string_view hlo_string = R"(
6660 HloModule module
6661
6662 ENTRY entry {
6663 %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
6664 %indices = s32[3] parameter(1), sharding={replicated}
6665 ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
6666 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
6667 slice_sizes={1,9}, sharding={devices=[1,2]0,1}
6668 })";
6669 TF_ASSERT_OK_AND_ASSIGN(auto module,
6670 PartitionComputation(hlo_string, /*num_devices=*/2));
6671 VLOG(1) << module->ToString();
6672 HloInstruction* root = module->entry_computation()->root_instruction();
6673 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6674 op::Shape("f32[3,5]")));
6675 }
6676
TEST_F(SpmdPartitioningTest,PassthroughGather_PartialReplicate)6677 TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) {
6678 absl::string_view hlo_string = R"(
6679 HloModule module
6680
6681 ENTRY entry {
6682 %input = f32[2,9] parameter(0),
6683 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6684 %indices = s32[3] parameter(1), sharding={replicated}
6685 ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
6686 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
6687 slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6688 })";
6689 TF_ASSERT_OK_AND_ASSIGN(auto module,
6690 PartitionComputation(hlo_string, /*num_devices=*/4));
6691 VLOG(1) << module->ToString();
6692 HloInstruction* root = module->entry_computation()->root_instruction();
6693 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6694 op::Shape("f32[3,5]")));
6695 }
6696
TEST_F(SpmdPartitioningTest,IndexPassthroughGather)6697 TEST_F(SpmdPartitioningTest, IndexPassthroughGather) {
6698 absl::string_view hlo_string = R"(
6699 HloModule module
6700
6701 ENTRY entry {
6702 %input = f32[2,9,8] parameter(0), sharding={replicated}
6703 %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
6704 ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
6705 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
6706 slice_sizes={1,1,8}, sharding={devices=[1,2,2]0,1,2,3}
6707 })";
6708 TF_ASSERT_OK_AND_ASSIGN(auto module,
6709 PartitionComputation(hlo_string, /*num_devices=*/4));
6710 VLOG(1) << module->ToString();
6711 HloInstruction* root = module->entry_computation()->root_instruction();
6712 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6713 op::Shape("f32[8,2,2]")));
6714 }
6715
TEST_F(SpmdPartitioningTest,IndexPassthroughGather_PartialReplicate)6716 TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) {
6717 absl::string_view hlo_string = R"(
6718 HloModule module
6719
6720 ENTRY entry {
6721 %input = f32[2,9,8] parameter(0), sharding={replicated}
6722 %indices = s32[4,2,4] parameter(1),
6723 sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6724 ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
6725 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
6726 slice_sizes={1,1,8},
6727 sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6728 })";
6729 TF_ASSERT_OK_AND_ASSIGN(auto module,
6730 PartitionComputation(hlo_string, /*num_devices=*/8));
6731 VLOG(1) << module->ToString();
6732 HloInstruction* root = module->entry_computation()->root_instruction();
6733 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
6734 op::Shape("f32[8,2,2]")));
6735 }
6736
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims)6737 TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) {
6738 absl::string_view hlo_string = R"(
6739 HloModule module
6740
6741 ENTRY entry {
6742 %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
6743 %indices = s32[2,3] parameter(1), sharding={replicated}
6744 ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
6745 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
6746 slice_sizes={1,9}, sharding={replicated}
6747 })";
6748 TF_ASSERT_OK_AND_ASSIGN(auto module,
6749 PartitionComputation(hlo_string, /*num_devices=*/2));
6750 VLOG(1) << module->ToString();
6751 auto offset =
6752 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
6753 auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
6754 auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
6755 op::Shape("s32[2,3]"));
6756 auto clamp = op::Clamp(min, op::Parameter(1), max);
6757 auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
6758 auto mask =
6759 op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
6760 auto masked =
6761 op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
6762 HloInstruction* root = module->entry_computation()->root_instruction();
6763 EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
6764 }
6765
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims_PartialReplicate)6766 TEST_F(SpmdPartitioningTest,
6767 GatherPartitionedOnTrivialSliceDims_PartialReplicate) {
6768 absl::string_view hlo_string = R"(
6769 HloModule module
6770
6771 ENTRY entry {
6772 %input = f32[17,9] parameter(0),
6773 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6774 %indices = s32[2,3] parameter(1), sharding={replicated}
6775 ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
6776 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
6777 slice_sizes={1,9}, sharding={replicated}
6778 })";
6779 TF_ASSERT_OK_AND_ASSIGN(auto module,
6780 PartitionComputation(hlo_string, /*num_devices=*/4));
6781 VLOG(1) << module->ToString();
6782 auto offset =
6783 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
6784 auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
6785 auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
6786 op::Shape("s32[2,3]"));
6787 auto clamp = op::Clamp(min, op::Parameter(1), max);
6788 auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
6789 auto mask =
6790 op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
6791 auto masked =
6792 op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
6793 HloInstruction* root = module->entry_computation()->root_instruction();
6794 EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
6795 }
6796
TEST_F(SpmdPartitioningTest,PassthroughScatter)6797 TEST_F(SpmdPartitioningTest, PassthroughScatter) {
6798 absl::string_view hlo_string = R"(
6799 HloModule module
6800
6801 add (lhs: f32[], rhs: f32[]) -> f32[] {
6802 lhs = f32[] parameter(0)
6803 rhs = f32[] parameter(1)
6804 ROOT sum = f32[] add(lhs, rhs)
6805 }
6806
6807 ENTRY entry {
6808 %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
6809 %indices = s32[3] parameter(1), sharding={replicated}
6810 %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
6811 ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
6812 to_apply=add,
6813 update_window_dims={1},
6814 inserted_window_dims={0},
6815 scatter_dims_to_operand_dims={0},
6816 index_vector_dim=1, sharding={devices=[1,2]0,1}
6817 })";
6818 TF_ASSERT_OK_AND_ASSIGN(auto module,
6819 PartitionComputation(hlo_string, /*num_devices=*/2));
6820 VLOG(1) << module->ToString();
6821 HloInstruction* root = module->entry_computation()->root_instruction();
6822 EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6823 op::Parameter(2)),
6824 op::Shape("f32[2,5]")));
6825 }
6826
TEST_F(SpmdPartitioningTest,PassthroughScatterVariadic)6827 TEST_F(SpmdPartitioningTest, PassthroughScatterVariadic) {
6828 absl::string_view hlo_string = R"(
6829 HloModule module
6830
6831 add_min_max {
6832 lhs0 = f32[] parameter(0)
6833 lhs1 = f32[] parameter(1)
6834 rhs0 = f32[] parameter(2)
6835 rhs1 = f32[] parameter(3)
6836 min = minimum(rhs0, rhs1)
6837 max = maximum(rhs0, rhs1)
6838 min_sum = add(lhs0, min)
6839 max_sum = add(lhs1, max)
6840 ROOT tuple = tuple(min_sum, max_sum)
6841 }
6842
6843 ENTRY entry {
6844 %input0 = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
6845 %input1 = f32[2,9] parameter(1), sharding={devices=[1,2]0,1}
6846 %indices = s32[3] parameter(2), sharding={replicated}
6847 %updates0 = f32[3,9] parameter(3), sharding={devices=[1,2]0,1}
6848 %updates1 = f32[3,9] parameter(4), sharding={devices=[1,2]0,1}
6849 ROOT %scatter = (f32[2,9], f32[2,9])
6850 scatter(%input0, %input1, %indices, %updates0, %updates1),
6851 to_apply=add_min_max, update_window_dims={1}, inserted_window_dims={0},
6852 scatter_dims_to_operand_dims={0}, index_vector_dim=1,
6853 sharding={devices=[1,2]0,1}
6854 })";
6855 TF_ASSERT_OK_AND_ASSIGN(auto module,
6856 PartitionComputation(hlo_string, /*num_devices=*/2));
6857 HloInstruction* root = module->entry_computation()->root_instruction();
6858 EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6859 op::Parameter(2), op::Parameter(3),
6860 op::Parameter(4)),
6861 op::Shape("(f32[2,5], f32[2,5])")));
6862 }
6863
TEST_F(SpmdPartitioningTest,PassthroughScatter_PartialReplicate)6864 TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
6865 absl::string_view hlo_string = R"(
6866 HloModule module
6867
6868 add (lhs: f32[], rhs: f32[]) -> f32[] {
6869 lhs = f32[] parameter(0)
6870 rhs = f32[] parameter(1)
6871 ROOT sum = f32[] add(lhs, rhs)
6872 }
6873
6874 ENTRY entry {
6875 %input = f32[2,9] parameter(0),
6876 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6877 %indices = s32[3] parameter(1), sharding={replicated}
6878 %updates = f32[3,9] parameter(2),
6879 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6880 ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
6881 to_apply=add,
6882 update_window_dims={1},
6883 inserted_window_dims={0},
6884 scatter_dims_to_operand_dims={0},
6885 index_vector_dim=1,
6886 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6887 })";
6888 TF_ASSERT_OK_AND_ASSIGN(auto module,
6889 PartitionComputation(hlo_string, /*num_devices=*/4));
6890 VLOG(1) << module->ToString();
6891 HloInstruction* root = module->entry_computation()->root_instruction();
6892 EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6893 op::Parameter(2)),
6894 op::Shape("f32[2,5]")));
6895 }
6896
TEST_F(SpmdPartitioningTest,PassthroughScatterVariadic_PartialReplicate)6897 TEST_F(SpmdPartitioningTest, PassthroughScatterVariadic_PartialReplicate) {
6898 absl::string_view hlo_string = R"(
6899 HloModule module
6900
6901 add_min_max {
6902 lhs0 = f32[] parameter(0)
6903 lhs1 = f32[] parameter(1)
6904 rhs0 = f32[] parameter(2)
6905 rhs1 = f32[] parameter(3)
6906 min = minimum(rhs0, rhs1)
6907 max = maximum(rhs0, rhs1)
6908 min_sum = add(lhs0, min)
6909 max_sum = add(lhs1, max)
6910 ROOT tuple = tuple(min_sum, max_sum)
6911 }
6912
6913 ENTRY entry {
6914 %input0 = f32[2,9] parameter(0),
6915 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6916 %input1 = f32[2,9] parameter(1),
6917 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6918 %indices = s32[3] parameter(2), sharding={replicated}
6919 %updates0 = f32[3,9] parameter(3),
6920 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6921 %updates1 = f32[3,9] parameter(4),
6922 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6923 ROOT %scatter = (f32[2,9], f32[2,9])
6924 scatter(%input0, %input1, %indices, %updates0, %updates1),
6925 to_apply=add_min_max, update_window_dims={1}, inserted_window_dims={0},
6926 scatter_dims_to_operand_dims={0}, index_vector_dim=1,
6927 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
6928 })";
6929 TF_ASSERT_OK_AND_ASSIGN(auto module,
6930 PartitionComputation(hlo_string, /*num_devices=*/4));
6931 HloInstruction* root = module->entry_computation()->root_instruction();
6932 EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
6933 op::Parameter(2), op::Parameter(3),
6934 op::Parameter(4)),
6935 op::Shape("(f32[2,5], f32[2,5])")));
6936 }
6937
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter)6938 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
6939 absl::string_view hlo_string = R"(
6940 HloModule module
6941
6942 add (lhs: f32[], rhs: f32[]) -> f32[] {
6943 lhs = f32[] parameter(0)
6944 rhs = f32[] parameter(1)
6945 ROOT sum = f32[] add(lhs, rhs)
6946 }
6947
6948 ENTRY entry {
6949 %input = f32[2,9,8] parameter(0), sharding={replicated}
6950 %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
6951 %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
6952 ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
6953 to_apply=add,
6954 update_window_dims={2},
6955 inserted_window_dims={0,1},
6956 scatter_dims_to_operand_dims={0,1},
6957 index_vector_dim=1, sharding={replicated}
6958 })";
6959 TF_ASSERT_OK_AND_ASSIGN(auto module,
6960 PartitionComputation(hlo_string, /*num_devices=*/4));
6961 VLOG(1) << module->ToString();
6962 HloInstruction* root = module->entry_computation()->root_instruction();
6963 EXPECT_THAT(
6964 root,
6965 AllOf(op::AllReduce(op::AllReduce(op::Scatter(
6966 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
6967 op::Broadcast(op::Constant()), op::Parameter(0)),
6968 op::Parameter(1), op::Parameter(2)))),
6969 op::Shape("f32[2,9,8]")));
6970 }
6971
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_PartialReplicate)6972 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
6973 absl::string_view hlo_string = R"(
6974 HloModule module
6975
6976 add (lhs: f32[], rhs: f32[]) -> f32[] {
6977 lhs = f32[] parameter(0)
6978 rhs = f32[] parameter(1)
6979 ROOT sum = f32[] add(lhs, rhs)
6980 }
6981
6982 ENTRY entry {
6983 %input = f32[2,9,8] parameter(0), sharding={replicated}
6984 %indices = s32[4,2,4] parameter(1),
6985 sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6986 %updates = f32[4,4,8] parameter(2),
6987 sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6988 ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
6989 to_apply=add,
6990 update_window_dims={2},
6991 inserted_window_dims={0,1},
6992 scatter_dims_to_operand_dims={0,1},
6993 index_vector_dim=1, sharding={replicated}
6994 })";
6995 TF_ASSERT_OK_AND_ASSIGN(auto module,
6996 PartitionComputation(hlo_string, /*num_devices=*/8));
6997 VLOG(1) << module->ToString();
6998 HloInstruction* root = module->entry_computation()->root_instruction();
6999 EXPECT_THAT(
7000 root,
7001 AllOf(op::AllReduce(op::AllReduce(op::Scatter(
7002 op::Select(op::Broadcast(op::Convert(op::Reshape())),
7003 op::Broadcast(op::Constant()), op::Parameter(0)),
7004 op::Parameter(1), op::Parameter(2)))),
7005 op::Shape("f32[2,9,8]")));
7006 }
7007
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_Min)7008 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
7009 absl::string_view hlo_string = R"(
7010 HloModule module
7011
7012 min (lhs: f32[], rhs: f32[]) -> f32[] {
7013 lhs = f32[] parameter(0)
7014 rhs = f32[] parameter(1)
7015 ROOT min = f32[] minimum(lhs, rhs)
7016 }
7017
7018 ENTRY entry {
7019 %input = f32[2,9,8] parameter(0), sharding={replicated}
7020 %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
7021 %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
7022 ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
7023 to_apply=min,
7024 update_window_dims={2},
7025 inserted_window_dims={0,1},
7026 scatter_dims_to_operand_dims={0,1},
7027 index_vector_dim=1, sharding={replicated}
7028 })";
7029 TF_ASSERT_OK_AND_ASSIGN(auto module,
7030 PartitionComputation(hlo_string, /*num_devices=*/4));
7031 VLOG(1) << module->ToString();
7032 HloInstruction* root = module->entry_computation()->root_instruction();
7033 EXPECT_THAT(
7034 root,
7035 AllOf(op::AllReduce(op::AllReduce(op::Scatter(
7036 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
7037 op::Broadcast(op::Constant()), op::Parameter(0)),
7038 op::Parameter(1), op::Parameter(2)))),
7039 op::Shape("f32[2,9,8]")));
7040 }
7041
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims)7042 TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) {
7043 absl::string_view hlo_string = R"(
7044 HloModule module
7045
7046 add (lhs: f32[], rhs: f32[]) -> f32[] {
7047 lhs = f32[] parameter(0)
7048 rhs = f32[] parameter(1)
7049 ROOT sum = f32[] add(lhs, rhs)
7050 }
7051
7052 ENTRY entry {
7053 %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
7054 %indices = s32[2,3] parameter(1), sharding={replicated}
7055 %updates = f32[2,3,9] parameter(2), sharding={replicated}
7056 ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
7057 to_apply=add,
7058 update_window_dims={2},
7059 inserted_window_dims={0},
7060 scatter_dims_to_operand_dims={0},
7061 index_vector_dim=2, sharding={devices=[2,1]0,1}
7062 })";
7063 TF_ASSERT_OK_AND_ASSIGN(auto module,
7064 PartitionComputation(hlo_string, /*num_devices=*/2));
7065 VLOG(1) << module->ToString();
7066 auto offset =
7067 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7068 auto indices = op::Subtract(
7069 op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7070 HloInstruction* root = module->entry_computation()->root_instruction();
7071 EXPECT_THAT(root,
7072 AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
7073 op::Shape("f32[9,9]")));
7074 }
7075
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDimsVariadic)7076 TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDimsVariadic) {
7077 absl::string_view hlo_string = R"(
7078 HloModule module
7079
7080 add_min_max {
7081 lhs0 = f32[] parameter(0)
7082 lhs1 = f32[] parameter(1)
7083 rhs0 = f32[] parameter(2)
7084 rhs1 = f32[] parameter(3)
7085 min = minimum(rhs0, rhs1)
7086 max = maximum(rhs0, rhs1)
7087 min_sum = add(lhs0, min)
7088 max_sum = add(lhs1, max)
7089 ROOT tuple = tuple(min_sum, max_sum)
7090 }
7091
7092 ENTRY entry {
7093 %input0 = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
7094 %input1 = f32[17,9] parameter(1), sharding={devices=[2,1]0,1}
7095 %indices = s32[2,3] parameter(2), sharding={replicated}
7096 %updates0 = f32[2,3,9] parameter(3), sharding={replicated}
7097 %updates1 = f32[2,3,9] parameter(4), sharding={replicated}
7098 ROOT %scatter = (f32[17,9], f32[17,9])
7099 scatter(%input0, %input1, %indices, %updates0, %updates1),
7100 to_apply=add_min_max, update_window_dims={2}, inserted_window_dims={0},
7101 scatter_dims_to_operand_dims={0}, index_vector_dim=2,
7102 sharding={devices=[2,1]0,1}
7103 })";
7104 TF_ASSERT_OK_AND_ASSIGN(auto module,
7105 PartitionComputation(hlo_string, /*num_devices=*/2));
7106 auto offset =
7107 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7108 auto indices = op::Subtract(
7109 op::Parameter(2), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7110 HloInstruction* root = module->entry_computation()->root_instruction();
7111 EXPECT_THAT(root,
7112 AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), indices,
7113 op::Parameter(3), op::Parameter(4)),
7114 op::Shape("(f32[9,9], f32[9,9])")));
7115 }
7116
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims_PartialReplicate)7117 TEST_F(SpmdPartitioningTest,
7118 ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
7119 absl::string_view hlo_string = R"(
7120 HloModule module
7121
7122 add (lhs: f32[], rhs: f32[]) -> f32[] {
7123 lhs = f32[] parameter(0)
7124 rhs = f32[] parameter(1)
7125 ROOT sum = f32[] add(lhs, rhs)
7126 }
7127
7128 ENTRY entry {
7129 %input = f32[17,9] parameter(0),
7130 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7131 %indices = s32[2,3] parameter(1), sharding={replicated}
7132 %updates = f32[2,3,9] parameter(2), sharding={replicated}
7133 ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
7134 to_apply=add,
7135 update_window_dims={2},
7136 inserted_window_dims={0},
7137 scatter_dims_to_operand_dims={0},
7138 index_vector_dim=2,
7139 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7140 })";
7141 TF_ASSERT_OK_AND_ASSIGN(auto module,
7142 PartitionComputation(hlo_string, /*num_devices=*/4));
7143 VLOG(1) << module->ToString();
7144 auto offset =
7145 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7146 auto indices = op::Subtract(
7147 op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7148 HloInstruction* root = module->entry_computation()->root_instruction();
7149 EXPECT_THAT(root,
7150 AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
7151 op::Shape("f32[9,9]")));
7152 }
7153
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDimsVariadic_PartialReplicate)7154 TEST_F(SpmdPartitioningTest,
7155 ScatterPartitionedOnTrivialSliceDimsVariadic_PartialReplicate) {
7156 absl::string_view hlo_string = R"(
7157 HloModule module
7158
7159 add_min_max {
7160 lhs0 = f32[] parameter(0)
7161 lhs1 = f32[] parameter(1)
7162 rhs0 = f32[] parameter(2)
7163 rhs1 = f32[] parameter(3)
7164 min = minimum(rhs0, rhs1)
7165 max = maximum(rhs0, rhs1)
7166 min_sum = add(lhs0, min)
7167 max_sum = add(lhs1, max)
7168 ROOT tuple = tuple(min_sum, max_sum)
7169 }
7170
7171 ENTRY entry {
7172 %input0 = f32[17,9] parameter(0),
7173 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7174 %input1 = f32[17,9] parameter(1),
7175 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7176 %indices = s32[2,3] parameter(2), sharding={replicated}
7177 %updates0 = f32[2,3,9] parameter(3), sharding={replicated}
7178 %updates1 = f32[2,3,9] parameter(4), sharding={replicated}
7179 ROOT %scatter = (f32[17,9], f32[17,9])
7180 scatter(%input0, %input1, %indices, %updates0, %updates1),
7181 to_apply=add_min_max, update_window_dims={2}, inserted_window_dims={0},
7182 scatter_dims_to_operand_dims={0}, index_vector_dim=2,
7183 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7184 })";
7185 TF_ASSERT_OK_AND_ASSIGN(auto module,
7186 PartitionComputation(hlo_string, /*num_devices=*/4));
7187 VLOG(1) << module->ToString();
7188 auto offset =
7189 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
7190 auto indices = op::Subtract(
7191 op::Parameter(2), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
7192 HloInstruction* root = module->entry_computation()->root_instruction();
7193 EXPECT_THAT(root,
7194 AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), indices,
7195 op::Parameter(3), op::Parameter(4)),
7196 op::Shape("(f32[9,9], f32[9,9])")));
7197 }
7198
TEST_F(SpmdPartitioningTest,TiledReversePassthrough)7199 TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
7200 absl::string_view hlo_string = R"(
7201 HloModule module
7202
7203 ENTRY entry {
7204 constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
7205 sharding={devices=[2,1]0,1}
7206 ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1},
7207 sharding={devices=[2,1]0,1}
7208 })";
7209 TF_ASSERT_OK_AND_ASSIGN(auto module,
7210 PartitionComputation(hlo_string, /*num_devices=*/2));
7211 VLOG(1) << module->ToString();
7212 HloInstruction* root = module->entry_computation()->root_instruction();
7213 EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"),
7214 op::Reverse(op::DynamicSlice(
7215 op::Pad(op::Constant(), op::Constant()),
7216 op::Reshape(), op::Constant()))));
7217 }
7218
TEST_F(SpmdPartitioningTest,TiledReversePassthroughViaReversedSharding)7219 TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
7220 absl::string_view hlo_string = R"(
7221 HloModule module
7222
7223 ENTRY entry {
7224 param = f32[4] parameter(0), sharding={devices=[2]0,1}
7225 ROOT reverse = f32[4] reverse(param), dimensions={0},
7226 sharding={devices=[2]1,0}
7227 })";
7228 TF_ASSERT_OK_AND_ASSIGN(auto module,
7229 PartitionComputation(hlo_string, /*num_devices=*/2));
7230 VLOG(1) << module->ToString();
7231 HloInstruction* root = module->entry_computation()->root_instruction();
7232 EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
7233 }
7234
TEST_F(SpmdPartitioningTest,TiledReverseSwapShards)7235 TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
7236 absl::string_view hlo_string = R"(
7237 HloModule module
7238
7239 ENTRY entry {
7240 param = f32[4] parameter(0), sharding={devices=[2]0,1}
7241 ROOT reverse = f32[4] reverse(param), dimensions={0},
7242 sharding={devices=[2]0,1}
7243 })";
7244 TF_ASSERT_OK_AND_ASSIGN(auto module,
7245 PartitionComputation(hlo_string, /*num_devices=*/2));
7246 VLOG(1) << module->ToString();
7247 HloInstruction* root = module->entry_computation()->root_instruction();
7248 EXPECT_THAT(root,
7249 AllOf(op::Shape("f32[2]"),
7250 op::Reverse(op::CollectivePermute(op::Parameter(0)))));
7251 }
7252
TEST_F(SpmdPartitioningTest,TiledReverseHaloExchange)7253 TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
7254 absl::string_view hlo_string = R"(
7255 HloModule module
7256
7257 ENTRY entry {
7258 param = f32[3] parameter(0), sharding={devices=[2]0,1}
7259 ROOT reverse = f32[3] reverse(param), dimensions={0},
7260 sharding={devices=[2]1,0}
7261 })";
7262 TF_ASSERT_OK_AND_ASSIGN(auto module,
7263 PartitionComputation(hlo_string, /*num_devices=*/2));
7264 VLOG(1) << module->ToString();
7265 HloInstruction* root = module->entry_computation()->root_instruction();
7266 auto halo_exchange_concat =
7267 op::Concatenate(AllOf(op::Shape("f32[1]"),
7268 op::CollectivePermute(op::Slice(op::Parameter(0)))),
7269 op::Slice(op::Parameter(0)));
7270 EXPECT_THAT(root,
7271 AllOf(op::Shape("f32[2]"), op::Reverse(halo_exchange_concat)));
7272 }
7273
TEST_F(SpmdPartitioningTest,MixWithManualPartitioning)7274 TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
7275 absl::string_view hlo_string = R"(
7276 HloModule module
7277
7278 ENTRY entry {
7279 param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}}
7280 param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1}
7281 param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual}
7282 to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
7283 add = f32[4,2] add(to_shard, param1), sharding={manual}
7284 to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
7285 mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
7286 to_shard2 = f32[4,2] custom-call(mul), custom_call_target="SPMDFullToShardShape", sharding={manual}
7287 ROOT tuple = (f32[4,2]) tuple(to_shard2), sharding={{manual}}
7288 })";
7289 TF_ASSERT_OK_AND_ASSIGN(auto module,
7290 PartitionComputation(hlo_string, /*num_devices=*/2));
7291 VLOG(1) << module->ToString();
7292 HloInstruction* root = module->entry_computation()->root_instruction();
7293 auto p0 = op::GetTupleElement(op::Parameter(0));
7294 auto to_shard = op::Copy(p0);
7295 auto p1 = op::GetTupleElement(op::Parameter(0));
7296 auto mul = AllOf(op::Shape("f32[4,2]"),
7297 op::Multiply(op::Copy(op::Add(to_shard, p1)), p0));
7298 EXPECT_THAT(root, op::Tuple(op::Copy(mul)));
7299 }
7300
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard)7301 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
7302 absl::string_view hlo_string = R"(
7303 HloModule module
7304
7305 ENTRY entry {
7306 %param0 = f32[8,8,8,8] parameter(0),
7307 sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
7308 ROOT %copy = f32[8,8,8,8] copy(%param0),
7309 sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
7310 })";
7311
7312 TF_ASSERT_OK_AND_ASSIGN(auto module,
7313 PartitionComputation(hlo_string, /*num_devices=*/8));
7314 VLOG(1) << module->ToString();
7315
7316 const auto root = module->entry_computation()->root_instruction();
7317 auto reshape =
7318 AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
7319 auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
7320 auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
7321 EXPECT_THAT(root,
7322 op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
7323 EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
7324 4);
7325 }
7326
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard2)7327 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) {
7328 absl::string_view hlo_string = R"(
7329 HloModule module
7330
7331 ENTRY entry {
7332 %param0 = f32[8,8] parameter(0),
7333 sharding={devices=[2,4]0,1,2,3,4,5,6,7}
7334 ROOT %copy = f32[8,8] copy(%param0),
7335 sharding={devices=[4,2]0,1,4,5,2,3,6,7}
7336 })";
7337
7338 TF_ASSERT_OK_AND_ASSIGN(auto module,
7339 PartitionComputation(hlo_string, /*num_devices=*/8));
7340 VLOG(1) << module->ToString();
7341
7342 const auto root = module->entry_computation()->root_instruction();
7343 auto all_to_all = op::AllToAll(
7344 AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0))));
7345 auto reshape =
7346 AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all)));
7347 EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape)));
7348 }
7349
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard3)7350 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) {
7351 absl::string_view hlo_string = R"(
7352 HloModule module
7353
7354 ENTRY entry {
7355 %param0 = f32[8,8,8] parameter(0),
7356 sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
7357 ROOT %copy = f32[8,8,8] copy(%param0),
7358 sharding={devices=[1,2,4]0,1,4,5,2,3,6,7}
7359 })";
7360
7361 TF_ASSERT_OK_AND_ASSIGN(auto module,
7362 PartitionComputation(hlo_string, /*num_devices=*/8));
7363 VLOG(1) << module->ToString();
7364
7365 const auto root = module->entry_computation()->root_instruction();
7366 auto all_to_all = op::AllToAll(
7367 AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0))));
7368 auto reshape =
7369 AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all)));
7370 auto all_to_all2 =
7371 op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape)));
7372 auto reshape2 =
7373 AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2)));
7374 EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
7375 }
7376
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting0)7377 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) {
7378 absl::string_view hlo_string = R"(
7379 HloModule module
7380
7381 ENTRY entry {
7382 %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3}
7383 %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,2,1,3}
7384 ROOT %dot = f32[48,32] dot(%lhs, %rhs),
7385 lhs_batch_dims={}, rhs_batch_dims={},
7386 lhs_contracting_dims={1}, rhs_contracting_dims={1},
7387 sharding={devices=[2,2]0,1,2,3}
7388 })";
7389
7390 TF_ASSERT_OK_AND_ASSIGN(auto module,
7391 PartitionComputation(hlo_string, /*num_devices=*/4));
7392 VLOG(1) << module->ToString();
7393
7394 const auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0));
7395 auto partial_replicated_lhs =
7396 AllOf(op::Shape("f32[24,12]"),
7397 op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _)));
7398 const auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1));
7399 auto partial_replicated_rhs =
7400 AllOf(op::Shape("f32[16,12]"),
7401 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
7402 const auto root = module->entry_computation()->root_instruction();
7403 EXPECT_THAT(root,
7404 AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs),
7405 op::Shape("f32[24,16]")));
7406 }
7407
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting1)7408 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) {
7409 absl::string_view hlo_string = R"(
7410 HloModule module
7411
7412 ENTRY entry {
7413 %lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3}
7414 %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
7415 ROOT %dot = f32[48,32] dot(%lhs, %rhs),
7416 lhs_batch_dims={}, rhs_batch_dims={},
7417 lhs_contracting_dims={1}, rhs_contracting_dims={1},
7418 sharding={devices=[2,2]0,1,2,3}
7419 })";
7420
7421 TF_ASSERT_OK_AND_ASSIGN(auto module,
7422 PartitionComputation(hlo_string, /*num_devices=*/4));
7423 VLOG(1) << module->ToString();
7424
7425 const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7426 const auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
7427 auto partial_replicated_rhs =
7428 AllOf(op::Shape("f32[32,50]"),
7429 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
7430 const auto root = module->entry_computation()->root_instruction();
7431 EXPECT_THAT(
7432 root, AllOf(op::Shape("f32[24,16]"),
7433 op::DynamicSlice(
7434 op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs),
7435 op::Shape("f32[24,32]"))),
7436 _, _)));
7437 }
7438
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting2)7439 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting2) {
7440 absl::string_view hlo_string = R"(
7441 HloModule module
7442
7443 ENTRY entry {
7444 %lhs = f32[48,100] parameter(0), sharding={replicated}
7445 %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
7446 ROOT %dot = f32[48,32] dot(%lhs, %rhs),
7447 lhs_batch_dims={}, rhs_batch_dims={},
7448 lhs_contracting_dims={1}, rhs_contracting_dims={1},
7449 sharding={devices=[2,2]0,1,2,3}
7450 })";
7451
7452 TF_ASSERT_OK_AND_ASSIGN(auto module,
7453 PartitionComputation(hlo_string, /*num_devices=*/4));
7454 VLOG(1) << module->ToString();
7455
7456 const auto lhs = AllOf(op::Shape("f32[48,100]"), op::Parameter(0));
7457 const auto lhs_slice =
7458 AllOf(op::Shape("f32[24,100]"), op::DynamicSlice(lhs, _, _));
7459 const auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
7460 auto partial_replicated_rhs = AllOf(
7461 op::Shape("f32[16,100]"), op::AllReduce(op::DynamicUpdateSlice(
7462 _, op::CollectivePermute(rhs), _, _)));
7463 const auto root = module->entry_computation()->root_instruction();
7464 EXPECT_THAT(root, AllOf(op::Shape("f32[24,16]"),
7465 op::Dot(lhs_slice, partial_replicated_rhs)));
7466 }
7467
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNoncontractingAndContracting3)7468 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) {
7469 absl::string_view hlo_string = R"(
7470 HloModule module
7471
7472 ENTRY entry {
7473 %lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
7474 %rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3}
7475 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7476 lhs_contracting_dims={0}, rhs_contracting_dims={0},
7477 sharding={devices=[2,2]1,0,3,2}
7478 })";
7479
7480 TF_ASSERT_OK_AND_ASSIGN(auto module,
7481 PartitionComputation(hlo_string, /*num_devices=*/4));
7482 VLOG(1) << module->ToString();
7483
7484 const auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0));
7485 auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant()));
7486 const auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1));
7487 auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant()));
7488 const auto root = module->entry_computation()->root_instruction();
7489 EXPECT_THAT(root,
7490 AllOf(op::Shape("f32[12,16]"),
7491 op::DynamicSlice(
7492 AllOf(op::Shape("f32[24,16]"),
7493 op::AllReduce(op::Dot(masked_lhs, masked_rhs))),
7494 _, _)));
7495 }
7496
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndNonContracting)7497 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
7498 absl::string_view hlo_string = R"(
7499 HloModule module
7500
7501 ENTRY entry {
7502 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
7503 %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
7504 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7505 lhs_batch_dims={0}, rhs_batch_dims={0},
7506 lhs_contracting_dims={2}, rhs_contracting_dims={2},
7507 sharding={devices=[2,2,1]0,1,2,3}
7508 })";
7509
7510 TF_ASSERT_OK_AND_ASSIGN(auto module,
7511 PartitionComputation(hlo_string, /*num_devices=*/4));
7512 VLOG(1) << module->ToString();
7513
7514 const auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0));
7515 const auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
7516 auto partial_replicated_rhs =
7517 AllOf(op::Shape("f32[2,32,100]"),
7518 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _)));
7519 const auto root = module->entry_computation()->root_instruction();
7520 EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
7521 op::Dot(lhs, partial_replicated_rhs)));
7522 }
7523
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting)7524 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting) {
7525 absl::string_view hlo_string = R"(
7526 HloModule module
7527
7528 ENTRY entry {
7529 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7530 %rhs = f32[4,32,100] parameter(1), sharding={devices=[1,2,2]0,1,2,3}
7531 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7532 lhs_batch_dims={0}, rhs_batch_dims={0},
7533 lhs_contracting_dims={2}, rhs_contracting_dims={2},
7534 sharding={devices=[2,2,1]0,1,2,3}
7535 })";
7536
7537 TF_ASSERT_OK_AND_ASSIGN(auto module,
7538 PartitionComputation(hlo_string, /*num_devices=*/4));
7539 VLOG(1) << module->ToString();
7540
7541 const auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
7542 const auto rhs = AllOf(op::Shape("f32[4,16,50]"), op::Parameter(1));
7543 auto resharded_rhs =
7544 AllOf(op::Shape("f32[2,32,50]"),
7545 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs)))));
7546 const auto root = module->entry_computation()->root_instruction();
7547 EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
7548 op::DynamicSlice(
7549 AllOf(op::Shape("f32[2,24,32]"),
7550 op::AllReduce(op::Dot(lhs, resharded_rhs))),
7551 _, _, _)));
7552 }
7553
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting2)7554 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting2) {
7555 absl::string_view hlo_string = R"(
7556 HloModule module
7557
7558 ENTRY entry {
7559 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7560 %rhs = f32[4,32,100] parameter(1), sharding={replicated}
7561 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7562 lhs_batch_dims={0}, rhs_batch_dims={0},
7563 lhs_contracting_dims={2}, rhs_contracting_dims={2},
7564 sharding={devices=[2,2,1]0,1,2,3}
7565 })";
7566
7567 TF_ASSERT_OK_AND_ASSIGN(auto module,
7568 PartitionComputation(hlo_string, /*num_devices=*/4));
7569 VLOG(1) << module->ToString();
7570
7571 const auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
7572 auto resharded_lhs =
7573 AllOf(op::Shape("f32[2,12,100]"),
7574 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))));
7575 const auto rhs = AllOf(op::Shape("f32[4,32,100]"), op::Parameter(1));
7576 const auto rhs_slice =
7577 AllOf(op::Shape("f32[2,32,100]"), op::DynamicSlice(rhs, _, _, _));
7578 const auto root = module->entry_computation()->root_instruction();
7579 EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
7580 op::Dot(resharded_lhs, rhs_slice)));
7581 }
7582
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchNonContractingAndContracting)7583 TEST_F(SpmdPartitioningTest,
7584 Dot2DPartitionedBatchNonContractingAndContracting) {
7585 absl::string_view hlo_string = R"(
7586 HloModule module
7587
7588 ENTRY entry {
7589 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7590 %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
7591 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7592 lhs_batch_dims={0}, rhs_batch_dims={0},
7593 lhs_contracting_dims={2}, rhs_contracting_dims={2},
7594 sharding={devices=[2,1,2]0,1,2,3}
7595 })";
7596
7597 TF_ASSERT_OK_AND_ASSIGN(auto module,
7598 PartitionComputation(hlo_string, /*num_devices=*/4));
7599 VLOG(1) << module->ToString();
7600
7601 const auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
7602 const auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
7603 auto partial_replicated_lhs =
7604 AllOf(op::Shape("f32[2,24,100]"),
7605 op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _)));
7606 const auto root = module->entry_computation()->root_instruction();
7607 EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"),
7608 op::Dot(partial_replicated_lhs, rhs)));
7609 }
7610
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndReshard)7611 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) {
7612 absl::string_view hlo_string = R"(
7613 HloModule module
7614
7615 ENTRY entry {
7616 %lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3}
7617 %rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3}
7618 ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs),
7619 lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
7620 lhs_contracting_dims={3}, rhs_contracting_dims={3},
7621 sharding={devices=[1,2,2,1]0,1,2,3}
7622 })";
7623
7624 TF_ASSERT_OK_AND_ASSIGN(auto module,
7625 PartitionComputation(hlo_string, /*num_devices=*/4));
7626 VLOG(1) << module->ToString();
7627
7628 const auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0));
7629 const auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1));
7630 auto partial_replicated_rhs =
7631 AllOf(op::Shape("f32[2,8,32,100]"),
7632 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _)));
7633 auto dot =
7634 AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs));
7635 auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot));
7636 auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape));
7637 auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all));
7638 const auto root = module->entry_computation()->root_instruction();
7639 EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose)));
7640 }
7641
TEST_F(SpmdPartitioningTest,SimpleDotPartial)7642 TEST_F(SpmdPartitioningTest, SimpleDotPartial) {
7643 absl::string_view hlo_string = R"(
7644 HloModule module
7645
7646 ENTRY entry {
7647 %lhs = f32[2,24,100] parameter(0),
7648 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7649 %rhs = f32[2,32,100] parameter(1),
7650 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7651 ROOT %dot = f32[2,24,32] dot(%lhs, %rhs),
7652 lhs_batch_dims={0}, rhs_batch_dims={0},
7653 lhs_contracting_dims={2}, rhs_contracting_dims={2},
7654 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7655 })";
7656
7657 TF_ASSERT_OK_AND_ASSIGN(auto module,
7658 PartitionComputation(hlo_string, /*num_devices=*/4));
7659 VLOG(1) << module->ToString();
7660
7661 const auto lhs = AllOf(op::Shape("f32[1,24,100]"), op::Parameter(0));
7662 const auto rhs = AllOf(op::Shape("f32[1,32,100]"), op::Parameter(1));
7663 auto dot = AllOf(op::Shape("f32[1,24,32]"), op::Dot(lhs, rhs));
7664 const auto root = module->entry_computation()->root_instruction();
7665 EXPECT_THAT(root, dot);
7666 }
7667
TEST_F(SpmdPartitioningTest,DotPartialContracting)7668 TEST_F(SpmdPartitioningTest, DotPartialContracting) {
7669 absl::string_view hlo_string = R"(
7670 HloModule module
7671
7672 ENTRY entry {
7673 %lhs = f32[24,100] parameter(0),
7674 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7675 %rhs = f32[32,100] parameter(1),
7676 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7677 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7678 lhs_batch_dims={}, rhs_batch_dims={},
7679 lhs_contracting_dims={1}, rhs_contracting_dims={1},
7680 sharding={replicated}
7681 })";
7682
7683 TF_ASSERT_OK_AND_ASSIGN(auto module,
7684 PartitionComputation(hlo_string, /*num_devices=*/4));
7685 VLOG(1) << module->ToString();
7686
7687 const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7688 const auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
7689 auto dot = AllOf(op::Shape("f32[24,32]"), op::Dot(lhs, rhs));
7690 const auto root = module->entry_computation()->root_instruction();
7691 EXPECT_THAT(root, op::AllReduce(dot));
7692 }
7693
TEST_F(SpmdPartitioningTest,DotPartialContracting2)7694 TEST_F(SpmdPartitioningTest, DotPartialContracting2) {
7695 absl::string_view hlo_string = R"(
7696 HloModule module
7697
7698 ENTRY entry {
7699 %lhs = f32[24,100] parameter(0),
7700 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7701 %rhs = f32[32,100] parameter(1),
7702 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7703 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7704 lhs_batch_dims={}, rhs_batch_dims={},
7705 lhs_contracting_dims={1}, rhs_contracting_dims={1},
7706 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
7707 })";
7708
7709 TF_ASSERT_OK_AND_ASSIGN(auto module,
7710 PartitionComputation(hlo_string, /*num_devices=*/4));
7711 VLOG(1) << module->ToString();
7712
7713 const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7714 const auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
7715 auto dot =
7716 AllOf(op::Shape("f32[12,32]"),
7717 op::Dot(AllOf(op::Shape("f32[12,50]"), op::DynamicSlice(lhs, _, _)),
7718 rhs));
7719 const auto root = module->entry_computation()->root_instruction();
7720 EXPECT_THAT(root, op::AllReduce(dot));
7721 }
7722
TEST_F(SpmdPartitioningTest,DotPartialContracting3)7723 TEST_F(SpmdPartitioningTest, DotPartialContracting3) {
7724 absl::string_view hlo_string = R"(
7725 HloModule module
7726
7727 ENTRY entry {
7728 %lhs = f32[24,100] parameter(0),
7729 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7730 %rhs = f32[32,100] parameter(1),
7731 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7732 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7733 lhs_batch_dims={}, rhs_batch_dims={},
7734 lhs_contracting_dims={1}, rhs_contracting_dims={1},
7735 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7736 })";
7737
7738 TF_ASSERT_OK_AND_ASSIGN(auto module,
7739 PartitionComputation(hlo_string, /*num_devices=*/8));
7740 VLOG(1) << module->ToString();
7741
7742 const auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
7743 const auto rhs =
7744 AllOf(op::Shape("f32[16,50]"), op::DynamicSlice(op::Parameter(1), _, _));
7745 auto dot = AllOf(op::Shape("f32[24,16]"), op::Dot(lhs, rhs));
7746 const auto root = module->entry_computation()->root_instruction();
7747 EXPECT_THAT(root, op::CollectivePermute(op::AllReduce(dot)));
7748 }
7749
TEST_F(SpmdPartitioningTest,DotBatchAndPartialContracting)7750 TEST_F(SpmdPartitioningTest, DotBatchAndPartialContracting) {
7751 absl::string_view hlo_string = R"(
7752 HloModule module
7753
7754 ENTRY entry {
7755 %lhs = f32[4,24,100] parameter(0),
7756 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
7757 %rhs = f32[4,32,100] parameter(1),
7758 sharding={devices=[2,1,2,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}
7759 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
7760 lhs_batch_dims={0}, rhs_batch_dims={0},
7761 lhs_contracting_dims={2}, rhs_contracting_dims={2},
7762 sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7763 })";
7764
7765 TF_ASSERT_OK_AND_ASSIGN(auto module,
7766 PartitionComputation(hlo_string, /*num_devices=*/8));
7767 VLOG(1) << module->ToString();
7768
7769 const auto lhs = AllOf(op::Shape("f32[2,12,50]"), op::Parameter(0));
7770 const auto rhs = AllOf(op::Shape("f32[2,32,50]"), op::Parameter(1));
7771 auto dot = AllOf(op::Shape("f32[2,12,32]"), op::Dot(lhs, rhs));
7772 const auto root = module->entry_computation()->root_instruction();
7773 EXPECT_THAT(root, op::AllReduce(dot));
7774 }
7775
TEST_F(SpmdPartitioningTest,DotPartialNonContracting)7776 TEST_F(SpmdPartitioningTest, DotPartialNonContracting) {
7777 absl::string_view hlo_string = R"(
7778 HloModule module
7779
7780 ENTRY entry {
7781 %lhs = f32[24,8,100] parameter(0),
7782 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7783 %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,2,1,3}
7784 ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
7785 lhs_batch_dims={}, rhs_batch_dims={},
7786 lhs_contracting_dims={2}, rhs_contracting_dims={1},
7787 sharding={devices=[2,1,2]0,1,2,3}
7788 })";
7789
7790 TF_ASSERT_OK_AND_ASSIGN(auto module,
7791 PartitionComputation(hlo_string, /*num_devices=*/4));
7792 VLOG(1) << module->ToString();
7793
7794 const auto lhs = AllOf(op::Shape("f32[12,8,100]"), op::Parameter(0));
7795 const auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
7796 auto partially_replicated_rhs =
7797 AllOf(op::Shape("f32[16,100]"),
7798 op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), rhs, _, _)));
7799 auto dot =
7800 AllOf(op::Shape("f32[12,8,16]"), op::Dot(lhs, partially_replicated_rhs));
7801 const auto root = module->entry_computation()->root_instruction();
7802 EXPECT_THAT(root, dot);
7803 }
7804
TEST_F(SpmdPartitioningTest,DotPartialNonContractingPartialMatch)7805 TEST_F(SpmdPartitioningTest, DotPartialNonContractingPartialMatch) {
7806 absl::string_view hlo_string = R"(
7807 HloModule module
7808
7809 ENTRY entry {
7810 %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
7811 %rhs = f32[32,100] parameter(1),
7812 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
7813 ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
7814 lhs_batch_dims={}, rhs_batch_dims={},
7815 lhs_contracting_dims={2}, rhs_contracting_dims={1},
7816 sharding={devices=[2,1,2]0,1,2,3}
7817 })";
7818
7819 TF_ASSERT_OK_AND_ASSIGN(auto module,
7820 PartitionComputation(hlo_string, /*num_devices=*/4));
7821 VLOG(1) << module->ToString();
7822
7823 const auto lhs = AllOf(op::Shape("f32[12,4,100]"), op::Parameter(0));
7824 const auto rhs = AllOf(op::Shape("f32[16,100]"), op::Parameter(1));
7825 auto partially_replicated_lhs = AllOf(
7826 op::Shape("f32[12,8,100]"),
7827 op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), lhs, _, _, _)));
7828 auto dot =
7829 AllOf(op::Shape("f32[12,8,16]"), op::Dot(partially_replicated_lhs, rhs));
7830 const auto root = module->entry_computation()->root_instruction();
7831 EXPECT_THAT(root, dot);
7832 }
7833
TEST_F(SpmdPartitioningTest,DotPartialContractingPartialMatch)7834 TEST_F(SpmdPartitioningTest, DotPartialContractingPartialMatch) {
7835 absl::string_view hlo_string = R"(
7836 HloModule module
7837
7838 ENTRY entry {
7839 %lhs = f32[24,8,100] parameter(0), sharding={devices=[1,2,2]0,1,2,3}
7840 %rhs = f32[32,8,100] parameter(1),
7841 sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}
7842 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
7843 lhs_batch_dims={}, rhs_batch_dims={},
7844 lhs_contracting_dims={1,2}, rhs_contracting_dims={1,2},
7845 sharding={replicated}
7846 })";
7847
7848 TF_ASSERT_OK_AND_ASSIGN(auto module,
7849 PartitionComputation(hlo_string, /*num_devices=*/4));
7850 VLOG(1) << module->ToString();
7851
7852 const auto lhs = AllOf(op::Shape("f32[24,4,50]"), op::Parameter(0));
7853 const auto rhs = AllOf(op::Shape("f32[32,8,50]"), op::Parameter(1));
7854 auto dot = AllOf(op::Shape("f32[24,32]"),
7855 op::Dot(lhs, AllOf(op::Shape("f32[32,4,50]"),
7856 op::DynamicSlice(rhs, _, _, _))));
7857 const auto root = module->entry_computation()->root_instruction();
7858 EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot)));
7859 }
7860
TEST_F(SpmdPartitioningTest,DotNonContractingPartialMatchContractingMatch)7861 TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) {
7862 absl::string_view hlo_string = R"(
7863 HloModule module
7864
7865 ENTRY entry {
7866 %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
7867 %rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3}
7868 ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
7869 lhs_batch_dims={}, rhs_batch_dims={},
7870 lhs_contracting_dims={2}, rhs_contracting_dims={0},
7871 sharding={devices=[2,2,1]0,1,2,3}
7872 })";
7873
7874 TF_ASSERT_OK_AND_ASSIGN(auto module,
7875 PartitionComputation(hlo_string, /*num_devices=*/4));
7876 VLOG(1) << module->ToString();
7877
7878 const auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0));
7879 const auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1));
7880 auto dot = AllOf(
7881 op::Shape("f32[12,8,50]"),
7882 op::Dot(lhs, AllOf(op::Shape("f32[50,50]"),
7883 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
7884 const auto root = module->entry_computation()->root_instruction();
7885 EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"),
7886 op::DynamicSlice(op::AllReduce(dot), _, _, _)))
7887 << module->ToString();
7888 }
7889
TEST_F(SpmdPartitioningTest,DotLHSMutiNonContractingRHSNotMatch)7890 TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) {
7891 absl::string_view hlo_string = R"(
7892 HloModule module
7893
7894 ENTRY entry {
7895 %lhs = f32[24,8,10] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
7896 %rhs = f32[10,50] parameter(1),
7897 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
7898 ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
7899 lhs_batch_dims={}, rhs_batch_dims={},
7900 lhs_contracting_dims={2}, rhs_contracting_dims={0},
7901 sharding={devices=[2,2,1]0,1,2,3}
7902 })";
7903
7904 TF_ASSERT_OK_AND_ASSIGN(auto module,
7905 PartitionComputation(hlo_string, /*num_devices=*/4));
7906 VLOG(1) << module->ToString();
7907
7908 const auto lhs = AllOf(op::Shape("f32[12,4,10]"), op::Parameter(0));
7909 const auto rhs = AllOf(op::Shape("f32[5,50]"), op::Parameter(1));
7910 auto dot = AllOf(
7911 op::Shape("f32[12,4,50]"),
7912 op::Dot(lhs, AllOf(op::Shape("f32[10,50]"),
7913 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
7914 const auto root = module->entry_computation()->root_instruction();
7915 EXPECT_THAT(root, dot) << module->ToString();
7916 }
7917
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_TileToReplicate)7918 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) {
7919 absl::string_view hlo_string = R"(
7920 HloModule module
7921
7922 ENTRY entry {
7923 constant = f32[6,3]{1,0}
7924 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
7925 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7926 constant.1 = f32[6,3]{1,0}
7927 constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
7928 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7929 multiply = f32[6,3]{1,0} multiply(constant, constant.1),
7930 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7931 ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
7932 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated, manual}}
7933 }
7934 )";
7935
7936 TF_ASSERT_OK_AND_ASSIGN(auto module,
7937 PartitionComputation(hlo_string, /*num_devices=*/4));
7938 VLOG(1) << module->ToString();
7939
7940 auto multiply_lhs =
7941 AllOf(op::Shape("f32[6,2]"),
7942 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
7943 op::Constant(), op::Reshape()));
7944 auto multiply_rhs =
7945 AllOf(op::Shape("f32[6,2]"),
7946 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
7947 op::Constant(), op::Reshape()));
7948 auto multiply =
7949 AllOf(op::Shape("f32[6,2]"), op::Multiply(multiply_lhs, multiply_rhs));
7950 auto replicated_lhs =
7951 AllOf(op::Shape("f32[6,3]"),
7952 op::Slice(op::AllReduce(op::DynamicUpdateSlice(
7953 op::Broadcast(), multiply, op::Constant(), op::Reshape()))));
7954 const auto root = module->entry_computation()->root_instruction();
7955 EXPECT_THAT(root, AllOf(op::Shape("f32[6,3]"),
7956 op::Add(replicated_lhs, op::Constant())));
7957 }
7958
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_ReplicateToTile)7959 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_ReplicateToTile) {
7960 absl::string_view hlo_string = R"(
7961 HloModule module
7962
7963 ENTRY entry {
7964 constant = f32[6,3]{1,0}
7965 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
7966 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
7967 constant.1 = f32[6,3]{1,0}
7968 constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
7969 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
7970 multiply = f32[6,3]{1,0} multiply(constant, constant.1),
7971 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
7972 ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
7973 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
7974 }
7975 )";
7976
7977 TF_ASSERT_OK_AND_ASSIGN(auto module,
7978 PartitionComputation(hlo_string, /*num_devices=*/4));
7979 VLOG(1) << module->ToString();
7980
7981 auto multiply = AllOf(op::Shape("f32[6,3]"),
7982 op::Multiply(op::Constant(), op::Constant()));
7983 auto add_lhs = AllOf(op::Shape("f32[6,2]"),
7984 op::DynamicSlice(op::Pad(multiply, op::Constant()),
7985 op::Constant(), op::Reshape()));
7986 auto add_rhs = AllOf(op::Shape("f32[6,2]"),
7987 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
7988 op::Constant(), op::Reshape()));
7989 const auto root = module->entry_computation()->root_instruction();
7990 EXPECT_THAT(root, AllOf(op::Shape("f32[6,2]"), op::Add(add_lhs, add_rhs)));
7991 }
7992
TEST_F(SpmdPartitioningTest,ElementwiseTest_PartialReplicateToTiledHaloExchange)7993 TEST_F(SpmdPartitioningTest,
7994 ElementwiseTest_PartialReplicateToTiledHaloExchange) {
7995 absl::string_view hlo_string = R"(
7996 HloModule module
7997
7998 ENTRY entry {
7999 constant = f32[6,3]{1,0}
8000 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
8001 sharding={replicated}
8002 constant.1 = f32[6,3]{1,0}
8003 constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
8004 sharding={replicated}
8005 multiply = f32[6,3]{1,0} multiply(constant, constant.1),
8006 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8007 ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
8008 sharding={devices=[4,1]0,1,2,3}
8009 }
8010 )";
8011
8012 TF_ASSERT_OK_AND_ASSIGN(auto module,
8013 PartitionComputation(hlo_string, /*num_devices=*/4));
8014 VLOG(1) << module->ToString();
8015 auto partial_replicate_lhs =
8016 AllOf(op::Shape("f32[3,3]"),
8017 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
8018 auto partial_replicate_rhs =
8019 AllOf(op::Shape("f32[3,3]"),
8020 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
8021 auto multiply =
8022 AllOf(op::Shape("f32[3,3]"),
8023 op::Multiply(partial_replicate_lhs, partial_replicate_rhs));
8024 auto right_halo =
8025 AllOf(op::Shape("f32[1,3]"), op::CollectivePermute(op::Slice(multiply)));
8026 auto add_lhs = AllOf(
8027 op::Shape("f32[2,3]"),
8028 op::DynamicSlice(
8029 op::DynamicSlice(
8030 op::Pad(op::Concatenate(multiply, right_halo), op::Constant()),
8031 op::Reshape(), op::Constant()),
8032 op::Subtract(), op::Subtract()));
8033 auto add_rhs = AllOf(op::Shape("f32[2,3]"),
8034 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
8035 op::Reshape(), op::Constant()));
8036 const auto root = module->entry_computation()->root_instruction();
8037 EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs)));
8038 }
8039
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshard)8040 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) {
8041 absl::string_view hlo_string = R"(
8042 HloModule module
8043
8044 ENTRY entry {
8045 %param0 = f32[8,8] parameter(0)
8046 %copy = f32[8,8] copy(%param0),
8047 sharding={devices=[2,2]0,1,2,3}
8048 ROOT %copy0 = f32[8,8] copy(%copy),
8049 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8050 })";
8051
8052 TF_ASSERT_OK_AND_ASSIGN(auto module,
8053 PartitionComputation(hlo_string, /*num_devices=*/4));
8054 VLOG(1) << module->ToString();
8055 auto tiled = AllOf(op::Shape("f32[4,4]"),
8056 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8057 op::Reshape())));
8058 auto partially_replicated = AllOf(
8059 op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice(
8060 op::Broadcast(_), tiled, _, _))));
8061 const auto root = module->entry_computation()->root_instruction();
8062 EXPECT_THAT(root, partially_replicated);
8063 }
8064
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshardUnevenPartition)8065 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
8066 absl::string_view hlo_string = R"(
8067 HloModule module
8068
8069 ENTRY entry {
8070 %param0 = f32[8,8] parameter(0),
8071 sharding={devices=[2,3]0,1,2,3,4,5}
8072 ROOT %copy0 = f32[8,8] copy(%param0),
8073 sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
8074 })";
8075
8076 TF_ASSERT_OK_AND_ASSIGN(auto module,
8077 PartitionComputation(hlo_string, /*num_devices=*/6));
8078 VLOG(1) << module->ToString();
8079 auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
8080 auto partially_replicated = AllOf(
8081 op::Shape("f32[8,4]"),
8082 op::Copy(op::Reshape(
8083 op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
8084 op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
8085 const auto root = module->entry_computation()->root_instruction();
8086 EXPECT_THAT(root, partially_replicated);
8087 }
8088
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshardUnevenPartition)8089 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
8090 absl::string_view hlo_string = R"(
8091 HloModule module
8092
8093 ENTRY entry {
8094 %param0 = f32[8,8] parameter(0),
8095 sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
8096 ROOT %copy0 = f32[8,8] copy(%param0),
8097 sharding={devices=[2,3]0,1,2,3,4,5}
8098 })";
8099
8100 TF_ASSERT_OK_AND_ASSIGN(auto module,
8101 PartitionComputation(hlo_string, /*num_devices=*/6));
8102 VLOG(1) << module->ToString();
8103 auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
8104 auto tiled = AllOf(
8105 op::Shape("f32[4,3]"),
8106 op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
8107 op::Reshape(partial_replicated)))),
8108 _),
8109 _, _)));
8110 const auto root = module->entry_computation()->root_instruction();
8111 EXPECT_THAT(root, tiled);
8112 }
8113
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshard)8114 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
8115 absl::string_view hlo_string = R"(
8116 HloModule module
8117
8118 ENTRY entry {
8119 %param0 = f32[8,8] parameter(0)
8120 %copy = f32[8,8] copy(%param0),
8121 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8122 ROOT %copy0 = f32[8,8] copy(%copy),
8123 sharding={devices=[2,2]0,1,2,3}
8124 })";
8125
8126 TF_ASSERT_OK_AND_ASSIGN(auto module,
8127 PartitionComputation(hlo_string, /*num_devices=*/4));
8128 VLOG(1) << module->ToString();
8129 auto partially_replicated =
8130 AllOf(op::Shape("f32[4,8]"),
8131 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8132 op::Constant())));
8133 auto tiled =
8134 AllOf(op::Shape("f32[4,4]"),
8135 op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
8136 op::Subtract())));
8137 const auto root = module->entry_computation()->root_instruction();
8138 EXPECT_THAT(root, tiled);
8139 }
8140
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_AllReduce)8141 TEST_F(SpmdPartitioningTest,
8142 PartialReplicateToPartialReplicateReshard_AllReduce) {
8143 absl::string_view hlo_string = R"(
8144 HloModule module
8145
8146 ENTRY entry {
8147 %param0 = f32[8,8] parameter(0)
8148 %copy = f32[8,8] copy(param0),
8149 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8150 ROOT %copy0 = f32[8,8] copy(%copy),
8151 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8152 })";
8153
8154 TF_ASSERT_OK_AND_ASSIGN(auto module,
8155 PartitionComputation(hlo_string, /*num_devices=*/8));
8156
8157 VLOG(1) << module->ToString();
8158 auto partially_replicated_init =
8159 AllOf(op::Shape("f32[4,4]"),
8160 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8161 op::Reshape())));
8162 auto partially_replicated =
8163 AllOf(op::Shape("f32[4,8]"),
8164 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
8165 op::Broadcast(_), partially_replicated_init, _, _))));
8166 const auto root = module->entry_computation()->root_instruction();
8167 EXPECT_THAT(root, partially_replicated);
8168 }
8169
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_DynamicSlice)8170 TEST_F(SpmdPartitioningTest,
8171 PartialReplicateToPartialReplicateReshard_DynamicSlice) {
8172 absl::string_view hlo_string = R"(
8173 HloModule module
8174
8175 ENTRY entry {
8176 %param0 = f32[8,8] parameter(0)
8177 %copy = f32[8,8] copy(%param0),
8178 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8179 ROOT %copy0 = f32[8,8] copy(%copy),
8180 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8181 })";
8182
8183 TF_ASSERT_OK_AND_ASSIGN(auto module,
8184 PartitionComputation(hlo_string, /*num_devices=*/8));
8185 VLOG(1) << module->ToString();
8186 auto partially_replicated =
8187 AllOf(op::Shape("f32[4,8]"),
8188 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8189 op::Constant())));
8190 auto tiled =
8191 AllOf(op::Shape("f32[4,4]"),
8192 op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
8193 op::Subtract())));
8194 const auto root = module->entry_computation()->root_instruction();
8195 EXPECT_THAT(root, tiled);
8196 }
8197
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardWithCollectivePermute)8198 TEST_F(SpmdPartitioningTest,
8199 PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
8200 absl::string_view hlo_string = R"(
8201 HloModule module
8202
8203 ENTRY entry {
8204 %param0 = f32[8,8] parameter(0)
8205 %copy = f32[8,8] copy(param0),
8206 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8207 ROOT %copy0 = f32[8,8] copy(%copy),
8208 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8209 })";
8210
8211 TF_ASSERT_OK_AND_ASSIGN(auto module,
8212 PartitionComputation(hlo_string, /*num_devices=*/8));
8213
8214 VLOG(1) << module->ToString();
8215 auto partially_replicated_init =
8216 AllOf(op::Shape("f32[4,4]"),
8217 op::CollectivePermute(op::Copy(op::DynamicSlice(
8218 op::Parameter(0), op::Reshape(), op::Reshape()))));
8219 auto partially_replicated =
8220 AllOf(op::Shape("f32[8,4]"),
8221 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
8222 op::Broadcast(_), partially_replicated_init, _, _))));
8223 const auto root = module->entry_computation()->root_instruction();
8224 EXPECT_THAT(root, partially_replicated);
8225 }
8226
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardCollectivePermute1)8227 TEST_F(SpmdPartitioningTest,
8228 PartialReplicateToPartialReplicateReshardCollectivePermute1) {
8229 absl::string_view hlo_string = R"(
8230 HloModule module
8231
8232 ENTRY entry {
8233 %param0 = f32[8,8] parameter(0)
8234 %copy = f32[8,8] copy(%param0),
8235 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8236 ROOT %copy0 = f32[8,8] copy(%copy),
8237 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8238 })";
8239
8240 TF_ASSERT_OK_AND_ASSIGN(auto module,
8241 PartitionComputation(hlo_string, /*num_devices=*/8));
8242 VLOG(1) << module->ToString();
8243 auto partially_replicated =
8244 AllOf(op::Shape("f32[8,4]"),
8245 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
8246 op::Reshape())));
8247 auto tiled =
8248 AllOf(op::Shape("f32[4,4]"),
8249 op::Copy(op::CollectivePermute(op::DynamicSlice(
8250 partially_replicated, op::Subtract(), op::Subtract()))));
8251 const auto root = module->entry_computation()->root_instruction();
8252 EXPECT_THAT(root, tiled);
8253 }
8254
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange)8255 TEST_F(SpmdPartitioningTest,
8256 PartialReplicateToPartialReplicateReshardHaloExchange) {
8257 absl::string_view hlo_string = R"(
8258 HloModule module
8259
8260 ENTRY entry {
8261 %param0 = f32[6,3] parameter(0)
8262 %copy = f32[6,3] copy(param0),
8263 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8264 ROOT %copy0 = f32[6,3] copy(%copy),
8265 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8266 })";
8267
8268 TF_ASSERT_OK_AND_ASSIGN(auto module,
8269 PartitionComputation(hlo_string, /*num_devices=*/8));
8270
8271 VLOG(1) << module->ToString();
8272 auto partially_replicated_init =
8273 AllOf(op::Shape("f32[2,3]"),
8274 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
8275 op::Reshape(), op::Constant())));
8276 auto slice =
8277 AllOf(op::Shape("f32[2,3]"),
8278 op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice(
8279 partially_replicated_init)),
8280 partially_replicated_init),
8281 _, _));
8282 auto partially_replicated =
8283 AllOf(op::Shape("f32[3,3]"),
8284 op::Copy(op::Slice(op::AllReduce(
8285 op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _)))));
8286 const auto root = module->entry_computation()->root_instruction();
8287 EXPECT_THAT(root, partially_replicated);
8288 }
8289
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange1)8290 TEST_F(SpmdPartitioningTest,
8291 PartialReplicateToPartialReplicateReshardHaloExchange1) {
8292 absl::string_view hlo_string = R"(
8293 HloModule module
8294
8295 ENTRY entry {
8296 %param0 = f32[6,3] parameter(0)
8297 %copy = f32[6,3] copy(param0),
8298 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8299 ROOT %copy0 = f32[6,3] copy(%copy),
8300 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
8301 })";
8302
8303 TF_ASSERT_OK_AND_ASSIGN(auto module,
8304 PartitionComputation(hlo_string, /*num_devices=*/8));
8305
8306 VLOG(1) << module->ToString();
8307 auto partially_replicated_init =
8308 AllOf(op::Shape("f32[3,3]"),
8309 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
8310 op::Constant())));
8311 auto slice = AllOf(
8312 op::Shape("f32[4,3]"),
8313 op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init,
8314 op::CollectivePermute(op::Slice(
8315 partially_replicated_init))),
8316 op::Constant()),
8317 _, _));
8318 auto partially_replicated =
8319 AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _)));
8320 const auto root = module->entry_computation()->root_instruction();
8321 EXPECT_THAT(root, partially_replicated);
8322 }
8323
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCount)8324 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCount) {
8325 absl::string_view hlo_string = R"(
8326 HloModule module
8327
8328 ENTRY entry {
8329 %lhs = f32[16,801,1,1024] parameter(0)
8330 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8331 sharding={devices=[1,1,1,2]0,1}
8332 %rhs = f32[16,801,1,1024] parameter(1)
8333 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8334 sharding={devices=[1,1,1,2]0,1}
8335 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8336 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8337 window={size=801x1 pad=2_2x0_0},
8338 sharding={devices=[1,1,1,2]0,1}
8339 })";
8340
8341 TF_ASSERT_OK_AND_ASSIGN(auto module,
8342 PartitionComputation(hlo_string, /*num_devices=*/2));
8343
8344 VLOG(1) << module->ToString();
8345 const auto root = module->entry_computation()->root_instruction();
8346 const auto lhs = AllOf(
8347 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8348 op::Constant(), op::Reshape())),
8349 op::Shape("f32[16,801,1,512]"));
8350 const auto rhs = AllOf(
8351 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8352 op::Constant(), op::Reshape())),
8353 op::Shape("f32[16,801,1,512]"));
8354 EXPECT_THAT(root,
8355 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]")));
8356 }
8357
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountRHSAlignWithLHS)8358 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountRHSAlignWithLHS) {
8359 absl::string_view hlo_string = R"(
8360 HloModule module
8361
8362 ENTRY entry {
8363 %lhs = f32[16,801,1,1024] parameter(0)
8364 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8365 sharding={devices=[1,1,1,2]0,1}
8366 %rhs = f32[16,801,1,1024] parameter(1)
8367 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8368 sharding={devices=[1,2,1,1]0,1}
8369 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8370 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8371 window={size=801x1 pad=2_2x0_0},
8372 sharding={devices=[1,1,1,2]0,1}
8373 })";
8374
8375 TF_ASSERT_OK_AND_ASSIGN(auto module,
8376 PartitionComputation(hlo_string, /*num_devices=*/2));
8377 VLOG(1) << module->ToString();
8378 const auto root = module->entry_computation()->root_instruction();
8379 const auto lhs = AllOf(
8380 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8381 op::Constant(), op::Reshape())),
8382 op::Shape("f32[16,801,1,512]"));
8383 const auto rhs =
8384 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8385 op::Constant(), op::Reshape(),
8386 op::Constant(), op::Constant())),
8387 op::Shape("f32[16,401,1,1024]"));
8388 auto resharded_rhs = AllOf(
8389 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
8390 op::Shape("f32[16,801,1,512]"));
8391 EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
8392 op::Shape("f32[5,1,1,512]")));
8393 }
8394
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountLHSAlignWithRHS)8395 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountLHSAlignWithRHS) {
8396 absl::string_view hlo_string = R"(
8397 HloModule module
8398
8399 ENTRY entry {
8400 %lhs = f32[16,801,1,1024] parameter(0)
8401 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8402 sharding={devices=[1,2,1,1]0,1}
8403 %rhs = f32[16,801,1,1024] parameter(1)
8404 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8405 sharding={devices=[1,1,1,2]0,1}
8406 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8407 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8408 window={size=801x1 pad=2_2x0_0},
8409 sharding={devices=[1,1,1,2]0,1}
8410 })";
8411
8412 TF_ASSERT_OK_AND_ASSIGN(auto module,
8413 PartitionComputation(hlo_string, /*num_devices=*/2));
8414 VLOG(1) << module->ToString();
8415 const auto root = module->entry_computation()->root_instruction();
8416 const auto rhs = AllOf(
8417 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8418 op::Constant(), op::Reshape())),
8419 op::Shape("f32[16,801,1,512]"));
8420 const auto lhs =
8421 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8422 op::Constant(), op::Reshape(),
8423 op::Constant(), op::Constant())),
8424 op::Shape("f32[16,401,1,1024]"));
8425 auto resharded_lhs = AllOf(
8426 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8427 op::Shape("f32[16,801,1,512]"));
8428 EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
8429 op::Shape("f32[5,1,1,512]")));
8430 }
8431
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithLHS)8432 TEST_F(SpmdPartitioningTest,
8433 PartitionConvWithBathGroupCountOutputAlignWithLHS) {
8434 absl::string_view hlo_string = R"(
8435 HloModule module
8436
8437 ENTRY entry {
8438 %lhs = f32[16,801,1,1024] parameter(0)
8439 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8440 sharding={devices=[1,1,1,2]0,1}
8441 %rhs = f32[16,801,1,1024] parameter(1)
8442 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8443 sharding={devices=[1,1,1,2]0,1}
8444 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8445 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8446 window={size=801x1 pad=2_2x0_0},
8447 sharding={devices=[2,1,1,1]0,1}
8448 })";
8449
8450 TF_ASSERT_OK_AND_ASSIGN(auto module,
8451 PartitionComputation(hlo_string, /*num_devices=*/2));
8452 VLOG(1) << module->ToString();
8453 const auto root = module->entry_computation()->root_instruction();
8454 const auto lhs = AllOf(
8455 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8456 op::Constant(), op::Reshape())),
8457 op::Shape("f32[16,801,1,512]"));
8458 const auto rhs = AllOf(
8459 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8460 op::Constant(), op::Reshape())),
8461 op::Shape("f32[16,801,1,512]"));
8462 auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]"));
8463 EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
8464 op::Reshape(op::Pad(conv, op::Constant()))))),
8465 op::Shape("f32[3,1,1,1024]")));
8466 }
8467
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithRHS)8468 TEST_F(SpmdPartitioningTest,
8469 PartitionConvWithBathGroupCountOutputAlignWithRHS) {
8470 absl::string_view hlo_string = R"(
8471 HloModule module
8472
8473 ENTRY entry {
8474 %lhs = f32[16,801,1,1024] parameter(0)
8475 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8476 sharding={devices=[1,2,1,1]0,1}
8477 %rhs = f32[16,801,1,1024] parameter(1)
8478 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8479 sharding={devices=[1,1,1,2]0,1}
8480 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8481 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8482 window={size=801x1 pad=2_2x0_0},
8483 sharding={devices=[2,1,1,1]0,1}
8484 })";
8485
8486 TF_ASSERT_OK_AND_ASSIGN(auto module,
8487 PartitionComputation(hlo_string, /*num_devices=*/2));
8488 VLOG(1) << module->ToString();
8489 const auto root = module->entry_computation()->root_instruction();
8490 const auto rhs = AllOf(
8491 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8492 op::Constant(), op::Reshape())),
8493 op::Shape("f32[16,801,1,512]"));
8494 const auto lhs =
8495 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8496 op::Constant(), op::Reshape(),
8497 op::Constant(), op::Constant())),
8498 op::Shape("f32[16,401,1,1024]"));
8499 auto resharded_lhs = AllOf(
8500 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8501 op::Shape("f32[16,801,1,512]"));
8502 auto conv =
8503 AllOf(op::Convolution(resharded_lhs, rhs), op::Shape("f32[5,1,1,512]"));
8504 EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
8505 op::Reshape(op::Pad(conv, op::Constant()))))),
8506 op::Shape("f32[3,1,1,1024]")));
8507 }
8508
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount)8509 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount) {
8510 absl::string_view hlo_string = R"(
8511 HloModule module
8512
8513 ENTRY entry {
8514 %lhs = f32[16,801,1,1024] parameter(0)
8515 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8516 sharding={devices=[1,1,1,2]0,1}
8517 %rhs = f32[5,1,1,2048] parameter(1)
8518 %rhs.copy = f32[5,1,1,2048] copy(%rhs),
8519 sharding={devices=[1,1,1,2]0,1}
8520 ROOT %conv = f32[16,801,1,2048] convolution(%lhs.copy, %rhs.copy),
8521 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8522 window={size=5x1 pad=2_2x0_0},
8523 sharding={devices=[1,1,1,2]0,1}
8524 })";
8525
8526 TF_ASSERT_OK_AND_ASSIGN(auto module,
8527 PartitionComputation(hlo_string, /*num_devices=*/2));
8528 VLOG(1) << module->ToString();
8529 const auto root = module->entry_computation()->root_instruction();
8530 const auto lhs = AllOf(
8531 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8532 op::Constant(), op::Reshape())),
8533 op::Shape("f32[16,801,1,512]"));
8534 const auto rhs = AllOf(
8535 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8536 op::Constant(), op::Reshape())),
8537 op::Shape("f32[5,1,1,1024]"));
8538 EXPECT_THAT(
8539 root, AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,1024]")));
8540 }
8541
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount2)8542 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount2) {
8543 absl::string_view hlo_string = R"(
8544 HloModule module
8545
8546 ENTRY entry {
8547 %lhs = f32[64,3,1,3072] parameter(0)
8548 %lhs.copy = f32[64,3,1,3072] copy(%lhs),
8549 sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
8550 ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
8551 %rhs = f32[3,1,1,3072] parameter(1)
8552 %rhs.copy = f32[3,1,1,3072] copy(%rhs),
8553 sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
8554 ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
8555 ROOT %conv = f32[64,1,1,3072] convolution(%lhs.copy, %rhs.copy),
8556 dim_labels=b01f_01io->b01f,feature_group_count=3072,
8557 window={size=3x1},
8558 sharding={devices=[8,1,1,4]0,16,24,8,2,18,26,10,4,20,28,12,6,22,30,14,7,23,
8559 31,15,5,21,29,13,3,19,27,11,1,17,25,9}
8560 })";
8561
8562 TF_ASSERT_OK_AND_ASSIGN(auto module,
8563 PartitionComputation(hlo_string, /*num_devices=*/32));
8564 VLOG(1) << module->ToString();
8565 const auto root = module->entry_computation()->root_instruction();
8566 const auto lhs =
8567 AllOf(op::DynamicSlice(
8568 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
8569 op::Constant(), op::Constant(),
8570 op::Reshape())),
8571 op::Reshape(), op::Constant(), op::Constant(), op::Constant()),
8572 op::Shape("f32[8,3,1,768]"));
8573 const auto rhs = AllOf(
8574 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8575 op::Constant(), op::Reshape())),
8576 op::Shape("f32[3,1,1,768]"));
8577 EXPECT_THAT(root,
8578 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[8,1,1,768]")));
8579 }
8580
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountRHSAlignWithLHS)8581 TEST_F(SpmdPartitioningTest,
8582 PartitionConvWithFeatureGroupCountRHSAlignWithLHS) {
8583 absl::string_view hlo_string = R"(
8584 HloModule module
8585
8586 ENTRY entry {
8587 %lhs = f32[16,801,1,1024] parameter(0)
8588 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8589 sharding={devices=[1,1,1,2]0,1}
8590 %rhs = f32[5,1,1,1024] parameter(1)
8591 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8592 sharding={devices=[2,1,1,1]0,1}
8593 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8594 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8595 window={size=5x1 pad=2_2x0_0},
8596 sharding={devices=[1,1,1,2]0,1}
8597 })";
8598
8599 TF_ASSERT_OK_AND_ASSIGN(auto module,
8600 PartitionComputation(hlo_string, /*num_devices=*/2));
8601 VLOG(1) << module->ToString();
8602 const auto root = module->entry_computation()->root_instruction();
8603 const auto lhs = AllOf(
8604 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8605 op::Constant(), op::Reshape())),
8606 op::Shape("f32[16,801,1,512]"));
8607 const auto rhs =
8608 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8609 op::Reshape(), op::Constant(),
8610 op::Constant(), op::Constant())),
8611 op::Shape("f32[3,1,1,1024]"));
8612 auto resharded_rhs = AllOf(
8613 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
8614 op::Shape("f32[5,1,1,512]"));
8615 EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
8616 op::Shape("f32[16,801,1,512]")));
8617 }
8618
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountLHSAlignWithRHS)8619 TEST_F(SpmdPartitioningTest,
8620 PartitionConvWithFeatureGroupCountLHSAlignWithRHS) {
8621 absl::string_view hlo_string = R"(
8622 HloModule module
8623
8624 ENTRY entry {
8625 %lhs = f32[16,801,1,1024] parameter(0)
8626 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8627 sharding={devices=[1,2,1,1]0,1}
8628 %rhs = f32[5,1,1,1024] parameter(1)
8629 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8630 sharding={devices=[1,1,1,2]0,1}
8631 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8632 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8633 window={size=5x1 pad=2_2x0_0},
8634 sharding={devices=[1,1,1,2]0,1}
8635 })";
8636
8637 TF_ASSERT_OK_AND_ASSIGN(auto module,
8638 PartitionComputation(hlo_string, /*num_devices=*/2));
8639 VLOG(1) << module->ToString();
8640 const auto root = module->entry_computation()->root_instruction();
8641 const auto lhs =
8642 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8643 op::Constant(), op::Reshape(),
8644 op::Constant(), op::Constant())),
8645 op::Shape("f32[16,401,1,1024]"));
8646 const auto rhs = AllOf(
8647 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8648 op::Constant(), op::Reshape())),
8649 op::Shape("f32[5,1,1,512]"));
8650 auto resharded_lhs = AllOf(
8651 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8652 op::Shape("f32[16,801,1,512]"));
8653 EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
8654 op::Shape("f32[16,801,1,512]")));
8655 }
8656
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithLHS)8657 TEST_F(SpmdPartitioningTest,
8658 PartitionConvWithFeatureGroupCountAlignOuputWithLHS) {
8659 absl::string_view hlo_string = R"(
8660 HloModule module
8661
8662 ENTRY entry {
8663 %lhs = f32[16,801,1,1024] parameter(0)
8664 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8665 sharding={devices=[1,1,1,2]0,1}
8666 %rhs = f32[5,1,1,1024] parameter(1)
8667 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8668 sharding={devices=[1,1,1,2]0,1}
8669 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8670 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8671 window={size=5x1 pad=2_2x0_0},
8672 sharding={devices=[2,1,1,1]0,1}
8673 })";
8674
8675 TF_ASSERT_OK_AND_ASSIGN(auto module,
8676 PartitionComputation(hlo_string, /*num_devices=*/2));
8677 VLOG(1) << module->ToString();
8678 const auto root = module->entry_computation()->root_instruction();
8679 const auto lhs = AllOf(
8680 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8681 op::Constant(), op::Reshape())),
8682 op::Shape("f32[16,801,1,512]"));
8683 const auto rhs = AllOf(
8684 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8685 op::Constant(), op::Reshape())),
8686 op::Shape("f32[5,1,1,512]"));
8687 auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"));
8688 EXPECT_THAT(root,
8689 AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
8690 op::Shape("f32[8,801,1,1024]")));
8691 }
8692
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate)8693 TEST_F(SpmdPartitioningTest,
8694 PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate) {
8695 absl::string_view hlo_string = R"(
8696 HloModule module
8697
8698 ENTRY entry {
8699 %lhs = f32[16,801,1,1024] parameter(0)
8700 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8701 sharding={devices=[1,2,1,2]0,1,2,3}
8702 %rhs = f32[5,1,1,1024] parameter(1)
8703 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8704 sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
8705 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8706 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8707 window={size=5x1 pad=2_2x0_0},
8708 sharding={devices=[1,2,1,2]0,1,2,3}
8709 })";
8710
8711 TF_ASSERT_OK_AND_ASSIGN(auto module,
8712 PartitionComputation(hlo_string, /*num_devices=*/4));
8713 VLOG(1) << module->ToString();
8714 const auto root = module->entry_computation()->root_instruction();
8715 const auto lhs =
8716 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8717 op::Constant(), op::Reshape(),
8718 op::Constant(), op::Reshape())),
8719 op::Shape("f32[16,401,1,512]"));
8720 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8721 op::CollectivePermute(op::Slice(lhs)));
8722 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8723 op::CollectivePermute(op::Slice(lhs)));
8724 const auto rhs = AllOf(
8725 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8726 op::Constant(), op::Reshape())),
8727 op::Shape("f32[5,1,1,512]"));
8728 EXPECT_THAT(
8729 root,
8730 AllOf(op::Convolution(
8731 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
8732 rhs),
8733 op::Shape("f32[16, 401, 1, 512]")));
8734 }
8735
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput)8736 TEST_F(SpmdPartitioningTest,
8737 PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput) {
8738 absl::string_view hlo_string = R"(
8739 HloModule module
8740
8741 ENTRY entry {
8742 %lhs = f32[16,801,1,1024] parameter(0)
8743 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8744 sharding={devices=[1,2,1,2]0,1,2,3}
8745 %rhs = f32[5,1,1,1024] parameter(1), sharding={replicated}
8746 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs),
8747 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8748 window={size=5x1 pad=2_2x0_0},
8749 sharding={devices=[1,2,1,2]0,1,2,3}
8750 })";
8751 TF_ASSERT_OK_AND_ASSIGN(auto module,
8752 PartitionComputation(hlo_string, /*num_devices=*/4));
8753 VLOG(1) << module->ToString();
8754 const auto root = module->entry_computation()->root_instruction();
8755 const auto lhs =
8756 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8757 op::Constant(), op::Reshape(),
8758 op::Constant(), op::Reshape())),
8759 op::Shape("f32[16,401,1,512]"));
8760 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8761 op::CollectivePermute(op::Slice(lhs)));
8762 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8763 op::CollectivePermute(op::Slice(lhs)));
8764 const auto rhs =
8765 AllOf(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8766 op::Constant(), op::Reshape()),
8767 op::Shape("f32[5,1,1,512]"));
8768 EXPECT_THAT(
8769 root,
8770 AllOf(op::Convolution(
8771 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
8772 rhs),
8773 op::Shape("f32[16, 401, 1, 512]")));
8774 }
8775
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput)8776 TEST_F(SpmdPartitioningTest,
8777 PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput) {
8778 absl::string_view hlo_string = R"(
8779 HloModule module
8780
8781 ENTRY entry {
8782 %lhs = f32[16,801,1,1024] parameter(0)
8783 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8784 sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
8785 %rhs = f32[5,1,1,1024] parameter(1)
8786 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8787 sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
8788 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8789 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8790 window={size=5x1 pad=2_2x0_0},
8791 sharding={devices=[1,2,1,2]0,1,2,3}
8792 })";
8793 TF_ASSERT_OK_AND_ASSIGN(auto module,
8794 PartitionComputation(hlo_string, /*num_devices=*/4));
8795 VLOG(1) << module->ToString();
8796 const auto root = module->entry_computation()->root_instruction();
8797 const auto lhs = AllOf(
8798 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
8799 op::Constant(), op::Constant())),
8800 op::Shape("f32[8,801,1,1024]"));
8801 auto resharded_lhs =
8802 AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(
8803 op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(),
8804 op::Subtract(), op::Subtract()),
8805 op::Constant()))))),
8806 op::Shape("f32[16,401,1,512]"));
8807 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8808 op::CollectivePermute(op::Slice(resharded_lhs)));
8809 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8810 op::CollectivePermute(op::Slice(resharded_lhs)));
8811 const auto rhs = AllOf(
8812 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8813 op::Constant(), op::Reshape())),
8814 op::Shape("f32[5,1,1,512]"));
8815 EXPECT_THAT(
8816 root,
8817 AllOf(
8818 op::Convolution(
8819 op::Select(
8820 _, op::Concatenate(left_halo, resharded_lhs, right_halo), _),
8821 rhs),
8822 op::Shape("f32[16, 401, 1, 512]")));
8823 }
8824
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnBatchGroupCount)8825 TEST_F(SpmdPartitioningTest, PartitionConvGroupOnBatchGroupCount) {
8826 absl::string_view hlo_string = R"(
8827 HloModule module
8828
8829 ENTRY entry {
8830 %lhs = f32[16,801,1,1024] parameter(0)
8831 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8832 sharding={devices=[1,2,1,2]0,1,2,3}
8833 %rhs = f32[16,801,1,1024] parameter(1)
8834 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
8835 sharding={devices=[1,2,1,2]0,1,2,3}
8836 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
8837 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
8838 window={size=801x1 pad=2_2x0_0},
8839 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
8840 })";
8841
8842 TF_ASSERT_OK_AND_ASSIGN(auto module,
8843 PartitionComputation(hlo_string, /*num_devices=*/4));
8844 VLOG(1) << module->ToString();
8845 const auto root = module->entry_computation()->root_instruction();
8846 const auto lhs = AllOf(
8847 op::Select(_,
8848 op::Copy(op::DynamicSlice(
8849 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
8850 op::Reshape(), op::Constant(), op::Reshape())),
8851 _),
8852 op::Shape("f32[16,401,1,512]"));
8853 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8854 op::CollectivePermute(op::Slice(lhs)));
8855 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
8856 op::CollectivePermute(op::Slice(lhs)));
8857 const auto rhs =
8858 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8859 op::Constant(), op::Reshape(),
8860 op::Constant(), op::Reshape())),
8861 op::Shape("f32[16,401,1,512]"));
8862 auto conv = AllOf(op::Convolution(op::Concatenate(left_halo, lhs, right_halo),
8863 op::Select(_, rhs, _)),
8864 op::Shape("f32[5,1,1,512]"));
8865 EXPECT_THAT(root, AllOf(op::CollectivePermute(op::AllReduce(conv)),
8866 op::Shape("f32[5,1,1,512]")));
8867 }
8868
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithRHS)8869 TEST_F(SpmdPartitioningTest,
8870 PartitionConvWithFeatureGroupCountAlignOuputWithRHS) {
8871 absl::string_view hlo_string = R"(
8872 HloModule module
8873
8874 ENTRY entry {
8875 %lhs = f32[16,801,1,1024] parameter(0)
8876 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8877 sharding={devices=[1,2,1,1]0,1}
8878 %rhs = f32[5,1,1,1024] parameter(1)
8879 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
8880 sharding={devices=[1,1,1,2]0,1}
8881 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8882 dim_labels=b01f_01io->b01f,feature_group_count=1024,
8883 window={size=5x1 pad=2_2x0_0},
8884 sharding={devices=[2,1,1,1]0,1}
8885 })";
8886
8887 TF_ASSERT_OK_AND_ASSIGN(auto module,
8888 PartitionComputation(hlo_string, /*num_devices=*/2));
8889 VLOG(1) << module->ToString();
8890 const auto root = module->entry_computation()->root_instruction();
8891 const auto lhs =
8892 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
8893 op::Constant(), op::Reshape(),
8894 op::Constant(), op::Constant())),
8895 op::Shape("f32[16,401,1,1024]"));
8896 const auto rhs = AllOf(
8897 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8898 op::Constant(), op::Reshape())),
8899 op::Shape("f32[5,1,1,512]"));
8900 auto resharded_lhs = AllOf(
8901 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
8902 op::Shape("f32[16,801,1,512]"));
8903 auto conv = AllOf(op::Convolution(resharded_lhs, rhs),
8904 op::Shape("f32[16,801,1,512]"));
8905 EXPECT_THAT(root,
8906 AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
8907 op::Shape("f32[8,801,1,1024]")));
8908 }
8909
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountBackProp)8910 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountBackProp) {
8911 absl::string_view hlo_string = R"(
8912 HloModule module
8913
8914 ENTRY entry {
8915 %lhs = f32[16,801,1,1024] parameter(0)
8916 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
8917 sharding={devices=[1,1,1,2]0,1}
8918 %rhs = f32[5,1,1024,1] parameter(1)
8919 %rhs.copy = f32[5,1,1024,1] copy(%rhs),
8920 sharding={devices=[1,1,2,1]0,1}
8921 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
8922 dim_labels=b01f_01oi->b01f,feature_group_count=1024,
8923 window={size=5x1 pad=2_2x0_0 rhs_reversal=1x1},
8924 sharding={devices=[1,1,1,2]0,1}
8925 })";
8926
8927 TF_ASSERT_OK_AND_ASSIGN(auto module,
8928 PartitionComputation(hlo_string, /*num_devices=*/2));
8929 VLOG(1) << module->ToString();
8930 const auto root = module->entry_computation()->root_instruction();
8931 const auto lhs = AllOf(
8932 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8933 op::Constant(), op::Reshape())),
8934 op::Shape("f32[16,801,1,512]"));
8935 const auto rhs = AllOf(
8936 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
8937 op::Reshape(), op::Constant())),
8938 op::Shape("f32[5,1,512,1]"));
8939 EXPECT_THAT(root,
8940 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")));
8941 }
8942
TEST_F(SpmdPartitioningTest,NoReshardOnBroadcastDims)8943 TEST_F(SpmdPartitioningTest, NoReshardOnBroadcastDims) {
8944 absl::string_view hlo_string = R"(
8945 HloModule module
8946
8947 ENTRY entry {
8948 %param0 = f32[2,3] parameter(0)
8949 %param1 = f32[2,3,20] parameter(1)
8950 %br0 = f32[20,2,20,3,20] broadcast(%param0), dimensions={1,3}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8951 %br1 = f32[20,2,20,3,20] broadcast(%param1), dimensions={1,3,4}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8952 %add = f32[20,2,20,3,20] add(%br0, %br1), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8953 %reshape = f32[10,4,10,6,20] reshape(%br0), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
8954 %transpose = f32[2,3,20,20,20] transpose(%br0), dimensions={1,3,0,2,4}, sharding={devices=[1,1,2,2,2]0,1,2,3,4,5,6,7}
8955 %copy_add0 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]6,7,2,3,4,5,0,1}
8956 %copy_add1 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
8957 %copy_reshape = f32[10,4,10,6,20] copy(%reshape), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
8958 %copy_transpose = f32[2,3,20,20,20] copy(%transpose), sharding={devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}
8959 ROOT %tuple = (f32[20,2,20,3,20], f32[20,2,20,3,20], f32[10,4,10,6,20], f32[2,3,20,20,20])
8960 tuple(%copy_add0, %copy_add1, %copy_reshape, %copy_transpose),
8961 sharding={{devices=[2,1,2,1,2]6,7,2,3,4,5,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}}
8962 })";
8963
8964 TF_ASSERT_OK_AND_ASSIGN(auto module,
8965 PartitionComputation(hlo_string, /*num_devices=*/8));
8966 VLOG(1) << module->ToString();
8967 const auto root = module->entry_computation()->root_instruction();
8968 // Reshard on copy_add0 only happens on broadcast dims, can be skipped.
8969 auto copy_add0 =
8970 op::Copy(op::Copy(op::Add(op::Broadcast(_), op::Broadcast(_))));
8971 // Reshard on copy_add1 also happens on non-broadcast dims.
8972 auto copy_add1 = op::Copy(
8973 op::CollectivePermute(op::Add(op::Broadcast(_), op::Broadcast(_))));
8974 // Reshard on copy_reshape only happens on broadcast dims, can be skipped.
8975 auto copy_reshape = op::Copy(op::Copy(op::Reshape(op::Broadcast(_))));
8976 // Reshard on copy_transpose only happens on broadcast dims, can be skipped.
8977 auto copy_transpose = op::Copy(op::Copy(op::Transpose(op::Broadcast(_))));
8978 EXPECT_THAT(root,
8979 op::Tuple(copy_add0, copy_add1, copy_reshape, copy_transpose));
8980 }
8981
TEST_F(SpmdPartitioningTest,ConvolutionFilterIFOFPartitionedInputPartialReplicate)8982 TEST_F(SpmdPartitioningTest,
8983 ConvolutionFilterIFOFPartitionedInputPartialReplicate) {
8984 absl::string_view hlo_string = R"(
8985 HloModule module
8986
8987 ENTRY entry {
8988 %lhs = f32[128,112,112,12] parameter(0)
8989 %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs),
8990 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
8991 %rhs = f32[7,7,12,64] parameter(1)
8992 %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs),
8993 sharding={devices=[1,1,2,2]0,1,2,3}
8994 ROOT %conv = f32[128,56,56,64] convolution(
8995 f32[128,112,112,12] %lhs.copy,
8996 f32[7,7,12,64] %rhs.copy),
8997 window={size=7x7 stride=2x2 pad=3_3x3_3},
8998 dim_labels=b01f_01io->b01f,
8999 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
9000 })";
9001
9002 TF_ASSERT_OK_AND_ASSIGN(auto module,
9003 PartitionComputation(hlo_string, /*num_devices=*/4));
9004 VLOG(1) << module->ToString();
9005 const auto root = module->entry_computation()->root_instruction();
9006 const auto lhs = AllOf(
9007 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9008 op::Constant(), op::Reshape())),
9009 op::Shape("f32[128,112,112,6]"));
9010 const auto rhs = AllOf(
9011 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9012 op::Reshape(), op::Reshape())),
9013 op::Shape("f32[7,7,6,32]"));
9014
9015 EXPECT_THAT(
9016 root,
9017 AllOf(op::CollectivePermute(op::AllReduce(op::Convolution(lhs, rhs))),
9018 op::Shape("f32[128,56,56,32]")));
9019 }
9020
TEST_F(SpmdPartitioningTest,ConvolutionInputKernelNonContractingDimPartialReplicate)9021 TEST_F(SpmdPartitioningTest,
9022 ConvolutionInputKernelNonContractingDimPartialReplicate) {
9023 absl::string_view hlo_string = R"(
9024 HloModule module
9025
9026 ENTRY entry {
9027 %lhs = f32[128,56,56,256] parameter(0)
9028 %lhs.copy = f32[128,56,56,256] copy(%lhs),
9029 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
9030 %rhs = f32[128,28,28,512] parameter(1)
9031 %rhs.copy = f32[128,28,28,512] copy(%rhs),
9032 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
9033 ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
9034 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf,
9035 sharding={devices=[1,1,2,2]0,1,2,3}
9036 })";
9037
9038 TF_ASSERT_OK_AND_ASSIGN(auto module,
9039 PartitionComputation(hlo_string, /*num_devices=*/4));
9040 VLOG(1) << module->ToString();
9041 const auto root = module->entry_computation()->root_instruction();
9042 const auto lhs = AllOf(
9043 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9044 op::Constant(), op::Reshape())),
9045 op::Shape("f32[128,56,56,128]"));
9046 const auto rhs = AllOf(
9047 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9048 op::Constant(), op::Reshape())),
9049 op::Shape("f32[128,28,28,256]"));
9050
9051 EXPECT_THAT(root, AllOf(op::Convolution(lhs, op::CollectivePermute(rhs)),
9052 op::Shape("f32[1,1,128,256]")));
9053 }
9054
TEST_F(SpmdPartitioningTest,ConvolutionInputSpatialDimAndFeatureDimParttiioned)9055 TEST_F(SpmdPartitioningTest,
9056 ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
9057 absl::string_view hlo_string = R"(
9058 HloModule module
9059
9060 ENTRY entry {
9061 %lhs = f32[8,210,210,12] parameter(0)
9062 %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
9063 sharding={devices=[1,2,1,2]0,1,2,3}
9064 %rhs = f32[3,3,12,32] parameter(1)
9065 %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
9066 sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
9067 ROOT %conv = f32[8,210,210,32] convolution(
9068 f32[8,210,210,12] %lhs.copy,
9069 f32[3,3,12,32] %rhs.copy),
9070 window={size=3x3 pad=1_1x1_1},
9071 dim_labels=b01f_01io->b01f,
9072 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
9073 })";
9074 TF_ASSERT_OK_AND_ASSIGN(auto module,
9075 PartitionComputation(hlo_string, /*num_devices=*/4));
9076 VLOG(1) << module->ToString();
9077 const auto root = module->entry_computation()->root_instruction();
9078 const auto lhs = AllOf(
9079 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
9080 op::Constant(), op::Reshape())),
9081 op::Shape("f32[8,105,210,6]"));
9082 auto left_halo =
9083 AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
9084 auto right_halo =
9085 AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
9086 auto exchanged_lhs = AllOf(
9087 op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
9088 op::Broadcast(_)),
9089 op::Shape("f32[8,107,210,6]"));
9090 const auto rhs = AllOf(
9091 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
9092 op::Reshape(), op::Constant())),
9093 op::Shape("f32[3,3,6,32]"));
9094 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
9095 exchanged_lhs, op::CollectivePermute(rhs))),
9096 op::Shape("f32[8,105,210,32]")));
9097 }
9098
TEST_F(SpmdPartitioningTest,Fft3D)9099 TEST_F(SpmdPartitioningTest, Fft3D) {
9100 absl::string_view hlo_string = R"(
9101 HloModule module
9102
9103 ENTRY entry {
9104 constant = c64[1,1,6]
9105 constant({{{(0,0),(1,1),(2,2),(3,3),(4,4),(5,5)}}}),
9106 sharding={devices=[1,1,2]0,1}
9107 ROOT fft = c64[1,1,6] fft(c64[1,1,6] constant), fft_type=FFT, fft_length={6},
9108 sharding={devices=[1,1,2]0,1}
9109 }
9110 )";
9111
9112 TF_ASSERT_OK_AND_ASSIGN(auto module,
9113 PartitionComputation(hlo_string, /*num_devices=*/2));
9114 VLOG(1) << module->ToString();
9115 const auto root = module->entry_computation()->root_instruction();
9116 auto input = AllOf(op::DynamicSlice(op::Constant(), op::Constant(),
9117 op::Constant(), op::Reshape()),
9118 op::Shape("c64[1,1,3]"));
9119 auto padded_input =
9120 AllOf(op::DynamicSlice(
9121 op::Concatenate(input, op::CollectivePermute(op::Slice())),
9122 op::Constant(), op::Constant(), op::Reshape()),
9123 op::Shape("c64[1,1,4]"));
9124
9125 auto shuffled_input =
9126 AllOf(op::Slice(op::AllToAll(op::Dot(padded_input, op::Convert()))),
9127 op::Shape("c64[1,1,3]"));
9128
9129 auto local_fft = AllOf(op::Fft(shuffled_input), op::Shape("c64[1,1,3]"));
9130
9131 EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple(
9132 _, op::Multiply(local_fft, op::Exp()), _, _, _))),
9133 op::Shape("c64[1,1,3]")));
9134 }
9135
TEST_F(SpmdPartitioningTest,DotInputsAreIdentical)9136 TEST_F(SpmdPartitioningTest, DotInputsAreIdentical) {
9137 absl::string_view hlo_string = R"(
9138 HloModule module
9139
9140 ENTRY entry {
9141 %parameter.1 = f32[4000,4000]{1,0} parameter(0),
9142 sharding={devices=[2,4]0,1,2,3,4,5,6,7}
9143 ROOT %convolution = f32[4000,4000]{1,0} convolution(
9144 f32[4000,4000]{1,0} %parameter.1, f32[4000,4000]{1,0} %parameter.1),
9145 dim_labels=bf_io->bf, sharding={devices=[2,4]0,1,2,3,4,5,6,7}
9146 }
9147
9148 )";
9149
9150 TF_ASSERT_OK_AND_ASSIGN(auto module,
9151 PartitionComputation(hlo_string, /*num_devices=*/8));
9152 VLOG(1) << module->ToString();
9153 const auto root = module->entry_computation()->root_instruction();
9154 auto param = AllOf(op::Parameter(), op::Shape("f32[2000, 1000]"));
9155 auto resharded_lhs =
9156 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _, _)),
9157 op::Shape("f32[2000, 4000]"));
9158 auto resharded_rhs =
9159 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, op::Copy(param), _, _)),
9160 op::Shape("f32[4000, 1000]"));
9161 EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, resharded_rhs),
9162 op::Shape("f32[2000, 1000]")));
9163 }
9164
TEST_F(SpmdPartitioningTest,ConstantSliceReshard)9165 TEST_F(SpmdPartitioningTest, ConstantSliceReshard) {
9166 absl::string_view hlo_string = R"(
9167 HloModule module
9168
9169 ENTRY entry {
9170 %constant.785 = f32[1,8] constant({{0,1,2,3,4,5,6,7}}),
9171 sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9172 %slice.62 = f32[1,1] slice(%constant.785), slice={[0:1], [0:1]},
9173 sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9174 ROOT %reshape.779 = f32[] reshape(%slice.62), sharding={replicated}
9175 })";
9176 TF_ASSERT_OK_AND_ASSIGN(auto module,
9177 PartitionComputation(hlo_string, /*num_devices=*/8));
9178 const auto root = module->entry_computation()->root_instruction();
9179 VLOG(1) << module->ToString();
9180 auto slice = AllOf(op::Shape("f32[1,1]"),
9181 op::Copy(op::DynamicSlice(op::Constant(), _, _)));
9182 EXPECT_THAT(root, op::Reshape(op::AllReduce(op::Select(_, slice, _))));
9183 }
9184
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionOperand)9185 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionOperand) {
9186 absl::string_view hlo_string = R"(
9187 HloModule module
9188
9189 ENTRY %module {
9190 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9191 sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
9192 %constant = s32[4] constant({0, 1, 2, 3}), sharding={replicated}
9193 %iota = s32[1,8,4]{2,1,0} broadcast(%constant), dimensions={2},
9194 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9195 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9196 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9197 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9198 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9199 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9200 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9201 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9202 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9203 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9204 slice_sizes={1,1,2,2}, sharding={replicated}
9205 })";
9206 TF_ASSERT_OK_AND_ASSIGN(auto module,
9207 PartitionComputation(hlo_string, /*num_devices=*/8));
9208 const auto root = module->entry_computation()->root_instruction();
9209 VLOG(1) << module->ToString();
9210 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Reshape());
9211 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9212 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9213 EXPECT_THAT(root,
9214 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9215 }
9216
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionIndices)9217 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionIndices) {
9218 absl::string_view hlo_string = R"(
9219 HloModule module
9220
9221 ENTRY %module {
9222 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9223 sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
9224 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9225 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
9226 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9227 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
9228 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9229 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9230 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
9231 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(s32[8,4,2,2]{3,2,1,0} %parameter.0,
9232 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9233 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9234 slice_sizes={1,1,2,2}, sharding={replicated}
9235 })";
9236 TF_ASSERT_OK_AND_ASSIGN(auto module,
9237 PartitionComputation(hlo_string, /*num_devices=*/8));
9238 const auto root = module->entry_computation()->root_instruction();
9239 VLOG(1) << module->ToString();
9240 auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice());
9241 auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract());
9242 auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices));
9243 EXPECT_THAT(root, op::AllReduce(op::AllReduce(
9244 op::DynamicUpdateSlice(_, gather, _, _, _, _))));
9245 }
9246
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedIndices)9247 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedIndices) {
9248 absl::string_view hlo_string = R"(
9249 HloModule module
9250
9251 ENTRY %module {
9252 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9253 sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
9254 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9255 sharding={replicated}
9256 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9257 sharding={replicated}
9258 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9259 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9260 sharding={replicated}
9261 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9262 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9263 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9264 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9265 slice_sizes={1,1,2,2}, sharding={replicated}
9266 })";
9267 TF_ASSERT_OK_AND_ASSIGN(auto module,
9268 PartitionComputation(hlo_string, /*num_devices=*/8));
9269 const auto root = module->entry_computation()->root_instruction();
9270 VLOG(1) << module->ToString();
9271 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
9272 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9273 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9274 EXPECT_THAT(root,
9275 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9276 }
9277
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedOperand)9278 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedOperand) {
9279 absl::string_view hlo_string = R"(
9280 HloModule module
9281
9282 ENTRY %module {
9283 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={replicated}
9284 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9285 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9286 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9287 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9288 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9289 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9290 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9291 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9292 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9293 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9294 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9295 slice_sizes={1,1,2,2}, sharding={replicated}
9296 })";
9297 TF_ASSERT_OK_AND_ASSIGN(auto module,
9298 PartitionComputation(hlo_string, /*num_devices=*/8));
9299 const auto root = module->entry_computation()->root_instruction();
9300 VLOG(1) << module->ToString();
9301 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
9302 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9303 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9304 EXPECT_THAT(root,
9305 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9306 }
9307
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedIndices)9308 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedIndices) {
9309 absl::string_view hlo_string = R"(
9310 HloModule module
9311
9312 ENTRY %module {
9313 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9314 sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
9315 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9316 sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9317 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9318 sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9319 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9320 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9321 sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9322 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9323 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9324 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9325 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9326 slice_sizes={1,1,2,2}, sharding={replicated}
9327 })";
9328 TF_ASSERT_OK_AND_ASSIGN(auto module,
9329 PartitionComputation(hlo_string, /*num_devices=*/8));
9330 const auto root = module->entry_computation()->root_instruction();
9331 VLOG(1) << module->ToString();
9332 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
9333 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9334 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9335 EXPECT_THAT(root,
9336 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9337 }
9338
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedOperand)9339 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedOperand) {
9340 absl::string_view hlo_string = R"(
9341 HloModule module
9342
9343 ENTRY %module {
9344 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
9345 devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9346 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9347 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9348 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9349 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9350 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9351 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9352 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
9353 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9354 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9355 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9356 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9357 slice_sizes={1,1,2,2}, sharding={replicated}
9358 })";
9359 TF_ASSERT_OK_AND_ASSIGN(auto module,
9360 PartitionComputation(hlo_string, /*num_devices=*/8));
9361 const auto root = module->entry_computation()->root_instruction();
9362 VLOG(1) << module->ToString();
9363 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
9364 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
9365 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
9366 EXPECT_THAT(root,
9367 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
9368 }
9369
TEST_F(SpmdPartitioningTest,GatherParallelDimSwappedDimensions)9370 TEST_F(SpmdPartitioningTest, GatherParallelDimSwappedDimensions) {
9371 absl::string_view hlo_string = R"(
9372 HloModule module
9373
9374 ENTRY %module {
9375 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
9376 devices=[4,2,1,1]0,1,2,3,4,5,6,7}
9377 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9378 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
9379 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9380 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
9381 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9382 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9383 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
9384 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9385 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9386 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9387 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9388 slice_sizes={1,1,2,2}, sharding={replicated}
9389 })";
9390 TF_ASSERT_OK_AND_ASSIGN(auto module,
9391 PartitionComputation(hlo_string, /*num_devices=*/8));
9392 const auto root = module->entry_computation()->root_instruction();
9393 VLOG(1) << module->ToString();
9394 auto operand = AllOf(op::Shape("s32[4,1,2,2]"), op::CollectivePermute());
9395 auto indices = AllOf(op::Shape("s32[2,4,1]"), op::Subtract());
9396 auto gather = AllOf(op::Shape("s32[4,1,2,2]"), op::Gather(operand, indices));
9397 EXPECT_THAT(root, op::AllReduce(op::AllReduce(
9398 op::DynamicUpdateSlice(_, gather, _, _, _, _))));
9399 }
9400
TEST_F(SpmdPartitioningTest,GatherParallelDimAndNonParallelDimPartitioned)9401 TEST_F(SpmdPartitioningTest, GatherParallelDimAndNonParallelDimPartitioned) {
9402 absl::string_view hlo_string = R"(
9403 HloModule module
9404
9405 ENTRY %module {
9406 %arg.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
9407 %arg.1 = s32[1,8,4]{2,1,0} parameter(1)
9408 %operand = s32[8,4,2,2]{3,2,1,0} copy(s32[8,4,2,2]{3,2,1,0} %arg.0),
9409 sharding={devices=[2,2,1,1]0,1,2,3}
9410 %indices = s32[1,8,4]{2,1,0} copy(s32[1,8,4]{2,1,0} %arg.1),
9411 sharding={devices=[1,2,2]0,1,2,3}
9412 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9413 sharding={devices=[1,2,2]0,1,2,3}
9414 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9415 s32[1,8,4]{2,1,0} %indices), dimensions={0},
9416 sharding={devices=[1,2,2]0,1,2,3}
9417 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9418 s32[8,4,2,2]{3,2,1,0} %operand,
9419 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9420 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
9421 slice_sizes={1,1,2,2}, sharding={replicated}
9422 })";
9423
9424 TF_ASSERT_OK_AND_ASSIGN(auto module,
9425 PartitionComputation(hlo_string, /*num_devices=*/4));
9426 const auto root = module->entry_computation()->root_instruction();
9427 VLOG(1) << module->ToString();
9428 auto operand = AllOf(op::Shape("s32[4,4,2,2]"), op::AllReduce());
9429 auto indices = AllOf(op::Shape("s32[2,4,2]"), op::Subtract());
9430 auto gather = AllOf(op::Shape("s32[4,2,2,2]"), op::Gather(operand, indices));
9431 EXPECT_THAT(
9432 root, op::AllReduce(op::DynamicUpdateSlice(
9433 _, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)),
9434 _, _, _, _)));
9435 }
9436
TEST_F(SpmdPartitioningTest,GatherMergedParalleIndexPassthrough)9437 TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {
9438 absl::string_view hlo_string = R"(
9439 HloModule module
9440
9441 ENTRY %module {
9442 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9443 sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
9444 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9445 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9446 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9447 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9448 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9449 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9450 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9451 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9452 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9453 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9454 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9455 slice_sizes={1,1,2,2}, sharding={replicated}
9456 })";
9457 TF_ASSERT_OK_AND_ASSIGN(auto module,
9458 PartitionComputation(hlo_string, /*num_devices=*/8));
9459 VLOG(1) << module->ToString();
9460 const auto root = module->entry_computation()->root_instruction();
9461 auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Reshape());
9462 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9463 auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
9464 EXPECT_THAT(root, op::AllReduce(op::AllReduce(
9465 op::DynamicUpdateSlice(_, gather, _, _, _, _))));
9466 }
9467
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperand)9468 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperand) {
9469 absl::string_view hlo_string = R"(
9470 HloModule module
9471
9472 ENTRY %module {
9473 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9474 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9475 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9476 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9477 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9478 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9479 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9480 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9481 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9482 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9483 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9484 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9485 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9486 slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9487 })";
9488 TF_ASSERT_OK_AND_ASSIGN(auto module,
9489 PartitionComputation(hlo_string, /*num_devices=*/8));
9490 VLOG(1) << module->ToString();
9491 const auto root = module->entry_computation()->root_instruction();
9492 auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
9493 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9494 auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
9495 EXPECT_THAT(root, gather);
9496 }
9497
TEST_F(SpmdPartitioningTest,GatherReshardParalleIndexAndOperand)9498 TEST_F(SpmdPartitioningTest, GatherReshardParalleIndexAndOperand) {
9499 absl::string_view hlo_string = R"(
9500 HloModule module
9501
9502 ENTRY %module {
9503 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9504 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9505 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9506 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9507 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9508 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9509 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9510 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9511 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9512 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9513 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9514 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9515 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9516 slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]1,0,3,2,4,5,6,7}
9517 })";
9518 TF_ASSERT_OK_AND_ASSIGN(auto module,
9519 PartitionComputation(hlo_string, /*num_devices=*/8));
9520 VLOG(1) << module->ToString();
9521 const auto root = module->entry_computation()->root_instruction();
9522 auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
9523 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9524 auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
9525 EXPECT_THAT(root, op::CollectivePermute(gather));
9526 }
9527
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperandReshard)9528 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperandReshard) {
9529 absl::string_view hlo_string = R"(
9530 HloModule module
9531
9532 ENTRY %module {
9533 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9534 sharding={devices=[4,1,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9535 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
9536 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9537 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
9538 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9539 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
9540 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
9541 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9542 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
9543 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9544 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
9545 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9546 slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
9547 })";
9548 TF_ASSERT_OK_AND_ASSIGN(auto module,
9549 PartitionComputation(hlo_string, /*num_devices=*/8));
9550 VLOG(1) << module->ToString();
9551 const auto root = module->entry_computation()->root_instruction();
9552 auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Parameter(0));
9553 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
9554 auto gather = AllOf(op::Shape("s32[2,4,2,2]"), op::Gather(operand, indices));
9555 EXPECT_THAT(root, op::DynamicSlice(gather, _, _, _, _));
9556 }
9557
TEST_F(SpmdPartitioningTest,GatherMergedParallelIndexTrivialSlice)9558 TEST_F(SpmdPartitioningTest, GatherMergedParallelIndexTrivialSlice) {
9559 absl::string_view hlo_string = R"(
9560 HloModule module
9561
9562 ENTRY %module {
9563 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
9564 sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
9565 %parameter.1 = s32[1,8,1]{2,1,0} parameter(1),
9566 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9567 %iota = s32[1,8,1]{2,1,0} iota(), iota_dimension=1,
9568 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9569 %concatenate.19 = s32[2,8,1]{2,1,0} concatenate(
9570 s32[1,8,1]{2,1,0} %parameter.1, s32[1,8,1]{2,1,0} %iota), dimensions={0},
9571 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9572 ROOT %gather.20 = s32[8,1,2,2]{3,2,1,0} gather(
9573 s32[8,4,2,2]{3,2,1,0} %parameter.0,
9574 s32[2,8,1]{2,1,0} %concatenate.19), offset_dims={2,3},
9575 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
9576 slice_sizes={1,1,2,2}, sharding={replicated}
9577 })";
9578 TF_ASSERT_OK_AND_ASSIGN(auto module,
9579 PartitionComputation(hlo_string, /*num_devices=*/8));
9580 const auto root = module->entry_computation()->root_instruction();
9581 auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Parameter());
9582 auto indices = AllOf(op::Shape("s32[2,2,1]"), op::Subtract());
9583 auto gather = AllOf(op::Shape("s32[2,1,2,2]"), op::Gather(operand, indices));
9584 VLOG(1) << module->ToString();
9585 EXPECT_THAT(root,
9586 op::AllReduce(op::DynamicUpdateSlice(
9587 _, op::AllReduce(op::Select(_, _, gather)), _, _, _, _)));
9588 }
9589
TEST_F(SpmdPartitioningTest,SortTopKNonSortDimension)9590 TEST_F(SpmdPartitioningTest, SortTopKNonSortDimension) {
9591 absl::string_view hlo_string = R"(
9592 HloModule module
9593
9594 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
9595 p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
9596 %p.0.lhs.42078 = f32[] parameter(0)
9597 %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
9598 %constant.45054 = s32[] constant(0)
9599 %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
9600 s32[] %constant.45054), direction=LT
9601 %constant.45278 = u32[] constant(2147483647)
9602 %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
9603 %subtract.337 = u32[] subtract(u32[] %constant.45278,
9604 u32[] %bitcast-convert.136)
9605 %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
9606 %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
9607 s32[] %bitcast-convert.135)
9608 %p.0.rhs.42079 = f32[] parameter(1)
9609 %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
9610 %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
9611 s32[] %constant.45054), direction=LT
9612 %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
9613 %subtract.338 = u32[] subtract(u32[] %constant.45278,
9614 u32[] %bitcast-convert.139)
9615 %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
9616 %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
9617 s32[] %bitcast-convert.138)
9618 %compare.135 = pred[] compare(s32[] %select.282,
9619 s32[] %select.283), direction=GT
9620 %compare.428 = pred[] compare(s32[] %select.283,
9621 s32[] %select.282), direction=GT
9622 %compare.429 = pred[] compare(pred[] %compare.135,
9623 pred[] %compare.428), direction=EQ
9624 %p.1.lhs.42080 = s32[] parameter(2)
9625 %p.1.rhs.42081 = s32[] parameter(3)
9626 %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
9627 s32[] %p.1.rhs.42081), direction=LT
9628 ROOT %select.579 = pred[] select(pred[] %compare.429,
9629 pred[] %compare.430, pred[] %compare.135)
9630 }
9631
9632 ENTRY %module {
9633 %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
9634 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9635 %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
9636 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9637 %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
9638 f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
9639 dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
9640 sharding={{devices=[2,1,4]0,1,2,3,4,5,6,7},
9641 {devices=[2,1,4]0,1,2,3,4,5,6,7}}
9642 output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
9643 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9644 %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
9645 slice={[0:2], [0:64], [0:2]},
9646 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9647 output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
9648 sharding={replicated}
9649 %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
9650 slice={[0:2], [0:64], [0:2]},
9651 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
9652 ROOT output.t = (f32[2,64,2]{2,1,0},
9653 s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
9654 sharding={{replicated}, {replicated}}
9655 })";
9656 TF_ASSERT_OK_AND_ASSIGN(auto module,
9657 PartitionComputation(hlo_string, /*num_devices=*/8));
9658
9659 const HloInstruction* sort = FindInstruction(module.get(), "sort.0");
9660 EXPECT_NE(sort, nullptr);
9661 auto sort_match =
9662 AllOf(op::Shape("(f32[2,64,32128], s32[2,64,32128])"), op::Sort(_, _));
9663 EXPECT_THAT(sort, sort_match);
9664 }
9665
TEST_F(SpmdPartitioningTest,SortTopKPropagateBaseShape)9666 TEST_F(SpmdPartitioningTest, SortTopKPropagateBaseShape) {
9667 absl::string_view hlo_string = R"(
9668 HloModule module
9669
9670 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
9671 p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
9672 %p.0.lhs.42078 = f32[] parameter(0)
9673 %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
9674 %constant.45054 = s32[] constant(0)
9675 %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
9676 s32[] %constant.45054), direction=LT
9677 %constant.45278 = u32[] constant(2147483647)
9678 %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
9679 %subtract.337 = u32[] subtract(u32[] %constant.45278,
9680 u32[] %bitcast-convert.136)
9681 %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
9682 %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
9683 s32[] %bitcast-convert.135)
9684 %p.0.rhs.42079 = f32[] parameter(1)
9685 %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
9686 %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
9687 s32[] %constant.45054), direction=LT
9688 %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
9689 %subtract.338 = u32[] subtract(u32[] %constant.45278,
9690 u32[] %bitcast-convert.139)
9691 %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
9692 %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
9693 s32[] %bitcast-convert.138)
9694 %compare.135 = pred[] compare(s32[] %select.282,
9695 s32[] %select.283), direction=GT
9696 %compare.428 = pred[] compare(s32[] %select.283,
9697 s32[] %select.282), direction=GT
9698 %compare.429 = pred[] compare(pred[] %compare.135,
9699 pred[] %compare.428), direction=EQ
9700 %p.1.lhs.42080 = s32[] parameter(2)
9701 %p.1.rhs.42081 = s32[] parameter(3)
9702 %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
9703 s32[] %p.1.rhs.42081), direction=LT
9704 ROOT %select.579 = pred[] select(pred[] %compare.429,
9705 pred[] %compare.430, pred[] %compare.135)
9706 }
9707
9708 ENTRY %module {
9709 %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
9710 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9711 %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
9712 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9713 %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
9714 f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
9715 dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
9716 sharding={{devices=[1,1,8]0,1,2,3,4,5,6,7},
9717 {devices=[1,1,8]0,1,2,3,4,5,6,7}}
9718 output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
9719 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9720 %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
9721 slice={[0:2], [0:64], [0:2]},
9722 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9723 output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
9724 sharding={replicated}
9725 %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
9726 slice={[0:2], [0:64], [0:2]},
9727 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
9728 ROOT output.t = (f32[2,64,2]{2,1,0},
9729 s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
9730 sharding={{replicated}, {replicated}}
9731 })";
9732 TF_ASSERT_OK_AND_ASSIGN(auto module,
9733 PartitionComputation(hlo_string, /*num_devices=*/8));
9734 VLOG(1) << module->ToString();
9735 const HloInstruction* root = module->entry_computation()->root_instruction();
9736 auto all_reduce_val =
9737 AllOf(op::Shape("f32[2,64,2]"),
9738 op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _)));
9739 auto all_reduce_idx =
9740 AllOf(op::Shape("s32[2,64,2]"),
9741 op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _)));
9742 auto tuple = AllOf(op::Shape("(f32[2,64,2], s32[2,64,2])"),
9743 op::Tuple(all_reduce_val, all_reduce_idx));
9744 EXPECT_THAT(root, tuple);
9745 }
9746
TEST_F(SpmdPartitioningTest,GatherIndexOnlyCorrectReplacement)9747 TEST_F(SpmdPartitioningTest, GatherIndexOnlyCorrectReplacement) {
9748 absl::string_view hlo_string = R"(
9749 HloModule module
9750
9751 ENTRY %module {
9752 %parameter.0 = bf16[1,8,6,6]{3,2,1,0} parameter(0),
9753 sharding={replicated}
9754 %parameter.1 = s32[2,4]{1,0} parameter(1),
9755 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9756 %gather.100 = bf16[2,1,8,1,6]{4,3,2,1,0} gather(
9757 bf16[1,8,6,6]{3,2,1,0} %parameter.0, s32[2,4]{1,0} %parameter.1),
9758 offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3},
9759 index_vector_dim=1, slice_sizes={1,8,1,6},
9760 sharding={devices=[2,1,4,1,1]0,1,2,3,4,5,6,7}
9761 %constant.45590 = s32[] constant(0), sharding={replicated}
9762 %broadcast.54515 = s32[2,64,1,1]{3,2,1,0} broadcast(s32[] %constant.45590),
9763 dimensions={},
9764 sharding={devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
9765 ROOT %reshape.4243 = bf16[2,8,6]{2,1,0} reshape(
9766 bf16[2,1,8,1,6]{4,3,2,1,0} %gather.100),
9767 sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
9768 })";
9769 TF_ASSERT_OK_AND_ASSIGN(auto module,
9770 PartitionComputation(hlo_string, /*num_devices=*/8));
9771
9772 const HloInstruction* root = module->entry_computation()->root_instruction();
9773 auto param0 = AllOf(op::Shape("bf16[1,8,6,6]"), op::Parameter());
9774 auto param1 = AllOf(op::Shape("s32[1,4]"), op::Parameter());
9775 auto reshape = AllOf(
9776 op::Shape("bf16[1,2,6]"),
9777 op::Reshape(op::DynamicSlice(op::Gather(param0, param1), _, _, _, _, _)));
9778 EXPECT_THAT(root, reshape);
9779 }
9780
TEST_F(SpmdPartitioningTest,GatherRegressionTest1)9781 TEST_F(SpmdPartitioningTest, GatherRegressionTest1) {
9782 absl::string_view hlo_string = R"(
9783 HloModule module
9784
9785 ENTRY %module {
9786 %parameter.0 = s32[1,4] parameter(0), sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9787 %iota.10 = s32[4]{0} iota(), iota_dimension=0, sharding={devices=[8]0,1,2,3,4,5,6,7}
9788 ROOT %gather.44 = s32[1,4]{1,0} gather(%parameter.0, %iota.10),
9789 offset_dims={0}, collapsed_slice_dims={1}, start_index_map={1}, index_vector_dim=1,
9790 slice_sizes={1,1}, sharding={devices=[1,8]0,1,2,3,4,5,6,7}
9791 })";
9792 TF_ASSERT_OK_AND_ASSIGN(auto module,
9793 PartitionComputation(hlo_string, /*num_devices=*/8));
9794
9795 const HloInstruction* root = module->entry_computation()->root_instruction();
9796 auto param0 = AllOf(op::Shape("s32[1,1]"), op::Parameter());
9797 EXPECT_THAT(root, op::Gather(param0, _));
9798 }
9799
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint)9800 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint) {
9801 absl::string_view hlo_string = R"(
9802 HloModule module
9803
9804 ENTRY %module {
9805 %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
9806 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9807 %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
9808 sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
9809 %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
9810 convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
9811 bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
9812 window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
9813 rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
9814 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9815 ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
9816 reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
9817 sharding={replicated}
9818 })";
9819 TF_ASSERT_OK_AND_ASSIGN(
9820 auto module,
9821 PartitionComputation(hlo_string, /*num_devices=*/8,
9822 /*conv_halo_exchange_always_on_lhs =*/true,
9823 /*choose_faster_windowed_einsum =*/false));
9824 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9825 EXPECT_NE(while_inst, nullptr);
9826 const HloComputation* cond_comp = while_inst->while_condition();
9827 const HloInstruction* root = cond_comp->root_instruction();
9828 EXPECT_THAT(root, op::Compare(_, op::Constant()));
9829 const HloConstantInstruction* iterations =
9830 Cast<HloConstantInstruction>(root->operand(1));
9831 EXPECT_TRUE(iterations->literal().GetFirstInteger());
9832 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
9833 }
9834
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations)9835 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations) {
9836 absl::string_view hlo_string = R"(
9837 HloModule module
9838
9839 ENTRY %module {
9840 %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
9841 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9842 %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
9843 sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
9844 %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
9845 convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
9846 bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
9847 window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
9848 rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
9849 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
9850 ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
9851 reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
9852 sharding={replicated}
9853 })";
9854 TF_ASSERT_OK_AND_ASSIGN(
9855 auto module,
9856 PartitionComputation(hlo_string, /*num_devices=*/8,
9857 /*conv_halo_exchange_always_on_lhs =*/true,
9858 /*choose_faster_windowed_einsum =*/true));
9859 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9860 EXPECT_NE(while_inst, nullptr);
9861 const HloComputation* cond_comp = while_inst->while_condition();
9862 const HloInstruction* root = cond_comp->root_instruction();
9863 EXPECT_THAT(root, op::Compare(_, op::Constant()));
9864 const HloConstantInstruction* iterations =
9865 Cast<HloConstantInstruction>(root->operand(1));
9866 EXPECT_TRUE(iterations->literal().GetFirstInteger());
9867 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 2);
9868 }
9869
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations2)9870 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations2) {
9871 const char* const hlo_string = R"(
9872 HloModule module
9873
9874 ENTRY entry {
9875 %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
9876 %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
9877 sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
9878 18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9879 %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
9880 %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
9881 sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
9882 17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9883 %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
9884 bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
9885 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9886 20,21,22,23,24,25,26,27,28,29,30,31}
9887 %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
9888 reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
9889 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9890 20,21,22,23,24,25,26,27,28,29,30,31}
9891 %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9892 convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
9893 bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
9894 window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
9895 dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
9896 12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
9897 27,31}
9898 ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9899 copy(%convolution.10), sharding={replicated}
9900 })";
9901 TF_ASSERT_OK_AND_ASSIGN(
9902 auto module,
9903 PartitionComputation(hlo_string, /*num_devices=*/32,
9904 /*conv_halo_exchange_always_on_lhs =*/true,
9905 /*choose_faster_windowed_einsum =*/true));
9906 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9907 EXPECT_NE(while_inst, nullptr);
9908 const HloComputation* cond_comp = while_inst->while_condition();
9909 const HloInstruction* root = cond_comp->root_instruction();
9910 EXPECT_THAT(root, op::Compare(_, op::Constant()));
9911 const HloConstantInstruction* iterations =
9912 Cast<HloConstantInstruction>(root->operand(1));
9913 EXPECT_TRUE(iterations->literal().GetFirstInteger());
9914 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
9915 }
9916
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint2)9917 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint2) {
9918 const char* const hlo_string = R"(
9919 HloModule module
9920
9921 ENTRY entry {
9922 %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
9923 %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
9924 sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
9925 18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9926 %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
9927 %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
9928 sharding={devices=[8,1,4,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
9929 17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
9930 %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
9931 bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
9932 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9933 20,21,22,23,24,25,26,27,28,29,30,31}
9934 %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
9935 reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
9936 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
9937 20,21,22,23,24,25,26,27,28,29,30,31}
9938 %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9939 convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
9940 bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
9941 window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
9942 dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
9943 12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
9944 27,31}
9945 ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
9946 copy(%convolution.10), sharding={replicated}
9947 })";
9948 TF_ASSERT_OK_AND_ASSIGN(
9949 auto module,
9950 PartitionComputation(hlo_string, /*num_devices=*/32,
9951 /*conv_halo_exchange_always_on_lhs =*/true,
9952 /*choose_faster_windowed_einsum =*/false));
9953 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
9954 EXPECT_NE(while_inst, nullptr);
9955 const HloComputation* cond_comp = while_inst->while_condition();
9956 const HloInstruction* root = cond_comp->root_instruction();
9957 EXPECT_THAT(root, op::Compare(_, op::Constant()));
9958 const HloConstantInstruction* iterations =
9959 Cast<HloConstantInstruction>(root->operand(1));
9960 EXPECT_TRUE(iterations->literal().GetFirstInteger());
9961 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 8);
9962 }
9963
TEST_F(SpmdPartitioningTest,ContractingPartitionDotOperandsSlicedWrong)9964 TEST_F(SpmdPartitioningTest, ContractingPartitionDotOperandsSlicedWrong) {
9965 const char* const hlo_string = R"(
9966 HloModule module
9967
9968 ENTRY entry {
9969 %lhs = f32[8,2,15,4] parameter(0)
9970 %lhs.copy = f32[8,2,15,4] copy(%lhs),
9971 sharding={devices=[1,2,4,1]0,1,2,3,4,5,6,7}
9972 %rhs = f32[2,15,4] parameter(1)
9973 %rhs.copy = f32[2,15,4] copy(%rhs),
9974 sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
9975 %dot = f32[8,2,2] dot(%lhs.copy, %rhs.copy),
9976 lhs_batch_dims={}, rhs_batch_dims={},
9977 lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2},
9978 operand_precision={HIGH,HIGH},
9979 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
9980 ROOT %output = f32[8,2,2] copy(%dot), sharding={replicated}
9981 })";
9982 TF_ASSERT_OK_AND_ASSIGN(
9983 auto module,
9984 PartitionComputation(hlo_string, /*num_devices=*/8,
9985 /*conv_halo_exchange_always_on_lhs =*/true,
9986 /*choose_faster_windowed_einsum =*/true));
9987
9988 const HloInstruction* dot_op = FindInstruction(module.get(), HloOpcode::kDot);
9989 auto op1 = op::Shape("f32[4,2,4,4]");
9990 auto op2 = op::Shape("f32[2,4,4]");
9991 EXPECT_THAT(dot_op, op::Dot(op1, op2));
9992 }
9993
TEST_F(SpmdPartitioningTest,PartitionDotGroupOnBatchContractingReshard)9994 TEST_F(SpmdPartitioningTest, PartitionDotGroupOnBatchContractingReshard) {
9995 absl::string_view hlo_string = R"(
9996 HloModule module
9997
9998 ENTRY entry {
9999 %lhs = f32[32,32,24,4096] parameter(0),
10000 sharding={devices=[2,1,1,2]0,1,2,3}
10001 %rhs = f32[32,4096,1024] parameter(1),
10002 sharding={devices=[2,2,1]0,1,2,3}
10003 ROOT %dot = f32[32,32,24,1024] dot(%lhs, %rhs),
10004 lhs_batch_dims={0}, rhs_batch_dims={0},
10005 lhs_contracting_dims={3}, rhs_contracting_dims={1},
10006 sharding={devices=[1,2,1,2]0,1,2,3}
10007 })";
10008
10009 TF_ASSERT_OK_AND_ASSIGN(
10010 auto module,
10011 PartitionComputation(hlo_string, /*num_devices=*/4,
10012 /*conv_halo_exchange_always_on_lhs=*/true,
10013 /*choose_faster_windowed_einsum=*/true));
10014 VLOG(1) << module->ToString();
10015 const auto root = module->entry_computation()->root_instruction();
10016 auto dot = AllOf(op::Shape("f32[16,32,24,1024]"),
10017 op::Dot(op::Parameter(0), op::Parameter(1)));
10018 auto reduce_scatter = AllOf(op::Shape("f32[16,32,24,512]"),
10019 op::DynamicSlice(op::AllReduce(dot), _, _, _, _));
10020 EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(
10021 op::AllToAll(op::Reshape(reduce_scatter)))),
10022 op::Shape("f32[32,16,24,512]")));
10023 }
10024
TEST_F(SpmdPartitioningTest,PartitionPassthroughScatterCorrectOutputSharding)10025 TEST_F(SpmdPartitioningTest, PartitionPassthroughScatterCorrectOutputSharding) {
10026 absl::string_view hlo_string = R"(
10027 HloModule module
10028
10029 %scatter_add (parameter.0: bf16[], parameter.1: bf16[]) -> bf16[] {
10030 %parameter.0 = bf16[] parameter(0)
10031 %parameter.1 = bf16[] parameter(1)
10032 ROOT %add = bf16[] add(bf16[] %parameter.0, bf16[] %parameter.1)
10033 }
10034
10035 ENTRY entry {
10036 %operand = bf16[2,1024]{1,0} parameter(0),
10037 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
10038 %indices = s32[8,512,1]{2,1,0} parameter(1),
10039 sharding={devices=[2,1,1,2]0,2,1,3 last_tile_dim_replicate}
10040 %updates = bf16[8,512,1024]{2,1,0} parameter(2),
10041 sharding={devices=[2,1,2]0,2,1,3}
10042 ROOT %scatter = bf16[2,1024]{1,0} scatter(bf16[2,1024]{1,0} %operand,
10043 s32[8,512,1]{2,1,0} %indices,
10044 bf16[8,512,1024]{2,1,0} %updates), update_window_dims={2},
10045 inserted_window_dims={0}, scatter_dims_to_operand_dims={0},
10046 index_vector_dim=2, to_apply=%scatter_add,
10047 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
10048 })";
10049
10050 TF_ASSERT_OK_AND_ASSIGN(auto module,
10051 PartitionComputation(hlo_string, /*num_devices=*/4));
10052 VLOG(1) << module->ToString();
10053 const auto root = module->entry_computation()->root_instruction();
10054 auto scatter = AllOf(op::Shape("bf16[2,512]"), op::Scatter(_, _, _));
10055 EXPECT_THAT(root, scatter);
10056 }
10057
IsTrivialCollectivePermute(HloInstruction * hlo)10058 bool IsTrivialCollectivePermute(HloInstruction* hlo) {
10059 if (hlo->opcode() != HloOpcode::kCollectivePermute) {
10060 return false;
10061 }
10062 if (hlo->source_target_pairs().empty()) {
10063 return true;
10064 }
10065 return absl::c_all_of(hlo->source_target_pairs(),
10066 [](const std::pair<int64_t, int64_t>& pair) {
10067 return pair.first == pair.second;
10068 });
10069 }
10070
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyIdentity)10071 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyIdentity) {
10072 absl::string_view hlo_string = R"(
10073 HloModule test
10074
10075 ENTRY entry {
10076 %parameter.7 = f32[3,16] parameter(0), sharding={devices=[1,2]0,1}
10077 %constant.7 = f32[] constant(0)
10078 %pad.3 = f32[3,18] pad(f32[3,16] %parameter.7, f32[] %constant.7), padding=0_0x1_1, sharding={devices=[1,2]0,1}
10079 // Shift right by 16.
10080 %slice.8 = f32[3,16] slice(f32[3,18] %pad.3), slice={[0:3], [2:18]}, sharding={devices=[1,2]0,1}
10081 %slice.9 = f32[3,2] slice(f32[3,18] %pad.3), slice={[0:3], [0:2]}, sharding={devices=[1,2]0,1}
10082 ROOT %concatenate.6 = f32[3,18] concatenate(f32[3,16] %slice.8, f32[3,2] %slice.9), dimensions={1}, sharding={devices=[1,2]0,1}
10083 }
10084 )";
10085
10086 TF_ASSERT_OK_AND_ASSIGN(auto module,
10087 PartitionComputation(hlo_string, /*num_devices=*/2));
10088 VLOG(1) << module->ToString();
10089
10090 // Check that the partitioned code does not have a "trivial" collective
10091 // permute (which would degenerate to a copy).
10092 for (HloComputation* computation : module->computations()) {
10093 for (HloInstruction* hlo : computation->instructions()) {
10094 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10095 }
10096 }
10097 }
10098
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyZero)10099 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyZero) {
10100 absl::string_view hlo_string = R"(
10101 HloModule test
10102
10103 ENTRY entry {
10104 %parameter = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
10105 %slice = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
10106 %c0 = f32[] constant(0)
10107 ROOT %pad = f32[3,18,16,16,16,132]{5,4,3,2,1,0} pad(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice, f32[] %c0), padding=0_0x0_17x0_0x0_0x0_0x0_0, sharding={devices=[1,2,1,1,1,1]0,1}
10108 }
10109 )";
10110
10111 TF_ASSERT_OK_AND_ASSIGN(auto module,
10112 PartitionComputation(hlo_string, /*num_devices=*/2));
10113 VLOG(1) << module->ToString();
10114
10115 // Check that the partitioned code does not have a collective permute with an
10116 // empty source_target_pair list.
10117 for (HloComputation* computation : module->computations()) {
10118 for (HloInstruction* hlo : computation->instructions()) {
10119 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10120 }
10121 }
10122 }
10123
TEST_F(SpmdPartitioningTest,PadWithWrapPattern)10124 TEST_F(SpmdPartitioningTest, PadWithWrapPattern) {
10125 absl::string_view hlo_string = R"(
10126 HloModule xla_computation_apply_fn__4.61
10127
10128 ENTRY %xla_computation_apply_fn__4.61 (parameter.7: f32[3,16,16,16,16,132]) -> f32[3,18,16,16,16,132] {
10129 %parameter.7 = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
10130 %slice.2 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
10131 %slice.3 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [0:1], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
10132 ROOT %concatenate.3 = f32[3,18,16,16,16,132]{5,4,3,2,1,0} concatenate(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.2, f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7, f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.3), dimensions={1}, sharding={devices=[1,2,1,1,1,1]0,1}
10133 }
10134 )";
10135
10136 TF_ASSERT_OK_AND_ASSIGN(auto module,
10137 PartitionComputation(hlo_string, /*num_devices=*/2));
10138 VLOG(1) << module->ToString();
10139
10140 // Check that the partitioned code does not have all-reduce and two
10141 // non-trivial collective permute instructions.
10142 for (HloComputation* computation : module->computations()) {
10143 for (HloInstruction* hlo : computation->instructions()) {
10144 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10145 EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
10146 }
10147 }
10148 }
10149
TEST_F(SpmdPartitioningTest,PadWrapWithNegatePattern)10150 TEST_F(SpmdPartitioningTest, PadWrapWithNegatePattern) {
10151 absl::string_view hlo_string = R"(
10152 HloModule module
10153
10154 ENTRY entry {
10155 %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
10156 %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
10157 %negate.2 = f32[1,2] negate(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
10158 %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
10159 %negate.3 = f32[1,2] negate(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
10160 ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %negate.2, f32[1,18] %parameter.1, f32[1,2] %negate.3), dimensions={1}, sharding={devices=[1,2]0,1}
10161 }
10162 )";
10163 TF_ASSERT_OK_AND_ASSIGN(auto module,
10164 PartitionComputation(hlo_string, /*num_devices=*/2));
10165 VLOG(1) << module->ToString();
10166
10167 // Check that the partitioned code does not have all-reduce or trivial
10168 // collective permute
10169 for (HloComputation* computation : module->computations()) {
10170 for (HloInstruction* hlo : computation->instructions()) {
10171 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10172 EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
10173 }
10174 }
10175 }
10176
TEST_F(SpmdPartitioningTest,PadWrapWithMultipleModifiersPattern)10177 TEST_F(SpmdPartitioningTest, PadWrapWithMultipleModifiersPattern) {
10178 absl::string_view hlo_string = R"(
10179 HloModule module
10180
10181 ENTRY entry {
10182 %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
10183 %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
10184 %mod0.16 = f32[1,2] rsqrt(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
10185 %mod1.16 = f32[1,2] sine(f32[1,2] %mod0.16), sharding={devices=[1,2]0,1}
10186 %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
10187 %mod0.17 = f16[1,2] convert(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
10188 %mod1.17 = f16[1,2] cosine(f16[1,2] %mod0.17), sharding={devices=[1,2]0,1}
10189 %mod2.17 = f32[1,2] convert(f16[1,2] %mod1.17), sharding={devices=[1,2]0,1}
10190 ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %mod1.16, f32[1,18] %parameter.1, f32[1,2] %mod2.17), dimensions={1}, sharding={devices=[1,2]0,1}
10191 }
10192 )";
10193 TF_ASSERT_OK_AND_ASSIGN(auto module,
10194 PartitionComputation(hlo_string, /*num_devices=*/2));
10195 VLOG(1) << module->ToString();
10196
10197 // Check that the partitioned code does not have all-reduce or trivial
10198 // collective permute. Also make sure modifiers have the right dependencies.
10199 for (HloComputation* computation : module->computations()) {
10200 for (HloInstruction* hlo : computation->instructions()) {
10201 const HloOpcode op = hlo->opcode();
10202 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
10203 EXPECT_NE(op, HloOpcode::kAllReduce) << hlo->ToString();
10204 if (hlo->operand_count() != 1) {
10205 continue;
10206 }
10207 const PrimitiveType type = hlo->shape().element_type();
10208 const HloOpcode child_op = hlo->operand(0)->opcode();
10209 const PrimitiveType child_type = hlo->operand(0)->shape().element_type();
10210
10211 if (op == HloOpcode::kSin) {
10212 EXPECT_EQ(child_op, HloOpcode::kRsqrt);
10213 } else if (op == HloOpcode::kConvert && type == F32) {
10214 EXPECT_EQ(child_op, HloOpcode::kCos);
10215 EXPECT_EQ(child_type, F16);
10216 } else if (op == HloOpcode::kCos) {
10217 EXPECT_EQ(child_op, HloOpcode::kConvert);
10218 EXPECT_EQ(child_type, F16);
10219 }
10220 }
10221 }
10222 }
10223
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate)10224 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate) {
10225 absl::string_view hlo_string = R"(
10226 HloModule module
10227
10228 ENTRY entry {
10229 %param0 = f32[1,1] parameter(0), sharding={devices=[2,2]0,1,2,3}
10230 ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
10231 })";
10232
10233 TF_ASSERT_OK_AND_ASSIGN(auto module,
10234 PartitionComputation(hlo_string, /*num_devices=*/4));
10235 VLOG(1) << module->ToString();
10236
10237 const auto root = module->entry_computation()->root_instruction();
10238 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
10239 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
10240 op::Shape("f32[1,1]")));
10241 }
10242
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate2)10243 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate2) {
10244 absl::string_view hlo_string = R"(
10245 HloModule module
10246
10247 ENTRY entry {
10248 %param0 = f32[1,2] parameter(0), sharding={devices=[2,2]0,1,2,3}
10249 ROOT %copy = f32[1,2] copy(%param0), sharding={replicated}
10250 })";
10251
10252 TF_ASSERT_OK_AND_ASSIGN(auto module,
10253 PartitionComputation(hlo_string, /*num_devices=*/4));
10254 VLOG(1) << module->ToString();
10255
10256 const auto root = module->entry_computation()->root_instruction();
10257 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
10258 auto broadcast =
10259 AllOf(op::AllReduce(op::Select(_, param0, _)), op::Shape("f32[1,1]"));
10260 EXPECT_THAT(
10261 root,
10262 AllOf(op::Copy(op::AllReduce(op::DynamicUpdateSlice(_, broadcast, _, _))),
10263 op::Shape("f32[1,2]")));
10264 }
10265
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate3)10266 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate3) {
10267 absl::string_view hlo_string = R"(
10268 HloModule module
10269
10270 ENTRY entry {
10271 %param0 = f32[1,1] parameter(0),
10272 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
10273 ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
10274 })";
10275
10276 TF_ASSERT_OK_AND_ASSIGN(auto module,
10277 PartitionComputation(hlo_string, /*num_devices=*/4));
10278 VLOG(1) << module->ToString();
10279
10280 const auto root = module->entry_computation()->root_instruction();
10281 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
10282 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
10283 op::Shape("f32[1,1]")));
10284 }
10285
TEST_F(SpmdPartitioningTest,TupleWithSubgroupManual)10286 TEST_F(SpmdPartitioningTest, TupleWithSubgroupManual) {
10287 absl::string_view hlo_string = R"(
10288 HloModule module
10289
10290 ENTRY entry {
10291 constant = f32[6,3]{1,0}
10292 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
10293 sharding={replicated}
10294 param = (f32[6,3]{1,0}, f32[]) parameter(0),
10295 sharding={{devices=[2,1,2]0,1,2,3 last_tile_dims={manual}},{replicated}}
10296 gte = f32[6,3]{1,0} get-tuple-element(param), index=0,
10297 sharding={devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}
10298 ROOT tuple = (f32[6,3]{1,0}, f32[6,3]{1,0}) tuple(constant, gte),
10299 sharding={{replicated},{devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}}
10300 }
10301 )";
10302
10303 TF_ASSERT_OK_AND_ASSIGN(auto module,
10304 PartitionComputation(hlo_string, /*num_devices=*/4));
10305 VLOG(1) << module->ToString();
10306 const auto root = module->entry_computation()->root_instruction();
10307 EXPECT_THAT(root,
10308 op::Tuple(op::Constant(), op::GetTupleElement(op::Parameter(0))));
10309 }
10310
TEST_F(SpmdPartitioningTest,SubgroupManualSharedOperand)10311 TEST_F(SpmdPartitioningTest, SubgroupManualSharedOperand) {
10312 absl::string_view hlo_string = R"(
10313 HloModule module
10314
10315 ENTRY entry {
10316 constant = f32[] constant(1), sharding={replicated}
10317 broadcast = f32[2,2] broadcast(constant), dimensions={},
10318 sharding={devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}
10319 ROOT add = f32[2,2] add(broadcast, broadcast),
10320 sharding={devices=[2,1,2]0,1,2,3 last_tile_dims={manual}}
10321 }
10322 )";
10323
10324 TF_ASSERT_OK_AND_ASSIGN(auto module,
10325 PartitionComputation(hlo_string, /*num_devices=*/4));
10326 VLOG(1) << module->ToString();
10327 const auto root = module->entry_computation()->root_instruction();
10328 EXPECT_THAT(root, op::Add(op::Broadcast(op::Constant()),
10329 op::Broadcast(op::Constant())));
10330 }
10331
TEST_F(SpmdPartitioningTest,SubgroupManualAllReduce)10332 TEST_F(SpmdPartitioningTest, SubgroupManualAllReduce) {
10333 absl::string_view hlo_string = R"(
10334 HloModule module
10335
10336 sum {
10337 a = f32[] parameter(0)
10338 b = f32[] parameter(1)
10339 ROOT add = f32[] add(a, b)
10340 }
10341
10342 ENTRY entry {
10343 param = f32[2,2] parameter(0),
10344 sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10345 ROOT all-reduce = f32[2,2]{1,0} all-reduce(param), to_apply=sum,
10346 replica_groups={{2,0},{1,3}}, use_global_device_ids=true, channel_id=1,
10347 sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10348 }
10349 )";
10350
10351 TF_ASSERT_OK_AND_ASSIGN(auto module,
10352 PartitionComputation(hlo_string, /*num_devices=*/4));
10353 VLOG(1) << module->ToString();
10354 const auto root = module->entry_computation()->root_instruction();
10355 EXPECT_THAT(root,
10356 AllOf(op::AllReduce(op::Parameter(0)), op::Shape("f32[1,2]")));
10357 EXPECT_EQ(root->replica_groups().size(), 2);
10358 }
10359
TEST_F(SpmdPartitioningTest,SubgroupIllegalManualAllReduce)10360 TEST_F(SpmdPartitioningTest, SubgroupIllegalManualAllReduce) {
10361 absl::string_view hlo_string = R"(
10362 HloModule module
10363
10364 sum {
10365 a = f32[] parameter(0)
10366 b = f32[] parameter(1)
10367 ROOT add = f32[] add(a, b)
10368 }
10369
10370 ENTRY entry {
10371 param = f32[2,2] parameter(0),
10372 sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10373 ROOT all-reduce = f32[2,2]{1,0} all-reduce(param), to_apply=sum,
10374 replica_groups={{1,0},{2,3}}, use_global_device_ids=true, channel_id=1,
10375 sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10376 }
10377 )";
10378
10379 auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
10380 EXPECT_FALSE(module_status.status().ok());
10381 EXPECT_THAT(module_status.status().ToString(),
10382 ::testing::HasSubstr("Manual all-reduce across devices that "
10383 "belong to different manual subgroups"));
10384 }
10385
TEST_F(SpmdPartitioningTest,SubgroupManualReduce)10386 TEST_F(SpmdPartitioningTest, SubgroupManualReduce) {
10387 absl::string_view hlo_string = R"(
10388 HloModule module
10389
10390 sum {
10391 a = f32[] parameter(0)
10392 b = f32[] parameter(1)
10393 ROOT add = f32[] add(a, b)
10394 }
10395
10396 ENTRY entry {
10397 constant = f32[] constant(0),
10398 sharding={devices=[2,2]0,1,2,3 last_tile_dims={manual,replicated}}
10399 param = f32[2,2] parameter(0),
10400 sharding={devices=[2,1,2]0,2,1,3 last_tile_dims={manual}}
10401 ROOT reduce = f32[2] reduce(param, constant), dimensions={0}, to_apply=sum,
10402 sharding={devices=[2,2]0,1,2,3 last_tile_dims={manual,replicated}}
10403 }
10404 )";
10405
10406 TF_ASSERT_OK_AND_ASSIGN(auto module,
10407 PartitionComputation(hlo_string, /*num_devices=*/4));
10408 VLOG(1) << module->ToString();
10409 const auto root = module->entry_computation()->root_instruction();
10410 EXPECT_THAT(root,
10411 op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())));
10412 EXPECT_EQ(root->replica_groups().size(), 2);
10413 }
10414
TEST_F(SpmdPartitioningTest,ScatterPreferUpdateIndexIfSmaller)10415 TEST_F(SpmdPartitioningTest, ScatterPreferUpdateIndexIfSmaller) {
10416 absl::string_view hlo_string = R"(
10417 HloModule module
10418
10419 %scatter_add_reducer__33.191857 (parameter.191858: bf16[], parameter.191859: bf16[]) -> bf16[] {
10420 %parameter.191858 = bf16[] parameter(0)
10421 %parameter.191859 = bf16[] parameter(1)
10422 ROOT %add.4425 = bf16[] add(bf16[] %parameter.191858, bf16[] %parameter.191859)
10423 }
10424
10425 ENTRY entry {
10426 p1 = s32[2048,1024,1]{2,1,0} parameter(0)
10427 p2 = bf16[2048,1024,2040]{2,1,0} parameter(1)
10428 %constant.8635 = bf16[] constant(0)
10429 %broadcast.21781 = bf16[50048,2040]{1,0} broadcast(bf16[] %constant.8635), dimensions={},
10430 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10431 %select.1954 = s32[2048,1024,1]{2,1,0} copy(%p1), sharding={devices=[4,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10432 %slice.1274 = bf16[2048,1024,2040]{2,1,0} copy(%p2),
10433 sharding={devices=[4,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10434 %scatter.34 = bf16[50048,2040]{1,0} scatter(bf16[50048,2040]{1,0} %broadcast.21781,
10435 s32[2048,1024,1]{2,1,0} %select.1954, bf16[2048,1024,2040]{2,1,0} %slice.1274),
10436 update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0},
10437 index_vector_dim=2, to_apply=%scatter_add_reducer__33.191857,
10438 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10439 ROOT c = bf16[50048,2040]{1,0} copy(scatter.34),
10440 sharding={replicated}
10441 }
10442 )";
10443
10444 TF_ASSERT_OK_AND_ASSIGN(auto module,
10445 PartitionComputation(hlo_string, /*num_devices=*/8));
10446 VLOG(1) << module->ToString();
10447 const auto root = module->entry_computation()->root_instruction();
10448 EXPECT_THAT(root,
10449 op::Copy(op::AllReduce(op::Scatter(
10450 op::Shape("bf16[50048,2040]"), op::Shape("s32[512,1024,1]"),
10451 op::Shape("bf16[512,1024,2040]")))));
10452 }
10453
TEST_F(SpmdPartitioningTest,ScatterPreferTrivialIfSmallerThanIndices)10454 TEST_F(SpmdPartitioningTest, ScatterPreferTrivialIfSmallerThanIndices) {
10455 absl::string_view hlo_string = R"(
10456 HloModule module
10457
10458 %scatter_add_reducer__33.191857 (parameter.191858: bf16[], parameter.191859: bf16[]) -> bf16[] {
10459 %parameter.191858 = bf16[] parameter(0)
10460 %parameter.191859 = bf16[] parameter(1)
10461 ROOT %add.4425 = bf16[] add(bf16[] %parameter.191858, bf16[] %parameter.191859)
10462 }
10463
10464 ENTRY entry {
10465 p1 = s32[32,512,3]{2,1,0} parameter(0)
10466 p2 = bf16[32,512]{1,0} parameter(1)
10467 %constant.8635 = bf16[] constant(0)
10468 %broadcast.21781 = bf16[32,512,50001]{2,1,0} broadcast(bf16[] %constant.8635), dimensions={},
10469 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10470 %select.1954 = s32[32,512,3]{2,1,0} copy(%p1), sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10471 %slice.1274 = bf16[32,512]{1,0} copy(%p2),
10472 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10473 %scatter.34 = bf16[32,512,50001]{2,1,0} scatter(bf16[32,512,50001]{2,1,0} %broadcast.21781,
10474 s32[32,512,3]{2,1,0} %select.1954, bf16[32,512]{1,0} %slice.1274),
10475 update_window_dims={}, inserted_window_dims={0,1,2}, scatter_dims_to_operand_dims={0,1,2},
10476 index_vector_dim=2, to_apply=%scatter_add_reducer__33.191857,
10477 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
10478 ROOT c = bf16[32,512,50001]{2,1,0} copy(scatter.34),
10479 sharding={replicated}
10480 }
10481 )";
10482
10483 TF_ASSERT_OK_AND_ASSIGN(auto module,
10484 PartitionComputation(hlo_string, /*num_devices=*/8));
10485 VLOG(1) << module->ToString();
10486 const auto root = module->entry_computation()->root_instruction();
10487 EXPECT_THAT(
10488 root,
10489 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
10490 _,
10491 op::Scatter(op::Shape("bf16[32,128,50001]"),
10492 op::Shape("s32[32,512,3]"), op::Shape("bf16[32,512]")),
10493 _, _, _))));
10494 }
10495
TEST_F(SpmdPartitioningTest,GatherOperandPassthroughIndexPassthrough)10496 TEST_F(SpmdPartitioningTest, GatherOperandPassthroughIndexPassthrough) {
10497 const char* const hlo_string = R"(
10498 HloModule module
10499
10500 ENTRY entry {
10501 %input = f32[2,9] parameter(0), sharding={replicated}
10502 %indices = s32[7] parameter(1), sharding={replicated}
10503 %input.copy = f32[2,9] copy(%input), sharding={devices=[1,2,2]1,0,3,2 last_tile_dim_replicate}
10504 %indices.copy = s32[7] copy(%indices), sharding={devices=[2,2]1,2,3,0 last_tile_dim_replicate}
10505 %gather = f32[7,9] gather(%input.copy, %indices.copy), offset_dims={1},
10506 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
10507 slice_sizes={1,9}, sharding={devices=[2,2]0,1,2,3}
10508 ROOT %copy = f32[7,9] copy(%gather), sharding={replicated}
10509 })";
10510
10511 TF_ASSERT_OK_AND_ASSIGN(auto module,
10512 PartitionComputation(hlo_string, /*num_devices=*/4));
10513 VLOG(1) << module->ToString();
10514 const HloInstruction* gather = FindInstruction(module.get(), "gather.1");
10515 EXPECT_NE(gather, nullptr);
10516 EXPECT_THAT(gather,
10517 AllOf(op::Shape("f32[4,5]"),
10518 op::Gather(op::Shape("f32[2,5]"), op::Shape("s32[4]"))));
10519 }
10520
TEST_F(SpmdPartitioningTest,GatherIndexPassthroughTrivialSlice)10521 TEST_F(SpmdPartitioningTest, GatherIndexPassthroughTrivialSlice) {
10522 const char* const hlo_string = R"(
10523 HloModule module
10524
10525 ENTRY entry {
10526 %input = f32[17,9] parameter(0)
10527 %indices = s32[2,3] parameter(1)
10528 %input.copy = f32[17,9] copy(%input), sharding={devices=[2,1,2]3,2,1,0 last_tile_dim_replicate}
10529 %indices.copy = s32[2,3] copy(%indices), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
10530 %gather = f32[2,3,9] gather(%input.copy, %indices.copy), offset_dims={2},
10531 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
10532 slice_sizes={1,9}, sharding={devices=[2,1,1,2]1,0,3,2 last_tile_dim_replicate}
10533 ROOT %copy = f32[2,3,9] copy(%gather), sharding={replicated}
10534 })";
10535
10536 TF_ASSERT_OK_AND_ASSIGN(auto module,
10537 PartitionComputation(hlo_string, /*num_devices=*/4));
10538 VLOG(1) << module->ToString();
10539 const HloInstruction* gather = FindInstruction(module.get(), "gather.1");
10540 EXPECT_NE(gather, nullptr);
10541 EXPECT_THAT(gather,
10542 AllOf(op::Shape("f32[1,3,9]"),
10543 op::Gather(op::Shape("f32[9,9]"), op::Shape("s32[1,3]"))));
10544 }
10545
TEST_F(SpmdPartitioningTest,GatherReplicatedCorrectOutput)10546 TEST_F(SpmdPartitioningTest, GatherReplicatedCorrectOutput) {
10547 const char* const hlo_string = R"(
10548 HloModule module
10549
10550 ENTRY entry {
10551 %input = f32[64,2,250112] parameter(0), sharding={
10552 devices=[16,1,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,
10553 23,24,25,26,27,28,29,30,31}
10554 %indices = s32[10,1] parameter(1), sharding={replicated}
10555 %input.copy = f32[64,2,250112] copy(%input), sharding={
10556 devices=[16,1,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,
10557 23,24,25,26,27,28,29,30,31}
10558 %indices.copy = s32[10,1] copy(%indices), sharding={replicated}
10559 %gather = f32[64,2,10] gather(f32[64,2,250112] %input,
10560 s32[10,1]{1,0} %indices.copy), offset_dims={0,1}, collapsed_slice_dims={2},
10561 start_index_map={2}, index_vector_dim=1, slice_sizes={64,2,1},
10562 sharding={devices=[16,1,1,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,
10563 19,20,21,22,23,24,25,26,27,28,29,30,
10564 31 last_tile_dim_replicate}
10565 ROOT %copy = (f32[64,2,10]) tuple(gather), sharding={{devices=[16,1,1,2]0,1,2,
10566 3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,
10567 30,31 last_tile_dim_replicate}}
10568 })";
10569
10570 TF_ASSERT_OK_AND_ASSIGN(auto module,
10571 PartitionComputation(hlo_string, /*num_devices=*/32));
10572 VLOG(1) << module->ToString();
10573 EXPECT_THAT(module->entry_computation()->root_instruction(),
10574 op::Shape("(f32[4,2,10])"));
10575 }
10576
TEST_F(SpmdPartitioningTest,GatherTrivialRestoreSharding)10577 TEST_F(SpmdPartitioningTest, GatherTrivialRestoreSharding) {
10578 const char* const hlo_string = R"(
10579 HloModule module
10580
10581 ENTRY entry {
10582 %input = bf16[250112,4096] parameter(0), sharding={replicated}
10583 %cpy.input = bf16[250112,4096] copy(%input), sharding={
10584 devices=[32,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,
10585 23,24,25,26,27,28,29,30,31}
10586 %indices = s32[64,1,1] parameter(1), sharding={replicated}
10587 %cpy.indices = s32[64,1,1] copy(%indices), sharding={replicated}
10588 %gather = bf16[64,1,4096] gather(bf16[250112,4096] %cpy.input, s32[64,1,1] %cpy.indices),
10589 offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0},
10590 index_vector_dim=2, slice_sizes={1,4096}, sharding={replicated}
10591 ROOT %copy = bf16[64,1,4096] copy(gather), sharding={replicated}
10592 })";
10593
10594 TF_ASSERT_OK_AND_ASSIGN(auto module,
10595 PartitionComputation(hlo_string, /*num_devices=*/32));
10596 VLOG(1) << module->ToString();
10597 EXPECT_THAT(module->entry_computation()->root_instruction(),
10598 op::Shape("bf16[64,1,4096]"));
10599 EXPECT_THAT(module->entry_computation()->root_instruction(),
10600 op::Copy(op::AllReduce(op::Select(
10601 _, _, op::Gather(op::Shape("bf16[7816,4096]"), _)))));
10602 }
10603
TEST_F(SpmdPartitioningTest,SliceTo1)10604 TEST_F(SpmdPartitioningTest, SliceTo1) {
10605 const char* const hlo_string = R"(
10606 HloModule module
10607
10608 ENTRY entry {
10609 %input = f32[512] parameter(0), sharding={devices=[4]0,1,2,3}
10610 ROOT slice.134 = f32[1] slice(input), slice={[0:1]},
10611 sharding={devices=[4]0,1,2,3}
10612 })";
10613
10614 TF_ASSERT_OK_AND_ASSIGN(auto module,
10615 PartitionComputation(hlo_string, /*num_devices=*/4));
10616 VLOG(1) << module->ToString();
10617 EXPECT_THAT(module->entry_computation()->root_instruction(),
10618 AllOf(op::Slice(op::Parameter()), op::Shape("f32[1]")));
10619 }
10620
TEST_F(SpmdPartitioningTest,SliceTo1_8Shards)10621 TEST_F(SpmdPartitioningTest, SliceTo1_8Shards) {
10622 const char* const hlo_string = R"(
10623 HloModule module
10624
10625 ENTRY entry {
10626 %input = f32[4,4] parameter(0), sharding={devices=[4,2]0,1,2,3,4,5,6,7}
10627 ROOT %slice = f32[1,4] slice(%input), slice={[0:1], [0:4]},
10628 sharding={devices=[4,2]0,1,2,3,4,5,6,7}
10629 })";
10630
10631 TF_ASSERT_OK_AND_ASSIGN(auto module,
10632 PartitionComputation(hlo_string, /*num_devices=*/8));
10633 VLOG(1) << module->ToString();
10634 EXPECT_THAT(module->entry_computation()->root_instruction(),
10635 AllOf(op::Copy(op::Parameter()), op::Shape("f32[1,2]")));
10636 }
10637
TEST_F(SpmdPartitioningTest,SliceTo1PartialReplicate)10638 TEST_F(SpmdPartitioningTest, SliceTo1PartialReplicate) {
10639 const char* const hlo_string = R"(
10640 HloModule module
10641
10642 ENTRY entry {
10643 %input = f32[16] parameter(0),
10644 sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
10645 ROOT slice.134 = f32[1] slice(input), slice={[0:1]},
10646 sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate}
10647 })";
10648
10649 TF_ASSERT_OK_AND_ASSIGN(auto module,
10650 PartitionComputation(hlo_string, /*num_devices=*/4));
10651 VLOG(1) << module->ToString();
10652 EXPECT_THAT(module->entry_computation()->root_instruction(),
10653 AllOf(op::Slice(op::Parameter()), op::Shape("f32[1]")));
10654 }
10655
TEST_F(SpmdPartitioningTest,SliceTo2)10656 TEST_F(SpmdPartitioningTest, SliceTo2) {
10657 const char* const hlo_string = R"(
10658 HloModule module
10659
10660 ENTRY entry {
10661 %input = f32[512] parameter(0), sharding={devices=[4]0,1,2,3}
10662 ROOT slice.134 = f32[2] slice(input), slice={[0:2]},
10663 sharding={devices=[4]0,1,2,3}
10664 })";
10665
10666 TF_ASSERT_OK_AND_ASSIGN(auto module,
10667 PartitionComputation(hlo_string, /*num_devices=*/4));
10668 VLOG(1) << module->ToString();
10669 auto slice1 = AllOf(op::Slice(op::Parameter()), op::Shape("f32[2]"));
10670 auto halo =
10671 op::CollectivePermute(AllOf(op::Slice(slice1), op::Shape("f32[1]")));
10672 auto slice_self = AllOf(op::Slice(slice1), op::Shape("f32[1]"));
10673 EXPECT_THAT(
10674 module->entry_computation()->root_instruction(),
10675 op::Copy(AllOf(op::DynamicSlice(op::Concatenate(halo, slice_self), _),
10676 op::Shape("f32[1]"))));
10677 }
10678
TEST_F(SpmdPartitioningTest,SliceToMiddle2)10679 TEST_F(SpmdPartitioningTest, SliceToMiddle2) {
10680 const char* const hlo_string = R"(
10681 HloModule module
10682
10683 ENTRY entry {
10684 %input = f32[512] parameter(0), sharding={devices=[8]0,1,2,3,4,5,6,7}
10685 ROOT %slice = f32[2] slice(input), slice={[300:302]},
10686 sharding={devices=[8]0,1,2,3,4,5,6,7}
10687 })";
10688
10689 TF_ASSERT_OK_AND_ASSIGN(auto module,
10690 PartitionComputation(hlo_string, /*num_devices=*/8));
10691 auto slice = AllOf(op::Slice(op::Parameter()), op::Shape("f32[2]"));
10692 auto halo1 = AllOf(op::CollectivePermute(slice), op::Shape("f32[2]"));
10693 auto halo2 =
10694 AllOf(op::CollectivePermute(op::Slice(slice)), op::Shape("f32[1]"));
10695 VLOG(1) << module->ToString();
10696 EXPECT_THAT(module->entry_computation()->root_instruction(),
10697 op::Copy(AllOf(op::DynamicSlice(op::Concatenate(halo1, halo2), _),
10698 op::Shape("f32[1]"))));
10699 }
10700
TEST_F(SpmdPartitioningTest,SliceToMiddle2PartiallyReplicated)10701 TEST_F(SpmdPartitioningTest, SliceToMiddle2PartiallyReplicated) {
10702 const char* const hlo_string = R"(
10703 HloModule module
10704
10705 ENTRY entry {
10706 %input = f32[512] parameter(0),
10707 sharding={devices=[8,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
10708 ROOT %slice = f32[2] slice(input), slice={[300:302]},
10709 sharding={devices=[8,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
10710 })";
10711
10712 TF_ASSERT_OK_AND_ASSIGN(auto module,
10713 PartitionComputation(hlo_string, /*num_devices=*/16));
10714 auto slice = AllOf(op::Slice(op::Parameter()), op::Shape("f32[2]"));
10715 auto halo1 = AllOf(op::CollectivePermute(slice), op::Shape("f32[2]"));
10716 auto halo2 =
10717 AllOf(op::CollectivePermute(op::Slice(slice)), op::Shape("f32[1]"));
10718 VLOG(1) << module->ToString();
10719 EXPECT_THAT(module->entry_computation()->root_instruction(),
10720 op::Copy(AllOf(op::DynamicSlice(op::Concatenate(halo1, halo2), _),
10721 op::Shape("f32[1]"))));
10722 }
10723
TEST_F(SpmdPartitioningTest,PartialDusReplicate)10724 TEST_F(SpmdPartitioningTest, PartialDusReplicate) {
10725 const char* const hlo_string = R"(
10726 HloModule module
10727
10728 ENTRY entry {
10729 %input = f32[3,2] parameter(0),
10730 sharding={devices=[8,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
10731 ROOT %copy = f32[3,2] copy(input), sharding={replicated}
10732 })";
10733
10734 TF_ASSERT_OK_AND_ASSIGN(auto module,
10735 PartitionComputation(hlo_string, /*num_devices=*/16));
10736 VLOG(1) << module->ToString();
10737 auto dus =
10738 AllOf(op::Shape("f32[3,2]"),
10739 op::DynamicUpdateSlice(op::Broadcast(),
10740 op::Select(_, op::Parameter(0), _), _, _));
10741 EXPECT_THAT(module->entry_computation()->root_instruction(),
10742 op::Copy(AllOf(op::AllReduce(op::AllReduce(dus)))));
10743 }
10744
TEST_F(SpmdPartitioningTest,GatherPassthrough)10745 TEST_F(SpmdPartitioningTest, GatherPassthrough) {
10746 const char* const hlo_string = R"(
10747 HloModule module
10748
10749 ENTRY entry {
10750 p = f32[16,64,768,768]{3,2,1,0} parameter(0), sharding={replicated}
10751 c = f32[16,64,768,768]{3,2,1,0} copy(p), sharding={devices=[1,4,1,1]0,1,2,3}
10752 constant.1669 = s32[] constant(0)
10753 iota.1012 = s32[6]{0} iota(), iota_dimension=0, sharding={replicated}
10754 constant.1748 = s32[] constant(128), sharding={replicated}
10755 broadcast.2642 = s32[6]{0} broadcast(constant.1748), dimensions={}, sharding={replicated}
10756 multiply.92 = s32[6]{0} multiply(iota.1012, broadcast.2642), sharding={replicated}
10757 broadcast.2643 = s32[2,6]{1,0} broadcast(multiply.92), dimensions={1}, sharding={replicated}
10758 transpose.542 = s32[6,2]{0,1} transpose(broadcast.2643), dimensions={1,0}, sharding={replicated}
10759 pad.19 = s32[6,4]{1,0} pad(transpose.542, constant.1669), padding=0_0x2_0, sharding={replicated}
10760 ROOT gather.1 = f32[16,64,6,128,128]{4,3,2,1,0} gather(c, pad.19), offset_dims={0,1,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3}, index_vector_dim=1, slice_sizes={16,64,128,128}, sharding={devices=[1,4,1,1,1]0,1,2,3}
10761 })";
10762
10763 TF_ASSERT_OK_AND_ASSIGN(auto module,
10764 PartitionComputation(hlo_string, /*num_devices=*/4));
10765 VLOG(1) << module->ToString();
10766
10767 auto root = module->entry_computation()->root_instruction();
10768 EXPECT_THAT(root, AllOf(op::Gather(), op::Shape("f32[16,16,6,128,128]")));
10769 }
10770
TEST_F(SpmdPartitioningTest,ComplexReshardFromPartialReplicate)10771 TEST_F(SpmdPartitioningTest, ComplexReshardFromPartialReplicate) {
10772 const char* const hlo_string = R"(
10773 HloModule module
10774
10775 ENTRY entry {
10776 %p = f32[4,15,4,16] parameter(0)
10777 %p.copy = f32[4,15,4,16] copy(p),
10778 sharding={devices=[1,1,1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
10779 %a = f32[4,15,4,16] add(p.copy, p.copy),
10780 sharding={devices=[1,1,1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
10781 ROOT %c2 = f32[4,15,4,16] copy(a), sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
10782 })";
10783
10784 TF_ASSERT_OK_AND_ASSIGN(auto module,
10785 PartitionComputation(hlo_string, /*num_devices=*/8));
10786 VLOG(1) << module->ToString();
10787
10788 EXPECT_THAT(
10789 module->entry_computation()->root_instruction(),
10790 op::Copy(op::Reshape(op::Reshape(op::Transpose(op::AllToAll(_))))));
10791 }
10792
TEST_F(SpmdPartitioningTest,ComplexReshardToPartialReplicate)10793 TEST_F(SpmdPartitioningTest, ComplexReshardToPartialReplicate) {
10794 const char* const hlo_string = R"(
10795 HloModule module
10796
10797 ENTRY entry {
10798 %p = f32[4,15,4,16] parameter(0)
10799 %p.copy = f32[4,15,4,16] copy(p),
10800 sharding={devices=[1,4,2,1]0,1,2,3,4,5,6,7}
10801 %a = f32[4,15,4,16] add(p.copy, p.copy),
10802 sharding={devices=[1,4,2,1]0,1,2,3,4,5,6,7}
10803 ROOT %c2 = f32[4,15,4,16] copy(a), sharding={devices=[1,1,1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
10804 })";
10805
10806 TF_ASSERT_OK_AND_ASSIGN(auto module,
10807 PartitionComputation(hlo_string, /*num_devices=*/8));
10808 VLOG(1) << module->ToString();
10809
10810 EXPECT_THAT(module->entry_computation()->root_instruction(),
10811 op::Copy(op::Reshape(op::Transpose(op::AllToAll(_)))));
10812 }
10813
TEST_F(SpmdPartitioningTest,ComplexReshardMoveMergeDimensionRight)10814 TEST_F(SpmdPartitioningTest, ComplexReshardMoveMergeDimensionRight) {
10815 const char* const hlo_string = R"(
10816 HloModule module
10817
10818 ENTRY entry {
10819 %p = f32[4,15,4,15] parameter(0)
10820 %p.copy = f32[4,15,4,15] copy(p),
10821 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10822 %a = f32[4,15,4,15] add(p.copy, p.copy),
10823 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10824 ROOT %c2 = f32[4,15,4,15] copy(a), sharding={devices=[1,1,1,8]0,2,4,6,1,3,5,7}
10825 })";
10826
10827 TF_ASSERT_OK_AND_ASSIGN(auto module,
10828 PartitionComputation(hlo_string, /*num_devices=*/8));
10829 VLOG(1) << module->ToString();
10830
10831 EXPECT_THAT(module->entry_computation()->root_instruction(),
10832 op::Copy(op::Reshape(
10833 op::Slice(op::Reshape(op::Transpose(op::AllToAll(_)))))));
10834 }
10835
TEST_F(SpmdPartitioningTest,ComplexReshardMoveMergeDimensionLeft)10836 TEST_F(SpmdPartitioningTest, ComplexReshardMoveMergeDimensionLeft) {
10837 const char* const hlo_string = R"(
10838 HloModule module
10839
10840 ENTRY entry {
10841 %p = f32[2,15,1,2] parameter(0)
10842 %p.copy = f32[2,15,1,2] copy(p),
10843 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10844 %a = f32[2,15,1,2] add(p.copy, p.copy),
10845 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10846 ROOT %c2 = f32[2,15,1,2] copy(a), sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
10847 })";
10848
10849 TF_ASSERT_OK_AND_ASSIGN(auto module,
10850 PartitionComputation(hlo_string, /*num_devices=*/8));
10851 VLOG(1) << module->ToString();
10852
10853 EXPECT_THAT(
10854 module->entry_computation()->root_instruction(),
10855 op::Copy(op::Reshape(op::Reshape(op::Transpose(op::AllToAll(_))))));
10856 }
10857
TEST_F(SpmdPartitioningTest,ComplexReshardMoveMergeDimensionLeftReorder)10858 TEST_F(SpmdPartitioningTest, ComplexReshardMoveMergeDimensionLeftReorder) {
10859 const char* const hlo_string = R"(
10860 HloModule module
10861
10862 ENTRY entry {
10863 %p = f32[4,15,4,16] parameter(0)
10864 %p.copy = f32[4,15,4,16] copy(p),
10865 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10866 %a = f32[4,15,4,16] add(p.copy, p.copy),
10867 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7}
10868 ROOT %c2 = f32[4,15,4,16] copy(a), sharding={devices=[1,8,1,1]0,2,4,6,1,3,5,7}
10869 })";
10870
10871 TF_ASSERT_OK_AND_ASSIGN(auto module,
10872 PartitionComputation(hlo_string, /*num_devices=*/8));
10873 VLOG(1) << module->ToString();
10874
10875 EXPECT_THAT(module->entry_computation()->root_instruction(),
10876 op::Copy(op::CollectivePermute(
10877 op::Reshape(op::Reshape(op::Transpose(op::AllToAll(_)))))));
10878 }
10879
TEST_F(SpmdPartitioningTest,PaddedConvReshard)10880 TEST_F(SpmdPartitioningTest, PaddedConvReshard) {
10881 const char* const hlo_string = R"(
10882 HloModule module
10883
10884 ENTRY entry {
10885 %p = bf16[16,256,256,384]{3,2,1,0} parameter(0)
10886 %p2 = bf16[3,3,384,384]{3,2,1,0} parameter(1)
10887 %p.copy = bf16[16,256,256,384]{3,2,1,0} copy(%p), sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
10888 %p2.copy = bf16[3,3,384,384]{3,2,1,0} copy(%p2), sharding={replicated}
10889 ROOT %convolution.10115 = bf16[16,256,256,384]{3,2,1,0} convolution(%p.copy, %p2.copy), window={size=3x3 pad=128_128x128_128 rhs_dilate=128x128}, dim_labels=b01f_01io->b01f, sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
10890 })";
10891
10892 TF_ASSERT_OK_AND_ASSIGN(auto module,
10893 PartitionComputation(hlo_string, /*num_devices=*/8));
10894 EXPECT_THAT(module->entry_computation()->root_instruction(),
10895 op::Convolution(
10896 op::DynamicSlice(op::Pad(_, op::Constant()), _, _, _, _), _));
10897 }
10898
TEST_F(SpmdPartitioningTest,KeepPartitionedNonSlicedDimension)10899 TEST_F(SpmdPartitioningTest, KeepPartitionedNonSlicedDimension) {
10900 const char* const hlo_string = R"(
10901 HloModule module
10902
10903 ENTRY entry {
10904 %p = bf16[16,128,128,384]{3,2,1,0} parameter(0), sharding={replicated}
10905 %constant.1165 = s32[] constant(0), sharding={replicated}
10906 constant.1151 = s32[] constant(192), sharding={replicated}
10907 broadcast.1152 = s32[2]{0} broadcast(constant.1151), dimensions={}, sharding={replicated}
10908 slice.1576 = s32[1]{0} slice(broadcast.1152), slice={[0:1]}, sharding={replicated}
10909 reshape.1888 = s32[] reshape(slice.1576), sharding={replicated}
10910 slice.1546 = s32[1]{0} slice(broadcast.1152), slice={[1:2]}, sharding={replicated}
10911 reshape.1890 = s32[] reshape(slice.1546), sharding={replicated}
10912 constant.861 = bf16[] constant(0), sharding={replicated}
10913 broadcast.862 = bf16[16,512,512,384]{3,2,1,0} broadcast(constant.861), dimensions={}, sharding={devices=[2,2,1,1]0,1,2,3}
10914 %c = bf16[16,128,128,384]{3,2,1,0} copy(p), sharding={devices=[2,2,1,1]0,1,2,3}
10915 add.228 = bf16[16,128,128,384]{3,2,1,0} add(c, c), sharding={devices=[2,2,1,1]0,1,2,3}
10916 ROOT dynamic-update-slice.111 = bf16[16,512,512,384]{3,2,1,0} dynamic-update-slice(broadcast.862, add.228, constant.1165, reshape.1888, reshape.1890, /*index=5*/constant.1165), sharding={devices=[2,2,1,1]0,1,2,3}
10917 })";
10918
10919 TF_ASSERT_OK_AND_ASSIGN(auto module,
10920 PartitionComputation(hlo_string, /*num_devices=*/4));
10921
10922 XLA_VLOG_LINES(1, module->ToString());
10923 EXPECT_THAT(module->entry_computation()->root_instruction(),
10924 op::DynamicSlice(AllOf(op::DynamicUpdateSlice(),
10925 op::Shape("bf16[8,512,512,384]")),
10926 _, _, _, _));
10927 }
10928
TEST_F(SpmdPartitioningTest,KeepPartitionedNonSlicedDimensionWithConstantIndices)10929 TEST_F(SpmdPartitioningTest,
10930 KeepPartitionedNonSlicedDimensionWithConstantIndices) {
10931 const char* const hlo_string = R"(
10932 HloModule module
10933
10934 ENTRY entry {
10935 p1 = bf16[16,192,192,384]{3,2,1,0} parameter(0), sharding={replicated}
10936 p2 = bf16[16,128,128,384]{3,2,1,0} parameter(1), sharding={replicated}
10937 c1 = bf16[16,192,192,384]{3,2,1,0} copy(p1), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10938 c2 = bf16[16,128,128,384]{3,2,1,0} copy(p2), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10939 constant.1163 = bf16[] constant(0), sharding={replicated}
10940 constant.1165 = s32[] constant(0), sharding={replicated}
10941 pad.179 = bf16[16,224,224,384]{3,2,1,0} pad(c1, constant.1163), padding=0_0x16_16x16_16x0_0, sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10942 add.439 = bf16[16,128,128,384]{3,2,1,0} add(c2, c2), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10943 constant.1070 = s32[] constant(48), sharding={replicated}
10944 dynamic-update-slice.128 = bf16[16,224,224,384]{3,2,1,0} dynamic-update-slice(pad.179, add.439, constant.1165, constant.1070, constant.1070, /*index=5*/constant.1165), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10945 ROOT c = bf16[16,224,224,384]{3,2,1,0} copy(dynamic-update-slice.128), sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
10946 })";
10947
10948 TF_ASSERT_OK_AND_ASSIGN(auto module,
10949 PartitionComputation(hlo_string, /*num_devices=*/8));
10950
10951 XLA_VLOG_LINES(1, module->ToString());
10952 EXPECT_THAT(
10953 module->entry_computation()->root_instruction(),
10954 op::Copy(op::DynamicSlice(
10955 AllOf(op::DynamicUpdateSlice(), op::Shape("bf16[8,224, 224,384]")), _,
10956 _, _, _)));
10957 }
10958
TEST_F(SpmdPartitioningTest,CustomCallManualSharding)10959 TEST_F(SpmdPartitioningTest, CustomCallManualSharding) {
10960 const char* const hlo_string = R"(
10961 HloModule pjit_xmap_dummy.5
10962
10963 ENTRY %main.21 (Arg_0.1: f32[4,4,8], Arg_1.2: f32[4,8]) -> (f32[4,4,8], f32[4]) {
10964 %Arg_0.1 = f32[4,4,8]{2,1,0} parameter(0), sharding={devices=[4,1,1]0,1,2,3}
10965 %copy.3 = f32[4,4,8]{2,1,0} copy(f32[4,4,8]{2,1,0} %Arg_0.1), sharding={devices=[4,1,1]0,1,2,3}
10966 %custom-call.4 = f32[1,4,8]{2,1,0} custom-call(f32[4,4,8]{2,1,0} %copy.3), custom_call_target="SPMDFullToShardShape", sharding={manual}
10967 %reshape.7 = f32[4,8]{1,0} reshape(f32[1,4,8]{2,1,0} %custom-call.4), sharding={manual}
10968 %Arg_1.2 = f32[4,8]{1,0} parameter(1), sharding={replicated}
10969 %copy.2 = f32[4,8]{1,0} copy(f32[4,8]{1,0} %Arg_1.2), sharding={replicated}
10970 %custom-call.6 = f32[4,8]{1,0} custom-call(f32[4,8]{1,0} %copy.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
10971 %custom-call.8 = (f32[4,8]{1,0}, f32[1]{0}) custom-call(f32[4,8]{1,0} %reshape.7, f32[4,8]{1,0} %custom-call.6), custom_call_target="dummy", operand_layout_constraints={f32[4,8]{1,0}, f32[4,8]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, sharding={{manual}, {manual}}
10972 %get-tuple-element.9 = f32[4,8]{1,0} get-tuple-element((f32[4,8]{1,0}, f32[1]{0}) %custom-call.8), index=0, sharding={manual}
10973 %reshape.11 = f32[1,4,8]{2,1,0} reshape(f32[4,8]{1,0} %get-tuple-element.9), sharding={manual}
10974 %copy.1 = f32[1,4,8]{2,1,0} copy(f32[1,4,8]{2,1,0} %reshape.11), sharding={manual}
10975 %custom-call.14 = f32[4,4,8]{2,1,0} custom-call(f32[1,4,8]{2,1,0} %copy.1), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1]0,1,2,3}
10976 %reshape.18 = f32[4,4,8]{2,1,0} reshape(f32[4,4,8]{2,1,0} %custom-call.14), sharding={devices=[4,1,1]0,1,2,3}
10977 %get-tuple-element.10 = f32[1]{0} get-tuple-element((f32[4,8]{1,0}, f32[1]{0}) %custom-call.8), index=1, sharding={manual}
10978 %reshape.12 = f32[1,1]{1,0} reshape(f32[1]{0} %get-tuple-element.10), sharding={manual}
10979 %copy = f32[1,1]{1,0} copy(f32[1,1]{1,0} %reshape.12), sharding={manual}
10980 %custom-call.16 = f32[4,1]{1,0} custom-call(f32[1,1]{1,0} %copy), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1]0,1,2,3}
10981 %reshape.17 = f32[4]{0} reshape(f32[4,1]{1,0} %custom-call.16), sharding={devices=[4]0,1,2,3}
10982 %reshape.19 = f32[4]{0} reshape(f32[4]{0} %reshape.17), sharding={devices=[4]0,1,2,3}
10983 ROOT %tuple.20 = (f32[4,4,8]{2,1,0}, f32[4]{0}) tuple(f32[4,4,8]{2,1,0} %reshape.18, f32[4]{0} %reshape.19), sharding={{replicated}, {replicated}}
10984 }
10985 )";
10986
10987 TF_ASSERT_OK_AND_ASSIGN(auto module,
10988 PartitionComputation(hlo_string, /*num_devices=*/4));
10989
10990 XLA_VLOG_LINES(1, module->ToString());
10991 EXPECT_THAT(module->entry_computation()->root_instruction(),
10992 op::Tuple(op::AllReduce(op::DynamicUpdateSlice(
10993 _, op::Shape("f32[1,4,8]"), _, _, _)),
10994 op::AllReduce(op::DynamicUpdateSlice(
10995 _, op::Shape("f32[1]"), _))));
10996 }
10997
10998 } // namespace
10999 } // namespace spmd
11000 } // namespace xla
11001