1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_parser.h"
32 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41
42 namespace xla {
43 namespace {
44
45 namespace op = xla::testing::opcode_matchers;
46 namespace m = xla::match;
47
48 class HloCseTest : public HloTestBase {
49 protected:
HloCseTest()50 HloCseTest() {}
51 };
52
TEST_F(HloCseTest,CombineTwoConstants)53 TEST_F(HloCseTest, CombineTwoConstants) {
54 // Test that two identical constants are commoned.
55 auto builder = HloComputation::Builder(TestName());
56 auto constant1 = builder.AddInstruction(
57 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
58 auto constant2 = builder.AddInstruction(
59 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
60 builder.AddInstruction(HloInstruction::CreateBinary(
61 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
62
63 auto module = CreateNewVerifiedModule();
64 auto computation = module->AddEntryComputation(builder.Build());
65
66 EXPECT_EQ(3, computation->instruction_count());
67
68 HloCSE cse(/*is_layout_sensitive=*/false);
69 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
70
71 EXPECT_EQ(2, computation->instruction_count());
72 HloInstruction* constant = *computation->instructions().begin();
73 EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
74
75 auto result = ExecuteAndTransfer(module->Clone(), {});
76 auto expected = LiteralUtil::CreateR0<float>(84.0);
77 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
78 }
79
TEST_F(HloCseTest,CombineTwoConstantsDifferentLayouts)80 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayouts) {
81 // Test that two identical constants with different layouts are *not*
82 // combined.
83 auto builder = HloComputation::Builder(TestName());
84 auto constant1 = builder.AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
86 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
87 auto constant2 = builder.AddInstruction(
88 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
89 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
90 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
91 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
92
93 auto module = CreateNewVerifiedModule();
94 auto computation = module->AddEntryComputation(builder.Build());
95
96 EXPECT_EQ(3, computation->instruction_count());
97 EXPECT_THAT(add, op::Add(constant1, constant2));
98
99 HloCSE cse(/*is_layout_sensitive=*/true);
100 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
101
102 EXPECT_EQ(3, computation->instruction_count());
103 EXPECT_THAT(add, op::Add(constant1, constant2));
104
105 auto result = ExecuteAndTransfer(module->Clone(), {});
106 auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
107 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
108 }
109
TEST_F(HloCseTest,ConstantsSameValueDifferentType)110 TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
111 // Test that constants with the same value but different type are *not*
112 // commoned.
113 auto builder = HloComputation::Builder(TestName());
114 std::vector<HloInstruction*> constants;
115 constants.push_back(builder.AddInstruction(
116 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42))));
117 constants.push_back(builder.AddInstruction(
118 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42))));
119 constants.push_back(builder.AddInstruction(
120 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64_t>(42.0))));
121 constants.push_back(builder.AddInstruction(
122 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(42.0))));
123 constants.push_back(builder.AddInstruction(
124 HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0))));
125 constants.push_back(builder.AddInstruction(
126 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
127 // Duplicate the float constant to verify something happens.
128 constants.push_back(builder.AddInstruction(
129 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
130
131 const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
132 for (int64_t i = 0; i < constants.size(); ++i) {
133 constants[i] = builder.AddInstruction(
134 HloInstruction::CreateConvert(shape_r0, constants[i]));
135 }
136 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
137 shape_r0, HloOpcode::kAdd, constants[0], constants[1]));
138 for (int64_t i = 2; i < constants.size(); ++i) {
139 root = builder.AddInstruction(HloInstruction::CreateBinary(
140 shape_r0, HloOpcode::kAdd, root, constants[i]));
141 }
142
143 auto module = CreateNewVerifiedModule();
144 auto computation = module->AddEntryComputation(builder.Build());
145
146 EXPECT_EQ(20, computation->instruction_count());
147
148 HloCSE cse(/*is_layout_sensitive=*/false);
149 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
150
151 // CSE will remove both the second float(42.0f) and the corresponding
152 // convert/cast.
153 EXPECT_EQ(18, computation->instruction_count());
154 }
155
TEST_F(HloCseTest,NonscalarConstants)156 TEST_F(HloCseTest, NonscalarConstants) {
157 // Test that identical nonscalar constants are merged.
158 auto builder = HloComputation::Builder(TestName());
159 auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
160 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
161 auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
162 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
163 // Create a constant which has the same shape but a different value.
164 auto uncommon_constant =
165 builder.AddInstruction(HloInstruction::CreateConstant(
166 LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
167
168 // Tie the constants together with a tuple. This makes it easier to refer to
169 // the constant instructions via their use.
170 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
171 {common_constant1, common_constant2, uncommon_constant}));
172
173 auto module = CreateNewVerifiedModule();
174 auto computation = module->AddEntryComputation(builder.Build());
175
176 EXPECT_EQ(4, computation->instruction_count());
177 EXPECT_THAT(tuple,
178 op::Tuple(common_constant1, common_constant2, uncommon_constant));
179
180 HloCSE cse(/*is_layout_sensitive=*/false);
181 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
182
183 EXPECT_EQ(3, computation->instruction_count());
184 auto first_operand = tuple->operand(0);
185 EXPECT_THAT(first_operand,
186 ::testing::AnyOf(common_constant1, common_constant2));
187 EXPECT_THAT(tuple,
188 op::Tuple(first_operand, first_operand, uncommon_constant));
189 }
190
TEST_F(HloCseTest,IdenticalInstructions)191 TEST_F(HloCseTest, IdenticalInstructions) {
192 // Test that three identical instructions are commoned.
193 auto builder = HloComputation::Builder(TestName());
194 auto constant = builder.AddInstruction(
195 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
196 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
197 constant->shape(), HloOpcode::kExp, constant));
198 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
199 constant->shape(), HloOpcode::kExp, constant));
200 auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary(
201 constant->shape(), HloOpcode::kExp, constant));
202 auto tuple =
203 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
204
205 auto module = CreateNewVerifiedModule();
206 auto computation = module->AddEntryComputation(builder.Build());
207
208 EXPECT_EQ(5, computation->instruction_count());
209 EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
210
211 HloCSE cse(/*is_layout_sensitive=*/true);
212 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
213
214 EXPECT_EQ(3, computation->instruction_count());
215 auto first_operand = tuple->operand(0);
216 EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3));
217 EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand));
218 }
219
220 // Test two identical while loops with same inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesSameInput)221 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
222 const char* const hlo_string = R"(
223 HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
224
225 %body (param: (f32[], f32[])) -> (f32[], f32[]) {
226 %param = (f32[], f32[]) parameter(0)
227 %gte0 = get-tuple-element(%param), index=0
228 %gte1 = get-tuple-element(%param), index=1
229 %add = add(%gte0, %gte1)
230 ROOT %tuple = tuple(%gte0, %add)
231 }
232
233 %condition {
234 %param.1 = (f32[], f32[]) parameter(0)
235 ROOT %constant = pred[] constant(false)
236 }
237
238 %condition.1 {
239 %param.2 = (f32[], f32[]) parameter(0)
240 ROOT %constant.1 = pred[] constant(false)
241 }
242
243 ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput {
244 %c0 = f32[] constant(1)
245 %c1 = f32[] constant(2)
246 %t = tuple(c0, c1)
247 %while = while(%t), condition=%condition, body=%body
248 %while.1 = while(%t), condition=%condition.1, body=%body
249 ROOT r = tuple(while, while.1)
250 })";
251
252 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
253 auto computation = m->entry_computation();
254
255 EXPECT_EQ(6, computation->instruction_count());
256 HloCSE cse(true);
257 EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
258 EXPECT_EQ(5, computation->instruction_count());
259 }
260
261 // Test two while loops with same conditions, same inputs, but different
262 // bodies
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsSameInputAndDifferentBodies)263 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
264 const char* const hlo_string = R"(
265 HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
266
267 %body {
268 %param = (f32[], f32[]) parameter(0)
269 %get-tuple-element = get-tuple-element(%param), index=0
270 %get-tuple-element.1 = get-tuple-element(%param), index=1
271 %add = add(%get-tuple-element, %get-tuple-element.1)
272 ROOT %tuple = tuple(%get-tuple-element, %add)
273 }
274
275 %body2 {
276 %param.1 = (f32[], f32[]) parameter(0)
277 %get-tuple-element.2 = get-tuple-element(%param.1), index=0
278 %get-tuple-element.3 = get-tuple-element(%param.1), index=1
279 %sub = subtract(%get-tuple-element.2, %get-tuple-element.3)
280 ROOT %tuple.2 = tuple(%get-tuple-element.2, %sub)
281 }
282
283 %condition (param.2: (f32[], f32[])) -> pred[] {
284 %param.2 = (f32[], f32[]) parameter(0)
285 ROOT %constant = pred[] constant(false)
286 }
287
288 %condition.1 (param.3: (f32[], f32[])) -> pred[] {
289 %param.3 = (f32[], f32[]) parameter(0)
290 ROOT %constant.1 = pred[] constant(false)
291 }
292
293 ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies {
294 %constant.2 = f32[] constant(1)
295 %constant.3 = f32[] constant(2)
296 %tuple.1 = tuple(f32[] %constant.2, f32[] %constant.3)
297 %while = while(%tuple.1), condition=%condition, body=%body
298 ROOT %while.1 = while(%tuple.1), condition=%condition.1, body=%body2
299 })";
300
301 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
302 auto computation = m->entry_computation();
303
304 EXPECT_EQ(5, computation->instruction_count());
305 HloCSE cse(true);
306 EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
307 EXPECT_EQ(5, computation->instruction_count());
308 }
309
310 // Test two identical while loops with different inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesDifferentInput)311 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
312 const char* const hlo_string = R"(
313 HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
314
315 %body {
316 %param = (f32[], f32[]) parameter(0)
317 %get-tuple-element = get-tuple-element(%param), index=0
318 %get-tuple-element.1 = get-tuple-element(%param), index=1
319 %add = add(%get-tuple-element, %get-tuple-element.1)
320 ROOT %tuple = tuple(%get-tuple-element, %add)
321 }
322
323 %body.1 {
324 %param.1 = (f32[], f32[]) parameter(0)
325 %gte = get-tuple-element(%param.1), index=0
326 %gte1 = get-tuple-element(%param.1), index=1
327 %add.1 = add(%gte, %gte1)
328 ROOT %tuple = tuple(%gte, %add.1)
329 }
330
331 %condition {
332 %param.1 = (f32[], f32[]) parameter(0)
333 ROOT %constant = pred[] constant(false)
334 }
335
336 %condition.1 {
337 %param.2 = (f32[], f32[]) parameter(0)
338 ROOT %constant.1 = pred[] constant(false)
339 }
340
341 ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput {
342 %constant.2 = f32[] constant(1)
343 %constant.3 = f32[] constant(2)
344 %tuple.1 = tuple(%constant.2, %constant.3)
345 %while = while(%tuple.1), condition=%condition, body=%body
346 %constant.4 = f32[] constant(1)
347 %constant.5 = f32[] constant(3)
348 %tuple.2 = tuple(%constant.4, %constant.5)
349 ROOT %while.1 = while(%tuple.2), condition=%condition.1, body=%body.1
350 })";
351
352 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
353 auto computation = m->entry_computation();
354
355 EXPECT_EQ(8, computation->instruction_count());
356 HloCSE cse(true);
357 EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
358 EXPECT_EQ(8, computation->instruction_count());
359 }
360
361 // Test two while loops with identical bodies and same inputs, but different
362 // conditions
TEST_F(HloCseTest,WhileLoopsIdenticalBodiesAndInputDifferentConditions)363 TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferentConditions) {
364 const char* const hlo_string = R"(
365 HloModule WhileLoopsIdenticalBodiesAndInputDifferentConditions
366
367 %body {
368 %param = (f32[], f32[]) parameter(0)
369 %get-tuple-element = get-tuple-element(%param), index=0
370 %get-tuple-element.1 = get-tuple-element((f32[], f32[]) %param), index=1
371 %add = add(%get-tuple-element, %get-tuple-element.1)
372 ROOT %tuple = tuple(%get-tuple-element, %add)
373 }
374
375 %condition {
376 %param.1 = (f32[], f32[]) parameter(0)
377 ROOT %constant = pred[] constant(false)
378 }
379
380 %condition.1 {
381 %param.2 = (f32[], f32[]) parameter(0)
382 ROOT %constant.1 = pred[] constant(true)
383 }
384
385 ENTRY %WhileLoopsIdenticalBodiesAndInputDifferentConditions {
386 %constant.2 = f32[] constant(1)
387 %constant.3 = f32[] constant(2)
388 %tuple.1 = tuple(%constant.2, %constant.3)
389 %while = while(%tuple.1), condition=%condition, body=%body
390 ROOT %while.1 = while(%tuple.1), condition=%condition.1, body=%body
391 })";
392
393 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
394 auto computation = m->entry_computation();
395
396 EXPECT_EQ(5, computation->instruction_count());
397 HloCSE cse(true);
398 EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
399 EXPECT_EQ(5, computation->instruction_count());
400 }
401
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsSensitive)402 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
403 // Test that two identical instructions with different layouts are *not*
404 // commoned if the pass is layout sensitive.
405 auto builder = HloComputation::Builder(TestName());
406 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
407 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
408
409 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
410 constant->shape(), HloOpcode::kExp, constant));
411 *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
412
413 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
414 constant->shape(), HloOpcode::kExp, constant));
415 *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
416
417 auto tuple =
418 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
419
420 auto module = CreateNewVerifiedModule();
421 auto computation = module->AddEntryComputation(builder.Build());
422
423 EXPECT_EQ(4, computation->instruction_count());
424 EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
425
426 HloCSE cse(/*is_layout_sensitive=*/true);
427 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
428
429 EXPECT_EQ(4, computation->instruction_count());
430 EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
431 }
432
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsInsensitive)433 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
434 // Test that two identical instructions with different layouts are commoned if
435 // the pass is layout insensitive.
436 auto builder = HloComputation::Builder(TestName());
437 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
438 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
439
440 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
441 constant->shape(), HloOpcode::kExp, constant));
442 *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
443
444 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
445 constant->shape(), HloOpcode::kExp, constant));
446 *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447
448 auto tuple =
449 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
450
451 auto module = CreateNewVerifiedModule();
452 auto computation = module->AddEntryComputation(builder.Build());
453
454 EXPECT_EQ(4, computation->instruction_count());
455 EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
456
457 HloCSE cse(/*is_layout_sensitive=*/false);
458 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
459
460 EXPECT_EQ(3, computation->instruction_count());
461 auto first_operand = tuple->operand(0);
462 EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2));
463 EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand));
464 }
465
TEST_F(HloCseTest,FusionInternalCSE)466 TEST_F(HloCseTest, FusionInternalCSE) {
467 // Test that we can CSE expressions that live within a fusion node
468 // computation.
469 auto module = CreateNewVerifiedModule();
470 auto builder = HloComputation::Builder(TestName());
471
472 const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
473 auto param0 = builder.AddInstruction(
474 HloInstruction::CreateParameter(0, shape_r0, "p0"));
475 auto param1 = builder.AddInstruction(
476 HloInstruction::CreateParameter(1, shape_r0, "p1"));
477 auto add1 = builder.AddInstruction(
478 HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
479 auto add2 = builder.AddInstruction(
480 HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
481 auto mul = builder.AddInstruction(
482 HloInstruction::CreateBinary(shape_r0, HloOpcode::kMultiply, add1, add2));
483
484 auto computation = module->AddEntryComputation(builder.Build());
485 auto fused_computation =
486 computation
487 ->CreateFusionInstruction({mul, add1, add2},
488 HloInstruction::FusionKind::kLoop)
489 ->fused_instructions_computation();
490
491 EXPECT_EQ(5, fused_computation->instruction_count());
492 HloCSE cse(/*is_layout_sensitive=*/false);
493 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
494 EXPECT_EQ(4, fused_computation->instruction_count());
495
496 auto root = fused_computation->root_instruction();
497 EXPECT_THAT(root, op::Multiply(root->operand(0), root->operand(0)));
498 }
499
TEST_F(HloCseTest,IdenticalExpressions)500 TEST_F(HloCseTest, IdenticalExpressions) {
501 // Test that two identical expressions are commoned. Build the following
502 // computation:
503 //
504 // constant = 42.0
505 // negate1 = neg(constant)
506 // exp1 = exp(constant)
507 // add1 = add(negate1, exp1)
508 // negate2 = neg(constant)
509 // exp2 = exp(constant)
510 // add2 = add(negate2, exp2)
511 // tuple = tuple(add1, add2)
512 //
513 // The *1 instructions should be merged with the *2 instructions.
514 auto builder = HloComputation::Builder(TestName());
515 auto constant = builder.AddInstruction(
516 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
517
518 auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
519 constant->shape(), HloOpcode::kNegate, constant));
520 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
521 constant->shape(), HloOpcode::kExp, constant));
522 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
523 constant->shape(), HloOpcode::kAdd, negate1, exp1));
524
525 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
526 constant->shape(), HloOpcode::kNegate, constant));
527 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
528 constant->shape(), HloOpcode::kExp, constant));
529 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
530 constant->shape(), HloOpcode::kAdd, negate2, exp2));
531
532 auto tuple =
533 builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
534
535 auto module = CreateNewVerifiedModule();
536 auto computation = module->AddEntryComputation(builder.Build());
537
538 EXPECT_EQ(8, computation->instruction_count());
539 EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
540
541 HloCSE cse(/*is_layout_sensitive=*/false);
542 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
543
544 EXPECT_EQ(5, computation->instruction_count());
545 auto operand = tuple->operand(0);
546 EXPECT_THAT(tuple, op::Tuple(operand, operand));
547 EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp()));
548 }
549
TEST_F(HloCseTest,DoNotCombineRng)550 TEST_F(HloCseTest, DoNotCombineRng) {
551 // Test that two RNG ops are not commoned.
552 auto builder = HloComputation::Builder(TestName());
553 auto constant1 = builder.AddInstruction(
554 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
555 auto constant2 = builder.AddInstruction(
556 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
557 auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
558 ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
559 {constant1, constant2}));
560 auto rng2 = builder.AddInstruction(HloInstruction::CreateRng(
561 ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
562 {constant1, constant2}));
563
564 builder.AddInstruction(HloInstruction::CreateBinary(
565 constant1->shape(), HloOpcode::kAdd, rng1, rng2));
566
567 auto module = CreateNewVerifiedModule();
568 auto computation = module->AddEntryComputation(builder.Build());
569
570 HloInstruction* root = computation->root_instruction();
571 EXPECT_THAT(root, op::Add(rng1, rng2));
572
573 uint32_t count_before = computation->instruction_count();
574
575 HloCSE cse(/*is_layout_sensitive=*/false);
576 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
577
578 uint32_t count_after = computation->instruction_count();
579 EXPECT_EQ(count_before, count_after);
580 root = computation->root_instruction();
581 EXPECT_THAT(root, op::Add(rng1, rng2));
582 }
583
TEST_F(HloCseTest,DoNotCombineCallsToImpureFunctions)584 TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
585 // Test that two calls to an impure function are not commoned. RNG
586 // is the source of the impurity.
587
588 auto module = CreateNewVerifiedModule();
589
590 // rng_function is an impure function because it does RNG.
591 HloComputation* rng_function = nullptr;
592 {
593 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
594 auto builder = HloComputation::Builder(TestName() + "_rng_fun");
595 auto constant1 = builder.AddInstruction(
596 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
597 auto constant2 = builder.AddInstruction(
598 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
599 auto rng = builder.AddInstruction(HloInstruction::CreateRng(
600 scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
601 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
602 0, ShapeUtil::MakeShape(F32, {}), "param"));
603 builder.AddInstruction(HloInstruction::CreateBinary(
604 scalar_shape, HloOpcode::kAdd, rng, param));
605 rng_function = module->AddEmbeddedComputation(builder.Build());
606 }
607
608 // Computation calls rng_function twice with the same parameter.
609 HloComputation* computation = nullptr;
610 {
611 auto builder = HloComputation::Builder(TestName());
612 auto constant = builder.AddInstruction(
613 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
614 auto rng1 = builder.AddInstruction(
615 HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
616 auto rng2 = builder.AddInstruction(
617 HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
618 builder.AddInstruction(HloInstruction::CreateBinary(
619 constant->shape(), HloOpcode::kAdd, rng1, rng2));
620 computation = module->AddEntryComputation(builder.Build());
621 }
622
623 EXPECT_EQ(4, computation->instruction_count());
624 HloInstruction* root = computation->root_instruction();
625 EXPECT_THAT(root, op::Add(op::Map(), op::Map()));
626
627 VLOG(3) << "before: " << module->ToString();
628
629 HloCSE cse(/*is_layout_sensitive=*/false);
630 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
631
632 VLOG(3) << "after: " << module->ToString();
633
634 EXPECT_EQ(4, computation->instruction_count());
635 root = computation->root_instruction();
636 EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant())));
637 }
638
TEST_F(HloCseTest,CompareComputations)639 TEST_F(HloCseTest, CompareComputations) {
640 const char* const hlo_string = R"(
641 HloModule m
642
643 add_computation {
644 add_lhs = f32[] parameter(0)
645 add_rhs = f32[] parameter(1)
646 ROOT add_root = add(add_lhs, add_rhs)
647 }
648
649 add_computation2 {
650 add_lhs2 = f32[] parameter(0)
651 add_rhs2 = f32[] parameter(1)
652 ROOT add_root2 = add(add_lhs2, add_rhs2)
653 }
654
655 ENTRY entry {
656 p = f32[10]{0} parameter(0)
657 c = f32[] constant(0)
658 r1 = reduce(p, c), dimensions={0}, to_apply=add_computation
659 r2 = reduce(p, c), dimensions={0}, to_apply=add_computation2
660 ROOT f2 = tuple(r1, r2)
661 })";
662
663 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
664 HloCSE cse(/*is_layout_sensitive=*/false);
665 EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
666 HloInstruction* root = m->entry_computation()->root_instruction();
667 EXPECT_EQ(root->operand(0), root->operand(1));
668 }
669
TEST_F(HloCseTest,ConstantsSameValueInDifferentDomains)670 TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
671 // Test that constants and iotas with the same value but in different domains
672 // (disjoint in this case) are not collapsed.
673 auto builder = HloComputation::Builder(TestName());
674 builder.AddInstruction(
675 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42)));
676 builder.AddInstruction(
677 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42)));
678 builder.AddInstruction(
679 HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), 0));
680 builder.AddInstruction(
681 HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), 0));
682
683 auto module = CreateNewVerifiedModule();
684 auto computation = module->AddEntryComputation(builder.Build());
685
686 EXPECT_EQ(4, computation->instruction_count());
687
688 HloCSE cse(/*is_layout_sensitive=*/false);
689 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
690
691 EXPECT_EQ(4, computation->instruction_count());
692 }
693
TEST_F(HloCseTest,Domain)694 TEST_F(HloCseTest, Domain) {
695 const char* const hlo_string = R"(
696 HloModule module
697 ENTRY %entry {
698 %param = f32[] parameter(0), sharding={maximal device=0}
699 %domain.0 = f32[] domain(%param),
700 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
701 %domain.1 = f32[] domain(%param),
702 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
703 %domain.2 = f32[] domain(%param),
704 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
705 %negate.0 = f32[] negate(%domain.0)
706 %negate.1 = f32[] negate(%domain.1)
707 %negate.2 = f32[] negate(%domain.2)
708 %domain.3 = f32[] domain(%negate.0),
709 domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
710 %domain.4 = f32[] domain(%negate.1),
711 domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
712 %domain.5 = f32[] domain(%negate.2),
713 domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
714 %add = f32[] add(%domain.3, %domain.4)
715 ROOT %sub = f32[] subtract(%add, %domain.5)
716 })";
717
718 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
719 HloCSE cse(/*is_layout_sensitive=*/false);
720 EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
721 const HloInstruction* sub = m->entry_computation()->root_instruction();
722 const HloInstruction* add = sub->operand(0);
723 EXPECT_EQ(add->operand(0), add->operand(1));
724 EXPECT_NE(add->operand(0), sub->operand(1));
725 EXPECT_NE(add->operand(1), sub->operand(1));
726 }
727
TEST_F(HloCseTest,Iota)728 TEST_F(HloCseTest, Iota) {
729 const char* const hlo_string = R"(
730 HloModule m
731
732 ENTRY entry {
733 i1 = s64[16,16] iota(), iota_dimension=0
734 i2 = s64[16,16] iota(), iota_dimension=0
735 i3 = s64[17,16] iota(), iota_dimension=0
736 i4 = s64[16,16] iota(), iota_dimension=1
737 ROOT root = (s64[16,16], s64[16,16], s64[17,16], s64[16,16]) tuple(i1, i2, i3, i4)
738 })";
739
740 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
741 HloCSE cse(/*is_layout_sensitive=*/false);
742 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
743 EXPECT_TRUE(changed);
744 HloInstruction* root = m->entry_computation()->root_instruction();
745 EXPECT_EQ(root->operand(0), root->operand(1));
746 EXPECT_NE(root->operand(0), root->operand(2));
747 EXPECT_NE(root->operand(0), root->operand(3));
748 }
749
TEST_F(HloCseTest,OptimizationBarrier)750 TEST_F(HloCseTest, OptimizationBarrier) {
751 const char* const hlo_string = R"(
752 HloModule m
753
754 ENTRY entry {
755 %param.0 = f32[] parameter(0)
756 %param.1 = f32[] parameter(1)
757 %add.0 = f32[] add(%param.0, %param.1)
758 %cse_tmp.0 = (f32[], f32[], f32[]) tuple(%param.0, %param.1, %add.0)
759 %cse_tmp.1 = (f32[], f32[], f32[]) opt-barrier(%cse_tmp.0)
760
761 %param.0.1 = f32[] get-tuple-element(%cse_tmp.1), index=0
762 %param.1.1 = f32[] get-tuple-element(%cse_tmp.1), index=1
763 %add.0.1 = f32[] get-tuple-element(%cse_tmp.1), index=2
764
765 %add.1 = f32[] add(%param.0.1, %param.1.1)
766 ROOT %add.2 = f32[] add(%add.1, %add.0.1)
767 })";
768
769 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
770 HloCSE cse(/*is_layout_sensitive=*/false);
771 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
772 EXPECT_FALSE(changed);
773 }
774
775 class HloCseCustomCallTest
776 : public HloCseTest,
777 public ::testing::WithParamInterface<std::tuple<
778 std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>> {};
779
TEST_P(HloCseCustomCallTest,DoIt)780 TEST_P(HloCseCustomCallTest, DoIt) {
781 std::string op1 = std::get<0>(GetParam());
782 std::string op2 = std::get<1>(GetParam());
783 bool should_cse = std::get<2>(GetParam());
784
785 const char* const hlo_string_tmpl = R"(
786 HloModule m
787 ENTRY entry {
788 p0 = f32[1,1,1] parameter(0)
789
790 op0 = $0
791 op1 = $0
792 op2 = $1
793 ROOT root = tuple(op0, op1, op2)
794 }
795 )";
796 std::string hlo_string = absl::Substitute(hlo_string_tmpl, op1, op2);
797 SCOPED_TRACE(absl::StrCat("Module before CSE:\n", hlo_string));
798
799 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
800 HloCSE cse(/*is_layout_sensitive=*/false);
801 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
802
803 SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
804 EXPECT_EQ(changed, true); // we always CSE op0 and op1, which are identical.
805 HloInstruction* root = m->entry_computation()->root_instruction();
806 EXPECT_EQ(root->operand(0), root->operand(1))
807 << "Identical ops should be CSE'ed";
808 if (should_cse) {
809 EXPECT_EQ(root->operand(0), root->operand(2)) << "Ops should be CSE'ed";
810 } else {
811 EXPECT_NE(root->operand(0), root->operand(2)) << "Ops should not be CSE'ed";
812 }
813 }
814
815 static std::vector<
816 std::tuple<std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>>
CustomCallTests()817 CustomCallTests() {
818 auto build = [](absl::string_view args1, absl::string_view args2) {
819 absl::string_view prefix =
820 "f32[] custom-call(p0), custom_call_target=\"foo\", ";
821 return std::make_tuple(absl::StrCat(prefix, args1),
822 absl::StrCat(prefix, args2), false);
823 };
824 return {
825 {
826 // metadata shouldn't prevent CSE
827 "f32[] custom-call(p0), custom_call_target=\"foo\"",
828 "f32[] custom-call(p0), custom_call_target=\"foo\", "
829 "metadata={op_name=\"bar\"}",
830 true,
831 },
832 {
833 "f32[] custom-call(p0), custom_call_target=\"foo\"",
834 "f32[] custom-call(p0, p0), custom_call_target=\"foo\"",
835 false,
836 },
837 {
838 "f32[1] custom-call(p0), custom_call_target=\"foo\"",
839 "f32[2] custom-call(p0), custom_call_target=\"foo\"",
840 false,
841 },
842 {
843 "f32[] custom-call(p0), custom_call_target=\"foo\"",
844 "f32[] custom-call(p0), custom_call_target=\"bar\"",
845 false,
846 },
847
848 build("window={size=1}", "window={size=2}"),
849 build("dim_labels=b0f_0oi->b0f", "dim_labels=b0f_0oi->bf0"),
850 build("backend_config=\"foo\"", "backend_config=\"bar\""),
851 build("literal=s32[] 0", "literal=s32[] 1"),
852 build("literal=s32[] 0", "literal=f32[] 0"),
853 build("operand_precision={high,default}",
854 "operand_precision={high, high}"),
855 build("api_version=API_VERSION_STATUS_RETURNING",
856 "api_version=API_VERSION_ORIGINAL"),
857 build("feature_group_count=0", "feature_group_count=1"),
858 };
859 }
860
861 INSTANTIATE_TEST_SUITE_P(HloCseCustomCallTestSuite, HloCseCustomCallTest,
862 ::testing::ValuesIn(CustomCallTests()));
863
TEST_F(HloCseTest,CustomCallCalledComputations)864 TEST_F(HloCseTest, CustomCallCalledComputations) {
865 const char* const hlo_string = R"(
866 HloModule m
867
868 comp {
869 lhs = f32[] parameter(0)
870 rhs = f32[] parameter(1)
871 ROOT maximum = f32[] maximum(lhs, rhs)
872 }
873
874 ENTRY entry {
875 p0 = f32[] parameter(0)
876
877 op0 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp}
878 op1 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp, comp}
879 ROOT root = tuple(op0, op1)
880 }
881 )";
882
883 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
884 HloCSE cse(/*is_layout_sensitive=*/false);
885 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
886
887 SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
888 EXPECT_EQ(changed, false);
889 }
890
TEST_F(HloCseTest,CustomCallSideEffects)891 TEST_F(HloCseTest, CustomCallSideEffects) {
892 const char* const hlo_string = R"(
893 HloModule m
894
895 ENTRY entry {
896 p0 = f32[] parameter(0)
897
898 op0 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
899 op1 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
900 ROOT root = tuple(op0, op1)
901 }
902 )";
903
904 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
905 HloCSE cse(/*is_layout_sensitive=*/false);
906 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
907
908 SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
909 EXPECT_EQ(changed, false);
910 }
911
912 class HloCseCommutativeOpTest
913 : public HloCseTest,
914 public ::testing::WithParamInterface<std::string /*op*/> {};
915
TEST_P(HloCseCommutativeOpTest,DoIt)916 TEST_P(HloCseCommutativeOpTest, DoIt) {
917 std::string op = GetParam();
918 const char* kModuleStr = R"(
919 HloModule m
920
921 ENTRY test {
922 p0 = s32[10] parameter(0)
923 p1 = s32[10] parameter(1)
924 op1 = s32[10] $0(p0, p1)
925 op2 = s32[10] $0(p1, p0)
926 ROOT t = tuple(op1, op2)
927 }
928 )";
929 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
930 absl::Substitute(kModuleStr, op)));
931 ASSERT_TRUE(
932 HloCSE(/*is_layout_sensitive=*/false).Run(module.get()).ValueOrDie());
933 SCOPED_TRACE(module->ToString());
934
935 const HloInstruction* op0;
936 const HloInstruction* op1;
937 ASSERT_THAT(module->entry_computation()->root_instruction(),
938 GmockMatch(m::Tuple(m::Op(&op0), m::Op(&op1))));
939 EXPECT_EQ(op0, op1);
940 }
941
942 INSTANTIATE_TEST_SUITE_P(AlgebraicSimplifierCanonicalizeCommutativeTestSuite,
943 HloCseCommutativeOpTest,
944 ::testing::Values("add", "multiply", "and", "or",
945 "xor", "minimum", "maximum"));
946
947 } // namespace
948 } // namespace xla
949