xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/collective_ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "absl/strings/str_replace.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/primitive_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/platform/env.h"
32 
33 // Tests cross-GPU operations.
34 //
35 // This test requires at least four GPUs.  For instructions on running this
36 // within Google, see go/multi-gpu-unit-test.
37 
38 namespace xla {
39 namespace {
40 
41 class CollectiveOpsTest : public HloTestBase {
42  public:
SetUpTestSuite()43   static void SetUpTestSuite() {
44     // Not needed structly, since this test exercises cross replica collective
45     // permute which does not use NCCL. But keeping it here for testing.
46     tensorflow::setenv("NCCL_LAUNCH_MODE", "PARALLEL", /*overwrite=*/1);
47     HloTestBase::SetUpTestSuite();
48   }
49 
50  protected:
MakeCrsModule(const Shape & shape,std::vector<std::vector<int64_t>> replica_groups,const HloModuleConfig & config,std::string op="add",std::string datatype="f32")51   std::unique_ptr<HloModule> MakeCrsModule(
52       const Shape& shape, std::vector<std::vector<int64_t>> replica_groups,
53       const HloModuleConfig& config, std::string op = "add",
54       std::string datatype = "f32") {
55     std::string hlo_template = R"(
56       HloModule test
57 
58       apply_op {
59         x = DATATYPE[] parameter(0)
60         y = DATATYPE[] parameter(1)
61         ROOT apply_op = DATATYPE[] OP(x, y)
62       }
63 
64       ENTRY test_computation {
65         p = SHAPE parameter(0)
66         p2 = SHAPE bitcast(p)
67         crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
68         copy = SHAPE copy(crs)
69         ROOT out = SHAPE bitcast(copy)
70       }
71     )";
72     std::vector<std::string> replica_group_strs;
73     replica_group_strs.reserve(replica_groups.size());
74     for (const auto& g : replica_groups) {
75       replica_group_strs.push_back(
76           absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
77     }
78     std::string shape_str = shape.ToString(/*print_layout=*/false);
79     if (shape_str == "f32[1]") {
80       // Exercise the scalar codepath.
81       hlo_template = absl::StrReplaceAll(
82           hlo_template,
83           {{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"},
84            {"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"},
85            {"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}});
86     }
87     std::string parameterized_hlo = absl::StrReplaceAll(
88         hlo_template,
89         {{"SHAPE", shape_str},
90          {"REPLICA_GROUPS",
91           absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))},
92          {"OP", op},
93          {"DATATYPE", datatype}});
94     return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie();
95   }
96 
97   template <typename LiteralType>
TestTwoReplicasOneOperand(std::string op,Literal input_value,Literal expected_value)98   void TestTwoReplicasOneOperand(std::string op, Literal input_value,
99                                  Literal expected_value) {
100     const int kNumReplicas = 2;
101     std::string dtype = primitive_util::LowercasePrimitiveTypeName(
102         primitive_util::NativeToPrimitiveType<LiteralType>());
103     auto config = GetModuleConfigForTest();
104     config.set_replica_count(kNumReplicas);
105     auto module = MakeCrsModule(
106         /*shape=*/input_value.shape(),
107         /*replica_groups=*/{}, config,
108         /*op=*/op, /*datatype=*/dtype);
109     TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
110                             ExecuteReplicated(std::move(module), {&input_value},
111                                               /*num_replicas=*/kNumReplicas,
112                                               /*use_threads=*/true));
113     for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
114       EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
115           expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5}));
116     }
117   }
118 
119   template <typename LiteralType>
TestAllOpsForReduce()120   void TestAllOpsForReduce() {
121     auto cast = [&](int value) { return static_cast<LiteralType>(value); };
122     auto to_literal = [&](absl::Span<const LiteralType> values) {
123       return LiteralUtil::CreateR1<LiteralType>(values);
124     };
125     Literal input_value = to_literal({cast(1), cast(2), cast(3)});
126     TestTwoReplicasOneOperand<LiteralType>(
127         "add",
128         /*input_value=*/input_value.Clone(),
129         /*expected_value=*/to_literal({cast(2), cast(4), cast(6)}));
130     TestTwoReplicasOneOperand<LiteralType>(
131         "multiply",
132         /*input_value=*/input_value.Clone(),
133         /*expected_value=*/to_literal({cast(1), cast(4), cast(9)}));
134     TestTwoReplicasOneOperand<LiteralType>(
135         "maximum",
136         /*input_value=*/input_value.Clone(),
137         /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
138     TestTwoReplicasOneOperand<LiteralType>(
139         "minimum",
140         /*input_value=*/input_value.Clone(),
141         /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
142   }
143 };
144 
145 // Returns the non-empty subsets of {0, 1, ..., n}.  For example,
146 // PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
PowerSetOfIota(int64_t n)147 std::vector<std::vector<int64_t>> PowerSetOfIota(int64_t n) {
148   std::vector<std::vector<int64_t>> power_set;
149   for (int64_t i = 1; i < (1 << n); ++i) {
150     power_set.emplace_back();
151     for (int64_t j = 0; j < n; ++j) {
152       if (i & (1 << j)) {
153         power_set.back().push_back(j);
154       }
155     }
156   }
157   return power_set;
158 }
159 
160 // Makes a DeviceAssignment assigning replica-id i to devices[i].
MakeDeviceAssn(std::vector<int64_t> devices)161 DeviceAssignment MakeDeviceAssn(std::vector<int64_t> devices) {
162   DeviceAssignment assn(/*replica_count=*/devices.size(),
163                         /*computation_count=*/1);
164   for (int64_t i = 0; i < devices.size(); ++i) {
165     assn(i, 0) = devices[i];
166   }
167   return assn;
168 }
169 
170 template <typename T>
ToHalf(T value)171 static Eigen::half ToHalf(T value) {
172   return static_cast<Eigen::half>(value);
173 }
174 
XLA_TEST_F(CollectiveOpsTest,AllReduce_sum_float32_2D)175 XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) {
176   TestTwoReplicasOneOperand<float>(
177       "add",
178       /*input_value=*/LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
179       /*expected_value=*/LiteralUtil::CreateR2<float>({{2, 4}, {6, 8}}));
180 }
181 
XLA_TEST_F(CollectiveOpsTest,AllReduceSingleOutput_float32)182 XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
183   TestTwoReplicasOneOperand<float>(
184       "add",
185       /*input_value=*/LiteralUtil::CreateR1<float>({1}),
186       /*expected_value=*/LiteralUtil::CreateR1<float>({2}));
187 }
188 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int8)189 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
190   TestAllOpsForReduce<int8_t>();
191 }
192 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint8)193 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
194   TestAllOpsForReduce<uint8_t>();
195 }
196 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint32)197 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
198   TestAllOpsForReduce<uint32_t>();
199 }
200 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int32)201 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
202   TestAllOpsForReduce<int32_t>();
203 }
204 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int64)205 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
206   TestAllOpsForReduce<int64_t>();
207 }
208 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint64)209 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
210   TestAllOpsForReduce<uint64_t>();
211 }
212 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_float32)213 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
214   TestAllOpsForReduce<float>();
215 }
216 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_double)217 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
218   TestAllOpsForReduce<double>();
219 }
220 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_half)221 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
222   TestAllOpsForReduce<Eigen::half>();
223 }
224 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceTwoReplicasOneOperand_bfloat16))225 XLA_TEST_F(CollectiveOpsTest,
226            DISABLED_ON_CPU(AllReduceTwoReplicasOneOperand_bfloat16)) {
227   TestAllOpsForReduce<bfloat16>();
228 }
229 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex64))230 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex64)) {
231   TestTwoReplicasOneOperand<complex64>(
232       "add",
233       /*input_value=*/LiteralUtil::CreateR1<complex64>({{1, 2}, {3, 4}}),
234       /*expected_value=*/LiteralUtil::CreateR1<complex64>({{2, 4}, {6, 8}}));
235 }
236 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex128))237 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex128)) {
238   TestTwoReplicasOneOperand<complex128>(
239       "add",
240       /*input_value=*/LiteralUtil::CreateR1<complex128>({{1, 2}, {3, 4}}),
241       /*expected_value=*/LiteralUtil::CreateR1<complex128>({{2, 4}, {6, 8}}));
242 }
243 
XLA_TEST_F(CollectiveOpsTest,AllReduceAnd_Pred)244 XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
245   // Test with equal elements.
246   TestTwoReplicasOneOperand<bool>(
247       "and",
248       /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
249       /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
250 
251   // Test with {true, false}.
252   const char* hlo_module = R"(
253     HloModule test
254 
255     apply_op {
256       x = pred[] parameter(0)
257       y = pred[] parameter(1)
258       ROOT apply_op = pred[] and(x, y)
259     }
260 
261     ENTRY test_computation {
262       id = u32[] replica-id()
263       c = u32[] constant(0)
264       p = pred[] compare(id, c), direction=EQ
265       p2 = pred[1] bitcast(p)
266       crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
267       copy = pred[1] copy(crs)
268       ROOT out = pred[1] bitcast(copy)
269     }
270   )";
271 
272   auto config = GetModuleConfigForTest();
273   config.set_replica_count(2);
274   auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
275   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
276                           ExecuteReplicated(std::move(module), {},
277                                             /*num_replicas=*/2,
278                                             /*use_threads=*/true));
279   for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
280     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
281                                        results[replica_idx]));
282   }
283 }
284 
XLA_TEST_F(CollectiveOpsTest,AllReduceOr_Pred)285 XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
286   // Test with equal elements.
287   TestTwoReplicasOneOperand<bool>(
288       "or",
289       /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
290       /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
291 
292   // Test with {true, false}.
293   const char* hlo_module = R"(
294     HloModule test
295 
296     apply_op {
297       x = pred[] parameter(0)
298       y = pred[] parameter(1)
299       ROOT apply_op = pred[] or(x, y)
300     }
301 
302     ENTRY test_computation {
303       id = u32[] replica-id()
304       c = u32[] constant(0)
305       p = pred[] compare(id, c), direction=EQ
306       p2 = pred[1] bitcast(p)
307       crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
308       copy = pred[1] copy(crs)
309       ROOT out = pred[1] bitcast(copy)
310     }
311   )";
312 
313   auto config = GetModuleConfigForTest();
314   config.set_replica_count(2);
315   auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
316   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
317                           ExecuteReplicated(std::move(module), {},
318                                             /*num_replicas=*/2,
319                                             /*use_threads=*/true));
320   for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
321     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
322                                        results[replica_idx]));
323   }
324 }
325 
326 // Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
327 // devices in sequence.
XLA_TEST_F(CollectiveOpsTest,AllReduce_AllCombinations)328 XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
329   const int64_t kNumDevices = 4;
330   const int64_t kNumElems = 1024;
331 
332   for (std::vector<int64_t> devices : PowerSetOfIota(kNumDevices)) {
333     SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
334                                  absl::StrJoin(devices, ", ")));
335 
336     DeviceAssignment device_assn = MakeDeviceAssn(devices);
337 
338     auto config = GetModuleConfigForTest();
339     config.set_replica_count(devices.size());
340     config.set_static_device_assignment(device_assn);
341 
342     std::vector<float> input_vec(kNumElems);
343     absl::c_iota(input_vec, 0);
344     auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
345 
346     auto module = MakeCrsModule(input_literal.shape(),
347                                 /*replica_groups=*/{}, config);
348 
349     TF_ASSERT_OK_AND_ASSIGN(
350         std::vector<Literal> results,
351         ExecuteReplicated(std::move(module), {&input_literal},
352                           /*num_replicas=*/devices.size(), &device_assn,
353                           /*run_hlo_passes=*/true, /*use_threads=*/true));
354   }
355 }
356 
357 // Runs the same executable many times concurrently.  The all-reduces should not
358 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ManyConcurrentAllReduces)359 XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
360   const int64_t kNumElems = 1024;
361   const int64_t kNumThreads = 200;
362   const int64_t kRunsPerThread = 10;
363 
364   std::vector<float> input_vec(kNumElems);
365   absl::c_iota(input_vec, 0);
366   auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
367 
368   auto config = GetModuleConfigForTest();
369   config.set_replica_count(2);
370   auto executable =
371       test_runner_
372           .CreateExecutable(MakeCrsModule(input_literal.shape(),
373                                           /*replica_groups=*/{}, config),
374                             /*run_hlo_passes=*/true)
375           .ValueOrDie();
376   std::vector<int64_t> devices = {0, 1};
377   auto device_assn = MakeDeviceAssn(devices);
378 
379   HloRunner::ReplicatedExecuteOptions opts;
380   opts.num_replicas = devices.size();
381   opts.use_threads = true;
382   opts.arguments.push_back(&input_literal);
383 
384   tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread);
385   tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(),
386                                       kNumThreads);
387   for (int64_t i = 0; i < kNumThreads * kRunsPerThread; ++i) {
388     pool.Schedule([&] {
389       TF_ASSERT_OK(
390           test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn)
391               .status());
392       done.DecrementCount();
393     });
394   }
395   done.Wait();
396 }
397 
398 // Runs the same executable many times concurrently.  The all-reduces should not
399 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_CombinableAllReduces)400 XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
401   std::string hlo_string = R"(
402     HloModule test
403 
404     apply_op {
405       x = f32[] parameter(0)
406       y = f32[] parameter(1)
407       ROOT apply_op = f32[] add(x, y)
408     }
409 
410     ENTRY test_computation {
411       p0 = f32[5] parameter(0)
412       p1 = f32[5] parameter(1)
413       crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
414       crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
415       ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
416     }
417   )";
418   static constexpr int kNumReplicas = 2;
419   auto config = GetModuleConfigForTest();
420   config.set_replica_count(kNumReplicas);
421   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
422                           ParseAndReturnVerifiedModule(hlo_string, config));
423 
424   std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
425   auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
426   std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
427   auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
428 
429   TF_ASSERT_OK_AND_ASSIGN(
430       std::vector<Literal> results,
431       ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
432                         /*num_replicas=*/kNumReplicas,
433                         /*use_threads=*/true));
434   std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
435   auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
436   std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
437   auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
438   for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
439     auto rs = results[replica_idx].DecomposeTuple();
440     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
441                                              ErrorSpec{1e-5, 1e-5}));
442     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
443                                              ErrorSpec{1e-5, 1e-5}));
444   }
445 }
446 
447 // Runs an all-reduce with three partitions:
448 //  {0}, {1,2}, {3}
449 // meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
450 // 2 actually exchange data with each other.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ThreeReplicaGroups)451 XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
452   // Test a prime number so it's not all powers of 2.
453   const int64_t kNumElems = 137;
454 
455   auto config = GetModuleConfigForTest();
456   config.set_replica_count(4);
457   std::vector<float> input_vec(kNumElems);
458   absl::c_iota(input_vec, 0);
459   auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
460   auto module = MakeCrsModule(
461       /*shape=*/input_literal.shape(),
462       /*replica_groups=*/{{0}, {1, 2}, {3}}, config);
463 
464   TF_ASSERT_OK_AND_ASSIGN(
465       std::vector<Literal> results,
466       ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4,
467                         /*use_threads=*/true));
468 
469   ASSERT_EQ(results.size(), 4);
470 
471   std::vector<float> input_vec_doubled;
472   input_vec_doubled.reserve(input_vec.size());
473   for (float n : input_vec) {
474     input_vec_doubled.push_back(n * 2);
475   }
476   auto input_literal_doubled = LiteralUtil::CreateR1<float>(input_vec_doubled);
477 
478   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[0]));
479   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[1]));
480   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[2]));
481   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[3]));
482 }
483 
XLA_TEST_F(CollectiveOpsTest,AllReduce_Degenerate)484 XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) {
485   const char* const kModuleStr = R"(
486       HloModule test
487 
488       apply_op {
489         x = u32[] parameter(0)
490         y = u32[] parameter(1)
491         ROOT apply_op = u32[] add(x, y)
492       }
493 
494       ENTRY test_computation {
495         id = u32[] replica-id()
496         ROOT crs = u32[] all-reduce(id), replica_groups={{0},{1},{2},{3}}, to_apply=apply_op
497       }
498     )";
499   static constexpr int kNumReplicas = 4;
500   auto config = GetModuleConfigForTest();
501   config.set_replica_count(kNumReplicas);
502   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
503                           ParseAndReturnVerifiedModule(kModuleStr, config));
504   TF_ASSERT_OK_AND_ASSIGN(
505       std::vector<Literal> results,
506       ExecuteReplicated(std::move(module), {}, /*num_replicas=*/kNumReplicas,
507                         /*use_threads=*/true));
508 
509   ASSERT_EQ(results.size(), kNumReplicas);
510   for (int i = 0; i < kNumReplicas; ++i) {
511     LiteralTestUtil::ExpectR0Equal<uint32_t>(i, results[i]);
512   }
513 }
514 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduce))515 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) {
516   const absl::string_view kModuleStr = R"(
517       HloModule test
518 
519       apply_op {
520         x = u32[] parameter(0)
521         y = u32[] parameter(1)
522         ROOT apply_op = u32[] add(x, y)
523       }
524 
525       ENTRY test_computation {
526         id = u32[] replica-id()
527         start = u32[] all-reduce-start(id), to_apply=apply_op
528         ROOT done = u32[] all-reduce-done(start)
529       }
530     )";
531   static constexpr int kNumReplicas = 4;
532   HloModuleConfig config = GetModuleConfigForTest();
533   config.set_replica_count(kNumReplicas);
534   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
535                           ParseAndReturnVerifiedModule(kModuleStr, config));
536   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
537                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
538                                             /*use_threads=*/true));
539 
540   ASSERT_EQ(results.size(), kNumReplicas);
541   uint32_t expected = 6;  // sum [0,4)
542   for (int i = 0; i < kNumReplicas; ++i) {
543     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
544   }
545 }
546 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduceTwoOperands))547 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) {
548   const absl::string_view kModuleStr = R"(
549       HloModule test
550 
551       apply_op {
552         x = u32[] parameter(0)
553         y = u32[] parameter(1)
554         ROOT apply_op = u32[] add(x, y)
555       }
556 
557       ENTRY test_computation {
558         id = u32[] replica-id()
559         id2 = u32[] multiply(id, id)
560         start = (u32[], u32[]) all-reduce-start(id, id2), to_apply=apply_op
561         ROOT done = (u32[], u32[]) all-reduce-done(start)
562       }
563     )";
564   static constexpr int kNumReplicas = 4;
565   HloModuleConfig config = GetModuleConfigForTest();
566   config.set_replica_count(kNumReplicas);
567   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
568                           ParseAndReturnVerifiedModule(kModuleStr, config));
569   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
570                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
571                                             /*use_threads=*/true));
572 
573   ASSERT_EQ(results.size(), kNumReplicas);
574   uint32_t expected0 = 6;   // sum [0,4)
575   uint32_t expected1 = 14;  // sum squares [0,4)
576   for (int i = 0; i < kNumReplicas; ++i) {
577     std::vector<Literal> replica_results = results[i].DecomposeTuple();
578     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected0, replica_results[0]);
579     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected1, replica_results[1]);
580   }
581 }
582 
XLA_TEST_F(CollectiveOpsTest,ReplicaId)583 XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
584   const char* const kModuleStr = R"(
585   HloModule test
586   ENTRY test_computation {
587     id = u32[] replica-id()
588     ROOT out = u32[] copy(id)
589   }
590   )";
591   const int64_t kNumReplicas = 4;
592 
593   auto config = GetModuleConfigForTest();
594   config.set_replica_count(kNumReplicas);
595   TF_ASSERT_OK_AND_ASSIGN(auto module,
596                           ParseAndReturnVerifiedModule(kModuleStr));
597 
598   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
599                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
600                                             /*use_threads=*/true));
601 
602   ASSERT_EQ(results.size(), kNumReplicas);
603   for (uint32_t i = 0; i < kNumReplicas; ++i) {
604     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
605   }
606 }
607 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Simple)608 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
609   const char* const kModuleStr = R"(
610   HloModule test
611   ENTRY test_computation {
612     replica = u32[] replica-id()
613     ten = u32[] constant(10)
614     sum = u32[] add(replica, ten)
615     p = u32[2] broadcast(sum), dimensions={}
616     permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}}
617     ROOT copy = u32[2] copy(permute)
618   }
619   )";
620   const int64_t kNumReplicas = 4;
621 
622   auto config = GetModuleConfigForTest();
623   config.set_replica_count(kNumReplicas);
624   TF_ASSERT_OK_AND_ASSIGN(auto module,
625                           ParseAndReturnVerifiedModule(kModuleStr, config));
626 
627   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
628                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
629                                             /*use_threads=*/true));
630   ASSERT_EQ(results.size(), kNumReplicas);
631   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
632                                      results[0]));
633   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
634                                      results[1]));
635   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
636                                      results[2]));
637   // Nothing writes to replica 3, so it is memzero'ed.
638   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({0, 0}),
639                                      results[3]));
640 }
641 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Degnerate)642 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degnerate) {
643   const char* const kModuleStr = R"(
644   HloModule test
645   ENTRY test_computation {
646     replica = u32[] replica-id()
647     ten = u32[] constant(10)
648     sum = u32[] add(replica, ten)
649     p = u32[2] broadcast(sum), dimensions={}
650     permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}, {3,3}}
651     ROOT copy = u32[2] copy(permute)
652   }
653   )";
654   const int64_t kNumReplicas = 4;
655 
656   auto config = GetModuleConfigForTest();
657   config.set_replica_count(kNumReplicas);
658   TF_ASSERT_OK_AND_ASSIGN(auto module,
659                           ParseAndReturnVerifiedModule(kModuleStr, config));
660 
661   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
662                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
663                                             /*use_threads=*/true));
664   ASSERT_EQ(results.size(), kNumReplicas);
665   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
666                                      results[0]));
667   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
668                                      results[1]));
669   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
670                                      results[2]));
671   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({13, 13}),
672                                      results[3]));
673 }
674 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_NoDegnerate)675 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NoDegnerate) {
676   const char* const kModuleStr = R"(
677   HloModule test
678   ENTRY test_computation {
679     replica = u32[] replica-id()
680     ten = u32[] constant(10)
681     sum = u32[] add(replica, ten)
682     p = u32[2] broadcast(sum), dimensions={}
683     permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}}
684     ROOT copy = u32[2] copy(permute)
685   }
686   )";
687   const int64_t kNumReplicas = 4;
688 
689   auto config = GetModuleConfigForTest();
690   config.set_replica_count(kNumReplicas);
691   TF_ASSERT_OK_AND_ASSIGN(auto module,
692                           ParseAndReturnVerifiedModule(kModuleStr, config));
693 
694   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
695                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
696                                             /*use_threads=*/true));
697   ASSERT_EQ(results.size(), kNumReplicas);
698   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
699                                      results[0]));
700   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
701                                      results[1]));
702   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
703                                      results[2]));
704   // Nothing writes to replica 3, so it is memzero'ed.
705   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({0, 0}),
706                                      results[3]));
707 }
708 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Rotate)709 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) {
710   const char* const kModuleStr = R"(
711   HloModule test
712   ENTRY test_computation {
713     replica = u32[] replica-id()
714     ten = u32[] constant(10)
715     sum = u32[] add(replica, ten)
716     p = u32[2] broadcast(sum), dimensions={}
717     permute = u32[2] collective-permute(p), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
718     ROOT copy = u32[2] copy(permute)
719   }
720   )";
721   const int64_t kNumReplicas = 4;
722 
723   auto config = GetModuleConfigForTest();
724   config.set_replica_count(kNumReplicas);
725   TF_ASSERT_OK_AND_ASSIGN(auto module,
726                           ParseAndReturnVerifiedModule(kModuleStr, config));
727 
728   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
729                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
730                                             /*use_threads=*/true));
731   ASSERT_EQ(results.size(), kNumReplicas);
732   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({13, 13}),
733                                      results[0]));
734   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({10, 10}),
735                                      results[1]));
736   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({11, 11}),
737                                      results[2]));
738   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32_t>({12, 12}),
739                                      results[3]));
740 }
741 
XLA_TEST_F(CollectiveOpsTest,AllToAll_EmptyReplicaGroups)742 XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) {
743   const char* const kModuleStr = R"(
744   HloModule test
745   ENTRY test_computation {
746     id = u32[] replica-id()
747     id2 = u32[2] broadcast(id), dimensions={}
748     a0 = u32[2] constant({10, 15})
749     b0 = u32[2] constant({20, 25})
750     c0 = u32[2] constant({30, 35})
751     d0 = u32[2] constant({40, 45})
752     a1 = u32[2] add(id2, a0)
753     b1 = u32[2] add(id2, b0)
754     c1 = u32[2] add(id2, c0)
755     d1 = u32[2] add(id2, d0)
756     all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={}
757     a_prime = u32[2] get-tuple-element(all2all), index=0
758     b_prime = u32[2] get-tuple-element(all2all), index=1
759     c_prime = u32[2] get-tuple-element(all2all), index=2
760     d_prime = u32[2] get-tuple-element(all2all), index=3
761     ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
762   }
763   )";
764   const int64_t kNumReplicas = 4;
765   auto config = GetModuleConfigForTest(kNumReplicas);
766   TF_ASSERT_OK_AND_ASSIGN(auto module,
767                           ParseAndReturnVerifiedModule(kModuleStr, config));
768 
769   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
770                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
771                                             /*use_threads=*/true));
772   ASSERT_EQ(results.size(), kNumReplicas);
773   LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 15, 11, 16, 12, 17, 13, 18},
774                                            results[0]);
775   LiteralTestUtil::ExpectR1Equal<uint32_t>({20, 25, 21, 26, 22, 27, 23, 28},
776                                            results[1]);
777   LiteralTestUtil::ExpectR1Equal<uint32_t>({30, 35, 31, 36, 32, 37, 33, 38},
778                                            results[2]);
779   LiteralTestUtil::ExpectR1Equal<uint32_t>({40, 45, 41, 46, 42, 47, 43, 48},
780                                            results[3]);
781 }
782 
XLA_TEST_F(CollectiveOpsTest,AllToAll_OrderedReplicaGroups)783 XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) {
784   const char* const kModuleStr = R"(
785   HloModule test
786   ENTRY test_computation {
787     id = u32[] replica-id()
788     id2 = u32[2] broadcast(id), dimensions={}
789     a0 = u32[2] constant({10, 15})
790     b0 = u32[2] constant({20, 25})
791     c0 = u32[2] constant({30, 35})
792     d0 = u32[2] constant({40, 45})
793     a1 = u32[2] add(id2, a0)
794     b1 = u32[2] add(id2, b0)
795     c1 = u32[2] add(id2, c0)
796     d1 = u32[2] add(id2, d0)
797     all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={{3,2,1,0}}
798     a_prime = u32[2] get-tuple-element(all2all), index=0
799     b_prime = u32[2] get-tuple-element(all2all), index=1
800     c_prime = u32[2] get-tuple-element(all2all), index=2
801     d_prime = u32[2] get-tuple-element(all2all), index=3
802     ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
803   }
804   )";
805   const int64_t kNumReplicas = 4;
806   auto config = GetModuleConfigForTest(kNumReplicas);
807   TF_ASSERT_OK_AND_ASSIGN(auto module,
808                           ParseAndReturnVerifiedModule(kModuleStr, config));
809 
810   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
811                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
812                                             /*use_threads=*/true));
813   ASSERT_EQ(results.size(), kNumReplicas);
814   LiteralTestUtil::ExpectR1Equal<uint32_t>({43, 48, 42, 47, 41, 46, 40, 45},
815                                            results[0]);
816   LiteralTestUtil::ExpectR1Equal<uint32_t>({33, 38, 32, 37, 31, 36, 30, 35},
817                                            results[1]);
818   LiteralTestUtil::ExpectR1Equal<uint32_t>({23, 28, 22, 27, 21, 26, 20, 25},
819                                            results[2]);
820   LiteralTestUtil::ExpectR1Equal<uint32_t>({13, 18, 12, 17, 11, 16, 10, 15},
821                                            results[3]);
822 }
823 
XLA_TEST_F(CollectiveOpsTest,AllToAll_TwoReplicaGroups)824 XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) {
825   const char* const kModuleStr = R"(
826   HloModule test
827   ENTRY test_computation {
828     id = u32[] replica-id()
829     id2 = u32[2] broadcast(id), dimensions={}
830     a0 = u32[2] constant({10, 15})
831     b0 = u32[2] constant({20, 25})
832     a1 = u32[2] add(id2, a0)
833     b1 = u32[2] add(id2, b0)
834     all2all = (u32[2], u32[2]) all-to-all(a1, b1), replica_groups={{2,1},{3,0}}
835     a_prime = u32[2] get-tuple-element(all2all), index=0
836     b_prime = u32[2] get-tuple-element(all2all), index=1
837     ROOT out = u32[4] concatenate(a_prime, b_prime), dimensions={0}
838   }
839   )";
840   const int64_t kNumReplicas = 4;
841   auto config = GetModuleConfigForTest(kNumReplicas);
842   TF_ASSERT_OK_AND_ASSIGN(auto module,
843                           ParseAndReturnVerifiedModule(kModuleStr, config));
844 
845   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
846                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
847                                             /*use_threads=*/true));
848   ASSERT_EQ(results.size(), kNumReplicas);
849   LiteralTestUtil::ExpectR1Equal<uint32_t>({23, 28, 20, 25}, results[0]);
850   LiteralTestUtil::ExpectR1Equal<uint32_t>({22, 27, 21, 26}, results[1]);
851   LiteralTestUtil::ExpectR1Equal<uint32_t>({12, 17, 11, 16}, results[2]);
852   LiteralTestUtil::ExpectR1Equal<uint32_t>({13, 18, 10, 15}, results[3]);
853 }
854 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllToAll_SplitDimension))855 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) {
856   const char* const kModuleStr = R"(
857   HloModule test
858   ENTRY test_computation {
859     id = u32[] replica-id()
860     id2 = u32[4, 2] broadcast(id), dimensions={}
861     a0 = u32[4, 2] constant({{10, 15}, {20, 25}, {30, 35}, {40, 45}})
862     a1 = u32[4, 2] add(id2, a0)
863     all2all = u32[4, 2] all-to-all(a1), replica_groups={{0,1,2,3}}, dimensions={0}
864     ROOT out = u32[8] reshape(all2all)
865   }
866   )";
867   const int64_t kNumReplicas = 4;
868   auto config = GetModuleConfigForTest(kNumReplicas);
869   TF_ASSERT_OK_AND_ASSIGN(auto module,
870                           ParseAndReturnVerifiedModule(kModuleStr, config));
871 
872   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
873                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
874                                             /*use_threads=*/true));
875   ASSERT_EQ(results.size(), kNumReplicas);
876   LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 15, 11, 16, 12, 17, 13, 18},
877                                            results[0]);
878   LiteralTestUtil::ExpectR1Equal<uint32_t>({20, 25, 21, 26, 22, 27, 23, 28},
879                                            results[1]);
880   LiteralTestUtil::ExpectR1Equal<uint32_t>({30, 35, 31, 36, 32, 37, 33, 38},
881                                            results[2]);
882   LiteralTestUtil::ExpectR1Equal<uint32_t>({40, 45, 41, 46, 42, 47, 43, 48},
883                                            results[3]);
884 }
885 
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim0)886 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) {
887   const char* const kModuleStr = R"(
888   HloModule test
889   ENTRY test_computation {
890     id = u32[] replica-id()
891     id2 = u32[1, 2] broadcast(id), dimensions={}
892     a0 = u32[1, 2] constant({{10, 15}})
893     a1 = u32[1, 2] add(id2, a0)
894     allgather = u32[4, 2] all-gather(a1), dimensions={0}
895     ROOT out = u32[8] reshape(allgather)
896   }
897   )";
898   const int64_t kNumReplicas = 4;
899   auto config = GetModuleConfigForTest(kNumReplicas);
900   TF_ASSERT_OK_AND_ASSIGN(auto module,
901                           ParseAndReturnVerifiedModule(kModuleStr, config));
902 
903   TF_ASSERT_OK_AND_ASSIGN(
904       std::vector<Literal> results,
905       ExecuteReplicated(std::move(module), {}, kNumReplicas,
906                         /*use_threads=*/true, /*run_hlo_passes=*/true));
907   ASSERT_EQ(results.size(), kNumReplicas);
908   for (const Literal& result : results) {
909     LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 15, 11, 16, 12, 17, 13, 18},
910                                              result);
911   }
912 }
913 
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim1)914 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) {
915   const char* const kModuleStr = R"(
916   HloModule test
917   ENTRY test_computation {
918     id = u32[] replica-id()
919     id2 = u32[2, 1] broadcast(id), dimensions={}
920     a0 = u32[2, 1] constant({{10}, {15}})
921     a1 = u32[2, 1] add(id2, a0)
922     allgather = u32[2, 4] all-gather(a1), dimensions={1}
923     ROOT out = u32[8] reshape(allgather)
924   }
925   )";
926   const int64_t kNumReplicas = 4;
927   auto config = GetModuleConfigForTest(kNumReplicas);
928   TF_ASSERT_OK_AND_ASSIGN(auto module,
929                           ParseAndReturnVerifiedModule(kModuleStr, config));
930 
931   TF_ASSERT_OK_AND_ASSIGN(
932       std::vector<Literal> results,
933       ExecuteReplicated(std::move(module), {}, kNumReplicas,
934                         /*use_threads=*/true, /*run_hlo_passes=*/true));
935   ASSERT_EQ(results.size(), kNumReplicas);
936   for (const Literal& result : results) {
937     LiteralTestUtil::ExpectR1Equal<uint32_t>({10, 11, 12, 13, 15, 16, 17, 18},
938                                              result);
939   }
940 }
941 
XLA_TEST_F(CollectiveOpsTest,AllReduce_TupleAllReduce)942 XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
943   std::string hlo_string = R"(
944     HloModule test
945 
946     apply_op {
947       x = f32[] parameter(0)
948       y = f32[] parameter(1)
949       ROOT apply_op = f32[] add(x, y)
950     }
951 
952     ENTRY test_computation {
953       p0 = f32[5] parameter(0)
954       p1 = f32[7] parameter(1)
955       ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
956     }
957   )";
958   static constexpr int kNumReplicas = 2;
959   auto config = GetModuleConfigForTest();
960   config.set_replica_count(kNumReplicas);
961   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
962                           ParseAndReturnVerifiedModule(hlo_string, config));
963 
964   std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
965   auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
966   std::vector<float> input1_vec = {
967       7., 3., 4., 1., 2., 3., 4.,
968   };
969   auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
970 
971   TF_ASSERT_OK_AND_ASSIGN(
972       std::vector<Literal> results,
973       ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
974                         /*num_replicas=*/kNumReplicas,
975                         /*use_threads=*/true));
976   std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
977   auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
978   std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
979   auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
980   for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
981     auto rs = results[replica_idx].DecomposeTuple();
982     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
983                                              ErrorSpec{1e-5, 1e-5}));
984     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
985                                              ErrorSpec{1e-5, 1e-5}));
986   }
987 }
988 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherMixedTypes))989 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGatherMixedTypes)) {
990   const char* const kModuleStr = R"(
991   HloModule test
992   ENTRY test_computation {
993     id = u32[] replica-id()
994     p0 = u32[2, 1] broadcast(id), dimensions={}
995     p1 = f32[2, 1] convert(p0)
996     allgather = (u32[2, 2], f32[2, 2]) all-gather(p0, p1), dimensions={1}
997     ag0 = u32[2, 2] get-tuple-element(allgather), index=0
998     ag1 = f32[2, 2] get-tuple-element(allgather), index=1
999     r0 = u32[4] reshape(ag0)
1000     r1 = f32[4] reshape(ag1)
1001     ROOT out = (u32[4], f32[4]) tuple(r0, r1)
1002   }
1003   )";
1004   const int64_t kNumReplicas = 2;
1005   auto config = GetModuleConfigForTest(kNumReplicas);
1006   TF_ASSERT_OK_AND_ASSIGN(auto module,
1007                           ParseAndReturnVerifiedModule(kModuleStr, config));
1008 
1009   TF_ASSERT_OK_AND_ASSIGN(
1010       std::vector<Literal> results,
1011       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1012                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1013   for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
1014     auto rs = results[replica_idx].DecomposeTuple();
1015     LiteralTestUtil::ExpectR1Equal<uint32_t>({0, 1, 0, 1}, rs[0]);
1016     LiteralTestUtil::ExpectR1Near<float>({0.0, 1.0, 0.0, 1.0}, rs[1],
1017                                          ErrorSpec{1e-5, 1e-5});
1018   }
1019 }
1020 
XLA_TEST_F(CollectiveOpsTest,ReduceScatter)1021 XLA_TEST_F(CollectiveOpsTest, ReduceScatter) {
1022   const char* const kModuleStr = R"(
1023   HloModule test
1024   add {
1025     lhs = u32[] parameter(0)
1026     rhs = u32[] parameter(1)
1027     ROOT add = u32[] add(lhs, rhs)
1028   }
1029 
1030   ENTRY main {
1031     c0 = u32[8] constant({1, 2, 3, 4, 5, 6, 7, 8})
1032     c1 = u32[8] constant({10, 11, 12, 13, 14, 15, 16, 17})
1033     zero = u32[] constant(0)
1034     id = u32[] replica-id()
1035     p = pred[] compare(id, zero), direction=EQ
1036     pb = pred[8] broadcast(p), dimensions={}
1037     // data = c0 for replica 0 and c1 for replica 1
1038     data = u32[8] select(pb, c0, c1)
1039     ROOT ars = u32[4] reduce-scatter(data), replica_groups={},
1040                       dimensions={0}, to_apply=add
1041   }
1042   )";
1043 
1044   const int64_t kNumReplicas = 2;
1045   auto config = GetModuleConfigForTest(kNumReplicas);
1046   TF_ASSERT_OK_AND_ASSIGN(auto module,
1047                           ParseAndReturnVerifiedModule(kModuleStr, config));
1048 
1049   TF_ASSERT_OK_AND_ASSIGN(
1050       std::vector<Literal> results,
1051       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1052                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1053   LiteralTestUtil::ExpectR1Equal<uint32_t>({11, 13, 15, 17}, results[0]);
1054   LiteralTestUtil::ExpectR1Equal<uint32_t>({19, 21, 23, 25}, results[1]);
1055 }
1056 
XLA_TEST_F(CollectiveOpsTest,ReduceScatter_Dim1)1057 XLA_TEST_F(CollectiveOpsTest, ReduceScatter_Dim1) {
1058   const char* const kModuleStr = R"(
1059   HloModule test
1060   add {
1061     lhs = u32[] parameter(0)
1062     rhs = u32[] parameter(1)
1063     ROOT add = u32[] add(lhs, rhs)
1064   }
1065 
1066   ENTRY main {
1067     c0 = u32[2, 4] constant({{ 1,  2,  3,  4}, { 5,  6,  7,  8}})
1068     c1 = u32[2, 4] constant({{10, 11, 12, 13}, {14, 15, 16, 17}})
1069     zero = u32[] constant(0)
1070     id = u32[] replica-id()
1071     p = pred[] compare(id, zero), direction=EQ
1072     pb = pred[2, 4] broadcast(p), dimensions={}
1073     // data = c0 for replica 0 and c1 for replica 1
1074     data = u32[2, 4] select(pb, c0, c1)
1075     // all-reduce result = {{11, 13, 15, 17}, {19, 21, 23, 25}}
1076     ars = u32[2, 2] reduce-scatter(data), replica_groups={},
1077                     dimensions={1}, to_apply=add
1078     ROOT r = u32[4] reshape(ars)
1079   }
1080   )";
1081 
1082   const int64_t kNumReplicas = 2;
1083   auto config = GetModuleConfigForTest(kNumReplicas);
1084   TF_ASSERT_OK_AND_ASSIGN(auto module,
1085                           ParseAndReturnVerifiedModule(kModuleStr, config));
1086 
1087   TF_ASSERT_OK_AND_ASSIGN(
1088       std::vector<Literal> results,
1089       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1090                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1091   LiteralTestUtil::ExpectR1Equal<uint32_t>({11, 13, 19, 21}, results[0]);
1092   LiteralTestUtil::ExpectR1Equal<uint32_t>({15, 17, 23, 25}, results[1]);
1093 }
1094 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceReassociate))1095 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduceReassociate)) {
1096   const char* const kModuleStr = R"(
1097   HloModule m
1098   sum {
1099     a = f32[] parameter(0)
1100     b = f32[] parameter(1)
1101     ROOT add.2 = f32[] add(a, b)
1102   }
1103 
1104   ENTRY main {
1105     c0 = f32[8] constant({  1,  2,  3,  4,  5,  6,  7,  8})
1106     c1 = f32[8] constant({ 11, 12, 13, 14, 15, 16, 17, 18})
1107     c2 = f32[8] constant({  2,  3,  4,  5,  6,  7,  8,  9})
1108     c3 = f32[8] constant({ 12, 13, 14, 15, 16, 17, 18, 19})
1109     zero = u32[] constant(0)
1110     id = u32[] replica-id()
1111     p = pred[] compare(id, zero), direction=EQ
1112     pb = pred[8] broadcast(p), dimensions={}
1113     // data0 = c0 for replica 0 and c1 for replica 1
1114     data0 = f32[8] select(pb, c0, c1)
1115     // data1 = c2 for replica 0 and c3 for replica 1
1116     data1 = f32[8] select(pb, c2, c3)
1117 
1118     ar0 = f32[8] all-reduce(data0), replica_groups={}, to_apply=sum
1119     ar1 = f32[8] all-reduce(data1), replica_groups={}, to_apply=sum
1120     ROOT add = f32[8] add(ar0, ar1)
1121   }
1122   )";
1123   const int64_t kNumReplicas = 2;
1124   auto config = GetModuleConfigForTest(kNumReplicas);
1125   TF_ASSERT_OK_AND_ASSIGN(auto module,
1126                           ParseAndReturnVerifiedModule(kModuleStr, config));
1127 
1128   TF_ASSERT_OK_AND_ASSIGN(
1129       std::vector<Literal> results,
1130       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1131                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1132 
1133   const ErrorSpec es{1e-5, 1e-5};
1134   EXPECT_TRUE(LiteralTestUtil::NearOrEqual(results[0], results[1], es));
1135   LiteralTestUtil::ExpectR1Near<float>(
1136       {26.0, 30.0, 34.0, 38.0, 42.0, 46.0, 50.0, 54.0}, results[0], es);
1137 }
1138 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_NonUniform))1139 XLA_TEST_F(CollectiveOpsTest,
1140            DISABLED_ON_CPU(AllGatherBroadcastReorder_NonUniform)) {
1141   const char* const kModuleStr = R"(
1142   HloModule m
1143 
1144   ENTRY main {
1145     c0 = u32[2, 3] constant({{ 1,  2,  3}, { 4, 5, 6}})
1146     c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1147     zero = u32[] constant(0)
1148     id = u32[] replica-id()
1149     p = pred[] compare(id, zero), direction=EQ
1150     pb = pred[2, 3] broadcast(p), dimensions={}
1151     // data = c0 for replica 0 and c1 for replica 1
1152     data = u32[2, 3] select(pb, c0, c1)
1153     bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1154     ROOT ag = u32[2, 4, 6] all-gather(bc), dimensions={2}, replica_groups={{0, 1}}
1155   }
1156   )";
1157 
1158   const int64_t kNumReplicas = 2;
1159   auto config = GetModuleConfigForTest(kNumReplicas);
1160   TF_ASSERT_OK_AND_ASSIGN(auto module,
1161                           ParseAndReturnVerifiedModule(kModuleStr, config));
1162 
1163   TF_ASSERT_OK_AND_ASSIGN(
1164       std::vector<Literal> results,
1165       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1166                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1167 
1168   EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1169   LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3, 10, 11, 12},
1170                                              {1, 2, 3, 10, 11, 12},
1171                                              {1, 2, 3, 10, 11, 12},
1172                                              {1, 2, 3, 10, 11, 12}},
1173                                             {{4, 5, 6, 13, 14, 15},
1174                                              {4, 5, 6, 13, 14, 15},
1175                                              {4, 5, 6, 13, 14, 15},
1176                                              {4, 5, 6, 13, 14, 15}}},
1177                                            results[0]);
1178 }
1179 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_Uniform))1180 XLA_TEST_F(CollectiveOpsTest,
1181            DISABLED_ON_CPU(AllGatherBroadcastReorder_Uniform)) {
1182   const char* const kModuleStr = R"(
1183   HloModule m
1184 
1185   ENTRY main {
1186     c0 = u32[2, 3] constant({{ 1,  2,  3}, { 4, 5, 6}})
1187     c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1188     zero = u32[] constant(0)
1189     id = u32[] replica-id()
1190     p = pred[] compare(id, zero), direction=EQ
1191     pb = pred[2, 3] broadcast(p), dimensions={}
1192     // data = c0 for replica 0 and c1 for replica 1
1193     data = u32[2, 3] select(pb, c0, c1)
1194     bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1195     ROOT ag = u32[2, 8, 3] all-gather(bc), dimensions={1}, replica_groups={{0, 1}}
1196   }
1197   )";
1198 
1199   const int64_t kNumReplicas = 2;
1200   auto config = GetModuleConfigForTest(kNumReplicas);
1201   TF_ASSERT_OK_AND_ASSIGN(auto module,
1202                           ParseAndReturnVerifiedModule(kModuleStr, config));
1203 
1204   TF_ASSERT_OK_AND_ASSIGN(
1205       std::vector<Literal> results,
1206       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1207                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1208   EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1209   LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3},
1210                                              {1, 2, 3},
1211                                              {1, 2, 3},
1212                                              {1, 2, 3},
1213                                              {10, 11, 12},
1214                                              {10, 11, 12},
1215                                              {10, 11, 12},
1216                                              {10, 11, 12}},
1217                                             {{4, 5, 6},
1218                                              {4, 5, 6},
1219                                              {4, 5, 6},
1220                                              {4, 5, 6},
1221                                              {13, 14, 15},
1222                                              {13, 14, 15},
1223                                              {13, 14, 15},
1224                                              {13, 14, 15}}},
1225                                            results[0]);
1226 }
1227 
1228 }  // namespace
1229 }  // namespace xla
1230