1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "absl/strings/str_replace.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/primitive_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/env.h"
32
33 // Tests cross-GPU operations.
34 //
35 // This test requires at least four GPUs. For instructions on running this
36 // within Google, see go/multi-gpu-unit-test.
37
38 namespace xla {
39 namespace {
40
41 class CollectiveOpsTest : public HloTestBase {
42 public:
SetUpTestSuite()43 static void SetUpTestSuite() {
44 // Not needed structly, since this test exercises cross replica collective
45 // permute which does not use NCCL. But keeping it here for testing.
46 tensorflow::setenv("NCCL_LAUNCH_MODE", "PARALLEL", /*overwrite=*/1);
47 HloTestBase::SetUpTestSuite();
48 }
49
50 protected:
MakeCrsModule(const Shape & shape,std::vector<std::vector<int64_t>> replica_groups,const HloModuleConfig & config,std::string op="add",std::string datatype="f32")51 std::unique_ptr<HloModule> MakeCrsModule(
52 const Shape& shape, std::vector<std::vector<int64_t>> replica_groups,
53 const HloModuleConfig& config, std::string op = "add",
54 std::string datatype = "f32") {
55 std::string hlo_template = R"(
56 HloModule test
57
58 apply_op {
59 x = DATATYPE[] parameter(0)
60 y = DATATYPE[] parameter(1)
61 ROOT apply_op = DATATYPE[] OP(x, y)
62 }
63
64 ENTRY test_computation {
65 p = SHAPE parameter(0)
66 p2 = SHAPE bitcast(p)
67 crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
68 copy = SHAPE copy(crs)
69 ROOT out = SHAPE bitcast(copy)
70 }
71 )";
72 std::vector<std::string> replica_group_strs;
73 replica_group_strs.reserve(replica_groups.size());
74 for (const auto& g : replica_groups) {
75 replica_group_strs.push_back(
76 absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
77 }
78 std::string shape_str = shape.ToString(/*print_layout=*/false);
79 if (shape_str == "f32[1]") {
80 // Exercise the scalar codepath.
81 hlo_template = absl::StrReplaceAll(
82 hlo_template,
83 {{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"},
84 {"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"},
85 {"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}});
86 }
87 std::string parameterized_hlo = absl::StrReplaceAll(
88 hlo_template,
89 {{"SHAPE", shape_str},
90 {"REPLICA_GROUPS",
91 absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))},
92 {"OP", op},
93 {"DATATYPE", datatype}});
94 return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie();
95 }
96
97 template <typename LiteralType>
TestTwoReplicasOneOperand(std::string op,Literal input_value,Literal expected_value)98 void TestTwoReplicasOneOperand(std::string op, Literal input_value,
99 Literal expected_value) {
100 const int kNumReplicas = 2;
101 std::string dtype = primitive_util::LowercasePrimitiveTypeName(
102 primitive_util::NativeToPrimitiveType<LiteralType>());
103 auto config = GetModuleConfigForTest();
104 config.set_replica_count(kNumReplicas);
105 auto module = MakeCrsModule(
106 /*shape=*/input_value.shape(),
107 /*replica_groups=*/{}, config,
108 /*op=*/op, /*datatype=*/dtype);
109 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
110 ExecuteReplicated(std::move(module), {&input_value},
111 /*num_replicas=*/kNumReplicas,
112 /*use_threads=*/true));
113 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
114 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
115 expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5}));
116 }
117 }
118
119 template <typename LiteralType>
TestAllOpsForReduce()120 void TestAllOpsForReduce() {
121 auto cast = [&](int value) { return static_cast<LiteralType>(value); };
122 auto to_literal = [&](absl::Span<const LiteralType> values) {
123 return LiteralUtil::CreateR1<LiteralType>(values);
124 };
125 Literal input_value = to_literal({cast(1), cast(2), cast(3)});
126 TestTwoReplicasOneOperand<LiteralType>(
127 "add",
128 /*input_value=*/input_value.Clone(),
129 /*expected_value=*/to_literal({cast(2), cast(4), cast(6)}));
130 TestTwoReplicasOneOperand<LiteralType>(
131 "multiply",
132 /*input_value=*/input_value.Clone(),
133 /*expected_value=*/to_literal({cast(1), cast(4), cast(9)}));
134 TestTwoReplicasOneOperand<LiteralType>(
135 "maximum",
136 /*input_value=*/input_value.Clone(),
137 /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
138 TestTwoReplicasOneOperand<LiteralType>(
139 "minimum",
140 /*input_value=*/input_value.Clone(),
141 /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
142 }
143 };
144
145 // Returns the non-empty subsets of {0, 1, ..., n}. For example,
146 // PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
PowerSetOfIota(int64_t n)147 std::vector<std::vector<int64_t>> PowerSetOfIota(int64_t n) {
148 std::vector<std::vector<int64_t>> power_set;
149 for (int64_t i = 1; i < (1 << n); ++i) {
150 power_set.emplace_back();
151 for (int64_t j = 0; j < n; ++j) {
152 if (i & (1 << j)) {
153 power_set.back().push_back(j);
154 }
155 }
156 }
157 return power_set;
158 }
159
160 // Makes a DeviceAssignment assigning replica-id i to devices[i].
MakeDeviceAssn(std::vector<int64_t> devices)161 DeviceAssignment MakeDeviceAssn(std::vector<int64_t> devices) {
162 DeviceAssignment assn(/*replica_count=*/devices.size(),
163 /*computation_count=*/1);
164 for (int64_t i = 0; i < devices.size(); ++i) {
165 assn(i, 0) = devices[i];
166 }
167 return assn;
168 }
169
170 template <typename T>
ToHalf(T value)171 static Eigen::half ToHalf(T value) {
172 return static_cast<Eigen::half>(value);
173 }
174
XLA_TEST_F(CollectiveOpsTest,AllReduce_sum_float32_2D)175 XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) {
176 TestTwoReplicasOneOperand<float>(
177 "add",
178 /*input_value=*/LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
179 /*expected_value=*/LiteralUtil::CreateR2<float>({{2, 4}, {6, 8}}));
180 }
181
XLA_TEST_F(CollectiveOpsTest,AllReduceSingleOutput_float32)182 XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
183 TestTwoReplicasOneOperand<float>(
184 "add",
185 /*input_value=*/LiteralUtil::CreateR1<float>({1}),
186 /*expected_value=*/LiteralUtil::CreateR1<float>({2}));
187 }
188
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int8)189 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
190 TestAllOpsForReduce<int8_t>();
191 }
192
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint8)193 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
194 TestAllOpsForReduce<uint8_t>();
195 }
196
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint32)197 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
198 TestAllOpsForReduce<uint32_t>();
199 }
200
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int32)201 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
202 TestAllOpsForReduce<int32_t>();
203 }
204
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int64)205 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
206 TestAllOpsForReduce<int64_t>();
207 }
208
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint64)209 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
210 TestAllOpsForReduce<uint64_t>();
211 }
212
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_float32)213 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
214 TestAllOpsForReduce<float>();
215 }
216
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_double)217 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
218 TestAllOpsForReduce<double>();
219 }
220
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_half)221 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
222 TestAllOpsForReduce<Eigen::half>();
223 }
224
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceTwoReplicasOneOperand_bfloat16))225 XLA_TEST_F(CollectiveOpsTest,
226 DISABLED_ON_CPU(AllReduceTwoReplicasOneOperand_bfloat16)) {
227 TestAllOpsForReduce<bfloat16>();
228 }
229
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex64))230 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex64)) {
231 TestTwoReplicasOneOperand<complex64>(
232 "add",
233 /*input_value=*/LiteralUtil::CreateR1<complex64>({{1, 2}, {3, 4}}),
234 /*expected_value=*/LiteralUtil::CreateR1<complex64>({{2, 4}, {6, 8}}));
235 }
236
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex128))237 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex128)) {
238 TestTwoReplicasOneOperand<complex128>(
239 "add",
240 /*input_value=*/LiteralUtil::CreateR1<complex128>({{1, 2}, {3, 4}}),
241 /*expected_value=*/LiteralUtil::CreateR1<complex128>({{2, 4}, {6, 8}}));
242 }
243
XLA_TEST_F(CollectiveOpsTest,AllReduceAnd_Pred)244 XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
245 // Test with equal elements.
246 TestTwoReplicasOneOperand<bool>(
247 "and",
248 /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
249 /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
250
251 // Test with {true, false}.
252 const char* hlo_module = R"(
253 HloModule test
254
255 apply_op {
256 x = pred[] parameter(0)
257 y = pred[] parameter(1)
258 ROOT apply_op = pred[] and(x, y)
259 }
260
261 ENTRY test_computation {
262 id = u32[] replica-id()
263 c = u32[] constant(0)
264 p = pred[] compare(id, c), direction=EQ
265 p2 = pred[1] bitcast(p)
266 crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
267 copy = pred[1] copy(crs)
268 ROOT out = pred[1] bitcast(copy)
269 }
270 )";
271
272 auto config = GetModuleConfigForTest();
273 config.set_replica_count(2);
274 auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
275 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
276 ExecuteReplicated(std::move(module), {},
277 /*num_replicas=*/2,
278 /*use_threads=*/true));
279 for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
280 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
281 results[replica_idx]));
282 }
283 }
284
XLA_TEST_F(CollectiveOpsTest,AllReduceOr_Pred)285 XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
286 // Test with equal elements.
287 TestTwoReplicasOneOperand<bool>(
288 "or",
289 /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
290 /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
291
292 // Test with {true, false}.
293 const char* hlo_module = R"(
294 HloModule test
295
296 apply_op {
297 x = pred[] parameter(0)
298 y = pred[] parameter(1)
299 ROOT apply_op = pred[] or(x, y)
300 }
301
302 ENTRY test_computation {
303 id = u32[] replica-id()
304 c = u32[] constant(0)
305 p = pred[] compare(id, c), direction=EQ
306 p2 = pred[1] bitcast(p)
307 crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
308 copy = pred[1] copy(crs)
309 ROOT out = pred[1] bitcast(copy)
310 }
311 )";
312
313 auto config = GetModuleConfigForTest();
314 config.set_replica_count(2);
315 auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
316 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
317 ExecuteReplicated(std::move(module), {},
318 /*num_replicas=*/2,
319 /*use_threads=*/true));
320 for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
321 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
322 results[replica_idx]));
323 }
324 }
325
326 // Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
327 // devices in sequence.
XLA_TEST_F(CollectiveOpsTest,AllReduce_AllCombinations)328 XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
329 const int64_t kNumDevices = 4;
330 const int64_t kNumElems = 1024;
331
332 for (std::vector<int64_t> devices : PowerSetOfIota(kNumDevices)) {
333 SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
334 absl::StrJoin(devices, ", ")));
335
336 DeviceAssignment device_assn = MakeDeviceAssn(devices);
337
338 auto config = GetModuleConfigForTest();
339 config.set_replica_count(devices.size());
340 config.set_static_device_assignment(device_assn);
341
342 std::vector<float> input_vec(kNumElems);
343 absl::c_iota(input_vec, 0);
344 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
345
346 auto module = MakeCrsModule(input_literal.shape(),
347 /*replica_groups=*/{}, config);
348
349 TF_ASSERT_OK_AND_ASSIGN(
350 std::vector<Literal> results,
351 ExecuteReplicated(std::move(module), {&input_literal},
352 /*num_replicas=*/devices.size(), &device_assn,
353 /*run_hlo_passes=*/true, /*use_threads=*/true));
354 }
355 }
356
357 // Runs the same executable many times concurrently. The all-reduces should not
358 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ManyConcurrentAllReduces)359 XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
360 const int64_t kNumElems = 1024;
361 const int64_t kNumThreads = 200;
362 const int64_t kRunsPerThread = 10;
363
364 std::vector<float> input_vec(kNumElems);
365 absl::c_iota(input_vec, 0);
366 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
367
368 auto config = GetModuleConfigForTest();
369 config.set_replica_count(2);
370 auto executable =
371 test_runner_
372 .CreateExecutable(MakeCrsModule(input_literal.shape(),
373 /*replica_groups=*/{}, config),
374 /*run_hlo_passes=*/true)
375 .ValueOrDie();
376 std::vector<int64_t> devices = {0, 1};
377 auto device_assn = MakeDeviceAssn(devices);
378
379 HloRunner::ReplicatedExecuteOptions opts;
380 opts.num_replicas = devices.size();
381 opts.use_threads = true;
382 opts.arguments.push_back(&input_literal);
383
384 tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread);
385 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(),
386 kNumThreads);
387 for (int64_t i = 0; i < kNumThreads * kRunsPerThread; ++i) {
388 pool.Schedule([&] {
389 TF_ASSERT_OK(
390 test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn)
391 .status());
392 done.DecrementCount();
393 });
394 }
395 done.Wait();
396 }
397
398 // Runs the same executable many times concurrently. The all-reduces should not
399 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_CombinableAllReduces)400 XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
401 std::string hlo_string = R"(
402 HloModule test
403
404 apply_op {
405 x = f32[] parameter(0)
406 y = f32[] parameter(1)
407 ROOT apply_op = f32[] add(x, y)
408 }
409
410 ENTRY test_computation {
411 p0 = f32[5] parameter(0)
412 p1 = f32[5] parameter(1)
413 crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
414 crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
415 ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
416 }
417 )";
418 static constexpr int kNumReplicas = 2;
419 auto config = GetModuleConfigForTest();
420 config.set_replica_count(kNumReplicas);
421 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
422 ParseAndReturnVerifiedModule(hlo_string, config));
423
424 std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
425 auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
426 std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
427 auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
428
429 TF_ASSERT_OK_AND_ASSIGN(
430 std::vector<Literal> results,
431 ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
432 /*num_replicas=*/kNumReplicas,
433 /*use_threads=*/true));
434 std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
435 auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
436 std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
437 auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
438 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
439 auto rs = results[replica_idx].DecomposeTuple();
440 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
441 ErrorSpec{1e-5, 1e-5}));
442 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
443 ErrorSpec{1e-5, 1e-5}));
444 }
445 }
446
447 // Runs an all-reduce with three partitions:
448 // {0}, {1,2}, {3}
449 // meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
450 // 2 actually exchange data with each other.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ThreeReplicaGroups)451 XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
452 // Test a prime number so it's not all powers of 2.
453 const int64_t kNumElems = 137;
454
455 auto config = GetModuleConfigForTest();
456 config.set_replica_count(4);
457 std::vector<float> input_vec(kNumElems);
458 absl::c_iota(input_vec, 0);
459 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
460 auto module = MakeCrsModule(
461 /*shape=*/input_literal.shape(),
462 /*replica_groups=*/{{0}, {1, 2}, {3}}, config);
463
464 TF_ASSERT_OK_AND_ASSIGN(
465 std::vector<Literal> results,
466 ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4,
467 /*use_threads=*/true));
468
469 ASSERT_EQ(results.size(), 4);
470
471 std::vector<float> input_vec_doubled;
472 input_vec_doubled.reserve(input_vec.size());
473 for (float n : input_vec) {
474 input_vec_doubled.push_back(n * 2);
475 }
476 auto input_literal_doubled = LiteralUtil::CreateR1<float>(input_vec_doubled);
477
478 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[0]));
479 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[1]));
480 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[2]));
481 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[3]));
482 }
483
XLA_TEST_F(CollectiveOpsTest,AllReduce_Degenerate)484 XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) {
485 const char* const kModuleStr = R"(
486 HloModule test
487
488 apply_op {
489 x = u32[] parameter(0)
490 y = u32[] parameter(1)
491 ROOT apply_op = u32[] add(x, y)
492 }
493
494 ENTRY test_computation {
495 id = u32[] replica-id()
496 ROOT crs = u32[] all-reduce(id), replica_groups={{0},{1},{2},{3}}, to_apply=apply_op
497 }
498 )";
499 static constexpr int kNumReplicas = 4;
500 auto config = GetModuleConfigForTest();
501 config.set_replica_count(kNumReplicas);
502 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
503 ParseAndReturnVerifiedModule(kModuleStr, config));
504 TF_ASSERT_OK_AND_ASSIGN(
505 std::vector<Literal> results,
506 ExecuteReplicated(std::move(module), {}, /*num_replicas=*/kNumReplicas,
507 /*use_threads=*/true));
508
509 ASSERT_EQ(results.size(), kNumReplicas);
510 for (int i = 0; i < kNumReplicas; ++i) {
511 LiteralTestUtil::ExpectR0Equal<uint32_t>(i, results[i]);
512 }
513 }
514
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduce))515 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) {
516 const absl::string_view kModuleStr = R"(
517 HloModule test
518
519 apply_op {
520 x = u32[] parameter(0)
521 y = u32[] parameter(1)
522 ROOT apply_op = u32[] add(x, y)
523 }
524
525 ENTRY test_computation {
526 id = u32[] replica-id()
527 start = u32[] all-reduce-start(id), to_apply=apply_op
528 ROOT done = u32[] all-reduce-done(start)
529 }
530 )";
531 static constexpr int kNumReplicas = 4;
532 HloModuleConfig config = GetModuleConfigForTest();
533 config.set_replica_count(kNumReplicas);
534 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
535 ParseAndReturnVerifiedModule(kModuleStr, config));
536 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
537 ExecuteReplicated(std::move(module), {}, kNumReplicas,
538 /*use_threads=*/true));
539
540 ASSERT_EQ(results.size(), kNumReplicas);
541 uint32_t expected = 6; // sum [0,4)
542 for (int i = 0; i < kNumReplicas; ++i) {
543 LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
544 }
545 }
546
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduceTwoOperands))547 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) {
548 const absl::string_view kModuleStr = R"(
549 HloModule test
550
551 apply_op {
552 x = u32[] parameter(0)
553 y = u32[] parameter(1)
554 ROOT apply_op = u32[] add(x, y)
555 }
556
557 ENTRY test_computation {
558 id = u32[] replica-id()
559 id2 = u32[] multiply(id, id)
560 start = (u32[], u32[]) all-reduce-start(id, id2), to_apply=apply_op
561 ROOT done = (u32[], u32[]) all-reduce-done(start)
562 }
563 )";
564 static constexpr int kNumReplicas = 4;
565 HloModuleConfig config = GetModuleConfigForTest();
566 config.set_replica_count(kNumReplicas);
567 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
568 ParseAndReturnVerifiedModule(kModuleStr, config));
569 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
570 ExecuteReplicated(std::move(module), {}, kNumReplicas,
571 /*use_threads=*/true));
572
573 ASSERT_EQ(results.size(), kNumReplicas);
574 uint32_t expected0 = 6; // sum [0,4)
575 uint32_t expected1 = 14; // sum squares [0,4)
576 for (int i = 0; i < kNumReplicas; ++i) {
577 std::vector<Literal> replica_results = results[i].DecomposeTuple();
578 LiteralTestUtil::ExpectR0Equal<uint32_t>(expected0, replica_results[0]);
579 LiteralTestUtil::ExpectR0Equal<uint32_t>(expected1, replica_results[1]);
580 }
581 }
582
XLA_TEST_F(CollectiveOpsTest,ReplicaId)583 XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
584 const char* const kModuleStr = R"(
585 HloModule test
586 ENTRY test_computation {
587 id = u32[] replica-id()
588 ROOT out = u32[] copy(id)
589 }
590 )";
591 const int64_t kNumReplicas = 4;
592
593 auto config = GetModuleConfigForTest();
594 config.set_replica_count(kNumReplicas);
595 TF_ASSERT_OK_AND_ASSIGN(auto module,
596 ParseAndReturnVerifiedModule(kModuleStr));
597
598 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
599 ExecuteReplicated(std::move(module), {}, kNumReplicas,
600 /*use_threads=*/true));
601
602 ASSERT_EQ(results.size(), kNumReplicas);
603 for (uint32_t i = 0; i < kNumReplicas; ++i) {
604 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
605 }
606 }
607
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Simple)608 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
609 const char* const kModuleStr = R"(
610 HloModule test
611 ENTRY test_computation {
612 replica = u32[] replica-id()
613 ten = u32[] constant(10)
614 sum = u32[] add(replica, ten)
615 p = u32[2] broadcast(sum), dimensions={}
616 permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}}
617 ROOT copy = u32[2] copy(permute)
618 }
619 )";
620 const int64_t kNumReplicas = 4;
621
622 auto config = GetModuleConfigForTest();
623 config.set_replica_count(kNumReplicas);
624 TF_ASSERT_OK_AND_ASSIGN(auto module,
625 ParseAndReturnVerifiedModule(kModuleStr, config));
626
627 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
628 ExecuteReplicated(std::move(module), {}, kNumReplicas,
629 /*use_threads=*/true));
630 ASSERT_EQ(results.size(), kNumReplicas);
631 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
632 results[0]));
633 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
634 results[1]));
635 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
636 results[2]));
637 // Nothing writes to replica 3, so it is memzero'ed.
638 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({0, 0}),
639 results[3]));
640 }
641
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Degnerate)642 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degnerate) {
643 const char* const kModuleStr = R"(
644 HloModule test
645 ENTRY test_computation {
646 replica = u32[] replica-id()
647 ten = u32[] constant(10)
648 sum = u32[] add(replica, ten)
649 p = u32[2] broadcast(sum), dimensions={}
650 permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}, {3,3}}
651 ROOT copy = u32[2] copy(permute)
652 }
653 )";
654 const int64_t kNumReplicas = 4;
655
656 auto config = GetModuleConfigForTest();
657 config.set_replica_count(kNumReplicas);
658 TF_ASSERT_OK_AND_ASSIGN(auto module,
659 ParseAndReturnVerifiedModule(kModuleStr, config));
660
661 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
662 ExecuteReplicated(std::move(module), {}, kNumReplicas,
663 /*use_threads=*/true));
664 ASSERT_EQ(results.size(), kNumReplicas);
665 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
666 results[0]));
667 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
668 results[1]));
669 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
670 results[2]));
671 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({13, 13}),
672 results[3]));
673 }
674
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_NoDegnerate)675 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NoDegnerate) {
676 const char* const kModuleStr = R"(
677 HloModule test
678 ENTRY test_computation {
679 replica = u32[] replica-id()
680 ten = u32[] constant(10)
681 sum = u32[] add(replica, ten)
682 p = u32[2] broadcast(sum), dimensions={}
683 permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}}
684 ROOT copy = u32[2] copy(permute)
685 }
686 )";
687 const int64_t kNumReplicas = 4;
688
689 auto config = GetModuleConfigForTest();
690 config.set_replica_count(kNumReplicas);
691 TF_ASSERT_OK_AND_ASSIGN(auto module,
692 ParseAndReturnVerifiedModule(kModuleStr, config));
693
694 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
695 ExecuteReplicated(std::move(module), {}, kNumReplicas,
696 /*use_threads=*/true));
697 ASSERT_EQ(results.size(), kNumReplicas);
698 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
699 results[0]));
700 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
701 results[1]));
702 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
703 results[2]));
704 // Nothing writes to replica 3, so it is memzero'ed.
705 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({0, 0}),
706 results[3]));
707 }
708
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Rotate)709 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) {
710 const char* const kModuleStr = R"(
711 HloModule test
712 ENTRY test_computation {
713 replica = u32[] replica-id()
714 ten = u32[] constant(10)
715 sum = u32[] add(replica, ten)
716 p = u32[2] broadcast(sum), dimensions={}
717 permute = u32[2] collective-permute(p), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
718 ROOT copy = u32[2] copy(permute)
719 }
720 )";
721 const int64_t kNumReplicas = 4;
722
723 auto config = GetModuleConfigForTest();
724 config.set_replica_count(kNumReplicas);
725 TF_ASSERT_OK_AND_ASSIGN(auto module,
726 ParseAndReturnVerifiedModule(kModuleStr, config));
727
728 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
729 ExecuteReplicated(std::move(module), {}, kNumReplicas,
730 /*use_threads=*/true));
731 ASSERT_EQ(results.size(), kNumReplicas);
732 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({13, 13}),
733 results[0]));
734 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
735 results[1]));
736 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
737 results[2]));
738 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
739 results[3]));
740 }
741
XLA_TEST_F(CollectiveOpsTest,AllToAll_EmptyReplicaGroups)742 XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) {
743 const char* const kModuleStr = R"(
744 HloModule test
745 ENTRY test_computation {
746 id = u32[] replica-id()
747 id2 = u32[2] broadcast(id), dimensions={}
748 a0 = u32[2] constant({10, 15})
749 b0 = u32[2] constant({20, 25})
750 c0 = u32[2] constant({30, 35})
751 d0 = u32[2] constant({40, 45})
752 a1 = u32[2] add(id2, a0)
753 b1 = u32[2] add(id2, b0)
754 c1 = u32[2] add(id2, c0)
755 d1 = u32[2] add(id2, d0)
756 all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={}
757 a_prime = u32[2] get-tuple-element(all2all), index=0
758 b_prime = u32[2] get-tuple-element(all2all), index=1
759 c_prime = u32[2] get-tuple-element(all2all), index=2
760 d_prime = u32[2] get-tuple-element(all2all), index=3
761 ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
762 }
763 )";
764 const int64_t kNumReplicas = 4;
765 auto config = GetModuleConfigForTest(kNumReplicas);
766 TF_ASSERT_OK_AND_ASSIGN(auto module,
767 ParseAndReturnVerifiedModule(kModuleStr, config));
768
769 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
770 ExecuteReplicated(std::move(module), {}, kNumReplicas,
771 /*use_threads=*/true));
772 ASSERT_EQ(results.size(), kNumReplicas);
773 LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 15, 11, 16, 12, 17, 13, 18},
774 results[0]);
775 LiteralTestUtil::ExpectR1Equal<uint32_t>({20, 25, 21, 26, 22, 27, 23, 28},
776 results[1]);
777 LiteralTestUtil::ExpectR1Equal<uint32_t>({30, 35, 31, 36, 32, 37, 33, 38},
778 results[2]);
779 LiteralTestUtil::ExpectR1Equal<uint32_t>({40, 45, 41, 46, 42, 47, 43, 48},
780 results[3]);
781 }
782
XLA_TEST_F(CollectiveOpsTest,AllToAll_OrderedReplicaGroups)783 XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) {
784 const char* const kModuleStr = R"(
785 HloModule test
786 ENTRY test_computation {
787 id = u32[] replica-id()
788 id2 = u32[2] broadcast(id), dimensions={}
789 a0 = u32[2] constant({10, 15})
790 b0 = u32[2] constant({20, 25})
791 c0 = u32[2] constant({30, 35})
792 d0 = u32[2] constant({40, 45})
793 a1 = u32[2] add(id2, a0)
794 b1 = u32[2] add(id2, b0)
795 c1 = u32[2] add(id2, c0)
796 d1 = u32[2] add(id2, d0)
797 all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={{3,2,1,0}}
798 a_prime = u32[2] get-tuple-element(all2all), index=0
799 b_prime = u32[2] get-tuple-element(all2all), index=1
800 c_prime = u32[2] get-tuple-element(all2all), index=2
801 d_prime = u32[2] get-tuple-element(all2all), index=3
802 ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
803 }
804 )";
805 const int64_t kNumReplicas = 4;
806 auto config = GetModuleConfigForTest(kNumReplicas);
807 TF_ASSERT_OK_AND_ASSIGN(auto module,
808 ParseAndReturnVerifiedModule(kModuleStr, config));
809
810 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
811 ExecuteReplicated(std::move(module), {}, kNumReplicas,
812 /*use_threads=*/true));
813 ASSERT_EQ(results.size(), kNumReplicas);
814 LiteralTestUtil::ExpectR1Equal<uint32_t>({43, 48, 42, 47, 41, 46, 40, 45},
815 results[0]);
816 LiteralTestUtil::ExpectR1Equal<uint32_t>({33, 38, 32, 37, 31, 36, 30, 35},
817 results[1]);
818 LiteralTestUtil::ExpectR1Equal<uint32_t>({23, 28, 22, 27, 21, 26, 20, 25},
819 results[2]);
820 LiteralTestUtil::ExpectR1Equal<uint32_t>({13, 18, 12, 17, 11, 16, 10, 15},
821 results[3]);
822 }
823
XLA_TEST_F(CollectiveOpsTest,AllToAll_TwoReplicaGroups)824 XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) {
825 const char* const kModuleStr = R"(
826 HloModule test
827 ENTRY test_computation {
828 id = u32[] replica-id()
829 id2 = u32[2] broadcast(id), dimensions={}
830 a0 = u32[2] constant({10, 15})
831 b0 = u32[2] constant({20, 25})
832 a1 = u32[2] add(id2, a0)
833 b1 = u32[2] add(id2, b0)
834 all2all = (u32[2], u32[2]) all-to-all(a1, b1), replica_groups={{2,1},{3,0}}
835 a_prime = u32[2] get-tuple-element(all2all), index=0
836 b_prime = u32[2] get-tuple-element(all2all), index=1
837 ROOT out = u32[4] concatenate(a_prime, b_prime), dimensions={0}
838 }
839 )";
840 const int64_t kNumReplicas = 4;
841 auto config = GetModuleConfigForTest(kNumReplicas);
842 TF_ASSERT_OK_AND_ASSIGN(auto module,
843 ParseAndReturnVerifiedModule(kModuleStr, config));
844
845 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
846 ExecuteReplicated(std::move(module), {}, kNumReplicas,
847 /*use_threads=*/true));
848 ASSERT_EQ(results.size(), kNumReplicas);
849 LiteralTestUtil::ExpectR1Equal<uint32_t>({23, 28, 20, 25}, results[0]);
850 LiteralTestUtil::ExpectR1Equal<uint32_t>({22, 27, 21, 26}, results[1]);
851 LiteralTestUtil::ExpectR1Equal<uint32_t>({12, 17, 11, 16}, results[2]);
852 LiteralTestUtil::ExpectR1Equal<uint32_t>({13, 18, 10, 15}, results[3]);
853 }
854
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllToAll_SplitDimension))855 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) {
856 const char* const kModuleStr = R"(
857 HloModule test
858 ENTRY test_computation {
859 id = u32[] replica-id()
860 id2 = u32[4, 2] broadcast(id), dimensions={}
861 a0 = u32[4, 2] constant({{10, 15}, {20, 25}, {30, 35}, {40, 45}})
862 a1 = u32[4, 2] add(id2, a0)
863 all2all = u32[4, 2] all-to-all(a1), replica_groups={{0,1,2,3}}, dimensions={0}
864 ROOT out = u32[8] reshape(all2all)
865 }
866 )";
867 const int64_t kNumReplicas = 4;
868 auto config = GetModuleConfigForTest(kNumReplicas);
869 TF_ASSERT_OK_AND_ASSIGN(auto module,
870 ParseAndReturnVerifiedModule(kModuleStr, config));
871
872 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
873 ExecuteReplicated(std::move(module), {}, kNumReplicas,
874 /*use_threads=*/true));
875 ASSERT_EQ(results.size(), kNumReplicas);
876 LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 15, 11, 16, 12, 17, 13, 18},
877 results[0]);
878 LiteralTestUtil::ExpectR1Equal<uint32_t>({20, 25, 21, 26, 22, 27, 23, 28},
879 results[1]);
880 LiteralTestUtil::ExpectR1Equal<uint32_t>({30, 35, 31, 36, 32, 37, 33, 38},
881 results[2]);
882 LiteralTestUtil::ExpectR1Equal<uint32_t>({40, 45, 41, 46, 42, 47, 43, 48},
883 results[3]);
884 }
885
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim0)886 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) {
887 const char* const kModuleStr = R"(
888 HloModule test
889 ENTRY test_computation {
890 id = u32[] replica-id()
891 id2 = u32[1, 2] broadcast(id), dimensions={}
892 a0 = u32[1, 2] constant({{10, 15}})
893 a1 = u32[1, 2] add(id2, a0)
894 allgather = u32[4, 2] all-gather(a1), dimensions={0}
895 ROOT out = u32[8] reshape(allgather)
896 }
897 )";
898 const int64_t kNumReplicas = 4;
899 auto config = GetModuleConfigForTest(kNumReplicas);
900 TF_ASSERT_OK_AND_ASSIGN(auto module,
901 ParseAndReturnVerifiedModule(kModuleStr, config));
902
903 TF_ASSERT_OK_AND_ASSIGN(
904 std::vector<Literal> results,
905 ExecuteReplicated(std::move(module), {}, kNumReplicas,
906 /*use_threads=*/true, /*run_hlo_passes=*/true));
907 ASSERT_EQ(results.size(), kNumReplicas);
908 for (const Literal& result : results) {
909 LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 15, 11, 16, 12, 17, 13, 18},
910 result);
911 }
912 }
913
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim1)914 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) {
915 const char* const kModuleStr = R"(
916 HloModule test
917 ENTRY test_computation {
918 id = u32[] replica-id()
919 id2 = u32[2, 1] broadcast(id), dimensions={}
920 a0 = u32[2, 1] constant({{10}, {15}})
921 a1 = u32[2, 1] add(id2, a0)
922 allgather = u32[2, 4] all-gather(a1), dimensions={1}
923 ROOT out = u32[8] reshape(allgather)
924 }
925 )";
926 const int64_t kNumReplicas = 4;
927 auto config = GetModuleConfigForTest(kNumReplicas);
928 TF_ASSERT_OK_AND_ASSIGN(auto module,
929 ParseAndReturnVerifiedModule(kModuleStr, config));
930
931 TF_ASSERT_OK_AND_ASSIGN(
932 std::vector<Literal> results,
933 ExecuteReplicated(std::move(module), {}, kNumReplicas,
934 /*use_threads=*/true, /*run_hlo_passes=*/true));
935 ASSERT_EQ(results.size(), kNumReplicas);
936 for (const Literal& result : results) {
937 LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 11, 12, 13, 15, 16, 17, 18},
938 result);
939 }
940 }
941
XLA_TEST_F(CollectiveOpsTest,AllReduce_TupleAllReduce)942 XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
943 std::string hlo_string = R"(
944 HloModule test
945
946 apply_op {
947 x = f32[] parameter(0)
948 y = f32[] parameter(1)
949 ROOT apply_op = f32[] add(x, y)
950 }
951
952 ENTRY test_computation {
953 p0 = f32[5] parameter(0)
954 p1 = f32[7] parameter(1)
955 ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
956 }
957 )";
958 static constexpr int kNumReplicas = 2;
959 auto config = GetModuleConfigForTest();
960 config.set_replica_count(kNumReplicas);
961 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
962 ParseAndReturnVerifiedModule(hlo_string, config));
963
964 std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
965 auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
966 std::vector<float> input1_vec = {
967 7., 3., 4., 1., 2., 3., 4.,
968 };
969 auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
970
971 TF_ASSERT_OK_AND_ASSIGN(
972 std::vector<Literal> results,
973 ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
974 /*num_replicas=*/kNumReplicas,
975 /*use_threads=*/true));
976 std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
977 auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
978 std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
979 auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
980 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
981 auto rs = results[replica_idx].DecomposeTuple();
982 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
983 ErrorSpec{1e-5, 1e-5}));
984 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
985 ErrorSpec{1e-5, 1e-5}));
986 }
987 }
988
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherMixedTypes))989 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGatherMixedTypes)) {
990 const char* const kModuleStr = R"(
991 HloModule test
992 ENTRY test_computation {
993 id = u32[] replica-id()
994 p0 = u32[2, 1] broadcast(id), dimensions={}
995 p1 = f32[2, 1] convert(p0)
996 allgather = (u32[2, 2], f32[2, 2]) all-gather(p0, p1), dimensions={1}
997 ag0 = u32[2, 2] get-tuple-element(allgather), index=0
998 ag1 = f32[2, 2] get-tuple-element(allgather), index=1
999 r0 = u32[4] reshape(ag0)
1000 r1 = f32[4] reshape(ag1)
1001 ROOT out = (u32[4], f32[4]) tuple(r0, r1)
1002 }
1003 )";
1004 const int64_t kNumReplicas = 2;
1005 auto config = GetModuleConfigForTest(kNumReplicas);
1006 TF_ASSERT_OK_AND_ASSIGN(auto module,
1007 ParseAndReturnVerifiedModule(kModuleStr, config));
1008
1009 TF_ASSERT_OK_AND_ASSIGN(
1010 std::vector<Literal> results,
1011 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1012 /*use_threads=*/true, /*run_hlo_passes=*/true));
1013 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
1014 auto rs = results[replica_idx].DecomposeTuple();
1015 LiteralTestUtil::ExpectR1Equal<uint32_t>({0, 1, 0, 1}, rs[0]);
1016 LiteralTestUtil::ExpectR1Near<float>({0.0, 1.0, 0.0, 1.0}, rs[1],
1017 ErrorSpec{1e-5, 1e-5});
1018 }
1019 }
1020
XLA_TEST_F(CollectiveOpsTest,ReduceScatter)1021 XLA_TEST_F(CollectiveOpsTest, ReduceScatter) {
1022 const char* const kModuleStr = R"(
1023 HloModule test
1024 add {
1025 lhs = u32[] parameter(0)
1026 rhs = u32[] parameter(1)
1027 ROOT add = u32[] add(lhs, rhs)
1028 }
1029
1030 ENTRY main {
1031 c0 = u32[8] constant({1, 2, 3, 4, 5, 6, 7, 8})
1032 c1 = u32[8] constant({10, 11, 12, 13, 14, 15, 16, 17})
1033 zero = u32[] constant(0)
1034 id = u32[] replica-id()
1035 p = pred[] compare(id, zero), direction=EQ
1036 pb = pred[8] broadcast(p), dimensions={}
1037 // data = c0 for replica 0 and c1 for replica 1
1038 data = u32[8] select(pb, c0, c1)
1039 ROOT ars = u32[4] reduce-scatter(data), replica_groups={},
1040 dimensions={0}, to_apply=add
1041 }
1042 )";
1043
1044 const int64_t kNumReplicas = 2;
1045 auto config = GetModuleConfigForTest(kNumReplicas);
1046 TF_ASSERT_OK_AND_ASSIGN(auto module,
1047 ParseAndReturnVerifiedModule(kModuleStr, config));
1048
1049 TF_ASSERT_OK_AND_ASSIGN(
1050 std::vector<Literal> results,
1051 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1052 /*use_threads=*/true, /*run_hlo_passes=*/true));
1053 LiteralTestUtil::ExpectR1Equal<uint32_t>({11, 13, 15, 17}, results[0]);
1054 LiteralTestUtil::ExpectR1Equal<uint32_t>({19, 21, 23, 25}, results[1]);
1055 }
1056
XLA_TEST_F(CollectiveOpsTest,ReduceScatter_Dim1)1057 XLA_TEST_F(CollectiveOpsTest, ReduceScatter_Dim1) {
1058 const char* const kModuleStr = R"(
1059 HloModule test
1060 add {
1061 lhs = u32[] parameter(0)
1062 rhs = u32[] parameter(1)
1063 ROOT add = u32[] add(lhs, rhs)
1064 }
1065
1066 ENTRY main {
1067 c0 = u32[2, 4] constant({{ 1, 2, 3, 4}, { 5, 6, 7, 8}})
1068 c1 = u32[2, 4] constant({{10, 11, 12, 13}, {14, 15, 16, 17}})
1069 zero = u32[] constant(0)
1070 id = u32[] replica-id()
1071 p = pred[] compare(id, zero), direction=EQ
1072 pb = pred[2, 4] broadcast(p), dimensions={}
1073 // data = c0 for replica 0 and c1 for replica 1
1074 data = u32[2, 4] select(pb, c0, c1)
1075 // all-reduce result = {{11, 13, 15, 17}, {19, 21, 23, 25}}
1076 ars = u32[2, 2] reduce-scatter(data), replica_groups={},
1077 dimensions={1}, to_apply=add
1078 ROOT r = u32[4] reshape(ars)
1079 }
1080 )";
1081
1082 const int64_t kNumReplicas = 2;
1083 auto config = GetModuleConfigForTest(kNumReplicas);
1084 TF_ASSERT_OK_AND_ASSIGN(auto module,
1085 ParseAndReturnVerifiedModule(kModuleStr, config));
1086
1087 TF_ASSERT_OK_AND_ASSIGN(
1088 std::vector<Literal> results,
1089 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1090 /*use_threads=*/true, /*run_hlo_passes=*/true));
1091 LiteralTestUtil::ExpectR1Equal<uint32_t>({11, 13, 19, 21}, results[0]);
1092 LiteralTestUtil::ExpectR1Equal<uint32_t>({15, 17, 23, 25}, results[1]);
1093 }
1094
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceReassociate))1095 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduceReassociate)) {
1096 const char* const kModuleStr = R"(
1097 HloModule m
1098 sum {
1099 a = f32[] parameter(0)
1100 b = f32[] parameter(1)
1101 ROOT add.2 = f32[] add(a, b)
1102 }
1103
1104 ENTRY main {
1105 c0 = f32[8] constant({ 1, 2, 3, 4, 5, 6, 7, 8})
1106 c1 = f32[8] constant({ 11, 12, 13, 14, 15, 16, 17, 18})
1107 c2 = f32[8] constant({ 2, 3, 4, 5, 6, 7, 8, 9})
1108 c3 = f32[8] constant({ 12, 13, 14, 15, 16, 17, 18, 19})
1109 zero = u32[] constant(0)
1110 id = u32[] replica-id()
1111 p = pred[] compare(id, zero), direction=EQ
1112 pb = pred[8] broadcast(p), dimensions={}
1113 // data0 = c0 for replica 0 and c1 for replica 1
1114 data0 = f32[8] select(pb, c0, c1)
1115 // data1 = c2 for replica 0 and c3 for replica 1
1116 data1 = f32[8] select(pb, c2, c3)
1117
1118 ar0 = f32[8] all-reduce(data0), replica_groups={}, to_apply=sum
1119 ar1 = f32[8] all-reduce(data1), replica_groups={}, to_apply=sum
1120 ROOT add = f32[8] add(ar0, ar1)
1121 }
1122 )";
1123 const int64_t kNumReplicas = 2;
1124 auto config = GetModuleConfigForTest(kNumReplicas);
1125 TF_ASSERT_OK_AND_ASSIGN(auto module,
1126 ParseAndReturnVerifiedModule(kModuleStr, config));
1127
1128 TF_ASSERT_OK_AND_ASSIGN(
1129 std::vector<Literal> results,
1130 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1131 /*use_threads=*/true, /*run_hlo_passes=*/true));
1132
1133 const ErrorSpec es{1e-5, 1e-5};
1134 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(results[0], results[1], es));
1135 LiteralTestUtil::ExpectR1Near<float>(
1136 {26.0, 30.0, 34.0, 38.0, 42.0, 46.0, 50.0, 54.0}, results[0], es);
1137 }
1138
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_NonUniform))1139 XLA_TEST_F(CollectiveOpsTest,
1140 DISABLED_ON_CPU(AllGatherBroadcastReorder_NonUniform)) {
1141 const char* const kModuleStr = R"(
1142 HloModule m
1143
1144 ENTRY main {
1145 c0 = u32[2, 3] constant({{ 1, 2, 3}, { 4, 5, 6}})
1146 c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1147 zero = u32[] constant(0)
1148 id = u32[] replica-id()
1149 p = pred[] compare(id, zero), direction=EQ
1150 pb = pred[2, 3] broadcast(p), dimensions={}
1151 // data = c0 for replica 0 and c1 for replica 1
1152 data = u32[2, 3] select(pb, c0, c1)
1153 bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1154 ROOT ag = u32[2, 4, 6] all-gather(bc), dimensions={2}, replica_groups={{0, 1}}
1155 }
1156 )";
1157
1158 const int64_t kNumReplicas = 2;
1159 auto config = GetModuleConfigForTest(kNumReplicas);
1160 TF_ASSERT_OK_AND_ASSIGN(auto module,
1161 ParseAndReturnVerifiedModule(kModuleStr, config));
1162
1163 TF_ASSERT_OK_AND_ASSIGN(
1164 std::vector<Literal> results,
1165 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1166 /*use_threads=*/true, /*run_hlo_passes=*/true));
1167
1168 EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1169 LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3, 10, 11, 12},
1170 {1, 2, 3, 10, 11, 12},
1171 {1, 2, 3, 10, 11, 12},
1172 {1, 2, 3, 10, 11, 12}},
1173 {{4, 5, 6, 13, 14, 15},
1174 {4, 5, 6, 13, 14, 15},
1175 {4, 5, 6, 13, 14, 15},
1176 {4, 5, 6, 13, 14, 15}}},
1177 results[0]);
1178 }
1179
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_Uniform))1180 XLA_TEST_F(CollectiveOpsTest,
1181 DISABLED_ON_CPU(AllGatherBroadcastReorder_Uniform)) {
1182 const char* const kModuleStr = R"(
1183 HloModule m
1184
1185 ENTRY main {
1186 c0 = u32[2, 3] constant({{ 1, 2, 3}, { 4, 5, 6}})
1187 c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1188 zero = u32[] constant(0)
1189 id = u32[] replica-id()
1190 p = pred[] compare(id, zero), direction=EQ
1191 pb = pred[2, 3] broadcast(p), dimensions={}
1192 // data = c0 for replica 0 and c1 for replica 1
1193 data = u32[2, 3] select(pb, c0, c1)
1194 bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1195 ROOT ag = u32[2, 8, 3] all-gather(bc), dimensions={1}, replica_groups={{0, 1}}
1196 }
1197 )";
1198
1199 const int64_t kNumReplicas = 2;
1200 auto config = GetModuleConfigForTest(kNumReplicas);
1201 TF_ASSERT_OK_AND_ASSIGN(auto module,
1202 ParseAndReturnVerifiedModule(kModuleStr, config));
1203
1204 TF_ASSERT_OK_AND_ASSIGN(
1205 std::vector<Literal> results,
1206 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1207 /*use_threads=*/true, /*run_hlo_passes=*/true));
1208 EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1209 LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3},
1210 {1, 2, 3},
1211 {1, 2, 3},
1212 {1, 2, 3},
1213 {10, 11, 12},
1214 {10, 11, 12},
1215 {10, 11, 12},
1216 {10, 11, 12}},
1217 {{4, 5, 6},
1218 {4, 5, 6},
1219 {4, 5, 6},
1220 {4, 5, 6},
1221 {13, 14, 15},
1222 {13, 14, 15},
1223 {13, 14, 15},
1224 {13, 14, 15}}},
1225 results[0]);
1226 }
1227
1228 } // namespace
1229 } // namespace xla
1230