1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_reduce_reassociate.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22
23 namespace xla {
24 namespace {
25
26 namespace m = xla::testing::opcode_matchers;
27
28 class AllReduceSimplifierTest : public HloTestBase {
29 public:
RunPass(absl::string_view hlo_module,bool expect_change)30 StatusOr<std::unique_ptr<HloModule>> RunPass(absl::string_view hlo_module,
31 bool expect_change) {
32 TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module));
33 auto changed = AllReduceReassociate().Run(module.get());
34 if (!changed.ok()) {
35 return changed.status();
36 }
37 EXPECT_EQ(changed.ValueOrDie(), expect_change);
38 return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
39 }
40
AllReduceCount(std::unique_ptr<HloModule> & module)41 size_t AllReduceCount(std::unique_ptr<HloModule>& module) {
42 return absl::c_count_if(module->entry_computation()->instructions(),
43 [](const HloInstruction* inst) {
44 return inst->opcode() == HloOpcode::kAllReduce;
45 });
46 }
47 };
48
TEST_F(AllReduceSimplifierTest,Simple)49 TEST_F(AllReduceSimplifierTest, Simple) {
50 absl::string_view hlo_string = R"(
51 HloModule m
52
53 sum {
54 a = f32[] parameter(0)
55 b = f32[] parameter(1)
56 ROOT add.2 = f32[] add(a, b)
57 }
58
59 ENTRY main {
60 p0 = f32[8] parameter(0)
61 p1 = f32[8] parameter(1)
62 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
63 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
64 ROOT add = f32[8] add(ar0, ar1)
65 }
66 )";
67 TF_ASSERT_OK_AND_ASSIGN(auto module,
68 RunPass(hlo_string, /*expect_change=*/true));
69 EXPECT_THAT(module->entry_computation()->root_instruction(),
70 m::AllReduce(m::Add(m::Parameter(0), m::Parameter(1))));
71 EXPECT_EQ(AllReduceCount(module), 1);
72 }
73
TEST_F(AllReduceSimplifierTest,SimpleWithChannelId)74 TEST_F(AllReduceSimplifierTest, SimpleWithChannelId) {
75 absl::string_view hlo_string = R"(
76 HloModule m
77
78 sum {
79 a = f32[] parameter(0)
80 b = f32[] parameter(1)
81 ROOT add.2 = f32[] add(a, b)
82 }
83
84 ENTRY main {
85 p0 = f32[8] parameter(0)
86 p1 = f32[8] parameter(1)
87 ar0 = f32[8] all-reduce(p0), channel_id=1, replica_groups={}, to_apply=sum
88 ar1 = f32[8] all-reduce(p1), channel_id=1, replica_groups={}, to_apply=sum
89 ROOT add = f32[8] add(ar0, ar1)
90 }
91 )";
92 TF_ASSERT_OK_AND_ASSIGN(auto module,
93 RunPass(hlo_string, /*expect_change=*/true));
94 EXPECT_THAT(module->entry_computation()->root_instruction(),
95 m::AllReduce(m::Add(m::Parameter(0), m::Parameter(1))));
96 EXPECT_EQ(AllReduceCount(module), 1);
97 }
98
99 // Checks whether a linear chain of adds of ARs is reassociated iin a single
100 // pass.
TEST_F(AllReduceSimplifierTest,SimpleChain)101 TEST_F(AllReduceSimplifierTest, SimpleChain) {
102 absl::string_view hlo_string = R"(
103 HloModule m
104
105 sum {
106 a = f32[] parameter(0)
107 b = f32[] parameter(1)
108 ROOT add.2 = f32[] add(a, b)
109 }
110
111 ENTRY main {
112 p0 = f32[8] parameter(0)
113 p1 = f32[8] parameter(1)
114 p2 = f32[8] parameter(2)
115 p3 = f32[8] parameter(3)
116 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
117 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
118 ar2 = f32[8] all-reduce(p2), replica_groups={}, to_apply=sum
119 ar3 = f32[8] all-reduce(p3), replica_groups={}, to_apply=sum
120 add0 = f32[8] add(ar0, ar1)
121 add1 = f32[8] add(add0, ar2)
122 ROOT add2 = f32[8] add(add1, ar3)
123 }
124 )";
125 TF_ASSERT_OK_AND_ASSIGN(auto module,
126 RunPass(hlo_string, /*expect_change=*/true));
127 EXPECT_THAT(
128 module->entry_computation()->root_instruction(),
129 m::AllReduce(m::Add(
130 m::Add(m::Add(m::Parameter(0), m::Parameter(1)), m::Parameter(2)),
131 m::Parameter(3))));
132 EXPECT_EQ(AllReduceCount(module), 1);
133 }
134
135 // Checks whether a tree of add of ARs is reassociated in a single pass.
TEST_F(AllReduceSimplifierTest,SimpleTree)136 TEST_F(AllReduceSimplifierTest, SimpleTree) {
137 absl::string_view hlo_string = R"(
138 HloModule m
139
140 sum {
141 a = f32[] parameter(0)
142 b = f32[] parameter(1)
143 ROOT add.2 = f32[] add(a, b)
144 }
145
146 ENTRY main {
147 p0 = f32[8] parameter(0)
148 p1 = f32[8] parameter(1)
149 p2 = f32[8] parameter(2)
150 p3 = f32[8] parameter(3)
151 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
152 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
153 ar2 = f32[8] all-reduce(p2), replica_groups={}, to_apply=sum
154 ar3 = f32[8] all-reduce(p3), replica_groups={}, to_apply=sum
155 add0 = f32[8] add(ar0, ar1)
156 add1 = f32[8] add(ar2, ar3)
157 ROOT add2 = f32[8] add(add0, add1)
158 }
159 )";
160 TF_ASSERT_OK_AND_ASSIGN(auto module,
161 RunPass(hlo_string, /*expect_change=*/true));
162 EXPECT_THAT(module->entry_computation()->root_instruction(),
163 m::AllReduce(m::Add(m::Add(m::Parameter(0), m::Parameter(1)),
164 m::Add(m::Parameter(2), m::Parameter(3)))));
165 EXPECT_EQ(AllReduceCount(module), 1);
166 }
167
TEST_F(AllReduceSimplifierTest,MismatchOp0)168 TEST_F(AllReduceSimplifierTest, MismatchOp0) {
169 absl::string_view hlo_string = R"(
170 HloModule m
171
172 sum {
173 a = f32[] parameter(0)
174 b = f32[] parameter(1)
175 ROOT add.2 = f32[] add(a, b)
176 }
177
178 max {
179 a = f32[] parameter(0)
180 b = f32[] parameter(1)
181 ROOT r = f32[] maximum(a, b)
182 }
183
184 ENTRY main {
185 p0 = f32[8] parameter(0)
186 p1 = f32[8] parameter(1)
187 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
188 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=max
189 ROOT add = f32[8] add(ar0, ar1)
190 }
191 )";
192 TF_ASSERT_OK_AND_ASSIGN(auto module,
193 RunPass(hlo_string, /*expect_change=*/false));
194 }
195
TEST_F(AllReduceSimplifierTest,MismatchOp1)196 TEST_F(AllReduceSimplifierTest, MismatchOp1) {
197 absl::string_view hlo_string = R"(
198 HloModule m
199
200 sum {
201 a = f32[] parameter(0)
202 b = f32[] parameter(1)
203 ROOT add.2 = f32[] add(a, b)
204 }
205
206 max {
207 a = f32[] parameter(0)
208 b = f32[] parameter(1)
209 ROOT r = f32[] maximum(a, b)
210 }
211
212 ENTRY main {
213 p0 = f32[8] parameter(0)
214 p1 = f32[8] parameter(1)
215 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=max
216 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=max
217 ROOT add = f32[8] add(ar0, ar1)
218 }
219 )";
220 TF_ASSERT_OK_AND_ASSIGN(auto module,
221 RunPass(hlo_string, /*expect_change=*/false));
222 }
223
TEST_F(AllReduceSimplifierTest,MismatchReplicaGroups)224 TEST_F(AllReduceSimplifierTest, MismatchReplicaGroups) {
225 absl::string_view hlo_string = R"(
226 HloModule m
227
228 sum {
229 a = f32[] parameter(0)
230 b = f32[] parameter(1)
231 ROOT add.2 = f32[] add(a, b)
232 }
233
234 ENTRY main {
235 p0 = f32[8] parameter(0)
236 p1 = f32[8] parameter(1)
237 ar0 = f32[8] all-reduce(p0), replica_groups={{0}}, to_apply=sum
238 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
239 ROOT add = f32[8] add(ar0, ar1)
240 }
241 )";
242 TF_ASSERT_OK_AND_ASSIGN(auto module,
243 RunPass(hlo_string, /*expect_change=*/false));
244 }
245
TEST_F(AllReduceSimplifierTest,MismatchHasChannelId)246 TEST_F(AllReduceSimplifierTest, MismatchHasChannelId) {
247 absl::string_view hlo_string = R"(
248 HloModule m
249
250 sum {
251 a = f32[] parameter(0)
252 b = f32[] parameter(1)
253 ROOT add.2 = f32[] add(a, b)
254 }
255
256 ENTRY main {
257 p0 = f32[8] parameter(0)
258 p1 = f32[8] parameter(1)
259 ar0 = f32[8] all-reduce(p0), replica_groups={}, channel_id=3, to_apply=sum
260 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
261 ROOT add = f32[8] add(ar0, ar1)
262 }
263 )";
264 TF_ASSERT_OK_AND_ASSIGN(auto module,
265 RunPass(hlo_string, /*expect_change=*/false));
266 }
267
TEST_F(AllReduceSimplifierTest,MismatchUseGlobalDeviceId)268 TEST_F(AllReduceSimplifierTest, MismatchUseGlobalDeviceId) {
269 absl::string_view hlo_string = R"(
270 HloModule m
271
272 sum {
273 a = f32[] parameter(0)
274 b = f32[] parameter(1)
275 ROOT add.2 = f32[] add(a, b)
276 }
277
278 ENTRY main {
279 p0 = f32[8] parameter(0)
280 p1 = f32[8] parameter(1)
281 ar0 = f32[8] all-reduce(p0), replica_groups={{0, 1}}, channel_id=3, use_global_device_ids=true, to_apply=sum
282 ar1 = f32[8] all-reduce(p1), replica_groups={{0, 1}}, channel_id=4, to_apply=sum
283 ROOT add = f32[8] add(ar0, ar1)
284 }
285 )";
286 TF_ASSERT_OK_AND_ASSIGN(auto module,
287 RunPass(hlo_string, /*expect_change=*/false));
288 }
289
TEST_F(AllReduceSimplifierTest,NotSingleUser)290 TEST_F(AllReduceSimplifierTest, NotSingleUser) {
291 absl::string_view hlo_string = R"(
292 HloModule m
293
294 sum {
295 a = f32[] parameter(0)
296 b = f32[] parameter(1)
297 ROOT add.2 = f32[] add(a, b)
298 }
299
300 ENTRY main {
301 p0 = f32[8] parameter(0)
302 p1 = f32[8] parameter(1)
303 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
304 ar1 = f32[8] all-reduce(p1), replica_groups={}, to_apply=sum
305 add = f32[8] add(ar0, ar1)
306 ROOT t = (f32[8], f32[8]) tuple(ar0, add)
307 }
308 )";
309 TF_ASSERT_OK_AND_ASSIGN(auto module,
310 RunPass(hlo_string, /*expect_change=*/false));
311 }
312
TEST_F(AllReduceSimplifierTest,DoubleUse)313 TEST_F(AllReduceSimplifierTest, DoubleUse) {
314 absl::string_view hlo_string = R"(
315 HloModule m
316
317 sum {
318 a = f32[] parameter(0)
319 b = f32[] parameter(1)
320 ROOT add.2 = f32[] add(a, b)
321 }
322
323 ENTRY main {
324 p0 = f32[8] parameter(0)
325 p1 = f32[8] parameter(1)
326 ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum
327 add = f32[8] add(ar0, ar0)
328 ROOT c = f32[8] copy(add)
329 }
330 )";
331 TF_ASSERT_OK_AND_ASSIGN(auto module,
332 RunPass(hlo_string, /*expect_change=*/true));
333 }
334
335 } // namespace
336 } // namespace xla
337