xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_cse_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_parser.h"
32 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 
42 namespace xla {
43 namespace {
44 
45 namespace op = xla::testing::opcode_matchers;
46 namespace m = xla::match;
47 
48 class HloCseTest : public HloTestBase {
49  protected:
HloCseTest()50   HloCseTest() {}
51 };
52 
TEST_F(HloCseTest,CombineTwoConstants)53 TEST_F(HloCseTest, CombineTwoConstants) {
54   // Test that two identical constants are commoned.
55   auto builder = HloComputation::Builder(TestName());
56   auto constant1 = builder.AddInstruction(
57       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
58   auto constant2 = builder.AddInstruction(
59       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
60   builder.AddInstruction(HloInstruction::CreateBinary(
61       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
62 
63   auto module = CreateNewVerifiedModule();
64   auto computation = module->AddEntryComputation(builder.Build());
65 
66   EXPECT_EQ(3, computation->instruction_count());
67 
68   HloCSE cse(/*is_layout_sensitive=*/false);
69   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
70 
71   EXPECT_EQ(2, computation->instruction_count());
72   HloInstruction* constant = *computation->instructions().begin();
73   EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
74 
75   auto result = ExecuteAndTransfer(module->Clone(), {});
76   auto expected = LiteralUtil::CreateR0<float>(84.0);
77   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
78 }
79 
TEST_F(HloCseTest,CombineTwoConstantsDifferentLayouts)80 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayouts) {
81   // Test that two identical constants with different layouts are *not*
82   // combined.
83   auto builder = HloComputation::Builder(TestName());
84   auto constant1 = builder.AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
86           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
87   auto constant2 = builder.AddInstruction(
88       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
89           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
90   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
91       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
92 
93   auto module = CreateNewVerifiedModule();
94   auto computation = module->AddEntryComputation(builder.Build());
95 
96   EXPECT_EQ(3, computation->instruction_count());
97   EXPECT_THAT(add, op::Add(constant1, constant2));
98 
99   HloCSE cse(/*is_layout_sensitive=*/true);
100   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
101 
102   EXPECT_EQ(3, computation->instruction_count());
103   EXPECT_THAT(add, op::Add(constant1, constant2));
104 
105   auto result = ExecuteAndTransfer(module->Clone(), {});
106   auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
107   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
108 }
109 
TEST_F(HloCseTest,ConstantsSameValueDifferentType)110 TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
111   // Test that constants with the same value but different type are *not*
112   // commoned.
113   auto builder = HloComputation::Builder(TestName());
114   std::vector<HloInstruction*> constants;
115   constants.push_back(builder.AddInstruction(
116       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42))));
117   constants.push_back(builder.AddInstruction(
118       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42))));
119   constants.push_back(builder.AddInstruction(
120       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64_t>(42.0))));
121   constants.push_back(builder.AddInstruction(
122       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(42.0))));
123   constants.push_back(builder.AddInstruction(
124       HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0))));
125   constants.push_back(builder.AddInstruction(
126       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
127   // Duplicate the float constant to verify something happens.
128   constants.push_back(builder.AddInstruction(
129       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
130 
131   const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
132   for (int64_t i = 0; i < constants.size(); ++i) {
133     constants[i] = builder.AddInstruction(
134         HloInstruction::CreateConvert(shape_r0, constants[i]));
135   }
136   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
137       shape_r0, HloOpcode::kAdd, constants[0], constants[1]));
138   for (int64_t i = 2; i < constants.size(); ++i) {
139     root = builder.AddInstruction(HloInstruction::CreateBinary(
140         shape_r0, HloOpcode::kAdd, root, constants[i]));
141   }
142 
143   auto module = CreateNewVerifiedModule();
144   auto computation = module->AddEntryComputation(builder.Build());
145 
146   EXPECT_EQ(20, computation->instruction_count());
147 
148   HloCSE cse(/*is_layout_sensitive=*/false);
149   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
150 
151   // CSE will remove both the second float(42.0f) and the corresponding
152   // convert/cast.
153   EXPECT_EQ(18, computation->instruction_count());
154 }
155 
TEST_F(HloCseTest,NonscalarConstants)156 TEST_F(HloCseTest, NonscalarConstants) {
157   // Test that identical nonscalar constants are merged.
158   auto builder = HloComputation::Builder(TestName());
159   auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
160       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
161   auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
162       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
163   // Create a constant which has the same shape but a different value.
164   auto uncommon_constant =
165       builder.AddInstruction(HloInstruction::CreateConstant(
166           LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
167 
168   // Tie the constants together with a tuple. This makes it easier to refer to
169   // the constant instructions via their use.
170   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
171       {common_constant1, common_constant2, uncommon_constant}));
172 
173   auto module = CreateNewVerifiedModule();
174   auto computation = module->AddEntryComputation(builder.Build());
175 
176   EXPECT_EQ(4, computation->instruction_count());
177   EXPECT_THAT(tuple,
178               op::Tuple(common_constant1, common_constant2, uncommon_constant));
179 
180   HloCSE cse(/*is_layout_sensitive=*/false);
181   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
182 
183   EXPECT_EQ(3, computation->instruction_count());
184   auto first_operand = tuple->operand(0);
185   EXPECT_THAT(first_operand,
186               ::testing::AnyOf(common_constant1, common_constant2));
187   EXPECT_THAT(tuple,
188               op::Tuple(first_operand, first_operand, uncommon_constant));
189 }
190 
TEST_F(HloCseTest,IdenticalInstructions)191 TEST_F(HloCseTest, IdenticalInstructions) {
192   // Test that three identical instructions are commoned.
193   auto builder = HloComputation::Builder(TestName());
194   auto constant = builder.AddInstruction(
195       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
196   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
197       constant->shape(), HloOpcode::kExp, constant));
198   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
199       constant->shape(), HloOpcode::kExp, constant));
200   auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary(
201       constant->shape(), HloOpcode::kExp, constant));
202   auto tuple =
203       builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
204 
205   auto module = CreateNewVerifiedModule();
206   auto computation = module->AddEntryComputation(builder.Build());
207 
208   EXPECT_EQ(5, computation->instruction_count());
209   EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
210 
211   HloCSE cse(/*is_layout_sensitive=*/true);
212   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
213 
214   EXPECT_EQ(3, computation->instruction_count());
215   auto first_operand = tuple->operand(0);
216   EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3));
217   EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand));
218 }
219 
220 // Test two identical while loops with same inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesSameInput)221 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
222   const char* const hlo_string = R"(
223     HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
224 
225     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
226       %param = (f32[], f32[]) parameter(0)
227       %gte0 = get-tuple-element(%param), index=0
228       %gte1 = get-tuple-element(%param), index=1
229       %add = add(%gte0, %gte1)
230       ROOT %tuple = tuple(%gte0, %add)
231     }
232 
233     %condition {
234       %param.1 = (f32[], f32[]) parameter(0)
235       ROOT %constant = pred[] constant(false)
236     }
237 
238     %condition.1 {
239       %param.2 = (f32[], f32[]) parameter(0)
240       ROOT %constant.1 = pred[] constant(false)
241     }
242 
243     ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput {
244       %c0 = f32[] constant(1)
245       %c1 = f32[] constant(2)
246       %t = tuple(c0, c1)
247       %while = while(%t), condition=%condition, body=%body
248       %while.1 = while(%t), condition=%condition.1, body=%body
249       ROOT r = tuple(while, while.1)
250     })";
251 
252   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
253   auto computation = m->entry_computation();
254 
255   EXPECT_EQ(6, computation->instruction_count());
256   HloCSE cse(true);
257   EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
258   EXPECT_EQ(5, computation->instruction_count());
259 }
260 
261 // Test two while loops with same conditions, same inputs, but different
262 // bodies
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsSameInputAndDifferentBodies)263 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
264   const char* const hlo_string = R"(
265     HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
266 
267     %body {
268       %param = (f32[], f32[]) parameter(0)
269       %get-tuple-element = get-tuple-element(%param), index=0
270       %get-tuple-element.1 = get-tuple-element(%param), index=1
271       %add = add(%get-tuple-element, %get-tuple-element.1)
272       ROOT %tuple = tuple(%get-tuple-element, %add)
273     }
274 
275     %body2 {
276       %param.1 = (f32[], f32[]) parameter(0)
277       %get-tuple-element.2 = get-tuple-element(%param.1), index=0
278       %get-tuple-element.3 = get-tuple-element(%param.1), index=1
279       %sub = subtract(%get-tuple-element.2, %get-tuple-element.3)
280       ROOT %tuple.2 = tuple(%get-tuple-element.2, %sub)
281     }
282 
283     %condition (param.2: (f32[], f32[])) -> pred[] {
284       %param.2 = (f32[], f32[]) parameter(0)
285       ROOT %constant = pred[] constant(false)
286     }
287 
288     %condition.1 (param.3: (f32[], f32[])) -> pred[] {
289       %param.3 = (f32[], f32[]) parameter(0)
290       ROOT %constant.1 = pred[] constant(false)
291     }
292 
293     ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies {
294       %constant.2 = f32[] constant(1)
295       %constant.3 = f32[] constant(2)
296       %tuple.1 = tuple(f32[] %constant.2, f32[] %constant.3)
297       %while = while(%tuple.1), condition=%condition, body=%body
298       ROOT %while.1 = while(%tuple.1), condition=%condition.1, body=%body2
299     })";
300 
301   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
302   auto computation = m->entry_computation();
303 
304   EXPECT_EQ(5, computation->instruction_count());
305   HloCSE cse(true);
306   EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
307   EXPECT_EQ(5, computation->instruction_count());
308 }
309 
310 // Test two identical while loops with different inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesDifferentInput)311 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
312   const char* const hlo_string = R"(
313     HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
314 
315     %body {
316       %param = (f32[], f32[]) parameter(0)
317       %get-tuple-element = get-tuple-element(%param), index=0
318       %get-tuple-element.1 = get-tuple-element(%param), index=1
319       %add = add(%get-tuple-element, %get-tuple-element.1)
320       ROOT %tuple = tuple(%get-tuple-element, %add)
321     }
322 
323     %body.1 {
324       %param.1 = (f32[], f32[]) parameter(0)
325       %gte = get-tuple-element(%param.1), index=0
326       %gte1 = get-tuple-element(%param.1), index=1
327       %add.1 = add(%gte, %gte1)
328       ROOT %tuple = tuple(%gte, %add.1)
329     }
330 
331     %condition {
332       %param.1 = (f32[], f32[]) parameter(0)
333       ROOT %constant = pred[] constant(false)
334     }
335 
336     %condition.1 {
337       %param.2 = (f32[], f32[]) parameter(0)
338       ROOT %constant.1 = pred[] constant(false)
339     }
340 
341     ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput {
342       %constant.2 = f32[] constant(1)
343       %constant.3 = f32[] constant(2)
344       %tuple.1 =  tuple(%constant.2, %constant.3)
345       %while = while(%tuple.1), condition=%condition, body=%body
346       %constant.4 = f32[] constant(1)
347       %constant.5 = f32[] constant(3)
348       %tuple.2 = tuple(%constant.4, %constant.5)
349       ROOT %while.1 = while(%tuple.2), condition=%condition.1, body=%body.1
350     })";
351 
352   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
353   auto computation = m->entry_computation();
354 
355   EXPECT_EQ(8, computation->instruction_count());
356   HloCSE cse(true);
357   EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
358   EXPECT_EQ(8, computation->instruction_count());
359 }
360 
361 // Test two while loops with identical bodies and same inputs, but different
362 // conditions
TEST_F(HloCseTest,WhileLoopsIdenticalBodiesAndInputDifferentConditions)363 TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferentConditions) {
364   const char* const hlo_string = R"(
365     HloModule WhileLoopsIdenticalBodiesAndInputDifferentConditions
366 
367     %body {
368       %param = (f32[], f32[]) parameter(0)
369       %get-tuple-element = get-tuple-element(%param), index=0
370       %get-tuple-element.1 = get-tuple-element((f32[], f32[]) %param), index=1
371       %add = add(%get-tuple-element, %get-tuple-element.1)
372       ROOT %tuple = tuple(%get-tuple-element, %add)
373     }
374 
375     %condition {
376       %param.1 = (f32[], f32[]) parameter(0)
377       ROOT %constant = pred[] constant(false)
378     }
379 
380     %condition.1 {
381       %param.2 = (f32[], f32[]) parameter(0)
382       ROOT %constant.1 = pred[] constant(true)
383     }
384 
385     ENTRY %WhileLoopsIdenticalBodiesAndInputDifferentConditions {
386       %constant.2 = f32[] constant(1)
387       %constant.3 = f32[] constant(2)
388       %tuple.1 = tuple(%constant.2, %constant.3)
389       %while = while(%tuple.1), condition=%condition, body=%body
390       ROOT %while.1 = while(%tuple.1), condition=%condition.1, body=%body
391     })";
392 
393   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
394   auto computation = m->entry_computation();
395 
396   EXPECT_EQ(5, computation->instruction_count());
397   HloCSE cse(true);
398   EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
399   EXPECT_EQ(5, computation->instruction_count());
400 }
401 
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsSensitive)402 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
403   // Test that two identical instructions with different layouts are *not*
404   // commoned if the pass is layout sensitive.
405   auto builder = HloComputation::Builder(TestName());
406   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
407       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
408 
409   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
410       constant->shape(), HloOpcode::kExp, constant));
411   *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
412 
413   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
414       constant->shape(), HloOpcode::kExp, constant));
415   *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
416 
417   auto tuple =
418       builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
419 
420   auto module = CreateNewVerifiedModule();
421   auto computation = module->AddEntryComputation(builder.Build());
422 
423   EXPECT_EQ(4, computation->instruction_count());
424   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
425 
426   HloCSE cse(/*is_layout_sensitive=*/true);
427   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
428 
429   EXPECT_EQ(4, computation->instruction_count());
430   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
431 }
432 
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsInsensitive)433 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
434   // Test that two identical instructions with different layouts are commoned if
435   // the pass is layout insensitive.
436   auto builder = HloComputation::Builder(TestName());
437   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
438       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
439 
440   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
441       constant->shape(), HloOpcode::kExp, constant));
442   *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
443 
444   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
445       constant->shape(), HloOpcode::kExp, constant));
446   *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447 
448   auto tuple =
449       builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
450 
451   auto module = CreateNewVerifiedModule();
452   auto computation = module->AddEntryComputation(builder.Build());
453 
454   EXPECT_EQ(4, computation->instruction_count());
455   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
456 
457   HloCSE cse(/*is_layout_sensitive=*/false);
458   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
459 
460   EXPECT_EQ(3, computation->instruction_count());
461   auto first_operand = tuple->operand(0);
462   EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2));
463   EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand));
464 }
465 
TEST_F(HloCseTest,FusionInternalCSE)466 TEST_F(HloCseTest, FusionInternalCSE) {
467   // Test that we can CSE expressions that live within a fusion node
468   // computation.
469   auto module = CreateNewVerifiedModule();
470   auto builder = HloComputation::Builder(TestName());
471 
472   const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
473   auto param0 = builder.AddInstruction(
474       HloInstruction::CreateParameter(0, shape_r0, "p0"));
475   auto param1 = builder.AddInstruction(
476       HloInstruction::CreateParameter(1, shape_r0, "p1"));
477   auto add1 = builder.AddInstruction(
478       HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
479   auto add2 = builder.AddInstruction(
480       HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
481   auto mul = builder.AddInstruction(
482       HloInstruction::CreateBinary(shape_r0, HloOpcode::kMultiply, add1, add2));
483 
484   auto computation = module->AddEntryComputation(builder.Build());
485   auto fused_computation =
486       computation
487           ->CreateFusionInstruction({mul, add1, add2},
488                                     HloInstruction::FusionKind::kLoop)
489           ->fused_instructions_computation();
490 
491   EXPECT_EQ(5, fused_computation->instruction_count());
492   HloCSE cse(/*is_layout_sensitive=*/false);
493   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
494   EXPECT_EQ(4, fused_computation->instruction_count());
495 
496   auto root = fused_computation->root_instruction();
497   EXPECT_THAT(root, op::Multiply(root->operand(0), root->operand(0)));
498 }
499 
TEST_F(HloCseTest,IdenticalExpressions)500 TEST_F(HloCseTest, IdenticalExpressions) {
501   // Test that two identical expressions are commoned. Build the following
502   // computation:
503   //
504   //   constant = 42.0
505   //   negate1 = neg(constant)
506   //   exp1 = exp(constant)
507   //   add1 = add(negate1, exp1)
508   //   negate2 = neg(constant)
509   //   exp2 = exp(constant)
510   //   add2 = add(negate2, exp2)
511   //   tuple = tuple(add1, add2)
512   //
513   // The *1 instructions should be merged with the *2 instructions.
514   auto builder = HloComputation::Builder(TestName());
515   auto constant = builder.AddInstruction(
516       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
517 
518   auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
519       constant->shape(), HloOpcode::kNegate, constant));
520   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
521       constant->shape(), HloOpcode::kExp, constant));
522   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
523       constant->shape(), HloOpcode::kAdd, negate1, exp1));
524 
525   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
526       constant->shape(), HloOpcode::kNegate, constant));
527   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
528       constant->shape(), HloOpcode::kExp, constant));
529   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
530       constant->shape(), HloOpcode::kAdd, negate2, exp2));
531 
532   auto tuple =
533       builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
534 
535   auto module = CreateNewVerifiedModule();
536   auto computation = module->AddEntryComputation(builder.Build());
537 
538   EXPECT_EQ(8, computation->instruction_count());
539   EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
540 
541   HloCSE cse(/*is_layout_sensitive=*/false);
542   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
543 
544   EXPECT_EQ(5, computation->instruction_count());
545   auto operand = tuple->operand(0);
546   EXPECT_THAT(tuple, op::Tuple(operand, operand));
547   EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp()));
548 }
549 
TEST_F(HloCseTest,DoNotCombineRng)550 TEST_F(HloCseTest, DoNotCombineRng) {
551   // Test that two RNG ops are not commoned.
552   auto builder = HloComputation::Builder(TestName());
553   auto constant1 = builder.AddInstruction(
554       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
555   auto constant2 = builder.AddInstruction(
556       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
557   auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
558       ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
559       {constant1, constant2}));
560   auto rng2 = builder.AddInstruction(HloInstruction::CreateRng(
561       ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
562       {constant1, constant2}));
563 
564   builder.AddInstruction(HloInstruction::CreateBinary(
565       constant1->shape(), HloOpcode::kAdd, rng1, rng2));
566 
567   auto module = CreateNewVerifiedModule();
568   auto computation = module->AddEntryComputation(builder.Build());
569 
570   HloInstruction* root = computation->root_instruction();
571   EXPECT_THAT(root, op::Add(rng1, rng2));
572 
573   uint32_t count_before = computation->instruction_count();
574 
575   HloCSE cse(/*is_layout_sensitive=*/false);
576   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
577 
578   uint32_t count_after = computation->instruction_count();
579   EXPECT_EQ(count_before, count_after);
580   root = computation->root_instruction();
581   EXPECT_THAT(root, op::Add(rng1, rng2));
582 }
583 
TEST_F(HloCseTest,DoNotCombineCallsToImpureFunctions)584 TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
585   // Test that two calls to an impure function are not commoned. RNG
586   // is the source of the impurity.
587 
588   auto module = CreateNewVerifiedModule();
589 
590   // rng_function is an impure function because it does RNG.
591   HloComputation* rng_function = nullptr;
592   {
593     Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
594     auto builder = HloComputation::Builder(TestName() + "_rng_fun");
595     auto constant1 = builder.AddInstruction(
596         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
597     auto constant2 = builder.AddInstruction(
598         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
599     auto rng = builder.AddInstruction(HloInstruction::CreateRng(
600         scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
601     auto param = builder.AddInstruction(HloInstruction::CreateParameter(
602         0, ShapeUtil::MakeShape(F32, {}), "param"));
603     builder.AddInstruction(HloInstruction::CreateBinary(
604         scalar_shape, HloOpcode::kAdd, rng, param));
605     rng_function = module->AddEmbeddedComputation(builder.Build());
606   }
607 
608   // Computation calls rng_function twice with the same parameter.
609   HloComputation* computation = nullptr;
610   {
611     auto builder = HloComputation::Builder(TestName());
612     auto constant = builder.AddInstruction(
613         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
614     auto rng1 = builder.AddInstruction(
615         HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
616     auto rng2 = builder.AddInstruction(
617         HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
618     builder.AddInstruction(HloInstruction::CreateBinary(
619         constant->shape(), HloOpcode::kAdd, rng1, rng2));
620     computation = module->AddEntryComputation(builder.Build());
621   }
622 
623   EXPECT_EQ(4, computation->instruction_count());
624   HloInstruction* root = computation->root_instruction();
625   EXPECT_THAT(root, op::Add(op::Map(), op::Map()));
626 
627   VLOG(3) << "before: " << module->ToString();
628 
629   HloCSE cse(/*is_layout_sensitive=*/false);
630   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
631 
632   VLOG(3) << "after: " << module->ToString();
633 
634   EXPECT_EQ(4, computation->instruction_count());
635   root = computation->root_instruction();
636   EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant())));
637 }
638 
TEST_F(HloCseTest,CompareComputations)639 TEST_F(HloCseTest, CompareComputations) {
640   const char* const hlo_string = R"(
641     HloModule m
642 
643     add_computation {
644       add_lhs = f32[] parameter(0)
645       add_rhs = f32[] parameter(1)
646       ROOT add_root = add(add_lhs, add_rhs)
647     }
648 
649     add_computation2 {
650       add_lhs2 = f32[] parameter(0)
651       add_rhs2 = f32[] parameter(1)
652       ROOT add_root2 = add(add_lhs2, add_rhs2)
653     }
654 
655     ENTRY entry {
656       p = f32[10]{0} parameter(0)
657       c = f32[] constant(0)
658       r1 = reduce(p, c), dimensions={0}, to_apply=add_computation
659       r2 = reduce(p, c), dimensions={0}, to_apply=add_computation2
660       ROOT f2 = tuple(r1, r2)
661     })";
662 
663   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
664   HloCSE cse(/*is_layout_sensitive=*/false);
665   EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
666   HloInstruction* root = m->entry_computation()->root_instruction();
667   EXPECT_EQ(root->operand(0), root->operand(1));
668 }
669 
TEST_F(HloCseTest,ConstantsSameValueInDifferentDomains)670 TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
671   // Test that constants and iotas with the same value but in different domains
672   // (disjoint in this case) are not collapsed.
673   auto builder = HloComputation::Builder(TestName());
674   builder.AddInstruction(
675       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42)));
676   builder.AddInstruction(
677       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(42)));
678   builder.AddInstruction(
679       HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), 0));
680   builder.AddInstruction(
681       HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), 0));
682 
683   auto module = CreateNewVerifiedModule();
684   auto computation = module->AddEntryComputation(builder.Build());
685 
686   EXPECT_EQ(4, computation->instruction_count());
687 
688   HloCSE cse(/*is_layout_sensitive=*/false);
689   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
690 
691   EXPECT_EQ(4, computation->instruction_count());
692 }
693 
TEST_F(HloCseTest,Domain)694 TEST_F(HloCseTest, Domain) {
695   const char* const hlo_string = R"(
696 HloModule module
697 ENTRY %entry {
698   %param = f32[] parameter(0), sharding={maximal device=0}
699   %domain.0 = f32[] domain(%param),
700     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
701   %domain.1 = f32[] domain(%param),
702     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
703   %domain.2 = f32[] domain(%param),
704     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
705   %negate.0 = f32[] negate(%domain.0)
706   %negate.1 = f32[] negate(%domain.1)
707   %negate.2 = f32[] negate(%domain.2)
708   %domain.3 = f32[] domain(%negate.0),
709     domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
710   %domain.4 = f32[] domain(%negate.1),
711     domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
712   %domain.5 = f32[] domain(%negate.2),
713     domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
714   %add = f32[] add(%domain.3, %domain.4)
715   ROOT %sub = f32[] subtract(%add, %domain.5)
716 })";
717 
718   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
719   HloCSE cse(/*is_layout_sensitive=*/false);
720   EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
721   const HloInstruction* sub = m->entry_computation()->root_instruction();
722   const HloInstruction* add = sub->operand(0);
723   EXPECT_EQ(add->operand(0), add->operand(1));
724   EXPECT_NE(add->operand(0), sub->operand(1));
725   EXPECT_NE(add->operand(1), sub->operand(1));
726 }
727 
TEST_F(HloCseTest,Iota)728 TEST_F(HloCseTest, Iota) {
729   const char* const hlo_string = R"(
730     HloModule m
731 
732     ENTRY entry {
733       i1 = s64[16,16] iota(), iota_dimension=0
734       i2 = s64[16,16] iota(), iota_dimension=0
735       i3 = s64[17,16] iota(), iota_dimension=0
736       i4 = s64[16,16] iota(), iota_dimension=1
737       ROOT root = (s64[16,16], s64[16,16], s64[17,16], s64[16,16]) tuple(i1, i2, i3, i4)
738     })";
739 
740   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
741   HloCSE cse(/*is_layout_sensitive=*/false);
742   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
743   EXPECT_TRUE(changed);
744   HloInstruction* root = m->entry_computation()->root_instruction();
745   EXPECT_EQ(root->operand(0), root->operand(1));
746   EXPECT_NE(root->operand(0), root->operand(2));
747   EXPECT_NE(root->operand(0), root->operand(3));
748 }
749 
TEST_F(HloCseTest,OptimizationBarrier)750 TEST_F(HloCseTest, OptimizationBarrier) {
751   const char* const hlo_string = R"(
752     HloModule m
753 
754     ENTRY entry {
755       %param.0 = f32[] parameter(0)
756       %param.1 = f32[] parameter(1)
757       %add.0 = f32[] add(%param.0, %param.1)
758       %cse_tmp.0 = (f32[], f32[], f32[]) tuple(%param.0, %param.1, %add.0)
759       %cse_tmp.1 = (f32[], f32[], f32[]) opt-barrier(%cse_tmp.0)
760 
761       %param.0.1 = f32[] get-tuple-element(%cse_tmp.1), index=0
762       %param.1.1 = f32[] get-tuple-element(%cse_tmp.1), index=1
763       %add.0.1 = f32[] get-tuple-element(%cse_tmp.1), index=2
764 
765       %add.1 = f32[] add(%param.0.1, %param.1.1)
766       ROOT %add.2 = f32[] add(%add.1, %add.0.1)
767     })";
768 
769   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
770   HloCSE cse(/*is_layout_sensitive=*/false);
771   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
772   EXPECT_FALSE(changed);
773 }
774 
775 class HloCseCustomCallTest
776     : public HloCseTest,
777       public ::testing::WithParamInterface<std::tuple<
778           std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>> {};
779 
TEST_P(HloCseCustomCallTest,DoIt)780 TEST_P(HloCseCustomCallTest, DoIt) {
781   std::string op1 = std::get<0>(GetParam());
782   std::string op2 = std::get<1>(GetParam());
783   bool should_cse = std::get<2>(GetParam());
784 
785   const char* const hlo_string_tmpl = R"(
786     HloModule m
787     ENTRY entry {
788       p0 = f32[1,1,1] parameter(0)
789 
790       op0 = $0
791       op1 = $0
792       op2 = $1
793       ROOT root = tuple(op0, op1, op2)
794     }
795   )";
796   std::string hlo_string = absl::Substitute(hlo_string_tmpl, op1, op2);
797   SCOPED_TRACE(absl::StrCat("Module before CSE:\n", hlo_string));
798 
799   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
800   HloCSE cse(/*is_layout_sensitive=*/false);
801   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
802 
803   SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
804   EXPECT_EQ(changed, true);  // we always CSE op0 and op1, which are identical.
805   HloInstruction* root = m->entry_computation()->root_instruction();
806   EXPECT_EQ(root->operand(0), root->operand(1))
807       << "Identical ops should be CSE'ed";
808   if (should_cse) {
809     EXPECT_EQ(root->operand(0), root->operand(2)) << "Ops should be CSE'ed";
810   } else {
811     EXPECT_NE(root->operand(0), root->operand(2)) << "Ops should not be CSE'ed";
812   }
813 }
814 
815 static std::vector<
816     std::tuple<std::string /*op1*/, std::string /*op2*/, bool /*should_cse*/>>
CustomCallTests()817 CustomCallTests() {
818   auto build = [](absl::string_view args1, absl::string_view args2) {
819     absl::string_view prefix =
820         "f32[] custom-call(p0), custom_call_target=\"foo\", ";
821     return std::make_tuple(absl::StrCat(prefix, args1),
822                            absl::StrCat(prefix, args2), false);
823   };
824   return {
825       {
826           // metadata shouldn't prevent CSE
827           "f32[] custom-call(p0), custom_call_target=\"foo\"",
828           "f32[] custom-call(p0), custom_call_target=\"foo\", "
829           "metadata={op_name=\"bar\"}",
830           true,
831       },
832       {
833           "f32[] custom-call(p0), custom_call_target=\"foo\"",
834           "f32[] custom-call(p0, p0), custom_call_target=\"foo\"",
835           false,
836       },
837       {
838           "f32[1] custom-call(p0), custom_call_target=\"foo\"",
839           "f32[2] custom-call(p0), custom_call_target=\"foo\"",
840           false,
841       },
842       {
843           "f32[] custom-call(p0), custom_call_target=\"foo\"",
844           "f32[] custom-call(p0), custom_call_target=\"bar\"",
845           false,
846       },
847 
848       build("window={size=1}", "window={size=2}"),
849       build("dim_labels=b0f_0oi->b0f", "dim_labels=b0f_0oi->bf0"),
850       build("backend_config=\"foo\"", "backend_config=\"bar\""),
851       build("literal=s32[] 0", "literal=s32[] 1"),
852       build("literal=s32[] 0", "literal=f32[] 0"),
853       build("operand_precision={high,default}",
854             "operand_precision={high, high}"),
855       build("api_version=API_VERSION_STATUS_RETURNING",
856             "api_version=API_VERSION_ORIGINAL"),
857       build("feature_group_count=0", "feature_group_count=1"),
858   };
859 }
860 
861 INSTANTIATE_TEST_SUITE_P(HloCseCustomCallTestSuite, HloCseCustomCallTest,
862                          ::testing::ValuesIn(CustomCallTests()));
863 
TEST_F(HloCseTest,CustomCallCalledComputations)864 TEST_F(HloCseTest, CustomCallCalledComputations) {
865   const char* const hlo_string = R"(
866     HloModule m
867 
868     comp {
869       lhs = f32[] parameter(0)
870       rhs = f32[] parameter(1)
871       ROOT maximum = f32[] maximum(lhs, rhs)
872     }
873 
874     ENTRY entry {
875       p0 = f32[] parameter(0)
876 
877       op0 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp}
878       op1 = f32[] custom-call(p0), custom_call_target="foo", called_computations={comp, comp}
879       ROOT root = tuple(op0, op1)
880     }
881   )";
882 
883   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
884   HloCSE cse(/*is_layout_sensitive=*/false);
885   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
886 
887   SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
888   EXPECT_EQ(changed, false);
889 }
890 
TEST_F(HloCseTest,CustomCallSideEffects)891 TEST_F(HloCseTest, CustomCallSideEffects) {
892   const char* const hlo_string = R"(
893     HloModule m
894 
895     ENTRY entry {
896       p0 = f32[] parameter(0)
897 
898       op0 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
899       op1 = f32[] custom-call(p0), custom_call_target="foo", custom_call_has_side_effect=true
900       ROOT root = tuple(op0, op1)
901     }
902   )";
903 
904   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
905   HloCSE cse(/*is_layout_sensitive=*/false);
906   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&cse, m.get()));
907 
908   SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString()));
909   EXPECT_EQ(changed, false);
910 }
911 
912 class HloCseCommutativeOpTest
913     : public HloCseTest,
914       public ::testing::WithParamInterface<std::string /*op*/> {};
915 
TEST_P(HloCseCommutativeOpTest,DoIt)916 TEST_P(HloCseCommutativeOpTest, DoIt) {
917   std::string op = GetParam();
918   const char* kModuleStr = R"(
919     HloModule m
920 
921     ENTRY test {
922       p0 = s32[10] parameter(0)
923       p1 = s32[10] parameter(1)
924       op1 = s32[10] $0(p0, p1)
925       op2 = s32[10] $0(p1, p0)
926       ROOT t = tuple(op1, op2)
927     }
928   )";
929   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
930                                            absl::Substitute(kModuleStr, op)));
931   ASSERT_TRUE(
932       HloCSE(/*is_layout_sensitive=*/false).Run(module.get()).ValueOrDie());
933   SCOPED_TRACE(module->ToString());
934 
935   const HloInstruction* op0;
936   const HloInstruction* op1;
937   ASSERT_THAT(module->entry_computation()->root_instruction(),
938               GmockMatch(m::Tuple(m::Op(&op0), m::Op(&op1))));
939   EXPECT_EQ(op0, op1);
940 }
941 
942 INSTANTIATE_TEST_SUITE_P(AlgebraicSimplifierCanonicalizeCommutativeTestSuite,
943                          HloCseCommutativeOpTest,
944                          ::testing::Values("add", "multiply", "and", "or",
945                                            "xor", "minimum", "maximum"));
946 
947 }  // namespace
948 }  // namespace xla
949