xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/strings/str_replace.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/test.h"
40 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/window_util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/statusor.h"
46 
47 namespace xla {
48 namespace {
49 
50 using ::testing::ElementsAre;
51 namespace m = match;
52 
53 class AlgebraicSimplifierTest : public HloTestBase {
54  protected:
55   AlgebraicSimplifierOptions default_options_;
56 };
57 
58 // Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,AddZero)59 TEST_F(AlgebraicSimplifierTest, AddZero) {
60   auto m = CreateNewVerifiedModule();
61   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
62   HloComputation::Builder builder(TestName());
63   HloInstruction* param0 = builder.AddInstruction(
64       HloInstruction::CreateParameter(0, r0f32, "param0"));
65   HloInstruction* zero = builder.AddInstruction(
66       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
67   builder.AddInstruction(
68       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
69 
70   auto computation = m->AddEntryComputation(builder.Build());
71   HloInstruction* root = computation->root_instruction();
72   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
73   AlgebraicSimplifier simplifier(default_options_);
74   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
75   root = computation->root_instruction();
76   EXPECT_EQ(root, param0);
77 }
78 
TEST_F(AlgebraicSimplifierTest,FactorIntegerAddition)79 TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) {
80   const char* kModuleStr = R"(
81     HloModule m
82     test {
83       p0 = s32[8] parameter(0)
84       p1 = s32[8] parameter(1)
85       p2 = s32[8] parameter(2)
86       x = s32[8] multiply(p0, p2)
87       y = s32[8] multiply(p1, p2)
88       ROOT sum = s32[8] add(x, y)
89     }
90   )";
91   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
92   AlgebraicSimplifier simplifier(default_options_);
93   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
94   EXPECT_THAT(
95       m->entry_computation()->root_instruction(),
96       GmockMatch(m::MultiplyAnyOrder(
97           m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2))));
98 }
99 
100 // A*C + B*C => (A+B)*C if C is a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAddition)101 TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
102   const char* kModuleStr = R"(
103     HloModule m
104     test {
105       p0 = f32[] parameter(0)
106       p1 = f32[] parameter(1)
107       c = f32[] constant(0.125)
108       x = f32[] multiply(p0, c)
109       y = f32[] multiply(p1, c)
110       ROOT sum = f32[] add(x, y)
111     }
112   )";
113   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
114   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
115   EXPECT_THAT(m->entry_computation()->root_instruction(),
116               GmockMatch(m::MultiplyAnyOrder(
117                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
118                   m::ConstantScalar(0.125))));
119 }
120 
121 // (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest,SquareOfAbs)122 TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
123   const char* kModuleStr = R"(
124     HloModule m
125     test {
126       p = f32[] parameter(0)
127       a = f32[] abs(p)
128       ROOT z = f32[] multiply(a, a)
129     }
130   )";
131   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
132   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
133   EXPECT_THAT(m->entry_computation()->root_instruction(),
134               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
135 }
136 
137 // (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain)138 TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
139   const char* kModuleStr = R"(
140     HloModule m
141     test {
142       p0 = f32[] parameter(0)
143       p1 = f32[] parameter(1)
144       c = f32[] constant(2)
145       d = f32[] constant(4)
146       x = f32[] multiply(p0, c)
147       y = f32[] multiply(p1, d)
148       ROOT z = f32[] multiply(x, y)
149     }
150   )";
151   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
152   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
153   EXPECT_THAT(
154       m->entry_computation()->root_instruction(),
155       GmockMatch(m::MultiplyAnyOrder(
156           m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)),
157           m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4)))));
158 }
159 
160 // (a*C1)*C2 => a*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain2)161 TEST_F(AlgebraicSimplifierTest, MultiplyChain2) {
162   const char* kModuleStr = R"(
163     HloModule m
164     test {
165       p0 = f32[] parameter(0)
166       a = f32[] constant(2)
167       b = f32[] constant(4)
168       c = f32[] multiply(p0, a)
169       ROOT y = f32[] multiply(c, b)
170     }
171   )";
172   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
173   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
174   EXPECT_THAT(m->entry_computation()->root_instruction(),
175               GmockMatch(m::MultiplyAnyOrder(
176                   m::Parameter(0), m::MultiplyAnyOrder(m::ConstantScalar(2),
177                                                        m::ConstantScalar(4)))));
178 }
179 
180 // MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==>
181 // MUL(X, BROADCAST(MUL(Y, BROADCAST(constant))))
TEST_F(AlgebraicSimplifierTest,MultiplyBroadcastReassoc)182 TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) {
183   const char* kModuleStr = R"(
184     HloModule m
185     test {
186       p0 = f32[2,2] parameter(0)
187       p1 = f32[] parameter(1)
188       b = f32[] constant(2)
189       c = f32[2, 2] broadcast(b), dimensions={}
190       x = f32[2,2] multiply(p0, c)
191       y = f32[2,2] broadcast(p1), dimensions={}
192       ROOT z = f32[2,2] multiply(y, x)
193     }
194   )";
195   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
196   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
197   EXPECT_THAT(m->entry_computation()->root_instruction(),
198               GmockMatch(m::MultiplyAnyOrder(
199                   m::Parameter(0), m::Broadcast(m::MultiplyAnyOrder(
200                                        m::Parameter(1), m::Constant())))));
201 }
202 
203 // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionWithBroadcast)204 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
205   const char* kModuleStr = R"(
206     HloModule m
207     test {
208       p0 = f32[4] parameter(0)
209       p1 = f32[4] parameter(1)
210       c = f32[] constant(0.125)
211       b = f32[4] broadcast(c), dimensions={}
212       x = f32[4] multiply(p0, b)
213       y = f32[4] multiply(p1, b)
214       ROOT sum = f32[4] add(x, y)
215     }
216   )";
217   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
218   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
219   EXPECT_THAT(m->entry_computation()->root_instruction(),
220               GmockMatch(m::MultiplyAnyOrder(
221                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
222                   m::Broadcast(m::ConstantScalar(0.125)))));
223 }
224 
225 // A*C + B*C => (A+B)*C simplification should not happen if C is not a
226 // floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionNotPowerOf2)227 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) {
228   const char* kModuleStr = R"(
229     HloModule m
230     test {
231       p0 = f32[] parameter(0)
232       p1 = f32[] parameter(1)
233       c = f32[] constant(0.3)
234       x = f32[] multiply(p0, c)
235       y = f32[] multiply(p1, c)
236       ROOT sum = f32[] add(x, y)
237     }
238   )";
239   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
240   EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
241 }
242 
243 // A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are
244 // complex numbers.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionComplex)245 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) {
246   const char* kModuleStr = R"(
247     HloModule m
248     test {
249       p0 = c64[8] parameter(0)
250       p1 = c64[8] parameter(1)
251       p2 = c64[8] parameter(2)
252       x = c64[8] multiply(p0, p2)
253       y = c64[8] multiply(p1, p2)
254       ROOT sum = c64[8] add(x, y)
255     }
256   )";
257   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
258   EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
259 }
260 
261 // A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionBfloat16)262 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) {
263   const char* kModuleStr = R"(
264     HloModule m
265     test {
266       p0 = bf16[4] parameter(0)
267       p1 = bf16[4] parameter(1)
268       c = bf16[] constant(0.125)
269       b = bf16[4] broadcast(c), dimensions={}
270       x = bf16[4] multiply(p0, b)
271       y = bf16[4] multiply(p1, b)
272       ROOT sum = bf16[4] add(x, y)
273     }
274   )";
275   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
276   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
277   EXPECT_THAT(m->entry_computation()->root_instruction(),
278               GmockMatch(m::MultiplyAnyOrder(
279                   m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
280                   m::Broadcast(m::ConstantScalar(0.125)))));
281 }
282 
TEST_F(AlgebraicSimplifierTest,UnsignedDivideByPowerOf2)283 TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) {
284   const char* kModuleStr = R"(
285     HloModule m
286     test {
287       p = u32[4] parameter(0)
288       c = u32[] constant(8)
289       b = u32[4] broadcast(c), dimensions={}
290       ROOT d = u32[4] divide(p, b)
291     }
292   )";
293   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
294   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
295   EXPECT_THAT(m->entry_computation()->root_instruction(),
296               GmockMatch(m::ShiftRightLogical(
297                   m::Parameter(0), m::Broadcast(m::ConstantScalar(3)))));
298 }
299 
TEST_F(AlgebraicSimplifierTest,SignedDivideByPowerOf2)300 TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) {
301   const char* kModuleStr = R"(
302     HloModule m
303     test {
304       p = s32[4] parameter(0)
305       c = s32[] constant(8)
306       b = s32[4] broadcast(c), dimensions={}
307       ROOT d = s32[4] divide(p, b)
308     }
309   )";
310   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
311   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
312   auto match_dividend_is_negative =
313       m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
314   auto match_abs = m::Select(match_dividend_is_negative,
315                              m::Negate(m::Parameter(0)), m::Parameter(0));
316   auto match_shift =
317       m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3)));
318   EXPECT_THAT(m->entry_computation()->root_instruction(),
319               GmockMatch(m::Select(match_dividend_is_negative,
320                                    m::Negate(match_shift), match_shift)));
321 }
322 
TEST_F(AlgebraicSimplifierTest,UnsignedRemainderByPowerOf2)323 TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) {
324   const char* kModuleStr = R"(
325     HloModule m
326     test {
327       p = u32[4] parameter(0)
328       c = u32[] constant(8)
329       b = u32[4] broadcast(c), dimensions={}
330       ROOT r = u32[4] remainder(p, b)
331     }
332   )";
333   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
334   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
335   EXPECT_THAT(m->entry_computation()->root_instruction(),
336               GmockMatch(m::AndAnyOrder(m::Parameter(0),
337                                         m::Broadcast(m::ConstantScalar(7)))));
338 }
339 
TEST_F(AlgebraicSimplifierTest,SignedRemainderByPowerOf2)340 TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) {
341   const char* kModuleStr = R"(
342     HloModule m
343     test {
344       p = s32[4] parameter(0)
345       c = s32[] constant(8)
346       b = s32[4] broadcast(c), dimensions={}
347       ROOT r = s32[4] remainder(p, b)
348     }
349   )";
350   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
351   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
352   auto match_dividend_is_negative =
353       m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
354   auto match_abs = m::Select(match_dividend_is_negative,
355                              m::Negate(m::Parameter(0)), m::Parameter(0));
356   auto match_and =
357       m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7)));
358   EXPECT_THAT(m->entry_computation()->root_instruction(),
359               GmockMatch(m::Select(match_dividend_is_negative,
360                                    m::Negate(match_and), match_and)));
361 }
362 
363 // Test that A * 0 is simplified to 0
TEST_F(AlgebraicSimplifierTest,MulZero)364 TEST_F(AlgebraicSimplifierTest, MulZero) {
365   auto m = CreateNewVerifiedModule();
366   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
367   HloComputation::Builder builder(TestName());
368   HloInstruction* param0 = builder.AddInstruction(
369       HloInstruction::CreateParameter(0, r0s32, "param0"));
370   HloInstruction* zero = builder.AddInstruction(
371       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
372   builder.AddInstruction(
373       HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
374 
375   auto computation = m->AddEntryComputation(builder.Build());
376   HloInstruction* root = computation->root_instruction();
377   EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
378   AlgebraicSimplifier simplifier(default_options_);
379   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
380   EXPECT_EQ(computation->root_instruction(), zero);
381 }
382 
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeConstants)383 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) {
384   const char* kModuleStr = R"(
385     HloModule m
386     test {
387       p0 = f32[] parameter(0)
388       c0 = f32[] constant(2.0)
389       c1 = f32[] constant(3.0)
390       multiply0 = f32[] multiply(p0, c0)
391       ROOT multiply1 = f32[] multiply(multiply0, c1)
392     }
393   )";
394   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
395   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
396   EXPECT_THAT(m->entry_computation()->root_instruction(),
397               GmockMatch(m::Multiply(m::Parameter(0),
398                                      m::Multiply(m::ConstantScalar(2.0),
399                                                  m::ConstantScalar(3.0)))));
400 }
401 
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeBroadcastedConstants)402 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
403   const char* kModuleStr = R"(
404     HloModule m
405     test {
406       p0 = f32[4] parameter(0)
407       c0 = f32[] constant(2.0)
408       c1 = f32[] constant(3.0)
409       b0 = f32[4] broadcast(c0), dimensions={}
410       b1 = f32[4] broadcast(c1), dimensions={}
411       multiply0 = f32[4] multiply(p0, b0)
412       ROOT multiply1 = f32[4] multiply(multiply0, b1)
413     }
414   )";
415   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
416   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
417   EXPECT_THAT(
418       m->entry_computation()->root_instruction(),
419       GmockMatch(m::Multiply(
420           m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0),
421                                                     m::ConstantScalar(3.0))))));
422 }
423 
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsScalar)424 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
425   const char* kModuleStr = R"(
426     HloModule m
427     test {
428       p0 = f32[] parameter(0)
429       p1 = f32[] parameter(1)
430       b0 = f32[4] broadcast(p0), dimensions={}
431       b1 = f32[4] broadcast(p1), dimensions={}
432       ROOT multiply = f32[4] multiply(b1, b0)
433     }
434   )";
435   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
436   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
437   EXPECT_THAT(
438       m->entry_computation()->root_instruction(),
439       GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
440                                           m::Broadcast(m::Parameter(0))))));
441 }
442 
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsConstantMix)443 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
444   const char* kModuleStr = R"(
445     HloModule m
446     test {
447       p0 = f32[4] parameter(0)
448       c0 = f32[] constant(2.0)
449       b0 = f32[4,2] broadcast(c0), dimensions={}
450       b1 = f32[4,2] broadcast(p0), dimensions={0}
451       ROOT multiply = f32[4,2] multiply(b1, b0)
452     }
453   )";
454   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
455   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
456   EXPECT_THAT(m->entry_computation()->root_instruction(),
457               GmockMatch(m::Broadcast(m::Multiply(
458                   m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
459 }
460 
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsNonScalar)461 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
462   const char* kModuleStr = R"(
463     HloModule m
464     test {
465       p0 = f32[4] parameter(0)
466       p1 = f32[4] parameter(1)
467       b0 = f32[4,2] broadcast(p0), dimensions={0}
468       b1 = f32[4,2] broadcast(p1), dimensions={0}
469       ROOT multiply = f32[4,2] multiply(b1, b0)
470     }
471   )";
472   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
473   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
474   EXPECT_THAT(
475       m->entry_computation()->root_instruction(),
476       GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
477 }
478 
TEST_F(AlgebraicSimplifierTest,ElementwiseNoSinkBroadcastsDifferentDims)479 TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
480   const char* kModuleStr = R"(
481     HloModule m
482     test {
483       p0 = f32[4] parameter(0)
484       p1 = f32[8] parameter(1)
485       b0 = f32[4,8] broadcast(p0), dimensions={0}
486       b1 = f32[4,8] broadcast(p1), dimensions={1}
487       ROOT multiply = f32[4,8] multiply(b1, b0)
488     }
489   )";
490   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
491   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
492   EXPECT_THAT(m->entry_computation()->root_instruction(),
493               GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
494                                      m::Broadcast(m::Parameter(0)))));
495 }
496 
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMultiplyOfConstantAndBroadcast)497 TEST_F(AlgebraicSimplifierTest,
498        MultiplyReassociateMultiplyOfConstantAndBroadcast) {
499   const char* kModuleStr = R"(
500     HloModule m
501     test {
502       c0 = f32[4] constant({2.0, 3.0, 4.0, 5.0})
503       c1 = f32[] constant(3.0)
504       c2 = f32[] constant(4.0)
505       b0 = f32[4] broadcast(c1), dimensions={}
506       b1 = f32[4] broadcast(c2), dimensions={}
507       multiply0 = f32[4] multiply(c0, b0)
508       ROOT multiply1 = f32[4] multiply(multiply0, b1)
509     }
510   )";
511   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
512   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
513   EXPECT_THAT(
514       m->entry_computation()->root_instruction(),
515       GmockMatch(m::Multiply(
516           m::Constant(), m::Broadcast(m::Multiply(m::ConstantScalar(3.0),
517                                                   m::ConstantScalar(4.0))))));
518 }
519 
520 // Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest,SelectTrue)521 TEST_F(AlgebraicSimplifierTest, SelectTrue) {
522   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
523   HloComputation::Builder builder(TestName());
524   HloInstruction* param0 = builder.AddInstruction(
525       HloInstruction::CreateParameter(0, r0s32, "param0"));
526   HloInstruction* param1 = builder.AddInstruction(
527       HloInstruction::CreateParameter(1, r0s32, "param1"));
528   HloInstruction* one = builder.AddInstruction(
529       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
530   builder.AddInstruction(HloInstruction::CreateTernary(
531       r0s32, HloOpcode::kSelect, one, param0, param1));
532 
533   auto module = CreateNewVerifiedModule();
534   auto computation = module->AddEntryComputation(builder.Build());
535   HloInstruction* root = computation->root_instruction();
536   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
537   AlgebraicSimplifier simplifier(default_options_);
538   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
539   EXPECT_EQ(computation->root_instruction(), param0);
540 }
541 
542 // Test that select(false, a, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectFalse)543 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
544   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
545   HloComputation::Builder builder(TestName());
546   HloInstruction* param0 = builder.AddInstruction(
547       HloInstruction::CreateParameter(0, r0s32, "param0"));
548   HloInstruction* param1 = builder.AddInstruction(
549       HloInstruction::CreateParameter(1, r0s32, "param1"));
550   HloInstruction* zero = builder.AddInstruction(
551       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
552   builder.AddInstruction(HloInstruction::CreateTernary(
553       r0s32, HloOpcode::kSelect, zero, param0, param1));
554 
555   auto module = CreateNewVerifiedModule();
556   auto computation = module->AddEntryComputation(builder.Build());
557   HloInstruction* root = computation->root_instruction();
558   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
559   AlgebraicSimplifier simplifier(default_options_);
560   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
561   EXPECT_EQ(computation->root_instruction(), param1);
562 }
563 
564 // Test that select(a, b, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectIdentical)565 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
566   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
567   HloComputation::Builder builder(TestName());
568   HloInstruction* param0 = builder.AddInstruction(
569       HloInstruction::CreateParameter(0, r0s32, "param0"));
570   HloInstruction* param1 = builder.AddInstruction(
571       HloInstruction::CreateParameter(1, r0s32, "param1"));
572   builder.AddInstruction(HloInstruction::CreateTernary(
573       r0s32, HloOpcode::kSelect, param0, param1, param1));
574 
575   auto module = CreateNewVerifiedModule();
576   auto computation = module->AddEntryComputation(builder.Build());
577   HloInstruction* root = computation->root_instruction();
578   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
579   AlgebraicSimplifier simplifier(default_options_);
580   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
581   EXPECT_EQ(computation->root_instruction(), param1);
582 }
583 
584 // Test that select(not(pred), a, b) is simplified to select(pred, b, a)
TEST_F(AlgebraicSimplifierTest,SelectWithNotPred)585 TEST_F(AlgebraicSimplifierTest, SelectWithNotPred) {
586   Shape pred_ty = ShapeUtil::MakeShape(PRED, {});
587   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
588   HloComputation::Builder builder(TestName());
589   HloInstruction* param0 = builder.AddInstruction(
590       HloInstruction::CreateParameter(0, pred_ty, "param0"));
591   HloInstruction* param1 = builder.AddInstruction(
592       HloInstruction::CreateParameter(1, r0s32, "param1"));
593   HloInstruction* param2 = builder.AddInstruction(
594       HloInstruction::CreateParameter(2, r0s32, "param2"));
595   HloInstruction* pred_instr = builder.AddInstruction(
596       HloInstruction::CreateUnary(pred_ty, HloOpcode::kNot, param0));
597   builder.AddInstruction(HloInstruction::CreateTernary(
598       r0s32, HloOpcode::kSelect, pred_instr, param1, param2));
599 
600   auto module = CreateNewVerifiedModule();
601   auto computation = module->AddEntryComputation(builder.Build());
602   HloInstruction* root = computation->root_instruction();
603   EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
604   AlgebraicSimplifier simplifier(default_options_);
605   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
606   const auto& operands = computation->root_instruction()->operands();
607   EXPECT_EQ(operands[0], param0);
608   EXPECT_EQ(operands[1], param2);
609   EXPECT_EQ(operands[2], param1);
610 }
611 
612 // Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest,TwoReducesToOne)613 TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
614   auto m = CreateNewVerifiedModule();
615   HloComputation::Builder builder(TestName());
616   // Create add computation.
617   HloInstruction* zero = builder.AddInstruction(
618       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
619   HloComputation* add_computation = nullptr;
620   {
621     HloComputation::Builder builder(TestName() + ".add");
622     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
623     HloInstruction* p0 = builder.AddInstruction(
624         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
625     HloInstruction* p1 = builder.AddInstruction(
626         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
627     builder.AddInstruction(
628         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
629     add_computation = m->AddEmbeddedComputation(builder.Build());
630   }
631   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
632   HloInstruction* param = builder.AddInstruction(
633       HloInstruction::CreateParameter(0, r4f32, "param"));
634   std::vector<int64_t> dims0({0});
635   Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
636   HloInstruction* reduce0 = builder.AddInstruction(
637       HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
638   std::vector<int64_t> dims1({1, 2});
639   Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
640   builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
641                                                       dims1, add_computation));
642   m->AddEntryComputation(builder.Build());
643   AlgebraicSimplifier simplifier(default_options_);
644   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
645   HloInstruction* root = m->entry_computation()->root_instruction();
646   EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero))));
647   EXPECT_EQ(root->dimensions(), std::vector<int64_t>({0, 2, 3}));
648 }
649 
650 // Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest,AddConstOnLHS)651 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
652   auto m = CreateNewVerifiedModule();
653   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
654   HloComputation::Builder builder(TestName());
655   HloInstruction* param0 = builder.AddInstruction(
656       HloInstruction::CreateParameter(0, r0f32, "param0"));
657   HloInstruction* constant = builder.AddInstruction(
658       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
659   builder.AddInstruction(
660       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
661 
662   auto computation = m->AddEntryComputation(builder.Build());
663   HloInstruction* root = computation->root_instruction();
664   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
665   AlgebraicSimplifier simplifier(default_options_);
666   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
667   root = computation->root_instruction();
668   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant())));
669 }
670 
671 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeConstants)672 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
673   auto m = CreateNewVerifiedModule();
674   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
675   HloComputation::Builder builder(TestName());
676   HloInstruction* param0 = builder.AddInstruction(
677       HloInstruction::CreateParameter(0, r0f32, "param0"));
678   HloInstruction* constant1 = builder.AddInstruction(
679       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
680   HloInstruction* constant2 = builder.AddInstruction(
681       HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
682 
683   HloInstruction* add1 = builder.AddInstruction(
684       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
685   builder.AddInstruction(
686       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
687 
688   auto computation = m->AddEntryComputation(builder.Build());
689   HloInstruction* root = computation->root_instruction();
690   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
691   AlgebraicSimplifier simplifier(default_options_);
692   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
693   root = computation->root_instruction();
694   EXPECT_THAT(root, GmockMatch(m::Add(
695                         m::Op().Is(param0),
696                         m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
697 }
698 
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeBroadcastedConstants)699 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
700   const char* kModuleStr = R"(
701     HloModule m
702     test {
703       p0 = f32[4] parameter(0)
704       c0 = f32[] constant(1.0)
705       c1 = f32[] constant(2.0)
706       b0 = f32[4] broadcast(c0), dimensions={}
707       b1 = f32[4] broadcast(c1), dimensions={}
708       add0 = f32[4] add(p0, b0)
709       ROOT add1 = f32[4] add(add0, b1)
710     }
711   )";
712   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
713   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
714   EXPECT_THAT(m->entry_computation()->root_instruction(),
715               GmockMatch(m::Add(m::Parameter(0),
716                                 m::Broadcast(m::Add(m::ConstantScalar(1.0),
717                                                     m::ConstantScalar(2.0))))));
718 }
719 
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR0Operand)720 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
721   auto m = CreateNewVerifiedModule();
722   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
723   HloComputation::Builder builder(TestName());
724   HloInstruction* param0 = builder.AddInstruction(
725       HloInstruction::CreateParameter(0, r2f32, "param0"));
726   HloInstruction* zero = builder.AddInstruction(
727       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
728   HloInstruction* bcast = builder.AddInstruction(
729       HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
730   builder.AddInstruction(
731       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
732 
733   auto computation = m->AddEntryComputation(builder.Build());
734   HloInstruction* root = computation->root_instruction();
735   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
736   AlgebraicSimplifier simplifier(default_options_);
737   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
738   root = computation->root_instruction();
739   EXPECT_EQ(root, param0);
740 }
741 
TEST_F(AlgebraicSimplifierTest,InlineTrivialMap)742 TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
743   auto m = CreateNewVerifiedModule();
744   HloComputation::Builder builder(TestName());
745   // Create add computation.
746   HloComputation* add_computation = nullptr;
747   {
748     HloComputation::Builder builder(TestName() + ".add");
749     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
750     HloInstruction* p0 = builder.AddInstruction(
751         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
752     HloInstruction* p1 = builder.AddInstruction(
753         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
754     builder.AddInstruction(
755         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
756     add_computation = m->AddEmbeddedComputation(builder.Build());
757   }
758   Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
759   HloInstruction* param0 = builder.AddInstruction(
760       HloInstruction::CreateParameter(0, r2f32, "param0"));
761   HloInstruction* zero = builder.AddInstruction(
762       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
763   builder.AddInstruction(HloInstruction::CreateMap(
764       r2f32,
765       {param0, builder.AddInstruction(
766                    HloInstruction::CreateBroadcast(r2f32, zero, {}))},
767       add_computation));
768 
769   auto computation = m->AddEntryComputation(builder.Build());
770   HloInstruction* root = computation->root_instruction();
771   EXPECT_EQ(root->opcode(), HloOpcode::kMap);
772   AlgebraicSimplifier simplifier(default_options_);
773   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
774   root = computation->root_instruction();
775   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
776                                       m::Broadcast(m::Op().Is(zero)))));
777 }
778 
TEST_F(AlgebraicSimplifierTest,KeepNontrivialMap)779 TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) {
780   const char* kModuleStr = R"(
781     HloModule m
782     fusion {
783       x = f32[] parameter(0)
784       c = f32[] constant(42)
785       m = f32[] multiply(x, x)
786       ROOT a = f32[] add(m, c)
787     }
788 
789     map {
790       x = f32[] parameter(0)
791       ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion
792     }
793 
794     ENTRY test {
795       p = f32[2,2] parameter(0)
796       ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map
797     }
798   )";
799   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
800   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
801 }
802 
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR1Operand)803 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
804   auto m = CreateNewVerifiedModule();
805   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
806   HloComputation::Builder builder(TestName());
807   HloInstruction* param0 = builder.AddInstruction(
808       HloInstruction::CreateParameter(0, r2f32, "param0"));
809   HloInstruction* zero = builder.AddInstruction(
810       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
811   HloInstruction* bcast =
812       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
813   builder.AddInstruction(
814       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
815 
816   auto computation = m->AddEntryComputation(builder.Build());
817   HloInstruction* root = computation->root_instruction();
818   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
819   AlgebraicSimplifier simplifier(default_options_);
820   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
821   root = computation->root_instruction();
822   EXPECT_EQ(root, param0);
823 }
824 
TEST_F(AlgebraicSimplifierTest,ConstantToBroadcast)825 TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
826   auto m = CreateNewVerifiedModule();
827   HloComputation::Builder builder(TestName());
828   builder.AddInstruction(HloInstruction::CreateConstant(
829       LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
830 
831   auto computation = m->AddEntryComputation(builder.Build());
832   HloInstruction* root = computation->root_instruction();
833   EXPECT_THAT(root, GmockMatch(m::Constant()));
834   AlgebraicSimplifier simplifier(default_options_);
835   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
836   root = computation->root_instruction();
837   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
838   EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement<float>());
839 }
840 
TEST_F(AlgebraicSimplifierTest,ConstantNotToBroadcast)841 TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
842   auto m = CreateNewVerifiedModule();
843   HloComputation::Builder builder(TestName());
844   builder.AddInstruction(HloInstruction::CreateConstant(
845       LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
846 
847   auto computation = m->AddEntryComputation(builder.Build());
848   HloInstruction* root = computation->root_instruction();
849   EXPECT_THAT(root, GmockMatch(m::Constant()));
850   AlgebraicSimplifier simplifier(default_options_);
851   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
852   root = computation->root_instruction();
853   EXPECT_THAT(root, GmockMatch(m::Constant()));
854 }
855 
TEST_F(AlgebraicSimplifierTest,IotaToBroadcast)856 TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
857   auto m = CreateNewVerifiedModule();
858   HloComputation::Builder builder(TestName());
859   builder.AddInstruction(HloInstruction::CreateConstant(
860       LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
861 
862   auto computation = m->AddEntryComputation(builder.Build());
863   HloInstruction* root = computation->root_instruction();
864   EXPECT_THAT(root, GmockMatch(m::Constant()));
865   AlgebraicSimplifier simplifier(default_options_);
866   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
867   root = computation->root_instruction();
868   EXPECT_THAT(root, GmockMatch(m::Iota()));
869 }
870 
871 // Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,SubZero)872 TEST_F(AlgebraicSimplifierTest, SubZero) {
873   auto m = CreateNewVerifiedModule();
874   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
875   HloComputation::Builder builder(TestName());
876   HloInstruction* param0 = builder.AddInstruction(
877       HloInstruction::CreateParameter(0, r0f32, "param0"));
878   HloInstruction* zero = builder.AddInstruction(
879       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
880   builder.AddInstruction(
881       HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
882 
883   auto computation = m->AddEntryComputation(builder.Build());
884   HloInstruction* root = computation->root_instruction();
885   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
886   AlgebraicSimplifier simplifier(default_options_);
887   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
888   root = computation->root_instruction();
889   EXPECT_EQ(root, param0);
890 }
891 
892 // Test that A - Const is canonicalized to A + (-Const).
TEST_F(AlgebraicSimplifierTest,SubConstCanonicalization)893 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
894   auto m = CreateNewVerifiedModule();
895   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
896   HloComputation::Builder builder(TestName());
897   HloInstruction* param0 = builder.AddInstruction(
898       HloInstruction::CreateParameter(0, r0f32, "param0"));
899   HloInstruction* constant = builder.AddInstruction(
900       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
901   builder.AddInstruction(HloInstruction::CreateBinary(
902       r0f32, HloOpcode::kSubtract, param0, constant));
903 
904   auto computation = m->AddEntryComputation(builder.Build());
905   HloInstruction* root = computation->root_instruction();
906   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
907   AlgebraicSimplifier simplifier(default_options_);
908   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
909   root = computation->root_instruction();
910   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
911                                       m::Negate(m::Op().Is(constant)))));
912 }
913 
914 // Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const).
TEST_F(AlgebraicSimplifierTest,SubBroadcastConstCanonicalization)915 TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) {
916   const char* kModuleStr = R"(
917     HloModule m
918     test {
919       p0 = f32[4] parameter(0)
920       c = f32[] constant(0.125)
921       b = f32[4] broadcast(c), dimensions={}
922       ROOT sub = f32[4] subtract(p0, b)
923     }
924   )";
925   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
926   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
927   EXPECT_THAT(
928       m->entry_computation()->root_instruction(),
929       GmockMatch(m::Add(m::Parameter(0),
930                         m::Broadcast(m::Negate(m::ConstantScalar(0.125))))));
931 }
932 
933 // Test that A - A is simplified to 0.
TEST_F(AlgebraicSimplifierTest,SubSame)934 TEST_F(AlgebraicSimplifierTest, SubSame) {
935   const char* kModuleStr = R"(
936     HloModule m
937     test {
938       p0 = s32[2] parameter(0)
939       ROOT sub = s32[2] subtract(p0, p0)
940     }
941   )";
942   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
943   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
944   EXPECT_THAT(m->entry_computation()->root_instruction(),
945               GmockMatch(m::Broadcast(m::ConstantScalar(0))));
946 }
947 
948 // Test that Broadcast(x) where x has degenerate dimensions first removes the
949 // degenerate dimensions.
TEST_F(AlgebraicSimplifierTest,DegenerateDimsInOperandRemovedFromBroadcast)950 TEST_F(AlgebraicSimplifierTest, DegenerateDimsInOperandRemovedFromBroadcast) {
951   const char* kModuleStr = R"(
952     HloModule m
953     test {
954       c = f32[1,4] parameter(0)
955       ROOT b = f32[5,1,4,3] broadcast(c), dimensions={1,2}
956     }
957   )";
958   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
959   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
960   EXPECT_THAT(m->entry_computation()->root_instruction(),
961               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
962 }
963 
964 // Test to catch a crash where we were overshooting the reshaped_dimensions
965 // vector.
TEST_F(AlgebraicSimplifierTest,ArrayOvershootTest)966 TEST_F(AlgebraicSimplifierTest, ArrayOvershootTest) {
967   const char* kModuleStr = R"(
968     HloModule m
969     test {
970       param0 = f32[18,18,2,1,1,128]{1,0,5,2,4,3} parameter(0)
971       cpy1 = f32[18,18,2,1,1,128]{5,2,1,0,4,3} copy(f32[18,18,2,1,1,128]{1,0,5,2,4,3} param0)
972       bitcast = f32[648,128,1,1]{3,2,1,0} bitcast(cpy1)
973       ROOT cpy2 = f32[648,128,1,1]{3,2,0,1} copy(f32[648,128,1,1]{3,2,1,0} bitcast)
974     }
975   )";
976 
977   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
978   AlgebraicSimplifierOptions options;
979   options.set_is_layout_sensitive(true);
980   AlgebraicSimplifier simplifier(options);
981   // Assert false because algebraic simplifier - at the time of adding this
982   // test - does not change anything. Motivation of the test to make sure it
983   // does not crash the compiler.
984   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
985 }
986 
987 // Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest,LhsDivOfDiv)988 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
989   auto m = CreateNewVerifiedModule();
990   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
991   HloComputation::Builder builder(TestName());
992   HloInstruction* param0 = builder.AddInstruction(
993       HloInstruction::CreateParameter(0, r0f32, "param0"));
994   HloInstruction* param1 = builder.AddInstruction(
995       HloInstruction::CreateParameter(1, r0f32, "param1"));
996   HloInstruction* param2 = builder.AddInstruction(
997       HloInstruction::CreateParameter(2, r0f32, "param2"));
998   HloInstruction* div = builder.AddInstruction(
999       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
1000   builder.AddInstruction(
1001       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
1002 
1003   auto computation = m->AddEntryComputation(builder.Build());
1004 
1005   EXPECT_THAT(computation->root_instruction(),
1006               GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
1007                                    m::Parameter(2))));
1008 
1009   AlgebraicSimplifier simplifier(default_options_);
1010   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1011 
1012   EXPECT_THAT(
1013       computation->root_instruction(),
1014       GmockMatch(m::Divide(m::Parameter(0),
1015                            m::Multiply(m::Parameter(1), m::Parameter(2)))));
1016 }
1017 
1018 // Test that A/(B/C) is simplified to (A*C)/B.
TEST_F(AlgebraicSimplifierTest,RhsDivOfDiv)1019 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
1020   auto m = CreateNewVerifiedModule();
1021   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1022   HloComputation::Builder builder(TestName());
1023   HloInstruction* param0 = builder.AddInstruction(
1024       HloInstruction::CreateParameter(0, r0f32, "param0"));
1025   HloInstruction* param1 = builder.AddInstruction(
1026       HloInstruction::CreateParameter(1, r0f32, "param1"));
1027   HloInstruction* param2 = builder.AddInstruction(
1028       HloInstruction::CreateParameter(2, r0f32, "param2"));
1029   HloInstruction* div = builder.AddInstruction(
1030       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
1031   builder.AddInstruction(
1032       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
1033 
1034   auto computation = m->AddEntryComputation(builder.Build());
1035 
1036   EXPECT_THAT(
1037       computation->root_instruction(),
1038       GmockMatch(m::Divide(m::Parameter(0),
1039                            m::Divide(m::Parameter(1), m::Parameter(2)))));
1040 
1041   AlgebraicSimplifier simplifier(default_options_);
1042   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1043 
1044   EXPECT_THAT(
1045       computation->root_instruction(),
1046       GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)),
1047                            m::Parameter(1))));
1048 }
1049 
1050 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest,DivOfDivAndDiv)1051 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
1052   auto m = CreateNewVerifiedModule();
1053   Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
1054   HloComputation::Builder builder(TestName());
1055   HloInstruction* param0 = builder.AddInstruction(
1056       HloInstruction::CreateParameter(0, r2f32, "param0"));
1057   HloInstruction* param1 = builder.AddInstruction(
1058       HloInstruction::CreateParameter(1, r2f32, "param1"));
1059   HloInstruction* param2 = builder.AddInstruction(
1060       HloInstruction::CreateParameter(2, r2f32, "param2"));
1061   HloInstruction* param3 = builder.AddInstruction(
1062       HloInstruction::CreateParameter(3, r2f32, "param3"));
1063   HloInstruction* div0 = builder.AddInstruction(
1064       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
1065   HloInstruction* div1 = builder.AddInstruction(
1066       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
1067   builder.AddInstruction(
1068       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
1069 
1070   auto computation = m->AddEntryComputation(builder.Build());
1071 
1072   EXPECT_THAT(
1073       computation->root_instruction(),
1074       GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
1075                            m::Divide(m::Parameter(2), m::Parameter(3)))));
1076 
1077   AlgebraicSimplifier simplifier(default_options_);
1078   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1079 
1080   EXPECT_THAT(
1081       computation->root_instruction(),
1082       GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)),
1083                            m::Multiply(m::Parameter(1), m::Parameter(2)))));
1084 }
1085 
1086 // Test that A/exp(B) is simplified to A*exp(-B).
TEST_F(AlgebraicSimplifierTest,DivOfExp)1087 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
1088   auto m = CreateNewVerifiedModule();
1089   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1090   HloComputation::Builder builder(TestName());
1091   HloInstruction* param0 = builder.AddInstruction(
1092       HloInstruction::CreateParameter(0, r0f32, "param0"));
1093   HloInstruction* param1 = builder.AddInstruction(
1094       HloInstruction::CreateParameter(1, r0f32, "param1"));
1095   HloInstruction* exp = builder.AddInstruction(
1096       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1097   builder.AddInstruction(
1098       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
1099 
1100   auto computation = m->AddEntryComputation(builder.Build());
1101 
1102   EXPECT_THAT(computation->root_instruction(),
1103               GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1)))));
1104 
1105   AlgebraicSimplifier simplifier(default_options_);
1106   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1107 
1108   EXPECT_THAT(computation->root_instruction(),
1109               GmockMatch(m::Multiply(m::Parameter(0),
1110                                      m::Exp(m::Negate(m::Parameter(1))))));
1111 }
1112 
1113 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfPower)1114 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
1115   auto m = CreateNewVerifiedModule();
1116   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1117   HloComputation::Builder builder(TestName());
1118   HloInstruction* param0 = builder.AddInstruction(
1119       HloInstruction::CreateParameter(0, r0f32, "param0"));
1120   HloInstruction* param1 = builder.AddInstruction(
1121       HloInstruction::CreateParameter(1, r0f32, "param1"));
1122   HloInstruction* param2 = builder.AddInstruction(
1123       HloInstruction::CreateParameter(2, r0f32, "param2"));
1124   HloInstruction* power = builder.AddInstruction(
1125       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
1126   builder.AddInstruction(
1127       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
1128 
1129   auto computation = m->AddEntryComputation(builder.Build());
1130 
1131   EXPECT_THAT(
1132       computation->root_instruction(),
1133       GmockMatch(m::Divide(m::Parameter(0),
1134                            m::Power(m::Parameter(1), m::Parameter(2)))));
1135 
1136   AlgebraicSimplifier simplifier(default_options_);
1137   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1138 
1139   EXPECT_THAT(computation->root_instruction(),
1140               GmockMatch(m::Multiply(
1141                   m::Parameter(0),
1142                   m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1143 }
1144 
1145 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
1146 // to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfBroadcastingPower)1147 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
1148   auto m = CreateNewVerifiedModule();
1149   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1150   HloComputation::Builder builder(TestName());
1151   HloInstruction* param0 = builder.AddInstruction(
1152       HloInstruction::CreateParameter(0, r1f32, "param0"));
1153   HloInstruction* param1 = builder.AddInstruction(
1154       HloInstruction::CreateParameter(1, r1f32, "param1"));
1155   HloInstruction* param2 = builder.AddInstruction(
1156       HloInstruction::CreateParameter(2, r1f32, "param2"));
1157   HloInstruction* power = builder.AddInstruction(
1158       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
1159   builder.AddInstruction(
1160       HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
1161 
1162   auto computation = m->AddEntryComputation(builder.Build());
1163 
1164   EXPECT_THAT(
1165       computation->root_instruction(),
1166       GmockMatch(m::Divide(m::Parameter(0),
1167                            m::Power(m::Parameter(1), m::Parameter(2)))));
1168 
1169   AlgebraicSimplifier simplifier(default_options_);
1170   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1171 
1172   ASSERT_THAT(computation->root_instruction(),
1173               GmockMatch(m::Multiply(
1174                   m::Parameter(0),
1175                   m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1176 }
1177 
1178 // A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest,DivideByConstant)1179 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
1180   auto m = CreateNewVerifiedModule();
1181   Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
1182   HloComputation::Builder builder(TestName());
1183   HloInstruction* param0 = builder.AddInstruction(
1184       HloInstruction::CreateParameter(0, r1f32, "param0"));
1185   HloInstruction* constant =
1186       builder.AddInstruction(HloInstruction::CreateConstant(
1187           LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
1188   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
1189                                                       param0, constant));
1190 
1191   auto computation = m->AddEntryComputation(builder.Build());
1192 
1193   AlgebraicSimplifier simplifier(default_options_);
1194   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1195 
1196   EXPECT_THAT(computation->root_instruction(),
1197               GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
1198 }
1199 
1200 // A / Broadcast(Const) => A * Broadcast(InvertedConst)
TEST_F(AlgebraicSimplifierTest,DivideByBroadcastedConstant)1201 TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
1202   const char* kModuleStr = R"(
1203     HloModule m
1204     test {
1205       p = f32[4] parameter(0)
1206       c = f32[] constant(256.0)
1207       b = f32[4] broadcast(c), dimensions={}
1208       ROOT d = f32[4] divide(p, b)
1209     }
1210   )";
1211   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
1212   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
1213 
1214   EXPECT_THAT(m->entry_computation()->root_instruction(),
1215               GmockMatch(m::Multiply(
1216                   m::Parameter(0),
1217                   m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
1218 }
1219 
1220 // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest,PowerOfPower)1221 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
1222   auto m = CreateNewVerifiedModule();
1223   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1224   HloComputation::Builder builder(TestName());
1225   HloInstruction* base = builder.AddInstruction(
1226       HloInstruction::CreateParameter(0, r1f32, "param0"));
1227   HloInstruction* exp1 = builder.AddInstruction(
1228       HloInstruction::CreateParameter(1, r1f32, "param1"));
1229   HloInstruction* exp2 = builder.AddInstruction(
1230       HloInstruction::CreateParameter(2, r1f32, "param2"));
1231   HloInstruction* inner_power = builder.AddInstruction(
1232       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
1233   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
1234                                                       inner_power, exp2));
1235 
1236   AlgebraicSimplifier simplifier(default_options_);
1237   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1238 }
1239 
1240 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
1241 // numbers.
TEST_F(AlgebraicSimplifierTest,PowerOfPowerComplex)1242 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
1243   auto m = CreateNewVerifiedModule();
1244   Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
1245   HloComputation::Builder builder(TestName());
1246   HloInstruction* base = builder.AddInstruction(
1247       HloInstruction::CreateParameter(0, r1c64, "param0"));
1248   HloInstruction* exp1 = builder.AddInstruction(
1249       HloInstruction::CreateParameter(1, r1c64, "param1"));
1250   HloInstruction* exp2 = builder.AddInstruction(
1251       HloInstruction::CreateParameter(2, r1c64, "param2"));
1252   HloInstruction* inner_power = builder.AddInstruction(
1253       HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
1254   builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
1255                                                       inner_power, exp2));
1256 
1257   m->AddEntryComputation(builder.Build());
1258   AlgebraicSimplifier simplifier(default_options_);
1259   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1260 }
1261 
1262 // Test that A/1 is simplified to A for a scalar.
TEST_F(AlgebraicSimplifierTest,DivOneScalar)1263 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
1264   auto m = CreateNewVerifiedModule();
1265   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1266   HloComputation::Builder builder(TestName());
1267   HloInstruction* param0 = builder.AddInstruction(
1268       HloInstruction::CreateParameter(0, r0f32, "param0"));
1269   HloInstruction* one = builder.AddInstruction(
1270       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
1271   HloInstruction* div = builder.AddInstruction(
1272       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
1273 
1274   auto computation = m->AddEntryComputation(builder.Build());
1275   HloInstruction* root = computation->root_instruction();
1276   EXPECT_EQ(root, div);
1277   AlgebraicSimplifier simplifier(default_options_);
1278   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1279   root = computation->root_instruction();
1280   EXPECT_EQ(root, param0);
1281 }
1282 
1283 // Test that A/1 is simplified to A for an array.
TEST_F(AlgebraicSimplifierTest,DivOneArray)1284 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
1285   auto m = CreateNewVerifiedModule();
1286   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1287   HloComputation::Builder builder(TestName());
1288   HloInstruction* param0 = builder.AddInstruction(
1289       HloInstruction::CreateParameter(0, r2f32, "param0"));
1290   HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
1291       LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
1292   HloInstruction* div = builder.AddInstruction(
1293       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
1294 
1295   auto computation = m->AddEntryComputation(builder.Build());
1296   HloInstruction* root = computation->root_instruction();
1297   EXPECT_EQ(root, div);
1298   AlgebraicSimplifier simplifier(default_options_);
1299   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1300   root = computation->root_instruction();
1301   EXPECT_EQ(root, param0);
1302 }
1303 
1304 // Test that complex(real(c), imag(c)) is simplified to c.
TEST_F(AlgebraicSimplifierTest,ComplexOfRealImagC)1305 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
1306   auto m = CreateNewVerifiedModule();
1307   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1308   Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
1309   HloComputation::Builder builder(TestName());
1310   HloInstruction* param0 = builder.AddInstruction(
1311       HloInstruction::CreateParameter(0, r2c64, "param0"));
1312   HloInstruction* real = builder.AddInstruction(
1313       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
1314   HloInstruction* imag = builder.AddInstruction(
1315       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
1316   HloInstruction* cplx = builder.AddInstruction(
1317       HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
1318 
1319   auto computation = m->AddEntryComputation(builder.Build());
1320   HloInstruction* root = computation->root_instruction();
1321   EXPECT_EQ(root, cplx);
1322   AlgebraicSimplifier simplifier(default_options_);
1323   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1324   root = computation->root_instruction();
1325   EXPECT_EQ(root, param0);
1326 }
1327 
1328 // Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest,RealOfComplex)1329 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
1330   auto m = CreateNewVerifiedModule();
1331   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1332   HloComputation::Builder builder(TestName());
1333   HloInstruction* param0 = builder.AddInstruction(
1334       HloInstruction::CreateParameter(0, r2f32, "param0"));
1335   HloInstruction* param1 = builder.AddInstruction(
1336       HloInstruction::CreateParameter(1, r2f32, "param1"));
1337   HloInstruction* cplx = builder.AddInstruction(
1338       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1339                                    HloOpcode::kComplex, param0, param1));
1340   HloInstruction* real = builder.AddInstruction(
1341       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
1342 
1343   auto computation = m->AddEntryComputation(builder.Build());
1344   HloInstruction* root = computation->root_instruction();
1345   EXPECT_EQ(root, real);
1346   AlgebraicSimplifier simplifier(default_options_);
1347   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1348   root = computation->root_instruction();
1349   EXPECT_EQ(root, param0);
1350 }
1351 
1352 // Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest,ImagOfComplex)1353 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
1354   auto m = CreateNewVerifiedModule();
1355   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1356   HloComputation::Builder builder(TestName());
1357   HloInstruction* param0 = builder.AddInstruction(
1358       HloInstruction::CreateParameter(0, r2f32, "param0"));
1359   HloInstruction* param1 = builder.AddInstruction(
1360       HloInstruction::CreateParameter(1, r2f32, "param1"));
1361   HloInstruction* cplx = builder.AddInstruction(
1362       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1363                                    HloOpcode::kComplex, param0, param1));
1364   HloInstruction* imag = builder.AddInstruction(
1365       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
1366 
1367   auto computation = m->AddEntryComputation(builder.Build());
1368   HloInstruction* root = computation->root_instruction();
1369   EXPECT_EQ(root, imag);
1370   AlgebraicSimplifier simplifier(default_options_);
1371   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1372   root = computation->root_instruction();
1373   EXPECT_EQ(root, param1);
1374 }
1375 
1376 // Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest,SelectMakeTuple)1377 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
1378   auto m = CreateNewVerifiedModule();
1379   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1380   HloComputation::Builder builder(TestName());
1381   HloInstruction* param0 = builder.AddInstruction(
1382       HloInstruction::CreateParameter(0, r0f32, "param0"));
1383   HloInstruction* param1 = builder.AddInstruction(
1384       HloInstruction::CreateParameter(1, r0f32, "param1"));
1385   HloInstruction* param2 = builder.AddInstruction(
1386       HloInstruction::CreateParameter(2, r0f32, "param2"));
1387   HloInstruction* tuple =
1388       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1389   HloInstruction* get = builder.AddInstruction(
1390       HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
1391   HloInstruction* add = builder.AddInstruction(
1392       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
1393 
1394   auto computation = m->AddEntryComputation(builder.Build());
1395   HloInstruction* root = computation->root_instruction();
1396   EXPECT_EQ(root, add);
1397   AlgebraicSimplifier simplifier(default_options_);
1398   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1399   root = computation->root_instruction();
1400   EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2))));
1401 }
1402 
1403 // Test that exp(A)/exp(B) is simplified to exp(A-B)
TEST_F(AlgebraicSimplifierTest,ExpDiv)1404 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
1405   auto m = CreateNewVerifiedModule();
1406   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1407   HloComputation::Builder builder(TestName());
1408   HloInstruction* param0 = builder.AddInstruction(
1409       HloInstruction::CreateParameter(0, r0f32, "param0"));
1410   HloInstruction* param1 = builder.AddInstruction(
1411       HloInstruction::CreateParameter(1, r0f32, "param1"));
1412   HloInstruction* exp0 = builder.AddInstruction(
1413       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1414   HloInstruction* exp1 = builder.AddInstruction(
1415       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1416   builder.AddInstruction(
1417       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1418 
1419   auto computation = m->AddEntryComputation(builder.Build());
1420 
1421   EXPECT_THAT(
1422       computation->root_instruction(),
1423       GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1)))));
1424 
1425   AlgebraicSimplifier simplifier(default_options_);
1426   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1427 
1428   EXPECT_THAT(
1429       computation->root_instruction(),
1430       GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1)))));
1431 }
1432 
1433 // Test that exp(A)*exp(B) is simplified to exp(A+B)
TEST_F(AlgebraicSimplifierTest,ExpMul)1434 TEST_F(AlgebraicSimplifierTest, ExpMul) {
1435   auto m = CreateNewVerifiedModule();
1436   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1437   HloComputation::Builder builder(TestName());
1438   HloInstruction* param0 = builder.AddInstruction(
1439       HloInstruction::CreateParameter(0, r0f32, "param0"));
1440   HloInstruction* param1 = builder.AddInstruction(
1441       HloInstruction::CreateParameter(1, r0f32, "param1"));
1442   HloInstruction* exp0 = builder.AddInstruction(
1443       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1444   HloInstruction* exp1 = builder.AddInstruction(
1445       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1446   builder.AddInstruction(
1447       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
1448 
1449   auto computation = m->AddEntryComputation(builder.Build());
1450 
1451   EXPECT_THAT(computation->root_instruction(),
1452               GmockMatch(m::Multiply(m::Exp(m::Parameter(0)),
1453                                      m::Exp(m::Parameter(1)))));
1454 
1455   AlgebraicSimplifier simplifier(default_options_);
1456   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1457 
1458   EXPECT_THAT(computation->root_instruction(),
1459               GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1)))));
1460 }
1461 
1462 // Test that pow(exp(A), B) is simplified to exp(A*B)
TEST_F(AlgebraicSimplifierTest,PowExp)1463 TEST_F(AlgebraicSimplifierTest, PowExp) {
1464   auto m = CreateNewVerifiedModule();
1465   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1466   HloComputation::Builder builder(TestName());
1467   HloInstruction* param0 = builder.AddInstruction(
1468       HloInstruction::CreateParameter(0, r0f32, "param0"));
1469   HloInstruction* param1 = builder.AddInstruction(
1470       HloInstruction::CreateParameter(1, r0f32, "param1"));
1471   HloInstruction* exp0 = builder.AddInstruction(
1472       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1473   builder.AddInstruction(
1474       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
1475 
1476   auto computation = m->AddEntryComputation(builder.Build());
1477 
1478   EXPECT_THAT(computation->root_instruction(),
1479               GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1))));
1480 
1481   AlgebraicSimplifier simplifier(default_options_);
1482   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1483 
1484   EXPECT_THAT(
1485       computation->root_instruction(),
1486       GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1)))));
1487 }
1488 
1489 // Test that ln(pow(A, B)) is simplified to ln(A)*B
TEST_F(AlgebraicSimplifierTest,LnPow)1490 TEST_F(AlgebraicSimplifierTest, LnPow) {
1491   auto m = CreateNewVerifiedModule();
1492   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1493   HloComputation::Builder builder(TestName());
1494   HloInstruction* param0 = builder.AddInstruction(
1495       HloInstruction::CreateParameter(0, r0f32, "param0"));
1496   HloInstruction* param1 = builder.AddInstruction(
1497       HloInstruction::CreateParameter(1, r0f32, "param1"));
1498   HloInstruction* pow = builder.AddInstruction(
1499       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
1500   builder.AddInstruction(
1501       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
1502 
1503   auto computation = m->AddEntryComputation(builder.Build());
1504 
1505   EXPECT_THAT(computation->root_instruction(),
1506               GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1)))));
1507 
1508   AlgebraicSimplifier simplifier(default_options_);
1509   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1510 
1511   EXPECT_THAT(
1512       computation->root_instruction(),
1513       GmockMatch(m::Select(
1514           m::Eq(m::Parameter(1), m::ConstantScalar(0.0f)),
1515           m::ConstantScalar(0.0f),
1516           m::Multiply(m::Log(m::Abs(m::Parameter(0))), m::Parameter(1)))));
1517 }
1518 
TEST_F(AlgebraicSimplifierTest,LnSqrt)1519 TEST_F(AlgebraicSimplifierTest, LnSqrt) {
1520   auto m = CreateNewVerifiedModule();
1521   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1522   HloComputation::Builder builder(TestName());
1523   HloInstruction* param0 = builder.AddInstruction(
1524       HloInstruction::CreateParameter(0, r0f32, "param0"));
1525   HloInstruction* sqrt = builder.AddInstruction(
1526       HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
1527   builder.AddInstruction(
1528       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, sqrt));
1529 
1530   auto computation = m->AddEntryComputation(builder.Build());
1531 
1532   EXPECT_THAT(computation->root_instruction(),
1533               GmockMatch(m::Log(m::Sqrt(m::Parameter(0)))));
1534 
1535   AlgebraicSimplifier simplifier(default_options_);
1536   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1537 
1538   EXPECT_THAT(
1539       computation->root_instruction(),
1540       GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::ConstantScalar(0.5))));
1541 }
1542 
TEST_F(AlgebraicSimplifierTest,LnRsqrt)1543 TEST_F(AlgebraicSimplifierTest, LnRsqrt) {
1544   auto m = CreateNewVerifiedModule();
1545   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1546   HloComputation::Builder builder(TestName());
1547   HloInstruction* param0 = builder.AddInstruction(
1548       HloInstruction::CreateParameter(0, r0f32, "param0"));
1549   HloInstruction* rsqrt = builder.AddInstruction(
1550       HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
1551   builder.AddInstruction(
1552       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, rsqrt));
1553 
1554   auto computation = m->AddEntryComputation(builder.Build());
1555 
1556   EXPECT_THAT(computation->root_instruction(),
1557               GmockMatch(m::Log(m::Rsqrt(m::Parameter(0)))));
1558 
1559   AlgebraicSimplifier simplifier(default_options_);
1560   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1561 
1562   EXPECT_THAT(computation->root_instruction(),
1563               GmockMatch(m::Multiply(m::Log(m::Parameter(0)),
1564                                      m::ConstantScalar(-0.5))));
1565 }
1566 
1567 // Test that ln(exp(A)) is simplified to A
TEST_F(AlgebraicSimplifierTest,LnExp)1568 TEST_F(AlgebraicSimplifierTest, LnExp) {
1569   auto m = CreateNewVerifiedModule();
1570   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1571   HloComputation::Builder builder(TestName());
1572   HloInstruction* param0 = builder.AddInstruction(
1573       HloInstruction::CreateParameter(0, r0f32, "param0"));
1574   HloInstruction* exp0 = builder.AddInstruction(
1575       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1576   builder.AddInstruction(
1577       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
1578 
1579   auto computation = m->AddEntryComputation(builder.Build());
1580 
1581   EXPECT_THAT(computation->root_instruction(),
1582               GmockMatch(m::Log(m::Exp(m::Parameter(0)))));
1583 
1584   AlgebraicSimplifier simplifier(default_options_);
1585   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1586 
1587   EXPECT_EQ(computation->root_instruction(), param0);
1588 }
1589 
1590 // Test that ln(exp(A)/exp(B)) is simplified to A-B
TEST_F(AlgebraicSimplifierTest,LnExpDiv)1591 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
1592   auto m = CreateNewVerifiedModule();
1593   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1594   HloComputation::Builder builder(TestName());
1595   HloInstruction* param0 = builder.AddInstruction(
1596       HloInstruction::CreateParameter(0, r0f32, "param0"));
1597   HloInstruction* param1 = builder.AddInstruction(
1598       HloInstruction::CreateParameter(1, r0f32, "param1"));
1599   HloInstruction* exp0 = builder.AddInstruction(
1600       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1601   HloInstruction* exp1 = builder.AddInstruction(
1602       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1603   HloInstruction* div = builder.AddInstruction(
1604       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1605   builder.AddInstruction(
1606       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
1607 
1608   auto computation = m->AddEntryComputation(builder.Build());
1609 
1610   EXPECT_THAT(computation->root_instruction(),
1611               GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)),
1612                                           m::Exp(m::Parameter(1))))));
1613 
1614   AlgebraicSimplifier simplifier(default_options_);
1615   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1616 
1617   EXPECT_THAT(computation->root_instruction(),
1618               GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1))));
1619 }
1620 
1621 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
1622 // constant 1.
TEST_F(AlgebraicSimplifierTest,Pow0Scalar)1623 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
1624   auto m = CreateNewVerifiedModule();
1625   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1626   HloComputation::Builder builder(TestName());
1627   HloInstruction* param0 = builder.AddInstruction(
1628       HloInstruction::CreateParameter(0, r0f32, "param0"));
1629   HloInstruction* zero = builder.AddInstruction(
1630       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1631   builder.AddInstruction(
1632       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
1633 
1634   auto computation = m->AddEntryComputation(builder.Build());
1635 
1636   EXPECT_THAT(computation->root_instruction(),
1637               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1638 
1639   AlgebraicSimplifier simplifier(default_options_);
1640   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1641 
1642   HloInstruction* root = computation->root_instruction();
1643   EXPECT_THAT(root, GmockMatch(m::Constant()));
1644   EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
1645 }
1646 
1647 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
TEST_F(AlgebraicSimplifierTest,Pow0Vector)1648 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
1649   auto m = CreateNewVerifiedModule();
1650   Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
1651   HloComputation::Builder builder(TestName());
1652   HloInstruction* param0 = builder.AddInstruction(
1653       HloInstruction::CreateParameter(0, r1f32, "param0"));
1654   HloInstruction* zero = builder.AddInstruction(
1655       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1656   builder.AddInstruction(
1657       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
1658 
1659   auto computation = m->AddEntryComputation(builder.Build());
1660 
1661   EXPECT_THAT(computation->root_instruction(),
1662               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1663 
1664   AlgebraicSimplifier simplifier(default_options_);
1665   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1666 
1667   HloInstruction* root = computation->root_instruction();
1668   EXPECT_THAT(root, GmockMatch(m::Broadcast()));
1669   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
1670       << ShapeUtil::HumanString(root->shape());
1671   EXPECT_EQ(root->dimensions().size(), 0);
1672   EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
1673   EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1674 }
1675 
1676 // Test that pow(A, 1) is simplified to A.
TEST_F(AlgebraicSimplifierTest,Pow1)1677 TEST_F(AlgebraicSimplifierTest, Pow1) {
1678   auto m = CreateNewVerifiedModule();
1679   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1680   HloComputation::Builder builder(TestName());
1681   HloInstruction* param0 = builder.AddInstruction(
1682       HloInstruction::CreateParameter(0, r0f32, "param0"));
1683   HloInstruction* one = builder.AddInstruction(
1684       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
1685   builder.AddInstruction(
1686       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
1687 
1688   auto computation = m->AddEntryComputation(builder.Build());
1689 
1690   EXPECT_THAT(computation->root_instruction(),
1691               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one))));
1692 
1693   AlgebraicSimplifier simplifier(default_options_);
1694   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1695 
1696   EXPECT_EQ(computation->root_instruction(), param0);
1697 }
1698 
1699 // Test that pow(A, 2) is simplified to A*A.
TEST_F(AlgebraicSimplifierTest,Pow2)1700 TEST_F(AlgebraicSimplifierTest, Pow2) {
1701   auto m = CreateNewVerifiedModule();
1702   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1703   HloComputation::Builder builder(TestName());
1704   HloInstruction* param0 = builder.AddInstruction(
1705       HloInstruction::CreateParameter(0, r0f32, "param0"));
1706   HloInstruction* two = builder.AddInstruction(
1707       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
1708   builder.AddInstruction(
1709       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
1710 
1711   auto computation = m->AddEntryComputation(builder.Build());
1712 
1713   EXPECT_THAT(computation->root_instruction(),
1714               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two))));
1715 
1716   AlgebraicSimplifier simplifier(default_options_);
1717   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1718 
1719   EXPECT_THAT(computation->root_instruction(),
1720               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
1721 }
1722 
1723 // Test that pow(A, 3) is simplified to A*A*A.
TEST_F(AlgebraicSimplifierTest,Pow3)1724 TEST_F(AlgebraicSimplifierTest, Pow3) {
1725   auto m = CreateNewVerifiedModule();
1726   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1727   HloComputation::Builder builder(TestName());
1728   HloInstruction* param0 = builder.AddInstruction(
1729       HloInstruction::CreateParameter(0, r0f32, "param0"));
1730   HloInstruction* three = builder.AddInstruction(
1731       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
1732   builder.AddInstruction(
1733       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
1734 
1735   auto computation = m->AddEntryComputation(builder.Build());
1736 
1737   EXPECT_THAT(computation->root_instruction(),
1738               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
1739 
1740   AlgebraicSimplifier simplifier(default_options_);
1741   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1742 
1743   EXPECT_THAT(
1744       computation->root_instruction(),
1745       GmockMatch(m::Multiply(m::Parameter(0),
1746                              m::Multiply(m::Parameter(0), m::Parameter(0)))));
1747 }
1748 
1749 // Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest,PowNegative1)1750 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1751   auto m = CreateNewVerifiedModule();
1752   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1753   HloComputation::Builder builder(TestName());
1754   HloInstruction* param0 = builder.AddInstruction(
1755       HloInstruction::CreateParameter(0, r0f32, "param0"));
1756   HloInstruction* negative_one = builder.AddInstruction(
1757       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
1758   builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
1759                                                       param0, negative_one));
1760 
1761   auto computation = m->AddEntryComputation(builder.Build());
1762 
1763   EXPECT_THAT(computation->root_instruction(),
1764               GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one))));
1765 
1766   AlgebraicSimplifier simplifier(default_options_);
1767   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1768 
1769   HloInstruction* root = computation->root_instruction();
1770   EXPECT_THAT(root, GmockMatch(m::Divide(m::Constant(), m::Parameter(0))));
1771   EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1772 }
1773 
TEST_F(AlgebraicSimplifierTest,ZeroSizedConvolution)1774 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
1775   auto m = CreateNewVerifiedModule();
1776   auto builder = HloComputation::Builder(TestName());
1777   HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
1778       0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
1779 
1780   HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1781       1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
1782 
1783   ConvolutionDimensionNumbers dnums;
1784   dnums.set_input_batch_dimension(0);
1785   dnums.add_input_spatial_dimensions(1);
1786   dnums.set_input_feature_dimension(2);
1787 
1788   dnums.set_output_batch_dimension(0);
1789   dnums.add_output_spatial_dimensions(1);
1790   dnums.set_output_feature_dimension(2);
1791 
1792   dnums.add_kernel_spatial_dimensions(0);
1793   dnums.set_kernel_input_feature_dimension(1);
1794   dnums.set_kernel_output_feature_dimension(2);
1795   Window window;
1796   WindowDimension* dim = window.add_dimensions();
1797   dim->set_size(3);
1798   dim->set_padding_low(0);
1799   dim->set_padding_high(0);
1800   dim->set_stride(1);
1801   dim->set_window_dilation(1);
1802   dim->set_base_dilation(1);
1803   dim->set_window_reversal(false);
1804   // Create add computation.
1805   builder.AddInstruction(HloInstruction::CreateConvolve(
1806       ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
1807       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1808   m->AddEntryComputation(builder.Build());
1809   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1810   EXPECT_THAT(m->entry_computation()->root_instruction(),
1811               GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs))));
1812   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1813   EXPECT_THAT(m->entry_computation()->root_instruction(),
1814               GmockMatch(m::Broadcast(m::Constant())));
1815 }
1816 
TEST_F(AlgebraicSimplifierTest,ReduceWindowIsReduceAndReshape)1817 TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) {
1818   auto m = CreateNewVerifiedModule();
1819   auto builder = HloComputation::Builder(TestName());
1820   HloInstruction* param =
1821       builder.AddInstruction(HloInstruction::CreateParameter(
1822           0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param"));
1823   Window window;
1824   for (int64_t i = 0; i < 4; ++i) {
1825     WindowDimension* dim = window.add_dimensions();
1826     // Makes 1x2x3x1 window.
1827     dim->set_size((i % 3) + 1);
1828     dim->set_stride(1);
1829     dim->set_padding_low(0);
1830     dim->set_padding_high(0);
1831     dim->set_window_dilation(1);
1832     dim->set_base_dilation(1);
1833   }
1834   // Create add computation.
1835   HloComputation* add_computation = nullptr;
1836   {
1837     HloComputation::Builder builder(TestName() + ".add");
1838     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1839     HloInstruction* p0 = builder.AddInstruction(
1840         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1841     HloInstruction* p1 = builder.AddInstruction(
1842         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1843     builder.AddInstruction(
1844         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1845     add_computation = m->AddEmbeddedComputation(builder.Build());
1846   }
1847   builder.AddInstruction(HloInstruction::CreateReduceWindow(
1848       ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param,
1849       builder.AddInstruction(
1850           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1851       window, add_computation));
1852   m->AddEntryComputation(builder.Build());
1853   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1854   EXPECT_THAT(m->entry_computation()->root_instruction(),
1855               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1856   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1857   EXPECT_THAT(
1858       m->entry_computation()->root_instruction(),
1859       GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()))));
1860 }
1861 
TEST_F(AlgebraicSimplifierTest,ZeroSizedReduceWindow)1862 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
1863   auto m = CreateNewVerifiedModule();
1864   auto builder = HloComputation::Builder(TestName());
1865   HloInstruction* param =
1866       builder.AddInstruction(HloInstruction::CreateParameter(
1867           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1868   Window window;
1869   for (int64_t i = 0; i < 2; ++i) {
1870     WindowDimension* dim = window.add_dimensions();
1871     dim->set_size(1);
1872     dim->set_padding_low(1);
1873     dim->set_padding_high(1);
1874     dim->set_window_dilation(1);
1875     dim->set_base_dilation(1);
1876   }
1877   // Create add computation.
1878   HloComputation* add_computation = nullptr;
1879   {
1880     HloComputation::Builder builder(TestName() + ".add");
1881     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1882     HloInstruction* p0 = builder.AddInstruction(
1883         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1884     HloInstruction* p1 = builder.AddInstruction(
1885         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1886     builder.AddInstruction(
1887         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1888     add_computation = m->AddEmbeddedComputation(builder.Build());
1889   }
1890   builder.AddInstruction(HloInstruction::CreateReduceWindow(
1891       ShapeUtil::MakeShape(F32, {5, 2}), param,
1892       builder.AddInstruction(
1893           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1894       window, add_computation));
1895   m->AddEntryComputation(builder.Build());
1896   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1897   EXPECT_THAT(m->entry_computation()->root_instruction(),
1898               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1899   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1900   EXPECT_THAT(m->entry_computation()->root_instruction(),
1901               GmockMatch(m::Broadcast(m::Constant())));
1902 }
1903 
TEST_F(AlgebraicSimplifierTest,ZeroSizedVariadicReduceWindow)1904 TEST_F(AlgebraicSimplifierTest, ZeroSizedVariadicReduceWindow) {
1905   const char* const hlo_string = R"(
1906 HloModule ZeroSizedVariadicReduceWindow
1907 
1908 ZeroSizedVariadicReduceWindow.add {
1909   p0 = f32[] parameter(0)
1910   p1 = f32[] parameter(1)
1911   p2 = f32[] parameter(2)
1912   p3 = f32[] parameter(3)
1913   add.0 = f32[] add(p0, p1)
1914   add.1 = f32[] add(p2, p3)
1915   ROOT r = tuple(add.0, add.1)
1916 }
1917 
1918 ENTRY ZeroSizedReduceWindow {
1919   op = f32[3,0] parameter(0)
1920   constant = f32[] constant(0)
1921   ROOT reduce-window = (f32[5,2], f32[5,2]) reduce-window(op, op, constant, constant), window={size=1x1 pad=1_1x1_1}, to_apply=ZeroSizedVariadicReduceWindow.add
1922 }
1923 )";
1924 
1925   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
1926   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1927   EXPECT_THAT(m->entry_computation()->root_instruction(),
1928               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Parameter(0),
1929                                          m::Constant(), m::Constant())));
1930   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1931   EXPECT_THAT(m->entry_computation()->root_instruction(),
1932               GmockMatch(m::Tuple(m::Broadcast(m::Constant()),
1933                                   m::Broadcast(m::Constant()))));
1934 }
1935 
TEST_F(AlgebraicSimplifierTest,NopMax)1936 TEST_F(AlgebraicSimplifierTest, NopMax) {
1937   const char* const hlo_string = R"(
1938 HloModule test
1939 
1940 ENTRY test {
1941   p_s8   = s8[]   parameter(0)
1942   p_u8   = u8[]   parameter(1)
1943   p_s16  = s16[]  parameter(2)
1944   p_u16  = u16[]  parameter(3)
1945   p_s32  = s32[]  parameter(4)
1946   p_u32  = u32[]  parameter(5)
1947   p_s64  = s64[]  parameter(6)
1948   p_u64  = u64[]  parameter(7)
1949   p_f16  = f16[]  parameter(8)
1950   p_bf16 = bf16[] parameter(9)
1951   p_f32  = f32[]  parameter(10)
1952   p_f64  = f64[]  parameter(11)
1953 
1954   const_s8   = s8[]   constant(-128)
1955   const_u8   = u8[]   constant(0)
1956   const_s16  = s16[]  constant(-32768)
1957   const_u16  = u16[]  constant(0)
1958   const_s32  = s32[]  constant(-2147483648)
1959   const_u32  = u32[]  constant(0)
1960   const_s64  = s64[]  constant(-9223372036854775808)
1961   const_u64  = u64[]  constant(0)
1962   const_f16  = f16[]  constant(-inf)
1963   const_bf16 = bf16[] constant(-inf)
1964   const_f32  = f32[]  constant(-inf)
1965   const_f64  = f64[]  constant(-inf)
1966 
1967   max_s8   = s8[]   maximum(p_s8, const_s8)
1968   max_u8   = u8[]   maximum(p_u8, const_u8)
1969   max_s16  = s16[]  maximum(p_s16, const_s16)
1970   max_u16  = u16[]  maximum(p_u16, const_u16)
1971   max_s32  = s32[]  maximum(p_s32, const_s32)
1972   max_u32  = u32[]  maximum(p_u32, const_u32)
1973   max_s64  = s64[]  maximum(p_s64, const_s64)
1974   max_u64  = u64[]  maximum(p_u64, const_u64)
1975   max_f16  = f16[]  maximum(p_f16, const_f16)
1976   max_bf16 = bf16[] maximum(p_bf16, const_bf16)
1977   max_f32  = f32[]  maximum(p_f32, const_f32)
1978   max_f64  = f64[]  maximum(p_f64, const_f64)
1979 
1980   max2_s8   = s8[]   maximum(const_s8, p_s8)
1981   max2_u8   = u8[]   maximum(const_u8, p_u8)
1982   max2_s16  = s16[]  maximum(const_s16, p_s16)
1983   max2_u16  = u16[]  maximum(const_u16, p_u16)
1984   max2_s32  = s32[]  maximum(const_s32, p_s32)
1985   max2_u32  = u32[]  maximum(const_u32, p_u32)
1986   max2_s64  = s64[]  maximum(const_s64, p_s64)
1987   max2_u64  = u64[]  maximum(const_u64, p_u64)
1988   max2_f16  = f16[]  maximum(const_f16, p_f16)
1989   max2_bf16 = bf16[] maximum(const_bf16, p_bf16)
1990   max2_f32  = f32[]  maximum(const_f32, p_f32)
1991   max2_f64  = f64[]  maximum(const_f64, p_f64)
1992 
1993   ROOT tuple = tuple(max_s8, max_u8, max_s16, max_u16, max_s32, max_u32,
1994                      max_s64, max_u64, max_f16, max_bf16, max_f32, max_f64,
1995                      max2_s8, max2_u8, max2_s16, max2_u16, max2_s32, max2_u32,
1996                      max2_s64, max2_u64, max2_f16, max2_bf16, max2_f32, max2_f64)
1997 }
1998 )";
1999 
2000   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2001   AlgebraicSimplifier simplifier(default_options_);
2002   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2003 
2004   // We can't write GmockMatch(m::Tuple(m::Parameter(0), m::Parameter(1), ...)
2005   // because this generates a template expression that's too complicated for our
2006   // MSVC to compile.  :(
2007   SCOPED_TRACE(m->ToString());
2008   const HloInstruction* root;
2009   ASSERT_THAT(
2010       m->entry_computation()->root_instruction(),
2011       GmockMatch(
2012           m::Op(&root).WithOpcode(HloOpcode::kTuple).WithNumOperands(24)));
2013   for (int i = 0; i < root->operand_count(); i++) {
2014     SCOPED_TRACE(absl::StrCat("operand ", i));
2015     const HloInstruction* operand = root->operand(i);
2016     ASSERT_EQ(operand->opcode(), HloOpcode::kParameter);
2017     EXPECT_EQ(operand->parameter_number(), i % 12);
2018   }
2019 }
2020 
TEST_F(AlgebraicSimplifierTest,NopMin)2021 TEST_F(AlgebraicSimplifierTest, NopMin) {
2022   const char* const hlo_string = R"(
2023 HloModule test
2024 
2025 ENTRY test {
2026   p_s8   = s8[]   parameter(0)
2027   p_u8   = u8[]   parameter(1)
2028   p_s16  = s16[]  parameter(2)
2029   p_u16  = u16[]  parameter(3)
2030   p_s32  = s32[]  parameter(4)
2031   p_u32  = u32[]  parameter(5)
2032   p_s64  = s64[]  parameter(6)
2033   p_u64  = u64[]  parameter(7)
2034   p_f16  = f16[]  parameter(8)
2035   p_bf16 = bf16[] parameter(9)
2036   p_f32  = f32[]  parameter(10)
2037   p_f64  = f64[]  parameter(11)
2038 
2039   const_s8   = s8[]   constant(127)
2040   const_u8   = u8[]   constant(255)
2041   const_s16  = s16[]  constant(32767)
2042   const_u16  = u16[]  constant(65535)
2043   const_s32  = s32[]  constant(2147483647)
2044   const_u32  = u32[]  constant(4294967295)
2045   const_s64  = s64[]  constant(9223372036854775807)
2046   const_u64  = u64[]  constant(18446744073709551615)
2047   const_f16  = f16[]  constant(inf)
2048   const_bf16 = bf16[] constant(inf)
2049   const_f32  = f32[]  constant(inf)
2050   const_f64  = f64[]  constant(inf)
2051 
2052   min_s8   = s8[]   minimum(p_s8, const_s8)
2053   min_u8   = u8[]   minimum(p_u8, const_u8)
2054   min_s16  = s16[]  minimum(p_s16, const_s16)
2055   min_u16  = u16[]  minimum(p_u16, const_u16)
2056   min_s32  = s32[]  minimum(p_s32, const_s32)
2057   min_u32  = u32[]  minimum(p_u32, const_u32)
2058   min_s64  = s64[]  minimum(p_s64, const_s64)
2059   min_u64  = u64[]  minimum(p_u64, const_u64)
2060   min_f16  = f16[]  minimum(p_f16, const_f16)
2061   min_bf16 = bf16[] minimum(p_bf16, const_bf16)
2062   min_f32  = f32[]  minimum(p_f32, const_f32)
2063   min_f64  = f64[]  minimum(p_f64, const_f64)
2064 
2065   min2_s8   = s8[]   minimum(const_s8, p_s8)
2066   min2_u8   = u8[]   minimum(const_u8, p_u8)
2067   min2_s16  = s16[]  minimum(const_s16, p_s16)
2068   min2_u16  = u16[]  minimum(const_u16, p_u16)
2069   min2_s32  = s32[]  minimum(const_s32, p_s32)
2070   min2_u32  = u32[]  minimum(const_u32, p_u32)
2071   min2_s64  = s64[]  minimum(const_s64, p_s64)
2072   min2_u64  = u64[]  minimum(const_u64, p_u64)
2073   min2_f16  = f16[]  minimum(const_f16, p_f16)
2074   min2_bf16 = bf16[] minimum(const_bf16, p_bf16)
2075   min2_f32  = f32[]  minimum(const_f32, p_f32)
2076   min2_f64  = f64[]  minimum(const_f64, p_f64)
2077 
2078   ROOT tuple = tuple(min_s8, min_u8, min_s16, min_u16, min_s32, min_u32,
2079                      min_s64, min_u64, min_f16, min_bf16, min_f32, min_f64,
2080                      min2_s8, min2_u8, min2_s16, min2_u16, min2_s32, min2_u32,
2081                      min2_s64, min2_u64, min2_f16, min2_bf16, min2_f32, min2_f64)
2082 }
2083 )";
2084 
2085   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2086   AlgebraicSimplifier simplifier(default_options_);
2087   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2088 
2089   SCOPED_TRACE(m->ToString());
2090 
2091   // We can't write GmockMatch(m::Tuple(m::Parameter(0), m::Parameter(1), ...)
2092   // because this generates a template expression that's too complicated for our
2093   // MSVC to compile.  :(
2094   SCOPED_TRACE(m->ToString());
2095   const HloInstruction* root;
2096   ASSERT_THAT(
2097       m->entry_computation()->root_instruction(),
2098       GmockMatch(
2099           m::Op(&root).WithOpcode(HloOpcode::kTuple).WithNumOperands(24)));
2100   for (int i = 0; i < root->operand_count(); i++) {
2101     SCOPED_TRACE(absl::StrCat("operand ", i));
2102     const HloInstruction* operand = root->operand(i);
2103     ASSERT_EQ(operand->opcode(), HloOpcode::kParameter);
2104     EXPECT_EQ(operand->parameter_number(), i % 12);
2105   }
2106 }
2107 
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindow_Add)2108 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Add) {
2109   const char* const hlo_string = R"(
2110 HloModule test
2111 
2112 add {
2113   p0 = f32[] parameter(0)
2114   p1 = f32[] parameter(1)
2115   ROOT add = f32[] add(p0, p1)
2116 }
2117 
2118 ENTRY test {
2119   p = f32[16,32] parameter(0)
2120   constant = f32[] constant(0)
2121   ROOT reduce-window = reduce-window(p, constant), window={size=1x1}, to_apply=add
2122 }
2123 )";
2124 
2125   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2126   AlgebraicSimplifier simplifier(default_options_);
2127   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2128   EXPECT_THAT(
2129       m->entry_computation()->root_instruction(),
2130       GmockMatch(m::AddAnyOrder(m::Parameter(),
2131                                 m::Broadcast(m::ConstantEffectiveScalar(0)))));
2132 }
2133 
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindow_Min)2134 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Min) {
2135   const char* const hlo_string = R"(
2136 HloModule test
2137 
2138 min {
2139   p0 = f32[] parameter(0)
2140   p1 = f32[] parameter(1)
2141   ROOT min = f32[] minimum(p0, p1)
2142 }
2143 
2144 ENTRY test {
2145   p = f32[16,32] parameter(0)
2146   constant = f32[] constant(inf)
2147   ROOT reduce-window = reduce-window(p, constant), window={size=1x1}, to_apply=min
2148 }
2149 )";
2150 
2151   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2152   AlgebraicSimplifier simplifier(default_options_);
2153   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2154   EXPECT_THAT(
2155       m->entry_computation()->root_instruction(),
2156       GmockMatch(m::MinimumAnyOrder(
2157           m::Parameter(), m::Broadcast(m::ConstantEffectiveScalar(
2158                               std::numeric_limits<float>::infinity())))));
2159 }
2160 
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindow_Max)2161 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Max) {
2162   const char* const hlo_string = R"(
2163 HloModule test
2164 
2165 max {
2166   p0 = f32[] parameter(0)
2167   p1 = f32[] parameter(1)
2168   ROOT max = f32[] maximum(p0, p1)
2169 }
2170 
2171 ENTRY test {
2172   p = f32[16,32] parameter(0)
2173   constant = f32[] constant(-inf)
2174   ROOT reduce-window = reduce-window(p, constant), window={size=1x1}, to_apply=max
2175 }
2176 )";
2177 
2178   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2179   AlgebraicSimplifier simplifier(default_options_);
2180   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2181   EXPECT_THAT(
2182       m->entry_computation()->root_instruction(),
2183       GmockMatch(m::MaximumAnyOrder(
2184           m::Parameter(), m::Broadcast(m::ConstantEffectiveScalar(
2185                               -std::numeric_limits<float>::infinity())))));
2186 }
2187 
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindowWithPad)2188 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindowWithPad) {
2189   const char* const hlo_string = R"(
2190 HloModule test
2191 
2192 max {
2193   p0 = f32[] parameter(0)
2194   p1 = f32[] parameter(1)
2195   ROOT max = f32[] maximum(p0, p1)
2196 }
2197 
2198 ENTRY test {
2199   p = f32[16,32] parameter(0)
2200   constant = f32[] constant(-inf)
2201   ROOT reduce-window = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4}, to_apply=max
2202 }
2203 )";
2204 
2205   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2206   AlgebraicSimplifier simplifier(default_options_);
2207   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2208   EXPECT_THAT(
2209       m->entry_computation()->root_instruction(),
2210       GmockMatch(m::Pad(
2211           m::MaximumAnyOrder(m::Parameter(),
2212                              m::Broadcast(m::ConstantEffectiveScalar(
2213                                  -std::numeric_limits<float>::infinity()))),
2214           m::ConstantEffectiveScalar(
2215               -std::numeric_limits<float>::infinity()))));
2216 }
2217 
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindowWithUnsupported)2218 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindowWithUnsupported) {
2219   const char* const hlo_string = R"(
2220 HloModule test
2221 
2222 max {
2223   p0 = f32[] parameter(0)
2224   p1 = f32[] parameter(1)
2225   ROOT max = f32[] maximum(p0, p1)
2226 }
2227 
2228 unsupported_fn {
2229   p0 = f32[] parameter(0)
2230   ROOT p1 = f32[] parameter(1)
2231 }
2232 
2233 ENTRY test {
2234   p = f32[16,32] parameter(0)
2235   constant = f32[] constant(-inf)
2236   a = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 stride=1x2}, to_apply=max
2237   b = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 lhs_dilate=2x1}, to_apply=max
2238   c = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 rhs_dilate=2x1}, to_apply=max
2239   d = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 rhs_reversal=1x1}, to_apply=max
2240   e = reduce-window(p, constant), window={size=1x1}, to_apply=unsupported_fn
2241 }
2242 )";
2243 
2244   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2245   AlgebraicSimplifier simplifier(default_options_);
2246   ASSERT_FALSE(RunHloPass(&simplifier, m.get()).ValueOrDie());
2247 }
2248 
TEST_F(AlgebraicSimplifierTest,ZeroSizedPad)2249 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
2250   auto m = CreateNewVerifiedModule();
2251   auto builder = HloComputation::Builder(TestName());
2252   HloInstruction* param =
2253       builder.AddInstruction(HloInstruction::CreateParameter(
2254           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
2255   PaddingConfig padding;
2256   for (int i = 0; i < 2; ++i) {
2257     PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
2258     dimension->set_edge_padding_low(1);
2259     dimension->set_edge_padding_high(1);
2260     dimension->set_interior_padding(0);
2261   }
2262   builder.AddInstruction(HloInstruction::CreatePad(
2263       ShapeUtil::MakeShape(F32, {5, 2}), param,
2264       builder.AddInstruction(
2265           HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
2266       padding));
2267   m->AddEntryComputation(builder.Build());
2268   EXPECT_THAT(m->entry_computation()->root_instruction(),
2269               GmockMatch(m::Pad(m::Parameter(0), m::Constant())));
2270   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2271   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2272   EXPECT_THAT(m->entry_computation()->root_instruction(),
2273               GmockMatch(m::Broadcast(m::Constant())));
2274 }
2275 
TEST_F(AlgebraicSimplifierTest,ReshapeBroadcast)2276 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
2277   auto m = CreateNewVerifiedModule();
2278   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2279 
2280   auto builder = HloComputation::Builder(TestName());
2281   auto op = builder.AddInstruction(HloInstruction::CreateParameter(
2282       0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
2283   auto reshape1 = builder.AddInstruction(
2284       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
2285   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2286       ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
2287   builder.AddInstruction(HloInstruction::CreateReshape(
2288       ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
2289 
2290   auto computation = builder.Build();
2291   m->AddEntryComputation(std::move(computation));
2292 
2293   EXPECT_THAT(m->entry_computation()->root_instruction(),
2294               GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op))))));
2295 
2296   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2297   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2298 
2299   EXPECT_THAT(m->entry_computation()->root_instruction(), op);
2300 }
2301 
2302 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
TEST_F(AlgebraicSimplifierTest,ConvertBetweenSameType)2303 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
2304   auto m = CreateNewVerifiedModule();
2305   HloComputation::Builder builder(TestName());
2306   HloInstruction* input = builder.AddInstruction(
2307       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
2308   builder.AddInstruction(
2309       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
2310 
2311   auto computation = m->AddEntryComputation(builder.Build());
2312 
2313   EXPECT_THAT(computation->root_instruction(),
2314               GmockMatch(m::Convert(m::Op().Is(input))));
2315 
2316   AlgebraicSimplifier simplifier(default_options_);
2317   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2318 
2319   EXPECT_THAT(computation->root_instruction(), input);
2320 }
2321 
2322 // Test that convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
2323 // $TYPE2 and convert(A, $TYP1) is an upcast.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairUpCast)2324 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairUpCast) {
2325   auto m = CreateNewVerifiedModule();
2326   HloComputation::Builder builder(TestName());
2327   HloInstruction* input =
2328       builder.AddInstruction(HloInstruction::CreateParameter(
2329           0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
2330           "param"));
2331   HloInstruction* convert_1 =
2332       builder.AddInstruction(HloInstruction::CreateConvert(
2333           ShapeUtil::ChangeElementType(input->shape(), F32), input));
2334   builder.AddInstruction(HloInstruction::CreateConvert(
2335       ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
2336 
2337   auto computation = m->AddEntryComputation(builder.Build());
2338 
2339   EXPECT_THAT(computation->root_instruction(),
2340               GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
2341 
2342   AlgebraicSimplifier simplifier(default_options_);
2343   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2344 
2345   EXPECT_THAT(computation->root_instruction(), input);
2346 }
2347 
2348 // Test that convert(convert(A, $TYPE1), $TYPE2) is NOT simplified to A even if
2349 // A is of $TYPE2 since convert(A, $TYP1) is a downcast.
TEST_F(AlgebraicSimplifierTest,DoNotEliminateConvertPairDownCast)2350 TEST_F(AlgebraicSimplifierTest, DoNotEliminateConvertPairDownCast) {
2351   auto m = CreateNewVerifiedModule();
2352   HloComputation::Builder builder(TestName());
2353   HloInstruction* input =
2354       builder.AddInstruction(HloInstruction::CreateParameter(
2355           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2356           "param"));
2357   HloInstruction* convert_1 =
2358       builder.AddInstruction(HloInstruction::CreateConvert(
2359           ShapeUtil::ChangeElementType(input->shape(), F16), input));
2360   builder.AddInstruction(HloInstruction::CreateConvert(
2361       ShapeUtil::ChangeElementType(convert_1->shape(), F32), convert_1));
2362 
2363   auto computation = m->AddEntryComputation(builder.Build());
2364 
2365   EXPECT_THAT(computation->root_instruction(),
2366               GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
2367   AlgebraicSimplifier simplifier(default_options_);
2368   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2369   EXPECT_THAT(computation->root_instruction(),
2370               GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
2371 }
2372 
2373 // Test that Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1),
2374 // $TYPE2)), convert(convert(A, $TYPE1), $TYPE2)) is simplified to
2375 // Tuple(convert(A, $TYPE1) , floor(A), A) showing a case where the first
2376 // convert has a fan-out.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairMultiOut)2377 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairMultiOut) {
2378   auto m = CreateNewVerifiedModule();
2379   HloComputation::Builder builder(TestName());
2380   HloInstruction* input =
2381       builder.AddInstruction(HloInstruction::CreateParameter(
2382           0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
2383           "param"));
2384   HloInstruction* convert_1 =
2385       builder.AddInstruction(HloInstruction::CreateConvert(
2386           ShapeUtil::ChangeElementType(input->shape(), F32), input));
2387   HloInstruction* convert_2 =
2388       builder.AddInstruction(HloInstruction::CreateConvert(
2389           ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
2390 
2391   HloInstruction* floor = builder.AddInstruction(HloInstruction::CreateUnary(
2392       convert_2->shape(), HloOpcode::kFloor, convert_2));
2393 
2394   // Collect all the reshapes into a tuple so they are not dead.
2395   builder.AddInstruction(
2396       HloInstruction::CreateTuple({convert_1, convert_2, floor}));
2397 
2398   auto computation = m->AddEntryComputation(builder.Build());
2399   EXPECT_THAT(computation->root_instruction(),
2400               GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(convert_2),
2401                                   m::Op().Is(floor))));
2402 
2403   AlgebraicSimplifier simplifier(default_options_);
2404   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2405 
2406   EXPECT_THAT(computation->root_instruction(),
2407               GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(input),
2408                                   m::Floor(m::Op().Is(input)))));
2409 }
2410 
2411 // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest,RemoveCopy)2412 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
2413   auto m = CreateNewVerifiedModule();
2414   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2415   HloComputation::Builder builder(TestName());
2416   HloInstruction* param0 = builder.AddInstruction(
2417       HloInstruction::CreateParameter(0, r0f32, "param0"));
2418   builder.AddInstruction(
2419       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2420 
2421   auto computation = m->AddEntryComputation(builder.Build());
2422 
2423   EXPECT_THAT(computation->root_instruction(),
2424               GmockMatch(m::Copy(m::Parameter(0))));
2425 
2426   AlgebraicSimplifier simplifier(default_options_);
2427   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2428 
2429   EXPECT_THAT(computation->root_instruction(), param0);
2430 }
2431 
TEST_F(AlgebraicSimplifierTest,CopyOfReshapeOfCopyEqualsBitcast)2432 TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) {
2433   auto m = CreateNewVerifiedModule();
2434   HloComputation::Builder builder(TestName());
2435   HloInstruction* param =
2436       builder.AddInstruction(HloInstruction::CreateParameter(
2437           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2438           "param"));
2439   HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2440       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2441       HloOpcode::kCopy, param));
2442   HloInstruction* reshape =
2443       builder.AddInstruction(HloInstruction::CreateReshape(
2444           ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy));
2445   builder.AddInstruction(HloInstruction::CreateUnary(
2446       ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}),
2447       HloOpcode::kCopy, reshape));
2448   auto computation = m->AddEntryComputation(builder.Build());
2449   EXPECT_THAT(computation->root_instruction(),
2450               GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0))))));
2451 
2452   AlgebraicSimplifierOptions options;
2453   options.set_is_layout_sensitive(true);
2454   AlgebraicSimplifier simplifier(options);
2455   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2456   // Verify that the copy of reshape of copy is replaced.
2457   EXPECT_THAT(computation->root_instruction(),
2458               GmockMatch(m::Bitcast(m::Parameter(0))));
2459 }
2460 
TEST_F(AlgebraicSimplifierTest,ReshapeOfCopyEqualsBitcast)2461 TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) {
2462   auto m = CreateNewVerifiedModule();
2463   HloComputation::Builder builder(TestName());
2464   HloInstruction* param =
2465       builder.AddInstruction(HloInstruction::CreateParameter(
2466           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2467           "param"));
2468   HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2469       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2470       HloOpcode::kCopy, param));
2471   builder.AddInstruction(HloInstruction::CreateReshape(
2472       ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy));
2473 
2474   auto computation = m->AddEntryComputation(builder.Build());
2475   EXPECT_THAT(computation->root_instruction(),
2476               GmockMatch(m::Reshape(m::Copy(m::Parameter(0)))));
2477 
2478   AlgebraicSimplifierOptions options;
2479   options.set_is_layout_sensitive(true);
2480   AlgebraicSimplifier simplifier(options);
2481   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2482   // Verify that the copy of reshape of copy is replaced.
2483   EXPECT_THAT(computation->root_instruction(),
2484               GmockMatch(m::Bitcast(m::Parameter(0))));
2485 }
2486 
TEST_F(AlgebraicSimplifierTest,CopyEqualsBitcast)2487 TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
2488   auto m = CreateNewVerifiedModule();
2489   HloComputation::Builder builder(TestName());
2490   HloInstruction* param =
2491       builder.AddInstruction(HloInstruction::CreateParameter(
2492           0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2493           "param"));
2494   builder.AddInstruction(HloInstruction::CreateUnary(
2495       ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}),
2496       HloOpcode::kCopy, param));
2497   auto computation = m->AddEntryComputation(builder.Build());
2498   EXPECT_THAT(computation->root_instruction(),
2499               GmockMatch(m::Copy(m::Parameter(0))));
2500 
2501   AlgebraicSimplifierOptions options(
2502       [](const Shape&, const Shape&) { return false; });
2503   options.set_is_layout_sensitive(true);
2504   AlgebraicSimplifier simplifier1(options);
2505   ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie());
2506   // Verify that the copy is not replaced.
2507   EXPECT_THAT(computation->root_instruction(),
2508               GmockMatch(m::Copy(m::Parameter(0))));
2509 
2510   AlgebraicSimplifierOptions options2;
2511   options2.set_is_layout_sensitive(true);
2512   AlgebraicSimplifier simplifier2(options2);
2513   EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie());
2514   // Verify that the copy is replaced.
2515   EXPECT_THAT(computation->root_instruction(),
2516               GmockMatch(m::Bitcast(m::Parameter(0))));
2517 }
2518 
2519 // Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveUnaryConcatenate)2520 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
2521   auto m = CreateNewVerifiedModule();
2522   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2523   HloComputation::Builder builder(TestName());
2524   HloInstruction* param0 = builder.AddInstruction(
2525       HloInstruction::CreateParameter(0, r1f32, "param0"));
2526   builder.AddInstruction(
2527       HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
2528 
2529   auto computation = m->AddEntryComputation(builder.Build());
2530 
2531   EXPECT_THAT(computation->root_instruction(),
2532               GmockMatch(m::Concatenate(m::Parameter(0))));
2533 
2534   AlgebraicSimplifier simplifier(default_options_);
2535   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2536 
2537   EXPECT_THAT(computation->root_instruction(), param0);
2538 }
2539 
TEST_F(AlgebraicSimplifierTest,SliceReverse)2540 TEST_F(AlgebraicSimplifierTest, SliceReverse) {
2541   const char* const hlo_string = R"(
2542 HloModule module
2543 
2544 ENTRY test {
2545   param = f32[6,7,32] parameter(0)
2546   constant = f32[] constant(0)
2547   pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2548   rev = f32[8,7,32] reverse(pad), dimensions={0,2}
2549   slice = f32[1,7,32] slice(rev), slice={[2:3:1], [0:7:1], [0:32:1]}
2550   ROOT tuple = (f32[1,7,32]) tuple(slice)
2551 })";
2552 
2553   TF_ASSERT_OK_AND_ASSIGN(auto module,
2554                           ParseAndReturnVerifiedModule(hlo_string));
2555   AlgebraicSimplifier simplifier(default_options_);
2556   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2557   HloComputation* computation = module->entry_computation();
2558   EXPECT_THAT(computation->root_instruction(),
2559               GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2560   const HloInstruction* slice =
2561       computation->root_instruction()->operand(0)->operand(0);
2562   EXPECT_TRUE(
2563       ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 7, 32})));
2564   // slice start,limit of 0th and 2nd dimensions are changed
2565   // while 1st dimension's slice start, limit remains the same since
2566   // it is not reversed.
2567   EXPECT_EQ(slice->slice_starts(0), 5);
2568   EXPECT_EQ(slice->slice_limits(0), 6);
2569   EXPECT_EQ(slice->slice_starts(1), 0);
2570   EXPECT_EQ(slice->slice_limits(1), 7);
2571   EXPECT_EQ(slice->slice_starts(2), 0);
2572   EXPECT_EQ(slice->slice_limits(2), 32);
2573   EXPECT_EQ(slice->slice_strides(0), 1);
2574   EXPECT_EQ(slice->slice_strides(1), 1);
2575   EXPECT_EQ(slice->slice_strides(2), 1);
2576 }
2577 
TEST_F(AlgebraicSimplifierTest,SliceReverseNonUnitEvenOddStrides)2578 TEST_F(AlgebraicSimplifierTest, SliceReverseNonUnitEvenOddStrides) {
2579   const char* const hlo_string = R"(
2580 HloModule module
2581 
2582 ENTRY test {
2583   param = f32[6,7,32] parameter(0)
2584   constant = f32[] constant(0)
2585   pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2586   rev = f32[8,7,32] reverse(pad), dimensions={0,1,2}
2587   slice = f32[1,2,7] slice(rev), slice={[2:3:2], [0:7:4], [0:32:5]}
2588   ROOT tuple = (f32[1,2,7]) tuple(slice)
2589 })";
2590   TF_ASSERT_OK_AND_ASSIGN(auto module,
2591                           ParseAndReturnVerifiedModule(hlo_string));
2592 
2593   AlgebraicSimplifier simplifier(default_options_);
2594   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2595   HloComputation* computation = module->entry_computation();
2596   EXPECT_THAT(computation->root_instruction(),
2597               GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2598   const HloInstruction* slice =
2599       computation->root_instruction()->operand(0)->operand(0);
2600   EXPECT_TRUE(
2601       ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 2, 7})));
2602   // slice start,limit of all dimensions are changed
2603   EXPECT_EQ(slice->slice_starts(0), 5);
2604   EXPECT_EQ(slice->slice_limits(0), 6);
2605   EXPECT_EQ(slice->slice_starts(1), 2);
2606   EXPECT_EQ(slice->slice_limits(1), 7);
2607   EXPECT_EQ(slice->slice_starts(2), 1);
2608   EXPECT_EQ(slice->slice_limits(2), 32);
2609   EXPECT_EQ(slice->slice_strides(0), 2);
2610   EXPECT_EQ(slice->slice_strides(1), 4);
2611   EXPECT_EQ(slice->slice_strides(2), 5);
2612 }
2613 
2614 // Test that empty operands of concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveEmptyConcatenateOperands)2615 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
2616   auto m = CreateNewVerifiedModule();
2617   const int kParamLength = 100;
2618   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2619   HloComputation::Builder builder(TestName());
2620   HloInstruction* param0 = builder.AddInstruction(
2621       HloInstruction::CreateParameter(0, r1f32, "param0"));
2622   HloInstruction* param1 = builder.AddInstruction(
2623       HloInstruction::CreateParameter(1, r1f32, "param1"));
2624   HloInstruction* empty_literal = builder.AddInstruction(
2625       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2626   HloInstruction* empty_slice =
2627       builder.AddInstruction(HloInstruction::CreateSlice(
2628           ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
2629   Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
2630   builder.AddInstruction(HloInstruction::CreateConcatenate(
2631       result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
2632 
2633   auto computation = m->AddEntryComputation(builder.Build());
2634 
2635   EXPECT_THAT(computation->root_instruction(),
2636               GmockMatch(m::Concatenate(
2637                   m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0),
2638                   m::Op().Is(empty_slice), m::Parameter(1))));
2639 
2640   AlgebraicSimplifier simplifier(default_options_);
2641   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2642 
2643   EXPECT_THAT(computation->root_instruction(),
2644               GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0),
2645                                         m::Parameter(1))));
2646 }
2647 
2648 // Test that reduce of concat is simplified.
TEST_F(AlgebraicSimplifierTest,SimplifyReduceOfConcat)2649 TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
2650   auto m = CreateNewVerifiedModule();
2651   const int kParamLength = 100;
2652   Shape r3f32 =
2653       ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
2654   HloComputation::Builder builder(TestName());
2655   HloInstruction* param0 = builder.AddInstruction(
2656       HloInstruction::CreateParameter(0, r3f32, "param0"));
2657   HloInstruction* param1 = builder.AddInstruction(
2658       HloInstruction::CreateParameter(1, r3f32, "param1"));
2659   HloInstruction* param2 = builder.AddInstruction(
2660       HloInstruction::CreateParameter(2, r3f32, "param2"));
2661   Shape concat_shape =
2662       ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
2663   HloInstruction* Concatenate =
2664       builder.AddInstruction(HloInstruction::CreateConcatenate(
2665           concat_shape, {param0, param1, param2}, 1));
2666   HloComputation* add_computation = nullptr;
2667   {
2668     HloComputation::Builder builder(TestName() + ".add");
2669     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2670     HloInstruction* p0 = builder.AddInstruction(
2671         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
2672     HloInstruction* p1 = builder.AddInstruction(
2673         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
2674     builder.AddInstruction(
2675         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
2676     add_computation = m->AddEmbeddedComputation(builder.Build());
2677   }
2678   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
2679   Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
2680 
2681   HloInstruction* zero = builder.AddInstruction(
2682       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
2683   builder.AddInstruction(HloInstruction::CreateReduce(
2684       reduce_shape, Concatenate, zero, {1, 2}, add_computation));
2685 
2686   auto computation = m->AddEntryComputation(builder.Build());
2687 
2688   AlgebraicSimplifier simplifier(default_options_);
2689   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2690 
2691   EXPECT_THAT(
2692       computation->root_instruction(),
2693       GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)),
2694                                m::Reduce(m::Parameter(1), m::Op().Is(zero))),
2695                         m::Reduce(m::Parameter(2), m::Op().Is(zero)))));
2696 }
2697 
2698 // Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest,OnlyEmptyConcatenateOperands)2699 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
2700   auto m = CreateNewVerifiedModule();
2701   const int kParamLength = 100;
2702   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2703   HloComputation::Builder builder(TestName());
2704   HloInstruction* param0 = builder.AddInstruction(
2705       HloInstruction::CreateParameter(0, r1f32, "param0"));
2706   HloInstruction* empty_literal = builder.AddInstruction(
2707       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2708   HloInstruction* empty_slice =
2709       builder.AddInstruction(HloInstruction::CreateSlice(
2710           ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
2711   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
2712   builder.AddInstruction(HloInstruction::CreateConcatenate(
2713       result_shape, {empty_literal, empty_slice}, 0));
2714 
2715   auto computation = m->AddEntryComputation(builder.Build());
2716 
2717   EXPECT_THAT(computation->root_instruction(),
2718               GmockMatch(m::Concatenate(m::Op().Is(empty_literal),
2719                                         m::Op().Is(empty_slice))));
2720 
2721   AlgebraicSimplifier simplifier(default_options_);
2722   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2723 
2724   EXPECT_EQ(computation->root_instruction(), empty_literal);
2725 }
2726 
2727 // Test that concat with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,ConcatenateOfBroadcastBecomesPad)2728 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
2729   auto m = CreateNewVerifiedModule();
2730   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2731   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2732   HloComputation::Builder builder(TestName());
2733   HloInstruction* param0 = builder.AddInstruction(
2734       HloInstruction::CreateParameter(0, r1f32, "param0"));
2735   HloInstruction* param1 = builder.AddInstruction(
2736       HloInstruction::CreateParameter(1, r0f32, "param1"));
2737   HloInstruction* broadcast = builder.AddInstruction(
2738       HloInstruction::CreateBroadcast(r1f32, param1, {}));
2739   builder.AddInstruction(HloInstruction::CreateConcatenate(
2740       ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
2741 
2742   auto computation = m->AddEntryComputation(builder.Build());
2743 
2744   AlgebraicSimplifier simplifier(default_options_);
2745   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2746   EXPECT_THAT(computation->root_instruction(),
2747               GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1))));
2748 }
2749 
TEST_F(AlgebraicSimplifierTest,SimplifyConcatenateOfSlices)2750 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
2751   auto m = CreateNewVerifiedModule();
2752   Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
2753   Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90});
2754   HloComputation::Builder builder(TestName());
2755   HloInstruction* param0 = builder.AddInstruction(
2756       HloInstruction::CreateParameter(0, r2f32, "param0"));
2757   HloInstruction* param1 = builder.AddInstruction(
2758       HloInstruction::CreateParameter(1, r2f32, "param1"));
2759 
2760   HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2761       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
2762       /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
2763 
2764   // Cannot merge 'slice0' and 'slice1' because of different start indices in
2765   // dimension 0.
2766   HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
2767       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
2768       /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
2769 
2770   // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
2771   HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
2772       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
2773       /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
2774 
2775   // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
2776   HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
2777       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
2778       /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
2779 
2780   // Can merge 'slice3' and 'slice4'.
2781   HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
2782       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
2783       /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
2784 
2785   // Can merge 'slice4' and 'slice5'.
2786   HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
2787       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
2788       /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
2789 
2790   // Cannot merge 'slice5' and 'slice6' because of overlap.
2791   HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
2792       ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
2793       /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
2794 
2795   // Cannot merge 'slice6' and 'slice7' because of slicing from a different
2796   // parameter.
2797   HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
2798       ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
2799       /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
2800   // Can merge 'slice7' and 'slice8'.
2801   HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice(
2802       ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89},
2803       /*limit_indices=*/{100, 99}, /*strides=*/{1, 1}));
2804 
2805   builder.AddInstruction(HloInstruction::CreateConcatenate(
2806       concat_shape,
2807       {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8},
2808       1));
2809   auto computation = m->AddEntryComputation(builder.Build());
2810 
2811   AlgebraicSimplifier simplifier(default_options_);
2812   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2813   auto s = m::Slice(m::Parameter(0));
2814   EXPECT_THAT(
2815       computation->root_instruction(),
2816       GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1)))));
2817   // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
2818   // shape should have dimensions {50, 30}.
2819   EXPECT_TRUE(
2820       ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
2821                        ShapeUtil::MakeShape(F32, {50, 30})));
2822   EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
2823 
2824   // The operand 6 should be  merge of 'slice7' and 'slice8', so its
2825   // shape should have dimensions {50, 20}
2826   EXPECT_TRUE(
2827       ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(),
2828                        ShapeUtil::MakeShape(F32, {50, 20})));
2829 }
2830 
2831 // Test that a simplification which changes layouts is not performed if layout
2832 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithDifferentLayout)2833 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
2834   auto m = CreateNewVerifiedModule();
2835   HloComputation::Builder builder(TestName());
2836   HloInstruction* param0 =
2837       builder.AddInstruction(HloInstruction::CreateParameter(
2838           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2839   HloInstruction* copy = builder.AddInstruction(
2840       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2841 
2842   auto computation = m->AddEntryComputation(builder.Build());
2843 
2844   // Set to different layouts.
2845   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2846   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
2847 
2848   EXPECT_THAT(computation->root_instruction(),
2849               GmockMatch(m::Copy(m::Parameter(0))));
2850 
2851   AlgebraicSimplifierOptions options;
2852   options.set_is_layout_sensitive(true);
2853   AlgebraicSimplifier simplifier(options);
2854   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2855 
2856   // Copy has not been removed.
2857   EXPECT_THAT(computation->root_instruction(),
2858               GmockMatch(m::Copy(m::Parameter(0))));
2859 }
2860 
2861 // Test that a simplification which preserves layouts is performed if layout
2862 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithSameLayout)2863 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
2864   auto m = CreateNewVerifiedModule();
2865   HloComputation::Builder builder(TestName());
2866   HloInstruction* param0 =
2867       builder.AddInstruction(HloInstruction::CreateParameter(
2868           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2869   HloInstruction* copy = builder.AddInstruction(
2870       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2871 
2872   auto computation = m->AddEntryComputation(builder.Build());
2873 
2874   // Set to same layouts.
2875   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2876   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2877 
2878   EXPECT_THAT(computation->root_instruction(),
2879               GmockMatch(m::Copy(m::Parameter(0))));
2880 
2881   AlgebraicSimplifierOptions options;
2882   options.set_is_layout_sensitive(true);
2883   AlgebraicSimplifier simplifier(options);
2884   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2885 
2886   // Copy has been removed.
2887   EXPECT_THAT(computation->root_instruction(), param0);
2888 }
2889 
2890 // Test that a reshape which could be replaced with a bitcast is not if
2891 // add_bitcasts is false.
TEST_F(AlgebraicSimplifierTest,NoBitcastAdded)2892 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
2893   auto m = CreateNewVerifiedModule();
2894   HloComputation::Builder builder(TestName());
2895   HloInstruction* param0 =
2896       builder.AddInstruction(HloInstruction::CreateParameter(
2897           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2898   HloInstruction* reshape =
2899       builder.AddInstruction(HloInstruction::CreateReshape(
2900           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2901 
2902   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2903   *reshape->mutable_shape()->mutable_layout() =
2904       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2905 
2906   auto computation = m->AddEntryComputation(builder.Build());
2907 
2908   EXPECT_THAT(computation->root_instruction(),
2909               GmockMatch(m::Reshape(m::Parameter(0))));
2910 
2911   AlgebraicSimplifierOptions options(
2912       [](const Shape&, const Shape&) { return false; });
2913   options.set_is_layout_sensitive(true);
2914   AlgebraicSimplifier simplifier(options);
2915   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2916 
2917   // Reshape is not replaced with a bitcast.
2918   EXPECT_THAT(computation->root_instruction(),
2919               GmockMatch(m::Reshape(m::Parameter(0))));
2920 }
2921 
2922 // Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest,ReshapeOfTransposeOfRngToRng)2923 TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
2924   auto m = CreateNewVerifiedModule();
2925   HloComputation::Builder builder(TestName());
2926   HloInstruction* zero = builder.AddInstruction(
2927       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2928   HloInstruction* one = builder.AddInstruction(
2929       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
2930   HloInstruction* rng0 = builder.AddInstruction(
2931       HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
2932                                 RandomDistribution::RNG_UNIFORM, {zero, one}));
2933 
2934   HloInstruction* transpose = builder.AddInstruction(
2935       HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
2936   Shape reshape_shape = builder
2937                             .AddInstruction(HloInstruction::CreateReshape(
2938                                 ShapeUtil::MakeShape(F32, {4}), transpose))
2939                             ->shape();
2940 
2941   auto computation = m->AddEntryComputation(builder.Build());
2942 
2943   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2944   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2945 
2946   // Verify that reshape(transpose(rng)) is replace by a single rng of the
2947   // same shape as the reshape.
2948   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng()));
2949   EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
2950                                reshape_shape));
2951 }
2952 
2953 // Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest,ReshapeReplacedWithBitcast)2954 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
2955   auto m = CreateNewVerifiedModule();
2956   HloComputation::Builder builder(TestName());
2957   HloInstruction* param0 =
2958       builder.AddInstruction(HloInstruction::CreateParameter(
2959           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2960   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2961 
2962   // Reshape which can be transformed into a bitcast.
2963   HloInstruction* transformable_reshape =
2964       builder.AddInstruction(HloInstruction::CreateReshape(
2965           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2966   *transformable_reshape->mutable_shape()->mutable_layout() =
2967       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2968 
2969   // Reshape does not just add degenerate dimensions.
2970   HloInstruction* dimensions_wrong_reshape =
2971       builder.AddInstruction(HloInstruction::CreateReshape(
2972           ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
2973   *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
2974       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2975 
2976   // Reshape has wrong layout.
2977   HloInstruction* layout_wrong_reshape =
2978       builder.AddInstruction(HloInstruction::CreateReshape(
2979           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2980   *layout_wrong_reshape->mutable_shape()->mutable_layout() =
2981       LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
2982 
2983   // Collect all the reshapes into a tuple so they are not dead.
2984   builder.AddInstruction(HloInstruction::CreateTuple(
2985       {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
2986 
2987   auto computation = m->AddEntryComputation(builder.Build());
2988 
2989   EXPECT_THAT(computation->root_instruction(),
2990               GmockMatch(m::Tuple(m::Op().Is(transformable_reshape),
2991                                   m::Op().Is(dimensions_wrong_reshape),
2992                                   m::Op().Is(layout_wrong_reshape))));
2993 
2994   AlgebraicSimplifierOptions options;
2995   options.set_is_layout_sensitive(true);
2996   AlgebraicSimplifier simplifier(options);
2997   simplifier.Run(m.get()).ValueOrDie();
2998 
2999   // Verify that only the first reshape is replaced.
3000   EXPECT_THAT(
3001       computation->root_instruction(),
3002       GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape),
3003                           m::Op().Is(layout_wrong_reshape))));
3004 }
3005 
3006 // Regression test for a bug where if we failed to sink a reshape, we'd set the
3007 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkReshapeDoesntAffectChangedBit)3008 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
3009   auto m = CreateNewVerifiedModule();
3010   HloComputation::Builder builder(TestName());
3011 
3012   // This add (param0 + 0) can be simplified.
3013   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
3014   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3015       shape, HloOpcode::kAdd,
3016       builder.AddInstruction(
3017           HloInstruction::CreateParameter(0, shape, "param0")),
3018       builder.AddInstruction(HloInstruction::CreateConstant(
3019           LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
3020 
3021   builder.AddInstruction(
3022       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
3023 
3024   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
3025   m->AddEntryComputation(builder.Build());
3026   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3027 }
3028 
3029 // Regression test for a bug where if we failed to sink a reshape, we'd set the
3030 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkBroadcastDoesntAffectChangedBit)3031 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
3032   auto m = CreateNewVerifiedModule();
3033   HloComputation::Builder builder(TestName());
3034 
3035   // This add (param0 + 0) can be simplified.
3036   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
3037   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3038       shape, HloOpcode::kAdd,
3039       builder.AddInstruction(
3040           HloInstruction::CreateParameter(0, shape, "param0")),
3041       builder.AddInstruction(HloInstruction::CreateConstant(
3042           LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
3043 
3044   builder.AddInstruction(
3045       HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
3046                                       /*broadcast_dimensions=*/{0, 1}));
3047 
3048   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
3049   m->AddEntryComputation(builder.Build());
3050   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3051 }
3052 
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast1)3053 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
3054   auto m = CreateNewVerifiedModule();
3055   HloComputation::Builder builder(TestName());
3056   HloInstruction* param =
3057       builder.AddInstruction(HloInstruction::CreateParameter(
3058           0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
3059   *param->mutable_shape()->mutable_layout() =
3060       LayoutUtil::MakeLayout({1, 2, 0, 3});
3061 
3062   HloInstruction* transpose =
3063       builder.AddInstruction(HloInstruction::CreateTranspose(
3064           ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
3065   *transpose->mutable_shape()->mutable_layout() =
3066       LayoutUtil::MakeLayout({0, 1, 2, 3});
3067 
3068   auto computation = m->AddEntryComputation(builder.Build());
3069 
3070   EXPECT_THAT(computation->root_instruction(),
3071               GmockMatch(m::Transpose(m::Parameter(0))));
3072 
3073   AlgebraicSimplifierOptions options;
3074   options.set_is_layout_sensitive(true);
3075   AlgebraicSimplifier simplifier(options);
3076   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3077 
3078   // Verify that the transpose is replaced.
3079   EXPECT_THAT(computation->root_instruction(),
3080               GmockMatch(m::Bitcast(m::Parameter(0))));
3081 }
3082 
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast2)3083 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
3084   auto m = CreateNewVerifiedModule();
3085   HloComputation::Builder builder(TestName());
3086   HloInstruction* param =
3087       builder.AddInstruction(HloInstruction::CreateParameter(
3088           0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
3089   *param->mutable_shape()->mutable_layout() =
3090       LayoutUtil::MakeLayout({1, 2, 3, 0});
3091 
3092   HloInstruction* transpose =
3093       builder.AddInstruction(HloInstruction::CreateTranspose(
3094           ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
3095   *transpose->mutable_shape()->mutable_layout() =
3096       LayoutUtil::MakeLayout({3, 1, 2, 0});
3097 
3098   auto computation = m->AddEntryComputation(builder.Build());
3099 
3100   EXPECT_THAT(computation->root_instruction(),
3101               GmockMatch(m::Transpose(m::Parameter(0))));
3102 
3103   AlgebraicSimplifierOptions options;
3104   options.set_is_layout_sensitive(true);
3105   AlgebraicSimplifier simplifier(options);
3106   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3107 
3108   // Verify that the transpose is replaced.
3109   EXPECT_THAT(computation->root_instruction(),
3110               GmockMatch(m::Bitcast(m::Parameter(0))));
3111 }
3112 
TEST_F(AlgebraicSimplifierTest,ReshapesMerged)3113 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
3114   auto m = CreateNewVerifiedModule();
3115   HloComputation::Builder builder(TestName());
3116   HloInstruction* param0 =
3117       builder.AddInstruction(HloInstruction::CreateParameter(
3118           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
3119 
3120   HloInstruction* reshape1 =
3121       builder.AddInstruction(HloInstruction::CreateReshape(
3122           ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
3123 
3124   builder.AddInstruction(HloInstruction::CreateReshape(
3125       ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
3126 
3127   auto computation = m->AddEntryComputation(builder.Build());
3128 
3129   EXPECT_THAT(computation->root_instruction(),
3130               GmockMatch(m::Reshape(m::Reshape(m::Parameter(0)))));
3131 
3132   AlgebraicSimplifier simplifier(default_options_);
3133   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3134 
3135   EXPECT_THAT(computation->root_instruction(),
3136               GmockMatch(m::Reshape(m::Parameter(0))));
3137 }
3138 
TEST_F(AlgebraicSimplifierTest,CopiesMerged)3139 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
3140   auto m = CreateNewVerifiedModule();
3141   HloComputation::Builder builder(TestName());
3142   HloInstruction* param0 =
3143       builder.AddInstruction(HloInstruction::CreateParameter(
3144           0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
3145           "param0"));
3146 
3147   HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
3148       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
3149       HloOpcode::kCopy, param0));
3150 
3151   builder.AddInstruction(HloInstruction::CreateUnary(
3152       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
3153       HloOpcode::kCopy, copy1));
3154 
3155   auto computation = m->AddEntryComputation(builder.Build());
3156 
3157   EXPECT_THAT(computation->root_instruction(),
3158               GmockMatch(m::Copy(m::Copy(m::Parameter(0)))));
3159 
3160   AlgebraicSimplifierOptions options;
3161   options.set_is_layout_sensitive(true);
3162   AlgebraicSimplifier simplifier(options);
3163   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3164 
3165   EXPECT_THAT(computation->root_instruction(),
3166               GmockMatch(m::Copy(m::Parameter(0))));
3167 }
3168 
TEST_F(AlgebraicSimplifierTest,TransposesMerged)3169 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
3170   auto m = CreateNewVerifiedModule();
3171   HloComputation::Builder builder(TestName());
3172   HloInstruction* param0 =
3173       builder.AddInstruction(HloInstruction::CreateParameter(
3174           0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
3175 
3176   HloInstruction* transpose1 =
3177       builder.AddInstruction(HloInstruction::CreateTranspose(
3178           ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
3179 
3180   builder.AddInstruction(HloInstruction::CreateTranspose(
3181       ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
3182 
3183   auto computation = m->AddEntryComputation(builder.Build());
3184 
3185   EXPECT_THAT(computation->root_instruction(),
3186               GmockMatch(m::Transpose(m::Op().Is(transpose1))));
3187 
3188   AlgebraicSimplifier simplifier(default_options_);
3189   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3190 
3191   EXPECT_THAT(computation->root_instruction(),
3192               GmockMatch(m::Transpose(m::Parameter(0))));
3193   EXPECT_EQ(std::vector<int64_t>({2, 1, 0}),
3194             computation->root_instruction()->dimensions());
3195 }
3196 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcast)3197 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
3198   const char* hlo_string = R"(
3199     HloModule module
3200 
3201     ENTRY test {
3202       p0 = f32[10,20] parameter(0)
3203       b = f32[10,30,20] broadcast(p0), dimensions={0,2}
3204       ROOT s = f32[5,5,5] slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
3205     }
3206   )";
3207   TF_ASSERT_OK_AND_ASSIGN(auto module,
3208                           ParseAndReturnVerifiedModule(hlo_string));
3209 
3210   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3211   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3212   auto root = module->entry_computation()->root_instruction();
3213   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
3214 }
3215 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastPreserveLayout)3216 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
3217   const char* hlo_string = R"(
3218     HloModule module
3219 
3220     ENTRY test {
3221       p0 = f32[10,20] parameter(0)
3222       b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
3223       ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
3224     }
3225   )";
3226   TF_ASSERT_OK_AND_ASSIGN(auto module,
3227                           ParseAndReturnVerifiedModule(hlo_string));
3228 
3229   const Shape original_slice_shape =
3230       module->entry_computation()->root_instruction()->shape();
3231   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3232   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3233   auto root = module->entry_computation()->root_instruction();
3234   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
3235   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
3236 }
3237 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcast)3238 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
3239   const char* hlo_string = R"(
3240     HloModule module
3241 
3242     ENTRY test {
3243       p0 = f32[10,20] parameter(0)
3244       i0 = s32[] parameter(1)
3245       i1 = s32[] parameter(2)
3246       i2 = s32[] parameter(3)
3247       b = f32[10,30,20] broadcast(p0), dimensions={0,2}
3248       ROOT ds = f32[5,5,5] dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
3249     }
3250   )";
3251   TF_ASSERT_OK_AND_ASSIGN(auto module,
3252                           ParseAndReturnVerifiedModule(hlo_string));
3253 
3254   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3255   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3256   auto root = module->entry_computation()->root_instruction();
3257   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
3258                         m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
3259 }
3260 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcastPreserveLayout)3261 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
3262   const char* hlo_string = R"(
3263     HloModule module
3264 
3265     ENTRY test {
3266       p0 = f32[10,20] parameter(0)
3267       i0 = s32[] parameter(1)
3268       i1 = s32[] parameter(2)
3269       i2 = s32[] parameter(3)
3270       b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
3271       ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
3272     }
3273   )";
3274   TF_ASSERT_OK_AND_ASSIGN(auto module,
3275                           ParseAndReturnVerifiedModule(hlo_string));
3276 
3277   const Shape original_dynslice_shape =
3278       module->entry_computation()->root_instruction()->shape();
3279   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3280   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3281   auto root = module->entry_computation()->root_instruction();
3282   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
3283                         m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
3284   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
3285 }
3286 
TEST_F(AlgebraicSimplifierTest,TransposeIsReshape)3287 TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
3288   const char* hlo_string = R"(
3289     HloModule module
3290 
3291     ENTRY test {
3292       param = f32[10] parameter(0)
3293       reshaped = f32[1,1,10] reshape(f32[10] param)
3294       transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
3295       ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
3296     }
3297   )";
3298   TF_ASSERT_OK_AND_ASSIGN(auto module,
3299                           ParseAndReturnVerifiedModule(hlo_string));
3300 
3301   HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3302   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3303   auto root = module->entry_computation()->root_instruction();
3304   EXPECT_THAT(root, GmockMatch(m::Parameter()));
3305 }
3306 
3307 // Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest,ReshapeAndBroadcastMerged)3308 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
3309   auto m = CreateNewVerifiedModule();
3310   HloComputation::Builder builder(TestName());
3311   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
3312       0, ShapeUtil::MakeShape(F32, {5}), "param0"));
3313   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
3314       ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
3315   builder.AddInstruction(HloInstruction::CreateBroadcast(
3316       ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
3317 
3318   auto computation = m->AddEntryComputation(builder.Build());
3319 
3320   EXPECT_THAT(computation->root_instruction(),
3321               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3322 
3323   AlgebraicSimplifier simplifier(default_options_);
3324   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3325 
3326   EXPECT_THAT(computation->root_instruction(),
3327               GmockMatch(m::Broadcast(m::Parameter(0))));
3328 }
3329 
3330 // Test merging broadcast and reshape.
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshapeMerged)3331 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
3332   auto m = CreateNewVerifiedModule();
3333   HloComputation::Builder builder(TestName());
3334   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
3335       0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
3336   auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
3337       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
3338   builder.AddInstruction(HloInstruction::CreateReshape(
3339       ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
3340 
3341   auto computation = m->AddEntryComputation(builder.Build());
3342 
3343   EXPECT_THAT(computation->root_instruction(),
3344               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3345 
3346   AlgebraicSimplifier simplifier(default_options_);
3347   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3348 
3349   EXPECT_THAT(computation->root_instruction(),
3350               GmockMatch(m::Broadcast(m::Parameter(0))));
3351 }
3352 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x1_3)3353 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
3354   auto m = CreateNewVerifiedModule();
3355   HloComputation::Builder builder(TestName());
3356   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3357       0, ShapeUtil::MakeShape(F32, {1}), "param"));
3358   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3359       ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
3360   builder.AddInstruction(
3361       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
3362 
3363   auto computation = m->AddEntryComputation(builder.Build());
3364 
3365   EXPECT_THAT(computation->root_instruction(),
3366               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3367 
3368   AlgebraicSimplifier simplifier(default_options_);
3369   EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3370 
3371   EXPECT_THAT(computation->root_instruction(),
3372               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3373 }
3374 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4_6x1x1x4)3375 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
3376   auto m = CreateNewVerifiedModule();
3377   HloComputation::Builder builder(TestName());
3378   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3379       0, ShapeUtil::MakeShape(F32, {4}), "param"));
3380   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3381       ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
3382   builder.AddInstruction(HloInstruction::CreateReshape(
3383       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
3384 
3385   HloComputation* computation = m->AddEntryComputation(builder.Build());
3386 
3387   EXPECT_THAT(computation->root_instruction(),
3388               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3389 
3390   AlgebraicSimplifier simplifier(default_options_);
3391   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3392 
3393   EXPECT_THAT(computation->root_instruction(),
3394               GmockMatch(m::Broadcast(m::Parameter(0))));
3395   EXPECT_THAT(computation->root_instruction()->dimensions(),
3396               ::testing::ElementsAre(3));
3397 }
3398 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x2x1_6x1x1x1)3399 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
3400   auto m = CreateNewVerifiedModule();
3401   HloComputation::Builder builder(TestName());
3402   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3403       0, ShapeUtil::MakeShape(F32, {1}), "param"));
3404   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3405       ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
3406   builder.AddInstruction(HloInstruction::CreateReshape(
3407       ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
3408 
3409   HloComputation* computation = m->AddEntryComputation(builder.Build());
3410 
3411   EXPECT_THAT(computation->root_instruction(),
3412               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3413 
3414   AlgebraicSimplifier simplifier(default_options_);
3415   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3416 
3417   EXPECT_THAT(computation->root_instruction(),
3418               GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3419   EXPECT_EQ(0, computation->root_instruction()->dimensions().size());
3420 }
3421 
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4x2_6x8)3422 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
3423   auto m = CreateNewVerifiedModule();
3424   HloComputation::Builder builder(TestName());
3425   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3426       0, ShapeUtil::MakeShape(F32, {4}), "param"));
3427   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3428       ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
3429   builder.AddInstruction(HloInstruction::CreateReshape(
3430       ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
3431 
3432   HloComputation* computation = m->AddEntryComputation(builder.Build());
3433 
3434   EXPECT_THAT(computation->root_instruction(),
3435               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3436 
3437   AlgebraicSimplifier simplifier(default_options_);
3438   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3439 
3440   EXPECT_THAT(computation->root_instruction(),
3441               GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3442 }
3443 
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeMerged)3444 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
3445   auto m = CreateNewVerifiedModule();
3446   HloComputation::Builder builder(TestName());
3447   auto iota = builder.AddInstruction(HloInstruction::CreateIota(
3448       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
3449   Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
3450   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3451 
3452   auto computation = m->AddEntryComputation(builder.Build());
3453 
3454   EXPECT_THAT(computation->root_instruction(),
3455               GmockMatch(m::Reshape(m::Iota())));
3456 
3457   AlgebraicSimplifier simplifier(default_options_);
3458   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3459 
3460   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3461   EXPECT_TRUE(
3462       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3463 }
3464 
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadix)3465 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
3466   auto m = CreateNewVerifiedModule();
3467   HloComputation::Builder builder(TestName());
3468   auto iota = builder.AddInstruction(
3469       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
3470   Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
3471   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3472 
3473   auto computation = m->AddEntryComputation(builder.Build());
3474 
3475   EXPECT_THAT(computation->root_instruction(),
3476               GmockMatch(m::Reshape(m::Iota())));
3477 
3478   AlgebraicSimplifier simplifier(default_options_);
3479   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3480 
3481   EXPECT_THAT(computation->root_instruction(),
3482               GmockMatch(m::Add(
3483                   m::Iota(),
3484                   m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3485   EXPECT_TRUE(
3486       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3487 }
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadixExtraDims)3488 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
3489   auto m = CreateNewVerifiedModule();
3490   HloComputation::Builder builder(TestName());
3491   auto iota = builder.AddInstruction(
3492       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
3493   Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
3494   builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3495 
3496   auto computation = m->AddEntryComputation(builder.Build());
3497 
3498   EXPECT_THAT(computation->root_instruction(),
3499               GmockMatch(m::Reshape(m::Iota())));
3500 
3501   AlgebraicSimplifier simplifier(default_options_);
3502   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3503 
3504   EXPECT_THAT(
3505       computation->root_instruction(),
3506       GmockMatch(m::Add(
3507           m::Add(m::Iota(),
3508                  m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
3509           m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3510   EXPECT_TRUE(
3511       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3512 }
TEST_F(AlgebraicSimplifierTest,IotaEffectiveScalar)3513 TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
3514   auto m = CreateNewVerifiedModule();
3515   HloComputation::Builder builder(TestName());
3516   auto iota = builder.AddInstruction(
3517       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
3518   auto result_shape = iota->shape();
3519 
3520   auto computation = m->AddEntryComputation(builder.Build());
3521 
3522   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3523 
3524   AlgebraicSimplifier simplifier(default_options_);
3525   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3526 
3527   auto root = computation->root_instruction();
3528   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
3529   EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
3530   EXPECT_TRUE(
3531       ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3532 }
3533 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2_6)3534 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
3535   auto m = CreateNewVerifiedModule();
3536   HloComputation::Builder builder(TestName());
3537   auto iota = builder.AddInstruction(
3538       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
3539   builder.AddInstruction(
3540       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
3541 
3542   auto computation = m->AddEntryComputation(builder.Build());
3543 
3544   EXPECT_THAT(computation->root_instruction(),
3545               GmockMatch(m::Reshape(m::Iota())));
3546 
3547   AlgebraicSimplifier simplifier(default_options_);
3548   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3549 
3550   EXPECT_THAT(computation->root_instruction(),
3551               GmockMatch(m::Reshape(m::Iota())));
3552 }
3553 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4_6x1x1x4)3554 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
3555   auto m = CreateNewVerifiedModule();
3556   HloComputation::Builder builder(TestName());
3557   auto iota = builder.AddInstruction(
3558       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
3559   builder.AddInstruction(HloInstruction::CreateReshape(
3560       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
3561 
3562   HloComputation* computation = m->AddEntryComputation(builder.Build());
3563 
3564   EXPECT_THAT(computation->root_instruction(),
3565               GmockMatch(m::Reshape(m::Iota())));
3566 
3567   AlgebraicSimplifier simplifier(default_options_);
3568   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3569 
3570   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3571   EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
3572                 ->iota_dimension(),
3573             3);
3574 }
3575 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2x2_6x1x1x2)3576 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
3577   auto m = CreateNewVerifiedModule();
3578   HloComputation::Builder builder(TestName());
3579   auto iota = builder.AddInstruction(
3580       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
3581   builder.AddInstruction(HloInstruction::CreateReshape(
3582       ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
3583 
3584   HloComputation* computation = m->AddEntryComputation(builder.Build());
3585 
3586   EXPECT_THAT(computation->root_instruction(),
3587               GmockMatch(m::Reshape(m::Iota())));
3588 
3589   AlgebraicSimplifier simplifier(default_options_);
3590   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3591 
3592   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3593   const int64_t iota_dim =
3594       Cast<HloIotaInstruction>(computation->root_instruction())
3595           ->iota_dimension();
3596   EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
3597 }
3598 
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4x2_6x8)3599 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
3600   auto m = CreateNewVerifiedModule();
3601   HloComputation::Builder builder(TestName());
3602   auto iota = builder.AddInstruction(
3603       HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
3604   builder.AddInstruction(
3605       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
3606 
3607   HloComputation* computation = m->AddEntryComputation(builder.Build());
3608 
3609   EXPECT_THAT(computation->root_instruction(),
3610               GmockMatch(m::Reshape(m::Iota())));
3611 
3612   AlgebraicSimplifier simplifier(default_options_);
3613   EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3614 
3615   EXPECT_THAT(computation->root_instruction(),
3616               GmockMatch(m::Reshape(m::Iota())));
3617 }
3618 
TEST_F(AlgebraicSimplifierTest,RemoveNoopPad)3619 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
3620   HloComputation::Builder builder(TestName());
3621   HloInstruction* param =
3622       builder.AddInstruction(HloInstruction::CreateParameter(
3623           0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3624   HloInstruction* zero = builder.AddInstruction(
3625       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3626   PaddingConfig no_padding;
3627   for (int i = 0; i < 2; ++i) {
3628     auto dimension = no_padding.add_dimensions();
3629     dimension->set_edge_padding_low(0);
3630     dimension->set_edge_padding_high(0);
3631     dimension->set_interior_padding(0);
3632   }
3633   builder.AddInstruction(HloInstruction::CreatePad(
3634       ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
3635 
3636   auto module = CreateNewVerifiedModule();
3637   HloComputation* computation = module->AddEntryComputation(builder.Build());
3638 
3639   EXPECT_THAT(computation->root_instruction(),
3640               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3641 
3642   AlgebraicSimplifier simplifier(default_options_);
3643   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3644 
3645   EXPECT_THAT(computation->root_instruction(), param);
3646 }
3647 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSliceOfPad)3648 TEST_F(AlgebraicSimplifierTest, RemoveNoopSliceOfPad) {
3649   HloComputation::Builder builder(TestName());
3650   HloInstruction* param =
3651       builder.AddInstruction(HloInstruction::CreateParameter(
3652           0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3653   HloInstruction* zero = builder.AddInstruction(
3654       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3655   PaddingConfig no_padding;
3656   for (int i = 0; i < 2; ++i) {
3657     auto dimension = no_padding.add_dimensions();
3658     dimension->set_edge_padding_low(2);
3659     dimension->set_edge_padding_high(0);
3660     dimension->set_interior_padding(1);
3661   }
3662   auto pad = builder.AddInstruction(HloInstruction::CreatePad(
3663       ShapeUtil::MakeShape(F32, {5, 5}), param, zero, no_padding));
3664   builder.AddInstruction(HloInstruction::CreateSlice(
3665       ShapeUtil::MakeShape(F32, {2, 2}), pad, /*start_indices=*/{2, 2},
3666       /*limit_indices=*/{5, 5}, /*strides=*/{2, 2}));
3667 
3668   auto module = CreateNewVerifiedModule();
3669   HloComputation* computation = module->AddEntryComputation(builder.Build());
3670 
3671   EXPECT_THAT(computation->root_instruction(),
3672               GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3673 
3674   AlgebraicSimplifier simplifier(default_options_);
3675   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3676 
3677   EXPECT_THAT(computation->root_instruction(), param);
3678 }
3679 
TEST_F(AlgebraicSimplifierTest,NegativePadding)3680 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
3681   // Verify that a pad instruction with negative padding is replaced with a
3682   // pad with non-negative padding followed by a slice.
3683   HloComputation::Builder builder(TestName());
3684   HloInstruction* param =
3685       builder.AddInstruction(HloInstruction::CreateParameter(
3686           0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3687   HloInstruction* zero = builder.AddInstruction(
3688       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3689   PaddingConfig padding;
3690   int64_t low_padding[2] = {-1, -2};
3691   int64_t high_padding[2] = {2, -3};
3692   for (int i = 0; i < 2; ++i) {
3693     auto dimension = padding.add_dimensions();
3694     dimension->set_edge_padding_low(low_padding[i]);
3695     dimension->set_edge_padding_high(high_padding[i]);
3696     dimension->set_interior_padding(0);
3697   }
3698   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3699       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3700 
3701   auto module = CreateNewVerifiedModule();
3702   HloComputation* computation = module->AddEntryComputation(builder.Build());
3703 
3704   AlgebraicSimplifier simplifier(default_options_);
3705 
3706   auto has_negative_padding = [](const HloInstruction* pad) {
3707     for (auto& padding_dimension : pad->padding_config().dimensions()) {
3708       if (padding_dimension.edge_padding_low() < 0 ||
3709           padding_dimension.edge_padding_high() < 0) {
3710         return true;
3711       }
3712     }
3713     return false;
3714   };
3715 
3716   EXPECT_THAT(computation->root_instruction(),
3717               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3718   EXPECT_TRUE(has_negative_padding(pad));
3719 
3720   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3721 
3722   EXPECT_THAT(computation->root_instruction(),
3723               GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3724   EXPECT_FALSE(
3725       has_negative_padding(computation->root_instruction()->operand(0)));
3726 }
3727 
TEST_F(AlgebraicSimplifierTest,CanDisableBroadcastSinking)3728 TEST_F(AlgebraicSimplifierTest, CanDisableBroadcastSinking) {
3729   // Some broadcasts can be sunk (or delayed). This test verifies that we can
3730   // disable this behavior when necessary.
3731   HloComputation::Builder builder(TestName());
3732   HloInstruction* param =
3733       builder.AddInstruction(HloInstruction::CreateParameter(
3734           0, ShapeUtil::MakeShape(F32, {}), "scalar"));
3735   HloInstruction* broadcast =
3736       builder.AddInstruction(HloInstruction::CreateBroadcast(
3737           ShapeUtil::MakeShape(F32, {512, 16}), param, {}));
3738   builder.AddInstruction(HloInstruction::CreateUnary(
3739       ShapeUtil::MakeShape(F32, {512, 16}), HloOpcode::kNegate, broadcast));
3740 
3741   auto module = CreateNewVerifiedModule();
3742   HloComputation* computation = module->AddEntryComputation(builder.Build());
3743 
3744   EXPECT_THAT(computation->root_instruction(),
3745               GmockMatch(m::Negate(m::Broadcast(m::Parameter(0)))));
3746 
3747   // Verify that we can disable the broadcast sinking optimization.
3748   AlgebraicSimplifierOptions opts = default_options_;
3749   opts.set_enable_sink_broadcast(false);
3750   AlgebraicSimplifier simplifier(opts);
3751 
3752   // Nothing has changed since broadcast sinking is disabled.
3753   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3754 }
3755 
TEST_F(AlgebraicSimplifierTest,CanDisableNegativePadding)3756 TEST_F(AlgebraicSimplifierTest, CanDisableNegativePadding) {
3757   // Verify that a pad instruction with negative padding is replaced with a
3758   // pad with non-negative padding followed by a slice.
3759   HloComputation::Builder builder(TestName());
3760   HloInstruction* param =
3761       builder.AddInstruction(HloInstruction::CreateParameter(
3762           0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3763   HloInstruction* zero = builder.AddInstruction(
3764       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3765   PaddingConfig padding;
3766   int64_t low_padding[2] = {-1, -2};
3767   int64_t high_padding[2] = {2, -3};
3768   for (int i = 0; i < 2; ++i) {
3769     auto dimension = padding.add_dimensions();
3770     dimension->set_edge_padding_low(low_padding[i]);
3771     dimension->set_edge_padding_high(high_padding[i]);
3772     dimension->set_interior_padding(0);
3773   }
3774   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3775       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3776 
3777   auto module = CreateNewVerifiedModule();
3778   HloComputation* computation = module->AddEntryComputation(builder.Build());
3779 
3780   // Verify that we can disable the negative padding optimization.
3781   AlgebraicSimplifierOptions opts = default_options_;
3782   opts.set_enable_negative_padding_replacement(false);
3783 
3784   AlgebraicSimplifier simplifier(opts);
3785 
3786   auto has_negative_padding = [](const HloInstruction* pad) {
3787     for (auto& padding_dimension : pad->padding_config().dimensions()) {
3788       if (padding_dimension.edge_padding_low() < 0 ||
3789           padding_dimension.edge_padding_high() < 0) {
3790         return true;
3791       }
3792     }
3793     return false;
3794   };
3795 
3796   EXPECT_THAT(computation->root_instruction(),
3797               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3798   EXPECT_TRUE(has_negative_padding(pad));
3799 
3800   // Nothing has changed since the negative padding replacement is disabled.
3801   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3802 }
3803 
TEST_F(AlgebraicSimplifierTest,TrivialInteriorPadding)3804 TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
3805   // Verify that a pad instruction with interior padding on one-sized
3806   // dimensions, removes the interior padding.
3807   HloComputation::Builder builder(TestName());
3808   HloInstruction* param =
3809       builder.AddInstruction(HloInstruction::CreateParameter(
3810           0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
3811   HloInstruction* zero = builder.AddInstruction(
3812       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3813   PaddingConfig padding;
3814   for (int i = 0; i < 2; ++i) {
3815     auto dimension = padding.add_dimensions();
3816     dimension->set_edge_padding_low(3);
3817     dimension->set_edge_padding_high(3);
3818     dimension->set_interior_padding(i * 3);
3819   }
3820   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3821       ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
3822 
3823   auto module = CreateNewVerifiedModule();
3824   HloComputation* computation = module->AddEntryComputation(builder.Build());
3825 
3826   AlgebraicSimplifier simplifier(default_options_);
3827 
3828   ASSERT_THAT(computation->root_instruction(),
3829               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3830   ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
3831 
3832   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3833 
3834   EXPECT_THAT(computation->root_instruction(),
3835               GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3836   EXPECT_FALSE(
3837       HasInteriorPadding(computation->root_instruction()->padding_config()));
3838 }
3839 
TEST_F(AlgebraicSimplifierTest,RemoveNoopReshape)3840 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
3841   HloComputation::Builder builder(TestName());
3842   HloInstruction* param =
3843       builder.AddInstruction(HloInstruction::CreateParameter(
3844           0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
3845   builder.AddInstruction(
3846       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
3847 
3848   auto module = CreateNewVerifiedModule();
3849   HloComputation* computation = module->AddEntryComputation(builder.Build());
3850 
3851   EXPECT_THAT(computation->root_instruction(),
3852               GmockMatch(m::Reshape(m::Parameter(0))));
3853 
3854   AlgebraicSimplifier simplifier(default_options_);
3855   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3856 
3857   EXPECT_THAT(computation->root_instruction(), param);
3858 }
3859 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSlice)3860 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
3861   HloComputation::Builder builder(TestName());
3862   const int64_t dim0 = 2;
3863   const int64_t dim1 = 3;
3864   HloInstruction* param =
3865       builder.AddInstruction(HloInstruction::CreateParameter(
3866           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3867   builder.AddInstruction(HloInstruction::CreateSlice(
3868       ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
3869       /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
3870 
3871   auto module = CreateNewVerifiedModule();
3872   HloComputation* computation = module->AddEntryComputation(builder.Build());
3873 
3874   EXPECT_THAT(computation->root_instruction(),
3875               GmockMatch(m::Slice(m::Parameter(0))));
3876 
3877   AlgebraicSimplifier simplifier(default_options_);
3878   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3879 
3880   EXPECT_THAT(computation->root_instruction(), param);
3881 }
3882 
TEST_F(AlgebraicSimplifierTest,SliceOfSliceToSlice)3883 TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
3884   HloComputation::Builder builder(TestName());
3885   const int64_t dim0 = 11;
3886   const int64_t dim1 = 12;
3887   HloInstruction* param =
3888       builder.AddInstruction(HloInstruction::CreateParameter(
3889           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3890   HloInstruction* original_slice =
3891       builder.AddInstruction(HloInstruction::CreateSlice(
3892           ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
3893           /*start_indices=*/{1, 2},
3894           /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
3895 
3896   builder.AddInstruction(HloInstruction::CreateSlice(
3897       ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
3898       /*start_indices=*/{2, 3},
3899       /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
3900   auto module = CreateNewVerifiedModule();
3901   HloComputation* computation = module->AddEntryComputation(builder.Build());
3902 
3903   EXPECT_THAT(computation->root_instruction(),
3904               GmockMatch(m::Slice(m::Slice(m::Parameter(0)))));
3905 
3906   AlgebraicSimplifier simplifier(default_options_);
3907   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3908 
3909   EXPECT_THAT(computation->root_instruction(),
3910               GmockMatch(m::Slice(m::Parameter(0))));
3911   EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
3912   EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
3913   EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
3914   EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
3915 }
3916 
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastToBroadcast)3917 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) {
3918   HloComputation::Builder builder(TestName());
3919   const int64_t dim0 = 11;
3920   const int64_t dim1 = 12;
3921   HloInstruction* param =
3922       builder.AddInstruction(HloInstruction::CreateParameter(
3923           0, ShapeUtil::MakeShape(F32, {dim0}), "param"));
3924   HloInstruction* broadcast =
3925       builder.AddInstruction(HloInstruction::CreateBroadcast(
3926           ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0}));
3927   builder.AddInstruction(HloInstruction::CreateSlice(
3928       ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast,
3929       /*start_indices=*/{0, 3},
3930       /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1}));
3931   auto module = CreateNewVerifiedModule();
3932   HloComputation* computation = module->AddEntryComputation(builder.Build());
3933 
3934   EXPECT_THAT(computation->root_instruction(),
3935               GmockMatch(m::Slice(m::Broadcast(m::Parameter(0)))));
3936 
3937   AlgebraicSimplifier simplifier(default_options_);
3938   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3939 
3940   EXPECT_THAT(computation->root_instruction(),
3941               GmockMatch(m::Broadcast(m::Parameter(0))));
3942 }
3943 
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeToReshapeOfSlice)3944 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) {
3945   HloComputation::Builder builder(TestName());
3946   const int64_t dim0 = 11;
3947   const int64_t dim1 = 12;
3948   const int64_t dim2 = 13;
3949   HloInstruction* param =
3950       builder.AddInstruction(HloInstruction::CreateParameter(
3951           0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param"));
3952   HloInstruction* original_reshape =
3953       builder.AddInstruction(HloInstruction::CreateReshape(
3954           ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param));
3955 
3956   builder.AddInstruction(HloInstruction::CreateSlice(
3957       ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape,
3958       /*start_indices=*/{0, 0, 0},
3959       /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1}));
3960   auto module = CreateNewVerifiedModule();
3961   HloComputation* computation = module->AddEntryComputation(builder.Build());
3962 
3963   EXPECT_THAT(computation->root_instruction(),
3964               GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3965 
3966   AlgebraicSimplifier simplifier(default_options_);
3967   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3968 
3969   EXPECT_THAT(computation->root_instruction(),
3970               GmockMatch(m::Reshape(m::Slice(m::Parameter(0)))));
3971 }
3972 
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeUnchanged)3973 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) {
3974   HloComputation::Builder builder(TestName());
3975   HloInstruction* param =
3976       builder.AddInstruction(HloInstruction::CreateParameter(
3977           0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param"));
3978   HloInstruction* original_reshape =
3979       builder.AddInstruction(HloInstruction::CreateReshape(
3980           ShapeUtil::MakeShape(F32, {3600, 512}), param));
3981 
3982   builder.AddInstruction(HloInstruction::CreateSlice(
3983       ShapeUtil::MakeShape(F32, {960, 512}), original_reshape,
3984       /*start_indices=*/{0, 0},
3985       /*limit_indices=*/{960, 512}, /*strides=*/{1, 1}));
3986   auto module = CreateNewVerifiedModule();
3987   HloComputation* computation = module->AddEntryComputation(builder.Build());
3988 
3989   EXPECT_THAT(computation->root_instruction(),
3990               GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3991 
3992   AlgebraicSimplifier simplifier(default_options_);
3993   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3994 }
3995 
TEST_F(AlgebraicSimplifierTest,RemoveNoopSort)3996 TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
3997   auto builder = HloComputation::Builder(TestName());
3998   auto module = CreateNewVerifiedModule();
3999 
4000   Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
4001   auto keys = builder.AddInstruction(
4002       HloInstruction::CreateParameter(0, keys_shape, "keys"));
4003   TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
4004                            module.get())
4005                    .status());
4006   HloComputation* computation = module->AddEntryComputation(builder.Build());
4007   AlgebraicSimplifier simplifier(default_options_);
4008   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4009   EXPECT_THAT(computation->root_instruction(), keys);
4010 }
4011 
TEST_F(AlgebraicSimplifierTest,ReplaceEffectiveScalarKeyValueSortWithTuple)4012 TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
4013   auto builder = HloComputation::Builder(TestName());
4014   auto module = CreateNewVerifiedModule();
4015 
4016   Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
4017   Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
4018   auto keys = builder.AddInstruction(
4019       HloInstruction::CreateParameter(0, keys_shape, "keys"));
4020   auto values0 = builder.AddInstruction(
4021       HloInstruction::CreateParameter(1, values_shape, "values0"));
4022   auto values1 = builder.AddInstruction(
4023       HloInstruction::CreateParameter(2, values_shape, "values1"));
4024   TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
4025                                {keys_shape, values_shape, values_shape}),
4026                            {keys, values0, values1}, 0, /*is_stable=*/false,
4027                            &builder, module.get())
4028                    .status());
4029   HloComputation* computation = module->AddEntryComputation(builder.Build());
4030   AlgebraicSimplifier simplifier(default_options_);
4031   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4032   EXPECT_THAT(computation->root_instruction(),
4033               GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0),
4034                                   m::Op().Is(values1))));
4035 }
4036 
4037 // Test that A && True is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue)4038 TEST_F(AlgebraicSimplifierTest, AndTrue) {
4039   auto m = CreateNewVerifiedModule();
4040   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4041   HloComputation::Builder builder(TestName());
4042   HloInstruction* param0 = builder.AddInstruction(
4043       HloInstruction::CreateParameter(0, r0pred, "param0"));
4044   HloInstruction* const_true = builder.AddInstruction(
4045       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4046   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4047                                                       param0, const_true));
4048 
4049   auto computation = m->AddEntryComputation(builder.Build());
4050   HloInstruction* root = computation->root_instruction();
4051   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4052   AlgebraicSimplifier simplifier(default_options_);
4053   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4054   root = computation->root_instruction();
4055   EXPECT_EQ(root, param0);
4056 }
4057 
4058 // Test that True && A is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue2)4059 TEST_F(AlgebraicSimplifierTest, AndTrue2) {
4060   auto m = CreateNewVerifiedModule();
4061   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4062   HloComputation::Builder builder(TestName());
4063   HloInstruction* param0 = builder.AddInstruction(
4064       HloInstruction::CreateParameter(0, r0pred, "param0"));
4065   HloInstruction* const_true = builder.AddInstruction(
4066       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4067   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4068                                                       const_true, param0));
4069 
4070   auto computation = m->AddEntryComputation(builder.Build());
4071   HloInstruction* root = computation->root_instruction();
4072   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4073   AlgebraicSimplifier simplifier(default_options_);
4074   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4075   root = computation->root_instruction();
4076   EXPECT_EQ(root, param0);
4077 }
4078 
4079 // Test that A && False is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse)4080 TEST_F(AlgebraicSimplifierTest, AndFalse) {
4081   auto m = CreateNewVerifiedModule();
4082   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4083   HloComputation::Builder builder(TestName());
4084   HloInstruction* param0 = builder.AddInstruction(
4085       HloInstruction::CreateParameter(0, r0pred, "param0"));
4086   HloInstruction* const_false = builder.AddInstruction(
4087       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4088   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4089                                                       param0, const_false));
4090 
4091   auto computation = m->AddEntryComputation(builder.Build());
4092   HloInstruction* root = computation->root_instruction();
4093   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4094   AlgebraicSimplifier simplifier(default_options_);
4095   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4096   root = computation->root_instruction();
4097   EXPECT_EQ(root, const_false);
4098 }
4099 
4100 // Test that False && A is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse2)4101 TEST_F(AlgebraicSimplifierTest, AndFalse2) {
4102   auto m = CreateNewVerifiedModule();
4103   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4104   HloComputation::Builder builder(TestName());
4105   HloInstruction* param0 = builder.AddInstruction(
4106       HloInstruction::CreateParameter(0, r0pred, "param0"));
4107   HloInstruction* const_false = builder.AddInstruction(
4108       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4109   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4110                                                       const_false, param0));
4111 
4112   auto computation = m->AddEntryComputation(builder.Build());
4113   HloInstruction* root = computation->root_instruction();
4114   EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4115   AlgebraicSimplifier simplifier(default_options_);
4116   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4117   root = computation->root_instruction();
4118   EXPECT_EQ(root, const_false);
4119 }
4120 
4121 // Test that A || True is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue)4122 TEST_F(AlgebraicSimplifierTest, OrTrue) {
4123   auto m = CreateNewVerifiedModule();
4124   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4125   HloComputation::Builder builder(TestName());
4126   HloInstruction* param0 = builder.AddInstruction(
4127       HloInstruction::CreateParameter(0, r0pred, "param0"));
4128   HloInstruction* const_true = builder.AddInstruction(
4129       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4130   builder.AddInstruction(
4131       HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true));
4132 
4133   auto computation = m->AddEntryComputation(builder.Build());
4134   HloInstruction* root = computation->root_instruction();
4135   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4136   AlgebraicSimplifier simplifier(default_options_);
4137   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4138   root = computation->root_instruction();
4139   EXPECT_EQ(root, const_true);
4140 }
4141 
4142 // Test that True || A is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue2)4143 TEST_F(AlgebraicSimplifierTest, OrTrue2) {
4144   auto m = CreateNewVerifiedModule();
4145   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4146   HloComputation::Builder builder(TestName());
4147   HloInstruction* param0 = builder.AddInstruction(
4148       HloInstruction::CreateParameter(0, r0pred, "param0"));
4149   HloInstruction* const_true = builder.AddInstruction(
4150       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4151   builder.AddInstruction(
4152       HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0));
4153 
4154   auto computation = m->AddEntryComputation(builder.Build());
4155   HloInstruction* root = computation->root_instruction();
4156   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4157   AlgebraicSimplifier simplifier(default_options_);
4158   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4159   root = computation->root_instruction();
4160   EXPECT_EQ(root, const_true);
4161 }
4162 
4163 // Test that A || False is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse)4164 TEST_F(AlgebraicSimplifierTest, OrFalse) {
4165   auto m = CreateNewVerifiedModule();
4166   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4167   HloComputation::Builder builder(TestName());
4168   HloInstruction* param0 = builder.AddInstruction(
4169       HloInstruction::CreateParameter(0, r0pred, "param0"));
4170   HloInstruction* const_false = builder.AddInstruction(
4171       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4172   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
4173                                                       param0, const_false));
4174 
4175   auto computation = m->AddEntryComputation(builder.Build());
4176   HloInstruction* root = computation->root_instruction();
4177   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4178   AlgebraicSimplifier simplifier(default_options_);
4179   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4180   root = computation->root_instruction();
4181   EXPECT_EQ(root, param0);
4182 }
4183 
4184 // Test that False || A is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse2)4185 TEST_F(AlgebraicSimplifierTest, OrFalse2) {
4186   auto m = CreateNewVerifiedModule();
4187   Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4188   HloComputation::Builder builder(TestName());
4189   HloInstruction* param0 = builder.AddInstruction(
4190       HloInstruction::CreateParameter(0, r0pred, "param0"));
4191   HloInstruction* const_false = builder.AddInstruction(
4192       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4193   builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
4194                                                       const_false, param0));
4195 
4196   auto computation = m->AddEntryComputation(builder.Build());
4197   HloInstruction* root = computation->root_instruction();
4198   EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4199   AlgebraicSimplifier simplifier(default_options_);
4200   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4201   root = computation->root_instruction();
4202   EXPECT_EQ(root, param0);
4203 }
4204 
4205 // Used for TEST_Ps that test merging (or not) of a kPad instruction into a
4206 // convolution's Window.
4207 struct ConvPaddingTestcase {
ConvPaddingTestcasexla::__anon8c2363a90111::ConvPaddingTestcase4208   ConvPaddingTestcase(absl::string_view padding,
4209                       absl::string_view orig_conv_window,
4210                       absl::string_view expected_conv_window)
4211       : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
4212                             /*pad_value=*/0) {}
4213 
ConvPaddingTestcasexla::__anon8c2363a90111::ConvPaddingTestcase4214   ConvPaddingTestcase(absl::string_view padding,
4215                       absl::string_view orig_conv_window,
4216                       absl::string_view expected_conv_window, float pad_value)
4217       : padding(padding),
4218         orig_conv_window(orig_conv_window),
4219         expected_conv_window(expected_conv_window),
4220         pad_value(pad_value) {}
4221 
ToStringxla::__anon8c2363a90111::ConvPaddingTestcase4222   std::string ToString() const {
4223     return absl::StrFormat(
4224         "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
4225         "pad_value=%f",
4226         padding, orig_conv_window, expected_conv_window, pad_value);
4227   }
4228 
4229   std::string padding;
4230   std::string orig_conv_window;
4231   std::string expected_conv_window;
4232   float pad_value;
4233 };
4234 
4235 // ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
4236 // computation that does
4237 //
4238 //   conv(pad(param0, padding=padding), param1), window=orig_conv_window
4239 //
4240 // gets transformed by AlgebraicSimplifier to
4241 //
4242 //   conv(param0, param1), window=expected_conv_window
4243 //
4244 // or, if expected_conv_window is the empty string, checks that
4245 // AlgebraicSimplifier does *not* transform the original convolution.
4246 class ConvInputPaddingTest
4247     : public AlgebraicSimplifierTest,
4248       public ::testing::WithParamInterface<ConvPaddingTestcase> {};
4249 
4250 INSTANTIATE_TEST_SUITE_P(
4251     ConvInputPaddingTestCases, ConvInputPaddingTest,
4252     ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
4253         // Merge this edge padding into the conv.
4254         {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
4255         // Merge this edge padding with the conv's edge padding.
4256         {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
4257         // Merge this interior-padded kPad with the unpadded conv.  The 3x6
4258         // interior padding gets transformed to 4x7 conv lhs dilation.
4259         {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
4260         // kPad has dilation on one dim, conv has it on the other; merge them.
4261         {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
4262         // kPad has dilation and edge padding on one dim, conv has them on the
4263         // other; merge them.
4264         {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
4265          "pad=0_1x3_0 lhs_dilate=2x10"},
4266 
4267         // Don't transform if the pad value is nonzero.
4268         {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
4269 
4270         // We refuse to transform the following because on some dimension, one
4271         // of the kPad and conv has dilation and the other has some sort of
4272         // padding.
4273         {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
4274         {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
4275         {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
4276         {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
4277         {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
4278         {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
4279 
4280         // We can't merge feature or batch padding into the conv.
4281         {"1_0x0_0x0_0x0_0", "", ""},
4282         {"0_0x1_0x0_0x0_0", "", ""},
4283     }));
4284 
TEST_P(ConvInputPaddingTest,DoTest)4285 TEST_P(ConvInputPaddingTest, DoTest) {
4286   ConvPaddingTestcase testcase = GetParam();
4287 
4288   // It would be better to put the testcase's ToString into the test name, but
4289   // gUnit has constraints on what can go into test names, and any reasonable
4290   // implementation of ToString() seems to violate them.
4291   SCOPED_TRACE(testcase.ToString());
4292 
4293   auto builder = HloComputation::Builder(TestName());
4294   auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
4295       0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}),  // bf01
4296       "input"));
4297   auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
4298       LiteralUtil::CreateR0(testcase.pad_value)));
4299 
4300   PaddingConfig padding_config =
4301       ParsePaddingConfig(testcase.padding).ValueOrDie();
4302   auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
4303       ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
4304                                     padding_config)
4305           .ValueOrDie(),
4306       input, pad_value, padding_config));
4307 
4308   auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
4309       1,
4310       ShapeUtil::MakeShape(
4311           F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}),  // io01
4312       "input"));
4313 
4314   ConvolutionDimensionNumbers dnums =
4315       ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
4316   Window window =
4317       ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
4318           .ValueOrDie();
4319   builder.AddInstruction(HloInstruction::CreateConvolve(
4320       ShapeInference::InferConvolveShape(
4321           lhs_pad->shape(), filter->shape(),
4322           /*feature_group_count=*/1,
4323           /*batch_group_count=*/1, window, dnums,
4324           /*preferred_element_type=*/std::nullopt)
4325           .ValueOrDie(),
4326       lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
4327       window, dnums, DefaultPrecisionConfig(2)));
4328   auto module = CreateNewVerifiedModule();
4329   module->AddEntryComputation(builder.Build());
4330 
4331   AlgebraicSimplifier simplifier(default_options_);
4332   if (testcase.expected_conv_window.empty()) {
4333     ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4334   } else {
4335     ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4336     auto* conv = module->entry_computation()->root_instruction();
4337     SCOPED_TRACE(module->ToString());
4338     ASSERT_THAT(conv,
4339                 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
4340     EXPECT_EQ(window_util::ToString(conv->window()),
4341               absl::StrCat("size=3x3 ", testcase.expected_conv_window));
4342   }
4343 }
4344 
4345 // ConvFilterPaddingTest (and its one associated TEST_P) checks that a
4346 // computation that does
4347 //
4348 //   conv(param0, pad(param1, padding=padding)), window=orig_conv_window
4349 //
4350 // gets transformed by AlgebraicSimplifier to
4351 //
4352 //   conv(param0, param1), window=expected_conv_window
4353 //
4354 // or, if expected_conv_window is the empty string, checks that
4355 // AlgebraicSimplifier does *not* transform the original convolution.
4356 class ConvFilterPaddingTest
4357     : public AlgebraicSimplifierTest,
4358       public ::testing::WithParamInterface<ConvPaddingTestcase> {};
4359 
4360 INSTANTIATE_TEST_SUITE_P(
4361     ConvFilterPaddingTestCases, ConvFilterPaddingTest,
4362     ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
4363         // Can only merge interior padding on the filter's spatial dimensions;
4364         // all
4365         // other paddings (edge padding and interior padding on the channel
4366         // dims)
4367         // should be rejected out of hand.
4368         {"1_0_0x0_0_0x0_0x0_0", "", ""},
4369         {"0_1_0x0_0_0x0_0x0_0", "", ""},
4370         {"0_0_1x0_0_0x0_0x0_0", "", ""},
4371         {"0_0_0x1_0_0x0_0x0_0", "", ""},
4372         {"0_0_0x0_1_0x0_0x0_0", "", ""},
4373         {"0_0_0x0_0_1x0_0x0_0", "", ""},
4374         {"0_0_0x0_0_0x1_0x0_0", "", ""},
4375         {"0_0_0x0_0_0x0_1x0_0", "", ""},
4376         {"0_0_0x0_0_0x0_0x1_0", "", ""},
4377         {"0_0_0x0_0_0x0_0x0_1", "", ""},
4378 
4379         // Interior padding on channel dims can be merged into the conv, so long
4380         // as the conv and pad don't have interior padding on the same dim.
4381         {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
4382         {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
4383         {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
4384         {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
4385         {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
4386 
4387         // Can't merge if for a given dim there's interior padding on both the
4388         // pad and conv.
4389         {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
4390         {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
4391 
4392         // Don't transform if the pad value is nonzero.
4393         {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
4394     }));
4395 
TEST_P(ConvFilterPaddingTest,DoIt)4396 TEST_P(ConvFilterPaddingTest, DoIt) {
4397   ConvPaddingTestcase testcase = GetParam();
4398 
4399   // It would be better to put the testcase's ToString into the test name, but
4400   // gUnit has constraints on what can go into test names, and any reasonable
4401   // implementation of ToString() seems to violate them.
4402   SCOPED_TRACE(testcase.ToString());
4403 
4404   auto builder = HloComputation::Builder(TestName());
4405   auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
4406       LiteralUtil::CreateR0(testcase.pad_value)));
4407   auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
4408       1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}),  // io01
4409       "input"));
4410   PaddingConfig padding_config =
4411       ParsePaddingConfig(testcase.padding).ValueOrDie();
4412   auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
4413       ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
4414                                     padding_config)
4415           .ValueOrDie(),
4416       filter, pad_value, padding_config));
4417 
4418   auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
4419       0,
4420       ShapeUtil::MakeShape(
4421           F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}),  // bf01
4422       "input"));
4423 
4424   ConvolutionDimensionNumbers dnums =
4425       ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
4426   Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
4427                                               rhs_pad->shape().dimensions(2),
4428                                               rhs_pad->shape().dimensions(3),
4429                                               testcase.orig_conv_window))
4430                       .ValueOrDie();
4431 
4432   // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
4433   // after the transformation.
4434   PrecisionConfig precision_config;
4435   precision_config.add_operand_precision(PrecisionConfig::HIGH);
4436   precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
4437 
4438   builder.AddInstruction(HloInstruction::CreateConvolve(
4439       ShapeInference::InferConvolveShape(
4440           input->shape(), rhs_pad->shape(),
4441           /*feature_group_count=*/1,
4442           /*batch_group_count=*/1, window, dnums,
4443           /*preferred_element_type=*/std::nullopt)
4444           .ValueOrDie(),
4445       input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
4446       window, dnums, precision_config));
4447 
4448   auto module = CreateNewVerifiedModule();
4449   module->AddEntryComputation(builder.Build());
4450 
4451   AlgebraicSimplifier simplifier(default_options_);
4452   if (testcase.expected_conv_window.empty()) {
4453     ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4454   } else {
4455     ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4456     auto* conv = module->entry_computation()->root_instruction();
4457     SCOPED_TRACE(module->ToString());
4458     ASSERT_THAT(conv,
4459                 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
4460     EXPECT_EQ(window_util::ToString(conv->window()),
4461               absl::StrFormat("size=%dx%d %s",
4462                               conv->operand(1)->shape().dimensions(2),
4463                               conv->operand(1)->shape().dimensions(3),
4464                               testcase.expected_conv_window));
4465     EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
4466                     ->precision_config()
4467                     .operand_precision(),
4468                 ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
4469   }
4470 }
4471 
TEST_F(AlgebraicSimplifierTest,ConvertConvToMatmul)4472 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
4473   struct ConvTestOptions {
4474     int in_batch = 10;
4475     int in_height = 2;
4476     int in_width = 2;
4477     int in_channels = 3;
4478     int f_width = 1;
4479     int f_height = 1;
4480     int f_output_channels = 10;
4481     int row_stride = 1;
4482     int row_padding = 0;
4483     int col_stride = 1;
4484     int col_padding = 0;
4485     bool input_minor_to_major_layout = false;
4486     bool filter_minor_to_major_layout = false;
4487     bool output_minor_to_major_layout = false;
4488 
4489     const char* dim_order = "NHWC";         // can use chars NHWC in any order.
4490     const char* kernel_dim_order = "HWIO";  // can use chars HWIO in any order.
4491 
4492     ConvTestOptions& Reset() {
4493       *this = ConvTestOptions();
4494       return *this;
4495     }
4496   };
4497 
4498   ConvTestOptions options;
4499 
4500   // Builds a convolution from <options> and runs algebraic simplification on
4501   // the computation. Returns a string description of the result of
4502   // simplification.
4503   auto build_and_simplify = [&]() -> std::string {
4504     HloComputation::Builder b(TestName());
4505 
4506     Window window;
4507     auto* f_dim_1 = window.add_dimensions();
4508     f_dim_1->set_size(options.f_height);
4509     f_dim_1->set_stride(options.row_stride);
4510     f_dim_1->set_padding_low(options.row_padding);
4511     f_dim_1->set_padding_high(options.row_padding);
4512     f_dim_1->set_window_dilation(1);
4513     f_dim_1->set_base_dilation(1);
4514     auto* f_dim_2 = window.add_dimensions();
4515     f_dim_2->set_size(options.f_width);
4516     f_dim_2->set_stride(options.col_stride);
4517     f_dim_2->set_padding_low(options.col_padding);
4518     f_dim_2->set_padding_high(options.col_padding);
4519     f_dim_2->set_window_dilation(1);
4520     f_dim_2->set_base_dilation(1);
4521 
4522     ConvolutionDimensionNumbers dnums;
4523     std::vector<int64_t> in_dims;
4524     int in_channel_idx = -1;
4525     // filled in later
4526     dnums.add_input_spatial_dimensions(-1);
4527     dnums.add_output_spatial_dimensions(-1);
4528     dnums.add_input_spatial_dimensions(-1);
4529     dnums.add_output_spatial_dimensions(-1);
4530     for (int i = 0; i < strlen(options.dim_order); ++i) {
4531       char ch = options.dim_order[i];
4532       if (ch == 'N') {
4533         dnums.set_input_batch_dimension(i);
4534         dnums.set_output_batch_dimension(i);
4535         in_dims.push_back(options.in_batch);
4536       } else if (ch == 'H') {
4537         dnums.set_input_spatial_dimensions(0, i);
4538         dnums.set_output_spatial_dimensions(0, i);
4539         in_dims.push_back(options.in_height);
4540       } else if (ch == 'W') {
4541         dnums.set_input_spatial_dimensions(1, i);
4542         dnums.set_output_spatial_dimensions(1, i);
4543         in_dims.push_back(options.in_width);
4544       } else if (ch == 'C') {
4545         dnums.set_input_feature_dimension(i);
4546         dnums.set_output_feature_dimension(i);
4547         in_dims.push_back(options.in_channels);
4548         in_channel_idx = i;
4549       }
4550     }
4551 
4552     std::vector<int64_t> f_dims;
4553     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
4554     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
4555     for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
4556       char ch = options.kernel_dim_order[i];
4557       if (ch == 'H') {
4558         dnums.set_kernel_spatial_dimensions(0, i);
4559         f_dims.push_back(options.f_height);
4560       } else if (ch == 'W') {
4561         dnums.set_kernel_spatial_dimensions(1, i);
4562         f_dims.push_back(options.f_width);
4563       } else if (ch == 'I') {
4564         dnums.set_kernel_input_feature_dimension(i);
4565         f_dims.push_back(options.in_channels);
4566       } else if (ch == 'O') {
4567         dnums.set_kernel_output_feature_dimension(i);
4568         f_dims.push_back(options.f_output_channels);
4569       }
4570     }
4571 
4572     auto make_shape = [](absl::Span<const int64_t> dims,
4573                          bool minor_to_major_layout) {
4574       if (minor_to_major_layout) {
4575         return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
4576       } else {
4577         return ShapeUtil::MakeShape(F32, dims);
4578       }
4579     };
4580     auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
4581     auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
4582 
4583     HloInstruction* input =
4584         b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
4585     HloInstruction* filter =
4586         b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
4587     Shape out_shape = ShapeInference::InferConvolveShape(
4588                           in_shape, f_shape, /*feature_group_count=*/1,
4589                           /*batch_group_count=*/1, window, dnums,
4590                           /*preferred_element_type=*/std::nullopt)
4591                           .ValueOrDie();
4592     if (options.output_minor_to_major_layout) {
4593       out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
4594                                                  {0, 1, 2, 3});
4595     }
4596 
4597     b.AddInstruction(HloInstruction::CreateConvolve(
4598         out_shape, input, filter,
4599         /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
4600         DefaultPrecisionConfig(2)));
4601 
4602     auto module = CreateNewVerifiedModule();
4603     auto* computation = module->AddEntryComputation(b.Build());
4604 
4605     AlgebraicSimplifierOptions simplifier_options;
4606     simplifier_options.set_is_layout_sensitive(true);
4607     AlgebraicSimplifier simplifier(simplifier_options);
4608     if (!simplifier.Run(module.get()).ValueOrDie()) {
4609       return "NO_CHANGE";
4610     }
4611     auto* root = computation->root_instruction();
4612     if (root->opcode() == HloOpcode::kBitcast &&
4613         root->operand(0)->opcode() == HloOpcode::kDot) {
4614       auto lhs_shape = root->operand(0)->operand(0)->shape();
4615       auto rhs_shape = root->operand(0)->operand(1)->shape();
4616       return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
4617                           absl::StrJoin(rhs_shape.dimensions(), "x"));
4618     }
4619     return "UNEXPECTED CHANGE";
4620   };
4621 
4622   // Default options are the simplest case and succeed.
4623   options.Reset();
4624   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4625 
4626   // Swapping dim spatial and batch order works.
4627   options.Reset().dim_order = "NWHC";
4628   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4629   options.Reset().dim_order = "WHNC";
4630   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4631   // Channel dimension earlier fails.
4632   options.Reset().dim_order = "HWCN";
4633   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4634   options.Reset().dim_order = "CHWN";
4635   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4636 
4637   // Filtering dims spatial dims can be anywhere, since they are 1x1.
4638   options.Reset().kernel_dim_order = "WHIO";
4639   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4640   options.Reset().kernel_dim_order = "IWOH";
4641   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4642   options.Reset().kernel_dim_order = "IWHO";
4643   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4644   // But moving output channel before input channel fails.
4645   options.Reset().kernel_dim_order = "HWOI";
4646   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4647   options.Reset().kernel_dim_order = "WHOI";
4648   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4649   options.Reset().kernel_dim_order = "OWIH";
4650   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4651   options.Reset().kernel_dim_order = "OWHI";
4652   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4653 
4654   // Combine different dim and kernel dim orders.
4655   options.Reset().kernel_dim_order = "IWHO";
4656   options.dim_order = "WHNC";
4657   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4658 
4659   // Test invalid cases from wrong filter size, strides, or padding.
4660   options.Reset().f_width = 2;
4661   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4662   options.Reset().f_height = 2;
4663   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4664   options.Reset().row_stride = 2;
4665   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4666   options.Reset().col_stride = 2;
4667   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4668   options.Reset().col_padding = 1;
4669   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4670   options.Reset().row_padding = 1;
4671   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4672 
4673   // The default dim_order is "NHWC". Col-major layout makes C the most major.
4674   options.Reset().input_minor_to_major_layout = true;
4675   options.output_minor_to_major_layout = true;
4676   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4677 
4678   // The input and output have different layouts.
4679   options.Reset().input_minor_to_major_layout = true;
4680   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4681 
4682   // C is most minor, and I is more major than O.
4683   options.Reset().input_minor_to_major_layout = true;
4684   options.filter_minor_to_major_layout = true;
4685   options.output_minor_to_major_layout = true;
4686   options.dim_order = "CHWN";
4687   options.kernel_dim_order = "OIHW";
4688   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4689 
4690   // C is not the most minor dimension.
4691   options.Reset().input_minor_to_major_layout = true;
4692   options.filter_minor_to_major_layout = true;
4693   options.output_minor_to_major_layout = true;
4694   options.dim_order = "HWNC";
4695   options.kernel_dim_order = "OIHW";
4696   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4697 
4698   // I is more minor than O.
4699   options.Reset().input_minor_to_major_layout = true;
4700   options.filter_minor_to_major_layout = true;
4701   options.output_minor_to_major_layout = true;
4702   options.dim_order = "CHWN";
4703   options.kernel_dim_order = "IOHW";
4704   EXPECT_EQ("NO_CHANGE", build_and_simplify());
4705 }
4706 
4707 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
4708 // broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToSlice)4709 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
4710   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
4711   HloComputation::Builder builder(TestName());
4712   HloInstruction* scalar_param = builder.AddInstruction(
4713       HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
4714 
4715   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
4716   HloInstruction* broadcast = builder.AddInstruction(
4717       HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
4718 
4719   Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
4720   HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
4721       slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
4722 
4723   auto module = CreateNewVerifiedModule();
4724   auto computation = module->AddEntryComputation(builder.Build());
4725 
4726   HloInstruction* root = computation->root_instruction();
4727   EXPECT_EQ(root, slice);
4728   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
4729 
4730   AlgebraicSimplifier simplifier(default_options_);
4731 
4732   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4733 
4734   // Running simplification again should not result in any further changes.
4735   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4736   EXPECT_THAT(computation->root_instruction(),
4737               GmockMatch(m::Broadcast(m::Op().Is(scalar_param))
4738                              .WithShapeEqualTo(&slice_shape)));
4739 }
4740 
4741 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
4742 // single broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToTransposeReshape)4743 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
4744   HloComputation::Builder builder(TestName());
4745   HloInstruction* forty_two = builder.AddInstruction(
4746       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
4747 
4748   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
4749   HloInstruction* broadcast = builder.AddInstruction(
4750       HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
4751 
4752   HloInstruction* transpose =
4753       builder.AddInstruction(HloInstruction::CreateTranspose(
4754           ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
4755 
4756   Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
4757   HloInstruction* reshape = builder.AddInstruction(
4758       HloInstruction::CreateReshape(reshape_shape, transpose));
4759 
4760   auto module = CreateNewVerifiedModule();
4761   auto computation = module->AddEntryComputation(builder.Build());
4762 
4763   HloInstruction* root = computation->root_instruction();
4764   EXPECT_EQ(root, reshape);
4765   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
4766 
4767   AlgebraicSimplifier simplifier(default_options_);
4768   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4769   EXPECT_THAT(computation->root_instruction(),
4770               GmockMatch(m::Broadcast(m::Op().Is(forty_two))
4771                              .WithShapeEqualTo(&reshape_shape)));
4772 }
4773 
4774 // Test that a depth-to-space transformation expressed as
4775 // reshape(transpose(reshape(op))) can simplify to
4776 // reshape(concat(slice(op), ..., slice(op))).
TEST_F(AlgebraicSimplifierTest,TransposeReshapeToConcatSlice)4777 TEST_F(AlgebraicSimplifierTest, TransposeReshapeToConcatSlice) {
4778   const std::string& hlo_string = R"(
4779 HloModule TransposeReshapeDepthToSpace
4780 
4781 ENTRY entry {
4782   %param = f32[8,14,14,128]{0,1,2,3} parameter(0)
4783   %reshape.1 = f32[8,14,14,2,64] reshape(%param)
4784   %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4}
4785   ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose)
4786 }
4787 )";
4788   TF_ASSERT_OK_AND_ASSIGN(auto module,
4789                           ParseAndReturnVerifiedModule(hlo_string));
4790 
4791   AlgebraicSimplifier simplifier(default_options_);
4792   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4793 
4794   Shape result_shape = ShapeUtil::MakeShape(F32, {8, 28, 14, 64});
4795   EXPECT_THAT(module->entry_computation()->root_instruction(),
4796               GmockMatch(m::Reshape(m::Concatenate(m::Slice(m::Parameter(0)),
4797                                                    m::Slice(m::Parameter(0))))
4798                              .WithShapeEqualTo(&result_shape)));
4799 }
4800 
4801 // Test that a depth-to-space transformation expressed as
4802 // reshape(transpose(reshape(op))) with a large number of chunks
4803 // is not rewritten.
TEST_F(AlgebraicSimplifierTest,TransposeReshapeTooLarge)4804 TEST_F(AlgebraicSimplifierTest, TransposeReshapeTooLarge) {
4805   const std::string& hlo_string = R"(
4806 HloModule TransposeReshapeDepthToSpaceBig
4807 
4808 ENTRY entry {
4809   %param = f32[8,14,14,128]{0,1,2,3} parameter(0)
4810   %reshape.1 = f32[8,14,14,8,16] reshape(%param)
4811   %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4}
4812   ROOT %reshape.2 = f32[8,112,14,16] reshape(%transpose)
4813 }
4814 )";
4815   TF_ASSERT_OK_AND_ASSIGN(auto module,
4816                           ParseAndReturnVerifiedModule(hlo_string));
4817 
4818   AlgebraicSimplifier simplifier(default_options_);
4819   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4820 }
4821 
4822 // Test that a reshape(transpose(reshape(op))) that does not constitute a
4823 // depth-to-space transformation is not rewritten.
TEST_F(AlgebraicSimplifierTest,TransposeReshapeNotDepthToSpace)4824 TEST_F(AlgebraicSimplifierTest, TransposeReshapeNotDepthToSpace) {
4825   const std::string& hlo_string = R"(
4826 HloModule TransposeReshapeDepthToSpace
4827 
4828 ENTRY entry {
4829   %param = f32[8,14,14,128]{0,1,2,3} parameter(0)
4830   %reshape.1 = f32[8,14,14,2,64] reshape(%param)
4831   %transpose = transpose(%reshape.1), dimensions={0,3,1,2,4}
4832   ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose)
4833 }
4834 )";
4835   TF_ASSERT_OK_AND_ASSIGN(auto module,
4836                           ParseAndReturnVerifiedModule(hlo_string));
4837 
4838   AlgebraicSimplifier simplifier(default_options_);
4839   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4840 }
4841 
4842 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest,FoldPadIntoReduceWindow)4843 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
4844   const std::string& hlo_string = R"(
4845 HloModule test
4846 fn {
4847   p0 = f32[] parameter(0)
4848   p1 = f32[] parameter(1)
4849   ROOT add = f32[] add(p0, p1)
4850 }
4851 ENTRY entry {
4852   param = f32[1,2,3,4] parameter(0)
4853   const = f32[] constant(5)
4854   pad = pad(param, const), padding=0_0x1_0x0_0x0_2
4855   ROOT r = reduce-window(pad, const), to_apply=fn, window={size=2x2x2x2 lhs_dilate=1x1x1x3 pad=10_100x10_100x10_100x10_100}
4856 }
4857 )";
4858   TF_ASSERT_OK_AND_ASSIGN(auto module,
4859                           ParseAndReturnVerifiedModule(hlo_string));
4860   AlgebraicSimplifier simplifier(default_options_);
4861   ASSERT_TRUE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4862   // Running simplification again should not result in any further changes.
4863   ASSERT_FALSE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4864 
4865   // Verify the result
4866   HloInstruction* root = module->entry_computation()->root_instruction();
4867   EXPECT_THAT(root,
4868               GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
4869   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4870   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4871   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4872   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4873   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4874   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4875   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4876   EXPECT_EQ(root->window().dimensions(3).padding_high(), 106);
4877 }
4878 
4879 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
4880 // ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest,FoldConvertedPadIntoReduceWindow)4881 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
4882   const std::string& hlo_string = R"(
4883 HloModule test
4884 fn {
4885   p0 = f32[] parameter(0)
4886   p1 = f32[] parameter(1)
4887   ROOT add = f32[] add(p0, p1)
4888 }
4889 ENTRY entry {
4890   param = bf16[1,2,3,4] parameter(0)
4891   const = bf16[] constant(5)
4892   pad = pad(param, const), padding=0_0x1_0x0_0x0_2
4893   converted = f32[1,3,3,6] convert(pad)
4894   ROOT r = reduce-window(converted, const), to_apply=fn, window={size=2x2x2x2 pad=10_100x10_100x10_100x10_100}
4895 }
4896 )";
4897   TF_ASSERT_OK_AND_ASSIGN(auto module,
4898                           ParseAndReturnVerifiedModule(hlo_string));
4899   AlgebraicSimplifier simplifier(default_options_);
4900   ASSERT_TRUE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4901   // Running simplification again should not result in any further changes.
4902   ASSERT_FALSE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4903 
4904   // Verify the result
4905   HloInstruction* root = module->entry_computation()->root_instruction();
4906   EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)),
4907                                                m::Constant())));
4908   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4909   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4910   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4911   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4912   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4913   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4914   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4915   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
4916 }
4917 
TEST_F(AlgebraicSimplifierTest,ReversalOfTrivialDimensionsToBitcast)4918 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
4919   HloComputation::Builder builder(TestName());
4920   const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
4921   HloInstruction* a =
4922       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
4923   builder.AddInstruction(
4924       HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
4925 
4926   auto module = CreateNewVerifiedModule();
4927   auto computation = module->AddEntryComputation(builder.Build());
4928 
4929   AlgebraicSimplifier simplifier(default_options_);
4930   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4931 
4932   HloInstruction* root = computation->root_instruction();
4933   EXPECT_EQ(a, root);
4934   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
4935 }
4936 
TEST_F(AlgebraicSimplifierTest,IteratorInvalidation)4937 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
4938   // Dots add computations to the parent module. Test that, when the HloModule's
4939   // computations are updated, then iterator invalidation doesn't occur
4940   // when running on subsequent computations.
4941   auto m = CreateNewVerifiedModule();
4942   Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
4943   HloComputation::Builder builder(TestName() + ".Dot");
4944   HloInstruction* x =
4945       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
4946   HloInstruction* y =
4947       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
4948   DotDimensionNumbers dot_dnums;
4949   dot_dnums.add_lhs_batch_dimensions(0);
4950   dot_dnums.add_rhs_batch_dimensions(0);
4951   builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
4952                                                    DefaultPrecisionConfig(2)));
4953   std::unique_ptr<HloComputation> dot_computation(builder.Build());
4954 
4955   HloComputation::Builder call_builder(TestName() + ".Call");
4956   HloInstruction* zero = call_builder.AddInstruction(
4957       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
4958   HloInstruction* one = call_builder.AddInstruction(
4959       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
4960   call_builder.AddInstruction(
4961       HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
4962 
4963   m->AddEmbeddedComputation(std::move(dot_computation));
4964   m->AddEntryComputation(call_builder.Build());
4965   AlgebraicSimplifier simplifier(default_options_);
4966   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4967 }
4968 
4969 // Test that a constant with tuple shape becomes a tuple of constants.
TEST_F(AlgebraicSimplifierTest,ConstantTupleBecomesTupleOfConstants)4970 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
4971   auto m = CreateNewVerifiedModule();
4972   HloComputation::Builder builder(TestName());
4973   const float constant_scalar = 7.3f;
4974   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
4975   Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
4976                         LiteralUtil::CreateR1<float>(constant_vector)};
4977   Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
4978   builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
4979 
4980   auto computation = m->AddEntryComputation(builder.Build());
4981 
4982   AlgebraicSimplifier simplifier(default_options_);
4983   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4984   EXPECT_THAT(computation->root_instruction(),
4985               GmockMatch(m::Tuple(m::Constant(), m::Constant())));
4986 }
4987 
4988 // A dynamic-slice is trivial if its start indices are all zeroes and the size
4989 // of its input equals the size of its output.  In this case, the dynamic slice
4990 // is equal to its input.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicSlice)4991 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
4992   auto m = CreateNewVerifiedModule();
4993   HloComputation::Builder builder(TestName());
4994 
4995   Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4996   std::vector<HloInstruction*> params;
4997   for (int i = 0; i < 3; ++i) {
4998     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
4999         i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
5000   }
5001   builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5002       shape,
5003       builder.AddInstruction(
5004           HloInstruction::CreateParameter(0, shape, "slice_from")),
5005       params,
5006       /*slice_sizes=*/{10, 100, 1000}));
5007 
5008   auto computation = m->AddEntryComputation(builder.Build());
5009   AlgebraicSimplifier simplifier(default_options_);
5010   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5011   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
5012 }
5013 
TEST_F(AlgebraicSimplifierTest,ConstantDynamicSlice)5014 TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
5015   auto m = CreateNewVerifiedModule();
5016   HloComputation::Builder builder(TestName());
5017 
5018   Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
5019   std::vector<HloInstruction*> params;
5020   for (int i = 0; i < 3; ++i) {
5021     params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
5022         LiteralUtil::CreateR0<int32_t>(2 << (i + 1)))));
5023   }
5024   Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
5025   builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5026       ds_shape,
5027       builder.AddInstruction(
5028           HloInstruction::CreateParameter(0, shape, "operand")),
5029       params,
5030       /*slice_sizes=*/{2, 20, 200}));
5031 
5032   auto computation = m->AddEntryComputation(builder.Build());
5033   AlgebraicSimplifier simplifier(default_options_);
5034   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5035   EXPECT_THAT(computation->root_instruction(),
5036               GmockMatch(m::Slice(m::Parameter())));
5037 }
5038 
5039 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
5040 // size of its "update" equals the size of its output.  In this case, the
5041 // dynamic-update-slice is equal to its update.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicUpdateSlice)5042 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
5043   auto m = CreateNewVerifiedModule();
5044   HloComputation::Builder builder(TestName());
5045 
5046   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
5047   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
5048 
5049   std::vector<HloInstruction*> slice_indices, update_indices;
5050   for (int i = 0; i < 3; ++i) {
5051     slice_indices.push_back(
5052         builder.AddInstruction(HloInstruction::CreateParameter(
5053             i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
5054     update_indices.push_back(
5055         builder.AddInstruction(HloInstruction::CreateParameter(
5056             i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices")));
5057   }
5058   HloInstruction* slice =
5059       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5060           slice_shape,
5061           builder.AddInstruction(
5062               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
5063           slice_indices,
5064           /*slice_sizes=*/{10, 1, 1000}));
5065 
5066   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
5067       slice_shape,
5068       builder.AddInstruction(
5069           HloInstruction::CreateParameter(4, slice_shape, "to_update")),
5070       slice, update_indices));
5071 
5072   auto computation = m->AddEntryComputation(builder.Build());
5073   AlgebraicSimplifier simplifier(default_options_);
5074   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5075   EXPECT_THAT(computation->root_instruction(),
5076               GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(),
5077                                          m::Parameter(), m::Parameter())));
5078 }
5079 
5080 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts)5081 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
5082   auto m = CreateNewVerifiedModule();
5083   HloComputation::Builder builder(TestName());
5084   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
5085   HloInstruction* input_array = builder.AddInstruction(
5086       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
5087   HloInstruction* inner_bcast = builder.AddInstruction(
5088       HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
5089   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
5090   builder.AddInstruction(
5091       HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
5092 
5093   auto computation = m->AddEntryComputation(builder.Build());
5094   HloInstruction* root = computation->root_instruction();
5095   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5096   AlgebraicSimplifier simplifier(default_options_);
5097   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5098   root = computation->root_instruction();
5099   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5100   EXPECT_THAT(root->dimensions(), ElementsAre(2));
5101 }
5102 
5103 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts2)5104 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
5105   auto m = CreateNewVerifiedModule();
5106   HloComputation::Builder builder(TestName());
5107   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
5108   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
5109   HloInstruction* param0 = builder.AddInstruction(
5110       HloInstruction::CreateParameter(0, r2f32, "param0"));
5111   // The initial dimensions go to places 0 and 2 in the 3-dim array,
5112   // and to places 1 and 3 in the 4-dim array,
5113   HloInstruction* inner_bcast = builder.AddInstruction(
5114       HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
5115   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
5116   builder.AddInstruction(
5117       HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
5118 
5119   auto computation = m->AddEntryComputation(builder.Build());
5120   HloInstruction* root = computation->root_instruction();
5121   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5122   AlgebraicSimplifier simplifier(default_options_);
5123   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5124   root = computation->root_instruction();
5125   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0))));
5126   EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
5127 }
5128 
5129 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota)5130 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
5131   auto m = CreateNewVerifiedModule();
5132   HloComputation::Builder builder(TestName());
5133   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
5134   HloInstruction* iota =
5135       builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
5136   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
5137   builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
5138 
5139   auto computation = m->AddEntryComputation(builder.Build());
5140   HloInstruction* root = computation->root_instruction();
5141   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5142   AlgebraicSimplifier simplifier(default_options_);
5143   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5144   root = computation->root_instruction();
5145   EXPECT_THAT(root, GmockMatch(m::Iota()));
5146   EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
5147 }
5148 
5149 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota2)5150 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
5151   auto m = CreateNewVerifiedModule();
5152   HloComputation::Builder builder(TestName());
5153   Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
5154   HloInstruction* iota =
5155       builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
5156   Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
5157   builder.AddInstruction(
5158       HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
5159 
5160   auto computation = m->AddEntryComputation(builder.Build());
5161   HloInstruction* root = computation->root_instruction();
5162   EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5163   AlgebraicSimplifier simplifier(default_options_);
5164   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5165   root = computation->root_instruction();
5166   EXPECT_THAT(root, GmockMatch(m::Iota()));
5167   EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
5168 }
5169 
TEST_F(AlgebraicSimplifierTest,TransposeOfDot)5170 TEST_F(AlgebraicSimplifierTest, TransposeOfDot) {
5171   const char* hlo_string = R"(
5172     HloModule module
5173 
5174     ENTRY test {
5175       lhs = f32[3,4,5] parameter(0)
5176       rhs = f32[6,3,4] parameter(1)
5177       dot = f32[5,6] dot(lhs,rhs), lhs_contracting_dims={0,1}, rhs_contracting_dims={1,2}, operand_precision={highest,high}
5178       ROOT transpose = f32[6,5] transpose(dot), dimensions={1,0}
5179     }
5180   )";
5181   TF_ASSERT_OK_AND_ASSIGN(auto module,
5182                           ParseAndReturnVerifiedModule(hlo_string));
5183 
5184   AlgebraicSimplifierOptions options;
5185   AlgebraicSimplifier simplifier(options);
5186   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5187   auto root = module->entry_computation()->root_instruction();
5188   const HloInstruction* dot;
5189   ASSERT_THAT(root, GmockMatch(m::Dot(&dot, m::Parameter(1), m::Parameter(0))));
5190   EXPECT_EQ(dot->precision_config().operand_precision()[0],
5191             PrecisionConfig::HIGH);
5192   EXPECT_EQ(dot->precision_config().operand_precision()[1],
5193             PrecisionConfig::HIGHEST);
5194 }
5195 
TEST_F(AlgebraicSimplifierTest,TransposeOfBatchDot)5196 TEST_F(AlgebraicSimplifierTest, TransposeOfBatchDot) {
5197   const char* hlo_string = R"(
5198     HloModule module
5199 
5200     ENTRY test {
5201       lhs = f32[10,20,30,40] parameter(0)
5202       rhs = f32[10,20,50,30] parameter(1)
5203       dot = dot(lhs,rhs), lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5204                           lhs_contracting_dims={2}, rhs_contracting_dims={3},
5205                           operand_precision={high, default}
5206       ROOT transpose = transpose(dot), dimensions={0,1,3,2}
5207     }
5208   )";
5209   TF_ASSERT_OK_AND_ASSIGN(auto module,
5210                           ParseAndReturnVerifiedModule(hlo_string));
5211 
5212   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5213   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5214   EXPECT_TRUE(changed);
5215   const HloInstruction* dot;
5216   ASSERT_THAT(module->entry_computation()->root_instruction(),
5217               GmockMatch(m::Dot(&dot, m::Parameter(1), m::Parameter(0))));
5218   DotDimensionNumbers dnums = dot->dot_dimension_numbers();
5219   EXPECT_THAT(dnums.lhs_batch_dimensions(), ElementsAre(0, 1));
5220   EXPECT_THAT(dnums.rhs_batch_dimensions(), ElementsAre(0, 1));
5221   EXPECT_THAT(dnums.lhs_contracting_dimensions(), ElementsAre(3));
5222   EXPECT_THAT(dnums.rhs_contracting_dimensions(), ElementsAre(2));
5223   EXPECT_EQ(dot->precision_config().operand_precision()[0],
5224             PrecisionConfig::DEFAULT);
5225   EXPECT_EQ(dot->precision_config().operand_precision()[1],
5226             PrecisionConfig::HIGH);
5227 }
5228 
TEST_F(AlgebraicSimplifierTest,TransposeOfBatchDimsInBatchDotCantSimplify)5229 TEST_F(AlgebraicSimplifierTest, TransposeOfBatchDimsInBatchDotCantSimplify) {
5230   const char* hlo_string = R"(
5231     HloModule module
5232 
5233     ENTRY test {
5234       lhs = f32[10,20,30,40] parameter(0)
5235       rhs = f32[10,20,50,30] parameter(1)
5236       dot = dot(lhs,rhs), lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5237                           lhs_contracting_dims={2}, rhs_contracting_dims={3}
5238       ROOT transpose = transpose(dot), dimensions={1,0,3,2}
5239     }
5240   )";
5241   TF_ASSERT_OK_AND_ASSIGN(auto module,
5242                           ParseAndReturnVerifiedModule(hlo_string));
5243 
5244   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5245   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5246   EXPECT_FALSE(changed);
5247 }
5248 
TEST_F(AlgebraicSimplifierTest,TransposeOfNonCanonicalBatchDotCantSimplify)5249 TEST_F(AlgebraicSimplifierTest, TransposeOfNonCanonicalBatchDotCantSimplify) {
5250   const char* hlo_string = R"(
5251     HloModule module
5252 
5253     ENTRY test {
5254       p0 = f32[13,11,2,3] parameter(0)
5255       p1 = f32[13,11,3,7,5] parameter(1)
5256       dot1 = dot(p0, p1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
5257       dot2 = dot(p1, p0), rhs_batch_dims={0,1}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, lhs_contracting_dims={2}
5258       trans1 = transpose(dot1), dimensions={0,1,2,4,3}
5259       trans2 = transpose(dot2), dimensions={0,1,2,4,3}
5260       ROOT root = tuple(trans1, trans2)
5261     }
5262   )";
5263   TF_ASSERT_OK_AND_ASSIGN(auto module,
5264                           ParseAndReturnVerifiedModule(hlo_string));
5265 
5266   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5267   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5268   EXPECT_FALSE(changed);
5269 }
5270 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfTranspose)5271 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTranspose) {
5272   const char* hlo_string = R"(
5273     HloModule module
5274 
5275     ENTRY test {
5276       param = f32[12,10,8] parameter(0)
5277       i0 = s32[] parameter(1)
5278       i1 = s32[] parameter(2)
5279       i2 = s32[] parameter(3)
5280       transpose = f32[12,8,10] transpose(param), dimensions={0,2,1}
5281       ROOT slice = f32[2,3,5] dynamic-slice(transpose, i0, i1, i2),
5282         dynamic_slice_sizes={2,3,5}
5283     }
5284   )";
5285   TF_ASSERT_OK_AND_ASSIGN(auto module,
5286                           ParseAndReturnVerifiedModule(hlo_string));
5287 
5288   AlgebraicSimplifierOptions options;
5289   AlgebraicSimplifier simplifier(options);
5290   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5291   auto root = module->entry_computation()->root_instruction();
5292   EXPECT_THAT(root, GmockMatch(m::Transpose(
5293                         m::DynamicSlice(m::Parameter(0), m::Parameter(1),
5294                                         m::Parameter(3), m::Parameter(2)))));
5295 }
5296 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfTrivialReshape)5297 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTrivialReshape) {
5298   const char* hlo_string = R"(
5299     HloModule module
5300 
5301     ENTRY test {
5302       param = f32[12,10,1,8] parameter(0)
5303       i0 = s32[] parameter(1)
5304       i1 = s32[] parameter(2)
5305       i2 = s32[] parameter(3)
5306       z = s32[] constant(0)
5307       reshape = f32[1,12,10,8] reshape(param)
5308       ROOT slice = f32[1,2,3,5] dynamic-slice(reshape, z, i0, i1, i2),
5309         dynamic_slice_sizes={1,2,3,5}
5310     }
5311   )";
5312   TF_ASSERT_OK_AND_ASSIGN(auto module,
5313                           ParseAndReturnVerifiedModule(hlo_string));
5314 
5315   AlgebraicSimplifierOptions options;
5316   options.set_is_layout_sensitive(false);
5317   AlgebraicSimplifier simplifier(options);
5318   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5319   auto root = module->entry_computation()->root_instruction();
5320   EXPECT_THAT(root, GmockMatch(m::Reshape(m::DynamicSlice(
5321                         m::Parameter(0), m::Parameter(1), m::Parameter(2),
5322                         m::Constant(), m::Parameter(3)))));
5323 }
5324 
TEST_F(AlgebraicSimplifierTest,SliceOfPadLow)5325 TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
5326   const char* hlo_string = R"(
5327     HloModule module
5328 
5329     ENTRY test {
5330       param = f32[3,4] parameter(0)
5331       constant = f32[] constant(0.0)
5332       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5333       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
5334     }
5335   )";
5336   TF_ASSERT_OK_AND_ASSIGN(auto module,
5337                           ParseAndReturnVerifiedModule(hlo_string));
5338 
5339   AlgebraicSimplifierOptions options;
5340   AlgebraicSimplifier simplifier(options);
5341   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5342   auto root = module->entry_computation()->root_instruction();
5343   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5344 }
5345 
TEST_F(AlgebraicSimplifierTest,SliceOfPadHigh)5346 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
5347   const char* hlo_string = R"(
5348     HloModule module
5349 
5350     ENTRY test {
5351       param = f32[3,4] parameter(0)
5352       constant = f32[] constant(0.0)
5353       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5354       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
5355     }
5356   )";
5357   TF_ASSERT_OK_AND_ASSIGN(auto module,
5358                           ParseAndReturnVerifiedModule(hlo_string));
5359 
5360   AlgebraicSimplifierOptions options;
5361   AlgebraicSimplifier simplifier(options);
5362   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5363   auto root = module->entry_computation()->root_instruction();
5364   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5365 }
5366 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidNonScalar)5367 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
5368   const char* hlo_string = R"(
5369     HloModule module
5370 
5371     ENTRY test {
5372       param = f32[3,4] parameter(0)
5373       constant = f32[] constant(0.0)
5374       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5375       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]}
5376     }
5377   )";
5378   TF_ASSERT_OK_AND_ASSIGN(auto module,
5379                           ParseAndReturnVerifiedModule(hlo_string));
5380 
5381   AlgebraicSimplifierOptions options;
5382   AlgebraicSimplifier simplifier(options);
5383   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5384   EXPECT_THAT(module->entry_computation()->root_instruction(),
5385               GmockMatch(m::Slice(m::Parameter(0))));
5386 }
5387 
TEST_F(AlgebraicSimplifierTest,SliceOfPad)5388 TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
5389   const char* hlo_string = R"(
5390     HloModule module
5391 
5392     ENTRY test {
5393       param = f32[3,4] parameter(0)
5394       constant = f32[] constant(0.0)
5395       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5396       ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
5397     }
5398   )";
5399   TF_ASSERT_OK_AND_ASSIGN(auto module,
5400                           ParseAndReturnVerifiedModule(hlo_string));
5401 
5402   AlgebraicSimplifierOptions options;
5403   AlgebraicSimplifier simplifier(options);
5404   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5405   auto root = module->entry_computation()->root_instruction();
5406   EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
5407   EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
5408 }
5409 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalarConstant)5410 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
5411   const char* hlo_string = R"(
5412     HloModule module
5413 
5414     ENTRY test {
5415       param = f32[3,4] parameter(0)
5416       constant = f32[] constant(0.0)
5417       pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5418       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
5419     }
5420   )";
5421   TF_ASSERT_OK_AND_ASSIGN(auto module,
5422                           ParseAndReturnVerifiedModule(hlo_string));
5423 
5424   AlgebraicSimplifierOptions options;
5425   AlgebraicSimplifier simplifier(options);
5426   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5427   auto root = module->entry_computation()->root_instruction();
5428   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5429 }
5430 
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalar)5431 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
5432   const char* hlo_string = R"(
5433     HloModule module
5434 
5435     ENTRY test {
5436       param = f32[1,1] parameter(0)
5437       constant = f32[] constant(0.0)
5438       pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
5439       ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
5440     }
5441   )";
5442   TF_ASSERT_OK_AND_ASSIGN(auto module,
5443                           ParseAndReturnVerifiedModule(hlo_string));
5444 
5445   AlgebraicSimplifierOptions options;
5446   AlgebraicSimplifier simplifier(options);
5447   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5448   auto root = module->entry_computation()->root_instruction();
5449   EXPECT_THAT(root, GmockMatch(m::Parameter()));
5450 }
5451 
TEST_F(AlgebraicSimplifierTest,SliceOfPadSomeDimsInPadding)5452 TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
5453   const char* hlo_string = R"(
5454     HloModule module
5455 
5456     ENTRY entry () -> f32[1]{0} {
5457       constant.val = f32[] constant(4)
5458       constant.pad = f32[] constant(-7)
5459       reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val)
5460       pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0
5461       slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]}
5462       ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice)
5463     }
5464   )";
5465   TF_ASSERT_OK_AND_ASSIGN(auto module,
5466                           ParseAndReturnVerifiedModule(hlo_string));
5467 
5468   AlgebraicSimplifierOptions options;
5469   AlgebraicSimplifier simplifier(options);
5470   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5471   auto root = module->entry_computation()->root_instruction();
5472   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
5473 }
5474 
TEST_F(AlgebraicSimplifierTest,SliceOfConcatScalarInput)5475 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
5476   const char* hlo_string = R"(
5477     HloModule module
5478 
5479     ENTRY test {
5480       param.0 = f32[2] parameter(0)
5481       param.1 = f32[1] parameter(1)
5482       param.2 = f32[3] parameter(2)
5483       concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
5484       ROOT slice = f32[1] slice(concat), slice={[2:3]}
5485     }
5486   )";
5487   TF_ASSERT_OK_AND_ASSIGN(auto module,
5488                           ParseAndReturnVerifiedModule(hlo_string));
5489 
5490   AlgebraicSimplifierOptions options;
5491   AlgebraicSimplifier simplifier(options);
5492   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5493   auto root = module->entry_computation()->root_instruction();
5494   EXPECT_THAT(root, GmockMatch(m::Parameter(1)));
5495 }
5496 
TEST_F(AlgebraicSimplifierTest,SliceOfConcatNonScalarInput)5497 TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) {
5498   const char* hlo_string = R"(
5499     HloModule module
5500 
5501     ENTRY test {
5502       param.0 = f32[2] parameter(0)
5503       param.1 = f32[1] parameter(1)
5504       param.2 = f32[3] parameter(2)
5505       concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
5506       ROOT slice = f32[1] slice(concat), slice={[4:5]}
5507     }
5508   )";
5509   TF_ASSERT_OK_AND_ASSIGN(auto module,
5510                           ParseAndReturnVerifiedModule(hlo_string));
5511 
5512   AlgebraicSimplifierOptions options;
5513   AlgebraicSimplifier simplifier(options);
5514   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5515   auto root = module->entry_computation()->root_instruction();
5516   EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2))));
5517   EXPECT_EQ(root->slice_starts(0), 1);
5518   EXPECT_EQ(root->slice_limits(0), 2);
5519 }
5520 
TEST_F(AlgebraicSimplifierTest,ConcatToBroadcast)5521 TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) {
5522   const char* hlo_string = R"(
5523     HloModule module
5524 
5525     ENTRY test {
5526       p = f32[2,1,4] parameter(0)
5527       ROOT concat = f32[2,6,4] concatenate(p,p,p,p,p,p), dimensions={1}
5528     }
5529   )";
5530   TF_ASSERT_OK_AND_ASSIGN(auto module,
5531                           ParseAndReturnVerifiedModule(hlo_string));
5532 
5533   AlgebraicSimplifierOptions options;
5534   AlgebraicSimplifier simplifier(options);
5535   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5536   auto root = module->entry_computation()->root_instruction();
5537   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
5538 }
5539 
TEST_F(AlgebraicSimplifierTest,NegateNegate)5540 TEST_F(AlgebraicSimplifierTest, NegateNegate) {
5541   const char* hlo_string = R"(
5542     HloModule module
5543 
5544     ENTRY test {
5545       param.0 = f32[2] parameter(0)
5546       neg.0 = f32[2] negate(param.0)
5547       ROOT neg.1 = f32[2] negate(neg.0)
5548     }
5549   )";
5550   TF_ASSERT_OK_AND_ASSIGN(auto module,
5551                           ParseAndReturnVerifiedModule(hlo_string));
5552 
5553   AlgebraicSimplifierOptions options;
5554   AlgebraicSimplifier simplifier(options);
5555   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5556   auto root = module->entry_computation()->root_instruction();
5557   EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5558 }
5559 
TEST_F(AlgebraicSimplifierTest,NotNot)5560 TEST_F(AlgebraicSimplifierTest, NotNot) {
5561   const char* hlo_string = R"(
5562     HloModule module
5563 
5564     ENTRY test {
5565       param.0 = pred[2] parameter(0)
5566       not.0 = pred[2] not(param.0)
5567       ROOT not.1 = pred[2] not(not.0)
5568     }
5569   )";
5570   TF_ASSERT_OK_AND_ASSIGN(auto module,
5571                           ParseAndReturnVerifiedModule(hlo_string));
5572 
5573   AlgebraicSimplifierOptions options;
5574   AlgebraicSimplifier simplifier(options);
5575   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5576   auto root = module->entry_computation()->root_instruction();
5577   EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5578 }
5579 
TEST_F(AlgebraicSimplifierTest,BatchDotTransposeOperands)5580 TEST_F(AlgebraicSimplifierTest, BatchDotTransposeOperands) {
5581   const char* hlo_string = R"(
5582     HloModule module
5583 
5584     ENTRY test {
5585       lhs = f32[10,20,30,40] parameter(0)
5586       rhs = f32[10,20,50,30] parameter(1)
5587       lhs_t = transpose(lhs), dimensions={0,1,3,2}
5588       rhs_t = transpose(rhs), dimensions={0,1,3,2}
5589       ROOT root = dot(lhs_t, rhs_t),
5590                   lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5591                   lhs_contracting_dims={3}, rhs_contracting_dims={2},
5592                   operand_precision={default, high}
5593     }
5594   )";
5595   TF_ASSERT_OK_AND_ASSIGN(auto module,
5596                           ParseAndReturnVerifiedModule(hlo_string));
5597 
5598   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5599   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5600   EXPECT_TRUE(changed);
5601   const HloInstruction* dot;
5602   ASSERT_THAT(
5603       module->entry_computation()->root_instruction(),
5604       GmockMatch(m::Transpose(m::Dot(&dot, m::Parameter(1), m::Parameter(0)))));
5605   EXPECT_EQ(dot->precision_config().operand_precision()[0],
5606             PrecisionConfig::HIGH);
5607   EXPECT_EQ(dot->precision_config().operand_precision()[1],
5608             PrecisionConfig::DEFAULT);
5609 }
5610 
TEST_F(AlgebraicSimplifierTest,BatchDotTransposeBatchDims)5611 TEST_F(AlgebraicSimplifierTest, BatchDotTransposeBatchDims) {
5612   const char* hlo_string = R"(
5613     HloModule module
5614 
5615     ENTRY test {
5616       lhs = f32[10,20,40,30] parameter(0)
5617       rhs = f32[10,20,30,50] parameter(1)
5618       lhs_t = transpose(lhs), dimensions={1,0,2,3}
5619       rhs_t = transpose(rhs), dimensions={1,0,2,3}
5620       ROOT root = dot(lhs_t, rhs_t),
5621                   lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5622                   lhs_contracting_dims={3}, rhs_contracting_dims={2},
5623                   operand_precision={default, high}
5624     }
5625   )";
5626   TF_ASSERT_OK_AND_ASSIGN(auto module,
5627                           ParseAndReturnVerifiedModule(hlo_string));
5628 
5629   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5630   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5631   EXPECT_TRUE(changed);
5632   const HloInstruction* dot;
5633   ASSERT_THAT(
5634       module->entry_computation()->root_instruction(),
5635       GmockMatch(m::Transpose(m::Dot(&dot, m::Parameter(0), m::Parameter(1)))));
5636   EXPECT_EQ(dot->precision_config().operand_precision()[0],
5637             PrecisionConfig::DEFAULT);
5638   EXPECT_EQ(dot->precision_config().operand_precision()[1],
5639             PrecisionConfig::HIGH);
5640 }
5641 
TEST_F(AlgebraicSimplifierTest,BatchDotTransposeBatchDimsAndOperands)5642 TEST_F(AlgebraicSimplifierTest, BatchDotTransposeBatchDimsAndOperands) {
5643   const char* hlo_string = R"(
5644     HloModule module
5645 
5646     ENTRY test {
5647       lhs = f32[10,20,30,40] parameter(0)
5648       rhs = f32[10,20,50,30] parameter(1)
5649       lhs_t = transpose(lhs), dimensions={1,0,3,2}
5650       rhs_t = transpose(rhs), dimensions={1,0,3,2}
5651       ROOT root = dot(lhs_t, rhs_t),
5652                   lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5653                   lhs_contracting_dims={3}, rhs_contracting_dims={2},
5654                   operand_precision={high, default}
5655     }
5656   )";
5657   TF_ASSERT_OK_AND_ASSIGN(auto module,
5658                           ParseAndReturnVerifiedModule(hlo_string));
5659 
5660   AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5661   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5662   EXPECT_TRUE(changed);
5663   const HloInstruction* dot;
5664   ASSERT_THAT(
5665       module->entry_computation()->root_instruction(),
5666       GmockMatch(m::Transpose(m::Dot(&dot, m::Parameter(1), m::Parameter(0)))));
5667   EXPECT_EQ(dot->precision_config().operand_precision()[0],
5668             PrecisionConfig::DEFAULT);
5669   EXPECT_EQ(dot->precision_config().operand_precision()[1],
5670             PrecisionConfig::HIGH);
5671 }
5672 
5673 struct PadReduceWindowEffectiveBroadcastCase {
5674   std::vector<int64_t> input_spatials;
5675   std::vector<int64_t> symmetric_pad_spatials;
5676   std::vector<int64_t> reduce_window_spatials;
5677   // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
5678   //
5679   // This doesn't test any different functionality but is useful for making sure
5680   // kBroadcast nodes are well formed.
5681   bool prepend_a;
5682   bool should_become_broadcast;
5683 
ToTestCaseNamexla::__anon8c2363a90111::PadReduceWindowEffectiveBroadcastCase5684   std::string ToTestCaseName() const {
5685     return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
5686                         absl::StrJoin(symmetric_pad_spatials, ","), ";",
5687                         absl::StrJoin(reduce_window_spatials, ","), ";",
5688                         prepend_a, ";", should_become_broadcast);
5689   }
5690 };
5691 
PrintTo(const PadReduceWindowEffectiveBroadcastCase & c,std::ostream * os)5692 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
5693   *os << c.ToTestCaseName();
5694 }
5695 
5696 class PadReduceWindowEffectiveBroadcastTest
5697     : public AlgebraicSimplifierTest,
5698       public ::testing::WithParamInterface<
5699           PadReduceWindowEffectiveBroadcastCase> {};
5700 
TEST_P(PadReduceWindowEffectiveBroadcastTest,DoIt)5701 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
5702   auto m = CreateNewVerifiedModule();
5703   const auto& param = GetParam();
5704 
5705   // a and b are parallel bounds we can either turn into a B F S0 S1 or
5706   // `B S0 S1 F` kind of pattern.
5707   auto decorate_spatials = [&param](absl::Span<const int64_t> spatials,
5708                                     int64_t a, int64_t b) {
5709     std::vector<int64_t> result;
5710     if (param.prepend_a) {
5711       result.push_back(a);
5712     }
5713     for (int64_t s : spatials) {
5714       result.push_back(s);
5715     }
5716     if (!param.prepend_a) {
5717       result.push_back(a);
5718     }
5719     result.push_back(b);
5720     return result;
5721   };
5722 
5723   HloComputation::Builder builder(TestName());
5724   const Shape input_shape = ShapeUtil::MakeShape(
5725       F32, decorate_spatials(param.input_spatials, 128, 2048));
5726   HloInstruction* input = builder.AddInstruction(
5727       HloInstruction::CreateParameter(0, input_shape, "input"));
5728 
5729   PaddingConfig padding = window_util::MakeSymmetricPadding(
5730       decorate_spatials(param.symmetric_pad_spatials, 0, 0));
5731   TF_ASSERT_OK_AND_ASSIGN(
5732       const Shape pad_shape,
5733       ShapeInference::InferPadShape(input->shape(),
5734                                     ShapeUtil::MakeShape(F32, {}), padding));
5735   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
5736       pad_shape, input,
5737       builder.AddInstruction(
5738           HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
5739       padding));
5740 
5741   HloComputation* add_computation = nullptr;
5742   {
5743     HloComputation::Builder builder(TestName() + ".add");
5744     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
5745     HloInstruction* p0 = builder.AddInstruction(
5746         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
5747     HloInstruction* p1 = builder.AddInstruction(
5748         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
5749     builder.AddInstruction(
5750         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
5751     add_computation = m->AddEmbeddedComputation(builder.Build());
5752   }
5753 
5754   Window window = window_util::MakeWindow(
5755       decorate_spatials(param.reduce_window_spatials, 1, 1));
5756   auto zero = builder.AddInstruction(
5757       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
5758   TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
5759                           ShapeInference::InferReduceWindowShape(
5760                               pad->shape(), zero->shape(), window,
5761                               add_computation->ComputeProgramShape()));
5762   builder.AddInstruction(HloInstruction::CreateReduceWindow(
5763       output_shape, pad, zero, window, add_computation));
5764 
5765   auto computation = m->AddEntryComputation(builder.Build());
5766   AlgebraicSimplifier simplifier(default_options_);
5767   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5768   ASSERT_TRUE(run_successful);
5769   SCOPED_TRACE(m->ToString());
5770 
5771   EXPECT_TRUE(
5772       ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
5773 
5774   if (param.should_become_broadcast) {
5775     EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast()));
5776   } else {
5777     EXPECT_THAT(computation->root_instruction(),
5778                 GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero))));
5779   }
5780 }
5781 
5782 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
PadReduceWindowEffectiveBroadcastCases()5783 PadReduceWindowEffectiveBroadcastCases() {
5784   static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
5785       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5786        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5787        /*should_become_broadcast=*/true},  //
5788       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5789        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
5790        /*should_become_broadcast=*/true},  //
5791       {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
5792        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5793        /*should_become_broadcast=*/false},  //
5794       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
5795        /*reduce_window_spatials=*/{2, 2}, /*prepend_a=*/true,
5796        /*should_become_broadcast=*/false},  //
5797       {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
5798        /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
5799        /*should_become_broadcast=*/false},  //
5800   };
5801   return *cases;
5802 }
5803 
5804 INSTANTIATE_TEST_SUITE_P(
5805     PadReduceWindowEffectiveBroadcastInstantiation,
5806     PadReduceWindowEffectiveBroadcastTest,
5807     ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
5808 
5809 class BatchDotStrengthReductionTest
5810     : public AlgebraicSimplifierTest,
5811       public ::testing::WithParamInterface<
5812           ::testing::tuple<int, int, int, PrimitiveType>> {};
TEST_P(BatchDotStrengthReductionTest,BatchDotStrengthReduction)5813 TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
5814   auto module = CreateNewVerifiedModule();
5815   int m, k, n;
5816   PrimitiveType element_type;
5817   std::tie(m, k, n, element_type) = GetParam();
5818   std::vector<int64_t> lhs_dims = {2, 3, 5};
5819   std::vector<int64_t> rhs_dims = lhs_dims;
5820   std::vector<int64_t> output_dims = lhs_dims;
5821   if (m > 0) {
5822     lhs_dims.push_back(m);
5823     output_dims.push_back(m);
5824   }
5825   if (k > 0) {
5826     lhs_dims.push_back(k);
5827     rhs_dims.push_back(k);
5828   }
5829   if (n > 0) {
5830     rhs_dims.push_back(n);
5831     output_dims.push_back(n);
5832   }
5833   Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
5834   Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
5835   Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
5836   HloComputation::Builder builder(TestName());
5837 
5838   auto lhs = builder.AddInstruction(
5839       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
5840   auto rhs = builder.AddInstruction(
5841       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
5842   DotDimensionNumbers dot_dnums;
5843   dot_dnums.add_lhs_batch_dimensions(0);
5844   dot_dnums.add_lhs_batch_dimensions(1);
5845   dot_dnums.add_lhs_batch_dimensions(2);
5846   dot_dnums.add_rhs_batch_dimensions(0);
5847   dot_dnums.add_rhs_batch_dimensions(1);
5848   dot_dnums.add_rhs_batch_dimensions(2);
5849   if (k > 0) {
5850     dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
5851     dot_dnums.add_rhs_contracting_dimensions(3);
5852   }
5853   builder.AddInstruction(HloInstruction::CreateDot(
5854       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5855   auto computation = module->AddEntryComputation(builder.Build());
5856   AlgebraicSimplifier simplifier(default_options_);
5857   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5858   const bool dot_should_be_transformed =
5859       m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
5860   EXPECT_EQ(changed, dot_should_be_transformed);
5861   TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5862   bool has_no_dot = true;
5863   for (const auto& hlo : computation->instructions()) {
5864     if (hlo->opcode() == HloOpcode::kDot) {
5865       has_no_dot = false;
5866       break;
5867     }
5868   }
5869   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5870 }
5871 
5872 INSTANTIATE_TEST_SUITE_P(
5873     BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
5874     ::testing::Combine(::testing::Values(-1, 1, 2), ::testing::Values(-1, 1, 2),
5875                        ::testing::Values(-1, 1, 2),
5876                        ::testing::Values(C128, C64, F64, F32, BF16)));
5877 
5878 class DotStrengthReductionTest
5879     : public AlgebraicSimplifierTest,
5880       public ::testing::WithParamInterface<
5881           ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest,DotStrengthReduction)5882 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
5883   auto module = CreateNewVerifiedModule();
5884   int m, k, n;
5885   bool transpose_lhs, transpose_rhs;
5886   PrimitiveType element_type;
5887   std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
5888 
5889   Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
5890   Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
5891   Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
5892   Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
5893   Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
5894   HloComputation::Builder builder(TestName());
5895 
5896   auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
5897       0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
5898   if (transpose_lhs) {
5899     lhs = builder.AddInstruction(
5900         HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
5901   }
5902   auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
5903       1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
5904   if (transpose_rhs) {
5905     rhs = builder.AddInstruction(
5906         HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
5907   }
5908   DotDimensionNumbers dot_dnums;
5909   dot_dnums.add_lhs_contracting_dimensions(1);
5910   dot_dnums.add_rhs_contracting_dimensions(0);
5911   builder.AddInstruction(HloInstruction::CreateDot(
5912       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5913   auto computation = module->AddEntryComputation(builder.Build());
5914   AlgebraicSimplifier simplifier(default_options_);
5915   // First pass of algebraic simplifier will remove degenerate dimensions
5916   // and optimize dot(transpose(x),transpose(y))
5917   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5918   const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
5919   const bool computation_should_be_modified =
5920       dot_should_be_transformed || (transpose_lhs && transpose_rhs);
5921   EXPECT_EQ(changed, computation_should_be_modified);
5922   // The second pass of algebraic simplifier will remove dots without
5923   // non-contracting dimensions or contracting dimensions.
5924   TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5925   EXPECT_EQ(changed, computation_should_be_modified);
5926   bool has_no_dot = true;
5927   for (const auto& hlo : computation->instructions()) {
5928     if (hlo->opcode() == HloOpcode::kDot) {
5929       has_no_dot = false;
5930       break;
5931     }
5932   }
5933   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5934 }
5935 
5936 INSTANTIATE_TEST_SUITE_P(
5937     DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
5938     ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
5939                        ::testing::Values(1, 2), ::testing::Bool(),
5940                        ::testing::Bool(),
5941                        ::testing::Values(C128, C64, F64, F32, BF16)));
5942 
5943 struct DotOfConcatTestSpec {
5944   int64_t m;
5945   int64_t k;
5946   int64_t n;
5947 };
5948 
5949 class DotOfConcatSimplificationTest
5950     : public AlgebraicSimplifierTest,
5951       public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
5952 
5953 // Test that we transform
5954 //  dot(const, concat(A, B, C))
5955 // to
5956 //  add(dot(const_0, A), dot(const_1, B),  dot(const_2, C))
TEST_P(DotOfConcatSimplificationTest,ConstantLHS)5957 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
5958   auto m = CreateNewVerifiedModule();
5959   HloComputation::Builder builder(TestName());
5960 
5961   DotOfConcatTestSpec spec = GetParam();
5962 
5963   ASSERT_GE(spec.k, 3);
5964 
5965   int64_t k0 = spec.k / 3;
5966   int64_t k1 = spec.k / 3;
5967   int64_t k2 = spec.k - k0 - k1;
5968 
5969   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
5970   auto* lhs = builder.AddInstruction(
5971       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5972           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
5973 
5974   Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
5975   Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
5976   Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
5977 
5978   HloInstruction* rhs0 = builder.AddInstruction(
5979       HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
5980   HloInstruction* rhs1 = builder.AddInstruction(
5981       HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
5982   HloInstruction* rhs2 = builder.AddInstruction(
5983       HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
5984 
5985   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
5986   HloInstruction* rhs = builder.AddInstruction(
5987       HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
5988 
5989   DotDimensionNumbers dot_dnums;
5990   dot_dnums.add_lhs_contracting_dimensions(1);
5991   dot_dnums.add_rhs_contracting_dimensions(0);
5992 
5993   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
5994   builder.AddInstruction(HloInstruction::CreateDot(
5995       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5996 
5997   auto computation = m->AddEntryComputation(builder.Build());
5998   AlgebraicSimplifier simplifier(default_options_);
5999   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6000   ASSERT_TRUE(run_successful);
6001 
6002   EXPECT_TRUE(
6003       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6004 
6005   auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0));
6006   auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1));
6007   auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2));
6008   EXPECT_THAT(
6009       computation->root_instruction(),
6010       GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2)));
6011 }
6012 
6013 // Test that we transform
6014 //  dot(concat(A, B, C), const)
6015 // to
6016 //  add(dot(A, const_0), dot(B, const_1),  dot(C, const_2))
TEST_P(DotOfConcatSimplificationTest,ConstantRHS)6017 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
6018   auto m = CreateNewVerifiedModule();
6019   HloComputation::Builder builder(TestName());
6020 
6021   DotOfConcatTestSpec spec = GetParam();
6022 
6023   ASSERT_GE(spec.k, 4);
6024 
6025   int64_t k0 = spec.k / 4;
6026   int64_t k1 = spec.k / 4;
6027   int64_t k2 = spec.k / 4;
6028   int64_t k3 = spec.k - k0 - k1 - k2;
6029 
6030   Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
6031   Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
6032   Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
6033   Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
6034 
6035   HloInstruction* lhs0 = builder.AddInstruction(
6036       HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
6037   HloInstruction* lhs1 = builder.AddInstruction(
6038       HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
6039   HloInstruction* lhs2 = builder.AddInstruction(
6040       HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
6041   HloInstruction* lhs3 = builder.AddInstruction(
6042       HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
6043 
6044   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
6045   HloInstruction* lhs =
6046       builder.AddInstruction(HloInstruction::CreateConcatenate(
6047           lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
6048 
6049   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
6050   auto* rhs = builder.AddInstruction(
6051       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6052           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
6053 
6054   DotDimensionNumbers dot_dnums;
6055   dot_dnums.add_lhs_contracting_dimensions(1);
6056   dot_dnums.add_rhs_contracting_dimensions(0);
6057 
6058   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
6059   builder.AddInstruction(HloInstruction::CreateDot(
6060       dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6061 
6062   auto computation = m->AddEntryComputation(builder.Build());
6063   AlgebraicSimplifier simplifier(default_options_);
6064   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6065   ASSERT_TRUE(run_successful);
6066   EXPECT_TRUE(
6067       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6068 
6069   auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant()));
6070   auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant()));
6071   auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant()));
6072   auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant()));
6073   EXPECT_THAT(
6074       computation->root_instruction(),
6075       GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2),
6076                         match_dot_3)));
6077 }
6078 
6079 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
6080     {/*m=*/3, /*k=*/9, /*n=*/3},    //
6081     {/*m=*/3, /*k=*/20, /*n=*/3},   //
6082     {/*m=*/1, /*k=*/18, /*n=*/5},   //
6083     {/*m=*/20, /*k=*/20, /*n=*/1},  //
6084     {/*m=*/1, /*k=*/16, /*n=*/1},   //
6085 };
6086 
TEST_F(DotOfConcatSimplificationTest,ConcatIntoScalarDot)6087 TEST_F(DotOfConcatSimplificationTest, ConcatIntoScalarDot) {
6088   const char* kModuleStr = R"(
6089     HloModule m
6090     test {
6091       param0 = f32[4] parameter(0)
6092       param1 = f32[1] parameter(1)
6093       constant = f32[5] constant({-0.38, 0.07, -0.62, 0.66, 0.20})
6094       concat = f32[5] concatenate(param0, param1), dimensions={0}
6095       ROOT dot = f32[] dot(concat, constant), lhs_contracting_dims={0},
6096                                               rhs_contracting_dims={0}
6097     })";
6098   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6099   AlgebraicSimplifierOptions options = default_options_;
6100   options.set_enable_dot_strength_reduction(false);
6101   ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6102 }
6103 
TEST_F(DotOfConcatSimplificationTest,UnnestConcatenate)6104 TEST_F(DotOfConcatSimplificationTest, UnnestConcatenate) {
6105   const char* kModuleStr = R"(
6106     HloModule m
6107     test {
6108       p0 = f32[2,10] parameter(0)
6109       p1 = f32[3,10] parameter(1)
6110       p2 = f32[4,10] parameter(2)
6111       c0 = f32[5,10] concatenate(p0, p1), dimensions={0}
6112       ROOT c1 = f32[9,10] concatenate(c0, p2), dimensions={0}
6113     })";
6114   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6115   AlgebraicSimplifier simplifier(default_options_);
6116   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, m.get()));
6117   EXPECT_TRUE(changed);
6118   EXPECT_THAT(m->entry_computation()->root_instruction(),
6119               GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1),
6120                                         m::Parameter(2))));
6121 }
6122 
6123 // Test that DynamicUpdateSlice update param with any dimension equal to zero
6124 // gets removed.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceZeroUpdate)6125 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
6126   auto m = CreateNewVerifiedModule();
6127   HloComputation::Builder builder(TestName());
6128   const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10});
6129   HloInstruction* const operand = builder.AddInstruction(
6130       HloInstruction::CreateParameter(0, dslice_shape, "operand"));
6131   const Shape update_shape = ShapeUtil::MakeShape(F32, {0});
6132   HloInstruction* const update = builder.AddInstruction(
6133       HloInstruction::CreateParameter(1, update_shape, "update"));
6134   HloInstruction* const start_indices = builder.AddInstruction(
6135       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>({})));
6136   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
6137       dslice_shape, operand, update,
6138       std::initializer_list<HloInstruction*>({start_indices})));
6139   const HloComputation* const computation =
6140       m->AddEntryComputation(builder.Build());
6141 
6142   AlgebraicSimplifier simplifier(default_options_);
6143   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
6144   EXPECT_THAT(computation->root_instruction(), operand);
6145 }
6146 
6147 // Test that dynamic-update-slice with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPad)6148 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPad) {
6149   const char* hlo_string = R"(
6150 HloModule AddBroadcastZeroWithDynamicSlice
6151 
6152 ENTRY AddBroadcastZeroWithDynamicSlice {
6153   param0 = f32[1800,12,512]{2,1,0} parameter(0)
6154   constant = f32[] constant(0)
6155   broadcast = f32[1800,12,512]{2,1,0} broadcast(constant), dimensions={}
6156   param1 = f32[1,12,512]{2,1,0} parameter(1)
6157   constant.1 = s32[] constant(0)
6158   dynamic-update-slice = f32[1800,12,512]{2,1,0} dynamic-update-slice(broadcast, param1, constant.1, constant.1, constant.1)
6159   ROOT add = f32[1800,12,512]{2,1,0} add(param0, dynamic-update-slice)
6160 }
6161 )";
6162   TF_ASSERT_OK_AND_ASSIGN(auto module,
6163                           ParseAndReturnVerifiedModule(hlo_string));
6164   VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
6165   AlgebraicSimplifier simplifier(default_options_);
6166   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6167   VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
6168   auto root = module->entry_computation()->root_instruction();
6169   EXPECT_THAT(root->opcode(), HloOpcode::kAdd);
6170   EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
6171 }
6172 
TEST_F(AlgebraicSimplifierTest,AddDynamicUpdateSliceToAddSlice)6173 TEST_F(AlgebraicSimplifierTest, AddDynamicUpdateSliceToAddSlice) {
6174   const char* hlo_string = R"(
6175 HloModule AddDynamicUpdateSliceToAddSlice
6176 
6177 ENTRY AddDynamicUpdateSliceToAddSlice {
6178   param0 = f32[1,4,12,512,1,1] parameter(0)
6179   constant = f32[] constant(0)
6180   broadcast = f32[4,12,512] broadcast(constant), dimensions={}
6181   param1 = f32[1,12,512] parameter(1)
6182   param2 = s32[] parameter(2)
6183   constant.1 = s32[] constant(0)
6184   dynamic-update-slice = f32[4,12,512] dynamic-update-slice(
6185     broadcast, param1, param2, constant.1, constant.1)
6186   reshape = f32[1,4,12,512,1,1] reshape(dynamic-update-slice)
6187   ROOT add = f32[1,4,12,512,1,1] add(param0, reshape)
6188 }
6189 )";
6190   TF_ASSERT_OK_AND_ASSIGN(auto module,
6191                           ParseAndReturnVerifiedModule(hlo_string));
6192   VLOG(2) << "Before rewrite reshape\n" << module->ToString();
6193   AlgebraicSimplifier simplifier(default_options_);
6194   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6195   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6196   VLOG(2) << "After rewrite to add slice\n" << module->ToString();
6197   auto root = module->entry_computation()->root_instruction();
6198   EXPECT_THAT(
6199       root,
6200       GmockMatch(m::DynamicUpdateSlice(
6201           m::Parameter(0),
6202           m::Add(m::DynamicSlice(m::Parameter(0), m::Constant(),
6203                                  m::Parameter(2), m::Constant(), m::Constant(),
6204                                  m::Constant(), m::Constant()),
6205                  m::Reshape(m::Parameter(1))),
6206           m::Constant(), m::Parameter(2), m::Constant(), m::Constant(),
6207           m::Constant(), m::Constant())));
6208 }
6209 
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReduction)6210 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
6211   const char* hlo_string = R"(
6212 HloModule ConstScalarMultiply
6213 ENTRY ConstScalarMultiply {
6214   param0 = f32[16,512,4096]{2,1,0} parameter(0)
6215   constant.0 = f32[] constant(0.5)
6216   broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={}
6217   multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0)
6218   param1 = f32[16,512,4096]{2,1,0} parameter(1)
6219   multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1)
6220   param2 = f32[16,512,1024]{2,1,0} parameter(2)
6221   constant.1 = f32[] constant(1.109)
6222   broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={}
6223   multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1)
6224   ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0
6225 }
6226 )";
6227   TF_ASSERT_OK_AND_ASSIGN(auto module,
6228                           ParseAndReturnVerifiedModule(hlo_string));
6229   AlgebraicSimplifierOptions options;
6230   options.set_enable_scalar_multiply_reduction(true);
6231   AlgebraicSimplifier simplifier(options);
6232   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6233   auto root = module->entry_computation()->root_instruction();
6234   EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
6235   EXPECT_THAT(root,
6236               GmockMatch(m::MultiplyAnyOrder(
6237                   m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f)))));
6238 }
6239 
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReductionMultiUser)6240 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) {
6241   const char* hlo_string = R"(
6242 HloModule ConstScalarMultiply
6243 ENTRY ConstScalarMultiply {
6244   param0 = f32[16,512,1024] parameter(0)
6245   param1 = f32[4096,1024,1] parameter(1)
6246   convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf
6247   constant.1 = f32[] constant(0.5)
6248   broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={}
6249   multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1)
6250   param2 = f32[16,512,4096] parameter(2)
6251   multiply.2 = f32[16,512,4096] multiply(convolution, param2)
6252   ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2)
6253 }
6254 )";
6255   TF_ASSERT_OK_AND_ASSIGN(auto module,
6256                           ParseAndReturnVerifiedModule(hlo_string));
6257   AlgebraicSimplifierOptions options;
6258   options.set_enable_scalar_multiply_reduction(true);
6259   AlgebraicSimplifier simplifier(options);
6260   ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
6261 }
6262 
6263 INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
6264                          DotOfConcatSimplificationTest,
6265                          ::testing::ValuesIn(kDotOfConcatTestSpecs));
6266 
6267 struct DotOfGatherTestSpec {
6268   int64_t m;
6269   int64_t k;
6270   int64_t n;
6271   int s;  // start index for dynamic slice on the non-contracting dimension
6272   int64_t lcd;  // left contracting dimension
6273   int64_t rcd;  // right contracting dimension
6274   bool neg;     // is negative testcase
6275 };
6276 
6277 class DotOfGatherSimplificationTest
6278     : public AlgebraicSimplifierTest,
6279       public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
6280 
6281 // input: dot(DS(ctA), ctB))
6282 // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
6283 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
6284 // output: DS(dot(ctA, ctB))
6285 // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
TEST_P(DotOfGatherSimplificationTest,ConstantRHS)6286 TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
6287   auto m = CreateNewVerifiedModule();
6288   HloComputation::Builder builder(TestName());
6289 
6290   DotOfGatherTestSpec spec = GetParam();
6291 
6292   ASSERT_LE(spec.s, spec.m);
6293 
6294   // For negative tests, increase k of the dynamic slice argument to prevent the
6295   // optimization (constants ctA, ctB must have equal contracting dimensions).
6296   int64_t k_increase = spec.neg ? 5 : 0;
6297   int64_t lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
6298   int64_t lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
6299   Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
6300   auto* lhs = builder.AddInstruction(
6301       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6302           /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
6303           /*cols=*/lhs_cols)));
6304 
6305   int32_t start_row = (spec.lcd == 0) ? 0 : spec.s;
6306   int32_t start_col = (spec.lcd == 0) ? spec.s : 0;
6307   std::vector<HloInstruction*> start_indices = {
6308       builder.AddInstruction(HloInstruction::CreateConstant(
6309           LiteralUtil::CreateR0<int32_t>(start_row))),
6310       builder.AddInstruction(HloInstruction::CreateConstant(
6311           LiteralUtil::CreateR0<int32_t>(start_col)))};
6312   int64_t slice_row_size = (spec.lcd == 0) ? spec.k : 1;
6313   int64_t slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
6314   std::vector<int64_t> slice_sizes = {slice_row_size, slice_col_size};
6315   Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
6316   auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
6317       ds_shape, lhs, start_indices, slice_sizes));
6318 
6319   int64_t rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
6320   int64_t rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
6321   Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
6322   auto* rhs = builder.AddInstruction(
6323       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6324           /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
6325           /*cols=*/rhs_cols)));
6326 
6327   DotDimensionNumbers dot_dnums;
6328   dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
6329   dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
6330 
6331   int64_t dot_row_size = 1;
6332   int64_t dot_col_size = spec.n;
6333   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
6334   builder.AddInstruction(HloInstruction::CreateDot(
6335       dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6336 
6337   auto computation = m->AddEntryComputation(builder.Build());
6338   AlgebraicSimplifier simplifier(default_options_);
6339   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6340   ASSERT_TRUE(run_successful);
6341   EXPECT_TRUE(
6342       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6343 
6344   if (spec.neg) {
6345     EXPECT_NE(computation->root_instruction()->opcode(),
6346               HloOpcode::kDynamicSlice);
6347   } else {
6348     EXPECT_THAT(computation->root_instruction(),
6349                 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
6350                                            m::Constant(), m::Constant())));
6351   }
6352 }
6353 
6354 // input: dot(ctA, DS(ctB))
6355 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
6356 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
6357 // output: DS(dot(ctA, ctB))
6358 // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
TEST_P(DotOfGatherSimplificationTest,ConstantLHS)6359 TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
6360   auto m = CreateNewVerifiedModule();
6361   HloComputation::Builder builder(TestName());
6362 
6363   DotOfGatherTestSpec spec = GetParam();
6364 
6365   ASSERT_LE(spec.s, spec.n);
6366 
6367   int64_t lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
6368   int64_t lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
6369   Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
6370   auto* lhs = builder.AddInstruction(
6371       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6372           /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
6373           /*cols=*/lhs_cols)));
6374 
6375   // For negative tests increase k of the dynamic slice argument to prevent the
6376   // optimization
6377   int64_t k_increase = spec.neg ? 5 : 0;
6378   int64_t rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
6379   int64_t rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
6380   Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
6381   auto* rhs = builder.AddInstruction(
6382       HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6383           /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
6384           /*cols=*/rhs_cols)));
6385 
6386   int32_t start_row = (spec.rcd == 0) ? 0 : spec.s;
6387   int32_t start_col = (spec.rcd == 0) ? spec.s : 0;
6388   std::vector<HloInstruction*> start_indices = {
6389       builder.AddInstruction(HloInstruction::CreateConstant(
6390           LiteralUtil::CreateR0<int32_t>(start_row))),
6391       builder.AddInstruction(HloInstruction::CreateConstant(
6392           LiteralUtil::CreateR0<int32_t>(start_col)))};
6393   int64_t slice_row_size = (spec.rcd == 0) ? spec.k : 1;
6394   int64_t slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
6395   std::vector<int64_t> slice_sizes = {slice_row_size, slice_col_size};
6396   Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
6397   auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
6398       ds_shape, rhs, start_indices, slice_sizes));
6399 
6400   DotDimensionNumbers dot_dnums;
6401   dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
6402   dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
6403 
6404   int64_t dot_row_size = spec.m;
6405   int64_t dot_col_size = 1;
6406   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
6407   builder.AddInstruction(HloInstruction::CreateDot(
6408       dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
6409 
6410   auto computation = m->AddEntryComputation(builder.Build());
6411   AlgebraicSimplifier simplifier(default_options_);
6412   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6413   ASSERT_TRUE(run_successful);
6414   EXPECT_TRUE(
6415       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6416 
6417   if (spec.neg) {
6418     EXPECT_NE(computation->root_instruction()->opcode(),
6419               HloOpcode::kDynamicSlice);
6420   } else {
6421     EXPECT_THAT(computation->root_instruction(),
6422                 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
6423                                            m::Constant(), m::Constant())));
6424   }
6425 }
6426 
DotOfGatherPositiveNegativeTests()6427 std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
6428   std::vector<DotOfGatherTestSpec> positives = {
6429       // "Classical dot", i.e. matrix multiply:
6430       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
6431        /*neg=*/false},
6432       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
6433        /*neg=*/false},
6434       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
6435        /*neg=*/false},
6436       // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
6437       // dot(ct, ct) before DotOfGather optimization kicks in.
6438       // Contract on rows:
6439       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
6440        /*neg=*/false},
6441       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
6442        /*neg=*/false},
6443       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
6444        /*neg=*/false},
6445       // Reverse matrix multiply:
6446       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
6447        /*neg=*/false},
6448       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
6449        /*neg=*/false},
6450       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
6451        /*neg=*/false},
6452       // Contract on columns:
6453       {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
6454        /*neg=*/false},
6455       {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
6456        /*neg=*/false},
6457       {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
6458        /*neg=*/false},
6459   };
6460   std::vector<DotOfGatherTestSpec> all;
6461   const std::vector<DotOfGatherTestSpec>::size_type positives_size =
6462       positives.size();
6463   all.reserve(positives_size * 2);
6464   for (std::vector<DotOfGatherTestSpec>::size_type i = 0; i < positives_size;
6465        i++) {
6466     DotOfGatherTestSpec positive_test = positives[i];
6467     all.push_back(positive_test);
6468     DotOfGatherTestSpec negative_test = positive_test;
6469     negative_test.neg = true;
6470     all.push_back(negative_test);
6471   }
6472   return all;
6473 }
6474 
6475 INSTANTIATE_TEST_SUITE_P(
6476     DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
6477     ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
6478 
TEST_F(AlgebraicSimplifierTest,GatherOfScalarToBroadcast)6479 TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
6480   const char* hlo_string = R"(
6481   HloModule repeat
6482 
6483   ENTRY main {
6484     o = f32[1,1] parameter(0)
6485     i = s32[100,2] parameter(1)
6486     ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
6487                                   start_index_map={0,1},
6488                                   index_vector_dim=1,
6489                                   offset_dims={},
6490                                   slice_sizes={1,1}
6491   }
6492   )";
6493   TF_ASSERT_OK_AND_ASSIGN(auto module,
6494                           ParseAndReturnVerifiedModule(hlo_string));
6495 
6496   AlgebraicSimplifierOptions options;
6497   AlgebraicSimplifier simplifier(options);
6498   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6499   auto root = module->entry_computation()->root_instruction();
6500   EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
6501 }
6502 
TEST_F(AlgebraicSimplifierTest,TupleReduceReshape)6503 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
6504   const char* hlo_string = R"(
6505 HloModule module
6506 
6507 reducer {
6508   parameter.1 = f32[] parameter(0)
6509   parameter.3 = f32[] parameter(2)
6510   add.2 = f32[] add(parameter.1, parameter.3)
6511   parameter.0 = f32[] parameter(1)
6512   parameter.2 = f32[] parameter(3)
6513   add.3 = f32[] add(parameter.0, parameter.2)
6514   ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3)
6515 }
6516 
6517 ENTRY entry {
6518   parameter.6 = (f32[], f32[]) parameter(0)
6519   get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0
6520   get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1
6521   constant = f32[] constant(0)
6522   ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer
6523 }
6524 )";
6525   TF_ASSERT_OK_AND_ASSIGN(auto module,
6526                           ParseAndReturnVerifiedModule(hlo_string));
6527 
6528   AlgebraicSimplifierOptions options;
6529   AlgebraicSimplifier simplifier(options);
6530   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6531   auto root = module->entry_computation()->root_instruction();
6532   EXPECT_THAT(root, GmockMatch(m::Tuple(
6533                         m::Reshape(m::GetTupleElement(m::Parameter(), 0)),
6534                         m::Reshape(m::GetTupleElement(m::Parameter(), 1)))));
6535 }
6536 
TEST_F(AlgebraicSimplifierTest,TupleReduceBroadcast)6537 TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) {
6538   const char* hlo_string = R"(
6539 HloModule module
6540 
6541 reducer {
6542   parameter.1 = f32[] parameter(0)
6543   parameter.3 = f32[] parameter(2)
6544   mul.2 = f32[] add(parameter.1, parameter.3)
6545   parameter.0 = f32[] parameter(1)
6546   parameter.2 = f32[] parameter(3)
6547   add.3 = f32[] add(parameter.0, parameter.2)
6548   ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3)
6549 }
6550 
6551 ENTRY entry {
6552   parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0)
6553   get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0
6554   get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1
6555   constant.0 = f32[] constant(0)
6556   constant.1 = f32[] constant(1)
6557   ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer
6558 }
6559 )";
6560   TF_ASSERT_OK_AND_ASSIGN(auto module,
6561                           ParseAndReturnVerifiedModule(hlo_string));
6562 
6563   AlgebraicSimplifierOptions options;
6564   AlgebraicSimplifier simplifier(options);
6565   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6566   auto root = module->entry_computation()->root_instruction();
6567   EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)),
6568                                         m::Broadcast(m::ConstantScalar(1)))));
6569 }
6570 
TEST_F(AlgebraicSimplifierTest,ZeroSizedReshapeWithoutLayout)6571 TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
6572   auto builder = HloComputation::Builder(TestName());
6573   HloInstruction* param =
6574       builder.AddInstruction(HloInstruction::CreateParameter(
6575           0, ShapeUtil::MakeShape(F32, {1}), "param"));
6576   HloInstruction* broadcast =
6577       builder.AddInstruction(HloInstruction::CreateBroadcast(
6578           ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
6579 
6580   // Create a reshape with zero sized result and without layout.
6581   Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
6582   reshaped_shape.clear_layout();
6583   builder.AddInstruction(
6584       HloInstruction::CreateReshape(reshaped_shape, broadcast));
6585 
6586   std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
6587   module->AddEntryComputation(builder.Build());
6588 
6589   AlgebraicSimplifierOptions options;
6590   AlgebraicSimplifier simplifier(options);
6591   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6592   HloInstruction* root = module->entry_computation()->root_instruction();
6593   EXPECT_THAT(root, GmockMatch(m::Constant()));
6594 }
6595 
TEST_F(AlgebraicSimplifierTest,DividedByConstantInstructionWithoutLayout)6596 TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
6597   Shape shape = ShapeUtil::MakeShape(F32, {});
6598   shape.clear_layout();
6599   auto builder = HloComputation::Builder(TestName());
6600   HloInstruction* param = builder.AddInstruction(
6601       HloInstruction::CreateParameter(0, shape, "param"));
6602 
6603   HloInstruction* const_value = builder.AddInstruction(
6604       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
6605   builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
6606                                                       param, const_value));
6607 
6608   std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
6609   module->AddEntryComputation(builder.Build());
6610 
6611   AlgebraicSimplifierOptions options;
6612   AlgebraicSimplifier simplifier(options);
6613   EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6614   HloInstruction* root = module->entry_computation()->root_instruction();
6615   EXPECT_THAT(root, GmockMatch(m::Multiply()));
6616 }
6617 
6618 // Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipSqrt)6619 TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
6620   const char* kModuleStr = R"(
6621     HloModule m
6622     test {
6623       p0 = f32[] parameter(0)
6624       p1 = f32[] parameter(1)
6625       sqrt = f32[] sqrt(p0)
6626       ROOT div = f32[] divide(p1, sqrt)
6627     }
6628   )";
6629   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6630   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6631   EXPECT_THAT(m->entry_computation()->root_instruction(),
6632               GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
6633                                              m::Rsqrt(m::Parameter(0)))));
6634 }
6635 
6636 // Test that 1/rsqrt(X) is simplified to sqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipRsqrt)6637 TEST_F(AlgebraicSimplifierTest, RecipRsqrt) {
6638   const char* kModuleStr = R"(
6639     HloModule m
6640     test {
6641       p0 = f32[] parameter(0)
6642       p1 = f32[] parameter(1)
6643       rsqrt = f32[] rsqrt(p0)
6644       ROOT div = f32[] divide(p1, rsqrt)
6645     }
6646   )";
6647   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6648   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6649   EXPECT_THAT(m->entry_computation()->root_instruction(),
6650               GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
6651                                              m::Sqrt(m::Parameter(0)))));
6652 }
6653 
TEST_F(AlgebraicSimplifierTest,CopyReshape)6654 TEST_F(AlgebraicSimplifierTest, CopyReshape) {
6655   const char* kModuleStr = R"(
6656     HloModule m
6657     test {
6658       p0 = f32[168,168,48,48]{3,2,1,0} parameter(0)
6659       r0 = f32[1,168,168,2304]{3,2,1,0} reshape(p0)
6660       ROOT c0 = f32[1,168,168,2304]{3,0,2,1} copy(r0)
6661     })";
6662   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6663   Shape result_shape = m->entry_computation()->root_instruction()->shape();
6664   AlgebraicSimplifierOptions options(
6665       [](const Shape&, const Shape&) { return false; });
6666   options.set_is_layout_sensitive(true);
6667   ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6668   EXPECT_THAT(
6669       m->entry_computation()->root_instruction(),
6670       GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape)));
6671 }
6672 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RL)6673 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) {
6674   const char* kModuleStr = R"(
6675     HloModule m
6676     test {
6677       rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
6678       t0 = f32[2, 2, 3] parameter(0)
6679       t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6680       lhs = f32[2, 6] reshape(t1)
6681       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6682     }
6683   )";
6684   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6685   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6686   auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
6687   auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2});
6688   auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6689   // The transformation of moving transpose and reshape to the constant side
6690   // is layout insensitive. We ignore layout when checking shapes.
6691   const HloInstruction* transpose;
6692   ASSERT_THAT(m->entry_computation()->root_instruction(),
6693               GmockMatch(m::Dot(
6694                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6695                   m::Reshape(m::Transpose(&transpose,
6696                                           m::Reshape(m::Constant())
6697                                               .WithShapeCompatibleTo(&shape2))
6698                                  .WithShapeCompatibleTo(&shape3)))));
6699   EXPECT_THAT(transpose->dimensions(), ElementsAre(1, 0, 2));
6700 }
6701 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RR)6702 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) {
6703   const char* kModuleStr = R"(
6704     HloModule m
6705     test {
6706       rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6707                                 {1, 1, 1, 1, 1, 1}})
6708       t0 = f32[2, 2, 3] parameter(0)
6709       t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6710       lhs = f32[2, 6] reshape(t1)
6711       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
6712     }
6713   )";
6714   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6715   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6716   auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
6717   auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6718   auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6719   EXPECT_THAT(m->entry_computation()->root_instruction(),
6720               GmockMatch(m::Dot(
6721                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6722                   m::Reshape(m::Transpose(m::Reshape(m::Constant())
6723                                               .WithShapeCompatibleTo(&shape2))
6724                                  .WithShapeCompatibleTo(&shape3)))));
6725 }
6726 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR)6727 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) {
6728   const char* kModuleStr = R"(
6729     HloModule m
6730     test {
6731       rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6732                                 {1, 1, 1, 1, 1, 1}})
6733       t0 = f32[2, 3, 2] parameter(0)
6734       t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2}
6735       lhs = f32[6, 2] reshape(t1)
6736       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1}
6737     }
6738   )";
6739   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6740   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6741   auto shape1 = ShapeUtil::MakeShape(F32, {6, 2});
6742   auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6743   auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6744   EXPECT_THAT(m->entry_computation()->root_instruction(),
6745               GmockMatch(m::Dot(
6746                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6747                   m::Reshape(m::Transpose(m::Reshape(m::Constant())
6748                                               .WithShapeCompatibleTo(&shape2))
6749                                  .WithShapeCompatibleTo(&shape3)))));
6750 }
6751 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR2)6752 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) {
6753   const char* kModuleStr = R"(
6754     HloModule m
6755     test {
6756       rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}})
6757       t0 = f32[2, 2, 2, 2] parameter(0)
6758       t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1}
6759       lhs = f32[2, 8] reshape(t1)
6760       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1},
6761                                             rhs_contracting_dims={0}
6762     }
6763   )";
6764   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6765   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6766   auto shape1 = ShapeUtil::MakeShape(F32, {2, 8});
6767   auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
6768   const HloInstruction* transpose;
6769   ASSERT_THAT(
6770       m->entry_computation()->root_instruction(),
6771       GmockMatch(m::Dot(
6772           m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6773           m::Reshape(m::Transpose(
6774               &transpose,
6775               m::Reshape(m::Constant()).WithShapeCompatibleTo(&shape2))))));
6776   EXPECT_THAT(transpose->dimensions(), ElementsAre(2, 0, 1, 3));
6777 }
6778 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_MM)6779 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) {
6780   const char* kModuleStr = R"(
6781     HloModule m
6782     test {
6783       rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}},
6784                                    {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}})
6785       t0 = f32[2, 2, 3, 2] parameter(0)
6786       t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3}
6787       lhs = f32[2, 6, 2] reshape(t1)
6788       ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1},
6789                                                rhs_batch_dims={0}, rhs_contracting_dims={1}
6790     }
6791   )";
6792   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6793   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6794   auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2});
6795   auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2});
6796   auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2});
6797   const HloInstruction* transpose;
6798   ASSERT_THAT(m->entry_computation()->root_instruction(),
6799               GmockMatch(m::Dot(
6800                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6801                   m::Reshape(m::Transpose(&transpose,
6802                                           m::Reshape(m::Constant())
6803                                               .WithShapeCompatibleTo(&shape2))
6804                                  .WithShapeCompatibleTo(&shape3)))));
6805   EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6806 }
6807 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegTranspose)6808 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) {
6809   const char* kModuleStr = R"(
6810     HloModule m
6811     test {
6812       rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}})
6813       t0 = f32[3, 4, 2] parameter(0)
6814       t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1}
6815       lhs = f32[2, 12] reshape(t1)
6816       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6817     }
6818   )";
6819   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6820   // Transpose affects non-contracting dimension. The transpose and reshape
6821   // should not be moved to the constant side.
6822   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6823 }
6824 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegReshape)6825 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) {
6826   const char* kModuleStr = R"(
6827     HloModule m
6828     test {
6829       rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}})
6830       t0 = f32[2, 4, 3] parameter(0)
6831       t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1}
6832       lhs = f32[3, 8] reshape(t1)
6833       ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6834     }
6835   )";
6836   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6837   // Reshape affects non-contracting dimensions. The transpose and reshape
6838   // should not be moved to the constant side.
6839   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6840 }
6841 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegConstant)6842 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) {
6843   const char* kModuleStr = R"(
6844     HloModule m
6845     test {
6846       t0 = f32[2, 3, 4] parameter(0)
6847       t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1}
6848       lhs = f32[2, 12] reshape(t1)
6849       rhs = f32[12, 2] parameter(1)
6850       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6851     }
6852   )";
6853   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6854   // Both operands are non-constant, so the optimization should not happen.
6855   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6856 }
6857 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegLayout)6858 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) {
6859   const char* kModuleStr = R"(
6860     HloModule m
6861     test {
6862       rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
6863       t0 = f32[2, 2, 3] parameter(0)
6864       t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6865       lhs = f32[2, 6] reshape(t1)
6866       ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6867     }
6868   )";
6869   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6870   // We disable converting reshape to bitcast to make sure algsimp pass does
6871   // not catch the reshape in this test, then we can simply check if algsimp
6872   // pass does not make any change.
6873   AlgebraicSimplifierOptions options(
6874       [](const Shape&, const Shape&) { return false; });
6875   options.set_is_layout_sensitive(true);
6876   // The transformation of moving transpose and reshape to the constant side is
6877   // layout insensitive. It should not happen if AlgebraicSimplifier is set up
6878   // to be layout sensitive.
6879   ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6880 }
6881 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDimsNoChange)6882 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDimsNoChange) {
6883   // This isn't transformed (notice that the relative order of the `2` and `3`
6884   // dims doesn't change, so there's no opportunity here), but it's nonetheless
6885   // an interesting testcase because of the presence of the size-1 dimensions.
6886   const char* kModuleStr = R"(
6887     HloModule m
6888     test {
6889      param = f32[1,2,5,3] parameter(0)
6890      transpose = f32[1,5,2,3] transpose(param), dimensions={0,2,1,3}
6891      reshape = f32[5,6] reshape(transpose)
6892      constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6893      ROOT dot = f32[5,4] dot(reshape, constant),
6894        lhs_contracting_dims={1}, rhs_contracting_dims={0}
6895     }
6896   )";
6897   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6898   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6899 }
6900 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDims)6901 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
6902   const char* kModuleStr = R"(
6903     HloModule m
6904     test {
6905      param = f32[1,2,3,5] parameter(0)
6906      transpose = f32[1,3,2,5] transpose(param), dimensions={0,2,1,3}
6907      reshape = f32[6,5] reshape(transpose)
6908      constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6909      ROOT dot = f32[5,4] dot(reshape, constant),
6910        lhs_contracting_dims={0}, rhs_contracting_dims={0}
6911     }
6912   )";
6913   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6914   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6915   auto shape1 = ShapeUtil::MakeShape(F32, {6, 5});
6916   auto shape2 = ShapeUtil::MakeShape(F32, {1, 3, 2, 4});
6917   auto shape3 = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
6918   const HloInstruction* transpose;
6919   ASSERT_THAT(m->entry_computation()->root_instruction(),
6920               GmockMatch(m::Dot(
6921                   m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6922                   m::Reshape(m::Transpose(&transpose,
6923                                           m::Reshape(m::Constant())
6924                                               .WithShapeCompatibleTo(&shape2))
6925                                  .WithShapeCompatibleTo(&shape3)))));
6926   EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6927 }
6928 
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NoChangeInContractingDimsOrder)6929 TEST_F(AlgebraicSimplifierTest,
6930        DotContractingReorder_NoChangeInContractingDimsOrder) {
6931   // No optimization opportunity here because the transpose does not reorder the
6932   // contracting dims.
6933   const char* kModuleStr = R"(
6934     HloModule m
6935     test {
6936       param = f32[2,5,1,3] parameter(0)
6937       transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
6938       reshape = f32[5,6] reshape(transpose)
6939       constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6940       ROOT dot = f32[5,4] dot(reshape, constant),
6941         lhs_contracting_dims={1}, rhs_contracting_dims={0}
6942     }
6943   )";
6944   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6945   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6946 }
6947 
TEST_F(AlgebraicSimplifierTest,CompareIota)6948 TEST_F(AlgebraicSimplifierTest, CompareIota) {
6949   const char* kModuleStr = R"(
6950     HloModule m
6951     test {
6952       zero = s32[] constant(0)
6953       iota = s32[128] iota(), iota_dimension=0
6954       broad = s32[128] broadcast(zero), dimensions={}
6955       ROOT compare = pred[128] compare(iota, broad), direction=LT
6956     })";
6957   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6958   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6959   EXPECT_THAT(m->entry_computation()->root_instruction(),
6960               GmockMatch(m::Broadcast(m::ConstantScalar(false))));
6961 }
6962 
TEST_F(AlgebraicSimplifierTest,CompareLtZero)6963 TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
6964   const char* kModuleStr = R"(
6965     HloModule m
6966     test {
6967       zero = u32[] constant(0)
6968       param = u32[] parameter(0)
6969       ROOT compare = pred[] compare(param, zero), direction=LT
6970     })";
6971   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6972   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6973   EXPECT_THAT(m->entry_computation()->root_instruction(),
6974               GmockMatch(m::ConstantScalar(false)));
6975 }
6976 
TEST_F(AlgebraicSimplifierTest,CompareLeZero)6977 TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
6978   const char* kModuleStr = R"(
6979     HloModule m
6980     test {
6981       zero = u32[] constant(0)
6982       param = u32[] parameter(0)
6983       ROOT compare = pred[] compare(param, zero), direction=LE
6984     })";
6985   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6986   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6987   EXPECT_THAT(
6988       m->entry_computation()->root_instruction(),
6989       GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
6990 }
6991 
TEST_F(AlgebraicSimplifierTest,CompareGeZero)6992 TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
6993   const char* kModuleStr = R"(
6994     HloModule m
6995     test {
6996       zero = u32[] constant(0)
6997       param = u32[] parameter(0)
6998       ROOT compare = pred[] compare(param, zero), direction=GE
6999     })";
7000   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7001   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7002   EXPECT_THAT(m->entry_computation()->root_instruction(),
7003               GmockMatch(m::ConstantScalar(true)));
7004 }
7005 
TEST_F(AlgebraicSimplifierTest,CompareGtZero)7006 TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
7007   const char* kModuleStr = R"(
7008     HloModule m
7009     test {
7010       zero = u32[] constant(0)
7011       param = u32[] parameter(0)
7012       ROOT compare = pred[] compare(param, zero), direction=GT
7013     })";
7014   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7015   EXPECT_THAT(
7016       m->entry_computation()->root_instruction(),
7017       GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
7018 }
7019 
TEST_F(AlgebraicSimplifierTest,CompareZeroGt)7020 TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
7021   const char* kModuleStr = R"(
7022     HloModule m
7023     test {
7024       zero = u32[] constant(0)
7025       param = u32[] parameter(0)
7026       ROOT compare = pred[] compare(zero, param), direction=GT
7027     })";
7028   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7029   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7030   EXPECT_THAT(m->entry_computation()->root_instruction(),
7031               GmockMatch(m::ConstantScalar(false)));
7032 }
7033 
TEST_F(AlgebraicSimplifierTest,CompareZeroGe)7034 TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
7035   const char* kModuleStr = R"(
7036     HloModule m
7037     test {
7038       zero = u32[] constant(0)
7039       param = u32[] parameter(0)
7040       ROOT compare = pred[] compare(zero, param), direction=GE
7041     })";
7042   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7043   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7044   EXPECT_THAT(
7045       m->entry_computation()->root_instruction(),
7046       GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
7047 }
7048 
TEST_F(AlgebraicSimplifierTest,CompareZeroLe)7049 TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
7050   const char* kModuleStr = R"(
7051     HloModule m
7052     test {
7053       zero = u32[] constant(0)
7054       param = u32[] parameter(0)
7055       ROOT compare = pred[] compare(zero, param), direction=LE
7056     })";
7057   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7058   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7059   EXPECT_THAT(m->entry_computation()->root_instruction(),
7060               GmockMatch(m::ConstantScalar(true)));
7061 }
7062 
TEST_F(AlgebraicSimplifierTest,CompareZeroLt)7063 TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
7064   const char* kModuleStr = R"(
7065     HloModule m
7066     test {
7067       zero = u32[] constant(0)
7068       param = u32[] parameter(0)
7069       ROOT compare = pred[] compare(zero, param), direction=LT
7070     })";
7071   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7072   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7073   EXPECT_THAT(
7074       m->entry_computation()->root_instruction(),
7075       GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
7076 }
7077 
TEST_F(AlgebraicSimplifierTest,CompareSame)7078 TEST_F(AlgebraicSimplifierTest, CompareSame) {
7079   const char* kModuleStr = R"(
7080     HloModule m
7081     test {
7082       param = s32[123] parameter(0)
7083       ROOT compare = pred[123] compare(param, param), direction=GE
7084     })";
7085   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7086   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7087   EXPECT_THAT(m->entry_computation()->root_instruction(),
7088               GmockMatch(m::Broadcast(m::ConstantScalar(true))));
7089 }
7090 
TEST_F(AlgebraicSimplifierTest,CompareSimplified)7091 TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
7092   const char* kModuleStr = R"(
7093     HloModule m
7094     test {
7095       param = s32[] parameter(0)
7096       c1 = s32[] constant(10)
7097       c2 = s32[] constant(100)
7098       cmp1 = pred[] compare(param, c1), direction=LT
7099       cmp2 = pred[] compare(param, c2), direction=LT
7100       ROOT out = pred[] and(cmp1, cmp2)
7101     })";
7102   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7103   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7104   EXPECT_THAT(
7105       m->entry_computation()->root_instruction(),
7106       GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
7107                      .WithComparisonDirection(ComparisonDirection::kLt)));
7108 }
7109 
TEST_F(AlgebraicSimplifierTest,CompareSimplifiedReversed)7110 TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) {
7111   const char* kModuleStr = R"(
7112     HloModule m
7113     test {
7114       param = s32[] parameter(0)
7115       c1 = s32[] constant(10)
7116       c2 = s32[] constant(100)
7117       cmp1 = pred[] compare(param, c1), direction=LT
7118       cmp2 = pred[] compare(c2, param), direction=GT
7119       ROOT out = pred[] and(cmp1, cmp2)
7120     })";
7121   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7122   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7123   EXPECT_THAT(
7124       m->entry_computation()->root_instruction(),
7125       GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
7126                      .WithComparisonDirection(ComparisonDirection::kLt)));
7127 }
7128 
TEST_F(AlgebraicSimplifierTest,CanDisableDotToMultiplyRewrite)7129 TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
7130   // Some backends may have better performance by treating an outer product as a
7131   // Dot, rather than a broadcast Multiply
7132   const char* kModuleStr = R"(
7133     HloModule m
7134     test {
7135       param1 = f32[64] parameter(0)
7136       param2 = f32[64] parameter(1)
7137       ROOT compare = f32[64, 64] dot(param1, param2),
7138         lhs_contracting_dims={}, rhs_contracting_dims={}
7139     })";
7140 
7141   // Verify that the default is to re-write
7142   TF_ASSERT_OK_AND_ASSIGN(auto m1, ParseAndReturnVerifiedModule(kModuleStr));
7143   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m1.get()).ValueOrDie());
7144   EXPECT_THAT(m1->entry_computation()->root_instruction(),
7145               GmockMatch(m::Multiply(m::Op(), m::Op())));
7146 
7147   // Verify that we can disable the re-write
7148   AlgebraicSimplifierOptions opts = default_options_;
7149   opts.set_enable_dot_to_multiply_rewrite(false);
7150   TF_ASSERT_OK_AND_ASSIGN(auto m2, ParseAndReturnVerifiedModule(kModuleStr));
7151   ASSERT_FALSE(AlgebraicSimplifier(opts).Run(m2.get()).ValueOrDie());
7152 }
7153 
TEST_F(AlgebraicSimplifierTest,RemainderOfIota)7154 TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
7155   const char* kModuleStr = R"(
7156     HloModule m
7157     test {
7158       iota = s32[5,1000] iota(), iota_dimension=0
7159       five = s32[] constant(5)
7160       five_bcast = s32[5,1000] broadcast(s32[] five), dimensions={}
7161       ROOT remainder = s32[5,1000] remainder(iota, s32[5,1000] five_bcast)
7162     })";
7163   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7164   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7165   EXPECT_THAT(m->entry_computation()->root_instruction(),
7166               GmockMatch(m::Iota()));
7167 }
7168 
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIota)7169 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIota) {
7170   const char* kModuleStr = R"(
7171     HloModule m
7172     test {
7173       iota = s32[5,1000] iota(), iota_dimension=0
7174       five = s32[] constant(5)
7175       five_bcast = s32[5,1000] broadcast(five), dimensions={}
7176       sum = s32[5,1000] add(iota, five_bcast)
7177       ROOT remainder = s32[5,1000] remainder(sum, s32[5,1000] five_bcast)
7178     })";
7179   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7180   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7181   EXPECT_THAT(m->entry_computation()->root_instruction(),
7182               GmockMatch(m::Remainder(m::Iota(), m::Broadcast())));
7183 }
7184 
7185 // No simplification because 125 + 5 overflows S8.
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIotaOverflow)7186 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
7187   const char* kModuleStr = R"(
7188     HloModule m
7189     test {
7190       iota = s8[126] iota(), iota_dimension=0
7191       five = s8[] constant(5)
7192       five_bcast = s8[126] broadcast(five), dimensions={}
7193       sum = s8[126] add(iota, five_bcast)
7194       ROOT remainder = s8[126] remainder(sum, s8[126] five_bcast)
7195     })";
7196   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7197   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7198 }
7199 
TEST_F(AlgebraicSimplifierTest,RepeatedRemainder)7200 TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
7201   const char* kModuleStr = R"(
7202     HloModule m
7203     test {
7204       p = s32[1000] parameter(0)
7205       q = s32[1000] parameter(1)
7206       r = s32[1000] remainder(p, q)
7207       ROOT rr = s32[1000] remainder(r, q)
7208     })";
7209   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7210   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7211   EXPECT_THAT(m->entry_computation()->root_instruction(),
7212               GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
7213 }
7214 
TEST_F(AlgebraicSimplifierTest,SlicePadLayout)7215 TEST_F(AlgebraicSimplifierTest, SlicePadLayout) {
7216   const char* kModuleStr = R"(
7217     HloModule m
7218     test {
7219       %param.0 = f32[128,9,9,1024]{0,3,2,1} parameter(0)
7220       %param.1 = f32[] parameter(1)
7221       %slice = f32[128,9,9,1024]{0,3,2,1} slice(%param.0),
7222         slice={[0:128], [0:9], [0:9], [0:1024]}
7223       ROOT %pad = f32[128,8,9,1024]{0,3,2,1} pad(%slice, %param.1),
7224         padding=0_0x-1_0x0_0x0_0
7225     })";
7226   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7227   const Shape root_shape = m->entry_computation()->root_instruction()->shape();
7228   AlgebraicSimplifierOptions options;
7229   options.set_is_layout_sensitive(true);
7230   ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
7231   EXPECT_THAT(m->entry_computation()->root_instruction(),
7232               GmockMatch(m::Slice().WithShapeEqualTo(&root_shape)));
7233 }
7234 
TEST_F(AlgebraicSimplifierTest,MinOfMaxToClamp)7235 TEST_F(AlgebraicSimplifierTest, MinOfMaxToClamp) {
7236   const char* kModuleStr = R"(
7237     HloModule m
7238     test {
7239       p0 = f32[4] parameter(0)
7240       c0 = f32[] constant(3.0)
7241       c1 = f32[] constant(4.0)
7242       b0 = f32[4] broadcast(c0), dimensions={}
7243       b1 = f32[4] broadcast(c1), dimensions={}
7244       m0 = f32[4] maximum(b0, p0)
7245       ROOT m1 = f32[4] minimum(m0, b1)
7246     }
7247   )";
7248   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7249   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7250   EXPECT_THAT(
7251       m->entry_computation()->root_instruction(),
7252       GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
7253                           m::Broadcast(m::ConstantScalar(4.0)))));
7254 }
7255 
TEST_F(AlgebraicSimplifierTest,MaxOfMinToClamp)7256 TEST_F(AlgebraicSimplifierTest, MaxOfMinToClamp) {
7257   const char* kModuleStr = R"(
7258     HloModule m
7259     test {
7260       p0 = f32[4] parameter(0)
7261       c0 = f32[] constant(3.0)
7262       c1 = f32[] constant(4.0)
7263       b0 = f32[4] broadcast(c0), dimensions={}
7264       b1 = f32[4] broadcast(c1), dimensions={}
7265       m0 = f32[4] minimum(p0, b1)
7266       ROOT m1 = f32[4] maximum(b0, m0)
7267     }
7268   )";
7269   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7270   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7271   EXPECT_THAT(
7272       m->entry_computation()->root_instruction(),
7273       GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
7274                           m::Broadcast(m::ConstantScalar(4.0)))));
7275 }
7276 
TEST_F(AlgebraicSimplifierTest,ClampOfClamp)7277 TEST_F(AlgebraicSimplifierTest, ClampOfClamp) {
7278   const char* kModuleStr = R"(
7279     HloModule m
7280     test {
7281       p0 = f32[] parameter(0)
7282       p1 = f32[] parameter(1)
7283       p2 = f32[] parameter(2)
7284       c0 = f32[] clamp(p0, p1, p2)
7285       ROOT c1 = f32[] clamp(p0, c0, p2)
7286     }
7287   )";
7288   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7289   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7290   EXPECT_THAT(
7291       m->entry_computation()->root_instruction(),
7292       GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
7293 }
7294 
TEST_F(AlgebraicSimplifierTest,MaxOfClamp)7295 TEST_F(AlgebraicSimplifierTest, MaxOfClamp) {
7296   const char* kModuleStr = R"(
7297     HloModule m
7298     test {
7299       p0 = f32[] parameter(0)
7300       p1 = f32[] parameter(1)
7301       p2 = f32[] parameter(2)
7302       c0 = f32[] clamp(p0, p1, p2)
7303       ROOT m0 = f32[] maximum(p0, c0)
7304     }
7305   )";
7306   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7307   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7308   EXPECT_THAT(
7309       m->entry_computation()->root_instruction(),
7310       GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
7311 }
7312 
TEST_F(AlgebraicSimplifierTest,SliceOfConcat)7313 TEST_F(AlgebraicSimplifierTest, SliceOfConcat) {
7314   const char* kModuleStr = R"(
7315     HloModule m
7316     test {
7317       p0 = f32[100,50] parameter(0)
7318       p1 = f32[50,50] parameter(1)
7319       c0 = f32[150,50] concatenate(p0, p1), dimensions={0}
7320       ROOT s0 = f32[50,50] slice(c0), slice={[100:150], [0:50]}
7321     }
7322   )";
7323   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7324   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7325   EXPECT_THAT(m->entry_computation()->root_instruction(),
7326               GmockMatch(m::Parameter(1)));
7327 }
7328 
TEST_F(AlgebraicSimplifierTest,SliceOfMultipleConcatOperands)7329 TEST_F(AlgebraicSimplifierTest, SliceOfMultipleConcatOperands) {
7330   const char* kModuleStr = R"(
7331     HloModule m
7332     test {
7333       p0 = f32[50,50] parameter(0)
7334       p1 = f32[50,50] parameter(1)
7335       p2 = f32[50,50] parameter(2)
7336       p3 = f32[50,50] parameter(3)
7337       c0 = f32[200,50] concatenate(p0, p1, p2, p3), dimensions={0}
7338       ROOT s0 = f32[98,50] slice(c0), slice={[51:149], [0:50]}
7339     }
7340   )";
7341   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7342   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7343   EXPECT_THAT(
7344       m->entry_computation()->root_instruction(),
7345       GmockMatch(m::Slice(m::Concatenate(m::Parameter(1), m::Parameter(2)))));
7346   EXPECT_THAT(m->entry_computation()->root_instruction()->slice_starts(),
7347               ElementsAre(1, 0));
7348   EXPECT_THAT(m->entry_computation()->root_instruction()->slice_limits(),
7349               ElementsAre(99, 50));
7350 }
7351 
TEST_F(AlgebraicSimplifierTest,SqrtOfSelfMultiply)7352 TEST_F(AlgebraicSimplifierTest, SqrtOfSelfMultiply) {
7353   const char* kModuleStr = R"(
7354     HloModule m
7355     test {
7356       p0 = f32[32]{0} parameter(0)
7357       m0 = f32[32]{0} multiply(f32[32]{0} p0, f32[32]{0} p0)
7358       ROOT s0 = f32[32]{0} sqrt(f32[32]{0} m0)
7359     }
7360   )";
7361   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7362   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7363   EXPECT_THAT(m->entry_computation()->root_instruction(),
7364               GmockMatch(m::Abs(m::Parameter(0))));
7365 }
7366 
TEST_F(AlgebraicSimplifierTest,ReduceOfBatchDotToContractingDimension)7367 TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
7368   const char* kModuleStr = R"(
7369     HloModule m
7370     a {
7371       p0 = f32[] parameter(0)
7372       p1 = f32[] parameter(1)
7373       ROOT r = f32[] add(p0, p1)
7374     }
7375     test {
7376       p0 = f32[32,8,5,6] parameter(0)
7377       p1 = f32[8,32,6,7] parameter(1)
7378       d = f32[32,8,5,7] dot(p0, p1),
7379         lhs_batch_dims={0,1},
7380         rhs_batch_dims={1,0},
7381         rhs_contracting_dims={2},
7382         lhs_contracting_dims={3}
7383      c = f32[] constant(0)
7384      ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a
7385     }
7386   )";
7387   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7388   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7389   EXPECT_THAT(m->entry_computation()->root_instruction(),
7390               GmockMatch(m::Dot(m::Parameter(0), m::Parameter(1))));
7391 }
7392 
TEST_F(AlgebraicSimplifierTest,ReduceAddIsCommutative)7393 TEST_F(AlgebraicSimplifierTest, ReduceAddIsCommutative) {
7394   const char* kModuleStr = R"(
7395     HloModule m
7396     fn1 {
7397       p0 = f32[] parameter(0)
7398       p1 = f32[] parameter(1)
7399       ROOT r = f32[] add(p0, p1)
7400     }
7401     fn2 {
7402       p0 = f32[] parameter(0)
7403       p1 = f32[] parameter(1)
7404       ROOT r = f32[] add(p1, p0)
7405     }
7406     test {
7407       p0 = f32[10,10,10] parameter(0)
7408       zero = f32[] constant(0)
7409       r1 = f32[10,10] reduce(p0, zero), dimensions={0}, to_apply=fn1
7410       ROOT r2 = f32[10] reduce(r1, zero), dimensions={0}, to_apply=fn2
7411     }
7412   )";
7413   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7414   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7415   EXPECT_THAT(m->entry_computation()->root_instruction(),
7416               GmockMatch(m::Reduce(m::Parameter(0), m::ConstantScalar(0))));
7417 }
7418 
TEST_F(AlgebraicSimplifierTest,RsqrtOfRPower)7419 TEST_F(AlgebraicSimplifierTest, RsqrtOfRPower) {
7420   const char* kModuleStr = R"(
7421     HloModule m
7422     test {
7423       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7424       p1 = f32[32]{0} parameter(1)
7425       p2 = f32[32]{0} parameter(2)
7426       c0 = f32[] constant(0.001)
7427       c1 = s64[] constant(1)
7428       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, c0, c1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7429       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7430       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7431       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7432       c2 = f32[] constant(-2)
7433       broadcast = f32[32]{0} broadcast(f32[] c2), dimensions={}
7434       power = f32[32]{0} power(get-tuple-element, broadcast)
7435       rsqrt = f32[32]{0} rsqrt(f32[32]{0} power)
7436       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
7437     }
7438   )";
7439   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7440   default_options_.set_cudnn_batchnorm_forward_training_metadata(
7441       "__cudnn$batchNormalizationForwardTraining");
7442   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7443   // Expect transformation: rsqrt(power(gte.2,-2)) -> abs(gte.2)
7444   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kPower), nullptr);
7445   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7446   auto computation = m->entry_computation();
7447   auto root = computation->root_instruction();
7448   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
7449   EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kAbs);
7450   EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
7451             HloOpcode::kGetTupleElement);
7452 }
7453 
TEST_F(AlgebraicSimplifierTest,RsqrtDivide)7454 TEST_F(AlgebraicSimplifierTest, RsqrtDivide) {
7455   const char* kModuleStr = R"(
7456     HloModule m
7457     test {
7458       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7459       p1 = f32[32]{0} parameter(1)
7460       p2 = f32[32]{0} parameter(2)
7461       constant = f32[] constant(0.001)
7462       constant.1 = s64[] constant(1)
7463       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7464       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7465       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7466       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7467       constant.2 = f32[] constant(1)
7468       broadcast.1 = f32[32]{0} broadcast(constant.2), dimensions={}
7469       divide = f32[32]{0} divide(broadcast.1, get-tuple-element)
7470       rsqrt = f32[32]{0} rsqrt(divide)
7471       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
7472     }
7473   )";
7474   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7475   default_options_.set_cudnn_batchnorm_forward_training_metadata(
7476       "__cudnn$batchNormalizationForwardTraining");
7477   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7478   // Expect transformation: rsqrt(divide(1,gte.2)) -> sqrt(gte.2)
7479   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
7480   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7481   auto computation = m->entry_computation();
7482   auto root = computation->root_instruction();
7483   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
7484   EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kSqrt);
7485   EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
7486             HloOpcode::kGetTupleElement);
7487 }
7488 
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt)7489 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) {
7490   const char* kModuleStr = R"(
7491     HloModule m
7492     test {
7493       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7494       p1 = f32[32]{0} parameter(1)
7495       p2 = f32[32]{0} parameter(2)
7496       constant = f32[] constant(0.001)
7497       constant.1 = s64[] constant(1)
7498       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7499       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7500       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7501       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7502       rsqrt = f32[32]{0} rsqrt(get-tuple-element)
7503       multiply = f32[32]{0} multiply(rsqrt, rsqrt)
7504       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
7505     }
7506   )";
7507   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7508   default_options_.set_cudnn_batchnorm_forward_training_metadata(
7509       "__cudnn$batchNormalizationForwardTraining");
7510   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7511 
7512   // Expect transformation: multiply(rsqrt(gte.2), rsqrt(gte.2)) -> divide(1,
7513   // gte.2)
7514   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
7515   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7516 
7517   auto computation = m->entry_computation();
7518   auto root = computation->root_instruction();
7519   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
7520   EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kDivide);
7521   EXPECT_EQ(root->operand(2)->operand(0)->opcode(), HloOpcode::kBroadcast);
7522   EXPECT_EQ(root->operand(2)->operand(1)->opcode(),
7523             HloOpcode::kGetTupleElement);
7524 }
7525 
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt_NegativeTestCase)7526 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) {
7527   const char* kModuleStr = R"(
7528     HloModule m
7529     test {
7530       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7531       p1 = f32[32]{0} parameter(1)
7532       p2 = f32[32]{0} parameter(2)
7533       constant = f32[] constant(0.001)
7534       constant.1 = s64[] constant(1)
7535       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7536       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7537       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7538       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7539       rsqrt = f32[32]{0} rsqrt(get-tuple-element)
7540       multiply = f32[32]{0} multiply(rsqrt, rsqrt)
7541       ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
7542     }
7543   )";
7544   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7545   default_options_.set_cudnn_batchnorm_forward_training_metadata(
7546       "__cudnn$batchNormalizationForward");
7547   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7548   EXPECT_NE(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
7549   EXPECT_NE(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7550   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
7551   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kBroadcast), nullptr);
7552   EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
7553             HloOpcode::kMultiply);
7554 }
7555 
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining)7556 TEST_F(AlgebraicSimplifierTest, AbsEliminationBatchnormTraining) {
7557   const char* kModuleStr = R"(
7558     HloModule m
7559     test {
7560       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7561       p1 = f32[32]{0} parameter(1)
7562       p2 = f32[32]{0} parameter(2)
7563       constant = f32[] constant(0.001)
7564       constant.1 = s64[] constant(1)
7565       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7566       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7567       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7568       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7569       abs = f32[32]{0} abs(get-tuple-element)
7570       ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
7571     }
7572   )";
7573   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7574   default_options_.set_cudnn_batchnorm_forward_training_metadata(
7575       "__cudnn$batchNormalizationForwardTraining");
7576   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7577   // Verify that the module doesn't have any abs node.
7578   EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
7579   EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
7580             HloOpcode::kGetTupleElement);
7581 }
7582 
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining_NegativeTestCase)7583 TEST_F(AlgebraicSimplifierTest,
7584        AbsEliminationBatchnormTraining_NegativeTestCase) {
7585   const char* kModuleStr = R"(
7586     HloModule m
7587     test {
7588       p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7589       p1 = f32[32]{0} parameter(1)
7590       p2 = f32[32]{0} parameter(2)
7591       constant = f32[] constant(0.001)
7592       constant.1 = s64[] constant(1)
7593       custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7594       get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7595       get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7596       get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7597       abs = f32[32]{0} abs(get-tuple-element)
7598       ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
7599     }
7600   )";
7601   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7602   default_options_.set_cudnn_batchnorm_forward_training_metadata(
7603       "__cudnn$batchNormalizationForwardInference");
7604   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7605   EXPECT_NE(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
7606 }
7607 
TEST_F(AlgebraicSimplifierTest,AbsEliminationMultiply)7608 TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
7609   const char* kModuleStr = R"(
7610     HloModule m
7611     test {
7612       p = f32[32]{0} parameter(0)
7613       m = f32[32]{0} multiply(p, p)
7614       ROOT a = f32[32]{0} abs(m)
7615     }
7616   )";
7617   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7618   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7619   EXPECT_THAT(m->entry_computation()->root_instruction(),
7620               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
7621 }
7622 
TEST_F(AlgebraicSimplifierTest,AbsEliminationPower2)7623 TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
7624   const char* kModuleStr = R"(
7625     HloModule m
7626     test {
7627       p0 = f32[32]{0} parameter(0)
7628       c0 = f32[] constant(2)
7629       b0 = f32[32]{0} broadcast(c0), dimensions={}
7630       pow = f32[32]{0} power(p0, b0)
7631       ROOT a = f32[32]{0} abs(pow)
7632     }
7633   )";
7634   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7635   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7636   // Pow(A, 2) is transformed to AA. As a result, Abs(Power(A, 2)) is
7637   // transformed to AA.
7638   EXPECT_THAT(m->entry_computation()->root_instruction(),
7639               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
7640 }
7641 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombined)7642 TEST_F(AlgebraicSimplifierTest, ScatterAddCombined) {
7643   const char* hlo_string = R"(
7644   HloModule m
7645   apply {
7646    a = f32[] parameter(0)
7647    b = f32[] parameter(1)
7648    ROOT c = f32[] add(a, b)
7649   }
7650   test {
7651     z  = f32[] constant(0)
7652     init = f32[100,4] broadcast(z), dimensions={}
7653     shared = f32[100,4] parameter(0)
7654     index0 = s32[20] parameter(1)
7655     index1 = s32[10] parameter(2)
7656     update0 = f32[20,4] parameter(3)
7657     update1 = f32[10,4] parameter(4)
7658     scatter.0 = f32[100,4] scatter(init, index0, update0),
7659               to_apply=apply,
7660               update_window_dims={1},
7661               inserted_window_dims={0},
7662               scatter_dims_to_operand_dims={0},
7663               index_vector_dim=1
7664     scatter.1 = f32[100,4] scatter(init, index1, update1),
7665               to_apply=apply,
7666               update_window_dims={1},
7667               inserted_window_dims={0},
7668               scatter_dims_to_operand_dims={0},
7669               index_vector_dim=1
7670     add.0 = f32[100,4] add(shared, scatter.0)
7671     ROOT add.1 = f32[100,4] add(add.0, scatter.1)
7672   }
7673   )";
7674   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7675   // Combine Scatters
7676   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7677   // Optimize Add with 0
7678   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7679   EXPECT_THAT(
7680       m->entry_computation()->root_instruction(),
7681       GmockMatch(m::Scatter(m::Parameter(0),
7682                             m::Concatenate(m::Parameter(1), m::Parameter(2)),
7683                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7684 }
7685 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedSwapped)7686 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedSwapped) {
7687   const char* hlo_string = R"(
7688   HloModule m
7689   apply {
7690    a = f32[] parameter(0)
7691    b = f32[] parameter(1)
7692    ROOT c = f32[] add(a, b)
7693   }
7694   test {
7695     z  = f32[] constant(0)
7696     init = f32[100,4] broadcast(z), dimensions={}
7697     shared = f32[100,4] parameter(0)
7698     index0 = s32[20] parameter(1)
7699     index1 = s32[10] parameter(2)
7700     update0 = f32[20,4] parameter(3)
7701     update1 = f32[10,4] parameter(4)
7702     scatter.0 = f32[100,4] scatter(init, index0, update0),
7703               to_apply=apply,
7704               update_window_dims={1},
7705               inserted_window_dims={0},
7706               scatter_dims_to_operand_dims={0},
7707               index_vector_dim=1
7708     scatter.1 = f32[100,4] scatter(init, index1, update1),
7709               to_apply=apply,
7710               update_window_dims={1},
7711               inserted_window_dims={0},
7712               scatter_dims_to_operand_dims={0},
7713               index_vector_dim=1
7714     add.0 = f32[100,4] add(shared, scatter.0)
7715     ROOT add.1 = f32[100,4] add(scatter.1, add.0)
7716   }
7717   )";
7718   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7719   // Combine Scatters
7720   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7721   // Optimize Add with 0
7722   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7723   EXPECT_THAT(
7724       m->entry_computation()->root_instruction(),
7725       GmockMatch(m::Scatter(m::Parameter(0),
7726                             m::Concatenate(m::Parameter(2), m::Parameter(1)),
7727                             m::Concatenate(m::Parameter(4), m::Parameter(3)))));
7728 }
7729 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums)7730 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums) {
7731   const char* hlo_string = R"(
7732   HloModule m
7733   apply {
7734    a = f32[] parameter(0)
7735    b = f32[] parameter(1)
7736    ROOT c = f32[] add(a, b)
7737   }
7738   test {
7739     z  = f32[] constant(0)
7740     init = f32[100,4] broadcast(z), dimensions={}
7741     shared = f32[100,4] parameter(0)
7742     index0 = s32[1,4,5] parameter(1)
7743     index1 = s32[1,2,5] parameter(2)
7744     update0 = f32[4,4,5] parameter(3)
7745     update1 = f32[2,4,5] parameter(4)
7746     scatter.0 = f32[100,4] scatter(init, index0, update0),
7747               to_apply=apply,
7748               update_window_dims={1},
7749               inserted_window_dims={0},
7750               scatter_dims_to_operand_dims={0},
7751               index_vector_dim=0
7752     scatter.1 = f32[100,4] scatter(init, index1, update1),
7753               to_apply=apply,
7754               update_window_dims={1},
7755               inserted_window_dims={0},
7756               scatter_dims_to_operand_dims={0},
7757               index_vector_dim=0
7758     ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7759   }
7760   )";
7761   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7762   // Combine Scatters
7763   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7764   // Simplify Add
7765   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7766   EXPECT_THAT(
7767       m->entry_computation()->root_instruction(),
7768       GmockMatch(m::Scatter(m::Broadcast(),
7769                             m::Concatenate(m::Parameter(1), m::Parameter(2)),
7770                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7771 }
7772 
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums2)7773 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) {
7774   const char* hlo_string = R"(
7775   HloModule m
7776   apply {
7777    a = f32[] parameter(0)
7778    b = f32[] parameter(1)
7779    ROOT c = f32[] add(a, b)
7780   }
7781   test {
7782     z  = f32[] constant(0)
7783     init = f32[100,4] broadcast(z), dimensions={}
7784     shared = f32[100,4] parameter(0)
7785     index0 = s32[4,3,1] parameter(1)
7786     index1 = s32[4,5,1] parameter(2)
7787     update0 = f32[4,4,3] parameter(3)
7788     update1 = f32[4,4,5] parameter(4)
7789     scatter.0 = f32[100,4] scatter(init, index0, update0),
7790               to_apply=apply,
7791               update_window_dims={0},
7792               inserted_window_dims={0},
7793               scatter_dims_to_operand_dims={0},
7794               index_vector_dim=2
7795     scatter.1 = f32[100,4] scatter(init, index1, update1),
7796               to_apply=apply,
7797               update_window_dims={0},
7798               inserted_window_dims={0},
7799               scatter_dims_to_operand_dims={0},
7800               index_vector_dim=2
7801     ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7802   }
7803   )";
7804   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7805   // Combine Scatters
7806   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7807   // Simplify Add
7808   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7809   EXPECT_THAT(
7810       m->entry_computation()->root_instruction(),
7811       GmockMatch(m::Scatter(m::Broadcast(),
7812                             m::Concatenate(m::Parameter(1), m::Parameter(2)),
7813                             m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7814 }
7815 
TEST_F(AlgebraicSimplifierTest,ScalarScatter)7816 TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
7817   const char* hlo_string = R"(
7818   HloModule m
7819   apply {
7820    a = f32[] parameter(0)
7821    b = f32[] parameter(1)
7822    ROOT c = f32[] add(a, b)
7823   }
7824   test {
7825     z  = f32[] constant(0)
7826     init = f32[100,4,20] broadcast(z), dimensions={}
7827     shared = f32[100,4,20] parameter(0)
7828     index0 = s32[1] parameter(1)
7829     index1 = s32[1] parameter(2)
7830     update0 = f32[4,20] parameter(3)
7831     update1 = f32[4,20] parameter(4)
7832     scatter.0 = f32[100,4,20] scatter(init, index0, update0),
7833               to_apply=apply,
7834               update_window_dims={0, 1},
7835               inserted_window_dims={0},
7836               scatter_dims_to_operand_dims={0},
7837               index_vector_dim=0
7838     scatter.1 = f32[100,4,20] scatter(init, index1, update1),
7839               to_apply=apply,
7840               update_window_dims={0, 1},
7841               inserted_window_dims={0},
7842               scatter_dims_to_operand_dims={0},
7843               index_vector_dim=0
7844     ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1)
7845   }
7846   )";
7847   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7848   // Combine Scatters
7849   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7850 }
7851 
TEST_F(AlgebraicSimplifierTest,SwapConvOperands)7852 TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
7853   const char* hlo_string = R"(
7854   HloModule m
7855   test {
7856     a = f32[3,3,160,160] parameter(0)
7857     b = f32[128,32,32,160] parameter(1)
7858     ROOT c = f32[128,32,32,160] convolution(a,b),
7859      window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1},
7860      dim_labels=01bf_o01i->f01b
7861   }
7862   )";
7863   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7864   // Combine Scatters
7865   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7866   const HloInstruction* conv = m->entry_computation()->root_instruction();
7867   EXPECT_THAT(conv,
7868               GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0))));
7869   EXPECT_EQ(conv->window().dimensions(0).size(), 3);
7870   EXPECT_EQ(conv->window().dimensions(1).size(), 3);
7871   EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true);
7872   EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true);
7873   EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1);
7874   EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1);
7875   EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1);
7876   EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
7877 }
7878 
TEST_F(AlgebraicSimplifierTest,ScalarDividePredicate)7879 TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
7880   const char* kModuleStr = R"(
7881     HloModule m
7882     test {
7883       p0 = pred[2] parameter(0)
7884       cvt = f32[2] convert(p0)
7885       p1 = f32[] parameter(1)
7886       bcast = f32[2] broadcast(p1), dimensions={}
7887       ROOT div = f32[2] divide(cvt, bcast)
7888     }
7889   )";
7890   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7891   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7892   EXPECT_THAT(
7893       m->entry_computation()->root_instruction(),
7894       GmockMatch(m::MultiplyAnyOrder(
7895           m::Convert(m::Parameter(0)),
7896           m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
7897 }
7898 
TEST_F(AlgebraicSimplifierTest,MultipleDotStrengthReductions)7899 TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
7900   constexpr char kModuleStr[] = R"(
7901     HloModule test
7902     ENTRY test {
7903       a = c64[2,2] parameter(0)
7904       b = c64[2] parameter(1)
7905       cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7906       c = f64[2,2] parameter(2)
7907       d = f64[2] parameter(3)
7908       dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7909       ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
7910     }
7911   )";
7912   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7913   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7914   EXPECT_EQ(3, m->computation_count());
7915 }
7916 
TEST_F(AlgebraicSimplifierTest,UnaryVariadicReduce)7917 TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) {
7918   const char* kModuleStr = R"(
7919     HloModule m
7920     fn {
7921       p0 = f32[] parameter(0)
7922       p1 = f32[] parameter(1)
7923       a = f32[] add(p0, p1)
7924       ROOT t = (f32[]) tuple(a)
7925     }
7926     test {
7927       p0 = f32[32,8,6,7] parameter(0)
7928       c = f32[] constant(0)
7929       ROOT r = (f32[8,6,7]) reduce(p0, c), dimensions={0}, to_apply=fn
7930     }
7931   )";
7932   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7933   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7934   ASSERT_THAT(
7935       m->entry_computation()->root_instruction(),
7936       GmockMatch(m::Tuple(m::Reduce(m::Parameter(0), m::ConstantScalar(0)))));
7937   ASSERT_EQ(m->entry_computation()
7938                 ->root_instruction()
7939                 ->operand(0)
7940                 ->called_computations()
7941                 .size(),
7942             1);
7943   EXPECT_THAT(m->entry_computation()
7944                   ->root_instruction()
7945                   ->operand(0)
7946                   ->called_computations()[0]
7947                   ->root_instruction(),
7948               GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
7949 }
7950 
TEST_F(AlgebraicSimplifierTest,ReplaceReduceMaxWithReduceArgMax)7951 TEST_F(AlgebraicSimplifierTest, ReplaceReduceMaxWithReduceArgMax) {
7952   const char* kModuleStr = R"(
7953 HloModule ReplaceReduceMaxWithReduceArgMax
7954 
7955 %reduction_computation__1.25287 (parameter.25288: bf16[], parameter.25289: s32[], parameter.25290: bf16[], parameter.25291: s32[]) -> (bf16[], s32[]) {
7956   %constant.25292 = pred[] constant(false)
7957   %parameter.25288 = bf16[] parameter(0)
7958   %parameter.25290 = bf16[] parameter(2)
7959   %compare.25293 = pred[] compare(bf16[] %parameter.25288, bf16[] %parameter.25290), direction=GT
7960   %compare.25294 = pred[] compare(bf16[] %parameter.25288, bf16[] %parameter.25288), direction=NE
7961   %or.25295 = pred[] or(pred[] %compare.25293, pred[] %compare.25294)
7962   %select.25300 = bf16[] select(pred[] %or.25295, bf16[] %parameter.25288, bf16[] %parameter.25290)
7963   %compare.25296 = pred[] compare(bf16[] %parameter.25288, bf16[] %parameter.25290), direction=EQ
7964   %parameter.25289 = s32[] parameter(1)
7965   %parameter.25291 = s32[] parameter(3)
7966   %compare.25297 = pred[] compare(s32[] %parameter.25289, s32[] %parameter.25291), direction=LT
7967   %and.25298 = pred[] and(pred[] %compare.25296, pred[] %compare.25297)
7968   %or.25299 = pred[] or(pred[] %or.25295, pred[] %and.25298)
7969   %select.25301 = s32[] select(pred[] %or.25299, s32[] %parameter.25289, s32[] %parameter.25291)
7970   ROOT %tuple.25302 = (bf16[], s32[]) tuple(bf16[] %select.25300, s32[] %select.25301)
7971 }
7972 
7973 %primitive_computation_max.25303 (parameter.25304: bf16[], parameter.25305: bf16[]) -> bf16[] {
7974   %parameter.25304 = bf16[] parameter(0), metadata={op_type="max" op_name="max"}
7975   %parameter.25305 = bf16[] parameter(1), metadata={op_type="max" op_name="max"}
7976   ROOT %maximum.25306 = bf16[] maximum(bf16[] %parameter.25304, bf16[] %parameter.25305), metadata={op_type="max" op_name="max"}
7977 }
7978 
7979 ENTRY %main {
7980   %p0 = bf16[384,128,19392]{2,1,0} parameter(0)
7981 
7982   // Variadic Reduce (ArgMax)
7983   %iota.25376 = s32[384,128,19392] iota(), iota_dimension=2
7984   %constant.25377 = bf16[] constant(-inf)
7985   %constant.25378 = s32[] constant(0)
7986   %reduce.25379 = (bf16[384,128]{1,0}, s32[384,128]{1,0}) reduce(bf16[384,128,19392]{2,1,0} %p0, s32[384,128,19392] %iota.25376, bf16[] %constant.25377, s32[] %constant.25378), dimensions={2}, to_apply=%reduction_computation__1.25287
7987 
7988   %get-tuple-element.25381 = s32[384,128]{1,0} get-tuple-element((bf16[384,128]{1,0}, s32[384,128]{1,0}) %reduce.25379), index=1
7989 
7990   // Reduce (Max)
7991   %constant.25382 = bf16[] constant(-inf)
7992   %reduce.25383 = bf16[384,128]{1,0} reduce(bf16[384,128,19392]{2,1,0} %p0, bf16[] %constant.25382), dimensions={2}, to_apply=%primitive_computation_max.25303
7993 
7994   ROOT %tuple.0 = (bf16[384,128]{1,0}, s32[384,128]{1,0}) tuple(%reduce.25383, %get-tuple-element.25381)
7995 }
7996 )";
7997   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7998   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7999   int64_t reduce_count = absl::c_count_if(
8000       m->entry_computation()->instructions(), [](const HloInstruction* hlo) {
8001         return hlo->opcode() == HloOpcode::kReduce;
8002       });
8003   // Expect one Reduce operation after simplification.
8004   EXPECT_EQ(1, reduce_count);
8005   auto variadic_reduce = m::Reduce().WithShape(m::Shape().IsTuple());
8006   auto root = m->entry_computation()->root_instruction();
8007   // Expect that both outputs are fed by 'variadic_reduce'.
8008   ASSERT_THAT(root,
8009               GmockMatch(m::Tuple(m::GetTupleElement(variadic_reduce, 0),
8010                                   m::GetTupleElement(variadic_reduce, 1))));
8011 }
8012 
TEST_F(AlgebraicSimplifierTest,ReplaceReduceMinWithReduceArgMin)8013 TEST_F(AlgebraicSimplifierTest, ReplaceReduceMinWithReduceArgMin) {
8014   const char* kModuleStr = R"(
8015 HloModule ReplaceReduceMinWithReduceArgMin
8016 
8017 %region_3.84 (Arg_0.85: bf16[], Arg_1.86: s32[], Arg_2.87: bf16[], Arg_3.88: s32[]) -> (bf16[], s32[]) {
8018   %Arg_3.88 = s32[]{:T(256)} parameter(3)
8019   %Arg_2.87 = bf16[]{:T(512)} parameter(2)
8020   %Arg_1.86 = s32[]{:T(256)} parameter(1)
8021   %compare.93 = pred[]{:T(1024)S(6)} compare(s32[]{:T(256)} %Arg_1.86, s32[]{:T(256)} %Arg_3.88), direction=LT, metadata={op_name="lt" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8022   %Arg_0.85 = bf16[]{:T(512)} parameter(0)
8023   %compare.92 = pred[]{:T(1024)S(6)} compare(bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_2.87), direction=EQ, metadata={op_name="eq" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8024   %and.94 = pred[]{:T(1024)S(6)} and(pred[]{:T(1024)S(6)} %compare.92, pred[]{:T(1024)S(6)} %compare.93), metadata={op_name="and" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8025   %compare.90 = pred[]{:T(1024)S(6)} compare(bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_0.85), direction=NE, metadata={op_name="ne" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8026   %compare.89 = pred[]{:T(1024)S(6)} compare(bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_2.87), direction=LT, metadata={op_name="lt" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8027   %or.91 = pred[]{:T(1024)S(6)} or(pred[]{:T(1024)S(6)} %compare.89, pred[]{:T(1024)S(6)} %compare.90), metadata={op_name="or" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8028   %select.96 = bf16[]{:T(512)} select(pred[]{:T(1024)S(6)} %or.91, bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_2.87), metadata={op_name="select_n" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8029   %or.95 = pred[]{:T(1024)S(6)} or(pred[]{:T(1024)S(6)} %or.91, pred[]{:T(1024)S(6)} %and.94), metadata={op_name="or" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8030   %select.97 = s32[]{:T(256)} select(pred[]{:T(1024)S(6)} %or.95, s32[]{:T(256)} %Arg_1.86, s32[]{:T(256)} %Arg_3.88), metadata={op_name="select_n" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8031   ROOT %tuple.98 = (bf16[]{:T(512)}, s32[]{:T(256)}) tuple(bf16[]{:T(512)} %select.96, s32[]{:T(256)} %select.97)
8032 }
8033 
8034 %region_0.8 (Arg_0.9: bf16[], Arg_1.10: bf16[]) -> bf16[] {
8035   %Arg_1.10 = bf16[]{:T(512)} parameter(1)
8036   %Arg_0.9 = bf16[]{:T(512)} parameter(0)
8037   ROOT %minimum.11 = bf16[]{:T(512)} minimum(bf16[]{:T(512)} %Arg_0.9, bf16[]{:T(512)} %Arg_1.10), metadata={op_name="jit(ScaMTPUTopK)/jit(main)/jit(ScaMTPUTopK)/jit(jit_ScaMTPUTopK)/reduce_min[axes=(2,)]" source_file="<ipython-input-4-4f3bd086a82e>" source_line=8}
8038 }
8039 
8040 ENTRY %main {
8041   %param_0.3 = bf16[1024,1024,2048]{2,0,1:T(8,128)(2,1)} parameter(0)
8042 
8043   // ArgMin
8044   %iota.5.clone.1 = s32[1024,1024,2048]{2,0,1:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(ScaMTPUTopK)/jit(main)/jit(ScaMTPUTopK)/jit(jit_ScaMTPUTopK)/iota[dtype=int32 shape=(1024, 1024, 2048) dimension=2]" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8045   %constant.24 = bf16[]{:T(512)} constant(inf)
8046   %constant.23 = s32[]{:T(256)} constant(0)
8047   %reduce.3 = (bf16[1024,1024]{0,1:T(8,128)(2,1)}, s32[1024,1024]{0,1:T(8,128)}) reduce(bf16[1024,1024,2048]{2,0,1:T(8,128)(2,1)} %param_0.3, s32[1024,1024,2048]{2,0,1:T(8,128)} %iota.5.clone.1, bf16[]{:T(512)} %constant.24, s32[]{:T(256)} %constant.23), dimensions={2}, to_apply=%region_3.84
8048 
8049   %gte.0 = s32[1024,1024]{0,1:T(8,128)} get-tuple-element(%reduce.3), index=1
8050 
8051   // ReduceMin
8052   %constant.25 = bf16[]{:T(512)} constant(inf)
8053   %reduce.4 = bf16[1024,1024]{0,1:T(8,128)(2,1)} reduce(bf16[1024,1024,2048]{2,0,1:T(8,128)(2,1)} %param_0.3, bf16[]{:T(512)} %constant.25), dimensions={2}, to_apply=%region_0.8
8054 
8055   ROOT %tuple.0 = (bf16[1024,1024]{0,1:T(8,128)(2,1)}, s32[1024,1024]{0,1:T(8,128)}) tuple(%reduce.4, %gte.0)
8056 }
8057 )";
8058   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8059   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8060   int64_t reduce_count = absl::c_count_if(
8061       m->entry_computation()->instructions(), [](const HloInstruction* hlo) {
8062         return hlo->opcode() == HloOpcode::kReduce;
8063       });
8064   // Expect one Reduce operation after simplification.
8065   EXPECT_EQ(1, reduce_count);
8066   auto variadic_reduce = m::Reduce().WithShape(m::Shape().IsTuple());
8067   auto root = m->entry_computation()->root_instruction();
8068   // Expect that both outputs are fed by 'variadic_reduce'.
8069   ASSERT_THAT(root,
8070               GmockMatch(m::Tuple(m::GetTupleElement(variadic_reduce, 0),
8071                                   m::GetTupleElement(variadic_reduce, 1))));
8072 }
8073 
TEST_F(AlgebraicSimplifierTest,UnaryVariadicReduceWindow)8074 TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduceWindow) {
8075   const char* kModuleStr = R"(
8076     HloModule m
8077     fn {
8078       p0 = f32[] parameter(0)
8079       p1 = f32[] parameter(1)
8080       a = f32[] add(p0, p1)
8081       ROOT t = (f32[]) tuple(a)
8082     }
8083     test {
8084       p0 = f32[32,8,6,7] parameter(0)
8085       c = f32[] constant(0)
8086       ROOT r = (f32[32,8,6,7]) reduce-window(p0, c), to_apply=fn, window={size=1x1x1x1}
8087     }
8088   )";
8089   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8090   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8091   ASSERT_THAT(m->entry_computation()->root_instruction(),
8092               GmockMatch(m::Tuple(
8093                   m::ReduceWindow(m::Parameter(0), m::ConstantScalar(0)))));
8094   ASSERT_EQ(m->entry_computation()
8095                 ->root_instruction()
8096                 ->operand(0)
8097                 ->called_computations()
8098                 .size(),
8099             1);
8100   EXPECT_THAT(m->entry_computation()
8101                   ->root_instruction()
8102                   ->operand(0)
8103                   ->called_computations()[0]
8104                   ->root_instruction(),
8105               GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
8106 }
8107 
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorder)8108 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) {
8109   const char* kModuleStr = R"(
8110     HloModule m
8111     test {
8112       c1 = pred[] constant(true)
8113       b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={}
8114       c3 = pred[] constant(false)
8115       ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0
8116     }
8117   )";
8118   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8119   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8120   EXPECT_THAT(m->entry_computation()->root_instruction(),
8121               GmockMatch(m::Broadcast(
8122                   m::Pad(m::Broadcast(m::Constant()), m::Constant()))));
8123 }
8124 
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithUse)8125 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) {
8126   const char* kModuleStr = R"(
8127     HloModule m
8128     test {
8129       c1 = pred[] constant(true)
8130       b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={}
8131       c3 = pred[] constant(false)
8132       p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
8133       ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
8134     }
8135   )";
8136   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8137   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8138   EXPECT_THAT(m->entry_computation()->root_instruction(),
8139               GmockMatch(m::Tuple(m::Broadcast(
8140                   m::Pad(m::Broadcast(m::Constant()), m::Constant())))));
8141 }
8142 
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithNonScalar)8143 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
8144   const char* kModuleStr = R"(
8145     HloModule m
8146     test {
8147       c1 = pred[32] parameter(0)
8148       b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2}
8149       c3 = pred[] constant(false)
8150       p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
8151       ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
8152     }
8153   )";
8154   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8155   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8156   EXPECT_THAT(m->entry_computation()->root_instruction(),
8157               GmockMatch(m::Tuple(m::Broadcast(
8158                   m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
8159 }
8160 
8161 // Test that dynamic-update-slice with a scalar broadcast becomes a pad when the
8162 // start_indices are too big.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPadOob)8163 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) {
8164   const char* hlo_string = R"(
8165 HloModule module
8166 
8167 ENTRY f {
8168   constant.546 = f32[] constant(0)
8169   broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={}
8170   parameter.1 = f32[1]{0} parameter(0)
8171   constant.551 = s32[] constant(2)
8172   ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551)
8173 }
8174 )";
8175   TF_ASSERT_OK_AND_ASSIGN(auto module,
8176                           ParseAndReturnVerifiedModule(hlo_string));
8177   VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
8178   AlgebraicSimplifier simplifier(default_options_);
8179   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8180   VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
8181   auto* pad = module->entry_computation()->root_instruction();
8182   EXPECT_THAT(pad,
8183               GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f))));
8184   EXPECT_FALSE(HasInteriorPadding(pad->padding_config()));
8185   ASSERT_EQ(pad->padding_config().dimensions_size(), 1);
8186   EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1);
8187   EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0);
8188 }
8189 
8190 // Test folding of dynamic_slice(iota, index) -> clamp(index, 0, size-1)
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfIota)8191 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfIota) {
8192   const char* hlo_string = R"(
8193 HloModule module
8194 
8195 ENTRY f {
8196   %cst = s32[2]{0} constant({0, 1})
8197   %index = u32[] parameter(0)
8198   ROOT %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %cst, u32[] %index),
8199                                   dynamic_slice_sizes={1}
8200 }
8201 )";
8202 
8203   TF_ASSERT_OK_AND_ASSIGN(auto module,
8204                           ParseAndReturnVerifiedModule(hlo_string));
8205   AlgebraicSimplifier simplifier(default_options_);
8206   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8207   VLOG(2) << "After rewrite \n" << module->ToString();
8208 
8209   EXPECT_THAT(module->entry_computation()->root_instruction(),
8210               GmockMatch(m::Reshape(m::Convert(
8211                   m::Clamp(m::Constant(), m::Parameter(0), m::Constant())))));
8212 }
8213 
8214 // Test folding of clamp(pid, 0, limit) -> pid
TEST_F(AlgebraicSimplifierTest,ClampOfPartitionId)8215 TEST_F(AlgebraicSimplifierTest, ClampOfPartitionId) {
8216   const char* hlo_string = R"(
8217 HloModule module
8218 
8219 ENTRY f {
8220   %pid = u32[] partition-id()
8221   %low = u32[] constant(0)
8222   %high = u32[] constant(5)
8223   ROOT %c = u32[] clamp(%low, %pid, %high)
8224 }
8225 )";
8226 
8227   TF_ASSERT_OK_AND_ASSIGN(
8228       auto module, ParseAndReturnVerifiedModule(hlo_string, /*replica_count=*/1,
8229                                                 /*num_partitions=*/6));
8230   AlgebraicSimplifier simplifier(default_options_);
8231   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8232   VLOG(2) << "After rewrite \n" << module->ToString();
8233 
8234   EXPECT_THAT(module->entry_computation()->root_instruction(),
8235               GmockMatch(m::PartitionId()));
8236 }
8237 
TEST_F(AlgebraicSimplifierTest,ConstantToIota)8238 TEST_F(AlgebraicSimplifierTest, ConstantToIota) {
8239   const char* hlo_string = R"(
8240 HloModule module
8241 
8242 ENTRY f {
8243   %cst = s32[4] constant({0, 25, 50, 75})
8244   ROOT %s = s32[4] copy(s32[4] %cst)
8245 }
8246 )";
8247   TF_ASSERT_OK_AND_ASSIGN(auto module,
8248                           ParseAndReturnVerifiedModule(hlo_string));
8249   AlgebraicSimplifier simplifier(default_options_);
8250   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8251   VLOG(2) << "After rewrite \n" << module->ToString();
8252 
8253   EXPECT_THAT(module->entry_computation()->root_instruction(),
8254               GmockMatch(m::Multiply(m::Iota(), m::Broadcast())));
8255 }
8256 
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfStridedIota)8257 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfStridedIota) {
8258   const char* hlo_string = R"(
8259 HloModule module
8260 
8261 ENTRY f {
8262   %cst = s32[4] constant({0, 25, 50, 75})
8263   %index = u32[] parameter(0)
8264   ROOT %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %cst, u32[] %index),
8265                                   dynamic_slice_sizes={1}
8266 }
8267 )";
8268   TF_ASSERT_OK_AND_ASSIGN(auto module,
8269                           ParseAndReturnVerifiedModule(hlo_string));
8270   AlgebraicSimplifier simplifier(default_options_);
8271   ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8272   VLOG(2) << "After rewrite \n" << module->ToString();
8273 
8274   EXPECT_THAT(module->entry_computation()->root_instruction(),
8275               GmockMatch(m::Reshape(
8276                   m::Multiply(m::Convert(m::Clamp()), m::Constant()))));
8277 }
8278 
TEST_F(AlgebraicSimplifierTest,AbsEliminationSelMaxBcast)8279 TEST_F(AlgebraicSimplifierTest, AbsEliminationSelMaxBcast) {
8280   const char* kModuleStr = R"(
8281     HloModule m
8282     test {
8283       p0 = f32[32]{0} parameter(0)
8284       p1 = pred[32]{0} parameter(1)
8285       zero = f32[] constant(0.0)
8286       bcast = f32[32] broadcast(zero), dimensions={}
8287       m = f32[32]{0} maximum(p0, bcast)
8288       s = f32[32]{0} select(p1, bcast, m)
8289       ROOT a = f32[32]{0} abs(s)
8290     }
8291   )";
8292   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8293   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8294   EXPECT_THAT(m->entry_computation()->root_instruction(),
8295               GmockMatch(m::Select(
8296                   m::Parameter(1), m::Broadcast(m::ConstantScalar()),
8297                   m::MaximumAnyOrder(m::Parameter(0),
8298                                      m::Broadcast(m::ConstantScalar())))));
8299 }
8300 
TEST_F(AlgebraicSimplifierTest,SimplifyRedundantBitcastConvert)8301 TEST_F(AlgebraicSimplifierTest, SimplifyRedundantBitcastConvert) {
8302   const char* kModuleStr = R"(
8303     HloModule m
8304 
8305     ENTRY test {
8306       p0 = s32[10] parameter(0)
8307       p1 = s32[10] parameter(1)
8308       b0 = u32[10] bitcast-convert(p0)
8309       b1 = u32[10] bitcast-convert(p1)
8310       c = u32[20] concatenate(b0, b1), dimensions={0}
8311       ROOT out = s32[20] bitcast-convert(c)
8312     }
8313   )";
8314   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8315   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8316   EXPECT_THAT(m->entry_computation()->root_instruction(),
8317               GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1))));
8318 }
8319 
TEST_F(AlgebraicSimplifierTest,SimplifyOptimizationBarrier)8320 TEST_F(AlgebraicSimplifierTest, SimplifyOptimizationBarrier) {
8321   const char* kModuleStr = R"(
8322     HloModule m
8323 
8324     ENTRY entry {
8325       param.0 = f32[] parameter(0)
8326       param.1 = f32[] parameter(1)
8327       add.0 = f32[] add(param.0, param.1)
8328       sub.0 = f32[] subtract(param.0, param.1)
8329       mul.0 = f32[] multiply(param.0, param.1)
8330       tuple.0 = (f32[], f32[], f32[]) tuple(mul.0, sub.0, add.0)
8331       b = (f32[], f32[], f32[]) opt-barrier(tuple.0)
8332       gte.0 = f32[] get-tuple-element(b), index=1
8333       ROOT  t = (f32[], f32[]) tuple(mul.0,gte.0)
8334     }
8335   )";
8336   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8337   EXPECT_EQ(m->entry_computation()
8338                 ->root_instruction()
8339                 ->operand(1)
8340                 ->operand(0)
8341                 ->operand(0)
8342                 ->operand_count(),
8343             3);
8344   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8345   EXPECT_EQ(m->entry_computation()
8346                 ->root_instruction()
8347                 ->operand(1)
8348                 ->operand(0)
8349                 ->operand(0)
8350                 ->operand_count(),
8351             2);
8352 }
8353 
TEST_F(AlgebraicSimplifierTest,GTETupleShardingLoss)8354 TEST_F(AlgebraicSimplifierTest, GTETupleShardingLoss) {
8355   // Verify the gte(tuple) folding does not happen if it loses sharding info.
8356   const char* kModuleStr = R"(
8357     HloModule m
8358 
8359     ENTRY test {
8360       p0 = s32[10] parameter(0), sharding={devices=[2]0,1}
8361       t = (s32[10]) tuple(p0)
8362       ROOT %gte = s32[10] get-tuple-element(t), index=0, sharding={replicated}
8363     }
8364   )";
8365   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8366   ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8367 }
8368 
TEST_F(AlgebraicSimplifierTest,DynamicSliceShapeLayout)8369 TEST_F(AlgebraicSimplifierTest, DynamicSliceShapeLayout) {
8370   // Verify we maintain layout when optimizing dynamic-slice
8371   const char* kModuleStr = R"(
8372     HloModule m
8373 
8374     ENTRY test {
8375       p0 = u32[]{:T(128)} parameter(0)
8376       %constant.1 = s32[4]{0:T(128)} constant({0, 16, 32, 48})
8377       %dynamic-slice = s32[1]{0:T(128)} dynamic-slice(s32[4]{0:T(128)} %constant.1, u32[] %p0), dynamic_slice_sizes={1}
8378       ROOT t = (s32[1]{0:T(128)}) tuple(dynamic-slice)
8379     }
8380   )";
8381   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8382   ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8383   const Shape& slice_shape =
8384       m.get()->entry_computation()->root_instruction()->operand(0)->shape();
8385   EXPECT_TRUE(slice_shape.has_layout());
8386   EXPECT_EQ(slice_shape.layout().tiles_size(), 1);
8387 }
8388 
8389 // Fold a sequence of copy bitcast copy
TEST_F(AlgebraicSimplifierTest,CopyBitcastCopy)8390 TEST_F(AlgebraicSimplifierTest, CopyBitcastCopy) {
8391   const char* kModuleStr = R"(
8392     HloModule m
8393 
8394     ENTRY test {
8395      fusion.1235 = bf16[1600,50,512]{2,0,1:T(8,128)(2,1)} parameter(0)
8396      copy.3038 = bf16[1600,50,512]{0,2,1:T(8,128)(2,1)} copy(fusion.1235)
8397      bitcast.8 = bf16[1600,50,16,32]{0,3,2,1:T(8,128)(2,1)} bitcast(copy.3038)
8398      copy.3045 = bf16[1600,50,16,32]{1,3,2,0:T(8,128)(2,1)} copy(bitcast.8)
8399     }
8400   )";
8401   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8402   AlgebraicSimplifierOptions options;
8403   options.set_is_layout_sensitive(true);
8404   AlgebraicSimplifier simplifier(options);
8405   ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
8406   EXPECT_THAT(m->entry_computation()->root_instruction(),
8407               GmockMatch(m::Bitcast(m::Copy(m::Parameter()))));
8408 }
8409 
TEST_F(AlgebraicSimplifierTest,CopyBitcastCopy2)8410 TEST_F(AlgebraicSimplifierTest, CopyBitcastCopy2) {
8411   const char* kModuleStr = R"(
8412     HloModule m
8413 
8414     ENTRY test {
8415      %Arg_0.1 = f32[8,3,3,7,7]{4,0,3,2,1:T(8,128)} parameter(0)
8416      %copy.1 = f32[8,3,3,7,7]{4,3,2,1,0:T(8,128)} copy(f32[8,3,3,7,7]{4,0,3,2,1:T(8,128)} %Arg_0.1)
8417      %bitcast = f32[1,72,7,7]{3,2,1,0:T(8,128)} bitcast(f32[8,3,3,7,7]{4,3,2,1,0:T(8,128)} %copy.1)
8418      %copy.2 = f32[1,72,7,7]{1,3,2,0:T(8,128)} copy(f32[1,72,7,7]{3,2,1,0:T(8,128)} %bitcast)
8419     }
8420 )";
8421   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8422   AlgebraicSimplifierOptions options;
8423   options.set_is_layout_sensitive(true);
8424   AlgebraicSimplifier simplifier(options);
8425   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
8426 }
8427 
TEST_F(AlgebraicSimplifierTest,CopyReshapeCopy3)8428 TEST_F(AlgebraicSimplifierTest, CopyReshapeCopy3) {
8429   const char* kModuleStr = R"(
8430    HloModule m
8431 
8432   ENTRY main {
8433   p = f32[2,3]{0,1} parameter(0)
8434   copy = f32[2,3]{1,0} copy(p)
8435   reshape = f32[3,2]{1,0} bitcast(copy)
8436   ROOT copy.1 = f32[3,2]{0,1} copy(reshape)
8437 })";
8438   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8439   AlgebraicSimplifierOptions options;
8440   options.set_is_layout_sensitive(true);
8441   AlgebraicSimplifier simplifier(options);
8442   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
8443   VLOG(3) << "Module " << m->ToString();
8444 }
8445 
TEST_F(AlgebraicSimplifierTest,CopyReshapeCopy4)8446 TEST_F(AlgebraicSimplifierTest, CopyReshapeCopy4) {
8447   const char* kModuleStr = R"(
8448    HloModule m
8449 
8450   ENTRY main {
8451     p = f32[6,2,3]{0,1,2} parameter(0)
8452     copy.0 = f32[6,2,3]{0,2,1} copy(p)
8453     reshape = f32[2,3,6]{1,0,2} bitcast(copy.0)
8454     ROOT copy.1 = f32[2,3,6]{0,1,2} copy(reshape)
8455 })";
8456   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8457   AlgebraicSimplifierOptions options;
8458   options.set_is_layout_sensitive(true);
8459   AlgebraicSimplifier simplifier(options);
8460   ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
8461   VLOG(3) << "Module " << m->ToString();
8462 }
8463 
TEST_F(AlgebraicSimplifierTest,BitcastCopyChain)8464 TEST_F(AlgebraicSimplifierTest, BitcastCopyChain) {
8465   const char* kModuleStr = R"(
8466    HloModule m
8467 
8468   ENTRY main {
8469    p.0 = f32[4,32,32,1]{2,3,1,0} parameter(0)
8470    reshape.30 = f32[4,32,1,1,32]{4,1,0,3,2} reshape(p.0)
8471    transpose.1757 = f32[4,1,1,32,32]{3,4,0,1,2} transpose(reshape.30), dimensions={0,3,2,4,1}
8472    copy.3 = f32[4,1,1,32,32]{4,3,0,2,1} copy(transpose.1757)
8473    reshape.1758 = f32[4,1,1,1024]{3,2,1,0} reshape(copy.3)
8474    transpose.61 = f32[1024,4,1,1]{0,3,2,1} transpose(reshape.1758), dimensions={3,0,1,2}
8475    copy.4 = f32[1024,4,1,1]{0,1,3,2} copy(transpose.61)
8476    ROOT reshape.107 = f32[1024,4]{0,1} reshape(copy.4)
8477    })";
8478   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8479   AlgebraicSimplifierOptions options;
8480   options.set_is_layout_sensitive(true);
8481   AlgebraicSimplifier simplifier(options);
8482   auto result = simplifier.Run(m.get()).ValueOrDie();
8483   SCOPED_TRACE(m->ToString());
8484   ASSERT_TRUE(result);
8485   EXPECT_NE(FindInstruction(m.get(), "copy.3"), nullptr);
8486 }
8487 
8488 // Make sure that the following copy-bitcast-copy is not transformed via
8489 // SwapCopyBitcastCopy function. If SwapCopyBitcastCopy does not fire, in this
8490 // case, the last copy will be turned into a bitcast by HandleCopy.
TEST_F(AlgebraicSimplifierTest,BitcastCopyChainSmall)8491 TEST_F(AlgebraicSimplifierTest, BitcastCopyChainSmall) {
8492   const char* kModuleStr = R"(
8493    HloModule m
8494    ENTRY %main (para.0: f32[4,1,1,32,32]) -> f32[1024,4,1,1] {
8495     %para.0 = f32[4,1,1,32,32]{3,4,0,1,2} parameter(0)
8496     %copy.0 = f32[4,1,1,32,32]{4,3,0,2,1} copy(%para.0)
8497     %bitcast.0 = f32[1024,4,1,1]{0,3,2,1} bitcast(%copy.0)
8498     ROOT %copy.1 = f32[1024,4,1,1]{0,1,3,2} copy(%bitcast.0)
8499   }
8500 
8501   )";
8502   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8503   AlgebraicSimplifierOptions options;
8504   options.set_is_layout_sensitive(true);
8505   AlgebraicSimplifier simplifier(options);
8506   SCOPED_TRACE(m->ToString());
8507   auto result = simplifier.Run(m.get()).ValueOrDie();
8508   SCOPED_TRACE(m->ToString());
8509   ASSERT_TRUE(result);
8510   EXPECT_THAT(m->entry_computation()->root_instruction(),
8511               GmockMatch(m::Bitcast(m::Bitcast(m::Copy(m::Parameter(0))))));
8512 }
8513 
8514 }  // namespace
8515 }  // namespace xla
8516