xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/test_utils.h"
26 #include "tensorflow/core/platform/status_matchers.h"
27 
28 namespace xla {
29 namespace {
30 
31 using ::tensorflow::testing::IsOkAndHolds;
32 using ::testing::AllOf;
33 namespace op = xla::testing::opcode_matchers;
34 
35 using AllReduceBlueConnectTest = HloTestBase;
36 
SetModuleConfig(HloModule & module,size_t replica_count)37 void SetModuleConfig(HloModule& module, size_t replica_count) {
38   DeviceAssignment device_assignment(replica_count, /*computation_count=*/1);
39   device_assignment.FillIota(0);
40   module.config().set_replica_count(replica_count);
41   module.config().set_static_device_assignment(device_assignment);
42 }
43 
TEST_F(AllReduceBlueConnectTest,OneStage)44 TEST_F(AllReduceBlueConnectTest, OneStage) {
45   constexpr absl::string_view hlo_string = R"(
46 HloModule module
47 
48 %add {
49   lhs = f32[] parameter(0)
50   rhs = f32[] parameter(1)
51   ROOT add = f32[] add(lhs, rhs)
52 }
53 
54 ENTRY %comp {
55   p0 = f32[4,4] parameter(0)
56   ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
57 })";
58   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
59                           ParseAndReturnVerifiedModule(hlo_string));
60   SetModuleConfig(*module, /*replica_count=*/8);
61 
62   AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
63   EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
64 
65   // clang-format off
66   std::vector<std::vector<int64_t>> scatter_gather_groups = {
67       {0, 1, 2, 3}, {4, 5, 6, 7}};
68   std::vector<std::vector<int64_t>> new_all_reduce_groups = {
69       {0, 4}, {1, 5}, {2, 6}, {3, 7}};
70   // clang-format on
71 
72   auto bitcast = AllOf(op::Shape("f32[16]"), op::Bitcast(op::Parameter(0)));
73   auto reduce_scatter = AllOf(op::Shape("f32[4]"), op::ReduceScatter(bitcast),
74                               op::ReplicaGroups(scatter_gather_groups));
75   auto all_reduce = AllOf(op::Shape("f32[4]"), op::AllReduce(reduce_scatter),
76                           op::ReplicaGroups(new_all_reduce_groups));
77   auto all_gather = AllOf(op::Shape("f32[16]"), op::AllGather(all_reduce),
78                           op::ReplicaGroups(scatter_gather_groups));
79   EXPECT_THAT(module->entry_computation()->root_instruction(),
80               AllOf(op::Shape("f32[4,4]"), op::Bitcast(all_gather)));
81 }
82 
TEST_F(AllReduceBlueConnectTest,TwoStage)83 TEST_F(AllReduceBlueConnectTest, TwoStage) {
84   constexpr absl::string_view hlo_string = R"(
85 HloModule module
86 
87 %add {
88   lhs = f32[] parameter(0)
89   rhs = f32[] parameter(1)
90   ROOT add = f32[] add(lhs, rhs)
91 }
92 
93 ENTRY %comp {
94   p0 = f32[4,4] parameter(0)
95   ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
96 })";
97   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
98                           ParseAndReturnVerifiedModule(hlo_string));
99   SetModuleConfig(*module, /*replica_count=*/16);
100 
101   AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
102   EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
103 
104   std::vector<std::vector<int64_t>> outer_scatter_gather_groups = {
105       {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}};
106   std::vector<std::vector<int64_t>> inner_scatter_gather_groups = {
107       {0, 4}, {8, 12}, {1, 5}, {9, 13}, {2, 6}, {10, 14}, {3, 7}, {11, 15}};
108   std::vector<std::vector<int64_t>> new_all_reduce_groups = {
109       {0, 8}, {4, 12}, {1, 9}, {5, 13}, {2, 10}, {6, 14}, {3, 11}, {7, 15}};
110 
111   auto bitcast0 = AllOf(op::Shape("f32[16]"), op::Bitcast(op::Parameter(0)));
112   auto reduce_scatter0 = AllOf(op::Shape("f32[4]"), op::ReduceScatter(bitcast0),
113                                op::ReplicaGroups(outer_scatter_gather_groups));
114   auto bitcast1 = AllOf(op::Shape("f32[4]"), op::Bitcast(reduce_scatter0));
115   auto reduce_scatter1 = AllOf(op::Shape("f32[2]"), op::ReduceScatter(bitcast1),
116                                op::ReplicaGroups(inner_scatter_gather_groups));
117   auto all_reduce = AllOf(op::Shape("f32[2]"), op::AllReduce(reduce_scatter1),
118                           op::ReplicaGroups(new_all_reduce_groups));
119   auto all_gather0 = AllOf(op::Shape("f32[4]"), op::AllGather(all_reduce),
120                            op::ReplicaGroups(inner_scatter_gather_groups));
121   auto bitcast2 = AllOf(op::Shape("f32[4]"), op::Bitcast(all_gather0));
122   auto all_gather1 = AllOf(op::Shape("f32[16]"), op::AllGather(bitcast2),
123                            op::ReplicaGroups(outer_scatter_gather_groups));
124   EXPECT_THAT(module->entry_computation()->root_instruction(),
125               AllOf(op::Shape("f32[4,4]"), op::Bitcast(all_gather1)));
126 }
127 
TEST_F(AllReduceBlueConnectTest,TwoOperands)128 TEST_F(AllReduceBlueConnectTest, TwoOperands) {
129   constexpr absl::string_view hlo_string = R"(
130 HloModule module
131 
132 %add {
133   lhs = f32[] parameter(0)
134   rhs = f32[] parameter(1)
135   ROOT add = f32[] add(lhs, rhs)
136 }
137 
138 ENTRY %comp {
139   p0 = f32[4,4] parameter(0)
140   p1 = f32[4,4,2] parameter(1)
141   ROOT crs = (f32[4,4], f32[4,4,2]) all-reduce(p0, p1), to_apply=add
142 })";
143   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
144                           ParseAndReturnVerifiedModule(hlo_string));
145   SetModuleConfig(*module, /*replica_count=*/8);
146 
147   AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
148   EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
149 
150   // clang-format off
151   std::vector<std::vector<int64_t>> scatter_gather_groups = {
152       {0, 1, 2, 3}, {4, 5, 6, 7}};
153   std::vector<std::vector<int64_t>> new_all_reduce_groups = {
154       {0, 4}, {1, 5}, {2, 6}, {3, 7}};
155   // clang-format on
156 
157   auto bitcast0 = AllOf(op::Shape("f32[16]"), op::Bitcast(op::Parameter(0)));
158   auto bitcast1 = AllOf(op::Shape("f32[32]"), op::Bitcast(op::Parameter(1)));
159   auto reduce_scatter = AllOf(op::Shape("(f32[4], f32[8])"),
160                               op::ReduceScatter(bitcast0, bitcast1),
161                               op::ReplicaGroups(scatter_gather_groups));
162   auto all_reduce = AllOf(op::Shape("(f32[4], f32[8])"),
163                           op::AllReduce(op::GetTupleElement(reduce_scatter, 0),
164                                         op::GetTupleElement(reduce_scatter, 1)),
165                           op::ReplicaGroups(new_all_reduce_groups));
166   auto all_gather = AllOf(op::Shape("(f32[16], f32[32])"),
167                           op::AllGather(op::GetTupleElement(all_reduce, 0),
168                                         op::GetTupleElement(all_reduce, 1)),
169                           op::ReplicaGroups(scatter_gather_groups));
170   auto bitcast2 = AllOf(op::Shape("f32[4,4]"),
171                         op::Bitcast(op::GetTupleElement(all_gather, 0)));
172   auto bitcast3 = AllOf(op::Shape("f32[4,4,2]"),
173                         op::Bitcast(op::GetTupleElement(all_gather, 1)));
174   EXPECT_THAT(module->entry_computation()->root_instruction(),
175               op::Tuple(bitcast2, bitcast3));
176 }
177 
TEST_F(AllReduceBlueConnectTest,DifferentNumLocalDevicesWithinReplicaGroup)178 TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) {
179   constexpr absl::string_view hlo_string = R"(
180 HloModule module
181 
182 %add {
183   lhs = f32[] parameter(0)
184   rhs = f32[] parameter(1)
185   ROOT add = f32[] add(lhs, rhs)
186 }
187 
188 ENTRY %comp {
189   p0 = f32[4,4] parameter(0)
190   ROOT crs = f32[4,4] all-reduce(p0),
191     replica_groups={{0,1,2,7},{3,4,5,6}}, to_apply=add
192 })";
193   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
194                           ParseAndReturnVerifiedModule(hlo_string));
195   SetModuleConfig(*module, /*replica_count=*/8);
196 
197   AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
198   EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
199 }
200 
TEST_F(AllReduceBlueConnectTest,DifferentNumLocalDevicesAcrossReplicaGroups)201 TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesAcrossReplicaGroups) {
202   constexpr absl::string_view hlo_string = R"(
203 HloModule module
204 
205 %add {
206   lhs = f32[] parameter(0)
207   rhs = f32[] parameter(1)
208   ROOT add = f32[] add(lhs, rhs)
209 }
210 
211 ENTRY %comp {
212   p0 = f32[4,4] parameter(0)
213   ROOT crs = f32[4,4] all-reduce(p0),
214     replica_groups={{0,1,4,5},{2,3,6,7},{8,9,10,11},{12,13,14,15}}, to_apply=add
215 })";
216   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
217                           ParseAndReturnVerifiedModule(hlo_string));
218   SetModuleConfig(*module, /*replica_count=*/16);
219 
220   AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
221   EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
222 }
223 
TEST_F(AllReduceBlueConnectTest,OperandIndivisible)224 TEST_F(AllReduceBlueConnectTest, OperandIndivisible) {
225   constexpr absl::string_view hlo_string = R"(
226 HloModule module
227 
228 %add {
229   lhs = f32[] parameter(0)
230   rhs = f32[] parameter(1)
231   ROOT add = f32[] add(lhs, rhs)
232 }
233 
234 ENTRY %comp {
235   p0 = f32[4,4] parameter(0)
236   p1 = f32[9] parameter(1)
237   ROOT crs = (f32[4,4], f32[9]) all-reduce(p0, p1), to_apply=add
238 })";
239   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
240                           ParseAndReturnVerifiedModule(hlo_string));
241   SetModuleConfig(*module, /*replica_count=*/8);
242 
243   AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
244   EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
245 }
246 
247 }  // namespace
248 }  // namespace xla
249