1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/all_reduce_blueconnect.h"
17
18 #include <memory>
19
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/test_utils.h"
26 #include "tensorflow/core/platform/status_matchers.h"
27
28 namespace xla {
29 namespace {
30
31 using ::tensorflow::testing::IsOkAndHolds;
32 using ::testing::AllOf;
33 namespace op = xla::testing::opcode_matchers;
34
35 using AllReduceBlueConnectTest = HloTestBase;
36
SetModuleConfig(HloModule & module,size_t replica_count)37 void SetModuleConfig(HloModule& module, size_t replica_count) {
38 DeviceAssignment device_assignment(replica_count, /*computation_count=*/1);
39 device_assignment.FillIota(0);
40 module.config().set_replica_count(replica_count);
41 module.config().set_static_device_assignment(device_assignment);
42 }
43
TEST_F(AllReduceBlueConnectTest,OneStage)44 TEST_F(AllReduceBlueConnectTest, OneStage) {
45 constexpr absl::string_view hlo_string = R"(
46 HloModule module
47
48 %add {
49 lhs = f32[] parameter(0)
50 rhs = f32[] parameter(1)
51 ROOT add = f32[] add(lhs, rhs)
52 }
53
54 ENTRY %comp {
55 p0 = f32[4,4] parameter(0)
56 ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
57 })";
58 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
59 ParseAndReturnVerifiedModule(hlo_string));
60 SetModuleConfig(*module, /*replica_count=*/8);
61
62 AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
63 EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
64
65 // clang-format off
66 std::vector<std::vector<int64_t>> scatter_gather_groups = {
67 {0, 1, 2, 3}, {4, 5, 6, 7}};
68 std::vector<std::vector<int64_t>> new_all_reduce_groups = {
69 {0, 4}, {1, 5}, {2, 6}, {3, 7}};
70 // clang-format on
71
72 auto bitcast = AllOf(op::Shape("f32[16]"), op::Bitcast(op::Parameter(0)));
73 auto reduce_scatter = AllOf(op::Shape("f32[4]"), op::ReduceScatter(bitcast),
74 op::ReplicaGroups(scatter_gather_groups));
75 auto all_reduce = AllOf(op::Shape("f32[4]"), op::AllReduce(reduce_scatter),
76 op::ReplicaGroups(new_all_reduce_groups));
77 auto all_gather = AllOf(op::Shape("f32[16]"), op::AllGather(all_reduce),
78 op::ReplicaGroups(scatter_gather_groups));
79 EXPECT_THAT(module->entry_computation()->root_instruction(),
80 AllOf(op::Shape("f32[4,4]"), op::Bitcast(all_gather)));
81 }
82
TEST_F(AllReduceBlueConnectTest,TwoStage)83 TEST_F(AllReduceBlueConnectTest, TwoStage) {
84 constexpr absl::string_view hlo_string = R"(
85 HloModule module
86
87 %add {
88 lhs = f32[] parameter(0)
89 rhs = f32[] parameter(1)
90 ROOT add = f32[] add(lhs, rhs)
91 }
92
93 ENTRY %comp {
94 p0 = f32[4,4] parameter(0)
95 ROOT crs = f32[4,4] all-reduce(p0), to_apply=add
96 })";
97 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
98 ParseAndReturnVerifiedModule(hlo_string));
99 SetModuleConfig(*module, /*replica_count=*/16);
100
101 AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
102 EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
103
104 std::vector<std::vector<int64_t>> outer_scatter_gather_groups = {
105 {0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}, {12, 13, 14, 15}};
106 std::vector<std::vector<int64_t>> inner_scatter_gather_groups = {
107 {0, 4}, {8, 12}, {1, 5}, {9, 13}, {2, 6}, {10, 14}, {3, 7}, {11, 15}};
108 std::vector<std::vector<int64_t>> new_all_reduce_groups = {
109 {0, 8}, {4, 12}, {1, 9}, {5, 13}, {2, 10}, {6, 14}, {3, 11}, {7, 15}};
110
111 auto bitcast0 = AllOf(op::Shape("f32[16]"), op::Bitcast(op::Parameter(0)));
112 auto reduce_scatter0 = AllOf(op::Shape("f32[4]"), op::ReduceScatter(bitcast0),
113 op::ReplicaGroups(outer_scatter_gather_groups));
114 auto bitcast1 = AllOf(op::Shape("f32[4]"), op::Bitcast(reduce_scatter0));
115 auto reduce_scatter1 = AllOf(op::Shape("f32[2]"), op::ReduceScatter(bitcast1),
116 op::ReplicaGroups(inner_scatter_gather_groups));
117 auto all_reduce = AllOf(op::Shape("f32[2]"), op::AllReduce(reduce_scatter1),
118 op::ReplicaGroups(new_all_reduce_groups));
119 auto all_gather0 = AllOf(op::Shape("f32[4]"), op::AllGather(all_reduce),
120 op::ReplicaGroups(inner_scatter_gather_groups));
121 auto bitcast2 = AllOf(op::Shape("f32[4]"), op::Bitcast(all_gather0));
122 auto all_gather1 = AllOf(op::Shape("f32[16]"), op::AllGather(bitcast2),
123 op::ReplicaGroups(outer_scatter_gather_groups));
124 EXPECT_THAT(module->entry_computation()->root_instruction(),
125 AllOf(op::Shape("f32[4,4]"), op::Bitcast(all_gather1)));
126 }
127
TEST_F(AllReduceBlueConnectTest,TwoOperands)128 TEST_F(AllReduceBlueConnectTest, TwoOperands) {
129 constexpr absl::string_view hlo_string = R"(
130 HloModule module
131
132 %add {
133 lhs = f32[] parameter(0)
134 rhs = f32[] parameter(1)
135 ROOT add = f32[] add(lhs, rhs)
136 }
137
138 ENTRY %comp {
139 p0 = f32[4,4] parameter(0)
140 p1 = f32[4,4,2] parameter(1)
141 ROOT crs = (f32[4,4], f32[4,4,2]) all-reduce(p0, p1), to_apply=add
142 })";
143 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
144 ParseAndReturnVerifiedModule(hlo_string));
145 SetModuleConfig(*module, /*replica_count=*/8);
146
147 AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
148 EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(true));
149
150 // clang-format off
151 std::vector<std::vector<int64_t>> scatter_gather_groups = {
152 {0, 1, 2, 3}, {4, 5, 6, 7}};
153 std::vector<std::vector<int64_t>> new_all_reduce_groups = {
154 {0, 4}, {1, 5}, {2, 6}, {3, 7}};
155 // clang-format on
156
157 auto bitcast0 = AllOf(op::Shape("f32[16]"), op::Bitcast(op::Parameter(0)));
158 auto bitcast1 = AllOf(op::Shape("f32[32]"), op::Bitcast(op::Parameter(1)));
159 auto reduce_scatter = AllOf(op::Shape("(f32[4], f32[8])"),
160 op::ReduceScatter(bitcast0, bitcast1),
161 op::ReplicaGroups(scatter_gather_groups));
162 auto all_reduce = AllOf(op::Shape("(f32[4], f32[8])"),
163 op::AllReduce(op::GetTupleElement(reduce_scatter, 0),
164 op::GetTupleElement(reduce_scatter, 1)),
165 op::ReplicaGroups(new_all_reduce_groups));
166 auto all_gather = AllOf(op::Shape("(f32[16], f32[32])"),
167 op::AllGather(op::GetTupleElement(all_reduce, 0),
168 op::GetTupleElement(all_reduce, 1)),
169 op::ReplicaGroups(scatter_gather_groups));
170 auto bitcast2 = AllOf(op::Shape("f32[4,4]"),
171 op::Bitcast(op::GetTupleElement(all_gather, 0)));
172 auto bitcast3 = AllOf(op::Shape("f32[4,4,2]"),
173 op::Bitcast(op::GetTupleElement(all_gather, 1)));
174 EXPECT_THAT(module->entry_computation()->root_instruction(),
175 op::Tuple(bitcast2, bitcast3));
176 }
177
TEST_F(AllReduceBlueConnectTest,DifferentNumLocalDevicesWithinReplicaGroup)178 TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesWithinReplicaGroup) {
179 constexpr absl::string_view hlo_string = R"(
180 HloModule module
181
182 %add {
183 lhs = f32[] parameter(0)
184 rhs = f32[] parameter(1)
185 ROOT add = f32[] add(lhs, rhs)
186 }
187
188 ENTRY %comp {
189 p0 = f32[4,4] parameter(0)
190 ROOT crs = f32[4,4] all-reduce(p0),
191 replica_groups={{0,1,2,7},{3,4,5,6}}, to_apply=add
192 })";
193 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
194 ParseAndReturnVerifiedModule(hlo_string));
195 SetModuleConfig(*module, /*replica_count=*/8);
196
197 AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
198 EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
199 }
200
TEST_F(AllReduceBlueConnectTest,DifferentNumLocalDevicesAcrossReplicaGroups)201 TEST_F(AllReduceBlueConnectTest, DifferentNumLocalDevicesAcrossReplicaGroups) {
202 constexpr absl::string_view hlo_string = R"(
203 HloModule module
204
205 %add {
206 lhs = f32[] parameter(0)
207 rhs = f32[] parameter(1)
208 ROOT add = f32[] add(lhs, rhs)
209 }
210
211 ENTRY %comp {
212 p0 = f32[4,4] parameter(0)
213 ROOT crs = f32[4,4] all-reduce(p0),
214 replica_groups={{0,1,4,5},{2,3,6,7},{8,9,10,11},{12,13,14,15}}, to_apply=add
215 })";
216 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
217 ParseAndReturnVerifiedModule(hlo_string));
218 SetModuleConfig(*module, /*replica_count=*/16);
219
220 AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
221 EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
222 }
223
TEST_F(AllReduceBlueConnectTest,OperandIndivisible)224 TEST_F(AllReduceBlueConnectTest, OperandIndivisible) {
225 constexpr absl::string_view hlo_string = R"(
226 HloModule module
227
228 %add {
229 lhs = f32[] parameter(0)
230 rhs = f32[] parameter(1)
231 ROOT add = f32[] add(lhs, rhs)
232 }
233
234 ENTRY %comp {
235 p0 = f32[4,4] parameter(0)
236 p1 = f32[9] parameter(1)
237 ROOT crs = (f32[4,4], f32[9]) all-reduce(p0, p1), to_apply=add
238 })";
239 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
240 ParseAndReturnVerifiedModule(hlo_string));
241 SetModuleConfig(*module, /*replica_count=*/8);
242
243 AllReduceBlueConnect pass(/*num_devices_per_host=*/4);
244 EXPECT_THAT(pass.Run(module.get()), IsOkAndHolds(false));
245 }
246
247 } // namespace
248 } // namespace xla
249