1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/strings/str_replace.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
35 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/test.h"
40 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/window_util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/statusor.h"
46
47 namespace xla {
48 namespace {
49
50 using ::testing::ElementsAre;
51 namespace m = match;
52
53 class AlgebraicSimplifierTest : public HloTestBase {
54 protected:
55 AlgebraicSimplifierOptions default_options_;
56 };
57
58 // Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,AddZero)59 TEST_F(AlgebraicSimplifierTest, AddZero) {
60 auto m = CreateNewVerifiedModule();
61 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
62 HloComputation::Builder builder(TestName());
63 HloInstruction* param0 = builder.AddInstruction(
64 HloInstruction::CreateParameter(0, r0f32, "param0"));
65 HloInstruction* zero = builder.AddInstruction(
66 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
67 builder.AddInstruction(
68 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
69
70 auto computation = m->AddEntryComputation(builder.Build());
71 HloInstruction* root = computation->root_instruction();
72 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
73 AlgebraicSimplifier simplifier(default_options_);
74 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
75 root = computation->root_instruction();
76 EXPECT_EQ(root, param0);
77 }
78
TEST_F(AlgebraicSimplifierTest,FactorIntegerAddition)79 TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) {
80 const char* kModuleStr = R"(
81 HloModule m
82 test {
83 p0 = s32[8] parameter(0)
84 p1 = s32[8] parameter(1)
85 p2 = s32[8] parameter(2)
86 x = s32[8] multiply(p0, p2)
87 y = s32[8] multiply(p1, p2)
88 ROOT sum = s32[8] add(x, y)
89 }
90 )";
91 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
92 AlgebraicSimplifier simplifier(default_options_);
93 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
94 EXPECT_THAT(
95 m->entry_computation()->root_instruction(),
96 GmockMatch(m::MultiplyAnyOrder(
97 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2))));
98 }
99
100 // A*C + B*C => (A+B)*C if C is a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAddition)101 TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
102 const char* kModuleStr = R"(
103 HloModule m
104 test {
105 p0 = f32[] parameter(0)
106 p1 = f32[] parameter(1)
107 c = f32[] constant(0.125)
108 x = f32[] multiply(p0, c)
109 y = f32[] multiply(p1, c)
110 ROOT sum = f32[] add(x, y)
111 }
112 )";
113 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
114 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
115 EXPECT_THAT(m->entry_computation()->root_instruction(),
116 GmockMatch(m::MultiplyAnyOrder(
117 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
118 m::ConstantScalar(0.125))));
119 }
120
121 // (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest,SquareOfAbs)122 TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
123 const char* kModuleStr = R"(
124 HloModule m
125 test {
126 p = f32[] parameter(0)
127 a = f32[] abs(p)
128 ROOT z = f32[] multiply(a, a)
129 }
130 )";
131 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
132 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
133 EXPECT_THAT(m->entry_computation()->root_instruction(),
134 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
135 }
136
137 // (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain)138 TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
139 const char* kModuleStr = R"(
140 HloModule m
141 test {
142 p0 = f32[] parameter(0)
143 p1 = f32[] parameter(1)
144 c = f32[] constant(2)
145 d = f32[] constant(4)
146 x = f32[] multiply(p0, c)
147 y = f32[] multiply(p1, d)
148 ROOT z = f32[] multiply(x, y)
149 }
150 )";
151 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
152 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
153 EXPECT_THAT(
154 m->entry_computation()->root_instruction(),
155 GmockMatch(m::MultiplyAnyOrder(
156 m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)),
157 m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4)))));
158 }
159
160 // (a*C1)*C2 => a*(C1*C2)
TEST_F(AlgebraicSimplifierTest,MultiplyChain2)161 TEST_F(AlgebraicSimplifierTest, MultiplyChain2) {
162 const char* kModuleStr = R"(
163 HloModule m
164 test {
165 p0 = f32[] parameter(0)
166 a = f32[] constant(2)
167 b = f32[] constant(4)
168 c = f32[] multiply(p0, a)
169 ROOT y = f32[] multiply(c, b)
170 }
171 )";
172 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
173 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
174 EXPECT_THAT(m->entry_computation()->root_instruction(),
175 GmockMatch(m::MultiplyAnyOrder(
176 m::Parameter(0), m::MultiplyAnyOrder(m::ConstantScalar(2),
177 m::ConstantScalar(4)))));
178 }
179
180 // MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==>
181 // MUL(X, BROADCAST(MUL(Y, BROADCAST(constant))))
TEST_F(AlgebraicSimplifierTest,MultiplyBroadcastReassoc)182 TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) {
183 const char* kModuleStr = R"(
184 HloModule m
185 test {
186 p0 = f32[2,2] parameter(0)
187 p1 = f32[] parameter(1)
188 b = f32[] constant(2)
189 c = f32[2, 2] broadcast(b), dimensions={}
190 x = f32[2,2] multiply(p0, c)
191 y = f32[2,2] broadcast(p1), dimensions={}
192 ROOT z = f32[2,2] multiply(y, x)
193 }
194 )";
195 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
196 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
197 EXPECT_THAT(m->entry_computation()->root_instruction(),
198 GmockMatch(m::MultiplyAnyOrder(
199 m::Parameter(0), m::Broadcast(m::MultiplyAnyOrder(
200 m::Parameter(1), m::Constant())))));
201 }
202
203 // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionWithBroadcast)204 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
205 const char* kModuleStr = R"(
206 HloModule m
207 test {
208 p0 = f32[4] parameter(0)
209 p1 = f32[4] parameter(1)
210 c = f32[] constant(0.125)
211 b = f32[4] broadcast(c), dimensions={}
212 x = f32[4] multiply(p0, b)
213 y = f32[4] multiply(p1, b)
214 ROOT sum = f32[4] add(x, y)
215 }
216 )";
217 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
218 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
219 EXPECT_THAT(m->entry_computation()->root_instruction(),
220 GmockMatch(m::MultiplyAnyOrder(
221 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
222 m::Broadcast(m::ConstantScalar(0.125)))));
223 }
224
225 // A*C + B*C => (A+B)*C simplification should not happen if C is not a
226 // floating-point power of 2.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionNotPowerOf2)227 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) {
228 const char* kModuleStr = R"(
229 HloModule m
230 test {
231 p0 = f32[] parameter(0)
232 p1 = f32[] parameter(1)
233 c = f32[] constant(0.3)
234 x = f32[] multiply(p0, c)
235 y = f32[] multiply(p1, c)
236 ROOT sum = f32[] add(x, y)
237 }
238 )";
239 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
240 EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
241 }
242
243 // A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are
244 // complex numbers.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionComplex)245 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) {
246 const char* kModuleStr = R"(
247 HloModule m
248 test {
249 p0 = c64[8] parameter(0)
250 p1 = c64[8] parameter(1)
251 p2 = c64[8] parameter(2)
252 x = c64[8] multiply(p0, p2)
253 y = c64[8] multiply(p1, p2)
254 ROOT sum = c64[8] add(x, y)
255 }
256 )";
257 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
258 EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
259 }
260
261 // A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex.
TEST_F(AlgebraicSimplifierTest,FactorFpAdditionBfloat16)262 TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) {
263 const char* kModuleStr = R"(
264 HloModule m
265 test {
266 p0 = bf16[4] parameter(0)
267 p1 = bf16[4] parameter(1)
268 c = bf16[] constant(0.125)
269 b = bf16[4] broadcast(c), dimensions={}
270 x = bf16[4] multiply(p0, b)
271 y = bf16[4] multiply(p1, b)
272 ROOT sum = bf16[4] add(x, y)
273 }
274 )";
275 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
276 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
277 EXPECT_THAT(m->entry_computation()->root_instruction(),
278 GmockMatch(m::MultiplyAnyOrder(
279 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
280 m::Broadcast(m::ConstantScalar(0.125)))));
281 }
282
TEST_F(AlgebraicSimplifierTest,UnsignedDivideByPowerOf2)283 TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) {
284 const char* kModuleStr = R"(
285 HloModule m
286 test {
287 p = u32[4] parameter(0)
288 c = u32[] constant(8)
289 b = u32[4] broadcast(c), dimensions={}
290 ROOT d = u32[4] divide(p, b)
291 }
292 )";
293 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
294 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
295 EXPECT_THAT(m->entry_computation()->root_instruction(),
296 GmockMatch(m::ShiftRightLogical(
297 m::Parameter(0), m::Broadcast(m::ConstantScalar(3)))));
298 }
299
TEST_F(AlgebraicSimplifierTest,SignedDivideByPowerOf2)300 TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) {
301 const char* kModuleStr = R"(
302 HloModule m
303 test {
304 p = s32[4] parameter(0)
305 c = s32[] constant(8)
306 b = s32[4] broadcast(c), dimensions={}
307 ROOT d = s32[4] divide(p, b)
308 }
309 )";
310 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
311 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
312 auto match_dividend_is_negative =
313 m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
314 auto match_abs = m::Select(match_dividend_is_negative,
315 m::Negate(m::Parameter(0)), m::Parameter(0));
316 auto match_shift =
317 m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3)));
318 EXPECT_THAT(m->entry_computation()->root_instruction(),
319 GmockMatch(m::Select(match_dividend_is_negative,
320 m::Negate(match_shift), match_shift)));
321 }
322
TEST_F(AlgebraicSimplifierTest,UnsignedRemainderByPowerOf2)323 TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) {
324 const char* kModuleStr = R"(
325 HloModule m
326 test {
327 p = u32[4] parameter(0)
328 c = u32[] constant(8)
329 b = u32[4] broadcast(c), dimensions={}
330 ROOT r = u32[4] remainder(p, b)
331 }
332 )";
333 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
334 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
335 EXPECT_THAT(m->entry_computation()->root_instruction(),
336 GmockMatch(m::AndAnyOrder(m::Parameter(0),
337 m::Broadcast(m::ConstantScalar(7)))));
338 }
339
TEST_F(AlgebraicSimplifierTest,SignedRemainderByPowerOf2)340 TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) {
341 const char* kModuleStr = R"(
342 HloModule m
343 test {
344 p = s32[4] parameter(0)
345 c = s32[] constant(8)
346 b = s32[4] broadcast(c), dimensions={}
347 ROOT r = s32[4] remainder(p, b)
348 }
349 )";
350 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
351 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
352 auto match_dividend_is_negative =
353 m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0)));
354 auto match_abs = m::Select(match_dividend_is_negative,
355 m::Negate(m::Parameter(0)), m::Parameter(0));
356 auto match_and =
357 m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7)));
358 EXPECT_THAT(m->entry_computation()->root_instruction(),
359 GmockMatch(m::Select(match_dividend_is_negative,
360 m::Negate(match_and), match_and)));
361 }
362
363 // Test that A * 0 is simplified to 0
TEST_F(AlgebraicSimplifierTest,MulZero)364 TEST_F(AlgebraicSimplifierTest, MulZero) {
365 auto m = CreateNewVerifiedModule();
366 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
367 HloComputation::Builder builder(TestName());
368 HloInstruction* param0 = builder.AddInstruction(
369 HloInstruction::CreateParameter(0, r0s32, "param0"));
370 HloInstruction* zero = builder.AddInstruction(
371 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
372 builder.AddInstruction(
373 HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
374
375 auto computation = m->AddEntryComputation(builder.Build());
376 HloInstruction* root = computation->root_instruction();
377 EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
378 AlgebraicSimplifier simplifier(default_options_);
379 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
380 EXPECT_EQ(computation->root_instruction(), zero);
381 }
382
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeConstants)383 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) {
384 const char* kModuleStr = R"(
385 HloModule m
386 test {
387 p0 = f32[] parameter(0)
388 c0 = f32[] constant(2.0)
389 c1 = f32[] constant(3.0)
390 multiply0 = f32[] multiply(p0, c0)
391 ROOT multiply1 = f32[] multiply(multiply0, c1)
392 }
393 )";
394 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
395 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
396 EXPECT_THAT(m->entry_computation()->root_instruction(),
397 GmockMatch(m::Multiply(m::Parameter(0),
398 m::Multiply(m::ConstantScalar(2.0),
399 m::ConstantScalar(3.0)))));
400 }
401
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMergeBroadcastedConstants)402 TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
403 const char* kModuleStr = R"(
404 HloModule m
405 test {
406 p0 = f32[4] parameter(0)
407 c0 = f32[] constant(2.0)
408 c1 = f32[] constant(3.0)
409 b0 = f32[4] broadcast(c0), dimensions={}
410 b1 = f32[4] broadcast(c1), dimensions={}
411 multiply0 = f32[4] multiply(p0, b0)
412 ROOT multiply1 = f32[4] multiply(multiply0, b1)
413 }
414 )";
415 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
416 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
417 EXPECT_THAT(
418 m->entry_computation()->root_instruction(),
419 GmockMatch(m::Multiply(
420 m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0),
421 m::ConstantScalar(3.0))))));
422 }
423
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsScalar)424 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
425 const char* kModuleStr = R"(
426 HloModule m
427 test {
428 p0 = f32[] parameter(0)
429 p1 = f32[] parameter(1)
430 b0 = f32[4] broadcast(p0), dimensions={}
431 b1 = f32[4] broadcast(p1), dimensions={}
432 ROOT multiply = f32[4] multiply(b1, b0)
433 }
434 )";
435 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
436 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
437 EXPECT_THAT(
438 m->entry_computation()->root_instruction(),
439 GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
440 m::Broadcast(m::Parameter(0))))));
441 }
442
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsConstantMix)443 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
444 const char* kModuleStr = R"(
445 HloModule m
446 test {
447 p0 = f32[4] parameter(0)
448 c0 = f32[] constant(2.0)
449 b0 = f32[4,2] broadcast(c0), dimensions={}
450 b1 = f32[4,2] broadcast(p0), dimensions={0}
451 ROOT multiply = f32[4,2] multiply(b1, b0)
452 }
453 )";
454 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
455 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
456 EXPECT_THAT(m->entry_computation()->root_instruction(),
457 GmockMatch(m::Broadcast(m::Multiply(
458 m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
459 }
460
TEST_F(AlgebraicSimplifierTest,ElementwiseSinkMultipleBroadcastsNonScalar)461 TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
462 const char* kModuleStr = R"(
463 HloModule m
464 test {
465 p0 = f32[4] parameter(0)
466 p1 = f32[4] parameter(1)
467 b0 = f32[4,2] broadcast(p0), dimensions={0}
468 b1 = f32[4,2] broadcast(p1), dimensions={0}
469 ROOT multiply = f32[4,2] multiply(b1, b0)
470 }
471 )";
472 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
473 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
474 EXPECT_THAT(
475 m->entry_computation()->root_instruction(),
476 GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
477 }
478
TEST_F(AlgebraicSimplifierTest,ElementwiseNoSinkBroadcastsDifferentDims)479 TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
480 const char* kModuleStr = R"(
481 HloModule m
482 test {
483 p0 = f32[4] parameter(0)
484 p1 = f32[8] parameter(1)
485 b0 = f32[4,8] broadcast(p0), dimensions={0}
486 b1 = f32[4,8] broadcast(p1), dimensions={1}
487 ROOT multiply = f32[4,8] multiply(b1, b0)
488 }
489 )";
490 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
491 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
492 EXPECT_THAT(m->entry_computation()->root_instruction(),
493 GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
494 m::Broadcast(m::Parameter(0)))));
495 }
496
TEST_F(AlgebraicSimplifierTest,MultiplyReassociateMultiplyOfConstantAndBroadcast)497 TEST_F(AlgebraicSimplifierTest,
498 MultiplyReassociateMultiplyOfConstantAndBroadcast) {
499 const char* kModuleStr = R"(
500 HloModule m
501 test {
502 c0 = f32[4] constant({2.0, 3.0, 4.0, 5.0})
503 c1 = f32[] constant(3.0)
504 c2 = f32[] constant(4.0)
505 b0 = f32[4] broadcast(c1), dimensions={}
506 b1 = f32[4] broadcast(c2), dimensions={}
507 multiply0 = f32[4] multiply(c0, b0)
508 ROOT multiply1 = f32[4] multiply(multiply0, b1)
509 }
510 )";
511 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
512 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
513 EXPECT_THAT(
514 m->entry_computation()->root_instruction(),
515 GmockMatch(m::Multiply(
516 m::Constant(), m::Broadcast(m::Multiply(m::ConstantScalar(3.0),
517 m::ConstantScalar(4.0))))));
518 }
519
520 // Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest,SelectTrue)521 TEST_F(AlgebraicSimplifierTest, SelectTrue) {
522 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
523 HloComputation::Builder builder(TestName());
524 HloInstruction* param0 = builder.AddInstruction(
525 HloInstruction::CreateParameter(0, r0s32, "param0"));
526 HloInstruction* param1 = builder.AddInstruction(
527 HloInstruction::CreateParameter(1, r0s32, "param1"));
528 HloInstruction* one = builder.AddInstruction(
529 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
530 builder.AddInstruction(HloInstruction::CreateTernary(
531 r0s32, HloOpcode::kSelect, one, param0, param1));
532
533 auto module = CreateNewVerifiedModule();
534 auto computation = module->AddEntryComputation(builder.Build());
535 HloInstruction* root = computation->root_instruction();
536 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
537 AlgebraicSimplifier simplifier(default_options_);
538 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
539 EXPECT_EQ(computation->root_instruction(), param0);
540 }
541
542 // Test that select(false, a, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectFalse)543 TEST_F(AlgebraicSimplifierTest, SelectFalse) {
544 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
545 HloComputation::Builder builder(TestName());
546 HloInstruction* param0 = builder.AddInstruction(
547 HloInstruction::CreateParameter(0, r0s32, "param0"));
548 HloInstruction* param1 = builder.AddInstruction(
549 HloInstruction::CreateParameter(1, r0s32, "param1"));
550 HloInstruction* zero = builder.AddInstruction(
551 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
552 builder.AddInstruction(HloInstruction::CreateTernary(
553 r0s32, HloOpcode::kSelect, zero, param0, param1));
554
555 auto module = CreateNewVerifiedModule();
556 auto computation = module->AddEntryComputation(builder.Build());
557 HloInstruction* root = computation->root_instruction();
558 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
559 AlgebraicSimplifier simplifier(default_options_);
560 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
561 EXPECT_EQ(computation->root_instruction(), param1);
562 }
563
564 // Test that select(a, b, b) is simplified to b
TEST_F(AlgebraicSimplifierTest,SelectIdentical)565 TEST_F(AlgebraicSimplifierTest, SelectIdentical) {
566 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
567 HloComputation::Builder builder(TestName());
568 HloInstruction* param0 = builder.AddInstruction(
569 HloInstruction::CreateParameter(0, r0s32, "param0"));
570 HloInstruction* param1 = builder.AddInstruction(
571 HloInstruction::CreateParameter(1, r0s32, "param1"));
572 builder.AddInstruction(HloInstruction::CreateTernary(
573 r0s32, HloOpcode::kSelect, param0, param1, param1));
574
575 auto module = CreateNewVerifiedModule();
576 auto computation = module->AddEntryComputation(builder.Build());
577 HloInstruction* root = computation->root_instruction();
578 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
579 AlgebraicSimplifier simplifier(default_options_);
580 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
581 EXPECT_EQ(computation->root_instruction(), param1);
582 }
583
584 // Test that select(not(pred), a, b) is simplified to select(pred, b, a)
TEST_F(AlgebraicSimplifierTest,SelectWithNotPred)585 TEST_F(AlgebraicSimplifierTest, SelectWithNotPred) {
586 Shape pred_ty = ShapeUtil::MakeShape(PRED, {});
587 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
588 HloComputation::Builder builder(TestName());
589 HloInstruction* param0 = builder.AddInstruction(
590 HloInstruction::CreateParameter(0, pred_ty, "param0"));
591 HloInstruction* param1 = builder.AddInstruction(
592 HloInstruction::CreateParameter(1, r0s32, "param1"));
593 HloInstruction* param2 = builder.AddInstruction(
594 HloInstruction::CreateParameter(2, r0s32, "param2"));
595 HloInstruction* pred_instr = builder.AddInstruction(
596 HloInstruction::CreateUnary(pred_ty, HloOpcode::kNot, param0));
597 builder.AddInstruction(HloInstruction::CreateTernary(
598 r0s32, HloOpcode::kSelect, pred_instr, param1, param2));
599
600 auto module = CreateNewVerifiedModule();
601 auto computation = module->AddEntryComputation(builder.Build());
602 HloInstruction* root = computation->root_instruction();
603 EXPECT_EQ(root->opcode(), HloOpcode::kSelect);
604 AlgebraicSimplifier simplifier(default_options_);
605 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
606 const auto& operands = computation->root_instruction()->operands();
607 EXPECT_EQ(operands[0], param0);
608 EXPECT_EQ(operands[1], param2);
609 EXPECT_EQ(operands[2], param1);
610 }
611
612 // Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest,TwoReducesToOne)613 TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
614 auto m = CreateNewVerifiedModule();
615 HloComputation::Builder builder(TestName());
616 // Create add computation.
617 HloInstruction* zero = builder.AddInstruction(
618 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
619 HloComputation* add_computation = nullptr;
620 {
621 HloComputation::Builder builder(TestName() + ".add");
622 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
623 HloInstruction* p0 = builder.AddInstruction(
624 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
625 HloInstruction* p1 = builder.AddInstruction(
626 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
627 builder.AddInstruction(
628 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
629 add_computation = m->AddEmbeddedComputation(builder.Build());
630 }
631 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
632 HloInstruction* param = builder.AddInstruction(
633 HloInstruction::CreateParameter(0, r4f32, "param"));
634 std::vector<int64_t> dims0({0});
635 Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
636 HloInstruction* reduce0 = builder.AddInstruction(
637 HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
638 std::vector<int64_t> dims1({1, 2});
639 Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
640 builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
641 dims1, add_computation));
642 m->AddEntryComputation(builder.Build());
643 AlgebraicSimplifier simplifier(default_options_);
644 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
645 HloInstruction* root = m->entry_computation()->root_instruction();
646 EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero))));
647 EXPECT_EQ(root->dimensions(), std::vector<int64_t>({0, 2, 3}));
648 }
649
650 // Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest,AddConstOnLHS)651 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
652 auto m = CreateNewVerifiedModule();
653 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
654 HloComputation::Builder builder(TestName());
655 HloInstruction* param0 = builder.AddInstruction(
656 HloInstruction::CreateParameter(0, r0f32, "param0"));
657 HloInstruction* constant = builder.AddInstruction(
658 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
659 builder.AddInstruction(
660 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
661
662 auto computation = m->AddEntryComputation(builder.Build());
663 HloInstruction* root = computation->root_instruction();
664 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
665 AlgebraicSimplifier simplifier(default_options_);
666 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
667 root = computation->root_instruction();
668 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant())));
669 }
670
671 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeConstants)672 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
673 auto m = CreateNewVerifiedModule();
674 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
675 HloComputation::Builder builder(TestName());
676 HloInstruction* param0 = builder.AddInstruction(
677 HloInstruction::CreateParameter(0, r0f32, "param0"));
678 HloInstruction* constant1 = builder.AddInstruction(
679 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
680 HloInstruction* constant2 = builder.AddInstruction(
681 HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
682
683 HloInstruction* add1 = builder.AddInstruction(
684 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
685 builder.AddInstruction(
686 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
687
688 auto computation = m->AddEntryComputation(builder.Build());
689 HloInstruction* root = computation->root_instruction();
690 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
691 AlgebraicSimplifier simplifier(default_options_);
692 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
693 root = computation->root_instruction();
694 EXPECT_THAT(root, GmockMatch(m::Add(
695 m::Op().Is(param0),
696 m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
697 }
698
TEST_F(AlgebraicSimplifierTest,AddReassociateMergeBroadcastedConstants)699 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
700 const char* kModuleStr = R"(
701 HloModule m
702 test {
703 p0 = f32[4] parameter(0)
704 c0 = f32[] constant(1.0)
705 c1 = f32[] constant(2.0)
706 b0 = f32[4] broadcast(c0), dimensions={}
707 b1 = f32[4] broadcast(c1), dimensions={}
708 add0 = f32[4] add(p0, b0)
709 ROOT add1 = f32[4] add(add0, b1)
710 }
711 )";
712 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
713 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
714 EXPECT_THAT(m->entry_computation()->root_instruction(),
715 GmockMatch(m::Add(m::Parameter(0),
716 m::Broadcast(m::Add(m::ConstantScalar(1.0),
717 m::ConstantScalar(2.0))))));
718 }
719
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR0Operand)720 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
721 auto m = CreateNewVerifiedModule();
722 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
723 HloComputation::Builder builder(TestName());
724 HloInstruction* param0 = builder.AddInstruction(
725 HloInstruction::CreateParameter(0, r2f32, "param0"));
726 HloInstruction* zero = builder.AddInstruction(
727 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
728 HloInstruction* bcast = builder.AddInstruction(
729 HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
730 builder.AddInstruction(
731 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
732
733 auto computation = m->AddEntryComputation(builder.Build());
734 HloInstruction* root = computation->root_instruction();
735 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
736 AlgebraicSimplifier simplifier(default_options_);
737 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
738 root = computation->root_instruction();
739 EXPECT_EQ(root, param0);
740 }
741
TEST_F(AlgebraicSimplifierTest,InlineTrivialMap)742 TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
743 auto m = CreateNewVerifiedModule();
744 HloComputation::Builder builder(TestName());
745 // Create add computation.
746 HloComputation* add_computation = nullptr;
747 {
748 HloComputation::Builder builder(TestName() + ".add");
749 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
750 HloInstruction* p0 = builder.AddInstruction(
751 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
752 HloInstruction* p1 = builder.AddInstruction(
753 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
754 builder.AddInstruction(
755 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
756 add_computation = m->AddEmbeddedComputation(builder.Build());
757 }
758 Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
759 HloInstruction* param0 = builder.AddInstruction(
760 HloInstruction::CreateParameter(0, r2f32, "param0"));
761 HloInstruction* zero = builder.AddInstruction(
762 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
763 builder.AddInstruction(HloInstruction::CreateMap(
764 r2f32,
765 {param0, builder.AddInstruction(
766 HloInstruction::CreateBroadcast(r2f32, zero, {}))},
767 add_computation));
768
769 auto computation = m->AddEntryComputation(builder.Build());
770 HloInstruction* root = computation->root_instruction();
771 EXPECT_EQ(root->opcode(), HloOpcode::kMap);
772 AlgebraicSimplifier simplifier(default_options_);
773 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
774 root = computation->root_instruction();
775 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
776 m::Broadcast(m::Op().Is(zero)))));
777 }
778
TEST_F(AlgebraicSimplifierTest,KeepNontrivialMap)779 TEST_F(AlgebraicSimplifierTest, KeepNontrivialMap) {
780 const char* kModuleStr = R"(
781 HloModule m
782 fusion {
783 x = f32[] parameter(0)
784 c = f32[] constant(42)
785 m = f32[] multiply(x, x)
786 ROOT a = f32[] add(m, c)
787 }
788
789 map {
790 x = f32[] parameter(0)
791 ROOT f = f32[] fusion(x), kind=kLoop, calls=fusion
792 }
793
794 ENTRY test {
795 p = f32[2,2] parameter(0)
796 ROOT map = f32[2,2] map(p), dimensions={0,1}, to_apply=map
797 }
798 )";
799 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
800 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
801 }
802
TEST_F(AlgebraicSimplifierTest,AddBroadcastZeroR1Operand)803 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
804 auto m = CreateNewVerifiedModule();
805 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
806 HloComputation::Builder builder(TestName());
807 HloInstruction* param0 = builder.AddInstruction(
808 HloInstruction::CreateParameter(0, r2f32, "param0"));
809 HloInstruction* zero = builder.AddInstruction(
810 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
811 HloInstruction* bcast =
812 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
813 builder.AddInstruction(
814 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
815
816 auto computation = m->AddEntryComputation(builder.Build());
817 HloInstruction* root = computation->root_instruction();
818 EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
819 AlgebraicSimplifier simplifier(default_options_);
820 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
821 root = computation->root_instruction();
822 EXPECT_EQ(root, param0);
823 }
824
TEST_F(AlgebraicSimplifierTest,ConstantToBroadcast)825 TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
826 auto m = CreateNewVerifiedModule();
827 HloComputation::Builder builder(TestName());
828 builder.AddInstruction(HloInstruction::CreateConstant(
829 LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
830
831 auto computation = m->AddEntryComputation(builder.Build());
832 HloInstruction* root = computation->root_instruction();
833 EXPECT_THAT(root, GmockMatch(m::Constant()));
834 AlgebraicSimplifier simplifier(default_options_);
835 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
836 root = computation->root_instruction();
837 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
838 EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement<float>());
839 }
840
TEST_F(AlgebraicSimplifierTest,ConstantNotToBroadcast)841 TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
842 auto m = CreateNewVerifiedModule();
843 HloComputation::Builder builder(TestName());
844 builder.AddInstruction(HloInstruction::CreateConstant(
845 LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
846
847 auto computation = m->AddEntryComputation(builder.Build());
848 HloInstruction* root = computation->root_instruction();
849 EXPECT_THAT(root, GmockMatch(m::Constant()));
850 AlgebraicSimplifier simplifier(default_options_);
851 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
852 root = computation->root_instruction();
853 EXPECT_THAT(root, GmockMatch(m::Constant()));
854 }
855
TEST_F(AlgebraicSimplifierTest,IotaToBroadcast)856 TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) {
857 auto m = CreateNewVerifiedModule();
858 HloComputation::Builder builder(TestName());
859 builder.AddInstruction(HloInstruction::CreateConstant(
860 LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f})));
861
862 auto computation = m->AddEntryComputation(builder.Build());
863 HloInstruction* root = computation->root_instruction();
864 EXPECT_THAT(root, GmockMatch(m::Constant()));
865 AlgebraicSimplifier simplifier(default_options_);
866 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
867 root = computation->root_instruction();
868 EXPECT_THAT(root, GmockMatch(m::Iota()));
869 }
870
871 // Test that A - 0 is simplified to A
TEST_F(AlgebraicSimplifierTest,SubZero)872 TEST_F(AlgebraicSimplifierTest, SubZero) {
873 auto m = CreateNewVerifiedModule();
874 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
875 HloComputation::Builder builder(TestName());
876 HloInstruction* param0 = builder.AddInstruction(
877 HloInstruction::CreateParameter(0, r0f32, "param0"));
878 HloInstruction* zero = builder.AddInstruction(
879 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
880 builder.AddInstruction(
881 HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
882
883 auto computation = m->AddEntryComputation(builder.Build());
884 HloInstruction* root = computation->root_instruction();
885 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
886 AlgebraicSimplifier simplifier(default_options_);
887 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
888 root = computation->root_instruction();
889 EXPECT_EQ(root, param0);
890 }
891
892 // Test that A - Const is canonicalized to A + (-Const).
TEST_F(AlgebraicSimplifierTest,SubConstCanonicalization)893 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
894 auto m = CreateNewVerifiedModule();
895 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
896 HloComputation::Builder builder(TestName());
897 HloInstruction* param0 = builder.AddInstruction(
898 HloInstruction::CreateParameter(0, r0f32, "param0"));
899 HloInstruction* constant = builder.AddInstruction(
900 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
901 builder.AddInstruction(HloInstruction::CreateBinary(
902 r0f32, HloOpcode::kSubtract, param0, constant));
903
904 auto computation = m->AddEntryComputation(builder.Build());
905 HloInstruction* root = computation->root_instruction();
906 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
907 AlgebraicSimplifier simplifier(default_options_);
908 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
909 root = computation->root_instruction();
910 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0),
911 m::Negate(m::Op().Is(constant)))));
912 }
913
914 // Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const).
TEST_F(AlgebraicSimplifierTest,SubBroadcastConstCanonicalization)915 TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) {
916 const char* kModuleStr = R"(
917 HloModule m
918 test {
919 p0 = f32[4] parameter(0)
920 c = f32[] constant(0.125)
921 b = f32[4] broadcast(c), dimensions={}
922 ROOT sub = f32[4] subtract(p0, b)
923 }
924 )";
925 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
926 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
927 EXPECT_THAT(
928 m->entry_computation()->root_instruction(),
929 GmockMatch(m::Add(m::Parameter(0),
930 m::Broadcast(m::Negate(m::ConstantScalar(0.125))))));
931 }
932
933 // Test that A - A is simplified to 0.
TEST_F(AlgebraicSimplifierTest,SubSame)934 TEST_F(AlgebraicSimplifierTest, SubSame) {
935 const char* kModuleStr = R"(
936 HloModule m
937 test {
938 p0 = s32[2] parameter(0)
939 ROOT sub = s32[2] subtract(p0, p0)
940 }
941 )";
942 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
943 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
944 EXPECT_THAT(m->entry_computation()->root_instruction(),
945 GmockMatch(m::Broadcast(m::ConstantScalar(0))));
946 }
947
948 // Test that Broadcast(x) where x has degenerate dimensions first removes the
949 // degenerate dimensions.
TEST_F(AlgebraicSimplifierTest,DegenerateDimsInOperandRemovedFromBroadcast)950 TEST_F(AlgebraicSimplifierTest, DegenerateDimsInOperandRemovedFromBroadcast) {
951 const char* kModuleStr = R"(
952 HloModule m
953 test {
954 c = f32[1,4] parameter(0)
955 ROOT b = f32[5,1,4,3] broadcast(c), dimensions={1,2}
956 }
957 )";
958 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
959 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
960 EXPECT_THAT(m->entry_computation()->root_instruction(),
961 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
962 }
963
964 // Test to catch a crash where we were overshooting the reshaped_dimensions
965 // vector.
TEST_F(AlgebraicSimplifierTest,ArrayOvershootTest)966 TEST_F(AlgebraicSimplifierTest, ArrayOvershootTest) {
967 const char* kModuleStr = R"(
968 HloModule m
969 test {
970 param0 = f32[18,18,2,1,1,128]{1,0,5,2,4,3} parameter(0)
971 cpy1 = f32[18,18,2,1,1,128]{5,2,1,0,4,3} copy(f32[18,18,2,1,1,128]{1,0,5,2,4,3} param0)
972 bitcast = f32[648,128,1,1]{3,2,1,0} bitcast(cpy1)
973 ROOT cpy2 = f32[648,128,1,1]{3,2,0,1} copy(f32[648,128,1,1]{3,2,1,0} bitcast)
974 }
975 )";
976
977 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
978 AlgebraicSimplifierOptions options;
979 options.set_is_layout_sensitive(true);
980 AlgebraicSimplifier simplifier(options);
981 // Assert false because algebraic simplifier - at the time of adding this
982 // test - does not change anything. Motivation of the test to make sure it
983 // does not crash the compiler.
984 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
985 }
986
987 // Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest,LhsDivOfDiv)988 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
989 auto m = CreateNewVerifiedModule();
990 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
991 HloComputation::Builder builder(TestName());
992 HloInstruction* param0 = builder.AddInstruction(
993 HloInstruction::CreateParameter(0, r0f32, "param0"));
994 HloInstruction* param1 = builder.AddInstruction(
995 HloInstruction::CreateParameter(1, r0f32, "param1"));
996 HloInstruction* param2 = builder.AddInstruction(
997 HloInstruction::CreateParameter(2, r0f32, "param2"));
998 HloInstruction* div = builder.AddInstruction(
999 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
1000 builder.AddInstruction(
1001 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
1002
1003 auto computation = m->AddEntryComputation(builder.Build());
1004
1005 EXPECT_THAT(computation->root_instruction(),
1006 GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
1007 m::Parameter(2))));
1008
1009 AlgebraicSimplifier simplifier(default_options_);
1010 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1011
1012 EXPECT_THAT(
1013 computation->root_instruction(),
1014 GmockMatch(m::Divide(m::Parameter(0),
1015 m::Multiply(m::Parameter(1), m::Parameter(2)))));
1016 }
1017
1018 // Test that A/(B/C) is simplified to (A*C)/B.
TEST_F(AlgebraicSimplifierTest,RhsDivOfDiv)1019 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
1020 auto m = CreateNewVerifiedModule();
1021 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1022 HloComputation::Builder builder(TestName());
1023 HloInstruction* param0 = builder.AddInstruction(
1024 HloInstruction::CreateParameter(0, r0f32, "param0"));
1025 HloInstruction* param1 = builder.AddInstruction(
1026 HloInstruction::CreateParameter(1, r0f32, "param1"));
1027 HloInstruction* param2 = builder.AddInstruction(
1028 HloInstruction::CreateParameter(2, r0f32, "param2"));
1029 HloInstruction* div = builder.AddInstruction(
1030 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
1031 builder.AddInstruction(
1032 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
1033
1034 auto computation = m->AddEntryComputation(builder.Build());
1035
1036 EXPECT_THAT(
1037 computation->root_instruction(),
1038 GmockMatch(m::Divide(m::Parameter(0),
1039 m::Divide(m::Parameter(1), m::Parameter(2)))));
1040
1041 AlgebraicSimplifier simplifier(default_options_);
1042 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1043
1044 EXPECT_THAT(
1045 computation->root_instruction(),
1046 GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)),
1047 m::Parameter(1))));
1048 }
1049
1050 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest,DivOfDivAndDiv)1051 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
1052 auto m = CreateNewVerifiedModule();
1053 Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
1054 HloComputation::Builder builder(TestName());
1055 HloInstruction* param0 = builder.AddInstruction(
1056 HloInstruction::CreateParameter(0, r2f32, "param0"));
1057 HloInstruction* param1 = builder.AddInstruction(
1058 HloInstruction::CreateParameter(1, r2f32, "param1"));
1059 HloInstruction* param2 = builder.AddInstruction(
1060 HloInstruction::CreateParameter(2, r2f32, "param2"));
1061 HloInstruction* param3 = builder.AddInstruction(
1062 HloInstruction::CreateParameter(3, r2f32, "param3"));
1063 HloInstruction* div0 = builder.AddInstruction(
1064 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
1065 HloInstruction* div1 = builder.AddInstruction(
1066 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
1067 builder.AddInstruction(
1068 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
1069
1070 auto computation = m->AddEntryComputation(builder.Build());
1071
1072 EXPECT_THAT(
1073 computation->root_instruction(),
1074 GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)),
1075 m::Divide(m::Parameter(2), m::Parameter(3)))));
1076
1077 AlgebraicSimplifier simplifier(default_options_);
1078 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1079
1080 EXPECT_THAT(
1081 computation->root_instruction(),
1082 GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)),
1083 m::Multiply(m::Parameter(1), m::Parameter(2)))));
1084 }
1085
1086 // Test that A/exp(B) is simplified to A*exp(-B).
TEST_F(AlgebraicSimplifierTest,DivOfExp)1087 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
1088 auto m = CreateNewVerifiedModule();
1089 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1090 HloComputation::Builder builder(TestName());
1091 HloInstruction* param0 = builder.AddInstruction(
1092 HloInstruction::CreateParameter(0, r0f32, "param0"));
1093 HloInstruction* param1 = builder.AddInstruction(
1094 HloInstruction::CreateParameter(1, r0f32, "param1"));
1095 HloInstruction* exp = builder.AddInstruction(
1096 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1097 builder.AddInstruction(
1098 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
1099
1100 auto computation = m->AddEntryComputation(builder.Build());
1101
1102 EXPECT_THAT(computation->root_instruction(),
1103 GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1)))));
1104
1105 AlgebraicSimplifier simplifier(default_options_);
1106 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1107
1108 EXPECT_THAT(computation->root_instruction(),
1109 GmockMatch(m::Multiply(m::Parameter(0),
1110 m::Exp(m::Negate(m::Parameter(1))))));
1111 }
1112
1113 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfPower)1114 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
1115 auto m = CreateNewVerifiedModule();
1116 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1117 HloComputation::Builder builder(TestName());
1118 HloInstruction* param0 = builder.AddInstruction(
1119 HloInstruction::CreateParameter(0, r0f32, "param0"));
1120 HloInstruction* param1 = builder.AddInstruction(
1121 HloInstruction::CreateParameter(1, r0f32, "param1"));
1122 HloInstruction* param2 = builder.AddInstruction(
1123 HloInstruction::CreateParameter(2, r0f32, "param2"));
1124 HloInstruction* power = builder.AddInstruction(
1125 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
1126 builder.AddInstruction(
1127 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
1128
1129 auto computation = m->AddEntryComputation(builder.Build());
1130
1131 EXPECT_THAT(
1132 computation->root_instruction(),
1133 GmockMatch(m::Divide(m::Parameter(0),
1134 m::Power(m::Parameter(1), m::Parameter(2)))));
1135
1136 AlgebraicSimplifier simplifier(default_options_);
1137 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1138
1139 EXPECT_THAT(computation->root_instruction(),
1140 GmockMatch(m::Multiply(
1141 m::Parameter(0),
1142 m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1143 }
1144
1145 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
1146 // to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest,DivOfBroadcastingPower)1147 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
1148 auto m = CreateNewVerifiedModule();
1149 Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1150 HloComputation::Builder builder(TestName());
1151 HloInstruction* param0 = builder.AddInstruction(
1152 HloInstruction::CreateParameter(0, r1f32, "param0"));
1153 HloInstruction* param1 = builder.AddInstruction(
1154 HloInstruction::CreateParameter(1, r1f32, "param1"));
1155 HloInstruction* param2 = builder.AddInstruction(
1156 HloInstruction::CreateParameter(2, r1f32, "param2"));
1157 HloInstruction* power = builder.AddInstruction(
1158 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
1159 builder.AddInstruction(
1160 HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
1161
1162 auto computation = m->AddEntryComputation(builder.Build());
1163
1164 EXPECT_THAT(
1165 computation->root_instruction(),
1166 GmockMatch(m::Divide(m::Parameter(0),
1167 m::Power(m::Parameter(1), m::Parameter(2)))));
1168
1169 AlgebraicSimplifier simplifier(default_options_);
1170 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1171
1172 ASSERT_THAT(computation->root_instruction(),
1173 GmockMatch(m::Multiply(
1174 m::Parameter(0),
1175 m::Power(m::Parameter(1), m::Negate(m::Parameter(2))))));
1176 }
1177
1178 // A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest,DivideByConstant)1179 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
1180 auto m = CreateNewVerifiedModule();
1181 Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
1182 HloComputation::Builder builder(TestName());
1183 HloInstruction* param0 = builder.AddInstruction(
1184 HloInstruction::CreateParameter(0, r1f32, "param0"));
1185 HloInstruction* constant =
1186 builder.AddInstruction(HloInstruction::CreateConstant(
1187 LiteralUtil::CreateR1<float>({1.f, 2.f, 3.f})));
1188 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
1189 param0, constant));
1190
1191 auto computation = m->AddEntryComputation(builder.Build());
1192
1193 AlgebraicSimplifier simplifier(default_options_);
1194 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1195
1196 EXPECT_THAT(computation->root_instruction(),
1197 GmockMatch(m::Multiply(m::Parameter(0), m::Constant())));
1198 }
1199
1200 // A / Broadcast(Const) => A * Broadcast(InvertedConst)
TEST_F(AlgebraicSimplifierTest,DivideByBroadcastedConstant)1201 TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) {
1202 const char* kModuleStr = R"(
1203 HloModule m
1204 test {
1205 p = f32[4] parameter(0)
1206 c = f32[] constant(256.0)
1207 b = f32[4] broadcast(c), dimensions={}
1208 ROOT d = f32[4] divide(p, b)
1209 }
1210 )";
1211 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
1212 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
1213
1214 EXPECT_THAT(m->entry_computation()->root_instruction(),
1215 GmockMatch(m::Multiply(
1216 m::Parameter(0),
1217 m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f)))));
1218 }
1219
1220 // pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest,PowerOfPower)1221 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
1222 auto m = CreateNewVerifiedModule();
1223 Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
1224 HloComputation::Builder builder(TestName());
1225 HloInstruction* base = builder.AddInstruction(
1226 HloInstruction::CreateParameter(0, r1f32, "param0"));
1227 HloInstruction* exp1 = builder.AddInstruction(
1228 HloInstruction::CreateParameter(1, r1f32, "param1"));
1229 HloInstruction* exp2 = builder.AddInstruction(
1230 HloInstruction::CreateParameter(2, r1f32, "param2"));
1231 HloInstruction* inner_power = builder.AddInstruction(
1232 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
1233 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
1234 inner_power, exp2));
1235
1236 AlgebraicSimplifier simplifier(default_options_);
1237 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1238 }
1239
1240 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
1241 // numbers.
TEST_F(AlgebraicSimplifierTest,PowerOfPowerComplex)1242 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
1243 auto m = CreateNewVerifiedModule();
1244 Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
1245 HloComputation::Builder builder(TestName());
1246 HloInstruction* base = builder.AddInstruction(
1247 HloInstruction::CreateParameter(0, r1c64, "param0"));
1248 HloInstruction* exp1 = builder.AddInstruction(
1249 HloInstruction::CreateParameter(1, r1c64, "param1"));
1250 HloInstruction* exp2 = builder.AddInstruction(
1251 HloInstruction::CreateParameter(2, r1c64, "param2"));
1252 HloInstruction* inner_power = builder.AddInstruction(
1253 HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
1254 builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
1255 inner_power, exp2));
1256
1257 m->AddEntryComputation(builder.Build());
1258 AlgebraicSimplifier simplifier(default_options_);
1259 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
1260 }
1261
1262 // Test that A/1 is simplified to A for a scalar.
TEST_F(AlgebraicSimplifierTest,DivOneScalar)1263 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
1264 auto m = CreateNewVerifiedModule();
1265 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1266 HloComputation::Builder builder(TestName());
1267 HloInstruction* param0 = builder.AddInstruction(
1268 HloInstruction::CreateParameter(0, r0f32, "param0"));
1269 HloInstruction* one = builder.AddInstruction(
1270 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
1271 HloInstruction* div = builder.AddInstruction(
1272 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
1273
1274 auto computation = m->AddEntryComputation(builder.Build());
1275 HloInstruction* root = computation->root_instruction();
1276 EXPECT_EQ(root, div);
1277 AlgebraicSimplifier simplifier(default_options_);
1278 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1279 root = computation->root_instruction();
1280 EXPECT_EQ(root, param0);
1281 }
1282
1283 // Test that A/1 is simplified to A for an array.
TEST_F(AlgebraicSimplifierTest,DivOneArray)1284 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
1285 auto m = CreateNewVerifiedModule();
1286 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1287 HloComputation::Builder builder(TestName());
1288 HloInstruction* param0 = builder.AddInstruction(
1289 HloInstruction::CreateParameter(0, r2f32, "param0"));
1290 HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
1291 LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
1292 HloInstruction* div = builder.AddInstruction(
1293 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
1294
1295 auto computation = m->AddEntryComputation(builder.Build());
1296 HloInstruction* root = computation->root_instruction();
1297 EXPECT_EQ(root, div);
1298 AlgebraicSimplifier simplifier(default_options_);
1299 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1300 root = computation->root_instruction();
1301 EXPECT_EQ(root, param0);
1302 }
1303
1304 // Test that complex(real(c), imag(c)) is simplified to c.
TEST_F(AlgebraicSimplifierTest,ComplexOfRealImagC)1305 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
1306 auto m = CreateNewVerifiedModule();
1307 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1308 Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
1309 HloComputation::Builder builder(TestName());
1310 HloInstruction* param0 = builder.AddInstruction(
1311 HloInstruction::CreateParameter(0, r2c64, "param0"));
1312 HloInstruction* real = builder.AddInstruction(
1313 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
1314 HloInstruction* imag = builder.AddInstruction(
1315 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
1316 HloInstruction* cplx = builder.AddInstruction(
1317 HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
1318
1319 auto computation = m->AddEntryComputation(builder.Build());
1320 HloInstruction* root = computation->root_instruction();
1321 EXPECT_EQ(root, cplx);
1322 AlgebraicSimplifier simplifier(default_options_);
1323 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1324 root = computation->root_instruction();
1325 EXPECT_EQ(root, param0);
1326 }
1327
1328 // Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest,RealOfComplex)1329 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
1330 auto m = CreateNewVerifiedModule();
1331 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1332 HloComputation::Builder builder(TestName());
1333 HloInstruction* param0 = builder.AddInstruction(
1334 HloInstruction::CreateParameter(0, r2f32, "param0"));
1335 HloInstruction* param1 = builder.AddInstruction(
1336 HloInstruction::CreateParameter(1, r2f32, "param1"));
1337 HloInstruction* cplx = builder.AddInstruction(
1338 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1339 HloOpcode::kComplex, param0, param1));
1340 HloInstruction* real = builder.AddInstruction(
1341 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
1342
1343 auto computation = m->AddEntryComputation(builder.Build());
1344 HloInstruction* root = computation->root_instruction();
1345 EXPECT_EQ(root, real);
1346 AlgebraicSimplifier simplifier(default_options_);
1347 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1348 root = computation->root_instruction();
1349 EXPECT_EQ(root, param0);
1350 }
1351
1352 // Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest,ImagOfComplex)1353 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
1354 auto m = CreateNewVerifiedModule();
1355 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
1356 HloComputation::Builder builder(TestName());
1357 HloInstruction* param0 = builder.AddInstruction(
1358 HloInstruction::CreateParameter(0, r2f32, "param0"));
1359 HloInstruction* param1 = builder.AddInstruction(
1360 HloInstruction::CreateParameter(1, r2f32, "param1"));
1361 HloInstruction* cplx = builder.AddInstruction(
1362 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
1363 HloOpcode::kComplex, param0, param1));
1364 HloInstruction* imag = builder.AddInstruction(
1365 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
1366
1367 auto computation = m->AddEntryComputation(builder.Build());
1368 HloInstruction* root = computation->root_instruction();
1369 EXPECT_EQ(root, imag);
1370 AlgebraicSimplifier simplifier(default_options_);
1371 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1372 root = computation->root_instruction();
1373 EXPECT_EQ(root, param1);
1374 }
1375
1376 // Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest,SelectMakeTuple)1377 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
1378 auto m = CreateNewVerifiedModule();
1379 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1380 HloComputation::Builder builder(TestName());
1381 HloInstruction* param0 = builder.AddInstruction(
1382 HloInstruction::CreateParameter(0, r0f32, "param0"));
1383 HloInstruction* param1 = builder.AddInstruction(
1384 HloInstruction::CreateParameter(1, r0f32, "param1"));
1385 HloInstruction* param2 = builder.AddInstruction(
1386 HloInstruction::CreateParameter(2, r0f32, "param2"));
1387 HloInstruction* tuple =
1388 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1389 HloInstruction* get = builder.AddInstruction(
1390 HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
1391 HloInstruction* add = builder.AddInstruction(
1392 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
1393
1394 auto computation = m->AddEntryComputation(builder.Build());
1395 HloInstruction* root = computation->root_instruction();
1396 EXPECT_EQ(root, add);
1397 AlgebraicSimplifier simplifier(default_options_);
1398 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1399 root = computation->root_instruction();
1400 EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2))));
1401 }
1402
1403 // Test that exp(A)/exp(B) is simplified to exp(A-B)
TEST_F(AlgebraicSimplifierTest,ExpDiv)1404 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
1405 auto m = CreateNewVerifiedModule();
1406 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1407 HloComputation::Builder builder(TestName());
1408 HloInstruction* param0 = builder.AddInstruction(
1409 HloInstruction::CreateParameter(0, r0f32, "param0"));
1410 HloInstruction* param1 = builder.AddInstruction(
1411 HloInstruction::CreateParameter(1, r0f32, "param1"));
1412 HloInstruction* exp0 = builder.AddInstruction(
1413 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1414 HloInstruction* exp1 = builder.AddInstruction(
1415 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1416 builder.AddInstruction(
1417 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1418
1419 auto computation = m->AddEntryComputation(builder.Build());
1420
1421 EXPECT_THAT(
1422 computation->root_instruction(),
1423 GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1)))));
1424
1425 AlgebraicSimplifier simplifier(default_options_);
1426 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1427
1428 EXPECT_THAT(
1429 computation->root_instruction(),
1430 GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1)))));
1431 }
1432
1433 // Test that exp(A)*exp(B) is simplified to exp(A+B)
TEST_F(AlgebraicSimplifierTest,ExpMul)1434 TEST_F(AlgebraicSimplifierTest, ExpMul) {
1435 auto m = CreateNewVerifiedModule();
1436 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1437 HloComputation::Builder builder(TestName());
1438 HloInstruction* param0 = builder.AddInstruction(
1439 HloInstruction::CreateParameter(0, r0f32, "param0"));
1440 HloInstruction* param1 = builder.AddInstruction(
1441 HloInstruction::CreateParameter(1, r0f32, "param1"));
1442 HloInstruction* exp0 = builder.AddInstruction(
1443 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1444 HloInstruction* exp1 = builder.AddInstruction(
1445 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1446 builder.AddInstruction(
1447 HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
1448
1449 auto computation = m->AddEntryComputation(builder.Build());
1450
1451 EXPECT_THAT(computation->root_instruction(),
1452 GmockMatch(m::Multiply(m::Exp(m::Parameter(0)),
1453 m::Exp(m::Parameter(1)))));
1454
1455 AlgebraicSimplifier simplifier(default_options_);
1456 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1457
1458 EXPECT_THAT(computation->root_instruction(),
1459 GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1)))));
1460 }
1461
1462 // Test that pow(exp(A), B) is simplified to exp(A*B)
TEST_F(AlgebraicSimplifierTest,PowExp)1463 TEST_F(AlgebraicSimplifierTest, PowExp) {
1464 auto m = CreateNewVerifiedModule();
1465 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1466 HloComputation::Builder builder(TestName());
1467 HloInstruction* param0 = builder.AddInstruction(
1468 HloInstruction::CreateParameter(0, r0f32, "param0"));
1469 HloInstruction* param1 = builder.AddInstruction(
1470 HloInstruction::CreateParameter(1, r0f32, "param1"));
1471 HloInstruction* exp0 = builder.AddInstruction(
1472 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1473 builder.AddInstruction(
1474 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
1475
1476 auto computation = m->AddEntryComputation(builder.Build());
1477
1478 EXPECT_THAT(computation->root_instruction(),
1479 GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1))));
1480
1481 AlgebraicSimplifier simplifier(default_options_);
1482 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1483
1484 EXPECT_THAT(
1485 computation->root_instruction(),
1486 GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1)))));
1487 }
1488
1489 // Test that ln(pow(A, B)) is simplified to ln(A)*B
TEST_F(AlgebraicSimplifierTest,LnPow)1490 TEST_F(AlgebraicSimplifierTest, LnPow) {
1491 auto m = CreateNewVerifiedModule();
1492 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1493 HloComputation::Builder builder(TestName());
1494 HloInstruction* param0 = builder.AddInstruction(
1495 HloInstruction::CreateParameter(0, r0f32, "param0"));
1496 HloInstruction* param1 = builder.AddInstruction(
1497 HloInstruction::CreateParameter(1, r0f32, "param1"));
1498 HloInstruction* pow = builder.AddInstruction(
1499 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
1500 builder.AddInstruction(
1501 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
1502
1503 auto computation = m->AddEntryComputation(builder.Build());
1504
1505 EXPECT_THAT(computation->root_instruction(),
1506 GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1)))));
1507
1508 AlgebraicSimplifier simplifier(default_options_);
1509 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1510
1511 EXPECT_THAT(
1512 computation->root_instruction(),
1513 GmockMatch(m::Select(
1514 m::Eq(m::Parameter(1), m::ConstantScalar(0.0f)),
1515 m::ConstantScalar(0.0f),
1516 m::Multiply(m::Log(m::Abs(m::Parameter(0))), m::Parameter(1)))));
1517 }
1518
TEST_F(AlgebraicSimplifierTest,LnSqrt)1519 TEST_F(AlgebraicSimplifierTest, LnSqrt) {
1520 auto m = CreateNewVerifiedModule();
1521 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1522 HloComputation::Builder builder(TestName());
1523 HloInstruction* param0 = builder.AddInstruction(
1524 HloInstruction::CreateParameter(0, r0f32, "param0"));
1525 HloInstruction* sqrt = builder.AddInstruction(
1526 HloInstruction::CreateUnary(r0f32, HloOpcode::kSqrt, param0));
1527 builder.AddInstruction(
1528 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, sqrt));
1529
1530 auto computation = m->AddEntryComputation(builder.Build());
1531
1532 EXPECT_THAT(computation->root_instruction(),
1533 GmockMatch(m::Log(m::Sqrt(m::Parameter(0)))));
1534
1535 AlgebraicSimplifier simplifier(default_options_);
1536 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1537
1538 EXPECT_THAT(
1539 computation->root_instruction(),
1540 GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::ConstantScalar(0.5))));
1541 }
1542
TEST_F(AlgebraicSimplifierTest,LnRsqrt)1543 TEST_F(AlgebraicSimplifierTest, LnRsqrt) {
1544 auto m = CreateNewVerifiedModule();
1545 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1546 HloComputation::Builder builder(TestName());
1547 HloInstruction* param0 = builder.AddInstruction(
1548 HloInstruction::CreateParameter(0, r0f32, "param0"));
1549 HloInstruction* rsqrt = builder.AddInstruction(
1550 HloInstruction::CreateUnary(r0f32, HloOpcode::kRsqrt, param0));
1551 builder.AddInstruction(
1552 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, rsqrt));
1553
1554 auto computation = m->AddEntryComputation(builder.Build());
1555
1556 EXPECT_THAT(computation->root_instruction(),
1557 GmockMatch(m::Log(m::Rsqrt(m::Parameter(0)))));
1558
1559 AlgebraicSimplifier simplifier(default_options_);
1560 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1561
1562 EXPECT_THAT(computation->root_instruction(),
1563 GmockMatch(m::Multiply(m::Log(m::Parameter(0)),
1564 m::ConstantScalar(-0.5))));
1565 }
1566
1567 // Test that ln(exp(A)) is simplified to A
TEST_F(AlgebraicSimplifierTest,LnExp)1568 TEST_F(AlgebraicSimplifierTest, LnExp) {
1569 auto m = CreateNewVerifiedModule();
1570 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1571 HloComputation::Builder builder(TestName());
1572 HloInstruction* param0 = builder.AddInstruction(
1573 HloInstruction::CreateParameter(0, r0f32, "param0"));
1574 HloInstruction* exp0 = builder.AddInstruction(
1575 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1576 builder.AddInstruction(
1577 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
1578
1579 auto computation = m->AddEntryComputation(builder.Build());
1580
1581 EXPECT_THAT(computation->root_instruction(),
1582 GmockMatch(m::Log(m::Exp(m::Parameter(0)))));
1583
1584 AlgebraicSimplifier simplifier(default_options_);
1585 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1586
1587 EXPECT_EQ(computation->root_instruction(), param0);
1588 }
1589
1590 // Test that ln(exp(A)/exp(B)) is simplified to A-B
TEST_F(AlgebraicSimplifierTest,LnExpDiv)1591 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
1592 auto m = CreateNewVerifiedModule();
1593 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1594 HloComputation::Builder builder(TestName());
1595 HloInstruction* param0 = builder.AddInstruction(
1596 HloInstruction::CreateParameter(0, r0f32, "param0"));
1597 HloInstruction* param1 = builder.AddInstruction(
1598 HloInstruction::CreateParameter(1, r0f32, "param1"));
1599 HloInstruction* exp0 = builder.AddInstruction(
1600 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
1601 HloInstruction* exp1 = builder.AddInstruction(
1602 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
1603 HloInstruction* div = builder.AddInstruction(
1604 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
1605 builder.AddInstruction(
1606 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
1607
1608 auto computation = m->AddEntryComputation(builder.Build());
1609
1610 EXPECT_THAT(computation->root_instruction(),
1611 GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)),
1612 m::Exp(m::Parameter(1))))));
1613
1614 AlgebraicSimplifier simplifier(default_options_);
1615 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1616
1617 EXPECT_THAT(computation->root_instruction(),
1618 GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1))));
1619 }
1620
1621 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
1622 // constant 1.
TEST_F(AlgebraicSimplifierTest,Pow0Scalar)1623 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
1624 auto m = CreateNewVerifiedModule();
1625 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1626 HloComputation::Builder builder(TestName());
1627 HloInstruction* param0 = builder.AddInstruction(
1628 HloInstruction::CreateParameter(0, r0f32, "param0"));
1629 HloInstruction* zero = builder.AddInstruction(
1630 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1631 builder.AddInstruction(
1632 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
1633
1634 auto computation = m->AddEntryComputation(builder.Build());
1635
1636 EXPECT_THAT(computation->root_instruction(),
1637 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1638
1639 AlgebraicSimplifier simplifier(default_options_);
1640 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1641
1642 HloInstruction* root = computation->root_instruction();
1643 EXPECT_THAT(root, GmockMatch(m::Constant()));
1644 EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
1645 }
1646
1647 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
TEST_F(AlgebraicSimplifierTest,Pow0Vector)1648 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
1649 auto m = CreateNewVerifiedModule();
1650 Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
1651 HloComputation::Builder builder(TestName());
1652 HloInstruction* param0 = builder.AddInstruction(
1653 HloInstruction::CreateParameter(0, r1f32, "param0"));
1654 HloInstruction* zero = builder.AddInstruction(
1655 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
1656 builder.AddInstruction(
1657 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
1658
1659 auto computation = m->AddEntryComputation(builder.Build());
1660
1661 EXPECT_THAT(computation->root_instruction(),
1662 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero))));
1663
1664 AlgebraicSimplifier simplifier(default_options_);
1665 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1666
1667 HloInstruction* root = computation->root_instruction();
1668 EXPECT_THAT(root, GmockMatch(m::Broadcast()));
1669 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
1670 << ShapeUtil::HumanString(root->shape());
1671 EXPECT_EQ(root->dimensions().size(), 0);
1672 EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
1673 EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1674 }
1675
1676 // Test that pow(A, 1) is simplified to A.
TEST_F(AlgebraicSimplifierTest,Pow1)1677 TEST_F(AlgebraicSimplifierTest, Pow1) {
1678 auto m = CreateNewVerifiedModule();
1679 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1680 HloComputation::Builder builder(TestName());
1681 HloInstruction* param0 = builder.AddInstruction(
1682 HloInstruction::CreateParameter(0, r0f32, "param0"));
1683 HloInstruction* one = builder.AddInstruction(
1684 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
1685 builder.AddInstruction(
1686 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
1687
1688 auto computation = m->AddEntryComputation(builder.Build());
1689
1690 EXPECT_THAT(computation->root_instruction(),
1691 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one))));
1692
1693 AlgebraicSimplifier simplifier(default_options_);
1694 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1695
1696 EXPECT_EQ(computation->root_instruction(), param0);
1697 }
1698
1699 // Test that pow(A, 2) is simplified to A*A.
TEST_F(AlgebraicSimplifierTest,Pow2)1700 TEST_F(AlgebraicSimplifierTest, Pow2) {
1701 auto m = CreateNewVerifiedModule();
1702 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1703 HloComputation::Builder builder(TestName());
1704 HloInstruction* param0 = builder.AddInstruction(
1705 HloInstruction::CreateParameter(0, r0f32, "param0"));
1706 HloInstruction* two = builder.AddInstruction(
1707 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
1708 builder.AddInstruction(
1709 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
1710
1711 auto computation = m->AddEntryComputation(builder.Build());
1712
1713 EXPECT_THAT(computation->root_instruction(),
1714 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two))));
1715
1716 AlgebraicSimplifier simplifier(default_options_);
1717 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1718
1719 EXPECT_THAT(computation->root_instruction(),
1720 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
1721 }
1722
1723 // Test that pow(A, 3) is simplified to A*A*A.
TEST_F(AlgebraicSimplifierTest,Pow3)1724 TEST_F(AlgebraicSimplifierTest, Pow3) {
1725 auto m = CreateNewVerifiedModule();
1726 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1727 HloComputation::Builder builder(TestName());
1728 HloInstruction* param0 = builder.AddInstruction(
1729 HloInstruction::CreateParameter(0, r0f32, "param0"));
1730 HloInstruction* three = builder.AddInstruction(
1731 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
1732 builder.AddInstruction(
1733 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
1734
1735 auto computation = m->AddEntryComputation(builder.Build());
1736
1737 EXPECT_THAT(computation->root_instruction(),
1738 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
1739
1740 AlgebraicSimplifier simplifier(default_options_);
1741 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1742
1743 EXPECT_THAT(
1744 computation->root_instruction(),
1745 GmockMatch(m::Multiply(m::Parameter(0),
1746 m::Multiply(m::Parameter(0), m::Parameter(0)))));
1747 }
1748
1749 // Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest,PowNegative1)1750 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1751 auto m = CreateNewVerifiedModule();
1752 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1753 HloComputation::Builder builder(TestName());
1754 HloInstruction* param0 = builder.AddInstruction(
1755 HloInstruction::CreateParameter(0, r0f32, "param0"));
1756 HloInstruction* negative_one = builder.AddInstruction(
1757 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
1758 builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
1759 param0, negative_one));
1760
1761 auto computation = m->AddEntryComputation(builder.Build());
1762
1763 EXPECT_THAT(computation->root_instruction(),
1764 GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one))));
1765
1766 AlgebraicSimplifier simplifier(default_options_);
1767 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1768
1769 HloInstruction* root = computation->root_instruction();
1770 EXPECT_THAT(root, GmockMatch(m::Divide(m::Constant(), m::Parameter(0))));
1771 EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
1772 }
1773
TEST_F(AlgebraicSimplifierTest,ZeroSizedConvolution)1774 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
1775 auto m = CreateNewVerifiedModule();
1776 auto builder = HloComputation::Builder(TestName());
1777 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
1778 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
1779
1780 HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1781 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
1782
1783 ConvolutionDimensionNumbers dnums;
1784 dnums.set_input_batch_dimension(0);
1785 dnums.add_input_spatial_dimensions(1);
1786 dnums.set_input_feature_dimension(2);
1787
1788 dnums.set_output_batch_dimension(0);
1789 dnums.add_output_spatial_dimensions(1);
1790 dnums.set_output_feature_dimension(2);
1791
1792 dnums.add_kernel_spatial_dimensions(0);
1793 dnums.set_kernel_input_feature_dimension(1);
1794 dnums.set_kernel_output_feature_dimension(2);
1795 Window window;
1796 WindowDimension* dim = window.add_dimensions();
1797 dim->set_size(3);
1798 dim->set_padding_low(0);
1799 dim->set_padding_high(0);
1800 dim->set_stride(1);
1801 dim->set_window_dilation(1);
1802 dim->set_base_dilation(1);
1803 dim->set_window_reversal(false);
1804 // Create add computation.
1805 builder.AddInstruction(HloInstruction::CreateConvolve(
1806 ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
1807 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1808 m->AddEntryComputation(builder.Build());
1809 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1810 EXPECT_THAT(m->entry_computation()->root_instruction(),
1811 GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs))));
1812 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1813 EXPECT_THAT(m->entry_computation()->root_instruction(),
1814 GmockMatch(m::Broadcast(m::Constant())));
1815 }
1816
TEST_F(AlgebraicSimplifierTest,ReduceWindowIsReduceAndReshape)1817 TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) {
1818 auto m = CreateNewVerifiedModule();
1819 auto builder = HloComputation::Builder(TestName());
1820 HloInstruction* param =
1821 builder.AddInstruction(HloInstruction::CreateParameter(
1822 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param"));
1823 Window window;
1824 for (int64_t i = 0; i < 4; ++i) {
1825 WindowDimension* dim = window.add_dimensions();
1826 // Makes 1x2x3x1 window.
1827 dim->set_size((i % 3) + 1);
1828 dim->set_stride(1);
1829 dim->set_padding_low(0);
1830 dim->set_padding_high(0);
1831 dim->set_window_dilation(1);
1832 dim->set_base_dilation(1);
1833 }
1834 // Create add computation.
1835 HloComputation* add_computation = nullptr;
1836 {
1837 HloComputation::Builder builder(TestName() + ".add");
1838 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1839 HloInstruction* p0 = builder.AddInstruction(
1840 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1841 HloInstruction* p1 = builder.AddInstruction(
1842 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1843 builder.AddInstruction(
1844 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1845 add_computation = m->AddEmbeddedComputation(builder.Build());
1846 }
1847 builder.AddInstruction(HloInstruction::CreateReduceWindow(
1848 ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param,
1849 builder.AddInstruction(
1850 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1851 window, add_computation));
1852 m->AddEntryComputation(builder.Build());
1853 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1854 EXPECT_THAT(m->entry_computation()->root_instruction(),
1855 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1856 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1857 EXPECT_THAT(
1858 m->entry_computation()->root_instruction(),
1859 GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant()))));
1860 }
1861
TEST_F(AlgebraicSimplifierTest,ZeroSizedReduceWindow)1862 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
1863 auto m = CreateNewVerifiedModule();
1864 auto builder = HloComputation::Builder(TestName());
1865 HloInstruction* param =
1866 builder.AddInstruction(HloInstruction::CreateParameter(
1867 0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
1868 Window window;
1869 for (int64_t i = 0; i < 2; ++i) {
1870 WindowDimension* dim = window.add_dimensions();
1871 dim->set_size(1);
1872 dim->set_padding_low(1);
1873 dim->set_padding_high(1);
1874 dim->set_window_dilation(1);
1875 dim->set_base_dilation(1);
1876 }
1877 // Create add computation.
1878 HloComputation* add_computation = nullptr;
1879 {
1880 HloComputation::Builder builder(TestName() + ".add");
1881 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
1882 HloInstruction* p0 = builder.AddInstruction(
1883 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
1884 HloInstruction* p1 = builder.AddInstruction(
1885 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
1886 builder.AddInstruction(
1887 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
1888 add_computation = m->AddEmbeddedComputation(builder.Build());
1889 }
1890 builder.AddInstruction(HloInstruction::CreateReduceWindow(
1891 ShapeUtil::MakeShape(F32, {5, 2}), param,
1892 builder.AddInstruction(
1893 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
1894 window, add_computation));
1895 m->AddEntryComputation(builder.Build());
1896 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1897 EXPECT_THAT(m->entry_computation()->root_instruction(),
1898 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
1899 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1900 EXPECT_THAT(m->entry_computation()->root_instruction(),
1901 GmockMatch(m::Broadcast(m::Constant())));
1902 }
1903
TEST_F(AlgebraicSimplifierTest,ZeroSizedVariadicReduceWindow)1904 TEST_F(AlgebraicSimplifierTest, ZeroSizedVariadicReduceWindow) {
1905 const char* const hlo_string = R"(
1906 HloModule ZeroSizedVariadicReduceWindow
1907
1908 ZeroSizedVariadicReduceWindow.add {
1909 p0 = f32[] parameter(0)
1910 p1 = f32[] parameter(1)
1911 p2 = f32[] parameter(2)
1912 p3 = f32[] parameter(3)
1913 add.0 = f32[] add(p0, p1)
1914 add.1 = f32[] add(p2, p3)
1915 ROOT r = tuple(add.0, add.1)
1916 }
1917
1918 ENTRY ZeroSizedReduceWindow {
1919 op = f32[3,0] parameter(0)
1920 constant = f32[] constant(0)
1921 ROOT reduce-window = (f32[5,2], f32[5,2]) reduce-window(op, op, constant, constant), window={size=1x1 pad=1_1x1_1}, to_apply=ZeroSizedVariadicReduceWindow.add
1922 }
1923 )";
1924
1925 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
1926 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
1927 EXPECT_THAT(m->entry_computation()->root_instruction(),
1928 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Parameter(0),
1929 m::Constant(), m::Constant())));
1930 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
1931 EXPECT_THAT(m->entry_computation()->root_instruction(),
1932 GmockMatch(m::Tuple(m::Broadcast(m::Constant()),
1933 m::Broadcast(m::Constant()))));
1934 }
1935
TEST_F(AlgebraicSimplifierTest,NopMax)1936 TEST_F(AlgebraicSimplifierTest, NopMax) {
1937 const char* const hlo_string = R"(
1938 HloModule test
1939
1940 ENTRY test {
1941 p_s8 = s8[] parameter(0)
1942 p_u8 = u8[] parameter(1)
1943 p_s16 = s16[] parameter(2)
1944 p_u16 = u16[] parameter(3)
1945 p_s32 = s32[] parameter(4)
1946 p_u32 = u32[] parameter(5)
1947 p_s64 = s64[] parameter(6)
1948 p_u64 = u64[] parameter(7)
1949 p_f16 = f16[] parameter(8)
1950 p_bf16 = bf16[] parameter(9)
1951 p_f32 = f32[] parameter(10)
1952 p_f64 = f64[] parameter(11)
1953
1954 const_s8 = s8[] constant(-128)
1955 const_u8 = u8[] constant(0)
1956 const_s16 = s16[] constant(-32768)
1957 const_u16 = u16[] constant(0)
1958 const_s32 = s32[] constant(-2147483648)
1959 const_u32 = u32[] constant(0)
1960 const_s64 = s64[] constant(-9223372036854775808)
1961 const_u64 = u64[] constant(0)
1962 const_f16 = f16[] constant(-inf)
1963 const_bf16 = bf16[] constant(-inf)
1964 const_f32 = f32[] constant(-inf)
1965 const_f64 = f64[] constant(-inf)
1966
1967 max_s8 = s8[] maximum(p_s8, const_s8)
1968 max_u8 = u8[] maximum(p_u8, const_u8)
1969 max_s16 = s16[] maximum(p_s16, const_s16)
1970 max_u16 = u16[] maximum(p_u16, const_u16)
1971 max_s32 = s32[] maximum(p_s32, const_s32)
1972 max_u32 = u32[] maximum(p_u32, const_u32)
1973 max_s64 = s64[] maximum(p_s64, const_s64)
1974 max_u64 = u64[] maximum(p_u64, const_u64)
1975 max_f16 = f16[] maximum(p_f16, const_f16)
1976 max_bf16 = bf16[] maximum(p_bf16, const_bf16)
1977 max_f32 = f32[] maximum(p_f32, const_f32)
1978 max_f64 = f64[] maximum(p_f64, const_f64)
1979
1980 max2_s8 = s8[] maximum(const_s8, p_s8)
1981 max2_u8 = u8[] maximum(const_u8, p_u8)
1982 max2_s16 = s16[] maximum(const_s16, p_s16)
1983 max2_u16 = u16[] maximum(const_u16, p_u16)
1984 max2_s32 = s32[] maximum(const_s32, p_s32)
1985 max2_u32 = u32[] maximum(const_u32, p_u32)
1986 max2_s64 = s64[] maximum(const_s64, p_s64)
1987 max2_u64 = u64[] maximum(const_u64, p_u64)
1988 max2_f16 = f16[] maximum(const_f16, p_f16)
1989 max2_bf16 = bf16[] maximum(const_bf16, p_bf16)
1990 max2_f32 = f32[] maximum(const_f32, p_f32)
1991 max2_f64 = f64[] maximum(const_f64, p_f64)
1992
1993 ROOT tuple = tuple(max_s8, max_u8, max_s16, max_u16, max_s32, max_u32,
1994 max_s64, max_u64, max_f16, max_bf16, max_f32, max_f64,
1995 max2_s8, max2_u8, max2_s16, max2_u16, max2_s32, max2_u32,
1996 max2_s64, max2_u64, max2_f16, max2_bf16, max2_f32, max2_f64)
1997 }
1998 )";
1999
2000 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2001 AlgebraicSimplifier simplifier(default_options_);
2002 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2003
2004 // We can't write GmockMatch(m::Tuple(m::Parameter(0), m::Parameter(1), ...)
2005 // because this generates a template expression that's too complicated for our
2006 // MSVC to compile. :(
2007 SCOPED_TRACE(m->ToString());
2008 const HloInstruction* root;
2009 ASSERT_THAT(
2010 m->entry_computation()->root_instruction(),
2011 GmockMatch(
2012 m::Op(&root).WithOpcode(HloOpcode::kTuple).WithNumOperands(24)));
2013 for (int i = 0; i < root->operand_count(); i++) {
2014 SCOPED_TRACE(absl::StrCat("operand ", i));
2015 const HloInstruction* operand = root->operand(i);
2016 ASSERT_EQ(operand->opcode(), HloOpcode::kParameter);
2017 EXPECT_EQ(operand->parameter_number(), i % 12);
2018 }
2019 }
2020
TEST_F(AlgebraicSimplifierTest,NopMin)2021 TEST_F(AlgebraicSimplifierTest, NopMin) {
2022 const char* const hlo_string = R"(
2023 HloModule test
2024
2025 ENTRY test {
2026 p_s8 = s8[] parameter(0)
2027 p_u8 = u8[] parameter(1)
2028 p_s16 = s16[] parameter(2)
2029 p_u16 = u16[] parameter(3)
2030 p_s32 = s32[] parameter(4)
2031 p_u32 = u32[] parameter(5)
2032 p_s64 = s64[] parameter(6)
2033 p_u64 = u64[] parameter(7)
2034 p_f16 = f16[] parameter(8)
2035 p_bf16 = bf16[] parameter(9)
2036 p_f32 = f32[] parameter(10)
2037 p_f64 = f64[] parameter(11)
2038
2039 const_s8 = s8[] constant(127)
2040 const_u8 = u8[] constant(255)
2041 const_s16 = s16[] constant(32767)
2042 const_u16 = u16[] constant(65535)
2043 const_s32 = s32[] constant(2147483647)
2044 const_u32 = u32[] constant(4294967295)
2045 const_s64 = s64[] constant(9223372036854775807)
2046 const_u64 = u64[] constant(18446744073709551615)
2047 const_f16 = f16[] constant(inf)
2048 const_bf16 = bf16[] constant(inf)
2049 const_f32 = f32[] constant(inf)
2050 const_f64 = f64[] constant(inf)
2051
2052 min_s8 = s8[] minimum(p_s8, const_s8)
2053 min_u8 = u8[] minimum(p_u8, const_u8)
2054 min_s16 = s16[] minimum(p_s16, const_s16)
2055 min_u16 = u16[] minimum(p_u16, const_u16)
2056 min_s32 = s32[] minimum(p_s32, const_s32)
2057 min_u32 = u32[] minimum(p_u32, const_u32)
2058 min_s64 = s64[] minimum(p_s64, const_s64)
2059 min_u64 = u64[] minimum(p_u64, const_u64)
2060 min_f16 = f16[] minimum(p_f16, const_f16)
2061 min_bf16 = bf16[] minimum(p_bf16, const_bf16)
2062 min_f32 = f32[] minimum(p_f32, const_f32)
2063 min_f64 = f64[] minimum(p_f64, const_f64)
2064
2065 min2_s8 = s8[] minimum(const_s8, p_s8)
2066 min2_u8 = u8[] minimum(const_u8, p_u8)
2067 min2_s16 = s16[] minimum(const_s16, p_s16)
2068 min2_u16 = u16[] minimum(const_u16, p_u16)
2069 min2_s32 = s32[] minimum(const_s32, p_s32)
2070 min2_u32 = u32[] minimum(const_u32, p_u32)
2071 min2_s64 = s64[] minimum(const_s64, p_s64)
2072 min2_u64 = u64[] minimum(const_u64, p_u64)
2073 min2_f16 = f16[] minimum(const_f16, p_f16)
2074 min2_bf16 = bf16[] minimum(const_bf16, p_bf16)
2075 min2_f32 = f32[] minimum(const_f32, p_f32)
2076 min2_f64 = f64[] minimum(const_f64, p_f64)
2077
2078 ROOT tuple = tuple(min_s8, min_u8, min_s16, min_u16, min_s32, min_u32,
2079 min_s64, min_u64, min_f16, min_bf16, min_f32, min_f64,
2080 min2_s8, min2_u8, min2_s16, min2_u16, min2_s32, min2_u32,
2081 min2_s64, min2_u64, min2_f16, min2_bf16, min2_f32, min2_f64)
2082 }
2083 )";
2084
2085 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2086 AlgebraicSimplifier simplifier(default_options_);
2087 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2088
2089 SCOPED_TRACE(m->ToString());
2090
2091 // We can't write GmockMatch(m::Tuple(m::Parameter(0), m::Parameter(1), ...)
2092 // because this generates a template expression that's too complicated for our
2093 // MSVC to compile. :(
2094 SCOPED_TRACE(m->ToString());
2095 const HloInstruction* root;
2096 ASSERT_THAT(
2097 m->entry_computation()->root_instruction(),
2098 GmockMatch(
2099 m::Op(&root).WithOpcode(HloOpcode::kTuple).WithNumOperands(24)));
2100 for (int i = 0; i < root->operand_count(); i++) {
2101 SCOPED_TRACE(absl::StrCat("operand ", i));
2102 const HloInstruction* operand = root->operand(i);
2103 ASSERT_EQ(operand->opcode(), HloOpcode::kParameter);
2104 EXPECT_EQ(operand->parameter_number(), i % 12);
2105 }
2106 }
2107
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindow_Add)2108 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Add) {
2109 const char* const hlo_string = R"(
2110 HloModule test
2111
2112 add {
2113 p0 = f32[] parameter(0)
2114 p1 = f32[] parameter(1)
2115 ROOT add = f32[] add(p0, p1)
2116 }
2117
2118 ENTRY test {
2119 p = f32[16,32] parameter(0)
2120 constant = f32[] constant(0)
2121 ROOT reduce-window = reduce-window(p, constant), window={size=1x1}, to_apply=add
2122 }
2123 )";
2124
2125 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2126 AlgebraicSimplifier simplifier(default_options_);
2127 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2128 EXPECT_THAT(
2129 m->entry_computation()->root_instruction(),
2130 GmockMatch(m::AddAnyOrder(m::Parameter(),
2131 m::Broadcast(m::ConstantEffectiveScalar(0)))));
2132 }
2133
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindow_Min)2134 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Min) {
2135 const char* const hlo_string = R"(
2136 HloModule test
2137
2138 min {
2139 p0 = f32[] parameter(0)
2140 p1 = f32[] parameter(1)
2141 ROOT min = f32[] minimum(p0, p1)
2142 }
2143
2144 ENTRY test {
2145 p = f32[16,32] parameter(0)
2146 constant = f32[] constant(inf)
2147 ROOT reduce-window = reduce-window(p, constant), window={size=1x1}, to_apply=min
2148 }
2149 )";
2150
2151 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2152 AlgebraicSimplifier simplifier(default_options_);
2153 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2154 EXPECT_THAT(
2155 m->entry_computation()->root_instruction(),
2156 GmockMatch(m::MinimumAnyOrder(
2157 m::Parameter(), m::Broadcast(m::ConstantEffectiveScalar(
2158 std::numeric_limits<float>::infinity())))));
2159 }
2160
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindow_Max)2161 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindow_Max) {
2162 const char* const hlo_string = R"(
2163 HloModule test
2164
2165 max {
2166 p0 = f32[] parameter(0)
2167 p1 = f32[] parameter(1)
2168 ROOT max = f32[] maximum(p0, p1)
2169 }
2170
2171 ENTRY test {
2172 p = f32[16,32] parameter(0)
2173 constant = f32[] constant(-inf)
2174 ROOT reduce-window = reduce-window(p, constant), window={size=1x1}, to_apply=max
2175 }
2176 )";
2177
2178 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2179 AlgebraicSimplifier simplifier(default_options_);
2180 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2181 EXPECT_THAT(
2182 m->entry_computation()->root_instruction(),
2183 GmockMatch(m::MaximumAnyOrder(
2184 m::Parameter(), m::Broadcast(m::ConstantEffectiveScalar(
2185 -std::numeric_limits<float>::infinity())))));
2186 }
2187
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindowWithPad)2188 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindowWithPad) {
2189 const char* const hlo_string = R"(
2190 HloModule test
2191
2192 max {
2193 p0 = f32[] parameter(0)
2194 p1 = f32[] parameter(1)
2195 ROOT max = f32[] maximum(p0, p1)
2196 }
2197
2198 ENTRY test {
2199 p = f32[16,32] parameter(0)
2200 constant = f32[] constant(-inf)
2201 ROOT reduce-window = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4}, to_apply=max
2202 }
2203 )";
2204
2205 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2206 AlgebraicSimplifier simplifier(default_options_);
2207 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2208 EXPECT_THAT(
2209 m->entry_computation()->root_instruction(),
2210 GmockMatch(m::Pad(
2211 m::MaximumAnyOrder(m::Parameter(),
2212 m::Broadcast(m::ConstantEffectiveScalar(
2213 -std::numeric_limits<float>::infinity()))),
2214 m::ConstantEffectiveScalar(
2215 -std::numeric_limits<float>::infinity()))));
2216 }
2217
TEST_F(AlgebraicSimplifierTest,TrivialReduceWindowWithUnsupported)2218 TEST_F(AlgebraicSimplifierTest, TrivialReduceWindowWithUnsupported) {
2219 const char* const hlo_string = R"(
2220 HloModule test
2221
2222 max {
2223 p0 = f32[] parameter(0)
2224 p1 = f32[] parameter(1)
2225 ROOT max = f32[] maximum(p0, p1)
2226 }
2227
2228 unsupported_fn {
2229 p0 = f32[] parameter(0)
2230 ROOT p1 = f32[] parameter(1)
2231 }
2232
2233 ENTRY test {
2234 p = f32[16,32] parameter(0)
2235 constant = f32[] constant(-inf)
2236 a = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 stride=1x2}, to_apply=max
2237 b = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 lhs_dilate=2x1}, to_apply=max
2238 c = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 rhs_dilate=2x1}, to_apply=max
2239 d = reduce-window(p, constant), window={size=1x1 pad=1_2x3_4 rhs_reversal=1x1}, to_apply=max
2240 e = reduce-window(p, constant), window={size=1x1}, to_apply=unsupported_fn
2241 }
2242 )";
2243
2244 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
2245 AlgebraicSimplifier simplifier(default_options_);
2246 ASSERT_FALSE(RunHloPass(&simplifier, m.get()).ValueOrDie());
2247 }
2248
TEST_F(AlgebraicSimplifierTest,ZeroSizedPad)2249 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
2250 auto m = CreateNewVerifiedModule();
2251 auto builder = HloComputation::Builder(TestName());
2252 HloInstruction* param =
2253 builder.AddInstruction(HloInstruction::CreateParameter(
2254 0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
2255 PaddingConfig padding;
2256 for (int i = 0; i < 2; ++i) {
2257 PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
2258 dimension->set_edge_padding_low(1);
2259 dimension->set_edge_padding_high(1);
2260 dimension->set_interior_padding(0);
2261 }
2262 builder.AddInstruction(HloInstruction::CreatePad(
2263 ShapeUtil::MakeShape(F32, {5, 2}), param,
2264 builder.AddInstruction(
2265 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
2266 padding));
2267 m->AddEntryComputation(builder.Build());
2268 EXPECT_THAT(m->entry_computation()->root_instruction(),
2269 GmockMatch(m::Pad(m::Parameter(0), m::Constant())));
2270 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2271 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2272 EXPECT_THAT(m->entry_computation()->root_instruction(),
2273 GmockMatch(m::Broadcast(m::Constant())));
2274 }
2275
TEST_F(AlgebraicSimplifierTest,ReshapeBroadcast)2276 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
2277 auto m = CreateNewVerifiedModule();
2278 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2279
2280 auto builder = HloComputation::Builder(TestName());
2281 auto op = builder.AddInstruction(HloInstruction::CreateParameter(
2282 0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
2283 auto reshape1 = builder.AddInstruction(
2284 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
2285 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
2286 ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
2287 builder.AddInstruction(HloInstruction::CreateReshape(
2288 ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
2289
2290 auto computation = builder.Build();
2291 m->AddEntryComputation(std::move(computation));
2292
2293 EXPECT_THAT(m->entry_computation()->root_instruction(),
2294 GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op))))));
2295
2296 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
2297 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2298
2299 EXPECT_THAT(m->entry_computation()->root_instruction(), op);
2300 }
2301
2302 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
TEST_F(AlgebraicSimplifierTest,ConvertBetweenSameType)2303 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
2304 auto m = CreateNewVerifiedModule();
2305 HloComputation::Builder builder(TestName());
2306 HloInstruction* input = builder.AddInstruction(
2307 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
2308 builder.AddInstruction(
2309 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
2310
2311 auto computation = m->AddEntryComputation(builder.Build());
2312
2313 EXPECT_THAT(computation->root_instruction(),
2314 GmockMatch(m::Convert(m::Op().Is(input))));
2315
2316 AlgebraicSimplifier simplifier(default_options_);
2317 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2318
2319 EXPECT_THAT(computation->root_instruction(), input);
2320 }
2321
2322 // Test that convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
2323 // $TYPE2 and convert(A, $TYP1) is an upcast.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairUpCast)2324 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairUpCast) {
2325 auto m = CreateNewVerifiedModule();
2326 HloComputation::Builder builder(TestName());
2327 HloInstruction* input =
2328 builder.AddInstruction(HloInstruction::CreateParameter(
2329 0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
2330 "param"));
2331 HloInstruction* convert_1 =
2332 builder.AddInstruction(HloInstruction::CreateConvert(
2333 ShapeUtil::ChangeElementType(input->shape(), F32), input));
2334 builder.AddInstruction(HloInstruction::CreateConvert(
2335 ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
2336
2337 auto computation = m->AddEntryComputation(builder.Build());
2338
2339 EXPECT_THAT(computation->root_instruction(),
2340 GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
2341
2342 AlgebraicSimplifier simplifier(default_options_);
2343 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2344
2345 EXPECT_THAT(computation->root_instruction(), input);
2346 }
2347
2348 // Test that convert(convert(A, $TYPE1), $TYPE2) is NOT simplified to A even if
2349 // A is of $TYPE2 since convert(A, $TYP1) is a downcast.
TEST_F(AlgebraicSimplifierTest,DoNotEliminateConvertPairDownCast)2350 TEST_F(AlgebraicSimplifierTest, DoNotEliminateConvertPairDownCast) {
2351 auto m = CreateNewVerifiedModule();
2352 HloComputation::Builder builder(TestName());
2353 HloInstruction* input =
2354 builder.AddInstruction(HloInstruction::CreateParameter(
2355 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2356 "param"));
2357 HloInstruction* convert_1 =
2358 builder.AddInstruction(HloInstruction::CreateConvert(
2359 ShapeUtil::ChangeElementType(input->shape(), F16), input));
2360 builder.AddInstruction(HloInstruction::CreateConvert(
2361 ShapeUtil::ChangeElementType(convert_1->shape(), F32), convert_1));
2362
2363 auto computation = m->AddEntryComputation(builder.Build());
2364
2365 EXPECT_THAT(computation->root_instruction(),
2366 GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
2367 AlgebraicSimplifier simplifier(default_options_);
2368 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2369 EXPECT_THAT(computation->root_instruction(),
2370 GmockMatch(m::Convert(m::Convert(m::Op().Is(input)))));
2371 }
2372
2373 // Test that Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1),
2374 // $TYPE2)), convert(convert(A, $TYPE1), $TYPE2)) is simplified to
2375 // Tuple(convert(A, $TYPE1) , floor(A), A) showing a case where the first
2376 // convert has a fan-out.
TEST_F(AlgebraicSimplifierTest,EliminateConvertPairMultiOut)2377 TEST_F(AlgebraicSimplifierTest, EliminateConvertPairMultiOut) {
2378 auto m = CreateNewVerifiedModule();
2379 HloComputation::Builder builder(TestName());
2380 HloInstruction* input =
2381 builder.AddInstruction(HloInstruction::CreateParameter(
2382 0, ShapeUtil::MakeShapeWithLayout(F16, {1, 14, 14, 64}, {3, 2, 1, 0}),
2383 "param"));
2384 HloInstruction* convert_1 =
2385 builder.AddInstruction(HloInstruction::CreateConvert(
2386 ShapeUtil::ChangeElementType(input->shape(), F32), input));
2387 HloInstruction* convert_2 =
2388 builder.AddInstruction(HloInstruction::CreateConvert(
2389 ShapeUtil::ChangeElementType(convert_1->shape(), F16), convert_1));
2390
2391 HloInstruction* floor = builder.AddInstruction(HloInstruction::CreateUnary(
2392 convert_2->shape(), HloOpcode::kFloor, convert_2));
2393
2394 // Collect all the reshapes into a tuple so they are not dead.
2395 builder.AddInstruction(
2396 HloInstruction::CreateTuple({convert_1, convert_2, floor}));
2397
2398 auto computation = m->AddEntryComputation(builder.Build());
2399 EXPECT_THAT(computation->root_instruction(),
2400 GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(convert_2),
2401 m::Op().Is(floor))));
2402
2403 AlgebraicSimplifier simplifier(default_options_);
2404 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2405
2406 EXPECT_THAT(computation->root_instruction(),
2407 GmockMatch(m::Tuple(m::Op().Is(convert_1), m::Op().Is(input),
2408 m::Floor(m::Op().Is(input)))));
2409 }
2410
2411 // Test that copies are removed.
TEST_F(AlgebraicSimplifierTest,RemoveCopy)2412 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
2413 auto m = CreateNewVerifiedModule();
2414 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2415 HloComputation::Builder builder(TestName());
2416 HloInstruction* param0 = builder.AddInstruction(
2417 HloInstruction::CreateParameter(0, r0f32, "param0"));
2418 builder.AddInstruction(
2419 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2420
2421 auto computation = m->AddEntryComputation(builder.Build());
2422
2423 EXPECT_THAT(computation->root_instruction(),
2424 GmockMatch(m::Copy(m::Parameter(0))));
2425
2426 AlgebraicSimplifier simplifier(default_options_);
2427 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2428
2429 EXPECT_THAT(computation->root_instruction(), param0);
2430 }
2431
TEST_F(AlgebraicSimplifierTest,CopyOfReshapeOfCopyEqualsBitcast)2432 TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) {
2433 auto m = CreateNewVerifiedModule();
2434 HloComputation::Builder builder(TestName());
2435 HloInstruction* param =
2436 builder.AddInstruction(HloInstruction::CreateParameter(
2437 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2438 "param"));
2439 HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2440 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2441 HloOpcode::kCopy, param));
2442 HloInstruction* reshape =
2443 builder.AddInstruction(HloInstruction::CreateReshape(
2444 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy));
2445 builder.AddInstruction(HloInstruction::CreateUnary(
2446 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}),
2447 HloOpcode::kCopy, reshape));
2448 auto computation = m->AddEntryComputation(builder.Build());
2449 EXPECT_THAT(computation->root_instruction(),
2450 GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0))))));
2451
2452 AlgebraicSimplifierOptions options;
2453 options.set_is_layout_sensitive(true);
2454 AlgebraicSimplifier simplifier(options);
2455 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2456 // Verify that the copy of reshape of copy is replaced.
2457 EXPECT_THAT(computation->root_instruction(),
2458 GmockMatch(m::Bitcast(m::Parameter(0))));
2459 }
2460
TEST_F(AlgebraicSimplifierTest,ReshapeOfCopyEqualsBitcast)2461 TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) {
2462 auto m = CreateNewVerifiedModule();
2463 HloComputation::Builder builder(TestName());
2464 HloInstruction* param =
2465 builder.AddInstruction(HloInstruction::CreateParameter(
2466 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}),
2467 "param"));
2468 HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
2469 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2470 HloOpcode::kCopy, param));
2471 builder.AddInstruction(HloInstruction::CreateReshape(
2472 ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy));
2473
2474 auto computation = m->AddEntryComputation(builder.Build());
2475 EXPECT_THAT(computation->root_instruction(),
2476 GmockMatch(m::Reshape(m::Copy(m::Parameter(0)))));
2477
2478 AlgebraicSimplifierOptions options;
2479 options.set_is_layout_sensitive(true);
2480 AlgebraicSimplifier simplifier(options);
2481 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2482 // Verify that the copy of reshape of copy is replaced.
2483 EXPECT_THAT(computation->root_instruction(),
2484 GmockMatch(m::Bitcast(m::Parameter(0))));
2485 }
2486
TEST_F(AlgebraicSimplifierTest,CopyEqualsBitcast)2487 TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
2488 auto m = CreateNewVerifiedModule();
2489 HloComputation::Builder builder(TestName());
2490 HloInstruction* param =
2491 builder.AddInstruction(HloInstruction::CreateParameter(
2492 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}),
2493 "param"));
2494 builder.AddInstruction(HloInstruction::CreateUnary(
2495 ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}),
2496 HloOpcode::kCopy, param));
2497 auto computation = m->AddEntryComputation(builder.Build());
2498 EXPECT_THAT(computation->root_instruction(),
2499 GmockMatch(m::Copy(m::Parameter(0))));
2500
2501 AlgebraicSimplifierOptions options(
2502 [](const Shape&, const Shape&) { return false; });
2503 options.set_is_layout_sensitive(true);
2504 AlgebraicSimplifier simplifier1(options);
2505 ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie());
2506 // Verify that the copy is not replaced.
2507 EXPECT_THAT(computation->root_instruction(),
2508 GmockMatch(m::Copy(m::Parameter(0))));
2509
2510 AlgebraicSimplifierOptions options2;
2511 options2.set_is_layout_sensitive(true);
2512 AlgebraicSimplifier simplifier2(options2);
2513 EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie());
2514 // Verify that the copy is replaced.
2515 EXPECT_THAT(computation->root_instruction(),
2516 GmockMatch(m::Bitcast(m::Parameter(0))));
2517 }
2518
2519 // Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveUnaryConcatenate)2520 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
2521 auto m = CreateNewVerifiedModule();
2522 Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2523 HloComputation::Builder builder(TestName());
2524 HloInstruction* param0 = builder.AddInstruction(
2525 HloInstruction::CreateParameter(0, r1f32, "param0"));
2526 builder.AddInstruction(
2527 HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
2528
2529 auto computation = m->AddEntryComputation(builder.Build());
2530
2531 EXPECT_THAT(computation->root_instruction(),
2532 GmockMatch(m::Concatenate(m::Parameter(0))));
2533
2534 AlgebraicSimplifier simplifier(default_options_);
2535 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2536
2537 EXPECT_THAT(computation->root_instruction(), param0);
2538 }
2539
TEST_F(AlgebraicSimplifierTest,SliceReverse)2540 TEST_F(AlgebraicSimplifierTest, SliceReverse) {
2541 const char* const hlo_string = R"(
2542 HloModule module
2543
2544 ENTRY test {
2545 param = f32[6,7,32] parameter(0)
2546 constant = f32[] constant(0)
2547 pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2548 rev = f32[8,7,32] reverse(pad), dimensions={0,2}
2549 slice = f32[1,7,32] slice(rev), slice={[2:3:1], [0:7:1], [0:32:1]}
2550 ROOT tuple = (f32[1,7,32]) tuple(slice)
2551 })";
2552
2553 TF_ASSERT_OK_AND_ASSIGN(auto module,
2554 ParseAndReturnVerifiedModule(hlo_string));
2555 AlgebraicSimplifier simplifier(default_options_);
2556 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2557 HloComputation* computation = module->entry_computation();
2558 EXPECT_THAT(computation->root_instruction(),
2559 GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2560 const HloInstruction* slice =
2561 computation->root_instruction()->operand(0)->operand(0);
2562 EXPECT_TRUE(
2563 ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 7, 32})));
2564 // slice start,limit of 0th and 2nd dimensions are changed
2565 // while 1st dimension's slice start, limit remains the same since
2566 // it is not reversed.
2567 EXPECT_EQ(slice->slice_starts(0), 5);
2568 EXPECT_EQ(slice->slice_limits(0), 6);
2569 EXPECT_EQ(slice->slice_starts(1), 0);
2570 EXPECT_EQ(slice->slice_limits(1), 7);
2571 EXPECT_EQ(slice->slice_starts(2), 0);
2572 EXPECT_EQ(slice->slice_limits(2), 32);
2573 EXPECT_EQ(slice->slice_strides(0), 1);
2574 EXPECT_EQ(slice->slice_strides(1), 1);
2575 EXPECT_EQ(slice->slice_strides(2), 1);
2576 }
2577
TEST_F(AlgebraicSimplifierTest,SliceReverseNonUnitEvenOddStrides)2578 TEST_F(AlgebraicSimplifierTest, SliceReverseNonUnitEvenOddStrides) {
2579 const char* const hlo_string = R"(
2580 HloModule module
2581
2582 ENTRY test {
2583 param = f32[6,7,32] parameter(0)
2584 constant = f32[] constant(0)
2585 pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0
2586 rev = f32[8,7,32] reverse(pad), dimensions={0,1,2}
2587 slice = f32[1,2,7] slice(rev), slice={[2:3:2], [0:7:4], [0:32:5]}
2588 ROOT tuple = (f32[1,2,7]) tuple(slice)
2589 })";
2590 TF_ASSERT_OK_AND_ASSIGN(auto module,
2591 ParseAndReturnVerifiedModule(hlo_string));
2592
2593 AlgebraicSimplifier simplifier(default_options_);
2594 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
2595 HloComputation* computation = module->entry_computation();
2596 EXPECT_THAT(computation->root_instruction(),
2597 GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad())))));
2598 const HloInstruction* slice =
2599 computation->root_instruction()->operand(0)->operand(0);
2600 EXPECT_TRUE(
2601 ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 2, 7})));
2602 // slice start,limit of all dimensions are changed
2603 EXPECT_EQ(slice->slice_starts(0), 5);
2604 EXPECT_EQ(slice->slice_limits(0), 6);
2605 EXPECT_EQ(slice->slice_starts(1), 2);
2606 EXPECT_EQ(slice->slice_limits(1), 7);
2607 EXPECT_EQ(slice->slice_starts(2), 1);
2608 EXPECT_EQ(slice->slice_limits(2), 32);
2609 EXPECT_EQ(slice->slice_strides(0), 2);
2610 EXPECT_EQ(slice->slice_strides(1), 4);
2611 EXPECT_EQ(slice->slice_strides(2), 5);
2612 }
2613
2614 // Test that empty operands of concatenates are removed.
TEST_F(AlgebraicSimplifierTest,RemoveEmptyConcatenateOperands)2615 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
2616 auto m = CreateNewVerifiedModule();
2617 const int kParamLength = 100;
2618 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2619 HloComputation::Builder builder(TestName());
2620 HloInstruction* param0 = builder.AddInstruction(
2621 HloInstruction::CreateParameter(0, r1f32, "param0"));
2622 HloInstruction* param1 = builder.AddInstruction(
2623 HloInstruction::CreateParameter(1, r1f32, "param1"));
2624 HloInstruction* empty_literal = builder.AddInstruction(
2625 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2626 HloInstruction* empty_slice =
2627 builder.AddInstruction(HloInstruction::CreateSlice(
2628 ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
2629 Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
2630 builder.AddInstruction(HloInstruction::CreateConcatenate(
2631 result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
2632
2633 auto computation = m->AddEntryComputation(builder.Build());
2634
2635 EXPECT_THAT(computation->root_instruction(),
2636 GmockMatch(m::Concatenate(
2637 m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0),
2638 m::Op().Is(empty_slice), m::Parameter(1))));
2639
2640 AlgebraicSimplifier simplifier(default_options_);
2641 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2642
2643 EXPECT_THAT(computation->root_instruction(),
2644 GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0),
2645 m::Parameter(1))));
2646 }
2647
2648 // Test that reduce of concat is simplified.
TEST_F(AlgebraicSimplifierTest,SimplifyReduceOfConcat)2649 TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
2650 auto m = CreateNewVerifiedModule();
2651 const int kParamLength = 100;
2652 Shape r3f32 =
2653 ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
2654 HloComputation::Builder builder(TestName());
2655 HloInstruction* param0 = builder.AddInstruction(
2656 HloInstruction::CreateParameter(0, r3f32, "param0"));
2657 HloInstruction* param1 = builder.AddInstruction(
2658 HloInstruction::CreateParameter(1, r3f32, "param1"));
2659 HloInstruction* param2 = builder.AddInstruction(
2660 HloInstruction::CreateParameter(2, r3f32, "param2"));
2661 Shape concat_shape =
2662 ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
2663 HloInstruction* Concatenate =
2664 builder.AddInstruction(HloInstruction::CreateConcatenate(
2665 concat_shape, {param0, param1, param2}, 1));
2666 HloComputation* add_computation = nullptr;
2667 {
2668 HloComputation::Builder builder(TestName() + ".add");
2669 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2670 HloInstruction* p0 = builder.AddInstruction(
2671 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
2672 HloInstruction* p1 = builder.AddInstruction(
2673 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
2674 builder.AddInstruction(
2675 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
2676 add_computation = m->AddEmbeddedComputation(builder.Build());
2677 }
2678 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
2679 Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
2680
2681 HloInstruction* zero = builder.AddInstruction(
2682 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
2683 builder.AddInstruction(HloInstruction::CreateReduce(
2684 reduce_shape, Concatenate, zero, {1, 2}, add_computation));
2685
2686 auto computation = m->AddEntryComputation(builder.Build());
2687
2688 AlgebraicSimplifier simplifier(default_options_);
2689 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2690
2691 EXPECT_THAT(
2692 computation->root_instruction(),
2693 GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)),
2694 m::Reduce(m::Parameter(1), m::Op().Is(zero))),
2695 m::Reduce(m::Parameter(2), m::Op().Is(zero)))));
2696 }
2697
2698 // Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest,OnlyEmptyConcatenateOperands)2699 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
2700 auto m = CreateNewVerifiedModule();
2701 const int kParamLength = 100;
2702 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
2703 HloComputation::Builder builder(TestName());
2704 HloInstruction* param0 = builder.AddInstruction(
2705 HloInstruction::CreateParameter(0, r1f32, "param0"));
2706 HloInstruction* empty_literal = builder.AddInstruction(
2707 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
2708 HloInstruction* empty_slice =
2709 builder.AddInstruction(HloInstruction::CreateSlice(
2710 ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
2711 Shape result_shape = ShapeUtil::MakeShape(F32, {0});
2712 builder.AddInstruction(HloInstruction::CreateConcatenate(
2713 result_shape, {empty_literal, empty_slice}, 0));
2714
2715 auto computation = m->AddEntryComputation(builder.Build());
2716
2717 EXPECT_THAT(computation->root_instruction(),
2718 GmockMatch(m::Concatenate(m::Op().Is(empty_literal),
2719 m::Op().Is(empty_slice))));
2720
2721 AlgebraicSimplifier simplifier(default_options_);
2722 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2723
2724 EXPECT_EQ(computation->root_instruction(), empty_literal);
2725 }
2726
2727 // Test that concat with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,ConcatenateOfBroadcastBecomesPad)2728 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
2729 auto m = CreateNewVerifiedModule();
2730 Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2731 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2732 HloComputation::Builder builder(TestName());
2733 HloInstruction* param0 = builder.AddInstruction(
2734 HloInstruction::CreateParameter(0, r1f32, "param0"));
2735 HloInstruction* param1 = builder.AddInstruction(
2736 HloInstruction::CreateParameter(1, r0f32, "param1"));
2737 HloInstruction* broadcast = builder.AddInstruction(
2738 HloInstruction::CreateBroadcast(r1f32, param1, {}));
2739 builder.AddInstruction(HloInstruction::CreateConcatenate(
2740 ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
2741
2742 auto computation = m->AddEntryComputation(builder.Build());
2743
2744 AlgebraicSimplifier simplifier(default_options_);
2745 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2746 EXPECT_THAT(computation->root_instruction(),
2747 GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1))));
2748 }
2749
TEST_F(AlgebraicSimplifierTest,SimplifyConcatenateOfSlices)2750 TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
2751 auto m = CreateNewVerifiedModule();
2752 Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
2753 Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 90});
2754 HloComputation::Builder builder(TestName());
2755 HloInstruction* param0 = builder.AddInstruction(
2756 HloInstruction::CreateParameter(0, r2f32, "param0"));
2757 HloInstruction* param1 = builder.AddInstruction(
2758 HloInstruction::CreateParameter(1, r2f32, "param1"));
2759
2760 HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
2761 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
2762 /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
2763
2764 // Cannot merge 'slice0' and 'slice1' because of different start indices in
2765 // dimension 0.
2766 HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
2767 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
2768 /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
2769
2770 // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
2771 HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
2772 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
2773 /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
2774
2775 // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
2776 HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
2777 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
2778 /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
2779
2780 // Can merge 'slice3' and 'slice4'.
2781 HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
2782 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
2783 /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
2784
2785 // Can merge 'slice4' and 'slice5'.
2786 HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
2787 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
2788 /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
2789
2790 // Cannot merge 'slice5' and 'slice6' because of overlap.
2791 HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
2792 ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
2793 /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
2794
2795 // Cannot merge 'slice6' and 'slice7' because of slicing from a different
2796 // parameter.
2797 HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
2798 ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
2799 /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
2800 // Can merge 'slice7' and 'slice8'.
2801 HloInstruction* slice8 = builder.AddInstruction(HloInstruction::CreateSlice(
2802 ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 89},
2803 /*limit_indices=*/{100, 99}, /*strides=*/{1, 1}));
2804
2805 builder.AddInstruction(HloInstruction::CreateConcatenate(
2806 concat_shape,
2807 {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8},
2808 1));
2809 auto computation = m->AddEntryComputation(builder.Build());
2810
2811 AlgebraicSimplifier simplifier(default_options_);
2812 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2813 auto s = m::Slice(m::Parameter(0));
2814 EXPECT_THAT(
2815 computation->root_instruction(),
2816 GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1)))));
2817 // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
2818 // shape should have dimensions {50, 30}.
2819 EXPECT_TRUE(
2820 ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
2821 ShapeUtil::MakeShape(F32, {50, 30})));
2822 EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
2823
2824 // The operand 6 should be merge of 'slice7' and 'slice8', so its
2825 // shape should have dimensions {50, 20}
2826 EXPECT_TRUE(
2827 ShapeUtil::Equal(computation->root_instruction()->operand(5)->shape(),
2828 ShapeUtil::MakeShape(F32, {50, 20})));
2829 }
2830
2831 // Test that a simplification which changes layouts is not performed if layout
2832 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithDifferentLayout)2833 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
2834 auto m = CreateNewVerifiedModule();
2835 HloComputation::Builder builder(TestName());
2836 HloInstruction* param0 =
2837 builder.AddInstruction(HloInstruction::CreateParameter(
2838 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2839 HloInstruction* copy = builder.AddInstruction(
2840 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2841
2842 auto computation = m->AddEntryComputation(builder.Build());
2843
2844 // Set to different layouts.
2845 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2846 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
2847
2848 EXPECT_THAT(computation->root_instruction(),
2849 GmockMatch(m::Copy(m::Parameter(0))));
2850
2851 AlgebraicSimplifierOptions options;
2852 options.set_is_layout_sensitive(true);
2853 AlgebraicSimplifier simplifier(options);
2854 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2855
2856 // Copy has not been removed.
2857 EXPECT_THAT(computation->root_instruction(),
2858 GmockMatch(m::Copy(m::Parameter(0))));
2859 }
2860
2861 // Test that a simplification which preserves layouts is performed if layout
2862 // sensitive is true.
TEST_F(AlgebraicSimplifierTest,CopyWithSameLayout)2863 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
2864 auto m = CreateNewVerifiedModule();
2865 HloComputation::Builder builder(TestName());
2866 HloInstruction* param0 =
2867 builder.AddInstruction(HloInstruction::CreateParameter(
2868 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2869 HloInstruction* copy = builder.AddInstruction(
2870 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
2871
2872 auto computation = m->AddEntryComputation(builder.Build());
2873
2874 // Set to same layouts.
2875 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2876 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2877
2878 EXPECT_THAT(computation->root_instruction(),
2879 GmockMatch(m::Copy(m::Parameter(0))));
2880
2881 AlgebraicSimplifierOptions options;
2882 options.set_is_layout_sensitive(true);
2883 AlgebraicSimplifier simplifier(options);
2884 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2885
2886 // Copy has been removed.
2887 EXPECT_THAT(computation->root_instruction(), param0);
2888 }
2889
2890 // Test that a reshape which could be replaced with a bitcast is not if
2891 // add_bitcasts is false.
TEST_F(AlgebraicSimplifierTest,NoBitcastAdded)2892 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
2893 auto m = CreateNewVerifiedModule();
2894 HloComputation::Builder builder(TestName());
2895 HloInstruction* param0 =
2896 builder.AddInstruction(HloInstruction::CreateParameter(
2897 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2898 HloInstruction* reshape =
2899 builder.AddInstruction(HloInstruction::CreateReshape(
2900 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2901
2902 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2903 *reshape->mutable_shape()->mutable_layout() =
2904 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2905
2906 auto computation = m->AddEntryComputation(builder.Build());
2907
2908 EXPECT_THAT(computation->root_instruction(),
2909 GmockMatch(m::Reshape(m::Parameter(0))));
2910
2911 AlgebraicSimplifierOptions options(
2912 [](const Shape&, const Shape&) { return false; });
2913 options.set_is_layout_sensitive(true);
2914 AlgebraicSimplifier simplifier(options);
2915 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
2916
2917 // Reshape is not replaced with a bitcast.
2918 EXPECT_THAT(computation->root_instruction(),
2919 GmockMatch(m::Reshape(m::Parameter(0))));
2920 }
2921
2922 // Test transforming reshapes and transposes of rng.
TEST_F(AlgebraicSimplifierTest,ReshapeOfTransposeOfRngToRng)2923 TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
2924 auto m = CreateNewVerifiedModule();
2925 HloComputation::Builder builder(TestName());
2926 HloInstruction* zero = builder.AddInstruction(
2927 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
2928 HloInstruction* one = builder.AddInstruction(
2929 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
2930 HloInstruction* rng0 = builder.AddInstruction(
2931 HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
2932 RandomDistribution::RNG_UNIFORM, {zero, one}));
2933
2934 HloInstruction* transpose = builder.AddInstruction(
2935 HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
2936 Shape reshape_shape = builder
2937 .AddInstruction(HloInstruction::CreateReshape(
2938 ShapeUtil::MakeShape(F32, {4}), transpose))
2939 ->shape();
2940
2941 auto computation = m->AddEntryComputation(builder.Build());
2942
2943 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
2944 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
2945
2946 // Verify that reshape(transpose(rng)) is replace by a single rng of the
2947 // same shape as the reshape.
2948 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng()));
2949 EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
2950 reshape_shape));
2951 }
2952
2953 // Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest,ReshapeReplacedWithBitcast)2954 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
2955 auto m = CreateNewVerifiedModule();
2956 HloComputation::Builder builder(TestName());
2957 HloInstruction* param0 =
2958 builder.AddInstruction(HloInstruction::CreateParameter(
2959 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
2960 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
2961
2962 // Reshape which can be transformed into a bitcast.
2963 HloInstruction* transformable_reshape =
2964 builder.AddInstruction(HloInstruction::CreateReshape(
2965 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2966 *transformable_reshape->mutable_shape()->mutable_layout() =
2967 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2968
2969 // Reshape does not just add degenerate dimensions.
2970 HloInstruction* dimensions_wrong_reshape =
2971 builder.AddInstruction(HloInstruction::CreateReshape(
2972 ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
2973 *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
2974 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
2975
2976 // Reshape has wrong layout.
2977 HloInstruction* layout_wrong_reshape =
2978 builder.AddInstruction(HloInstruction::CreateReshape(
2979 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
2980 *layout_wrong_reshape->mutable_shape()->mutable_layout() =
2981 LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
2982
2983 // Collect all the reshapes into a tuple so they are not dead.
2984 builder.AddInstruction(HloInstruction::CreateTuple(
2985 {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
2986
2987 auto computation = m->AddEntryComputation(builder.Build());
2988
2989 EXPECT_THAT(computation->root_instruction(),
2990 GmockMatch(m::Tuple(m::Op().Is(transformable_reshape),
2991 m::Op().Is(dimensions_wrong_reshape),
2992 m::Op().Is(layout_wrong_reshape))));
2993
2994 AlgebraicSimplifierOptions options;
2995 options.set_is_layout_sensitive(true);
2996 AlgebraicSimplifier simplifier(options);
2997 simplifier.Run(m.get()).ValueOrDie();
2998
2999 // Verify that only the first reshape is replaced.
3000 EXPECT_THAT(
3001 computation->root_instruction(),
3002 GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape),
3003 m::Op().Is(layout_wrong_reshape))));
3004 }
3005
3006 // Regression test for a bug where if we failed to sink a reshape, we'd set the
3007 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkReshapeDoesntAffectChangedBit)3008 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
3009 auto m = CreateNewVerifiedModule();
3010 HloComputation::Builder builder(TestName());
3011
3012 // This add (param0 + 0) can be simplified.
3013 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
3014 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3015 shape, HloOpcode::kAdd,
3016 builder.AddInstruction(
3017 HloInstruction::CreateParameter(0, shape, "param0")),
3018 builder.AddInstruction(HloInstruction::CreateConstant(
3019 LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
3020
3021 builder.AddInstruction(
3022 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
3023
3024 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
3025 m->AddEntryComputation(builder.Build());
3026 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3027 }
3028
3029 // Regression test for a bug where if we failed to sink a reshape, we'd set the
3030 // 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest,FailureToSinkBroadcastDoesntAffectChangedBit)3031 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
3032 auto m = CreateNewVerifiedModule();
3033 HloComputation::Builder builder(TestName());
3034
3035 // This add (param0 + 0) can be simplified.
3036 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
3037 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3038 shape, HloOpcode::kAdd,
3039 builder.AddInstruction(
3040 HloInstruction::CreateParameter(0, shape, "param0")),
3041 builder.AddInstruction(HloInstruction::CreateConstant(
3042 LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
3043
3044 builder.AddInstruction(
3045 HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
3046 /*broadcast_dimensions=*/{0, 1}));
3047
3048 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
3049 m->AddEntryComputation(builder.Build());
3050 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3051 }
3052
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast1)3053 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
3054 auto m = CreateNewVerifiedModule();
3055 HloComputation::Builder builder(TestName());
3056 HloInstruction* param =
3057 builder.AddInstruction(HloInstruction::CreateParameter(
3058 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
3059 *param->mutable_shape()->mutable_layout() =
3060 LayoutUtil::MakeLayout({1, 2, 0, 3});
3061
3062 HloInstruction* transpose =
3063 builder.AddInstruction(HloInstruction::CreateTranspose(
3064 ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
3065 *transpose->mutable_shape()->mutable_layout() =
3066 LayoutUtil::MakeLayout({0, 1, 2, 3});
3067
3068 auto computation = m->AddEntryComputation(builder.Build());
3069
3070 EXPECT_THAT(computation->root_instruction(),
3071 GmockMatch(m::Transpose(m::Parameter(0))));
3072
3073 AlgebraicSimplifierOptions options;
3074 options.set_is_layout_sensitive(true);
3075 AlgebraicSimplifier simplifier(options);
3076 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3077
3078 // Verify that the transpose is replaced.
3079 EXPECT_THAT(computation->root_instruction(),
3080 GmockMatch(m::Bitcast(m::Parameter(0))));
3081 }
3082
TEST_F(AlgebraicSimplifierTest,TransposeEqualsBitcast2)3083 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
3084 auto m = CreateNewVerifiedModule();
3085 HloComputation::Builder builder(TestName());
3086 HloInstruction* param =
3087 builder.AddInstruction(HloInstruction::CreateParameter(
3088 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
3089 *param->mutable_shape()->mutable_layout() =
3090 LayoutUtil::MakeLayout({1, 2, 3, 0});
3091
3092 HloInstruction* transpose =
3093 builder.AddInstruction(HloInstruction::CreateTranspose(
3094 ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
3095 *transpose->mutable_shape()->mutable_layout() =
3096 LayoutUtil::MakeLayout({3, 1, 2, 0});
3097
3098 auto computation = m->AddEntryComputation(builder.Build());
3099
3100 EXPECT_THAT(computation->root_instruction(),
3101 GmockMatch(m::Transpose(m::Parameter(0))));
3102
3103 AlgebraicSimplifierOptions options;
3104 options.set_is_layout_sensitive(true);
3105 AlgebraicSimplifier simplifier(options);
3106 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3107
3108 // Verify that the transpose is replaced.
3109 EXPECT_THAT(computation->root_instruction(),
3110 GmockMatch(m::Bitcast(m::Parameter(0))));
3111 }
3112
TEST_F(AlgebraicSimplifierTest,ReshapesMerged)3113 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
3114 auto m = CreateNewVerifiedModule();
3115 HloComputation::Builder builder(TestName());
3116 HloInstruction* param0 =
3117 builder.AddInstruction(HloInstruction::CreateParameter(
3118 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
3119
3120 HloInstruction* reshape1 =
3121 builder.AddInstruction(HloInstruction::CreateReshape(
3122 ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
3123
3124 builder.AddInstruction(HloInstruction::CreateReshape(
3125 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
3126
3127 auto computation = m->AddEntryComputation(builder.Build());
3128
3129 EXPECT_THAT(computation->root_instruction(),
3130 GmockMatch(m::Reshape(m::Reshape(m::Parameter(0)))));
3131
3132 AlgebraicSimplifier simplifier(default_options_);
3133 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3134
3135 EXPECT_THAT(computation->root_instruction(),
3136 GmockMatch(m::Reshape(m::Parameter(0))));
3137 }
3138
TEST_F(AlgebraicSimplifierTest,CopiesMerged)3139 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
3140 auto m = CreateNewVerifiedModule();
3141 HloComputation::Builder builder(TestName());
3142 HloInstruction* param0 =
3143 builder.AddInstruction(HloInstruction::CreateParameter(
3144 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
3145 "param0"));
3146
3147 HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
3148 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
3149 HloOpcode::kCopy, param0));
3150
3151 builder.AddInstruction(HloInstruction::CreateUnary(
3152 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
3153 HloOpcode::kCopy, copy1));
3154
3155 auto computation = m->AddEntryComputation(builder.Build());
3156
3157 EXPECT_THAT(computation->root_instruction(),
3158 GmockMatch(m::Copy(m::Copy(m::Parameter(0)))));
3159
3160 AlgebraicSimplifierOptions options;
3161 options.set_is_layout_sensitive(true);
3162 AlgebraicSimplifier simplifier(options);
3163 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3164
3165 EXPECT_THAT(computation->root_instruction(),
3166 GmockMatch(m::Copy(m::Parameter(0))));
3167 }
3168
TEST_F(AlgebraicSimplifierTest,TransposesMerged)3169 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
3170 auto m = CreateNewVerifiedModule();
3171 HloComputation::Builder builder(TestName());
3172 HloInstruction* param0 =
3173 builder.AddInstruction(HloInstruction::CreateParameter(
3174 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
3175
3176 HloInstruction* transpose1 =
3177 builder.AddInstruction(HloInstruction::CreateTranspose(
3178 ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
3179
3180 builder.AddInstruction(HloInstruction::CreateTranspose(
3181 ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
3182
3183 auto computation = m->AddEntryComputation(builder.Build());
3184
3185 EXPECT_THAT(computation->root_instruction(),
3186 GmockMatch(m::Transpose(m::Op().Is(transpose1))));
3187
3188 AlgebraicSimplifier simplifier(default_options_);
3189 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3190
3191 EXPECT_THAT(computation->root_instruction(),
3192 GmockMatch(m::Transpose(m::Parameter(0))));
3193 EXPECT_EQ(std::vector<int64_t>({2, 1, 0}),
3194 computation->root_instruction()->dimensions());
3195 }
3196
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcast)3197 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcast) {
3198 const char* hlo_string = R"(
3199 HloModule module
3200
3201 ENTRY test {
3202 p0 = f32[10,20] parameter(0)
3203 b = f32[10,30,20] broadcast(p0), dimensions={0,2}
3204 ROOT s = f32[5,5,5] slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
3205 }
3206 )";
3207 TF_ASSERT_OK_AND_ASSIGN(auto module,
3208 ParseAndReturnVerifiedModule(hlo_string));
3209
3210 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3211 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3212 auto root = module->entry_computation()->root_instruction();
3213 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
3214 }
3215
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastPreserveLayout)3216 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastPreserveLayout) {
3217 const char* hlo_string = R"(
3218 HloModule module
3219
3220 ENTRY test {
3221 p0 = f32[10,20] parameter(0)
3222 b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
3223 ROOT s = f32[5,5,5]{2,0,1:T(256)} slice(b), slice={[0:5:1], [5:25:4], [5:15:2]}
3224 }
3225 )";
3226 TF_ASSERT_OK_AND_ASSIGN(auto module,
3227 ParseAndReturnVerifiedModule(hlo_string));
3228
3229 const Shape original_slice_shape =
3230 module->entry_computation()->root_instruction()->shape();
3231 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3232 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3233 auto root = module->entry_computation()->root_instruction();
3234 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Slice(m::Parameter(0)))));
3235 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_slice_shape));
3236 }
3237
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcast)3238 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcast) {
3239 const char* hlo_string = R"(
3240 HloModule module
3241
3242 ENTRY test {
3243 p0 = f32[10,20] parameter(0)
3244 i0 = s32[] parameter(1)
3245 i1 = s32[] parameter(2)
3246 i2 = s32[] parameter(3)
3247 b = f32[10,30,20] broadcast(p0), dimensions={0,2}
3248 ROOT ds = f32[5,5,5] dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
3249 }
3250 )";
3251 TF_ASSERT_OK_AND_ASSIGN(auto module,
3252 ParseAndReturnVerifiedModule(hlo_string));
3253
3254 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3255 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3256 auto root = module->entry_computation()->root_instruction();
3257 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
3258 m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
3259 }
3260
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfBroadcastPreserveLayout)3261 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfBroadcastPreserveLayout) {
3262 const char* hlo_string = R"(
3263 HloModule module
3264
3265 ENTRY test {
3266 p0 = f32[10,20] parameter(0)
3267 i0 = s32[] parameter(1)
3268 i1 = s32[] parameter(2)
3269 i2 = s32[] parameter(3)
3270 b = f32[10,30,20]{2,0,1:T(256)} broadcast(p0), dimensions={0,2}
3271 ROOT ds = f32[5,5,5]{2,0,1:T(256)} dynamic-slice(b, i0, i1, i2), dynamic_slice_sizes={5,5,5}
3272 }
3273 )";
3274 TF_ASSERT_OK_AND_ASSIGN(auto module,
3275 ParseAndReturnVerifiedModule(hlo_string));
3276
3277 const Shape original_dynslice_shape =
3278 module->entry_computation()->root_instruction()->shape();
3279 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3280 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3281 auto root = module->entry_computation()->root_instruction();
3282 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::DynamicSlice(
3283 m::Parameter(0), m::Parameter(1), m::Parameter(3)))));
3284 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), original_dynslice_shape));
3285 }
3286
TEST_F(AlgebraicSimplifierTest,TransposeIsReshape)3287 TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) {
3288 const char* hlo_string = R"(
3289 HloModule module
3290
3291 ENTRY test {
3292 param = f32[10] parameter(0)
3293 reshaped = f32[1,1,10] reshape(f32[10] param)
3294 transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0}
3295 ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed)
3296 }
3297 )";
3298 TF_ASSERT_OK_AND_ASSIGN(auto module,
3299 ParseAndReturnVerifiedModule(hlo_string));
3300
3301 HloPassFix<AlgebraicSimplifier> simplifier(default_options_);
3302 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3303 auto root = module->entry_computation()->root_instruction();
3304 EXPECT_THAT(root, GmockMatch(m::Parameter()));
3305 }
3306
3307 // Test merging reshape and broadcast.
TEST_F(AlgebraicSimplifierTest,ReshapeAndBroadcastMerged)3308 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
3309 auto m = CreateNewVerifiedModule();
3310 HloComputation::Builder builder(TestName());
3311 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
3312 0, ShapeUtil::MakeShape(F32, {5}), "param0"));
3313 auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
3314 ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
3315 builder.AddInstruction(HloInstruction::CreateBroadcast(
3316 ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
3317
3318 auto computation = m->AddEntryComputation(builder.Build());
3319
3320 EXPECT_THAT(computation->root_instruction(),
3321 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3322
3323 AlgebraicSimplifier simplifier(default_options_);
3324 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3325
3326 EXPECT_THAT(computation->root_instruction(),
3327 GmockMatch(m::Broadcast(m::Parameter(0))));
3328 }
3329
3330 // Test merging broadcast and reshape.
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshapeMerged)3331 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
3332 auto m = CreateNewVerifiedModule();
3333 HloComputation::Builder builder(TestName());
3334 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
3335 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
3336 auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
3337 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
3338 builder.AddInstruction(HloInstruction::CreateReshape(
3339 ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
3340
3341 auto computation = m->AddEntryComputation(builder.Build());
3342
3343 EXPECT_THAT(computation->root_instruction(),
3344 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3345
3346 AlgebraicSimplifier simplifier(default_options_);
3347 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3348
3349 EXPECT_THAT(computation->root_instruction(),
3350 GmockMatch(m::Broadcast(m::Parameter(0))));
3351 }
3352
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x1_3)3353 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
3354 auto m = CreateNewVerifiedModule();
3355 HloComputation::Builder builder(TestName());
3356 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3357 0, ShapeUtil::MakeShape(F32, {1}), "param"));
3358 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3359 ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
3360 builder.AddInstruction(
3361 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
3362
3363 auto computation = m->AddEntryComputation(builder.Build());
3364
3365 EXPECT_THAT(computation->root_instruction(),
3366 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3367
3368 AlgebraicSimplifier simplifier(default_options_);
3369 EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3370
3371 EXPECT_THAT(computation->root_instruction(),
3372 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3373 }
3374
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4_6x1x1x4)3375 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
3376 auto m = CreateNewVerifiedModule();
3377 HloComputation::Builder builder(TestName());
3378 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3379 0, ShapeUtil::MakeShape(F32, {4}), "param"));
3380 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3381 ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
3382 builder.AddInstruction(HloInstruction::CreateReshape(
3383 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
3384
3385 HloComputation* computation = m->AddEntryComputation(builder.Build());
3386
3387 EXPECT_THAT(computation->root_instruction(),
3388 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3389
3390 AlgebraicSimplifier simplifier(default_options_);
3391 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3392
3393 EXPECT_THAT(computation->root_instruction(),
3394 GmockMatch(m::Broadcast(m::Parameter(0))));
3395 EXPECT_THAT(computation->root_instruction()->dimensions(),
3396 ::testing::ElementsAre(3));
3397 }
3398
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_1_3x2x1_6x1x1x1)3399 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
3400 auto m = CreateNewVerifiedModule();
3401 HloComputation::Builder builder(TestName());
3402 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3403 0, ShapeUtil::MakeShape(F32, {1}), "param"));
3404 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3405 ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
3406 builder.AddInstruction(HloInstruction::CreateReshape(
3407 ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
3408
3409 HloComputation* computation = m->AddEntryComputation(builder.Build());
3410
3411 EXPECT_THAT(computation->root_instruction(),
3412 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3413
3414 AlgebraicSimplifier simplifier(default_options_);
3415 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3416
3417 EXPECT_THAT(computation->root_instruction(),
3418 GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
3419 EXPECT_EQ(0, computation->root_instruction()->dimensions().size());
3420 }
3421
TEST_F(AlgebraicSimplifierTest,BroadcastAndReshape_4_3x2x4x2_6x8)3422 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
3423 auto m = CreateNewVerifiedModule();
3424 HloComputation::Builder builder(TestName());
3425 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
3426 0, ShapeUtil::MakeShape(F32, {4}), "param"));
3427 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
3428 ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
3429 builder.AddInstruction(HloInstruction::CreateReshape(
3430 ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
3431
3432 HloComputation* computation = m->AddEntryComputation(builder.Build());
3433
3434 EXPECT_THAT(computation->root_instruction(),
3435 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3436
3437 AlgebraicSimplifier simplifier(default_options_);
3438 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3439
3440 EXPECT_THAT(computation->root_instruction(),
3441 GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0)))));
3442 }
3443
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeMerged)3444 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) {
3445 auto m = CreateNewVerifiedModule();
3446 HloComputation::Builder builder(TestName());
3447 auto iota = builder.AddInstruction(HloInstruction::CreateIota(
3448 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2));
3449 Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2});
3450 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3451
3452 auto computation = m->AddEntryComputation(builder.Build());
3453
3454 EXPECT_THAT(computation->root_instruction(),
3455 GmockMatch(m::Reshape(m::Iota())));
3456
3457 AlgebraicSimplifier simplifier(default_options_);
3458 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3459
3460 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3461 EXPECT_TRUE(
3462 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3463 }
3464
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadix)3465 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadix) {
3466 auto m = CreateNewVerifiedModule();
3467 HloComputation::Builder builder(TestName());
3468 auto iota = builder.AddInstruction(
3469 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {21}), 0));
3470 Shape result_shape = ShapeUtil::MakeShape(F32, {7, 3});
3471 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3472
3473 auto computation = m->AddEntryComputation(builder.Build());
3474
3475 EXPECT_THAT(computation->root_instruction(),
3476 GmockMatch(m::Reshape(m::Iota())));
3477
3478 AlgebraicSimplifier simplifier(default_options_);
3479 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3480
3481 EXPECT_THAT(computation->root_instruction(),
3482 GmockMatch(m::Add(
3483 m::Iota(),
3484 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3485 EXPECT_TRUE(
3486 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3487 }
TEST_F(AlgebraicSimplifierTest,IotaAndReshapeToMixedRadixExtraDims)3488 TEST_F(AlgebraicSimplifierTest, IotaAndReshapeToMixedRadixExtraDims) {
3489 auto m = CreateNewVerifiedModule();
3490 HloComputation::Builder builder(TestName());
3491 auto iota = builder.AddInstruction(
3492 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {42, 24, 15}), 1));
3493 Shape result_shape = ShapeUtil::MakeShape(F32, {3, 14, 4, 3, 2, 5, 3});
3494 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota));
3495
3496 auto computation = m->AddEntryComputation(builder.Build());
3497
3498 EXPECT_THAT(computation->root_instruction(),
3499 GmockMatch(m::Reshape(m::Iota())));
3500
3501 AlgebraicSimplifier simplifier(default_options_);
3502 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3503
3504 EXPECT_THAT(
3505 computation->root_instruction(),
3506 GmockMatch(m::Add(
3507 m::Add(m::Iota(),
3508 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar()))),
3509 m::Multiply(m::Iota(), m::Broadcast(m::ConstantScalar())))));
3510 EXPECT_TRUE(
3511 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3512 }
TEST_F(AlgebraicSimplifierTest,IotaEffectiveScalar)3513 TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) {
3514 auto m = CreateNewVerifiedModule();
3515 HloComputation::Builder builder(TestName());
3516 auto iota = builder.AddInstruction(
3517 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0));
3518 auto result_shape = iota->shape();
3519
3520 auto computation = m->AddEntryComputation(builder.Build());
3521
3522 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3523
3524 AlgebraicSimplifier simplifier(default_options_);
3525 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3526
3527 auto root = computation->root_instruction();
3528 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
3529 EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement<float>());
3530 EXPECT_TRUE(
3531 ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape));
3532 }
3533
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2_6)3534 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) {
3535 auto m = CreateNewVerifiedModule();
3536 HloComputation::Builder builder(TestName());
3537 auto iota = builder.AddInstruction(
3538 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1));
3539 builder.AddInstruction(
3540 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota));
3541
3542 auto computation = m->AddEntryComputation(builder.Build());
3543
3544 EXPECT_THAT(computation->root_instruction(),
3545 GmockMatch(m::Reshape(m::Iota())));
3546
3547 AlgebraicSimplifier simplifier(default_options_);
3548 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3549
3550 EXPECT_THAT(computation->root_instruction(),
3551 GmockMatch(m::Reshape(m::Iota())));
3552 }
3553
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4_6x1x1x4)3554 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) {
3555 auto m = CreateNewVerifiedModule();
3556 HloComputation::Builder builder(TestName());
3557 auto iota = builder.AddInstruction(
3558 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2));
3559 builder.AddInstruction(HloInstruction::CreateReshape(
3560 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota));
3561
3562 HloComputation* computation = m->AddEntryComputation(builder.Build());
3563
3564 EXPECT_THAT(computation->root_instruction(),
3565 GmockMatch(m::Reshape(m::Iota())));
3566
3567 AlgebraicSimplifier simplifier(default_options_);
3568 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3569
3570 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3571 EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction())
3572 ->iota_dimension(),
3573 3);
3574 }
3575
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_1_3x2x2_6x1x1x2)3576 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) {
3577 auto m = CreateNewVerifiedModule();
3578 HloComputation::Builder builder(TestName());
3579 auto iota = builder.AddInstruction(
3580 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2));
3581 builder.AddInstruction(HloInstruction::CreateReshape(
3582 ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota));
3583
3584 HloComputation* computation = m->AddEntryComputation(builder.Build());
3585
3586 EXPECT_THAT(computation->root_instruction(),
3587 GmockMatch(m::Reshape(m::Iota())));
3588
3589 AlgebraicSimplifier simplifier(default_options_);
3590 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
3591
3592 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota()));
3593 const int64_t iota_dim =
3594 Cast<HloIotaInstruction>(computation->root_instruction())
3595 ->iota_dimension();
3596 EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3));
3597 }
3598
TEST_F(AlgebraicSimplifierTest,IotaAndReshape_4_3x2x4x2_6x8)3599 TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) {
3600 auto m = CreateNewVerifiedModule();
3601 HloComputation::Builder builder(TestName());
3602 auto iota = builder.AddInstruction(
3603 HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2));
3604 builder.AddInstruction(
3605 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota));
3606
3607 HloComputation* computation = m->AddEntryComputation(builder.Build());
3608
3609 EXPECT_THAT(computation->root_instruction(),
3610 GmockMatch(m::Reshape(m::Iota())));
3611
3612 AlgebraicSimplifier simplifier(default_options_);
3613 EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie());
3614
3615 EXPECT_THAT(computation->root_instruction(),
3616 GmockMatch(m::Reshape(m::Iota())));
3617 }
3618
TEST_F(AlgebraicSimplifierTest,RemoveNoopPad)3619 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
3620 HloComputation::Builder builder(TestName());
3621 HloInstruction* param =
3622 builder.AddInstruction(HloInstruction::CreateParameter(
3623 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3624 HloInstruction* zero = builder.AddInstruction(
3625 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3626 PaddingConfig no_padding;
3627 for (int i = 0; i < 2; ++i) {
3628 auto dimension = no_padding.add_dimensions();
3629 dimension->set_edge_padding_low(0);
3630 dimension->set_edge_padding_high(0);
3631 dimension->set_interior_padding(0);
3632 }
3633 builder.AddInstruction(HloInstruction::CreatePad(
3634 ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
3635
3636 auto module = CreateNewVerifiedModule();
3637 HloComputation* computation = module->AddEntryComputation(builder.Build());
3638
3639 EXPECT_THAT(computation->root_instruction(),
3640 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3641
3642 AlgebraicSimplifier simplifier(default_options_);
3643 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3644
3645 EXPECT_THAT(computation->root_instruction(), param);
3646 }
3647
TEST_F(AlgebraicSimplifierTest,RemoveNoopSliceOfPad)3648 TEST_F(AlgebraicSimplifierTest, RemoveNoopSliceOfPad) {
3649 HloComputation::Builder builder(TestName());
3650 HloInstruction* param =
3651 builder.AddInstruction(HloInstruction::CreateParameter(
3652 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
3653 HloInstruction* zero = builder.AddInstruction(
3654 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3655 PaddingConfig no_padding;
3656 for (int i = 0; i < 2; ++i) {
3657 auto dimension = no_padding.add_dimensions();
3658 dimension->set_edge_padding_low(2);
3659 dimension->set_edge_padding_high(0);
3660 dimension->set_interior_padding(1);
3661 }
3662 auto pad = builder.AddInstruction(HloInstruction::CreatePad(
3663 ShapeUtil::MakeShape(F32, {5, 5}), param, zero, no_padding));
3664 builder.AddInstruction(HloInstruction::CreateSlice(
3665 ShapeUtil::MakeShape(F32, {2, 2}), pad, /*start_indices=*/{2, 2},
3666 /*limit_indices=*/{5, 5}, /*strides=*/{2, 2}));
3667
3668 auto module = CreateNewVerifiedModule();
3669 HloComputation* computation = module->AddEntryComputation(builder.Build());
3670
3671 EXPECT_THAT(computation->root_instruction(),
3672 GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3673
3674 AlgebraicSimplifier simplifier(default_options_);
3675 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3676
3677 EXPECT_THAT(computation->root_instruction(), param);
3678 }
3679
TEST_F(AlgebraicSimplifierTest,NegativePadding)3680 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
3681 // Verify that a pad instruction with negative padding is replaced with a
3682 // pad with non-negative padding followed by a slice.
3683 HloComputation::Builder builder(TestName());
3684 HloInstruction* param =
3685 builder.AddInstruction(HloInstruction::CreateParameter(
3686 0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3687 HloInstruction* zero = builder.AddInstruction(
3688 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3689 PaddingConfig padding;
3690 int64_t low_padding[2] = {-1, -2};
3691 int64_t high_padding[2] = {2, -3};
3692 for (int i = 0; i < 2; ++i) {
3693 auto dimension = padding.add_dimensions();
3694 dimension->set_edge_padding_low(low_padding[i]);
3695 dimension->set_edge_padding_high(high_padding[i]);
3696 dimension->set_interior_padding(0);
3697 }
3698 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3699 ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3700
3701 auto module = CreateNewVerifiedModule();
3702 HloComputation* computation = module->AddEntryComputation(builder.Build());
3703
3704 AlgebraicSimplifier simplifier(default_options_);
3705
3706 auto has_negative_padding = [](const HloInstruction* pad) {
3707 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3708 if (padding_dimension.edge_padding_low() < 0 ||
3709 padding_dimension.edge_padding_high() < 0) {
3710 return true;
3711 }
3712 }
3713 return false;
3714 };
3715
3716 EXPECT_THAT(computation->root_instruction(),
3717 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3718 EXPECT_TRUE(has_negative_padding(pad));
3719
3720 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3721
3722 EXPECT_THAT(computation->root_instruction(),
3723 GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero)))));
3724 EXPECT_FALSE(
3725 has_negative_padding(computation->root_instruction()->operand(0)));
3726 }
3727
TEST_F(AlgebraicSimplifierTest,CanDisableBroadcastSinking)3728 TEST_F(AlgebraicSimplifierTest, CanDisableBroadcastSinking) {
3729 // Some broadcasts can be sunk (or delayed). This test verifies that we can
3730 // disable this behavior when necessary.
3731 HloComputation::Builder builder(TestName());
3732 HloInstruction* param =
3733 builder.AddInstruction(HloInstruction::CreateParameter(
3734 0, ShapeUtil::MakeShape(F32, {}), "scalar"));
3735 HloInstruction* broadcast =
3736 builder.AddInstruction(HloInstruction::CreateBroadcast(
3737 ShapeUtil::MakeShape(F32, {512, 16}), param, {}));
3738 builder.AddInstruction(HloInstruction::CreateUnary(
3739 ShapeUtil::MakeShape(F32, {512, 16}), HloOpcode::kNegate, broadcast));
3740
3741 auto module = CreateNewVerifiedModule();
3742 HloComputation* computation = module->AddEntryComputation(builder.Build());
3743
3744 EXPECT_THAT(computation->root_instruction(),
3745 GmockMatch(m::Negate(m::Broadcast(m::Parameter(0)))));
3746
3747 // Verify that we can disable the broadcast sinking optimization.
3748 AlgebraicSimplifierOptions opts = default_options_;
3749 opts.set_enable_sink_broadcast(false);
3750 AlgebraicSimplifier simplifier(opts);
3751
3752 // Nothing has changed since broadcast sinking is disabled.
3753 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3754 }
3755
TEST_F(AlgebraicSimplifierTest,CanDisableNegativePadding)3756 TEST_F(AlgebraicSimplifierTest, CanDisableNegativePadding) {
3757 // Verify that a pad instruction with negative padding is replaced with a
3758 // pad with non-negative padding followed by a slice.
3759 HloComputation::Builder builder(TestName());
3760 HloInstruction* param =
3761 builder.AddInstruction(HloInstruction::CreateParameter(
3762 0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
3763 HloInstruction* zero = builder.AddInstruction(
3764 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3765 PaddingConfig padding;
3766 int64_t low_padding[2] = {-1, -2};
3767 int64_t high_padding[2] = {2, -3};
3768 for (int i = 0; i < 2; ++i) {
3769 auto dimension = padding.add_dimensions();
3770 dimension->set_edge_padding_low(low_padding[i]);
3771 dimension->set_edge_padding_high(high_padding[i]);
3772 dimension->set_interior_padding(0);
3773 }
3774 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3775 ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
3776
3777 auto module = CreateNewVerifiedModule();
3778 HloComputation* computation = module->AddEntryComputation(builder.Build());
3779
3780 // Verify that we can disable the negative padding optimization.
3781 AlgebraicSimplifierOptions opts = default_options_;
3782 opts.set_enable_negative_padding_replacement(false);
3783
3784 AlgebraicSimplifier simplifier(opts);
3785
3786 auto has_negative_padding = [](const HloInstruction* pad) {
3787 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3788 if (padding_dimension.edge_padding_low() < 0 ||
3789 padding_dimension.edge_padding_high() < 0) {
3790 return true;
3791 }
3792 }
3793 return false;
3794 };
3795
3796 EXPECT_THAT(computation->root_instruction(),
3797 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3798 EXPECT_TRUE(has_negative_padding(pad));
3799
3800 // Nothing has changed since the negative padding replacement is disabled.
3801 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3802 }
3803
TEST_F(AlgebraicSimplifierTest,TrivialInteriorPadding)3804 TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
3805 // Verify that a pad instruction with interior padding on one-sized
3806 // dimensions, removes the interior padding.
3807 HloComputation::Builder builder(TestName());
3808 HloInstruction* param =
3809 builder.AddInstruction(HloInstruction::CreateParameter(
3810 0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
3811 HloInstruction* zero = builder.AddInstruction(
3812 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
3813 PaddingConfig padding;
3814 for (int i = 0; i < 2; ++i) {
3815 auto dimension = padding.add_dimensions();
3816 dimension->set_edge_padding_low(3);
3817 dimension->set_edge_padding_high(3);
3818 dimension->set_interior_padding(i * 3);
3819 }
3820 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
3821 ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
3822
3823 auto module = CreateNewVerifiedModule();
3824 HloComputation* computation = module->AddEntryComputation(builder.Build());
3825
3826 AlgebraicSimplifier simplifier(default_options_);
3827
3828 ASSERT_THAT(computation->root_instruction(),
3829 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3830 ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
3831
3832 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3833
3834 EXPECT_THAT(computation->root_instruction(),
3835 GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero))));
3836 EXPECT_FALSE(
3837 HasInteriorPadding(computation->root_instruction()->padding_config()));
3838 }
3839
TEST_F(AlgebraicSimplifierTest,RemoveNoopReshape)3840 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
3841 HloComputation::Builder builder(TestName());
3842 HloInstruction* param =
3843 builder.AddInstruction(HloInstruction::CreateParameter(
3844 0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
3845 builder.AddInstruction(
3846 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
3847
3848 auto module = CreateNewVerifiedModule();
3849 HloComputation* computation = module->AddEntryComputation(builder.Build());
3850
3851 EXPECT_THAT(computation->root_instruction(),
3852 GmockMatch(m::Reshape(m::Parameter(0))));
3853
3854 AlgebraicSimplifier simplifier(default_options_);
3855 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3856
3857 EXPECT_THAT(computation->root_instruction(), param);
3858 }
3859
TEST_F(AlgebraicSimplifierTest,RemoveNoopSlice)3860 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
3861 HloComputation::Builder builder(TestName());
3862 const int64_t dim0 = 2;
3863 const int64_t dim1 = 3;
3864 HloInstruction* param =
3865 builder.AddInstruction(HloInstruction::CreateParameter(
3866 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3867 builder.AddInstruction(HloInstruction::CreateSlice(
3868 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
3869 /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
3870
3871 auto module = CreateNewVerifiedModule();
3872 HloComputation* computation = module->AddEntryComputation(builder.Build());
3873
3874 EXPECT_THAT(computation->root_instruction(),
3875 GmockMatch(m::Slice(m::Parameter(0))));
3876
3877 AlgebraicSimplifier simplifier(default_options_);
3878 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3879
3880 EXPECT_THAT(computation->root_instruction(), param);
3881 }
3882
TEST_F(AlgebraicSimplifierTest,SliceOfSliceToSlice)3883 TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
3884 HloComputation::Builder builder(TestName());
3885 const int64_t dim0 = 11;
3886 const int64_t dim1 = 12;
3887 HloInstruction* param =
3888 builder.AddInstruction(HloInstruction::CreateParameter(
3889 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
3890 HloInstruction* original_slice =
3891 builder.AddInstruction(HloInstruction::CreateSlice(
3892 ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
3893 /*start_indices=*/{1, 2},
3894 /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
3895
3896 builder.AddInstruction(HloInstruction::CreateSlice(
3897 ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
3898 /*start_indices=*/{2, 3},
3899 /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
3900 auto module = CreateNewVerifiedModule();
3901 HloComputation* computation = module->AddEntryComputation(builder.Build());
3902
3903 EXPECT_THAT(computation->root_instruction(),
3904 GmockMatch(m::Slice(m::Slice(m::Parameter(0)))));
3905
3906 AlgebraicSimplifier simplifier(default_options_);
3907 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3908
3909 EXPECT_THAT(computation->root_instruction(),
3910 GmockMatch(m::Slice(m::Parameter(0))));
3911 EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
3912 EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
3913 EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
3914 EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
3915 }
3916
TEST_F(AlgebraicSimplifierTest,SliceOfBroadcastToBroadcast)3917 TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) {
3918 HloComputation::Builder builder(TestName());
3919 const int64_t dim0 = 11;
3920 const int64_t dim1 = 12;
3921 HloInstruction* param =
3922 builder.AddInstruction(HloInstruction::CreateParameter(
3923 0, ShapeUtil::MakeShape(F32, {dim0}), "param"));
3924 HloInstruction* broadcast =
3925 builder.AddInstruction(HloInstruction::CreateBroadcast(
3926 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, {0}));
3927 builder.AddInstruction(HloInstruction::CreateSlice(
3928 ShapeUtil::MakeShape(F32, {dim0, dim1 - 9}), broadcast,
3929 /*start_indices=*/{0, 3},
3930 /*limit_indices=*/{dim0, dim1 - 6}, /*strides=*/{1, 1}));
3931 auto module = CreateNewVerifiedModule();
3932 HloComputation* computation = module->AddEntryComputation(builder.Build());
3933
3934 EXPECT_THAT(computation->root_instruction(),
3935 GmockMatch(m::Slice(m::Broadcast(m::Parameter(0)))));
3936
3937 AlgebraicSimplifier simplifier(default_options_);
3938 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3939
3940 EXPECT_THAT(computation->root_instruction(),
3941 GmockMatch(m::Broadcast(m::Parameter(0))));
3942 }
3943
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeToReshapeOfSlice)3944 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) {
3945 HloComputation::Builder builder(TestName());
3946 const int64_t dim0 = 11;
3947 const int64_t dim1 = 12;
3948 const int64_t dim2 = 13;
3949 HloInstruction* param =
3950 builder.AddInstruction(HloInstruction::CreateParameter(
3951 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param"));
3952 HloInstruction* original_reshape =
3953 builder.AddInstruction(HloInstruction::CreateReshape(
3954 ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param));
3955
3956 builder.AddInstruction(HloInstruction::CreateSlice(
3957 ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape,
3958 /*start_indices=*/{0, 0, 0},
3959 /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1}));
3960 auto module = CreateNewVerifiedModule();
3961 HloComputation* computation = module->AddEntryComputation(builder.Build());
3962
3963 EXPECT_THAT(computation->root_instruction(),
3964 GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3965
3966 AlgebraicSimplifier simplifier(default_options_);
3967 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
3968
3969 EXPECT_THAT(computation->root_instruction(),
3970 GmockMatch(m::Reshape(m::Slice(m::Parameter(0)))));
3971 }
3972
TEST_F(AlgebraicSimplifierTest,SliceOfReshapeUnchanged)3973 TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) {
3974 HloComputation::Builder builder(TestName());
3975 HloInstruction* param =
3976 builder.AddInstruction(HloInstruction::CreateParameter(
3977 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param"));
3978 HloInstruction* original_reshape =
3979 builder.AddInstruction(HloInstruction::CreateReshape(
3980 ShapeUtil::MakeShape(F32, {3600, 512}), param));
3981
3982 builder.AddInstruction(HloInstruction::CreateSlice(
3983 ShapeUtil::MakeShape(F32, {960, 512}), original_reshape,
3984 /*start_indices=*/{0, 0},
3985 /*limit_indices=*/{960, 512}, /*strides=*/{1, 1}));
3986 auto module = CreateNewVerifiedModule();
3987 HloComputation* computation = module->AddEntryComputation(builder.Build());
3988
3989 EXPECT_THAT(computation->root_instruction(),
3990 GmockMatch(m::Slice(m::Reshape(m::Parameter(0)))));
3991
3992 AlgebraicSimplifier simplifier(default_options_);
3993 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
3994 }
3995
TEST_F(AlgebraicSimplifierTest,RemoveNoopSort)3996 TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
3997 auto builder = HloComputation::Builder(TestName());
3998 auto module = CreateNewVerifiedModule();
3999
4000 Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
4001 auto keys = builder.AddInstruction(
4002 HloInstruction::CreateParameter(0, keys_shape, "keys"));
4003 TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder,
4004 module.get())
4005 .status());
4006 HloComputation* computation = module->AddEntryComputation(builder.Build());
4007 AlgebraicSimplifier simplifier(default_options_);
4008 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4009 EXPECT_THAT(computation->root_instruction(), keys);
4010 }
4011
TEST_F(AlgebraicSimplifierTest,ReplaceEffectiveScalarKeyValueSortWithTuple)4012 TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
4013 auto builder = HloComputation::Builder(TestName());
4014 auto module = CreateNewVerifiedModule();
4015
4016 Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
4017 Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
4018 auto keys = builder.AddInstruction(
4019 HloInstruction::CreateParameter(0, keys_shape, "keys"));
4020 auto values0 = builder.AddInstruction(
4021 HloInstruction::CreateParameter(1, values_shape, "values0"));
4022 auto values1 = builder.AddInstruction(
4023 HloInstruction::CreateParameter(2, values_shape, "values1"));
4024 TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
4025 {keys_shape, values_shape, values_shape}),
4026 {keys, values0, values1}, 0, /*is_stable=*/false,
4027 &builder, module.get())
4028 .status());
4029 HloComputation* computation = module->AddEntryComputation(builder.Build());
4030 AlgebraicSimplifier simplifier(default_options_);
4031 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4032 EXPECT_THAT(computation->root_instruction(),
4033 GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0),
4034 m::Op().Is(values1))));
4035 }
4036
4037 // Test that A && True is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue)4038 TEST_F(AlgebraicSimplifierTest, AndTrue) {
4039 auto m = CreateNewVerifiedModule();
4040 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4041 HloComputation::Builder builder(TestName());
4042 HloInstruction* param0 = builder.AddInstruction(
4043 HloInstruction::CreateParameter(0, r0pred, "param0"));
4044 HloInstruction* const_true = builder.AddInstruction(
4045 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4046 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4047 param0, const_true));
4048
4049 auto computation = m->AddEntryComputation(builder.Build());
4050 HloInstruction* root = computation->root_instruction();
4051 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4052 AlgebraicSimplifier simplifier(default_options_);
4053 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4054 root = computation->root_instruction();
4055 EXPECT_EQ(root, param0);
4056 }
4057
4058 // Test that True && A is simplified to A
TEST_F(AlgebraicSimplifierTest,AndTrue2)4059 TEST_F(AlgebraicSimplifierTest, AndTrue2) {
4060 auto m = CreateNewVerifiedModule();
4061 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4062 HloComputation::Builder builder(TestName());
4063 HloInstruction* param0 = builder.AddInstruction(
4064 HloInstruction::CreateParameter(0, r0pred, "param0"));
4065 HloInstruction* const_true = builder.AddInstruction(
4066 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4067 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4068 const_true, param0));
4069
4070 auto computation = m->AddEntryComputation(builder.Build());
4071 HloInstruction* root = computation->root_instruction();
4072 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4073 AlgebraicSimplifier simplifier(default_options_);
4074 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4075 root = computation->root_instruction();
4076 EXPECT_EQ(root, param0);
4077 }
4078
4079 // Test that A && False is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse)4080 TEST_F(AlgebraicSimplifierTest, AndFalse) {
4081 auto m = CreateNewVerifiedModule();
4082 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4083 HloComputation::Builder builder(TestName());
4084 HloInstruction* param0 = builder.AddInstruction(
4085 HloInstruction::CreateParameter(0, r0pred, "param0"));
4086 HloInstruction* const_false = builder.AddInstruction(
4087 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4088 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4089 param0, const_false));
4090
4091 auto computation = m->AddEntryComputation(builder.Build());
4092 HloInstruction* root = computation->root_instruction();
4093 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4094 AlgebraicSimplifier simplifier(default_options_);
4095 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4096 root = computation->root_instruction();
4097 EXPECT_EQ(root, const_false);
4098 }
4099
4100 // Test that False && A is simplified to False
TEST_F(AlgebraicSimplifierTest,AndFalse2)4101 TEST_F(AlgebraicSimplifierTest, AndFalse2) {
4102 auto m = CreateNewVerifiedModule();
4103 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4104 HloComputation::Builder builder(TestName());
4105 HloInstruction* param0 = builder.AddInstruction(
4106 HloInstruction::CreateParameter(0, r0pred, "param0"));
4107 HloInstruction* const_false = builder.AddInstruction(
4108 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4109 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd,
4110 const_false, param0));
4111
4112 auto computation = m->AddEntryComputation(builder.Build());
4113 HloInstruction* root = computation->root_instruction();
4114 EXPECT_EQ(root->opcode(), HloOpcode::kAnd);
4115 AlgebraicSimplifier simplifier(default_options_);
4116 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4117 root = computation->root_instruction();
4118 EXPECT_EQ(root, const_false);
4119 }
4120
4121 // Test that A || True is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue)4122 TEST_F(AlgebraicSimplifierTest, OrTrue) {
4123 auto m = CreateNewVerifiedModule();
4124 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4125 HloComputation::Builder builder(TestName());
4126 HloInstruction* param0 = builder.AddInstruction(
4127 HloInstruction::CreateParameter(0, r0pred, "param0"));
4128 HloInstruction* const_true = builder.AddInstruction(
4129 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4130 builder.AddInstruction(
4131 HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true));
4132
4133 auto computation = m->AddEntryComputation(builder.Build());
4134 HloInstruction* root = computation->root_instruction();
4135 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4136 AlgebraicSimplifier simplifier(default_options_);
4137 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4138 root = computation->root_instruction();
4139 EXPECT_EQ(root, const_true);
4140 }
4141
4142 // Test that True || A is simplified to True
TEST_F(AlgebraicSimplifierTest,OrTrue2)4143 TEST_F(AlgebraicSimplifierTest, OrTrue2) {
4144 auto m = CreateNewVerifiedModule();
4145 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4146 HloComputation::Builder builder(TestName());
4147 HloInstruction* param0 = builder.AddInstruction(
4148 HloInstruction::CreateParameter(0, r0pred, "param0"));
4149 HloInstruction* const_true = builder.AddInstruction(
4150 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
4151 builder.AddInstruction(
4152 HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0));
4153
4154 auto computation = m->AddEntryComputation(builder.Build());
4155 HloInstruction* root = computation->root_instruction();
4156 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4157 AlgebraicSimplifier simplifier(default_options_);
4158 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4159 root = computation->root_instruction();
4160 EXPECT_EQ(root, const_true);
4161 }
4162
4163 // Test that A || False is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse)4164 TEST_F(AlgebraicSimplifierTest, OrFalse) {
4165 auto m = CreateNewVerifiedModule();
4166 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4167 HloComputation::Builder builder(TestName());
4168 HloInstruction* param0 = builder.AddInstruction(
4169 HloInstruction::CreateParameter(0, r0pred, "param0"));
4170 HloInstruction* const_false = builder.AddInstruction(
4171 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4172 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
4173 param0, const_false));
4174
4175 auto computation = m->AddEntryComputation(builder.Build());
4176 HloInstruction* root = computation->root_instruction();
4177 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4178 AlgebraicSimplifier simplifier(default_options_);
4179 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4180 root = computation->root_instruction();
4181 EXPECT_EQ(root, param0);
4182 }
4183
4184 // Test that False || A is simplified to A
TEST_F(AlgebraicSimplifierTest,OrFalse2)4185 TEST_F(AlgebraicSimplifierTest, OrFalse2) {
4186 auto m = CreateNewVerifiedModule();
4187 Shape r0pred = ShapeUtil::MakeShape(PRED, {});
4188 HloComputation::Builder builder(TestName());
4189 HloInstruction* param0 = builder.AddInstruction(
4190 HloInstruction::CreateParameter(0, r0pred, "param0"));
4191 HloInstruction* const_false = builder.AddInstruction(
4192 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
4193 builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr,
4194 const_false, param0));
4195
4196 auto computation = m->AddEntryComputation(builder.Build());
4197 HloInstruction* root = computation->root_instruction();
4198 EXPECT_EQ(root->opcode(), HloOpcode::kOr);
4199 AlgebraicSimplifier simplifier(default_options_);
4200 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4201 root = computation->root_instruction();
4202 EXPECT_EQ(root, param0);
4203 }
4204
4205 // Used for TEST_Ps that test merging (or not) of a kPad instruction into a
4206 // convolution's Window.
4207 struct ConvPaddingTestcase {
ConvPaddingTestcasexla::__anon8c2363a90111::ConvPaddingTestcase4208 ConvPaddingTestcase(absl::string_view padding,
4209 absl::string_view orig_conv_window,
4210 absl::string_view expected_conv_window)
4211 : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window,
4212 /*pad_value=*/0) {}
4213
ConvPaddingTestcasexla::__anon8c2363a90111::ConvPaddingTestcase4214 ConvPaddingTestcase(absl::string_view padding,
4215 absl::string_view orig_conv_window,
4216 absl::string_view expected_conv_window, float pad_value)
4217 : padding(padding),
4218 orig_conv_window(orig_conv_window),
4219 expected_conv_window(expected_conv_window),
4220 pad_value(pad_value) {}
4221
ToStringxla::__anon8c2363a90111::ConvPaddingTestcase4222 std::string ToString() const {
4223 return absl::StrFormat(
4224 "padding=%s, orig_conv_window=%s, expected_conv_window=%s, "
4225 "pad_value=%f",
4226 padding, orig_conv_window, expected_conv_window, pad_value);
4227 }
4228
4229 std::string padding;
4230 std::string orig_conv_window;
4231 std::string expected_conv_window;
4232 float pad_value;
4233 };
4234
4235 // ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a
4236 // computation that does
4237 //
4238 // conv(pad(param0, padding=padding), param1), window=orig_conv_window
4239 //
4240 // gets transformed by AlgebraicSimplifier to
4241 //
4242 // conv(param0, param1), window=expected_conv_window
4243 //
4244 // or, if expected_conv_window is the empty string, checks that
4245 // AlgebraicSimplifier does *not* transform the original convolution.
4246 class ConvInputPaddingTest
4247 : public AlgebraicSimplifierTest,
4248 public ::testing::WithParamInterface<ConvPaddingTestcase> {};
4249
4250 INSTANTIATE_TEST_SUITE_P(
4251 ConvInputPaddingTestCases, ConvInputPaddingTest,
4252 ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
4253 // Merge this edge padding into the conv.
4254 {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"},
4255 // Merge this edge padding with the conv's edge padding.
4256 {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"},
4257 // Merge this interior-padded kPad with the unpadded conv. The 3x6
4258 // interior padding gets transformed to 4x7 conv lhs dilation.
4259 {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"},
4260 // kPad has dilation on one dim, conv has it on the other; merge them.
4261 {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"},
4262 // kPad has dilation and edge padding on one dim, conv has them on the
4263 // other; merge them.
4264 {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10",
4265 "pad=0_1x3_0 lhs_dilate=2x10"},
4266
4267 // Don't transform if the pad value is nonzero.
4268 {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1},
4269
4270 // We refuse to transform the following because on some dimension, one
4271 // of the kPad and conv has dilation and the other has some sort of
4272 // padding.
4273 {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""},
4274 {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""},
4275 {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
4276 {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""},
4277 {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""},
4278 {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""},
4279
4280 // We can't merge feature or batch padding into the conv.
4281 {"1_0x0_0x0_0x0_0", "", ""},
4282 {"0_0x1_0x0_0x0_0", "", ""},
4283 }));
4284
TEST_P(ConvInputPaddingTest,DoTest)4285 TEST_P(ConvInputPaddingTest, DoTest) {
4286 ConvPaddingTestcase testcase = GetParam();
4287
4288 // It would be better to put the testcase's ToString into the test name, but
4289 // gUnit has constraints on what can go into test names, and any reasonable
4290 // implementation of ToString() seems to violate them.
4291 SCOPED_TRACE(testcase.ToString());
4292
4293 auto builder = HloComputation::Builder(TestName());
4294 auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
4295 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01
4296 "input"));
4297 auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
4298 LiteralUtil::CreateR0(testcase.pad_value)));
4299
4300 PaddingConfig padding_config =
4301 ParsePaddingConfig(testcase.padding).ValueOrDie();
4302 auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
4303 ShapeInference::InferPadShape(input->shape(), pad_value->shape(),
4304 padding_config)
4305 .ValueOrDie(),
4306 input, pad_value, padding_config));
4307
4308 auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
4309 1,
4310 ShapeUtil::MakeShape(
4311 F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01
4312 "input"));
4313
4314 ConvolutionDimensionNumbers dnums =
4315 ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
4316 Window window =
4317 ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
4318 .ValueOrDie();
4319 builder.AddInstruction(HloInstruction::CreateConvolve(
4320 ShapeInference::InferConvolveShape(
4321 lhs_pad->shape(), filter->shape(),
4322 /*feature_group_count=*/1,
4323 /*batch_group_count=*/1, window, dnums,
4324 /*preferred_element_type=*/std::nullopt)
4325 .ValueOrDie(),
4326 lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
4327 window, dnums, DefaultPrecisionConfig(2)));
4328 auto module = CreateNewVerifiedModule();
4329 module->AddEntryComputation(builder.Build());
4330
4331 AlgebraicSimplifier simplifier(default_options_);
4332 if (testcase.expected_conv_window.empty()) {
4333 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4334 } else {
4335 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4336 auto* conv = module->entry_computation()->root_instruction();
4337 SCOPED_TRACE(module->ToString());
4338 ASSERT_THAT(conv,
4339 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
4340 EXPECT_EQ(window_util::ToString(conv->window()),
4341 absl::StrCat("size=3x3 ", testcase.expected_conv_window));
4342 }
4343 }
4344
4345 // ConvFilterPaddingTest (and its one associated TEST_P) checks that a
4346 // computation that does
4347 //
4348 // conv(param0, pad(param1, padding=padding)), window=orig_conv_window
4349 //
4350 // gets transformed by AlgebraicSimplifier to
4351 //
4352 // conv(param0, param1), window=expected_conv_window
4353 //
4354 // or, if expected_conv_window is the empty string, checks that
4355 // AlgebraicSimplifier does *not* transform the original convolution.
4356 class ConvFilterPaddingTest
4357 : public AlgebraicSimplifierTest,
4358 public ::testing::WithParamInterface<ConvPaddingTestcase> {};
4359
4360 INSTANTIATE_TEST_SUITE_P(
4361 ConvFilterPaddingTestCases, ConvFilterPaddingTest,
4362 ::testing::ValuesIn(std::vector<ConvPaddingTestcase>{
4363 // Can only merge interior padding on the filter's spatial dimensions;
4364 // all
4365 // other paddings (edge padding and interior padding on the channel
4366 // dims)
4367 // should be rejected out of hand.
4368 {"1_0_0x0_0_0x0_0x0_0", "", ""},
4369 {"0_1_0x0_0_0x0_0x0_0", "", ""},
4370 {"0_0_1x0_0_0x0_0x0_0", "", ""},
4371 {"0_0_0x1_0_0x0_0x0_0", "", ""},
4372 {"0_0_0x0_1_0x0_0x0_0", "", ""},
4373 {"0_0_0x0_0_1x0_0x0_0", "", ""},
4374 {"0_0_0x0_0_0x1_0x0_0", "", ""},
4375 {"0_0_0x0_0_0x0_1x0_0", "", ""},
4376 {"0_0_0x0_0_0x0_0x1_0", "", ""},
4377 {"0_0_0x0_0_0x0_0x0_1", "", ""},
4378
4379 // Interior padding on channel dims can be merged into the conv, so long
4380 // as the conv and pad don't have interior padding on the same dim.
4381 {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"},
4382 {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"},
4383 {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"},
4384 {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"},
4385 {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"},
4386
4387 // Can't merge if for a given dim there's interior padding on both the
4388 // pad and conv.
4389 {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""},
4390 {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""},
4391
4392 // Don't transform if the pad value is nonzero.
4393 {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1},
4394 }));
4395
TEST_P(ConvFilterPaddingTest,DoIt)4396 TEST_P(ConvFilterPaddingTest, DoIt) {
4397 ConvPaddingTestcase testcase = GetParam();
4398
4399 // It would be better to put the testcase's ToString into the test name, but
4400 // gUnit has constraints on what can go into test names, and any reasonable
4401 // implementation of ToString() seems to violate them.
4402 SCOPED_TRACE(testcase.ToString());
4403
4404 auto builder = HloComputation::Builder(TestName());
4405 auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant(
4406 LiteralUtil::CreateR0(testcase.pad_value)));
4407 auto* filter = builder.AddInstruction(HloInstruction::CreateParameter(
4408 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01
4409 "input"));
4410 PaddingConfig padding_config =
4411 ParsePaddingConfig(testcase.padding).ValueOrDie();
4412 auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad(
4413 ShapeInference::InferPadShape(filter->shape(), pad_value->shape(),
4414 padding_config)
4415 .ValueOrDie(),
4416 filter, pad_value, padding_config));
4417
4418 auto* input = builder.AddInstruction(HloInstruction::CreateParameter(
4419 0,
4420 ShapeUtil::MakeShape(
4421 F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01
4422 "input"));
4423
4424 ConvolutionDimensionNumbers dnums =
4425 ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie();
4426 Window window = ParseWindow(absl::StrFormat("size=%dx%d %s",
4427 rhs_pad->shape().dimensions(2),
4428 rhs_pad->shape().dimensions(3),
4429 testcase.orig_conv_window))
4430 .ValueOrDie();
4431
4432 // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
4433 // after the transformation.
4434 PrecisionConfig precision_config;
4435 precision_config.add_operand_precision(PrecisionConfig::HIGH);
4436 precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
4437
4438 builder.AddInstruction(HloInstruction::CreateConvolve(
4439 ShapeInference::InferConvolveShape(
4440 input->shape(), rhs_pad->shape(),
4441 /*feature_group_count=*/1,
4442 /*batch_group_count=*/1, window, dnums,
4443 /*preferred_element_type=*/std::nullopt)
4444 .ValueOrDie(),
4445 input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
4446 window, dnums, precision_config));
4447
4448 auto module = CreateNewVerifiedModule();
4449 module->AddEntryComputation(builder.Build());
4450
4451 AlgebraicSimplifier simplifier(default_options_);
4452 if (testcase.expected_conv_window.empty()) {
4453 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4454 } else {
4455 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4456 auto* conv = module->entry_computation()->root_instruction();
4457 SCOPED_TRACE(module->ToString());
4458 ASSERT_THAT(conv,
4459 GmockMatch(m::Convolution(m::Parameter(), m::Parameter())));
4460 EXPECT_EQ(window_util::ToString(conv->window()),
4461 absl::StrFormat("size=%dx%d %s",
4462 conv->operand(1)->shape().dimensions(2),
4463 conv->operand(1)->shape().dimensions(3),
4464 testcase.expected_conv_window));
4465 EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
4466 ->precision_config()
4467 .operand_precision(),
4468 ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
4469 }
4470 }
4471
TEST_F(AlgebraicSimplifierTest,ConvertConvToMatmul)4472 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
4473 struct ConvTestOptions {
4474 int in_batch = 10;
4475 int in_height = 2;
4476 int in_width = 2;
4477 int in_channels = 3;
4478 int f_width = 1;
4479 int f_height = 1;
4480 int f_output_channels = 10;
4481 int row_stride = 1;
4482 int row_padding = 0;
4483 int col_stride = 1;
4484 int col_padding = 0;
4485 bool input_minor_to_major_layout = false;
4486 bool filter_minor_to_major_layout = false;
4487 bool output_minor_to_major_layout = false;
4488
4489 const char* dim_order = "NHWC"; // can use chars NHWC in any order.
4490 const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order.
4491
4492 ConvTestOptions& Reset() {
4493 *this = ConvTestOptions();
4494 return *this;
4495 }
4496 };
4497
4498 ConvTestOptions options;
4499
4500 // Builds a convolution from <options> and runs algebraic simplification on
4501 // the computation. Returns a string description of the result of
4502 // simplification.
4503 auto build_and_simplify = [&]() -> std::string {
4504 HloComputation::Builder b(TestName());
4505
4506 Window window;
4507 auto* f_dim_1 = window.add_dimensions();
4508 f_dim_1->set_size(options.f_height);
4509 f_dim_1->set_stride(options.row_stride);
4510 f_dim_1->set_padding_low(options.row_padding);
4511 f_dim_1->set_padding_high(options.row_padding);
4512 f_dim_1->set_window_dilation(1);
4513 f_dim_1->set_base_dilation(1);
4514 auto* f_dim_2 = window.add_dimensions();
4515 f_dim_2->set_size(options.f_width);
4516 f_dim_2->set_stride(options.col_stride);
4517 f_dim_2->set_padding_low(options.col_padding);
4518 f_dim_2->set_padding_high(options.col_padding);
4519 f_dim_2->set_window_dilation(1);
4520 f_dim_2->set_base_dilation(1);
4521
4522 ConvolutionDimensionNumbers dnums;
4523 std::vector<int64_t> in_dims;
4524 int in_channel_idx = -1;
4525 // filled in later
4526 dnums.add_input_spatial_dimensions(-1);
4527 dnums.add_output_spatial_dimensions(-1);
4528 dnums.add_input_spatial_dimensions(-1);
4529 dnums.add_output_spatial_dimensions(-1);
4530 for (int i = 0; i < strlen(options.dim_order); ++i) {
4531 char ch = options.dim_order[i];
4532 if (ch == 'N') {
4533 dnums.set_input_batch_dimension(i);
4534 dnums.set_output_batch_dimension(i);
4535 in_dims.push_back(options.in_batch);
4536 } else if (ch == 'H') {
4537 dnums.set_input_spatial_dimensions(0, i);
4538 dnums.set_output_spatial_dimensions(0, i);
4539 in_dims.push_back(options.in_height);
4540 } else if (ch == 'W') {
4541 dnums.set_input_spatial_dimensions(1, i);
4542 dnums.set_output_spatial_dimensions(1, i);
4543 in_dims.push_back(options.in_width);
4544 } else if (ch == 'C') {
4545 dnums.set_input_feature_dimension(i);
4546 dnums.set_output_feature_dimension(i);
4547 in_dims.push_back(options.in_channels);
4548 in_channel_idx = i;
4549 }
4550 }
4551
4552 std::vector<int64_t> f_dims;
4553 dnums.add_kernel_spatial_dimensions(-1); // filled in later
4554 dnums.add_kernel_spatial_dimensions(-1); // filled in later
4555 for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
4556 char ch = options.kernel_dim_order[i];
4557 if (ch == 'H') {
4558 dnums.set_kernel_spatial_dimensions(0, i);
4559 f_dims.push_back(options.f_height);
4560 } else if (ch == 'W') {
4561 dnums.set_kernel_spatial_dimensions(1, i);
4562 f_dims.push_back(options.f_width);
4563 } else if (ch == 'I') {
4564 dnums.set_kernel_input_feature_dimension(i);
4565 f_dims.push_back(options.in_channels);
4566 } else if (ch == 'O') {
4567 dnums.set_kernel_output_feature_dimension(i);
4568 f_dims.push_back(options.f_output_channels);
4569 }
4570 }
4571
4572 auto make_shape = [](absl::Span<const int64_t> dims,
4573 bool minor_to_major_layout) {
4574 if (minor_to_major_layout) {
4575 return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
4576 } else {
4577 return ShapeUtil::MakeShape(F32, dims);
4578 }
4579 };
4580 auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
4581 auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
4582
4583 HloInstruction* input =
4584 b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
4585 HloInstruction* filter =
4586 b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
4587 Shape out_shape = ShapeInference::InferConvolveShape(
4588 in_shape, f_shape, /*feature_group_count=*/1,
4589 /*batch_group_count=*/1, window, dnums,
4590 /*preferred_element_type=*/std::nullopt)
4591 .ValueOrDie();
4592 if (options.output_minor_to_major_layout) {
4593 out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
4594 {0, 1, 2, 3});
4595 }
4596
4597 b.AddInstruction(HloInstruction::CreateConvolve(
4598 out_shape, input, filter,
4599 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
4600 DefaultPrecisionConfig(2)));
4601
4602 auto module = CreateNewVerifiedModule();
4603 auto* computation = module->AddEntryComputation(b.Build());
4604
4605 AlgebraicSimplifierOptions simplifier_options;
4606 simplifier_options.set_is_layout_sensitive(true);
4607 AlgebraicSimplifier simplifier(simplifier_options);
4608 if (!simplifier.Run(module.get()).ValueOrDie()) {
4609 return "NO_CHANGE";
4610 }
4611 auto* root = computation->root_instruction();
4612 if (root->opcode() == HloOpcode::kBitcast &&
4613 root->operand(0)->opcode() == HloOpcode::kDot) {
4614 auto lhs_shape = root->operand(0)->operand(0)->shape();
4615 auto rhs_shape = root->operand(0)->operand(1)->shape();
4616 return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ",
4617 absl::StrJoin(rhs_shape.dimensions(), "x"));
4618 }
4619 return "UNEXPECTED CHANGE";
4620 };
4621
4622 // Default options are the simplest case and succeed.
4623 options.Reset();
4624 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4625
4626 // Swapping dim spatial and batch order works.
4627 options.Reset().dim_order = "NWHC";
4628 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4629 options.Reset().dim_order = "WHNC";
4630 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4631 // Channel dimension earlier fails.
4632 options.Reset().dim_order = "HWCN";
4633 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4634 options.Reset().dim_order = "CHWN";
4635 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4636
4637 // Filtering dims spatial dims can be anywhere, since they are 1x1.
4638 options.Reset().kernel_dim_order = "WHIO";
4639 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4640 options.Reset().kernel_dim_order = "IWOH";
4641 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4642 options.Reset().kernel_dim_order = "IWHO";
4643 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4644 // But moving output channel before input channel fails.
4645 options.Reset().kernel_dim_order = "HWOI";
4646 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4647 options.Reset().kernel_dim_order = "WHOI";
4648 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4649 options.Reset().kernel_dim_order = "OWIH";
4650 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4651 options.Reset().kernel_dim_order = "OWHI";
4652 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4653
4654 // Combine different dim and kernel dim orders.
4655 options.Reset().kernel_dim_order = "IWHO";
4656 options.dim_order = "WHNC";
4657 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4658
4659 // Test invalid cases from wrong filter size, strides, or padding.
4660 options.Reset().f_width = 2;
4661 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4662 options.Reset().f_height = 2;
4663 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4664 options.Reset().row_stride = 2;
4665 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4666 options.Reset().col_stride = 2;
4667 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4668 options.Reset().col_padding = 1;
4669 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4670 options.Reset().row_padding = 1;
4671 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4672
4673 // The default dim_order is "NHWC". Col-major layout makes C the most major.
4674 options.Reset().input_minor_to_major_layout = true;
4675 options.output_minor_to_major_layout = true;
4676 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4677
4678 // The input and output have different layouts.
4679 options.Reset().input_minor_to_major_layout = true;
4680 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4681
4682 // C is most minor, and I is more major than O.
4683 options.Reset().input_minor_to_major_layout = true;
4684 options.filter_minor_to_major_layout = true;
4685 options.output_minor_to_major_layout = true;
4686 options.dim_order = "CHWN";
4687 options.kernel_dim_order = "OIHW";
4688 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
4689
4690 // C is not the most minor dimension.
4691 options.Reset().input_minor_to_major_layout = true;
4692 options.filter_minor_to_major_layout = true;
4693 options.output_minor_to_major_layout = true;
4694 options.dim_order = "HWNC";
4695 options.kernel_dim_order = "OIHW";
4696 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4697
4698 // I is more minor than O.
4699 options.Reset().input_minor_to_major_layout = true;
4700 options.filter_minor_to_major_layout = true;
4701 options.output_minor_to_major_layout = true;
4702 options.dim_order = "CHWN";
4703 options.kernel_dim_order = "IOHW";
4704 EXPECT_EQ("NO_CHANGE", build_and_simplify());
4705 }
4706
4707 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
4708 // broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToSlice)4709 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
4710 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
4711 HloComputation::Builder builder(TestName());
4712 HloInstruction* scalar_param = builder.AddInstruction(
4713 HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
4714
4715 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
4716 HloInstruction* broadcast = builder.AddInstruction(
4717 HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
4718
4719 Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
4720 HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
4721 slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
4722
4723 auto module = CreateNewVerifiedModule();
4724 auto computation = module->AddEntryComputation(builder.Build());
4725
4726 HloInstruction* root = computation->root_instruction();
4727 EXPECT_EQ(root, slice);
4728 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
4729
4730 AlgebraicSimplifier simplifier(default_options_);
4731
4732 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4733
4734 // Running simplification again should not result in any further changes.
4735 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4736 EXPECT_THAT(computation->root_instruction(),
4737 GmockMatch(m::Broadcast(m::Op().Is(scalar_param))
4738 .WithShapeEqualTo(&slice_shape)));
4739 }
4740
4741 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
4742 // single broadcast.
TEST_F(AlgebraicSimplifierTest,ScalarBroadcastToTransposeReshape)4743 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
4744 HloComputation::Builder builder(TestName());
4745 HloInstruction* forty_two = builder.AddInstruction(
4746 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
4747
4748 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
4749 HloInstruction* broadcast = builder.AddInstruction(
4750 HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
4751
4752 HloInstruction* transpose =
4753 builder.AddInstruction(HloInstruction::CreateTranspose(
4754 ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
4755
4756 Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
4757 HloInstruction* reshape = builder.AddInstruction(
4758 HloInstruction::CreateReshape(reshape_shape, transpose));
4759
4760 auto module = CreateNewVerifiedModule();
4761 auto computation = module->AddEntryComputation(builder.Build());
4762
4763 HloInstruction* root = computation->root_instruction();
4764 EXPECT_EQ(root, reshape);
4765 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
4766
4767 AlgebraicSimplifier simplifier(default_options_);
4768 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4769 EXPECT_THAT(computation->root_instruction(),
4770 GmockMatch(m::Broadcast(m::Op().Is(forty_two))
4771 .WithShapeEqualTo(&reshape_shape)));
4772 }
4773
4774 // Test that a depth-to-space transformation expressed as
4775 // reshape(transpose(reshape(op))) can simplify to
4776 // reshape(concat(slice(op), ..., slice(op))).
TEST_F(AlgebraicSimplifierTest,TransposeReshapeToConcatSlice)4777 TEST_F(AlgebraicSimplifierTest, TransposeReshapeToConcatSlice) {
4778 const std::string& hlo_string = R"(
4779 HloModule TransposeReshapeDepthToSpace
4780
4781 ENTRY entry {
4782 %param = f32[8,14,14,128]{0,1,2,3} parameter(0)
4783 %reshape.1 = f32[8,14,14,2,64] reshape(%param)
4784 %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4}
4785 ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose)
4786 }
4787 )";
4788 TF_ASSERT_OK_AND_ASSIGN(auto module,
4789 ParseAndReturnVerifiedModule(hlo_string));
4790
4791 AlgebraicSimplifier simplifier(default_options_);
4792 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4793
4794 Shape result_shape = ShapeUtil::MakeShape(F32, {8, 28, 14, 64});
4795 EXPECT_THAT(module->entry_computation()->root_instruction(),
4796 GmockMatch(m::Reshape(m::Concatenate(m::Slice(m::Parameter(0)),
4797 m::Slice(m::Parameter(0))))
4798 .WithShapeEqualTo(&result_shape)));
4799 }
4800
4801 // Test that a depth-to-space transformation expressed as
4802 // reshape(transpose(reshape(op))) with a large number of chunks
4803 // is not rewritten.
TEST_F(AlgebraicSimplifierTest,TransposeReshapeTooLarge)4804 TEST_F(AlgebraicSimplifierTest, TransposeReshapeTooLarge) {
4805 const std::string& hlo_string = R"(
4806 HloModule TransposeReshapeDepthToSpaceBig
4807
4808 ENTRY entry {
4809 %param = f32[8,14,14,128]{0,1,2,3} parameter(0)
4810 %reshape.1 = f32[8,14,14,8,16] reshape(%param)
4811 %transpose = transpose(%reshape.1), dimensions={0,1,3,2,4}
4812 ROOT %reshape.2 = f32[8,112,14,16] reshape(%transpose)
4813 }
4814 )";
4815 TF_ASSERT_OK_AND_ASSIGN(auto module,
4816 ParseAndReturnVerifiedModule(hlo_string));
4817
4818 AlgebraicSimplifier simplifier(default_options_);
4819 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4820 }
4821
4822 // Test that a reshape(transpose(reshape(op))) that does not constitute a
4823 // depth-to-space transformation is not rewritten.
TEST_F(AlgebraicSimplifierTest,TransposeReshapeNotDepthToSpace)4824 TEST_F(AlgebraicSimplifierTest, TransposeReshapeNotDepthToSpace) {
4825 const std::string& hlo_string = R"(
4826 HloModule TransposeReshapeDepthToSpace
4827
4828 ENTRY entry {
4829 %param = f32[8,14,14,128]{0,1,2,3} parameter(0)
4830 %reshape.1 = f32[8,14,14,2,64] reshape(%param)
4831 %transpose = transpose(%reshape.1), dimensions={0,3,1,2,4}
4832 ROOT %reshape.2 = f32[8,28,14,64] reshape(%transpose)
4833 }
4834 )";
4835 TF_ASSERT_OK_AND_ASSIGN(auto module,
4836 ParseAndReturnVerifiedModule(hlo_string));
4837
4838 AlgebraicSimplifier simplifier(default_options_);
4839 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
4840 }
4841
4842 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest,FoldPadIntoReduceWindow)4843 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
4844 const std::string& hlo_string = R"(
4845 HloModule test
4846 fn {
4847 p0 = f32[] parameter(0)
4848 p1 = f32[] parameter(1)
4849 ROOT add = f32[] add(p0, p1)
4850 }
4851 ENTRY entry {
4852 param = f32[1,2,3,4] parameter(0)
4853 const = f32[] constant(5)
4854 pad = pad(param, const), padding=0_0x1_0x0_0x0_2
4855 ROOT r = reduce-window(pad, const), to_apply=fn, window={size=2x2x2x2 lhs_dilate=1x1x1x3 pad=10_100x10_100x10_100x10_100}
4856 }
4857 )";
4858 TF_ASSERT_OK_AND_ASSIGN(auto module,
4859 ParseAndReturnVerifiedModule(hlo_string));
4860 AlgebraicSimplifier simplifier(default_options_);
4861 ASSERT_TRUE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4862 // Running simplification again should not result in any further changes.
4863 ASSERT_FALSE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4864
4865 // Verify the result
4866 HloInstruction* root = module->entry_computation()->root_instruction();
4867 EXPECT_THAT(root,
4868 GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant())));
4869 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4870 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4871 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4872 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4873 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4874 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4875 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4876 EXPECT_EQ(root->window().dimensions(3).padding_high(), 106);
4877 }
4878
4879 // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
4880 // ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest,FoldConvertedPadIntoReduceWindow)4881 TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
4882 const std::string& hlo_string = R"(
4883 HloModule test
4884 fn {
4885 p0 = f32[] parameter(0)
4886 p1 = f32[] parameter(1)
4887 ROOT add = f32[] add(p0, p1)
4888 }
4889 ENTRY entry {
4890 param = bf16[1,2,3,4] parameter(0)
4891 const = bf16[] constant(5)
4892 pad = pad(param, const), padding=0_0x1_0x0_0x0_2
4893 converted = f32[1,3,3,6] convert(pad)
4894 ROOT r = reduce-window(converted, const), to_apply=fn, window={size=2x2x2x2 pad=10_100x10_100x10_100x10_100}
4895 }
4896 )";
4897 TF_ASSERT_OK_AND_ASSIGN(auto module,
4898 ParseAndReturnVerifiedModule(hlo_string));
4899 AlgebraicSimplifier simplifier(default_options_);
4900 ASSERT_TRUE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4901 // Running simplification again should not result in any further changes.
4902 ASSERT_FALSE(RunHloPass(&simplifier, module.get()).ValueOrDie());
4903
4904 // Verify the result
4905 HloInstruction* root = module->entry_computation()->root_instruction();
4906 EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)),
4907 m::Constant())));
4908 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
4909 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
4910 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
4911 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
4912 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
4913 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
4914 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
4915 EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
4916 }
4917
TEST_F(AlgebraicSimplifierTest,ReversalOfTrivialDimensionsToBitcast)4918 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
4919 HloComputation::Builder builder(TestName());
4920 const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
4921 HloInstruction* a =
4922 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
4923 builder.AddInstruction(
4924 HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
4925
4926 auto module = CreateNewVerifiedModule();
4927 auto computation = module->AddEntryComputation(builder.Build());
4928
4929 AlgebraicSimplifier simplifier(default_options_);
4930 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
4931
4932 HloInstruction* root = computation->root_instruction();
4933 EXPECT_EQ(a, root);
4934 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
4935 }
4936
TEST_F(AlgebraicSimplifierTest,IteratorInvalidation)4937 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
4938 // Dots add computations to the parent module. Test that, when the HloModule's
4939 // computations are updated, then iterator invalidation doesn't occur
4940 // when running on subsequent computations.
4941 auto m = CreateNewVerifiedModule();
4942 Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
4943 HloComputation::Builder builder(TestName() + ".Dot");
4944 HloInstruction* x =
4945 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
4946 HloInstruction* y =
4947 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
4948 DotDimensionNumbers dot_dnums;
4949 dot_dnums.add_lhs_batch_dimensions(0);
4950 dot_dnums.add_rhs_batch_dimensions(0);
4951 builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
4952 DefaultPrecisionConfig(2)));
4953 std::unique_ptr<HloComputation> dot_computation(builder.Build());
4954
4955 HloComputation::Builder call_builder(TestName() + ".Call");
4956 HloInstruction* zero = call_builder.AddInstruction(
4957 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
4958 HloInstruction* one = call_builder.AddInstruction(
4959 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
4960 call_builder.AddInstruction(
4961 HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
4962
4963 m->AddEmbeddedComputation(std::move(dot_computation));
4964 m->AddEntryComputation(call_builder.Build());
4965 AlgebraicSimplifier simplifier(default_options_);
4966 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4967 }
4968
4969 // Test that a constant with tuple shape becomes a tuple of constants.
TEST_F(AlgebraicSimplifierTest,ConstantTupleBecomesTupleOfConstants)4970 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
4971 auto m = CreateNewVerifiedModule();
4972 HloComputation::Builder builder(TestName());
4973 const float constant_scalar = 7.3f;
4974 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
4975 Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
4976 LiteralUtil::CreateR1<float>(constant_vector)};
4977 Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
4978 builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
4979
4980 auto computation = m->AddEntryComputation(builder.Build());
4981
4982 AlgebraicSimplifier simplifier(default_options_);
4983 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
4984 EXPECT_THAT(computation->root_instruction(),
4985 GmockMatch(m::Tuple(m::Constant(), m::Constant())));
4986 }
4987
4988 // A dynamic-slice is trivial if its start indices are all zeroes and the size
4989 // of its input equals the size of its output. In this case, the dynamic slice
4990 // is equal to its input.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicSlice)4991 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
4992 auto m = CreateNewVerifiedModule();
4993 HloComputation::Builder builder(TestName());
4994
4995 Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
4996 std::vector<HloInstruction*> params;
4997 for (int i = 0; i < 3; ++i) {
4998 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
4999 i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
5000 }
5001 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5002 shape,
5003 builder.AddInstruction(
5004 HloInstruction::CreateParameter(0, shape, "slice_from")),
5005 params,
5006 /*slice_sizes=*/{10, 100, 1000}));
5007
5008 auto computation = m->AddEntryComputation(builder.Build());
5009 AlgebraicSimplifier simplifier(default_options_);
5010 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5011 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter()));
5012 }
5013
TEST_F(AlgebraicSimplifierTest,ConstantDynamicSlice)5014 TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) {
5015 auto m = CreateNewVerifiedModule();
5016 HloComputation::Builder builder(TestName());
5017
5018 Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
5019 std::vector<HloInstruction*> params;
5020 for (int i = 0; i < 3; ++i) {
5021 params.push_back(builder.AddInstruction(HloInstruction::CreateConstant(
5022 LiteralUtil::CreateR0<int32_t>(2 << (i + 1)))));
5023 }
5024 Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200});
5025 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5026 ds_shape,
5027 builder.AddInstruction(
5028 HloInstruction::CreateParameter(0, shape, "operand")),
5029 params,
5030 /*slice_sizes=*/{2, 20, 200}));
5031
5032 auto computation = m->AddEntryComputation(builder.Build());
5033 AlgebraicSimplifier simplifier(default_options_);
5034 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5035 EXPECT_THAT(computation->root_instruction(),
5036 GmockMatch(m::Slice(m::Parameter())));
5037 }
5038
5039 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
5040 // size of its "update" equals the size of its output. In this case, the
5041 // dynamic-update-slice is equal to its update.
TEST_F(AlgebraicSimplifierTest,TrivialDynamicUpdateSlice)5042 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
5043 auto m = CreateNewVerifiedModule();
5044 HloComputation::Builder builder(TestName());
5045
5046 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
5047 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
5048
5049 std::vector<HloInstruction*> slice_indices, update_indices;
5050 for (int i = 0; i < 3; ++i) {
5051 slice_indices.push_back(
5052 builder.AddInstruction(HloInstruction::CreateParameter(
5053 i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
5054 update_indices.push_back(
5055 builder.AddInstruction(HloInstruction::CreateParameter(
5056 i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices")));
5057 }
5058 HloInstruction* slice =
5059 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
5060 slice_shape,
5061 builder.AddInstruction(
5062 HloInstruction::CreateParameter(0, full_shape, "slice_from")),
5063 slice_indices,
5064 /*slice_sizes=*/{10, 1, 1000}));
5065
5066 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
5067 slice_shape,
5068 builder.AddInstruction(
5069 HloInstruction::CreateParameter(4, slice_shape, "to_update")),
5070 slice, update_indices));
5071
5072 auto computation = m->AddEntryComputation(builder.Build());
5073 AlgebraicSimplifier simplifier(default_options_);
5074 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5075 EXPECT_THAT(computation->root_instruction(),
5076 GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(),
5077 m::Parameter(), m::Parameter())));
5078 }
5079
5080 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts)5081 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
5082 auto m = CreateNewVerifiedModule();
5083 HloComputation::Builder builder(TestName());
5084 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
5085 HloInstruction* input_array = builder.AddInstruction(
5086 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
5087 HloInstruction* inner_bcast = builder.AddInstruction(
5088 HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
5089 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
5090 builder.AddInstruction(
5091 HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2}));
5092
5093 auto computation = m->AddEntryComputation(builder.Build());
5094 HloInstruction* root = computation->root_instruction();
5095 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5096 AlgebraicSimplifier simplifier(default_options_);
5097 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5098 root = computation->root_instruction();
5099 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5100 EXPECT_THAT(root->dimensions(), ElementsAre(2));
5101 }
5102
5103 // Test that two consecutive broadcasts can be merged to one.
TEST_F(AlgebraicSimplifierTest,MergeBroadcasts2)5104 TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) {
5105 auto m = CreateNewVerifiedModule();
5106 HloComputation::Builder builder(TestName());
5107 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3});
5108 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
5109 HloInstruction* param0 = builder.AddInstruction(
5110 HloInstruction::CreateParameter(0, r2f32, "param0"));
5111 // The initial dimensions go to places 0 and 2 in the 3-dim array,
5112 // and to places 1 and 3 in the 4-dim array,
5113 HloInstruction* inner_bcast = builder.AddInstruction(
5114 HloInstruction::CreateBroadcast(r3f32, param0, {0, 2}));
5115 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
5116 builder.AddInstruction(
5117 HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3}));
5118
5119 auto computation = m->AddEntryComputation(builder.Build());
5120 HloInstruction* root = computation->root_instruction();
5121 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5122 AlgebraicSimplifier simplifier(default_options_);
5123 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5124 root = computation->root_instruction();
5125 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0))));
5126 EXPECT_THAT(root->dimensions(), ElementsAre(1, 3));
5127 }
5128
5129 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota)5130 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) {
5131 auto m = CreateNewVerifiedModule();
5132 HloComputation::Builder builder(TestName());
5133 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
5134 HloInstruction* iota =
5135 builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1));
5136 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
5137 builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2}));
5138
5139 auto computation = m->AddEntryComputation(builder.Build());
5140 HloInstruction* root = computation->root_instruction();
5141 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5142 AlgebraicSimplifier simplifier(default_options_);
5143 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5144 root = computation->root_instruction();
5145 EXPECT_THAT(root, GmockMatch(m::Iota()));
5146 EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
5147 }
5148
5149 // Test that a broadcast of an iota can be merged to one iota.
TEST_F(AlgebraicSimplifierTest,MergeBroadcastAndIota2)5150 TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) {
5151 auto m = CreateNewVerifiedModule();
5152 HloComputation::Builder builder(TestName());
5153 Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3});
5154 HloInstruction* iota =
5155 builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1));
5156 Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3});
5157 builder.AddInstruction(
5158 HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3}));
5159
5160 auto computation = m->AddEntryComputation(builder.Build());
5161 HloInstruction* root = computation->root_instruction();
5162 EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
5163 AlgebraicSimplifier simplifier(default_options_);
5164 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
5165 root = computation->root_instruction();
5166 EXPECT_THAT(root, GmockMatch(m::Iota()));
5167 EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2);
5168 }
5169
TEST_F(AlgebraicSimplifierTest,TransposeOfDot)5170 TEST_F(AlgebraicSimplifierTest, TransposeOfDot) {
5171 const char* hlo_string = R"(
5172 HloModule module
5173
5174 ENTRY test {
5175 lhs = f32[3,4,5] parameter(0)
5176 rhs = f32[6,3,4] parameter(1)
5177 dot = f32[5,6] dot(lhs,rhs), lhs_contracting_dims={0,1}, rhs_contracting_dims={1,2}, operand_precision={highest,high}
5178 ROOT transpose = f32[6,5] transpose(dot), dimensions={1,0}
5179 }
5180 )";
5181 TF_ASSERT_OK_AND_ASSIGN(auto module,
5182 ParseAndReturnVerifiedModule(hlo_string));
5183
5184 AlgebraicSimplifierOptions options;
5185 AlgebraicSimplifier simplifier(options);
5186 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5187 auto root = module->entry_computation()->root_instruction();
5188 const HloInstruction* dot;
5189 ASSERT_THAT(root, GmockMatch(m::Dot(&dot, m::Parameter(1), m::Parameter(0))));
5190 EXPECT_EQ(dot->precision_config().operand_precision()[0],
5191 PrecisionConfig::HIGH);
5192 EXPECT_EQ(dot->precision_config().operand_precision()[1],
5193 PrecisionConfig::HIGHEST);
5194 }
5195
TEST_F(AlgebraicSimplifierTest,TransposeOfBatchDot)5196 TEST_F(AlgebraicSimplifierTest, TransposeOfBatchDot) {
5197 const char* hlo_string = R"(
5198 HloModule module
5199
5200 ENTRY test {
5201 lhs = f32[10,20,30,40] parameter(0)
5202 rhs = f32[10,20,50,30] parameter(1)
5203 dot = dot(lhs,rhs), lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5204 lhs_contracting_dims={2}, rhs_contracting_dims={3},
5205 operand_precision={high, default}
5206 ROOT transpose = transpose(dot), dimensions={0,1,3,2}
5207 }
5208 )";
5209 TF_ASSERT_OK_AND_ASSIGN(auto module,
5210 ParseAndReturnVerifiedModule(hlo_string));
5211
5212 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5213 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5214 EXPECT_TRUE(changed);
5215 const HloInstruction* dot;
5216 ASSERT_THAT(module->entry_computation()->root_instruction(),
5217 GmockMatch(m::Dot(&dot, m::Parameter(1), m::Parameter(0))));
5218 DotDimensionNumbers dnums = dot->dot_dimension_numbers();
5219 EXPECT_THAT(dnums.lhs_batch_dimensions(), ElementsAre(0, 1));
5220 EXPECT_THAT(dnums.rhs_batch_dimensions(), ElementsAre(0, 1));
5221 EXPECT_THAT(dnums.lhs_contracting_dimensions(), ElementsAre(3));
5222 EXPECT_THAT(dnums.rhs_contracting_dimensions(), ElementsAre(2));
5223 EXPECT_EQ(dot->precision_config().operand_precision()[0],
5224 PrecisionConfig::DEFAULT);
5225 EXPECT_EQ(dot->precision_config().operand_precision()[1],
5226 PrecisionConfig::HIGH);
5227 }
5228
TEST_F(AlgebraicSimplifierTest,TransposeOfBatchDimsInBatchDotCantSimplify)5229 TEST_F(AlgebraicSimplifierTest, TransposeOfBatchDimsInBatchDotCantSimplify) {
5230 const char* hlo_string = R"(
5231 HloModule module
5232
5233 ENTRY test {
5234 lhs = f32[10,20,30,40] parameter(0)
5235 rhs = f32[10,20,50,30] parameter(1)
5236 dot = dot(lhs,rhs), lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5237 lhs_contracting_dims={2}, rhs_contracting_dims={3}
5238 ROOT transpose = transpose(dot), dimensions={1,0,3,2}
5239 }
5240 )";
5241 TF_ASSERT_OK_AND_ASSIGN(auto module,
5242 ParseAndReturnVerifiedModule(hlo_string));
5243
5244 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5245 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5246 EXPECT_FALSE(changed);
5247 }
5248
TEST_F(AlgebraicSimplifierTest,TransposeOfNonCanonicalBatchDotCantSimplify)5249 TEST_F(AlgebraicSimplifierTest, TransposeOfNonCanonicalBatchDotCantSimplify) {
5250 const char* hlo_string = R"(
5251 HloModule module
5252
5253 ENTRY test {
5254 p0 = f32[13,11,2,3] parameter(0)
5255 p1 = f32[13,11,3,7,5] parameter(1)
5256 dot1 = dot(p0, p1), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
5257 dot2 = dot(p1, p0), rhs_batch_dims={0,1}, rhs_contracting_dims={3}, lhs_batch_dims={0,1}, lhs_contracting_dims={2}
5258 trans1 = transpose(dot1), dimensions={0,1,2,4,3}
5259 trans2 = transpose(dot2), dimensions={0,1,2,4,3}
5260 ROOT root = tuple(trans1, trans2)
5261 }
5262 )";
5263 TF_ASSERT_OK_AND_ASSIGN(auto module,
5264 ParseAndReturnVerifiedModule(hlo_string));
5265
5266 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5267 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5268 EXPECT_FALSE(changed);
5269 }
5270
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfTranspose)5271 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTranspose) {
5272 const char* hlo_string = R"(
5273 HloModule module
5274
5275 ENTRY test {
5276 param = f32[12,10,8] parameter(0)
5277 i0 = s32[] parameter(1)
5278 i1 = s32[] parameter(2)
5279 i2 = s32[] parameter(3)
5280 transpose = f32[12,8,10] transpose(param), dimensions={0,2,1}
5281 ROOT slice = f32[2,3,5] dynamic-slice(transpose, i0, i1, i2),
5282 dynamic_slice_sizes={2,3,5}
5283 }
5284 )";
5285 TF_ASSERT_OK_AND_ASSIGN(auto module,
5286 ParseAndReturnVerifiedModule(hlo_string));
5287
5288 AlgebraicSimplifierOptions options;
5289 AlgebraicSimplifier simplifier(options);
5290 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5291 auto root = module->entry_computation()->root_instruction();
5292 EXPECT_THAT(root, GmockMatch(m::Transpose(
5293 m::DynamicSlice(m::Parameter(0), m::Parameter(1),
5294 m::Parameter(3), m::Parameter(2)))));
5295 }
5296
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfTrivialReshape)5297 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTrivialReshape) {
5298 const char* hlo_string = R"(
5299 HloModule module
5300
5301 ENTRY test {
5302 param = f32[12,10,1,8] parameter(0)
5303 i0 = s32[] parameter(1)
5304 i1 = s32[] parameter(2)
5305 i2 = s32[] parameter(3)
5306 z = s32[] constant(0)
5307 reshape = f32[1,12,10,8] reshape(param)
5308 ROOT slice = f32[1,2,3,5] dynamic-slice(reshape, z, i0, i1, i2),
5309 dynamic_slice_sizes={1,2,3,5}
5310 }
5311 )";
5312 TF_ASSERT_OK_AND_ASSIGN(auto module,
5313 ParseAndReturnVerifiedModule(hlo_string));
5314
5315 AlgebraicSimplifierOptions options;
5316 options.set_is_layout_sensitive(false);
5317 AlgebraicSimplifier simplifier(options);
5318 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5319 auto root = module->entry_computation()->root_instruction();
5320 EXPECT_THAT(root, GmockMatch(m::Reshape(m::DynamicSlice(
5321 m::Parameter(0), m::Parameter(1), m::Parameter(2),
5322 m::Constant(), m::Parameter(3)))));
5323 }
5324
TEST_F(AlgebraicSimplifierTest,SliceOfPadLow)5325 TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) {
5326 const char* hlo_string = R"(
5327 HloModule module
5328
5329 ENTRY test {
5330 param = f32[3,4] parameter(0)
5331 constant = f32[] constant(0.0)
5332 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5333 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]}
5334 }
5335 )";
5336 TF_ASSERT_OK_AND_ASSIGN(auto module,
5337 ParseAndReturnVerifiedModule(hlo_string));
5338
5339 AlgebraicSimplifierOptions options;
5340 AlgebraicSimplifier simplifier(options);
5341 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5342 auto root = module->entry_computation()->root_instruction();
5343 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5344 }
5345
TEST_F(AlgebraicSimplifierTest,SliceOfPadHigh)5346 TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) {
5347 const char* hlo_string = R"(
5348 HloModule module
5349
5350 ENTRY test {
5351 param = f32[3,4] parameter(0)
5352 constant = f32[] constant(0.0)
5353 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5354 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]}
5355 }
5356 )";
5357 TF_ASSERT_OK_AND_ASSIGN(auto module,
5358 ParseAndReturnVerifiedModule(hlo_string));
5359
5360 AlgebraicSimplifierOptions options;
5361 AlgebraicSimplifier simplifier(options);
5362 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5363 auto root = module->entry_computation()->root_instruction();
5364 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5365 }
5366
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidNonScalar)5367 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) {
5368 const char* hlo_string = R"(
5369 HloModule module
5370
5371 ENTRY test {
5372 param = f32[3,4] parameter(0)
5373 constant = f32[] constant(0.0)
5374 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5375 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]}
5376 }
5377 )";
5378 TF_ASSERT_OK_AND_ASSIGN(auto module,
5379 ParseAndReturnVerifiedModule(hlo_string));
5380
5381 AlgebraicSimplifierOptions options;
5382 AlgebraicSimplifier simplifier(options);
5383 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5384 EXPECT_THAT(module->entry_computation()->root_instruction(),
5385 GmockMatch(m::Slice(m::Parameter(0))));
5386 }
5387
TEST_F(AlgebraicSimplifierTest,SliceOfPad)5388 TEST_F(AlgebraicSimplifierTest, SliceOfPad) {
5389 const char* hlo_string = R"(
5390 HloModule module
5391
5392 ENTRY test {
5393 param = f32[3,4] parameter(0)
5394 constant = f32[] constant(0.0)
5395 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5396 ROOT slice = f32[2,3] slice(f32[8,10] pad), slice={[4:6],[2:5]}
5397 }
5398 )";
5399 TF_ASSERT_OK_AND_ASSIGN(auto module,
5400 ParseAndReturnVerifiedModule(hlo_string));
5401
5402 AlgebraicSimplifierOptions options;
5403 AlgebraicSimplifier simplifier(options);
5404 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5405 auto root = module->entry_computation()->root_instruction();
5406 EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(0))));
5407 EXPECT_THAT(root->slice_starts(), ElementsAre(1, 1));
5408 }
5409
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalarConstant)5410 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) {
5411 const char* hlo_string = R"(
5412 HloModule module
5413
5414 ENTRY test {
5415 param = f32[3,4] parameter(0)
5416 constant = f32[] constant(0.0)
5417 pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5
5418 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]}
5419 }
5420 )";
5421 TF_ASSERT_OK_AND_ASSIGN(auto module,
5422 ParseAndReturnVerifiedModule(hlo_string));
5423
5424 AlgebraicSimplifierOptions options;
5425 AlgebraicSimplifier simplifier(options);
5426 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5427 auto root = module->entry_computation()->root_instruction();
5428 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant())));
5429 }
5430
TEST_F(AlgebraicSimplifierTest,SliceOfPadMidScalar)5431 TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) {
5432 const char* hlo_string = R"(
5433 HloModule module
5434
5435 ENTRY test {
5436 param = f32[1,1] parameter(0)
5437 constant = f32[] constant(0.0)
5438 pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5
5439 ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]}
5440 }
5441 )";
5442 TF_ASSERT_OK_AND_ASSIGN(auto module,
5443 ParseAndReturnVerifiedModule(hlo_string));
5444
5445 AlgebraicSimplifierOptions options;
5446 AlgebraicSimplifier simplifier(options);
5447 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5448 auto root = module->entry_computation()->root_instruction();
5449 EXPECT_THAT(root, GmockMatch(m::Parameter()));
5450 }
5451
TEST_F(AlgebraicSimplifierTest,SliceOfPadSomeDimsInPadding)5452 TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) {
5453 const char* hlo_string = R"(
5454 HloModule module
5455
5456 ENTRY entry () -> f32[1]{0} {
5457 constant.val = f32[] constant(4)
5458 constant.pad = f32[] constant(-7)
5459 reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val)
5460 pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0
5461 slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]}
5462 ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice)
5463 }
5464 )";
5465 TF_ASSERT_OK_AND_ASSIGN(auto module,
5466 ParseAndReturnVerifiedModule(hlo_string));
5467
5468 AlgebraicSimplifierOptions options;
5469 AlgebraicSimplifier simplifier(options);
5470 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5471 auto root = module->entry_computation()->root_instruction();
5472 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::ConstantScalar(-7.0))));
5473 }
5474
TEST_F(AlgebraicSimplifierTest,SliceOfConcatScalarInput)5475 TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) {
5476 const char* hlo_string = R"(
5477 HloModule module
5478
5479 ENTRY test {
5480 param.0 = f32[2] parameter(0)
5481 param.1 = f32[1] parameter(1)
5482 param.2 = f32[3] parameter(2)
5483 concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
5484 ROOT slice = f32[1] slice(concat), slice={[2:3]}
5485 }
5486 )";
5487 TF_ASSERT_OK_AND_ASSIGN(auto module,
5488 ParseAndReturnVerifiedModule(hlo_string));
5489
5490 AlgebraicSimplifierOptions options;
5491 AlgebraicSimplifier simplifier(options);
5492 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5493 auto root = module->entry_computation()->root_instruction();
5494 EXPECT_THAT(root, GmockMatch(m::Parameter(1)));
5495 }
5496
TEST_F(AlgebraicSimplifierTest,SliceOfConcatNonScalarInput)5497 TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) {
5498 const char* hlo_string = R"(
5499 HloModule module
5500
5501 ENTRY test {
5502 param.0 = f32[2] parameter(0)
5503 param.1 = f32[1] parameter(1)
5504 param.2 = f32[3] parameter(2)
5505 concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0}
5506 ROOT slice = f32[1] slice(concat), slice={[4:5]}
5507 }
5508 )";
5509 TF_ASSERT_OK_AND_ASSIGN(auto module,
5510 ParseAndReturnVerifiedModule(hlo_string));
5511
5512 AlgebraicSimplifierOptions options;
5513 AlgebraicSimplifier simplifier(options);
5514 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5515 auto root = module->entry_computation()->root_instruction();
5516 EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2))));
5517 EXPECT_EQ(root->slice_starts(0), 1);
5518 EXPECT_EQ(root->slice_limits(0), 2);
5519 }
5520
TEST_F(AlgebraicSimplifierTest,ConcatToBroadcast)5521 TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) {
5522 const char* hlo_string = R"(
5523 HloModule module
5524
5525 ENTRY test {
5526 p = f32[2,1,4] parameter(0)
5527 ROOT concat = f32[2,6,4] concatenate(p,p,p,p,p,p), dimensions={1}
5528 }
5529 )";
5530 TF_ASSERT_OK_AND_ASSIGN(auto module,
5531 ParseAndReturnVerifiedModule(hlo_string));
5532
5533 AlgebraicSimplifierOptions options;
5534 AlgebraicSimplifier simplifier(options);
5535 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5536 auto root = module->entry_computation()->root_instruction();
5537 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
5538 }
5539
TEST_F(AlgebraicSimplifierTest,NegateNegate)5540 TEST_F(AlgebraicSimplifierTest, NegateNegate) {
5541 const char* hlo_string = R"(
5542 HloModule module
5543
5544 ENTRY test {
5545 param.0 = f32[2] parameter(0)
5546 neg.0 = f32[2] negate(param.0)
5547 ROOT neg.1 = f32[2] negate(neg.0)
5548 }
5549 )";
5550 TF_ASSERT_OK_AND_ASSIGN(auto module,
5551 ParseAndReturnVerifiedModule(hlo_string));
5552
5553 AlgebraicSimplifierOptions options;
5554 AlgebraicSimplifier simplifier(options);
5555 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5556 auto root = module->entry_computation()->root_instruction();
5557 EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5558 }
5559
TEST_F(AlgebraicSimplifierTest,NotNot)5560 TEST_F(AlgebraicSimplifierTest, NotNot) {
5561 const char* hlo_string = R"(
5562 HloModule module
5563
5564 ENTRY test {
5565 param.0 = pred[2] parameter(0)
5566 not.0 = pred[2] not(param.0)
5567 ROOT not.1 = pred[2] not(not.0)
5568 }
5569 )";
5570 TF_ASSERT_OK_AND_ASSIGN(auto module,
5571 ParseAndReturnVerifiedModule(hlo_string));
5572
5573 AlgebraicSimplifierOptions options;
5574 AlgebraicSimplifier simplifier(options);
5575 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
5576 auto root = module->entry_computation()->root_instruction();
5577 EXPECT_THAT(root, GmockMatch(m::Parameter(0)));
5578 }
5579
TEST_F(AlgebraicSimplifierTest,BatchDotTransposeOperands)5580 TEST_F(AlgebraicSimplifierTest, BatchDotTransposeOperands) {
5581 const char* hlo_string = R"(
5582 HloModule module
5583
5584 ENTRY test {
5585 lhs = f32[10,20,30,40] parameter(0)
5586 rhs = f32[10,20,50,30] parameter(1)
5587 lhs_t = transpose(lhs), dimensions={0,1,3,2}
5588 rhs_t = transpose(rhs), dimensions={0,1,3,2}
5589 ROOT root = dot(lhs_t, rhs_t),
5590 lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5591 lhs_contracting_dims={3}, rhs_contracting_dims={2},
5592 operand_precision={default, high}
5593 }
5594 )";
5595 TF_ASSERT_OK_AND_ASSIGN(auto module,
5596 ParseAndReturnVerifiedModule(hlo_string));
5597
5598 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5599 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5600 EXPECT_TRUE(changed);
5601 const HloInstruction* dot;
5602 ASSERT_THAT(
5603 module->entry_computation()->root_instruction(),
5604 GmockMatch(m::Transpose(m::Dot(&dot, m::Parameter(1), m::Parameter(0)))));
5605 EXPECT_EQ(dot->precision_config().operand_precision()[0],
5606 PrecisionConfig::HIGH);
5607 EXPECT_EQ(dot->precision_config().operand_precision()[1],
5608 PrecisionConfig::DEFAULT);
5609 }
5610
TEST_F(AlgebraicSimplifierTest,BatchDotTransposeBatchDims)5611 TEST_F(AlgebraicSimplifierTest, BatchDotTransposeBatchDims) {
5612 const char* hlo_string = R"(
5613 HloModule module
5614
5615 ENTRY test {
5616 lhs = f32[10,20,40,30] parameter(0)
5617 rhs = f32[10,20,30,50] parameter(1)
5618 lhs_t = transpose(lhs), dimensions={1,0,2,3}
5619 rhs_t = transpose(rhs), dimensions={1,0,2,3}
5620 ROOT root = dot(lhs_t, rhs_t),
5621 lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5622 lhs_contracting_dims={3}, rhs_contracting_dims={2},
5623 operand_precision={default, high}
5624 }
5625 )";
5626 TF_ASSERT_OK_AND_ASSIGN(auto module,
5627 ParseAndReturnVerifiedModule(hlo_string));
5628
5629 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5630 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5631 EXPECT_TRUE(changed);
5632 const HloInstruction* dot;
5633 ASSERT_THAT(
5634 module->entry_computation()->root_instruction(),
5635 GmockMatch(m::Transpose(m::Dot(&dot, m::Parameter(0), m::Parameter(1)))));
5636 EXPECT_EQ(dot->precision_config().operand_precision()[0],
5637 PrecisionConfig::DEFAULT);
5638 EXPECT_EQ(dot->precision_config().operand_precision()[1],
5639 PrecisionConfig::HIGH);
5640 }
5641
TEST_F(AlgebraicSimplifierTest,BatchDotTransposeBatchDimsAndOperands)5642 TEST_F(AlgebraicSimplifierTest, BatchDotTransposeBatchDimsAndOperands) {
5643 const char* hlo_string = R"(
5644 HloModule module
5645
5646 ENTRY test {
5647 lhs = f32[10,20,30,40] parameter(0)
5648 rhs = f32[10,20,50,30] parameter(1)
5649 lhs_t = transpose(lhs), dimensions={1,0,3,2}
5650 rhs_t = transpose(rhs), dimensions={1,0,3,2}
5651 ROOT root = dot(lhs_t, rhs_t),
5652 lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5653 lhs_contracting_dims={3}, rhs_contracting_dims={2},
5654 operand_precision={high, default}
5655 }
5656 )";
5657 TF_ASSERT_OK_AND_ASSIGN(auto module,
5658 ParseAndReturnVerifiedModule(hlo_string));
5659
5660 AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{});
5661 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, module.get()));
5662 EXPECT_TRUE(changed);
5663 const HloInstruction* dot;
5664 ASSERT_THAT(
5665 module->entry_computation()->root_instruction(),
5666 GmockMatch(m::Transpose(m::Dot(&dot, m::Parameter(1), m::Parameter(0)))));
5667 EXPECT_EQ(dot->precision_config().operand_precision()[0],
5668 PrecisionConfig::DEFAULT);
5669 EXPECT_EQ(dot->precision_config().operand_precision()[1],
5670 PrecisionConfig::HIGH);
5671 }
5672
5673 struct PadReduceWindowEffectiveBroadcastCase {
5674 std::vector<int64_t> input_spatials;
5675 std::vector<int64_t> symmetric_pad_spatials;
5676 std::vector<int64_t> reduce_window_spatials;
5677 // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
5678 //
5679 // This doesn't test any different functionality but is useful for making sure
5680 // kBroadcast nodes are well formed.
5681 bool prepend_a;
5682 bool should_become_broadcast;
5683
ToTestCaseNamexla::__anon8c2363a90111::PadReduceWindowEffectiveBroadcastCase5684 std::string ToTestCaseName() const {
5685 return absl::StrCat(absl::StrJoin(input_spatials, ","), ";",
5686 absl::StrJoin(symmetric_pad_spatials, ","), ";",
5687 absl::StrJoin(reduce_window_spatials, ","), ";",
5688 prepend_a, ";", should_become_broadcast);
5689 }
5690 };
5691
PrintTo(const PadReduceWindowEffectiveBroadcastCase & c,std::ostream * os)5692 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
5693 *os << c.ToTestCaseName();
5694 }
5695
5696 class PadReduceWindowEffectiveBroadcastTest
5697 : public AlgebraicSimplifierTest,
5698 public ::testing::WithParamInterface<
5699 PadReduceWindowEffectiveBroadcastCase> {};
5700
TEST_P(PadReduceWindowEffectiveBroadcastTest,DoIt)5701 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
5702 auto m = CreateNewVerifiedModule();
5703 const auto& param = GetParam();
5704
5705 // a and b are parallel bounds we can either turn into a B F S0 S1 or
5706 // `B S0 S1 F` kind of pattern.
5707 auto decorate_spatials = [¶m](absl::Span<const int64_t> spatials,
5708 int64_t a, int64_t b) {
5709 std::vector<int64_t> result;
5710 if (param.prepend_a) {
5711 result.push_back(a);
5712 }
5713 for (int64_t s : spatials) {
5714 result.push_back(s);
5715 }
5716 if (!param.prepend_a) {
5717 result.push_back(a);
5718 }
5719 result.push_back(b);
5720 return result;
5721 };
5722
5723 HloComputation::Builder builder(TestName());
5724 const Shape input_shape = ShapeUtil::MakeShape(
5725 F32, decorate_spatials(param.input_spatials, 128, 2048));
5726 HloInstruction* input = builder.AddInstruction(
5727 HloInstruction::CreateParameter(0, input_shape, "input"));
5728
5729 PaddingConfig padding = window_util::MakeSymmetricPadding(
5730 decorate_spatials(param.symmetric_pad_spatials, 0, 0));
5731 TF_ASSERT_OK_AND_ASSIGN(
5732 const Shape pad_shape,
5733 ShapeInference::InferPadShape(input->shape(),
5734 ShapeUtil::MakeShape(F32, {}), padding));
5735 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
5736 pad_shape, input,
5737 builder.AddInstruction(
5738 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
5739 padding));
5740
5741 HloComputation* add_computation = nullptr;
5742 {
5743 HloComputation::Builder builder(TestName() + ".add");
5744 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
5745 HloInstruction* p0 = builder.AddInstruction(
5746 HloInstruction::CreateParameter(0, scalar_shape, "p0"));
5747 HloInstruction* p1 = builder.AddInstruction(
5748 HloInstruction::CreateParameter(1, scalar_shape, "p1"));
5749 builder.AddInstruction(
5750 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
5751 add_computation = m->AddEmbeddedComputation(builder.Build());
5752 }
5753
5754 Window window = window_util::MakeWindow(
5755 decorate_spatials(param.reduce_window_spatials, 1, 1));
5756 auto zero = builder.AddInstruction(
5757 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
5758 TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
5759 ShapeInference::InferReduceWindowShape(
5760 pad->shape(), zero->shape(), window,
5761 add_computation->ComputeProgramShape()));
5762 builder.AddInstruction(HloInstruction::CreateReduceWindow(
5763 output_shape, pad, zero, window, add_computation));
5764
5765 auto computation = m->AddEntryComputation(builder.Build());
5766 AlgebraicSimplifier simplifier(default_options_);
5767 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
5768 ASSERT_TRUE(run_successful);
5769 SCOPED_TRACE(m->ToString());
5770
5771 EXPECT_TRUE(
5772 ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
5773
5774 if (param.should_become_broadcast) {
5775 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast()));
5776 } else {
5777 EXPECT_THAT(computation->root_instruction(),
5778 GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero))));
5779 }
5780 }
5781
5782 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
PadReduceWindowEffectiveBroadcastCases()5783 PadReduceWindowEffectiveBroadcastCases() {
5784 static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
5785 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5786 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5787 /*should_become_broadcast=*/true}, //
5788 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
5789 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
5790 /*should_become_broadcast=*/true}, //
5791 {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
5792 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
5793 /*should_become_broadcast=*/false}, //
5794 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
5795 /*reduce_window_spatials=*/{2, 2}, /*prepend_a=*/true,
5796 /*should_become_broadcast=*/false}, //
5797 {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
5798 /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
5799 /*should_become_broadcast=*/false}, //
5800 };
5801 return *cases;
5802 }
5803
5804 INSTANTIATE_TEST_SUITE_P(
5805 PadReduceWindowEffectiveBroadcastInstantiation,
5806 PadReduceWindowEffectiveBroadcastTest,
5807 ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
5808
5809 class BatchDotStrengthReductionTest
5810 : public AlgebraicSimplifierTest,
5811 public ::testing::WithParamInterface<
5812 ::testing::tuple<int, int, int, PrimitiveType>> {};
TEST_P(BatchDotStrengthReductionTest,BatchDotStrengthReduction)5813 TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) {
5814 auto module = CreateNewVerifiedModule();
5815 int m, k, n;
5816 PrimitiveType element_type;
5817 std::tie(m, k, n, element_type) = GetParam();
5818 std::vector<int64_t> lhs_dims = {2, 3, 5};
5819 std::vector<int64_t> rhs_dims = lhs_dims;
5820 std::vector<int64_t> output_dims = lhs_dims;
5821 if (m > 0) {
5822 lhs_dims.push_back(m);
5823 output_dims.push_back(m);
5824 }
5825 if (k > 0) {
5826 lhs_dims.push_back(k);
5827 rhs_dims.push_back(k);
5828 }
5829 if (n > 0) {
5830 rhs_dims.push_back(n);
5831 output_dims.push_back(n);
5832 }
5833 Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims);
5834 Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims);
5835 Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims);
5836 HloComputation::Builder builder(TestName());
5837
5838 auto lhs = builder.AddInstruction(
5839 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
5840 auto rhs = builder.AddInstruction(
5841 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
5842 DotDimensionNumbers dot_dnums;
5843 dot_dnums.add_lhs_batch_dimensions(0);
5844 dot_dnums.add_lhs_batch_dimensions(1);
5845 dot_dnums.add_lhs_batch_dimensions(2);
5846 dot_dnums.add_rhs_batch_dimensions(0);
5847 dot_dnums.add_rhs_batch_dimensions(1);
5848 dot_dnums.add_rhs_batch_dimensions(2);
5849 if (k > 0) {
5850 dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3);
5851 dot_dnums.add_rhs_contracting_dimensions(3);
5852 }
5853 builder.AddInstruction(HloInstruction::CreateDot(
5854 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5855 auto computation = module->AddEntryComputation(builder.Build());
5856 AlgebraicSimplifier simplifier(default_options_);
5857 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5858 const bool dot_should_be_transformed =
5859 m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1;
5860 EXPECT_EQ(changed, dot_should_be_transformed);
5861 TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5862 bool has_no_dot = true;
5863 for (const auto& hlo : computation->instructions()) {
5864 if (hlo->opcode() == HloOpcode::kDot) {
5865 has_no_dot = false;
5866 break;
5867 }
5868 }
5869 EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5870 }
5871
5872 INSTANTIATE_TEST_SUITE_P(
5873 BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest,
5874 ::testing::Combine(::testing::Values(-1, 1, 2), ::testing::Values(-1, 1, 2),
5875 ::testing::Values(-1, 1, 2),
5876 ::testing::Values(C128, C64, F64, F32, BF16)));
5877
5878 class DotStrengthReductionTest
5879 : public AlgebraicSimplifierTest,
5880 public ::testing::WithParamInterface<
5881 ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest,DotStrengthReduction)5882 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
5883 auto module = CreateNewVerifiedModule();
5884 int m, k, n;
5885 bool transpose_lhs, transpose_rhs;
5886 PrimitiveType element_type;
5887 std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
5888
5889 Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
5890 Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
5891 Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
5892 Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
5893 Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
5894 HloComputation::Builder builder(TestName());
5895
5896 auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
5897 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
5898 if (transpose_lhs) {
5899 lhs = builder.AddInstruction(
5900 HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
5901 }
5902 auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
5903 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
5904 if (transpose_rhs) {
5905 rhs = builder.AddInstruction(
5906 HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
5907 }
5908 DotDimensionNumbers dot_dnums;
5909 dot_dnums.add_lhs_contracting_dimensions(1);
5910 dot_dnums.add_rhs_contracting_dimensions(0);
5911 builder.AddInstruction(HloInstruction::CreateDot(
5912 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5913 auto computation = module->AddEntryComputation(builder.Build());
5914 AlgebraicSimplifier simplifier(default_options_);
5915 // First pass of algebraic simplifier will remove degenerate dimensions
5916 // and optimize dot(transpose(x),transpose(y))
5917 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get()));
5918 const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
5919 const bool computation_should_be_modified =
5920 dot_should_be_transformed || (transpose_lhs && transpose_rhs);
5921 EXPECT_EQ(changed, computation_should_be_modified);
5922 // The second pass of algebraic simplifier will remove dots without
5923 // non-contracting dimensions or contracting dimensions.
5924 TF_ASSERT_OK_AND_ASSIGN(changed, simplifier.Run(module.get()));
5925 EXPECT_EQ(changed, computation_should_be_modified);
5926 bool has_no_dot = true;
5927 for (const auto& hlo : computation->instructions()) {
5928 if (hlo->opcode() == HloOpcode::kDot) {
5929 has_no_dot = false;
5930 break;
5931 }
5932 }
5933 EXPECT_EQ(has_no_dot, dot_should_be_transformed);
5934 }
5935
5936 INSTANTIATE_TEST_SUITE_P(
5937 DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
5938 ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
5939 ::testing::Values(1, 2), ::testing::Bool(),
5940 ::testing::Bool(),
5941 ::testing::Values(C128, C64, F64, F32, BF16)));
5942
5943 struct DotOfConcatTestSpec {
5944 int64_t m;
5945 int64_t k;
5946 int64_t n;
5947 };
5948
5949 class DotOfConcatSimplificationTest
5950 : public AlgebraicSimplifierTest,
5951 public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
5952
5953 // Test that we transform
5954 // dot(const, concat(A, B, C))
5955 // to
5956 // add(dot(const_0, A), dot(const_1, B), dot(const_2, C))
TEST_P(DotOfConcatSimplificationTest,ConstantLHS)5957 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
5958 auto m = CreateNewVerifiedModule();
5959 HloComputation::Builder builder(TestName());
5960
5961 DotOfConcatTestSpec spec = GetParam();
5962
5963 ASSERT_GE(spec.k, 3);
5964
5965 int64_t k0 = spec.k / 3;
5966 int64_t k1 = spec.k / 3;
5967 int64_t k2 = spec.k - k0 - k1;
5968
5969 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
5970 auto* lhs = builder.AddInstruction(
5971 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
5972 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
5973
5974 Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
5975 Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
5976 Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
5977
5978 HloInstruction* rhs0 = builder.AddInstruction(
5979 HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
5980 HloInstruction* rhs1 = builder.AddInstruction(
5981 HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
5982 HloInstruction* rhs2 = builder.AddInstruction(
5983 HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
5984
5985 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
5986 HloInstruction* rhs = builder.AddInstruction(
5987 HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
5988
5989 DotDimensionNumbers dot_dnums;
5990 dot_dnums.add_lhs_contracting_dimensions(1);
5991 dot_dnums.add_rhs_contracting_dimensions(0);
5992
5993 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
5994 builder.AddInstruction(HloInstruction::CreateDot(
5995 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5996
5997 auto computation = m->AddEntryComputation(builder.Build());
5998 AlgebraicSimplifier simplifier(default_options_);
5999 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6000 ASSERT_TRUE(run_successful);
6001
6002 EXPECT_TRUE(
6003 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6004
6005 auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0));
6006 auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1));
6007 auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2));
6008 EXPECT_THAT(
6009 computation->root_instruction(),
6010 GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2)));
6011 }
6012
6013 // Test that we transform
6014 // dot(concat(A, B, C), const)
6015 // to
6016 // add(dot(A, const_0), dot(B, const_1), dot(C, const_2))
TEST_P(DotOfConcatSimplificationTest,ConstantRHS)6017 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
6018 auto m = CreateNewVerifiedModule();
6019 HloComputation::Builder builder(TestName());
6020
6021 DotOfConcatTestSpec spec = GetParam();
6022
6023 ASSERT_GE(spec.k, 4);
6024
6025 int64_t k0 = spec.k / 4;
6026 int64_t k1 = spec.k / 4;
6027 int64_t k2 = spec.k / 4;
6028 int64_t k3 = spec.k - k0 - k1 - k2;
6029
6030 Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
6031 Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
6032 Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
6033 Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
6034
6035 HloInstruction* lhs0 = builder.AddInstruction(
6036 HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
6037 HloInstruction* lhs1 = builder.AddInstruction(
6038 HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
6039 HloInstruction* lhs2 = builder.AddInstruction(
6040 HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
6041 HloInstruction* lhs3 = builder.AddInstruction(
6042 HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
6043
6044 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
6045 HloInstruction* lhs =
6046 builder.AddInstruction(HloInstruction::CreateConcatenate(
6047 lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
6048
6049 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
6050 auto* rhs = builder.AddInstruction(
6051 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6052 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
6053
6054 DotDimensionNumbers dot_dnums;
6055 dot_dnums.add_lhs_contracting_dimensions(1);
6056 dot_dnums.add_rhs_contracting_dimensions(0);
6057
6058 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
6059 builder.AddInstruction(HloInstruction::CreateDot(
6060 dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6061
6062 auto computation = m->AddEntryComputation(builder.Build());
6063 AlgebraicSimplifier simplifier(default_options_);
6064 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6065 ASSERT_TRUE(run_successful);
6066 EXPECT_TRUE(
6067 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6068
6069 auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant()));
6070 auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant()));
6071 auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant()));
6072 auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant()));
6073 EXPECT_THAT(
6074 computation->root_instruction(),
6075 GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2),
6076 match_dot_3)));
6077 }
6078
6079 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
6080 {/*m=*/3, /*k=*/9, /*n=*/3}, //
6081 {/*m=*/3, /*k=*/20, /*n=*/3}, //
6082 {/*m=*/1, /*k=*/18, /*n=*/5}, //
6083 {/*m=*/20, /*k=*/20, /*n=*/1}, //
6084 {/*m=*/1, /*k=*/16, /*n=*/1}, //
6085 };
6086
TEST_F(DotOfConcatSimplificationTest,ConcatIntoScalarDot)6087 TEST_F(DotOfConcatSimplificationTest, ConcatIntoScalarDot) {
6088 const char* kModuleStr = R"(
6089 HloModule m
6090 test {
6091 param0 = f32[4] parameter(0)
6092 param1 = f32[1] parameter(1)
6093 constant = f32[5] constant({-0.38, 0.07, -0.62, 0.66, 0.20})
6094 concat = f32[5] concatenate(param0, param1), dimensions={0}
6095 ROOT dot = f32[] dot(concat, constant), lhs_contracting_dims={0},
6096 rhs_contracting_dims={0}
6097 })";
6098 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6099 AlgebraicSimplifierOptions options = default_options_;
6100 options.set_enable_dot_strength_reduction(false);
6101 ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6102 }
6103
TEST_F(DotOfConcatSimplificationTest,UnnestConcatenate)6104 TEST_F(DotOfConcatSimplificationTest, UnnestConcatenate) {
6105 const char* kModuleStr = R"(
6106 HloModule m
6107 test {
6108 p0 = f32[2,10] parameter(0)
6109 p1 = f32[3,10] parameter(1)
6110 p2 = f32[4,10] parameter(2)
6111 c0 = f32[5,10] concatenate(p0, p1), dimensions={0}
6112 ROOT c1 = f32[9,10] concatenate(c0, p2), dimensions={0}
6113 })";
6114 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6115 AlgebraicSimplifier simplifier(default_options_);
6116 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&simplifier, m.get()));
6117 EXPECT_TRUE(changed);
6118 EXPECT_THAT(m->entry_computation()->root_instruction(),
6119 GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1),
6120 m::Parameter(2))));
6121 }
6122
6123 // Test that DynamicUpdateSlice update param with any dimension equal to zero
6124 // gets removed.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceZeroUpdate)6125 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
6126 auto m = CreateNewVerifiedModule();
6127 HloComputation::Builder builder(TestName());
6128 const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10});
6129 HloInstruction* const operand = builder.AddInstruction(
6130 HloInstruction::CreateParameter(0, dslice_shape, "operand"));
6131 const Shape update_shape = ShapeUtil::MakeShape(F32, {0});
6132 HloInstruction* const update = builder.AddInstruction(
6133 HloInstruction::CreateParameter(1, update_shape, "update"));
6134 HloInstruction* const start_indices = builder.AddInstruction(
6135 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>({})));
6136 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
6137 dslice_shape, operand, update,
6138 std::initializer_list<HloInstruction*>({start_indices})));
6139 const HloComputation* const computation =
6140 m->AddEntryComputation(builder.Build());
6141
6142 AlgebraicSimplifier simplifier(default_options_);
6143 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
6144 EXPECT_THAT(computation->root_instruction(), operand);
6145 }
6146
6147 // Test that dynamic-update-slice with a scalar broadcast becomes a pad.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPad)6148 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPad) {
6149 const char* hlo_string = R"(
6150 HloModule AddBroadcastZeroWithDynamicSlice
6151
6152 ENTRY AddBroadcastZeroWithDynamicSlice {
6153 param0 = f32[1800,12,512]{2,1,0} parameter(0)
6154 constant = f32[] constant(0)
6155 broadcast = f32[1800,12,512]{2,1,0} broadcast(constant), dimensions={}
6156 param1 = f32[1,12,512]{2,1,0} parameter(1)
6157 constant.1 = s32[] constant(0)
6158 dynamic-update-slice = f32[1800,12,512]{2,1,0} dynamic-update-slice(broadcast, param1, constant.1, constant.1, constant.1)
6159 ROOT add = f32[1800,12,512]{2,1,0} add(param0, dynamic-update-slice)
6160 }
6161 )";
6162 TF_ASSERT_OK_AND_ASSIGN(auto module,
6163 ParseAndReturnVerifiedModule(hlo_string));
6164 VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
6165 AlgebraicSimplifier simplifier(default_options_);
6166 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6167 VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
6168 auto root = module->entry_computation()->root_instruction();
6169 EXPECT_THAT(root->opcode(), HloOpcode::kAdd);
6170 EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad);
6171 }
6172
TEST_F(AlgebraicSimplifierTest,AddDynamicUpdateSliceToAddSlice)6173 TEST_F(AlgebraicSimplifierTest, AddDynamicUpdateSliceToAddSlice) {
6174 const char* hlo_string = R"(
6175 HloModule AddDynamicUpdateSliceToAddSlice
6176
6177 ENTRY AddDynamicUpdateSliceToAddSlice {
6178 param0 = f32[1,4,12,512,1,1] parameter(0)
6179 constant = f32[] constant(0)
6180 broadcast = f32[4,12,512] broadcast(constant), dimensions={}
6181 param1 = f32[1,12,512] parameter(1)
6182 param2 = s32[] parameter(2)
6183 constant.1 = s32[] constant(0)
6184 dynamic-update-slice = f32[4,12,512] dynamic-update-slice(
6185 broadcast, param1, param2, constant.1, constant.1)
6186 reshape = f32[1,4,12,512,1,1] reshape(dynamic-update-slice)
6187 ROOT add = f32[1,4,12,512,1,1] add(param0, reshape)
6188 }
6189 )";
6190 TF_ASSERT_OK_AND_ASSIGN(auto module,
6191 ParseAndReturnVerifiedModule(hlo_string));
6192 VLOG(2) << "Before rewrite reshape\n" << module->ToString();
6193 AlgebraicSimplifier simplifier(default_options_);
6194 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6195 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6196 VLOG(2) << "After rewrite to add slice\n" << module->ToString();
6197 auto root = module->entry_computation()->root_instruction();
6198 EXPECT_THAT(
6199 root,
6200 GmockMatch(m::DynamicUpdateSlice(
6201 m::Parameter(0),
6202 m::Add(m::DynamicSlice(m::Parameter(0), m::Constant(),
6203 m::Parameter(2), m::Constant(), m::Constant(),
6204 m::Constant(), m::Constant()),
6205 m::Reshape(m::Parameter(1))),
6206 m::Constant(), m::Parameter(2), m::Constant(), m::Constant(),
6207 m::Constant(), m::Constant())));
6208 }
6209
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReduction)6210 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) {
6211 const char* hlo_string = R"(
6212 HloModule ConstScalarMultiply
6213 ENTRY ConstScalarMultiply {
6214 param0 = f32[16,512,4096]{2,1,0} parameter(0)
6215 constant.0 = f32[] constant(0.5)
6216 broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={}
6217 multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0)
6218 param1 = f32[16,512,4096]{2,1,0} parameter(1)
6219 multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1)
6220 param2 = f32[16,512,1024]{2,1,0} parameter(2)
6221 constant.1 = f32[] constant(1.109)
6222 broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={}
6223 multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1)
6224 ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0
6225 }
6226 )";
6227 TF_ASSERT_OK_AND_ASSIGN(auto module,
6228 ParseAndReturnVerifiedModule(hlo_string));
6229 AlgebraicSimplifierOptions options;
6230 options.set_enable_scalar_multiply_reduction(true);
6231 AlgebraicSimplifier simplifier(options);
6232 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6233 auto root = module->entry_computation()->root_instruction();
6234 EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
6235 EXPECT_THAT(root,
6236 GmockMatch(m::MultiplyAnyOrder(
6237 m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f)))));
6238 }
6239
TEST_F(AlgebraicSimplifierTest,ScalarMultiplyReductionMultiUser)6240 TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) {
6241 const char* hlo_string = R"(
6242 HloModule ConstScalarMultiply
6243 ENTRY ConstScalarMultiply {
6244 param0 = f32[16,512,1024] parameter(0)
6245 param1 = f32[4096,1024,1] parameter(1)
6246 convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf
6247 constant.1 = f32[] constant(0.5)
6248 broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={}
6249 multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1)
6250 param2 = f32[16,512,4096] parameter(2)
6251 multiply.2 = f32[16,512,4096] multiply(convolution, param2)
6252 ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2)
6253 }
6254 )";
6255 TF_ASSERT_OK_AND_ASSIGN(auto module,
6256 ParseAndReturnVerifiedModule(hlo_string));
6257 AlgebraicSimplifierOptions options;
6258 options.set_enable_scalar_multiply_reduction(true);
6259 AlgebraicSimplifier simplifier(options);
6260 ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
6261 }
6262
6263 INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation,
6264 DotOfConcatSimplificationTest,
6265 ::testing::ValuesIn(kDotOfConcatTestSpecs));
6266
6267 struct DotOfGatherTestSpec {
6268 int64_t m;
6269 int64_t k;
6270 int64_t n;
6271 int s; // start index for dynamic slice on the non-contracting dimension
6272 int64_t lcd; // left contracting dimension
6273 int64_t rcd; // right contracting dimension
6274 bool neg; // is negative testcase
6275 };
6276
6277 class DotOfGatherSimplificationTest
6278 : public AlgebraicSimplifierTest,
6279 public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
6280
6281 // input: dot(DS(ctA), ctB))
6282 // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
6283 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
6284 // output: DS(dot(ctA, ctB))
6285 // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}.
TEST_P(DotOfGatherSimplificationTest,ConstantRHS)6286 TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
6287 auto m = CreateNewVerifiedModule();
6288 HloComputation::Builder builder(TestName());
6289
6290 DotOfGatherTestSpec spec = GetParam();
6291
6292 ASSERT_LE(spec.s, spec.m);
6293
6294 // For negative tests, increase k of the dynamic slice argument to prevent the
6295 // optimization (constants ctA, ctB must have equal contracting dimensions).
6296 int64_t k_increase = spec.neg ? 5 : 0;
6297 int64_t lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m;
6298 int64_t lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
6299 Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
6300 auto* lhs = builder.AddInstruction(
6301 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6302 /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
6303 /*cols=*/lhs_cols)));
6304
6305 int32_t start_row = (spec.lcd == 0) ? 0 : spec.s;
6306 int32_t start_col = (spec.lcd == 0) ? spec.s : 0;
6307 std::vector<HloInstruction*> start_indices = {
6308 builder.AddInstruction(HloInstruction::CreateConstant(
6309 LiteralUtil::CreateR0<int32_t>(start_row))),
6310 builder.AddInstruction(HloInstruction::CreateConstant(
6311 LiteralUtil::CreateR0<int32_t>(start_col)))};
6312 int64_t slice_row_size = (spec.lcd == 0) ? spec.k : 1;
6313 int64_t slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
6314 std::vector<int64_t> slice_sizes = {slice_row_size, slice_col_size};
6315 Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
6316 auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
6317 ds_shape, lhs, start_indices, slice_sizes));
6318
6319 int64_t rhs_rows = (spec.rcd == 0) ? spec.k : spec.n;
6320 int64_t rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
6321 Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
6322 auto* rhs = builder.AddInstruction(
6323 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6324 /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
6325 /*cols=*/rhs_cols)));
6326
6327 DotDimensionNumbers dot_dnums;
6328 dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
6329 dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
6330
6331 int64_t dot_row_size = 1;
6332 int64_t dot_col_size = spec.n;
6333 Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
6334 builder.AddInstruction(HloInstruction::CreateDot(
6335 dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6336
6337 auto computation = m->AddEntryComputation(builder.Build());
6338 AlgebraicSimplifier simplifier(default_options_);
6339 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6340 ASSERT_TRUE(run_successful);
6341 EXPECT_TRUE(
6342 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6343
6344 if (spec.neg) {
6345 EXPECT_NE(computation->root_instruction()->opcode(),
6346 HloOpcode::kDynamicSlice);
6347 } else {
6348 EXPECT_THAT(computation->root_instruction(),
6349 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
6350 m::Constant(), m::Constant())));
6351 }
6352 }
6353
6354 // input: dot(ctA, DS(ctB))
6355 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}).
6356 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
6357 // output: DS(dot(ctA, ctB))
6358 // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}.
TEST_P(DotOfGatherSimplificationTest,ConstantLHS)6359 TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
6360 auto m = CreateNewVerifiedModule();
6361 HloComputation::Builder builder(TestName());
6362
6363 DotOfGatherTestSpec spec = GetParam();
6364
6365 ASSERT_LE(spec.s, spec.n);
6366
6367 int64_t lhs_rows = (spec.lcd == 0) ? spec.k : spec.m;
6368 int64_t lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
6369 Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
6370 auto* lhs = builder.AddInstruction(
6371 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6372 /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
6373 /*cols=*/lhs_cols)));
6374
6375 // For negative tests increase k of the dynamic slice argument to prevent the
6376 // optimization
6377 int64_t k_increase = spec.neg ? 5 : 0;
6378 int64_t rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n;
6379 int64_t rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
6380 Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
6381 auto* rhs = builder.AddInstruction(
6382 HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
6383 /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
6384 /*cols=*/rhs_cols)));
6385
6386 int32_t start_row = (spec.rcd == 0) ? 0 : spec.s;
6387 int32_t start_col = (spec.rcd == 0) ? spec.s : 0;
6388 std::vector<HloInstruction*> start_indices = {
6389 builder.AddInstruction(HloInstruction::CreateConstant(
6390 LiteralUtil::CreateR0<int32_t>(start_row))),
6391 builder.AddInstruction(HloInstruction::CreateConstant(
6392 LiteralUtil::CreateR0<int32_t>(start_col)))};
6393 int64_t slice_row_size = (spec.rcd == 0) ? spec.k : 1;
6394 int64_t slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
6395 std::vector<int64_t> slice_sizes = {slice_row_size, slice_col_size};
6396 Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes);
6397 auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
6398 ds_shape, rhs, start_indices, slice_sizes));
6399
6400 DotDimensionNumbers dot_dnums;
6401 dot_dnums.add_lhs_contracting_dimensions(spec.lcd);
6402 dot_dnums.add_rhs_contracting_dimensions(spec.rcd);
6403
6404 int64_t dot_row_size = spec.m;
6405 int64_t dot_col_size = 1;
6406 Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
6407 builder.AddInstruction(HloInstruction::CreateDot(
6408 dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
6409
6410 auto computation = m->AddEntryComputation(builder.Build());
6411 AlgebraicSimplifier simplifier(default_options_);
6412 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get()));
6413 ASSERT_TRUE(run_successful);
6414 EXPECT_TRUE(
6415 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
6416
6417 if (spec.neg) {
6418 EXPECT_NE(computation->root_instruction()->opcode(),
6419 HloOpcode::kDynamicSlice);
6420 } else {
6421 EXPECT_THAT(computation->root_instruction(),
6422 GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()),
6423 m::Constant(), m::Constant())));
6424 }
6425 }
6426
DotOfGatherPositiveNegativeTests()6427 std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() {
6428 std::vector<DotOfGatherTestSpec> positives = {
6429 // "Classical dot", i.e. matrix multiply:
6430 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0,
6431 /*neg=*/false},
6432 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0,
6433 /*neg=*/false},
6434 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0,
6435 /*neg=*/false},
6436 // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to
6437 // dot(ct, ct) before DotOfGather optimization kicks in.
6438 // Contract on rows:
6439 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0,
6440 /*neg=*/false},
6441 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0,
6442 /*neg=*/false},
6443 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0,
6444 /*neg=*/false},
6445 // Reverse matrix multiply:
6446 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1,
6447 /*neg=*/false},
6448 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1,
6449 /*neg=*/false},
6450 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1,
6451 /*neg=*/false},
6452 // Contract on columns:
6453 {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1,
6454 /*neg=*/false},
6455 {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1,
6456 /*neg=*/false},
6457 {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1,
6458 /*neg=*/false},
6459 };
6460 std::vector<DotOfGatherTestSpec> all;
6461 const std::vector<DotOfGatherTestSpec>::size_type positives_size =
6462 positives.size();
6463 all.reserve(positives_size * 2);
6464 for (std::vector<DotOfGatherTestSpec>::size_type i = 0; i < positives_size;
6465 i++) {
6466 DotOfGatherTestSpec positive_test = positives[i];
6467 all.push_back(positive_test);
6468 DotOfGatherTestSpec negative_test = positive_test;
6469 negative_test.neg = true;
6470 all.push_back(negative_test);
6471 }
6472 return all;
6473 }
6474
6475 INSTANTIATE_TEST_SUITE_P(
6476 DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest,
6477 ::testing::ValuesIn(DotOfGatherPositiveNegativeTests()));
6478
TEST_F(AlgebraicSimplifierTest,GatherOfScalarToBroadcast)6479 TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) {
6480 const char* hlo_string = R"(
6481 HloModule repeat
6482
6483 ENTRY main {
6484 o = f32[1,1] parameter(0)
6485 i = s32[100,2] parameter(1)
6486 ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1},
6487 start_index_map={0,1},
6488 index_vector_dim=1,
6489 offset_dims={},
6490 slice_sizes={1,1}
6491 }
6492 )";
6493 TF_ASSERT_OK_AND_ASSIGN(auto module,
6494 ParseAndReturnVerifiedModule(hlo_string));
6495
6496 AlgebraicSimplifierOptions options;
6497 AlgebraicSimplifier simplifier(options);
6498 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6499 auto root = module->entry_computation()->root_instruction();
6500 EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
6501 }
6502
TEST_F(AlgebraicSimplifierTest,TupleReduceReshape)6503 TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) {
6504 const char* hlo_string = R"(
6505 HloModule module
6506
6507 reducer {
6508 parameter.1 = f32[] parameter(0)
6509 parameter.3 = f32[] parameter(2)
6510 add.2 = f32[] add(parameter.1, parameter.3)
6511 parameter.0 = f32[] parameter(1)
6512 parameter.2 = f32[] parameter(3)
6513 add.3 = f32[] add(parameter.0, parameter.2)
6514 ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3)
6515 }
6516
6517 ENTRY entry {
6518 parameter.6 = (f32[], f32[]) parameter(0)
6519 get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0
6520 get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1
6521 constant = f32[] constant(0)
6522 ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer
6523 }
6524 )";
6525 TF_ASSERT_OK_AND_ASSIGN(auto module,
6526 ParseAndReturnVerifiedModule(hlo_string));
6527
6528 AlgebraicSimplifierOptions options;
6529 AlgebraicSimplifier simplifier(options);
6530 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6531 auto root = module->entry_computation()->root_instruction();
6532 EXPECT_THAT(root, GmockMatch(m::Tuple(
6533 m::Reshape(m::GetTupleElement(m::Parameter(), 0)),
6534 m::Reshape(m::GetTupleElement(m::Parameter(), 1)))));
6535 }
6536
TEST_F(AlgebraicSimplifierTest,TupleReduceBroadcast)6537 TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) {
6538 const char* hlo_string = R"(
6539 HloModule module
6540
6541 reducer {
6542 parameter.1 = f32[] parameter(0)
6543 parameter.3 = f32[] parameter(2)
6544 mul.2 = f32[] add(parameter.1, parameter.3)
6545 parameter.0 = f32[] parameter(1)
6546 parameter.2 = f32[] parameter(3)
6547 add.3 = f32[] add(parameter.0, parameter.2)
6548 ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3)
6549 }
6550
6551 ENTRY entry {
6552 parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0)
6553 get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0
6554 get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1
6555 constant.0 = f32[] constant(0)
6556 constant.1 = f32[] constant(1)
6557 ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer
6558 }
6559 )";
6560 TF_ASSERT_OK_AND_ASSIGN(auto module,
6561 ParseAndReturnVerifiedModule(hlo_string));
6562
6563 AlgebraicSimplifierOptions options;
6564 AlgebraicSimplifier simplifier(options);
6565 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6566 auto root = module->entry_computation()->root_instruction();
6567 EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)),
6568 m::Broadcast(m::ConstantScalar(1)))));
6569 }
6570
TEST_F(AlgebraicSimplifierTest,ZeroSizedReshapeWithoutLayout)6571 TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
6572 auto builder = HloComputation::Builder(TestName());
6573 HloInstruction* param =
6574 builder.AddInstruction(HloInstruction::CreateParameter(
6575 0, ShapeUtil::MakeShape(F32, {1}), "param"));
6576 HloInstruction* broadcast =
6577 builder.AddInstruction(HloInstruction::CreateBroadcast(
6578 ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
6579
6580 // Create a reshape with zero sized result and without layout.
6581 Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
6582 reshaped_shape.clear_layout();
6583 builder.AddInstruction(
6584 HloInstruction::CreateReshape(reshaped_shape, broadcast));
6585
6586 std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
6587 module->AddEntryComputation(builder.Build());
6588
6589 AlgebraicSimplifierOptions options;
6590 AlgebraicSimplifier simplifier(options);
6591 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6592 HloInstruction* root = module->entry_computation()->root_instruction();
6593 EXPECT_THAT(root, GmockMatch(m::Constant()));
6594 }
6595
TEST_F(AlgebraicSimplifierTest,DividedByConstantInstructionWithoutLayout)6596 TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
6597 Shape shape = ShapeUtil::MakeShape(F32, {});
6598 shape.clear_layout();
6599 auto builder = HloComputation::Builder(TestName());
6600 HloInstruction* param = builder.AddInstruction(
6601 HloInstruction::CreateParameter(0, shape, "param"));
6602
6603 HloInstruction* const_value = builder.AddInstruction(
6604 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
6605 builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
6606 param, const_value));
6607
6608 std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
6609 module->AddEntryComputation(builder.Build());
6610
6611 AlgebraicSimplifierOptions options;
6612 AlgebraicSimplifier simplifier(options);
6613 EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
6614 HloInstruction* root = module->entry_computation()->root_instruction();
6615 EXPECT_THAT(root, GmockMatch(m::Multiply()));
6616 }
6617
6618 // Test that 1/sqrt(X) is simplified to rsqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipSqrt)6619 TEST_F(AlgebraicSimplifierTest, RecipSqrt) {
6620 const char* kModuleStr = R"(
6621 HloModule m
6622 test {
6623 p0 = f32[] parameter(0)
6624 p1 = f32[] parameter(1)
6625 sqrt = f32[] sqrt(p0)
6626 ROOT div = f32[] divide(p1, sqrt)
6627 }
6628 )";
6629 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6630 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6631 EXPECT_THAT(m->entry_computation()->root_instruction(),
6632 GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
6633 m::Rsqrt(m::Parameter(0)))));
6634 }
6635
6636 // Test that 1/rsqrt(X) is simplified to sqrt(X).
TEST_F(AlgebraicSimplifierTest,RecipRsqrt)6637 TEST_F(AlgebraicSimplifierTest, RecipRsqrt) {
6638 const char* kModuleStr = R"(
6639 HloModule m
6640 test {
6641 p0 = f32[] parameter(0)
6642 p1 = f32[] parameter(1)
6643 rsqrt = f32[] rsqrt(p0)
6644 ROOT div = f32[] divide(p1, rsqrt)
6645 }
6646 )";
6647 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6648 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6649 EXPECT_THAT(m->entry_computation()->root_instruction(),
6650 GmockMatch(m::MultiplyAnyOrder(m::Parameter(1),
6651 m::Sqrt(m::Parameter(0)))));
6652 }
6653
TEST_F(AlgebraicSimplifierTest,CopyReshape)6654 TEST_F(AlgebraicSimplifierTest, CopyReshape) {
6655 const char* kModuleStr = R"(
6656 HloModule m
6657 test {
6658 p0 = f32[168,168,48,48]{3,2,1,0} parameter(0)
6659 r0 = f32[1,168,168,2304]{3,2,1,0} reshape(p0)
6660 ROOT c0 = f32[1,168,168,2304]{3,0,2,1} copy(r0)
6661 })";
6662 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6663 Shape result_shape = m->entry_computation()->root_instruction()->shape();
6664 AlgebraicSimplifierOptions options(
6665 [](const Shape&, const Shape&) { return false; });
6666 options.set_is_layout_sensitive(true);
6667 ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6668 EXPECT_THAT(
6669 m->entry_computation()->root_instruction(),
6670 GmockMatch(m::Reshape(m::Parameter(0)).WithShapeEqualTo(&result_shape)));
6671 }
6672
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RL)6673 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RL) {
6674 const char* kModuleStr = R"(
6675 HloModule m
6676 test {
6677 rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
6678 t0 = f32[2, 2, 3] parameter(0)
6679 t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6680 lhs = f32[2, 6] reshape(t1)
6681 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6682 }
6683 )";
6684 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6685 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6686 auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
6687 auto shape2 = ShapeUtil::MakeShape(F32, {3, 2, 2});
6688 auto shape3 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6689 // The transformation of moving transpose and reshape to the constant side
6690 // is layout insensitive. We ignore layout when checking shapes.
6691 const HloInstruction* transpose;
6692 ASSERT_THAT(m->entry_computation()->root_instruction(),
6693 GmockMatch(m::Dot(
6694 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6695 m::Reshape(m::Transpose(&transpose,
6696 m::Reshape(m::Constant())
6697 .WithShapeCompatibleTo(&shape2))
6698 .WithShapeCompatibleTo(&shape3)))));
6699 EXPECT_THAT(transpose->dimensions(), ElementsAre(1, 0, 2));
6700 }
6701
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_RR)6702 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_RR) {
6703 const char* kModuleStr = R"(
6704 HloModule m
6705 test {
6706 rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6707 {1, 1, 1, 1, 1, 1}})
6708 t0 = f32[2, 2, 3] parameter(0)
6709 t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6710 lhs = f32[2, 6] reshape(t1)
6711 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
6712 }
6713 )";
6714 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6715 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6716 auto shape1 = ShapeUtil::MakeShape(F32, {2, 6});
6717 auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6718 auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6719 EXPECT_THAT(m->entry_computation()->root_instruction(),
6720 GmockMatch(m::Dot(
6721 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6722 m::Reshape(m::Transpose(m::Reshape(m::Constant())
6723 .WithShapeCompatibleTo(&shape2))
6724 .WithShapeCompatibleTo(&shape3)))));
6725 }
6726
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR)6727 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR) {
6728 const char* kModuleStr = R"(
6729 HloModule m
6730 test {
6731 rhs = f32[2, 6] constant({{1, 2, 3, 4, 5, 6},
6732 {1, 1, 1, 1, 1, 1}})
6733 t0 = f32[2, 3, 2] parameter(0)
6734 t1 = f32[3, 2, 2] transpose(t0), dimensions={1, 0, 2}
6735 lhs = f32[6, 2] reshape(t1)
6736 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={1}
6737 }
6738 )";
6739 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6740 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6741 auto shape1 = ShapeUtil::MakeShape(F32, {6, 2});
6742 auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2});
6743 auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3});
6744 EXPECT_THAT(m->entry_computation()->root_instruction(),
6745 GmockMatch(m::Dot(
6746 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6747 m::Reshape(m::Transpose(m::Reshape(m::Constant())
6748 .WithShapeCompatibleTo(&shape2))
6749 .WithShapeCompatibleTo(&shape3)))));
6750 }
6751
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_LR2)6752 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_LR2) {
6753 const char* kModuleStr = R"(
6754 HloModule m
6755 test {
6756 rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{7, 7},{8, 8}})
6757 t0 = f32[2, 2, 2, 2] parameter(0)
6758 t1 = f32[2, 2, 2, 2] transpose(t0), dimensions={0, 2, 3, 1}
6759 lhs = f32[2, 8] reshape(t1)
6760 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1},
6761 rhs_contracting_dims={0}
6762 }
6763 )";
6764 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6765 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6766 auto shape1 = ShapeUtil::MakeShape(F32, {2, 8});
6767 auto shape2 = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
6768 const HloInstruction* transpose;
6769 ASSERT_THAT(
6770 m->entry_computation()->root_instruction(),
6771 GmockMatch(m::Dot(
6772 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6773 m::Reshape(m::Transpose(
6774 &transpose,
6775 m::Reshape(m::Constant()).WithShapeCompatibleTo(&shape2))))));
6776 EXPECT_THAT(transpose->dimensions(), ElementsAre(2, 0, 1, 3));
6777 }
6778
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_MM)6779 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_MM) {
6780 const char* kModuleStr = R"(
6781 HloModule m
6782 test {
6783 rhs = f32[2, 6, 2] constant({{{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}},
6784 {{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}}})
6785 t0 = f32[2, 2, 3, 2] parameter(0)
6786 t1 = f32[2, 3, 2, 2] transpose(t0), dimensions={0, 2, 1, 3}
6787 lhs = f32[2, 6, 2] reshape(t1)
6788 ROOT dot.5 = f32[2, 2, 2] dot(lhs, rhs), lhs_batch_dims={0}, lhs_contracting_dims={1},
6789 rhs_batch_dims={0}, rhs_contracting_dims={1}
6790 }
6791 )";
6792 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6793 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6794 auto shape1 = ShapeUtil::MakeShape(F32, {2, 6, 2});
6795 auto shape2 = ShapeUtil::MakeShape(F32, {2, 3, 2, 2});
6796 auto shape3 = ShapeUtil::MakeShape(F32, {2, 2, 3, 2});
6797 const HloInstruction* transpose;
6798 ASSERT_THAT(m->entry_computation()->root_instruction(),
6799 GmockMatch(m::Dot(
6800 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6801 m::Reshape(m::Transpose(&transpose,
6802 m::Reshape(m::Constant())
6803 .WithShapeCompatibleTo(&shape2))
6804 .WithShapeCompatibleTo(&shape3)))));
6805 EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6806 }
6807
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegTranspose)6808 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegTranspose) {
6809 const char* kModuleStr = R"(
6810 HloModule m
6811 test {
6812 rhs = f32[12, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6},{1, 1},{2, 2},{3, 3},{4, 4},{5, 5},{6, 6}})
6813 t0 = f32[3, 4, 2] parameter(0)
6814 t1 = f32[2, 3, 4] transpose(t0), dimensions={2, 0, 1}
6815 lhs = f32[2, 12] reshape(t1)
6816 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6817 }
6818 )";
6819 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6820 // Transpose affects non-contracting dimension. The transpose and reshape
6821 // should not be moved to the constant side.
6822 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6823 }
6824
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegReshape)6825 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegReshape) {
6826 const char* kModuleStr = R"(
6827 HloModule m
6828 test {
6829 rhs = f32[8, 2] constant({{1, 1},{2, 2},{3, 3},{4, 4},{1, 1},{2, 2},{3, 3},{4, 4}})
6830 t0 = f32[2, 4, 3] parameter(0)
6831 t1 = f32[2, 3, 4] transpose(t0), dimensions={0, 2, 1}
6832 lhs = f32[3, 8] reshape(t1)
6833 ROOT dot.5 = f32[3, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6834 }
6835 )";
6836 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6837 // Reshape affects non-contracting dimensions. The transpose and reshape
6838 // should not be moved to the constant side.
6839 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6840 }
6841
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegConstant)6842 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegConstant) {
6843 const char* kModuleStr = R"(
6844 HloModule m
6845 test {
6846 t0 = f32[2, 3, 4] parameter(0)
6847 t1 = f32[2, 4, 3] transpose(t0), dimensions={0, 2, 1}
6848 lhs = f32[2, 12] reshape(t1)
6849 rhs = f32[12, 2] parameter(1)
6850 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6851 }
6852 )";
6853 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6854 // Both operands are non-constant, so the optimization should not happen.
6855 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6856 }
6857
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NegLayout)6858 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_NegLayout) {
6859 const char* kModuleStr = R"(
6860 HloModule m
6861 test {
6862 rhs = f32[6, 2] constant({{1, 2},{3, 4},{5, 6},{1, 1},{1, 1},{1, 1}})
6863 t0 = f32[2, 2, 3] parameter(0)
6864 t1 = f32[2, 3, 2] transpose(t0), dimensions={0, 2, 1}
6865 lhs = f32[2, 6] reshape(t1)
6866 ROOT dot.5 = f32[2, 2] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6867 }
6868 )";
6869 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6870 // We disable converting reshape to bitcast to make sure algsimp pass does
6871 // not catch the reshape in this test, then we can simply check if algsimp
6872 // pass does not make any change.
6873 AlgebraicSimplifierOptions options(
6874 [](const Shape&, const Shape&) { return false; });
6875 options.set_is_layout_sensitive(true);
6876 // The transformation of moving transpose and reshape to the constant side is
6877 // layout insensitive. It should not happen if AlgebraicSimplifier is set up
6878 // to be layout sensitive.
6879 ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
6880 }
6881
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDimsNoChange)6882 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDimsNoChange) {
6883 // This isn't transformed (notice that the relative order of the `2` and `3`
6884 // dims doesn't change, so there's no opportunity here), but it's nonetheless
6885 // an interesting testcase because of the presence of the size-1 dimensions.
6886 const char* kModuleStr = R"(
6887 HloModule m
6888 test {
6889 param = f32[1,2,5,3] parameter(0)
6890 transpose = f32[1,5,2,3] transpose(param), dimensions={0,2,1,3}
6891 reshape = f32[5,6] reshape(transpose)
6892 constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6893 ROOT dot = f32[5,4] dot(reshape, constant),
6894 lhs_contracting_dims={1}, rhs_contracting_dims={0}
6895 }
6896 )";
6897 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6898 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6899 }
6900
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_SizeOneDims)6901 TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
6902 const char* kModuleStr = R"(
6903 HloModule m
6904 test {
6905 param = f32[1,2,3,5] parameter(0)
6906 transpose = f32[1,3,2,5] transpose(param), dimensions={0,2,1,3}
6907 reshape = f32[6,5] reshape(transpose)
6908 constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6909 ROOT dot = f32[5,4] dot(reshape, constant),
6910 lhs_contracting_dims={0}, rhs_contracting_dims={0}
6911 }
6912 )";
6913 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6914 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6915 auto shape1 = ShapeUtil::MakeShape(F32, {6, 5});
6916 auto shape2 = ShapeUtil::MakeShape(F32, {1, 3, 2, 4});
6917 auto shape3 = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
6918 const HloInstruction* transpose;
6919 ASSERT_THAT(m->entry_computation()->root_instruction(),
6920 GmockMatch(m::Dot(
6921 m::Reshape(m::Parameter(0)).WithShapeCompatibleTo(&shape1),
6922 m::Reshape(m::Transpose(&transpose,
6923 m::Reshape(m::Constant())
6924 .WithShapeCompatibleTo(&shape2))
6925 .WithShapeCompatibleTo(&shape3)))));
6926 EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
6927 }
6928
TEST_F(AlgebraicSimplifierTest,DotContractingReorder_NoChangeInContractingDimsOrder)6929 TEST_F(AlgebraicSimplifierTest,
6930 DotContractingReorder_NoChangeInContractingDimsOrder) {
6931 // No optimization opportunity here because the transpose does not reorder the
6932 // contracting dims.
6933 const char* kModuleStr = R"(
6934 HloModule m
6935 test {
6936 param = f32[2,5,1,3] parameter(0)
6937 transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
6938 reshape = f32[5,6] reshape(transpose)
6939 constant = f32[6,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4},{1,2,3,4}})
6940 ROOT dot = f32[5,4] dot(reshape, constant),
6941 lhs_contracting_dims={1}, rhs_contracting_dims={0}
6942 }
6943 )";
6944 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6945 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6946 }
6947
TEST_F(AlgebraicSimplifierTest,CompareIota)6948 TEST_F(AlgebraicSimplifierTest, CompareIota) {
6949 const char* kModuleStr = R"(
6950 HloModule m
6951 test {
6952 zero = s32[] constant(0)
6953 iota = s32[128] iota(), iota_dimension=0
6954 broad = s32[128] broadcast(zero), dimensions={}
6955 ROOT compare = pred[128] compare(iota, broad), direction=LT
6956 })";
6957 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6958 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6959 EXPECT_THAT(m->entry_computation()->root_instruction(),
6960 GmockMatch(m::Broadcast(m::ConstantScalar(false))));
6961 }
6962
TEST_F(AlgebraicSimplifierTest,CompareLtZero)6963 TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
6964 const char* kModuleStr = R"(
6965 HloModule m
6966 test {
6967 zero = u32[] constant(0)
6968 param = u32[] parameter(0)
6969 ROOT compare = pred[] compare(param, zero), direction=LT
6970 })";
6971 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6972 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6973 EXPECT_THAT(m->entry_computation()->root_instruction(),
6974 GmockMatch(m::ConstantScalar(false)));
6975 }
6976
TEST_F(AlgebraicSimplifierTest,CompareLeZero)6977 TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
6978 const char* kModuleStr = R"(
6979 HloModule m
6980 test {
6981 zero = u32[] constant(0)
6982 param = u32[] parameter(0)
6983 ROOT compare = pred[] compare(param, zero), direction=LE
6984 })";
6985 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
6986 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
6987 EXPECT_THAT(
6988 m->entry_computation()->root_instruction(),
6989 GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
6990 }
6991
TEST_F(AlgebraicSimplifierTest,CompareGeZero)6992 TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
6993 const char* kModuleStr = R"(
6994 HloModule m
6995 test {
6996 zero = u32[] constant(0)
6997 param = u32[] parameter(0)
6998 ROOT compare = pred[] compare(param, zero), direction=GE
6999 })";
7000 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7001 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7002 EXPECT_THAT(m->entry_computation()->root_instruction(),
7003 GmockMatch(m::ConstantScalar(true)));
7004 }
7005
TEST_F(AlgebraicSimplifierTest,CompareGtZero)7006 TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
7007 const char* kModuleStr = R"(
7008 HloModule m
7009 test {
7010 zero = u32[] constant(0)
7011 param = u32[] parameter(0)
7012 ROOT compare = pred[] compare(param, zero), direction=GT
7013 })";
7014 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7015 EXPECT_THAT(
7016 m->entry_computation()->root_instruction(),
7017 GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
7018 }
7019
TEST_F(AlgebraicSimplifierTest,CompareZeroGt)7020 TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
7021 const char* kModuleStr = R"(
7022 HloModule m
7023 test {
7024 zero = u32[] constant(0)
7025 param = u32[] parameter(0)
7026 ROOT compare = pred[] compare(zero, param), direction=GT
7027 })";
7028 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7029 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7030 EXPECT_THAT(m->entry_computation()->root_instruction(),
7031 GmockMatch(m::ConstantScalar(false)));
7032 }
7033
TEST_F(AlgebraicSimplifierTest,CompareZeroGe)7034 TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
7035 const char* kModuleStr = R"(
7036 HloModule m
7037 test {
7038 zero = u32[] constant(0)
7039 param = u32[] parameter(0)
7040 ROOT compare = pred[] compare(zero, param), direction=GE
7041 })";
7042 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7043 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7044 EXPECT_THAT(
7045 m->entry_computation()->root_instruction(),
7046 GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
7047 }
7048
TEST_F(AlgebraicSimplifierTest,CompareZeroLe)7049 TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
7050 const char* kModuleStr = R"(
7051 HloModule m
7052 test {
7053 zero = u32[] constant(0)
7054 param = u32[] parameter(0)
7055 ROOT compare = pred[] compare(zero, param), direction=LE
7056 })";
7057 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7058 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7059 EXPECT_THAT(m->entry_computation()->root_instruction(),
7060 GmockMatch(m::ConstantScalar(true)));
7061 }
7062
TEST_F(AlgebraicSimplifierTest,CompareZeroLt)7063 TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
7064 const char* kModuleStr = R"(
7065 HloModule m
7066 test {
7067 zero = u32[] constant(0)
7068 param = u32[] parameter(0)
7069 ROOT compare = pred[] compare(zero, param), direction=LT
7070 })";
7071 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7072 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7073 EXPECT_THAT(
7074 m->entry_computation()->root_instruction(),
7075 GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
7076 }
7077
TEST_F(AlgebraicSimplifierTest,CompareSame)7078 TEST_F(AlgebraicSimplifierTest, CompareSame) {
7079 const char* kModuleStr = R"(
7080 HloModule m
7081 test {
7082 param = s32[123] parameter(0)
7083 ROOT compare = pred[123] compare(param, param), direction=GE
7084 })";
7085 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7086 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7087 EXPECT_THAT(m->entry_computation()->root_instruction(),
7088 GmockMatch(m::Broadcast(m::ConstantScalar(true))));
7089 }
7090
TEST_F(AlgebraicSimplifierTest,CompareSimplified)7091 TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
7092 const char* kModuleStr = R"(
7093 HloModule m
7094 test {
7095 param = s32[] parameter(0)
7096 c1 = s32[] constant(10)
7097 c2 = s32[] constant(100)
7098 cmp1 = pred[] compare(param, c1), direction=LT
7099 cmp2 = pred[] compare(param, c2), direction=LT
7100 ROOT out = pred[] and(cmp1, cmp2)
7101 })";
7102 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7103 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7104 EXPECT_THAT(
7105 m->entry_computation()->root_instruction(),
7106 GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
7107 .WithComparisonDirection(ComparisonDirection::kLt)));
7108 }
7109
TEST_F(AlgebraicSimplifierTest,CompareSimplifiedReversed)7110 TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) {
7111 const char* kModuleStr = R"(
7112 HloModule m
7113 test {
7114 param = s32[] parameter(0)
7115 c1 = s32[] constant(10)
7116 c2 = s32[] constant(100)
7117 cmp1 = pred[] compare(param, c1), direction=LT
7118 cmp2 = pred[] compare(c2, param), direction=GT
7119 ROOT out = pred[] and(cmp1, cmp2)
7120 })";
7121 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7122 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7123 EXPECT_THAT(
7124 m->entry_computation()->root_instruction(),
7125 GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
7126 .WithComparisonDirection(ComparisonDirection::kLt)));
7127 }
7128
TEST_F(AlgebraicSimplifierTest,CanDisableDotToMultiplyRewrite)7129 TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
7130 // Some backends may have better performance by treating an outer product as a
7131 // Dot, rather than a broadcast Multiply
7132 const char* kModuleStr = R"(
7133 HloModule m
7134 test {
7135 param1 = f32[64] parameter(0)
7136 param2 = f32[64] parameter(1)
7137 ROOT compare = f32[64, 64] dot(param1, param2),
7138 lhs_contracting_dims={}, rhs_contracting_dims={}
7139 })";
7140
7141 // Verify that the default is to re-write
7142 TF_ASSERT_OK_AND_ASSIGN(auto m1, ParseAndReturnVerifiedModule(kModuleStr));
7143 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m1.get()).ValueOrDie());
7144 EXPECT_THAT(m1->entry_computation()->root_instruction(),
7145 GmockMatch(m::Multiply(m::Op(), m::Op())));
7146
7147 // Verify that we can disable the re-write
7148 AlgebraicSimplifierOptions opts = default_options_;
7149 opts.set_enable_dot_to_multiply_rewrite(false);
7150 TF_ASSERT_OK_AND_ASSIGN(auto m2, ParseAndReturnVerifiedModule(kModuleStr));
7151 ASSERT_FALSE(AlgebraicSimplifier(opts).Run(m2.get()).ValueOrDie());
7152 }
7153
TEST_F(AlgebraicSimplifierTest,RemainderOfIota)7154 TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
7155 const char* kModuleStr = R"(
7156 HloModule m
7157 test {
7158 iota = s32[5,1000] iota(), iota_dimension=0
7159 five = s32[] constant(5)
7160 five_bcast = s32[5,1000] broadcast(s32[] five), dimensions={}
7161 ROOT remainder = s32[5,1000] remainder(iota, s32[5,1000] five_bcast)
7162 })";
7163 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7164 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7165 EXPECT_THAT(m->entry_computation()->root_instruction(),
7166 GmockMatch(m::Iota()));
7167 }
7168
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIota)7169 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIota) {
7170 const char* kModuleStr = R"(
7171 HloModule m
7172 test {
7173 iota = s32[5,1000] iota(), iota_dimension=0
7174 five = s32[] constant(5)
7175 five_bcast = s32[5,1000] broadcast(five), dimensions={}
7176 sum = s32[5,1000] add(iota, five_bcast)
7177 ROOT remainder = s32[5,1000] remainder(sum, s32[5,1000] five_bcast)
7178 })";
7179 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7180 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7181 EXPECT_THAT(m->entry_computation()->root_instruction(),
7182 GmockMatch(m::Remainder(m::Iota(), m::Broadcast())));
7183 }
7184
7185 // No simplification because 125 + 5 overflows S8.
TEST_F(AlgebraicSimplifierTest,RemainderOfNPlusIotaOverflow)7186 TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
7187 const char* kModuleStr = R"(
7188 HloModule m
7189 test {
7190 iota = s8[126] iota(), iota_dimension=0
7191 five = s8[] constant(5)
7192 five_bcast = s8[126] broadcast(five), dimensions={}
7193 sum = s8[126] add(iota, five_bcast)
7194 ROOT remainder = s8[126] remainder(sum, s8[126] five_bcast)
7195 })";
7196 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7197 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7198 }
7199
TEST_F(AlgebraicSimplifierTest,RepeatedRemainder)7200 TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
7201 const char* kModuleStr = R"(
7202 HloModule m
7203 test {
7204 p = s32[1000] parameter(0)
7205 q = s32[1000] parameter(1)
7206 r = s32[1000] remainder(p, q)
7207 ROOT rr = s32[1000] remainder(r, q)
7208 })";
7209 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7210 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7211 EXPECT_THAT(m->entry_computation()->root_instruction(),
7212 GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
7213 }
7214
TEST_F(AlgebraicSimplifierTest,SlicePadLayout)7215 TEST_F(AlgebraicSimplifierTest, SlicePadLayout) {
7216 const char* kModuleStr = R"(
7217 HloModule m
7218 test {
7219 %param.0 = f32[128,9,9,1024]{0,3,2,1} parameter(0)
7220 %param.1 = f32[] parameter(1)
7221 %slice = f32[128,9,9,1024]{0,3,2,1} slice(%param.0),
7222 slice={[0:128], [0:9], [0:9], [0:1024]}
7223 ROOT %pad = f32[128,8,9,1024]{0,3,2,1} pad(%slice, %param.1),
7224 padding=0_0x-1_0x0_0x0_0
7225 })";
7226 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7227 const Shape root_shape = m->entry_computation()->root_instruction()->shape();
7228 AlgebraicSimplifierOptions options;
7229 options.set_is_layout_sensitive(true);
7230 ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
7231 EXPECT_THAT(m->entry_computation()->root_instruction(),
7232 GmockMatch(m::Slice().WithShapeEqualTo(&root_shape)));
7233 }
7234
TEST_F(AlgebraicSimplifierTest,MinOfMaxToClamp)7235 TEST_F(AlgebraicSimplifierTest, MinOfMaxToClamp) {
7236 const char* kModuleStr = R"(
7237 HloModule m
7238 test {
7239 p0 = f32[4] parameter(0)
7240 c0 = f32[] constant(3.0)
7241 c1 = f32[] constant(4.0)
7242 b0 = f32[4] broadcast(c0), dimensions={}
7243 b1 = f32[4] broadcast(c1), dimensions={}
7244 m0 = f32[4] maximum(b0, p0)
7245 ROOT m1 = f32[4] minimum(m0, b1)
7246 }
7247 )";
7248 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7249 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7250 EXPECT_THAT(
7251 m->entry_computation()->root_instruction(),
7252 GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
7253 m::Broadcast(m::ConstantScalar(4.0)))));
7254 }
7255
TEST_F(AlgebraicSimplifierTest,MaxOfMinToClamp)7256 TEST_F(AlgebraicSimplifierTest, MaxOfMinToClamp) {
7257 const char* kModuleStr = R"(
7258 HloModule m
7259 test {
7260 p0 = f32[4] parameter(0)
7261 c0 = f32[] constant(3.0)
7262 c1 = f32[] constant(4.0)
7263 b0 = f32[4] broadcast(c0), dimensions={}
7264 b1 = f32[4] broadcast(c1), dimensions={}
7265 m0 = f32[4] minimum(p0, b1)
7266 ROOT m1 = f32[4] maximum(b0, m0)
7267 }
7268 )";
7269 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7270 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7271 EXPECT_THAT(
7272 m->entry_computation()->root_instruction(),
7273 GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0),
7274 m::Broadcast(m::ConstantScalar(4.0)))));
7275 }
7276
TEST_F(AlgebraicSimplifierTest,ClampOfClamp)7277 TEST_F(AlgebraicSimplifierTest, ClampOfClamp) {
7278 const char* kModuleStr = R"(
7279 HloModule m
7280 test {
7281 p0 = f32[] parameter(0)
7282 p1 = f32[] parameter(1)
7283 p2 = f32[] parameter(2)
7284 c0 = f32[] clamp(p0, p1, p2)
7285 ROOT c1 = f32[] clamp(p0, c0, p2)
7286 }
7287 )";
7288 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7289 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7290 EXPECT_THAT(
7291 m->entry_computation()->root_instruction(),
7292 GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
7293 }
7294
TEST_F(AlgebraicSimplifierTest,MaxOfClamp)7295 TEST_F(AlgebraicSimplifierTest, MaxOfClamp) {
7296 const char* kModuleStr = R"(
7297 HloModule m
7298 test {
7299 p0 = f32[] parameter(0)
7300 p1 = f32[] parameter(1)
7301 p2 = f32[] parameter(2)
7302 c0 = f32[] clamp(p0, p1, p2)
7303 ROOT m0 = f32[] maximum(p0, c0)
7304 }
7305 )";
7306 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7307 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7308 EXPECT_THAT(
7309 m->entry_computation()->root_instruction(),
7310 GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2))));
7311 }
7312
TEST_F(AlgebraicSimplifierTest,SliceOfConcat)7313 TEST_F(AlgebraicSimplifierTest, SliceOfConcat) {
7314 const char* kModuleStr = R"(
7315 HloModule m
7316 test {
7317 p0 = f32[100,50] parameter(0)
7318 p1 = f32[50,50] parameter(1)
7319 c0 = f32[150,50] concatenate(p0, p1), dimensions={0}
7320 ROOT s0 = f32[50,50] slice(c0), slice={[100:150], [0:50]}
7321 }
7322 )";
7323 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7324 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7325 EXPECT_THAT(m->entry_computation()->root_instruction(),
7326 GmockMatch(m::Parameter(1)));
7327 }
7328
TEST_F(AlgebraicSimplifierTest,SliceOfMultipleConcatOperands)7329 TEST_F(AlgebraicSimplifierTest, SliceOfMultipleConcatOperands) {
7330 const char* kModuleStr = R"(
7331 HloModule m
7332 test {
7333 p0 = f32[50,50] parameter(0)
7334 p1 = f32[50,50] parameter(1)
7335 p2 = f32[50,50] parameter(2)
7336 p3 = f32[50,50] parameter(3)
7337 c0 = f32[200,50] concatenate(p0, p1, p2, p3), dimensions={0}
7338 ROOT s0 = f32[98,50] slice(c0), slice={[51:149], [0:50]}
7339 }
7340 )";
7341 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7342 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7343 EXPECT_THAT(
7344 m->entry_computation()->root_instruction(),
7345 GmockMatch(m::Slice(m::Concatenate(m::Parameter(1), m::Parameter(2)))));
7346 EXPECT_THAT(m->entry_computation()->root_instruction()->slice_starts(),
7347 ElementsAre(1, 0));
7348 EXPECT_THAT(m->entry_computation()->root_instruction()->slice_limits(),
7349 ElementsAre(99, 50));
7350 }
7351
TEST_F(AlgebraicSimplifierTest,SqrtOfSelfMultiply)7352 TEST_F(AlgebraicSimplifierTest, SqrtOfSelfMultiply) {
7353 const char* kModuleStr = R"(
7354 HloModule m
7355 test {
7356 p0 = f32[32]{0} parameter(0)
7357 m0 = f32[32]{0} multiply(f32[32]{0} p0, f32[32]{0} p0)
7358 ROOT s0 = f32[32]{0} sqrt(f32[32]{0} m0)
7359 }
7360 )";
7361 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7362 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7363 EXPECT_THAT(m->entry_computation()->root_instruction(),
7364 GmockMatch(m::Abs(m::Parameter(0))));
7365 }
7366
TEST_F(AlgebraicSimplifierTest,ReduceOfBatchDotToContractingDimension)7367 TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) {
7368 const char* kModuleStr = R"(
7369 HloModule m
7370 a {
7371 p0 = f32[] parameter(0)
7372 p1 = f32[] parameter(1)
7373 ROOT r = f32[] add(p0, p1)
7374 }
7375 test {
7376 p0 = f32[32,8,5,6] parameter(0)
7377 p1 = f32[8,32,6,7] parameter(1)
7378 d = f32[32,8,5,7] dot(p0, p1),
7379 lhs_batch_dims={0,1},
7380 rhs_batch_dims={1,0},
7381 rhs_contracting_dims={2},
7382 lhs_contracting_dims={3}
7383 c = f32[] constant(0)
7384 ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a
7385 }
7386 )";
7387 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7388 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7389 EXPECT_THAT(m->entry_computation()->root_instruction(),
7390 GmockMatch(m::Dot(m::Parameter(0), m::Parameter(1))));
7391 }
7392
TEST_F(AlgebraicSimplifierTest,ReduceAddIsCommutative)7393 TEST_F(AlgebraicSimplifierTest, ReduceAddIsCommutative) {
7394 const char* kModuleStr = R"(
7395 HloModule m
7396 fn1 {
7397 p0 = f32[] parameter(0)
7398 p1 = f32[] parameter(1)
7399 ROOT r = f32[] add(p0, p1)
7400 }
7401 fn2 {
7402 p0 = f32[] parameter(0)
7403 p1 = f32[] parameter(1)
7404 ROOT r = f32[] add(p1, p0)
7405 }
7406 test {
7407 p0 = f32[10,10,10] parameter(0)
7408 zero = f32[] constant(0)
7409 r1 = f32[10,10] reduce(p0, zero), dimensions={0}, to_apply=fn1
7410 ROOT r2 = f32[10] reduce(r1, zero), dimensions={0}, to_apply=fn2
7411 }
7412 )";
7413 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7414 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7415 EXPECT_THAT(m->entry_computation()->root_instruction(),
7416 GmockMatch(m::Reduce(m::Parameter(0), m::ConstantScalar(0))));
7417 }
7418
TEST_F(AlgebraicSimplifierTest,RsqrtOfRPower)7419 TEST_F(AlgebraicSimplifierTest, RsqrtOfRPower) {
7420 const char* kModuleStr = R"(
7421 HloModule m
7422 test {
7423 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7424 p1 = f32[32]{0} parameter(1)
7425 p2 = f32[32]{0} parameter(2)
7426 c0 = f32[] constant(0.001)
7427 c1 = s64[] constant(1)
7428 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, c0, c1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7429 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7430 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7431 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7432 c2 = f32[] constant(-2)
7433 broadcast = f32[32]{0} broadcast(f32[] c2), dimensions={}
7434 power = f32[32]{0} power(get-tuple-element, broadcast)
7435 rsqrt = f32[32]{0} rsqrt(f32[32]{0} power)
7436 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
7437 }
7438 )";
7439 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7440 default_options_.set_cudnn_batchnorm_forward_training_metadata(
7441 "__cudnn$batchNormalizationForwardTraining");
7442 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7443 // Expect transformation: rsqrt(power(gte.2,-2)) -> abs(gte.2)
7444 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kPower), nullptr);
7445 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7446 auto computation = m->entry_computation();
7447 auto root = computation->root_instruction();
7448 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
7449 EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kAbs);
7450 EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
7451 HloOpcode::kGetTupleElement);
7452 }
7453
TEST_F(AlgebraicSimplifierTest,RsqrtDivide)7454 TEST_F(AlgebraicSimplifierTest, RsqrtDivide) {
7455 const char* kModuleStr = R"(
7456 HloModule m
7457 test {
7458 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7459 p1 = f32[32]{0} parameter(1)
7460 p2 = f32[32]{0} parameter(2)
7461 constant = f32[] constant(0.001)
7462 constant.1 = s64[] constant(1)
7463 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7464 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7465 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7466 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7467 constant.2 = f32[] constant(1)
7468 broadcast.1 = f32[32]{0} broadcast(constant.2), dimensions={}
7469 divide = f32[32]{0} divide(broadcast.1, get-tuple-element)
7470 rsqrt = f32[32]{0} rsqrt(divide)
7471 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
7472 }
7473 )";
7474 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7475 default_options_.set_cudnn_batchnorm_forward_training_metadata(
7476 "__cudnn$batchNormalizationForwardTraining");
7477 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7478 // Expect transformation: rsqrt(divide(1,gte.2)) -> sqrt(gte.2)
7479 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
7480 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7481 auto computation = m->entry_computation();
7482 auto root = computation->root_instruction();
7483 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
7484 EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kSqrt);
7485 EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
7486 HloOpcode::kGetTupleElement);
7487 }
7488
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt)7489 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) {
7490 const char* kModuleStr = R"(
7491 HloModule m
7492 test {
7493 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7494 p1 = f32[32]{0} parameter(1)
7495 p2 = f32[32]{0} parameter(2)
7496 constant = f32[] constant(0.001)
7497 constant.1 = s64[] constant(1)
7498 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7499 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7500 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7501 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7502 rsqrt = f32[32]{0} rsqrt(get-tuple-element)
7503 multiply = f32[32]{0} multiply(rsqrt, rsqrt)
7504 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
7505 }
7506 )";
7507 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7508 default_options_.set_cudnn_batchnorm_forward_training_metadata(
7509 "__cudnn$batchNormalizationForwardTraining");
7510 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7511
7512 // Expect transformation: multiply(rsqrt(gte.2), rsqrt(gte.2)) -> divide(1,
7513 // gte.2)
7514 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
7515 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7516
7517 auto computation = m->entry_computation();
7518 auto root = computation->root_instruction();
7519 EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
7520 EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kDivide);
7521 EXPECT_EQ(root->operand(2)->operand(0)->opcode(), HloOpcode::kBroadcast);
7522 EXPECT_EQ(root->operand(2)->operand(1)->opcode(),
7523 HloOpcode::kGetTupleElement);
7524 }
7525
TEST_F(AlgebraicSimplifierTest,MultiplySelfRsqrt_NegativeTestCase)7526 TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) {
7527 const char* kModuleStr = R"(
7528 HloModule m
7529 test {
7530 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7531 p1 = f32[32]{0} parameter(1)
7532 p2 = f32[32]{0} parameter(2)
7533 constant = f32[] constant(0.001)
7534 constant.1 = s64[] constant(1)
7535 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7536 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7537 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7538 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7539 rsqrt = f32[32]{0} rsqrt(get-tuple-element)
7540 multiply = f32[32]{0} multiply(rsqrt, rsqrt)
7541 ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
7542 }
7543 )";
7544 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7545 default_options_.set_cudnn_batchnorm_forward_training_metadata(
7546 "__cudnn$batchNormalizationForward");
7547 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7548 EXPECT_NE(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
7549 EXPECT_NE(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
7550 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
7551 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kBroadcast), nullptr);
7552 EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
7553 HloOpcode::kMultiply);
7554 }
7555
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining)7556 TEST_F(AlgebraicSimplifierTest, AbsEliminationBatchnormTraining) {
7557 const char* kModuleStr = R"(
7558 HloModule m
7559 test {
7560 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7561 p1 = f32[32]{0} parameter(1)
7562 p2 = f32[32]{0} parameter(2)
7563 constant = f32[] constant(0.001)
7564 constant.1 = s64[] constant(1)
7565 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7566 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7567 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7568 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7569 abs = f32[32]{0} abs(get-tuple-element)
7570 ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
7571 }
7572 )";
7573 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7574 default_options_.set_cudnn_batchnorm_forward_training_metadata(
7575 "__cudnn$batchNormalizationForwardTraining");
7576 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7577 // Verify that the module doesn't have any abs node.
7578 EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
7579 EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
7580 HloOpcode::kGetTupleElement);
7581 }
7582
TEST_F(AlgebraicSimplifierTest,AbsEliminationBatchnormTraining_NegativeTestCase)7583 TEST_F(AlgebraicSimplifierTest,
7584 AbsEliminationBatchnormTraining_NegativeTestCase) {
7585 const char* kModuleStr = R"(
7586 HloModule m
7587 test {
7588 p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
7589 p1 = f32[32]{0} parameter(1)
7590 p2 = f32[32]{0} parameter(2)
7591 constant = f32[] constant(0.001)
7592 constant.1 = s64[] constant(1)
7593 custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
7594 get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
7595 get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
7596 get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
7597 abs = f32[32]{0} abs(get-tuple-element)
7598 ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
7599 }
7600 )";
7601 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7602 default_options_.set_cudnn_batchnorm_forward_training_metadata(
7603 "__cudnn$batchNormalizationForwardInference");
7604 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7605 EXPECT_NE(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
7606 }
7607
TEST_F(AlgebraicSimplifierTest,AbsEliminationMultiply)7608 TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
7609 const char* kModuleStr = R"(
7610 HloModule m
7611 test {
7612 p = f32[32]{0} parameter(0)
7613 m = f32[32]{0} multiply(p, p)
7614 ROOT a = f32[32]{0} abs(m)
7615 }
7616 )";
7617 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7618 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7619 EXPECT_THAT(m->entry_computation()->root_instruction(),
7620 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
7621 }
7622
TEST_F(AlgebraicSimplifierTest,AbsEliminationPower2)7623 TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
7624 const char* kModuleStr = R"(
7625 HloModule m
7626 test {
7627 p0 = f32[32]{0} parameter(0)
7628 c0 = f32[] constant(2)
7629 b0 = f32[32]{0} broadcast(c0), dimensions={}
7630 pow = f32[32]{0} power(p0, b0)
7631 ROOT a = f32[32]{0} abs(pow)
7632 }
7633 )";
7634 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7635 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7636 // Pow(A, 2) is transformed to AA. As a result, Abs(Power(A, 2)) is
7637 // transformed to AA.
7638 EXPECT_THAT(m->entry_computation()->root_instruction(),
7639 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
7640 }
7641
TEST_F(AlgebraicSimplifierTest,ScatterAddCombined)7642 TEST_F(AlgebraicSimplifierTest, ScatterAddCombined) {
7643 const char* hlo_string = R"(
7644 HloModule m
7645 apply {
7646 a = f32[] parameter(0)
7647 b = f32[] parameter(1)
7648 ROOT c = f32[] add(a, b)
7649 }
7650 test {
7651 z = f32[] constant(0)
7652 init = f32[100,4] broadcast(z), dimensions={}
7653 shared = f32[100,4] parameter(0)
7654 index0 = s32[20] parameter(1)
7655 index1 = s32[10] parameter(2)
7656 update0 = f32[20,4] parameter(3)
7657 update1 = f32[10,4] parameter(4)
7658 scatter.0 = f32[100,4] scatter(init, index0, update0),
7659 to_apply=apply,
7660 update_window_dims={1},
7661 inserted_window_dims={0},
7662 scatter_dims_to_operand_dims={0},
7663 index_vector_dim=1
7664 scatter.1 = f32[100,4] scatter(init, index1, update1),
7665 to_apply=apply,
7666 update_window_dims={1},
7667 inserted_window_dims={0},
7668 scatter_dims_to_operand_dims={0},
7669 index_vector_dim=1
7670 add.0 = f32[100,4] add(shared, scatter.0)
7671 ROOT add.1 = f32[100,4] add(add.0, scatter.1)
7672 }
7673 )";
7674 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7675 // Combine Scatters
7676 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7677 // Optimize Add with 0
7678 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7679 EXPECT_THAT(
7680 m->entry_computation()->root_instruction(),
7681 GmockMatch(m::Scatter(m::Parameter(0),
7682 m::Concatenate(m::Parameter(1), m::Parameter(2)),
7683 m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7684 }
7685
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedSwapped)7686 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedSwapped) {
7687 const char* hlo_string = R"(
7688 HloModule m
7689 apply {
7690 a = f32[] parameter(0)
7691 b = f32[] parameter(1)
7692 ROOT c = f32[] add(a, b)
7693 }
7694 test {
7695 z = f32[] constant(0)
7696 init = f32[100,4] broadcast(z), dimensions={}
7697 shared = f32[100,4] parameter(0)
7698 index0 = s32[20] parameter(1)
7699 index1 = s32[10] parameter(2)
7700 update0 = f32[20,4] parameter(3)
7701 update1 = f32[10,4] parameter(4)
7702 scatter.0 = f32[100,4] scatter(init, index0, update0),
7703 to_apply=apply,
7704 update_window_dims={1},
7705 inserted_window_dims={0},
7706 scatter_dims_to_operand_dims={0},
7707 index_vector_dim=1
7708 scatter.1 = f32[100,4] scatter(init, index1, update1),
7709 to_apply=apply,
7710 update_window_dims={1},
7711 inserted_window_dims={0},
7712 scatter_dims_to_operand_dims={0},
7713 index_vector_dim=1
7714 add.0 = f32[100,4] add(shared, scatter.0)
7715 ROOT add.1 = f32[100,4] add(scatter.1, add.0)
7716 }
7717 )";
7718 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7719 // Combine Scatters
7720 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7721 // Optimize Add with 0
7722 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7723 EXPECT_THAT(
7724 m->entry_computation()->root_instruction(),
7725 GmockMatch(m::Scatter(m::Parameter(0),
7726 m::Concatenate(m::Parameter(2), m::Parameter(1)),
7727 m::Concatenate(m::Parameter(4), m::Parameter(3)))));
7728 }
7729
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums)7730 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums) {
7731 const char* hlo_string = R"(
7732 HloModule m
7733 apply {
7734 a = f32[] parameter(0)
7735 b = f32[] parameter(1)
7736 ROOT c = f32[] add(a, b)
7737 }
7738 test {
7739 z = f32[] constant(0)
7740 init = f32[100,4] broadcast(z), dimensions={}
7741 shared = f32[100,4] parameter(0)
7742 index0 = s32[1,4,5] parameter(1)
7743 index1 = s32[1,2,5] parameter(2)
7744 update0 = f32[4,4,5] parameter(3)
7745 update1 = f32[2,4,5] parameter(4)
7746 scatter.0 = f32[100,4] scatter(init, index0, update0),
7747 to_apply=apply,
7748 update_window_dims={1},
7749 inserted_window_dims={0},
7750 scatter_dims_to_operand_dims={0},
7751 index_vector_dim=0
7752 scatter.1 = f32[100,4] scatter(init, index1, update1),
7753 to_apply=apply,
7754 update_window_dims={1},
7755 inserted_window_dims={0},
7756 scatter_dims_to_operand_dims={0},
7757 index_vector_dim=0
7758 ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7759 }
7760 )";
7761 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7762 // Combine Scatters
7763 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7764 // Simplify Add
7765 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7766 EXPECT_THAT(
7767 m->entry_computation()->root_instruction(),
7768 GmockMatch(m::Scatter(m::Broadcast(),
7769 m::Concatenate(m::Parameter(1), m::Parameter(2)),
7770 m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7771 }
7772
TEST_F(AlgebraicSimplifierTest,ScatterAddCombinedWeirdDnums2)7773 TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) {
7774 const char* hlo_string = R"(
7775 HloModule m
7776 apply {
7777 a = f32[] parameter(0)
7778 b = f32[] parameter(1)
7779 ROOT c = f32[] add(a, b)
7780 }
7781 test {
7782 z = f32[] constant(0)
7783 init = f32[100,4] broadcast(z), dimensions={}
7784 shared = f32[100,4] parameter(0)
7785 index0 = s32[4,3,1] parameter(1)
7786 index1 = s32[4,5,1] parameter(2)
7787 update0 = f32[4,4,3] parameter(3)
7788 update1 = f32[4,4,5] parameter(4)
7789 scatter.0 = f32[100,4] scatter(init, index0, update0),
7790 to_apply=apply,
7791 update_window_dims={0},
7792 inserted_window_dims={0},
7793 scatter_dims_to_operand_dims={0},
7794 index_vector_dim=2
7795 scatter.1 = f32[100,4] scatter(init, index1, update1),
7796 to_apply=apply,
7797 update_window_dims={0},
7798 inserted_window_dims={0},
7799 scatter_dims_to_operand_dims={0},
7800 index_vector_dim=2
7801 ROOT add.1 = f32[100,4] add(scatter.0, scatter.1)
7802 }
7803 )";
7804 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7805 // Combine Scatters
7806 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7807 // Simplify Add
7808 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7809 EXPECT_THAT(
7810 m->entry_computation()->root_instruction(),
7811 GmockMatch(m::Scatter(m::Broadcast(),
7812 m::Concatenate(m::Parameter(1), m::Parameter(2)),
7813 m::Concatenate(m::Parameter(3), m::Parameter(4)))));
7814 }
7815
TEST_F(AlgebraicSimplifierTest,ScalarScatter)7816 TEST_F(AlgebraicSimplifierTest, ScalarScatter) {
7817 const char* hlo_string = R"(
7818 HloModule m
7819 apply {
7820 a = f32[] parameter(0)
7821 b = f32[] parameter(1)
7822 ROOT c = f32[] add(a, b)
7823 }
7824 test {
7825 z = f32[] constant(0)
7826 init = f32[100,4,20] broadcast(z), dimensions={}
7827 shared = f32[100,4,20] parameter(0)
7828 index0 = s32[1] parameter(1)
7829 index1 = s32[1] parameter(2)
7830 update0 = f32[4,20] parameter(3)
7831 update1 = f32[4,20] parameter(4)
7832 scatter.0 = f32[100,4,20] scatter(init, index0, update0),
7833 to_apply=apply,
7834 update_window_dims={0, 1},
7835 inserted_window_dims={0},
7836 scatter_dims_to_operand_dims={0},
7837 index_vector_dim=0
7838 scatter.1 = f32[100,4,20] scatter(init, index1, update1),
7839 to_apply=apply,
7840 update_window_dims={0, 1},
7841 inserted_window_dims={0},
7842 scatter_dims_to_operand_dims={0},
7843 index_vector_dim=0
7844 ROOT add.1 = f32[100,4,20] add(scatter.0, scatter.1)
7845 }
7846 )";
7847 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7848 // Combine Scatters
7849 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7850 }
7851
TEST_F(AlgebraicSimplifierTest,SwapConvOperands)7852 TEST_F(AlgebraicSimplifierTest, SwapConvOperands) {
7853 const char* hlo_string = R"(
7854 HloModule m
7855 test {
7856 a = f32[3,3,160,160] parameter(0)
7857 b = f32[128,32,32,160] parameter(1)
7858 ROOT c = f32[128,32,32,160] convolution(a,b),
7859 window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1},
7860 dim_labels=01bf_o01i->f01b
7861 }
7862 )";
7863 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
7864 // Combine Scatters
7865 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7866 const HloInstruction* conv = m->entry_computation()->root_instruction();
7867 EXPECT_THAT(conv,
7868 GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0))));
7869 EXPECT_EQ(conv->window().dimensions(0).size(), 3);
7870 EXPECT_EQ(conv->window().dimensions(1).size(), 3);
7871 EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true);
7872 EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true);
7873 EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1);
7874 EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1);
7875 EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1);
7876 EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1);
7877 }
7878
TEST_F(AlgebraicSimplifierTest,ScalarDividePredicate)7879 TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
7880 const char* kModuleStr = R"(
7881 HloModule m
7882 test {
7883 p0 = pred[2] parameter(0)
7884 cvt = f32[2] convert(p0)
7885 p1 = f32[] parameter(1)
7886 bcast = f32[2] broadcast(p1), dimensions={}
7887 ROOT div = f32[2] divide(cvt, bcast)
7888 }
7889 )";
7890 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7891 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7892 EXPECT_THAT(
7893 m->entry_computation()->root_instruction(),
7894 GmockMatch(m::MultiplyAnyOrder(
7895 m::Convert(m::Parameter(0)),
7896 m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
7897 }
7898
TEST_F(AlgebraicSimplifierTest,MultipleDotStrengthReductions)7899 TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
7900 constexpr char kModuleStr[] = R"(
7901 HloModule test
7902 ENTRY test {
7903 a = c64[2,2] parameter(0)
7904 b = c64[2] parameter(1)
7905 cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7906 c = f64[2,2] parameter(2)
7907 d = f64[2] parameter(3)
7908 dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
7909 ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
7910 }
7911 )";
7912 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7913 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7914 EXPECT_EQ(3, m->computation_count());
7915 }
7916
TEST_F(AlgebraicSimplifierTest,UnaryVariadicReduce)7917 TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) {
7918 const char* kModuleStr = R"(
7919 HloModule m
7920 fn {
7921 p0 = f32[] parameter(0)
7922 p1 = f32[] parameter(1)
7923 a = f32[] add(p0, p1)
7924 ROOT t = (f32[]) tuple(a)
7925 }
7926 test {
7927 p0 = f32[32,8,6,7] parameter(0)
7928 c = f32[] constant(0)
7929 ROOT r = (f32[8,6,7]) reduce(p0, c), dimensions={0}, to_apply=fn
7930 }
7931 )";
7932 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7933 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7934 ASSERT_THAT(
7935 m->entry_computation()->root_instruction(),
7936 GmockMatch(m::Tuple(m::Reduce(m::Parameter(0), m::ConstantScalar(0)))));
7937 ASSERT_EQ(m->entry_computation()
7938 ->root_instruction()
7939 ->operand(0)
7940 ->called_computations()
7941 .size(),
7942 1);
7943 EXPECT_THAT(m->entry_computation()
7944 ->root_instruction()
7945 ->operand(0)
7946 ->called_computations()[0]
7947 ->root_instruction(),
7948 GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
7949 }
7950
TEST_F(AlgebraicSimplifierTest,ReplaceReduceMaxWithReduceArgMax)7951 TEST_F(AlgebraicSimplifierTest, ReplaceReduceMaxWithReduceArgMax) {
7952 const char* kModuleStr = R"(
7953 HloModule ReplaceReduceMaxWithReduceArgMax
7954
7955 %reduction_computation__1.25287 (parameter.25288: bf16[], parameter.25289: s32[], parameter.25290: bf16[], parameter.25291: s32[]) -> (bf16[], s32[]) {
7956 %constant.25292 = pred[] constant(false)
7957 %parameter.25288 = bf16[] parameter(0)
7958 %parameter.25290 = bf16[] parameter(2)
7959 %compare.25293 = pred[] compare(bf16[] %parameter.25288, bf16[] %parameter.25290), direction=GT
7960 %compare.25294 = pred[] compare(bf16[] %parameter.25288, bf16[] %parameter.25288), direction=NE
7961 %or.25295 = pred[] or(pred[] %compare.25293, pred[] %compare.25294)
7962 %select.25300 = bf16[] select(pred[] %or.25295, bf16[] %parameter.25288, bf16[] %parameter.25290)
7963 %compare.25296 = pred[] compare(bf16[] %parameter.25288, bf16[] %parameter.25290), direction=EQ
7964 %parameter.25289 = s32[] parameter(1)
7965 %parameter.25291 = s32[] parameter(3)
7966 %compare.25297 = pred[] compare(s32[] %parameter.25289, s32[] %parameter.25291), direction=LT
7967 %and.25298 = pred[] and(pred[] %compare.25296, pred[] %compare.25297)
7968 %or.25299 = pred[] or(pred[] %or.25295, pred[] %and.25298)
7969 %select.25301 = s32[] select(pred[] %or.25299, s32[] %parameter.25289, s32[] %parameter.25291)
7970 ROOT %tuple.25302 = (bf16[], s32[]) tuple(bf16[] %select.25300, s32[] %select.25301)
7971 }
7972
7973 %primitive_computation_max.25303 (parameter.25304: bf16[], parameter.25305: bf16[]) -> bf16[] {
7974 %parameter.25304 = bf16[] parameter(0), metadata={op_type="max" op_name="max"}
7975 %parameter.25305 = bf16[] parameter(1), metadata={op_type="max" op_name="max"}
7976 ROOT %maximum.25306 = bf16[] maximum(bf16[] %parameter.25304, bf16[] %parameter.25305), metadata={op_type="max" op_name="max"}
7977 }
7978
7979 ENTRY %main {
7980 %p0 = bf16[384,128,19392]{2,1,0} parameter(0)
7981
7982 // Variadic Reduce (ArgMax)
7983 %iota.25376 = s32[384,128,19392] iota(), iota_dimension=2
7984 %constant.25377 = bf16[] constant(-inf)
7985 %constant.25378 = s32[] constant(0)
7986 %reduce.25379 = (bf16[384,128]{1,0}, s32[384,128]{1,0}) reduce(bf16[384,128,19392]{2,1,0} %p0, s32[384,128,19392] %iota.25376, bf16[] %constant.25377, s32[] %constant.25378), dimensions={2}, to_apply=%reduction_computation__1.25287
7987
7988 %get-tuple-element.25381 = s32[384,128]{1,0} get-tuple-element((bf16[384,128]{1,0}, s32[384,128]{1,0}) %reduce.25379), index=1
7989
7990 // Reduce (Max)
7991 %constant.25382 = bf16[] constant(-inf)
7992 %reduce.25383 = bf16[384,128]{1,0} reduce(bf16[384,128,19392]{2,1,0} %p0, bf16[] %constant.25382), dimensions={2}, to_apply=%primitive_computation_max.25303
7993
7994 ROOT %tuple.0 = (bf16[384,128]{1,0}, s32[384,128]{1,0}) tuple(%reduce.25383, %get-tuple-element.25381)
7995 }
7996 )";
7997 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
7998 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
7999 int64_t reduce_count = absl::c_count_if(
8000 m->entry_computation()->instructions(), [](const HloInstruction* hlo) {
8001 return hlo->opcode() == HloOpcode::kReduce;
8002 });
8003 // Expect one Reduce operation after simplification.
8004 EXPECT_EQ(1, reduce_count);
8005 auto variadic_reduce = m::Reduce().WithShape(m::Shape().IsTuple());
8006 auto root = m->entry_computation()->root_instruction();
8007 // Expect that both outputs are fed by 'variadic_reduce'.
8008 ASSERT_THAT(root,
8009 GmockMatch(m::Tuple(m::GetTupleElement(variadic_reduce, 0),
8010 m::GetTupleElement(variadic_reduce, 1))));
8011 }
8012
TEST_F(AlgebraicSimplifierTest,ReplaceReduceMinWithReduceArgMin)8013 TEST_F(AlgebraicSimplifierTest, ReplaceReduceMinWithReduceArgMin) {
8014 const char* kModuleStr = R"(
8015 HloModule ReplaceReduceMinWithReduceArgMin
8016
8017 %region_3.84 (Arg_0.85: bf16[], Arg_1.86: s32[], Arg_2.87: bf16[], Arg_3.88: s32[]) -> (bf16[], s32[]) {
8018 %Arg_3.88 = s32[]{:T(256)} parameter(3)
8019 %Arg_2.87 = bf16[]{:T(512)} parameter(2)
8020 %Arg_1.86 = s32[]{:T(256)} parameter(1)
8021 %compare.93 = pred[]{:T(1024)S(6)} compare(s32[]{:T(256)} %Arg_1.86, s32[]{:T(256)} %Arg_3.88), direction=LT, metadata={op_name="lt" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8022 %Arg_0.85 = bf16[]{:T(512)} parameter(0)
8023 %compare.92 = pred[]{:T(1024)S(6)} compare(bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_2.87), direction=EQ, metadata={op_name="eq" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8024 %and.94 = pred[]{:T(1024)S(6)} and(pred[]{:T(1024)S(6)} %compare.92, pred[]{:T(1024)S(6)} %compare.93), metadata={op_name="and" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8025 %compare.90 = pred[]{:T(1024)S(6)} compare(bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_0.85), direction=NE, metadata={op_name="ne" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8026 %compare.89 = pred[]{:T(1024)S(6)} compare(bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_2.87), direction=LT, metadata={op_name="lt" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8027 %or.91 = pred[]{:T(1024)S(6)} or(pred[]{:T(1024)S(6)} %compare.89, pred[]{:T(1024)S(6)} %compare.90), metadata={op_name="or" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8028 %select.96 = bf16[]{:T(512)} select(pred[]{:T(1024)S(6)} %or.91, bf16[]{:T(512)} %Arg_0.85, bf16[]{:T(512)} %Arg_2.87), metadata={op_name="select_n" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8029 %or.95 = pred[]{:T(1024)S(6)} or(pred[]{:T(1024)S(6)} %or.91, pred[]{:T(1024)S(6)} %and.94), metadata={op_name="or" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8030 %select.97 = s32[]{:T(256)} select(pred[]{:T(1024)S(6)} %or.95, s32[]{:T(256)} %Arg_1.86, s32[]{:T(256)} %Arg_3.88), metadata={op_name="select_n" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8031 ROOT %tuple.98 = (bf16[]{:T(512)}, s32[]{:T(256)}) tuple(bf16[]{:T(512)} %select.96, s32[]{:T(256)} %select.97)
8032 }
8033
8034 %region_0.8 (Arg_0.9: bf16[], Arg_1.10: bf16[]) -> bf16[] {
8035 %Arg_1.10 = bf16[]{:T(512)} parameter(1)
8036 %Arg_0.9 = bf16[]{:T(512)} parameter(0)
8037 ROOT %minimum.11 = bf16[]{:T(512)} minimum(bf16[]{:T(512)} %Arg_0.9, bf16[]{:T(512)} %Arg_1.10), metadata={op_name="jit(ScaMTPUTopK)/jit(main)/jit(ScaMTPUTopK)/jit(jit_ScaMTPUTopK)/reduce_min[axes=(2,)]" source_file="<ipython-input-4-4f3bd086a82e>" source_line=8}
8038 }
8039
8040 ENTRY %main {
8041 %param_0.3 = bf16[1024,1024,2048]{2,0,1:T(8,128)(2,1)} parameter(0)
8042
8043 // ArgMin
8044 %iota.5.clone.1 = s32[1024,1024,2048]{2,0,1:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(ScaMTPUTopK)/jit(main)/jit(ScaMTPUTopK)/jit(jit_ScaMTPUTopK)/iota[dtype=int32 shape=(1024, 1024, 2048) dimension=2]" source_file="<ipython-input-4-4f3bd086a82e>" source_line=12}
8045 %constant.24 = bf16[]{:T(512)} constant(inf)
8046 %constant.23 = s32[]{:T(256)} constant(0)
8047 %reduce.3 = (bf16[1024,1024]{0,1:T(8,128)(2,1)}, s32[1024,1024]{0,1:T(8,128)}) reduce(bf16[1024,1024,2048]{2,0,1:T(8,128)(2,1)} %param_0.3, s32[1024,1024,2048]{2,0,1:T(8,128)} %iota.5.clone.1, bf16[]{:T(512)} %constant.24, s32[]{:T(256)} %constant.23), dimensions={2}, to_apply=%region_3.84
8048
8049 %gte.0 = s32[1024,1024]{0,1:T(8,128)} get-tuple-element(%reduce.3), index=1
8050
8051 // ReduceMin
8052 %constant.25 = bf16[]{:T(512)} constant(inf)
8053 %reduce.4 = bf16[1024,1024]{0,1:T(8,128)(2,1)} reduce(bf16[1024,1024,2048]{2,0,1:T(8,128)(2,1)} %param_0.3, bf16[]{:T(512)} %constant.25), dimensions={2}, to_apply=%region_0.8
8054
8055 ROOT %tuple.0 = (bf16[1024,1024]{0,1:T(8,128)(2,1)}, s32[1024,1024]{0,1:T(8,128)}) tuple(%reduce.4, %gte.0)
8056 }
8057 )";
8058 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8059 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8060 int64_t reduce_count = absl::c_count_if(
8061 m->entry_computation()->instructions(), [](const HloInstruction* hlo) {
8062 return hlo->opcode() == HloOpcode::kReduce;
8063 });
8064 // Expect one Reduce operation after simplification.
8065 EXPECT_EQ(1, reduce_count);
8066 auto variadic_reduce = m::Reduce().WithShape(m::Shape().IsTuple());
8067 auto root = m->entry_computation()->root_instruction();
8068 // Expect that both outputs are fed by 'variadic_reduce'.
8069 ASSERT_THAT(root,
8070 GmockMatch(m::Tuple(m::GetTupleElement(variadic_reduce, 0),
8071 m::GetTupleElement(variadic_reduce, 1))));
8072 }
8073
TEST_F(AlgebraicSimplifierTest,UnaryVariadicReduceWindow)8074 TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduceWindow) {
8075 const char* kModuleStr = R"(
8076 HloModule m
8077 fn {
8078 p0 = f32[] parameter(0)
8079 p1 = f32[] parameter(1)
8080 a = f32[] add(p0, p1)
8081 ROOT t = (f32[]) tuple(a)
8082 }
8083 test {
8084 p0 = f32[32,8,6,7] parameter(0)
8085 c = f32[] constant(0)
8086 ROOT r = (f32[32,8,6,7]) reduce-window(p0, c), to_apply=fn, window={size=1x1x1x1}
8087 }
8088 )";
8089 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8090 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8091 ASSERT_THAT(m->entry_computation()->root_instruction(),
8092 GmockMatch(m::Tuple(
8093 m::ReduceWindow(m::Parameter(0), m::ConstantScalar(0)))));
8094 ASSERT_EQ(m->entry_computation()
8095 ->root_instruction()
8096 ->operand(0)
8097 ->called_computations()
8098 .size(),
8099 1);
8100 EXPECT_THAT(m->entry_computation()
8101 ->root_instruction()
8102 ->operand(0)
8103 ->called_computations()[0]
8104 ->root_instruction(),
8105 GmockMatch(m::Add(m::Parameter(0), m::Parameter(1))));
8106 }
8107
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorder)8108 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorder) {
8109 const char* kModuleStr = R"(
8110 HloModule m
8111 test {
8112 c1 = pred[] constant(true)
8113 b2 = pred[32,1,768]{2,1,0} broadcast(pred[] c1), dimensions={}
8114 c3 = pred[] constant(false)
8115 ROOT p4 = pred[4096,1,768]{2,1,0} pad(pred[32,1,768]{2,1,0} b2, pred[] c3), padding=0_4064x0_0x0_0
8116 }
8117 )";
8118 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8119 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8120 EXPECT_THAT(m->entry_computation()->root_instruction(),
8121 GmockMatch(m::Broadcast(
8122 m::Pad(m::Broadcast(m::Constant()), m::Constant()))));
8123 }
8124
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithUse)8125 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithUse) {
8126 const char* kModuleStr = R"(
8127 HloModule m
8128 test {
8129 c1 = pred[] constant(true)
8130 b2 = pred[1,768,32]{2,1,0} broadcast(pred[] c1), dimensions={}
8131 c3 = pred[] constant(false)
8132 p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
8133 ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
8134 }
8135 )";
8136 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8137 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8138 EXPECT_THAT(m->entry_computation()->root_instruction(),
8139 GmockMatch(m::Tuple(m::Broadcast(
8140 m::Pad(m::Broadcast(m::Constant()), m::Constant())))));
8141 }
8142
TEST_F(AlgebraicSimplifierTest,BroadcastAndPadReorderWithNonScalar)8143 TEST_F(AlgebraicSimplifierTest, BroadcastAndPadReorderWithNonScalar) {
8144 const char* kModuleStr = R"(
8145 HloModule m
8146 test {
8147 c1 = pred[32] parameter(0)
8148 b2 = pred[1,768,32]{2,1,0} broadcast(pred[32] c1), dimensions={2}
8149 c3 = pred[] constant(false)
8150 p4 = pred[1,768,4096]{2,1,0} pad(pred[1,768,32]{2,1,0} b2, pred[] c3), padding=0_0x0_0x0_4064
8151 ROOT p5 = (pred[1,768,4096]{2,1,0}) tuple(pred[1,768,4096]{2,1,0} p4)
8152 }
8153 )";
8154 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8155 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8156 EXPECT_THAT(m->entry_computation()->root_instruction(),
8157 GmockMatch(m::Tuple(m::Broadcast(
8158 m::Pad(m::Broadcast(m::Parameter()), m::Constant())))));
8159 }
8160
8161 // Test that dynamic-update-slice with a scalar broadcast becomes a pad when the
8162 // start_indices are too big.
TEST_F(AlgebraicSimplifierTest,DynamicUpdateSliceOfBroadcastToPadOob)8163 TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceOfBroadcastToPadOob) {
8164 const char* hlo_string = R"(
8165 HloModule module
8166
8167 ENTRY f {
8168 constant.546 = f32[] constant(0)
8169 broadcast.467 = f32[2]{0} broadcast(constant.546), dimensions={}
8170 parameter.1 = f32[1]{0} parameter(0)
8171 constant.551 = s32[] constant(2)
8172 ROOT dynamic-update-slice.44 = f32[2]{0} dynamic-update-slice(broadcast.467, parameter.1, constant.551)
8173 }
8174 )";
8175 TF_ASSERT_OK_AND_ASSIGN(auto module,
8176 ParseAndReturnVerifiedModule(hlo_string));
8177 VLOG(2) << "Before rewrite dus->pad\n" << module->ToString();
8178 AlgebraicSimplifier simplifier(default_options_);
8179 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8180 VLOG(2) << "After rewrite dus->pad\n" << module->ToString();
8181 auto* pad = module->entry_computation()->root_instruction();
8182 EXPECT_THAT(pad,
8183 GmockMatch(m::Pad(m::Parameter(0), m::ConstantScalar(0.0f))));
8184 EXPECT_FALSE(HasInteriorPadding(pad->padding_config()));
8185 ASSERT_EQ(pad->padding_config().dimensions_size(), 1);
8186 EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_low(), 1);
8187 EXPECT_EQ(pad->padding_config().dimensions(0).edge_padding_high(), 0);
8188 }
8189
8190 // Test folding of dynamic_slice(iota, index) -> clamp(index, 0, size-1)
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfIota)8191 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfIota) {
8192 const char* hlo_string = R"(
8193 HloModule module
8194
8195 ENTRY f {
8196 %cst = s32[2]{0} constant({0, 1})
8197 %index = u32[] parameter(0)
8198 ROOT %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %cst, u32[] %index),
8199 dynamic_slice_sizes={1}
8200 }
8201 )";
8202
8203 TF_ASSERT_OK_AND_ASSIGN(auto module,
8204 ParseAndReturnVerifiedModule(hlo_string));
8205 AlgebraicSimplifier simplifier(default_options_);
8206 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8207 VLOG(2) << "After rewrite \n" << module->ToString();
8208
8209 EXPECT_THAT(module->entry_computation()->root_instruction(),
8210 GmockMatch(m::Reshape(m::Convert(
8211 m::Clamp(m::Constant(), m::Parameter(0), m::Constant())))));
8212 }
8213
8214 // Test folding of clamp(pid, 0, limit) -> pid
TEST_F(AlgebraicSimplifierTest,ClampOfPartitionId)8215 TEST_F(AlgebraicSimplifierTest, ClampOfPartitionId) {
8216 const char* hlo_string = R"(
8217 HloModule module
8218
8219 ENTRY f {
8220 %pid = u32[] partition-id()
8221 %low = u32[] constant(0)
8222 %high = u32[] constant(5)
8223 ROOT %c = u32[] clamp(%low, %pid, %high)
8224 }
8225 )";
8226
8227 TF_ASSERT_OK_AND_ASSIGN(
8228 auto module, ParseAndReturnVerifiedModule(hlo_string, /*replica_count=*/1,
8229 /*num_partitions=*/6));
8230 AlgebraicSimplifier simplifier(default_options_);
8231 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8232 VLOG(2) << "After rewrite \n" << module->ToString();
8233
8234 EXPECT_THAT(module->entry_computation()->root_instruction(),
8235 GmockMatch(m::PartitionId()));
8236 }
8237
TEST_F(AlgebraicSimplifierTest,ConstantToIota)8238 TEST_F(AlgebraicSimplifierTest, ConstantToIota) {
8239 const char* hlo_string = R"(
8240 HloModule module
8241
8242 ENTRY f {
8243 %cst = s32[4] constant({0, 25, 50, 75})
8244 ROOT %s = s32[4] copy(s32[4] %cst)
8245 }
8246 )";
8247 TF_ASSERT_OK_AND_ASSIGN(auto module,
8248 ParseAndReturnVerifiedModule(hlo_string));
8249 AlgebraicSimplifier simplifier(default_options_);
8250 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8251 VLOG(2) << "After rewrite \n" << module->ToString();
8252
8253 EXPECT_THAT(module->entry_computation()->root_instruction(),
8254 GmockMatch(m::Multiply(m::Iota(), m::Broadcast())));
8255 }
8256
TEST_F(AlgebraicSimplifierTest,DynamicSliceOfStridedIota)8257 TEST_F(AlgebraicSimplifierTest, DynamicSliceOfStridedIota) {
8258 const char* hlo_string = R"(
8259 HloModule module
8260
8261 ENTRY f {
8262 %cst = s32[4] constant({0, 25, 50, 75})
8263 %index = u32[] parameter(0)
8264 ROOT %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %cst, u32[] %index),
8265 dynamic_slice_sizes={1}
8266 }
8267 )";
8268 TF_ASSERT_OK_AND_ASSIGN(auto module,
8269 ParseAndReturnVerifiedModule(hlo_string));
8270 AlgebraicSimplifier simplifier(default_options_);
8271 ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
8272 VLOG(2) << "After rewrite \n" << module->ToString();
8273
8274 EXPECT_THAT(module->entry_computation()->root_instruction(),
8275 GmockMatch(m::Reshape(
8276 m::Multiply(m::Convert(m::Clamp()), m::Constant()))));
8277 }
8278
TEST_F(AlgebraicSimplifierTest,AbsEliminationSelMaxBcast)8279 TEST_F(AlgebraicSimplifierTest, AbsEliminationSelMaxBcast) {
8280 const char* kModuleStr = R"(
8281 HloModule m
8282 test {
8283 p0 = f32[32]{0} parameter(0)
8284 p1 = pred[32]{0} parameter(1)
8285 zero = f32[] constant(0.0)
8286 bcast = f32[32] broadcast(zero), dimensions={}
8287 m = f32[32]{0} maximum(p0, bcast)
8288 s = f32[32]{0} select(p1, bcast, m)
8289 ROOT a = f32[32]{0} abs(s)
8290 }
8291 )";
8292 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8293 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8294 EXPECT_THAT(m->entry_computation()->root_instruction(),
8295 GmockMatch(m::Select(
8296 m::Parameter(1), m::Broadcast(m::ConstantScalar()),
8297 m::MaximumAnyOrder(m::Parameter(0),
8298 m::Broadcast(m::ConstantScalar())))));
8299 }
8300
TEST_F(AlgebraicSimplifierTest,SimplifyRedundantBitcastConvert)8301 TEST_F(AlgebraicSimplifierTest, SimplifyRedundantBitcastConvert) {
8302 const char* kModuleStr = R"(
8303 HloModule m
8304
8305 ENTRY test {
8306 p0 = s32[10] parameter(0)
8307 p1 = s32[10] parameter(1)
8308 b0 = u32[10] bitcast-convert(p0)
8309 b1 = u32[10] bitcast-convert(p1)
8310 c = u32[20] concatenate(b0, b1), dimensions={0}
8311 ROOT out = s32[20] bitcast-convert(c)
8312 }
8313 )";
8314 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8315 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8316 EXPECT_THAT(m->entry_computation()->root_instruction(),
8317 GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(1))));
8318 }
8319
TEST_F(AlgebraicSimplifierTest,SimplifyOptimizationBarrier)8320 TEST_F(AlgebraicSimplifierTest, SimplifyOptimizationBarrier) {
8321 const char* kModuleStr = R"(
8322 HloModule m
8323
8324 ENTRY entry {
8325 param.0 = f32[] parameter(0)
8326 param.1 = f32[] parameter(1)
8327 add.0 = f32[] add(param.0, param.1)
8328 sub.0 = f32[] subtract(param.0, param.1)
8329 mul.0 = f32[] multiply(param.0, param.1)
8330 tuple.0 = (f32[], f32[], f32[]) tuple(mul.0, sub.0, add.0)
8331 b = (f32[], f32[], f32[]) opt-barrier(tuple.0)
8332 gte.0 = f32[] get-tuple-element(b), index=1
8333 ROOT t = (f32[], f32[]) tuple(mul.0,gte.0)
8334 }
8335 )";
8336 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8337 EXPECT_EQ(m->entry_computation()
8338 ->root_instruction()
8339 ->operand(1)
8340 ->operand(0)
8341 ->operand(0)
8342 ->operand_count(),
8343 3);
8344 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8345 EXPECT_EQ(m->entry_computation()
8346 ->root_instruction()
8347 ->operand(1)
8348 ->operand(0)
8349 ->operand(0)
8350 ->operand_count(),
8351 2);
8352 }
8353
TEST_F(AlgebraicSimplifierTest,GTETupleShardingLoss)8354 TEST_F(AlgebraicSimplifierTest, GTETupleShardingLoss) {
8355 // Verify the gte(tuple) folding does not happen if it loses sharding info.
8356 const char* kModuleStr = R"(
8357 HloModule m
8358
8359 ENTRY test {
8360 p0 = s32[10] parameter(0), sharding={devices=[2]0,1}
8361 t = (s32[10]) tuple(p0)
8362 ROOT %gte = s32[10] get-tuple-element(t), index=0, sharding={replicated}
8363 }
8364 )";
8365 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8366 ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8367 }
8368
TEST_F(AlgebraicSimplifierTest,DynamicSliceShapeLayout)8369 TEST_F(AlgebraicSimplifierTest, DynamicSliceShapeLayout) {
8370 // Verify we maintain layout when optimizing dynamic-slice
8371 const char* kModuleStr = R"(
8372 HloModule m
8373
8374 ENTRY test {
8375 p0 = u32[]{:T(128)} parameter(0)
8376 %constant.1 = s32[4]{0:T(128)} constant({0, 16, 32, 48})
8377 %dynamic-slice = s32[1]{0:T(128)} dynamic-slice(s32[4]{0:T(128)} %constant.1, u32[] %p0), dynamic_slice_sizes={1}
8378 ROOT t = (s32[1]{0:T(128)}) tuple(dynamic-slice)
8379 }
8380 )";
8381 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8382 ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
8383 const Shape& slice_shape =
8384 m.get()->entry_computation()->root_instruction()->operand(0)->shape();
8385 EXPECT_TRUE(slice_shape.has_layout());
8386 EXPECT_EQ(slice_shape.layout().tiles_size(), 1);
8387 }
8388
8389 // Fold a sequence of copy bitcast copy
TEST_F(AlgebraicSimplifierTest,CopyBitcastCopy)8390 TEST_F(AlgebraicSimplifierTest, CopyBitcastCopy) {
8391 const char* kModuleStr = R"(
8392 HloModule m
8393
8394 ENTRY test {
8395 fusion.1235 = bf16[1600,50,512]{2,0,1:T(8,128)(2,1)} parameter(0)
8396 copy.3038 = bf16[1600,50,512]{0,2,1:T(8,128)(2,1)} copy(fusion.1235)
8397 bitcast.8 = bf16[1600,50,16,32]{0,3,2,1:T(8,128)(2,1)} bitcast(copy.3038)
8398 copy.3045 = bf16[1600,50,16,32]{1,3,2,0:T(8,128)(2,1)} copy(bitcast.8)
8399 }
8400 )";
8401 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8402 AlgebraicSimplifierOptions options;
8403 options.set_is_layout_sensitive(true);
8404 AlgebraicSimplifier simplifier(options);
8405 ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
8406 EXPECT_THAT(m->entry_computation()->root_instruction(),
8407 GmockMatch(m::Bitcast(m::Copy(m::Parameter()))));
8408 }
8409
TEST_F(AlgebraicSimplifierTest,CopyBitcastCopy2)8410 TEST_F(AlgebraicSimplifierTest, CopyBitcastCopy2) {
8411 const char* kModuleStr = R"(
8412 HloModule m
8413
8414 ENTRY test {
8415 %Arg_0.1 = f32[8,3,3,7,7]{4,0,3,2,1:T(8,128)} parameter(0)
8416 %copy.1 = f32[8,3,3,7,7]{4,3,2,1,0:T(8,128)} copy(f32[8,3,3,7,7]{4,0,3,2,1:T(8,128)} %Arg_0.1)
8417 %bitcast = f32[1,72,7,7]{3,2,1,0:T(8,128)} bitcast(f32[8,3,3,7,7]{4,3,2,1,0:T(8,128)} %copy.1)
8418 %copy.2 = f32[1,72,7,7]{1,3,2,0:T(8,128)} copy(f32[1,72,7,7]{3,2,1,0:T(8,128)} %bitcast)
8419 }
8420 )";
8421 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8422 AlgebraicSimplifierOptions options;
8423 options.set_is_layout_sensitive(true);
8424 AlgebraicSimplifier simplifier(options);
8425 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
8426 }
8427
TEST_F(AlgebraicSimplifierTest,CopyReshapeCopy3)8428 TEST_F(AlgebraicSimplifierTest, CopyReshapeCopy3) {
8429 const char* kModuleStr = R"(
8430 HloModule m
8431
8432 ENTRY main {
8433 p = f32[2,3]{0,1} parameter(0)
8434 copy = f32[2,3]{1,0} copy(p)
8435 reshape = f32[3,2]{1,0} bitcast(copy)
8436 ROOT copy.1 = f32[3,2]{0,1} copy(reshape)
8437 })";
8438 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8439 AlgebraicSimplifierOptions options;
8440 options.set_is_layout_sensitive(true);
8441 AlgebraicSimplifier simplifier(options);
8442 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
8443 VLOG(3) << "Module " << m->ToString();
8444 }
8445
TEST_F(AlgebraicSimplifierTest,CopyReshapeCopy4)8446 TEST_F(AlgebraicSimplifierTest, CopyReshapeCopy4) {
8447 const char* kModuleStr = R"(
8448 HloModule m
8449
8450 ENTRY main {
8451 p = f32[6,2,3]{0,1,2} parameter(0)
8452 copy.0 = f32[6,2,3]{0,2,1} copy(p)
8453 reshape = f32[2,3,6]{1,0,2} bitcast(copy.0)
8454 ROOT copy.1 = f32[2,3,6]{0,1,2} copy(reshape)
8455 })";
8456 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8457 AlgebraicSimplifierOptions options;
8458 options.set_is_layout_sensitive(true);
8459 AlgebraicSimplifier simplifier(options);
8460 ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie());
8461 VLOG(3) << "Module " << m->ToString();
8462 }
8463
TEST_F(AlgebraicSimplifierTest,BitcastCopyChain)8464 TEST_F(AlgebraicSimplifierTest, BitcastCopyChain) {
8465 const char* kModuleStr = R"(
8466 HloModule m
8467
8468 ENTRY main {
8469 p.0 = f32[4,32,32,1]{2,3,1,0} parameter(0)
8470 reshape.30 = f32[4,32,1,1,32]{4,1,0,3,2} reshape(p.0)
8471 transpose.1757 = f32[4,1,1,32,32]{3,4,0,1,2} transpose(reshape.30), dimensions={0,3,2,4,1}
8472 copy.3 = f32[4,1,1,32,32]{4,3,0,2,1} copy(transpose.1757)
8473 reshape.1758 = f32[4,1,1,1024]{3,2,1,0} reshape(copy.3)
8474 transpose.61 = f32[1024,4,1,1]{0,3,2,1} transpose(reshape.1758), dimensions={3,0,1,2}
8475 copy.4 = f32[1024,4,1,1]{0,1,3,2} copy(transpose.61)
8476 ROOT reshape.107 = f32[1024,4]{0,1} reshape(copy.4)
8477 })";
8478 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8479 AlgebraicSimplifierOptions options;
8480 options.set_is_layout_sensitive(true);
8481 AlgebraicSimplifier simplifier(options);
8482 auto result = simplifier.Run(m.get()).ValueOrDie();
8483 SCOPED_TRACE(m->ToString());
8484 ASSERT_TRUE(result);
8485 EXPECT_NE(FindInstruction(m.get(), "copy.3"), nullptr);
8486 }
8487
8488 // Make sure that the following copy-bitcast-copy is not transformed via
8489 // SwapCopyBitcastCopy function. If SwapCopyBitcastCopy does not fire, in this
8490 // case, the last copy will be turned into a bitcast by HandleCopy.
TEST_F(AlgebraicSimplifierTest,BitcastCopyChainSmall)8491 TEST_F(AlgebraicSimplifierTest, BitcastCopyChainSmall) {
8492 const char* kModuleStr = R"(
8493 HloModule m
8494 ENTRY %main (para.0: f32[4,1,1,32,32]) -> f32[1024,4,1,1] {
8495 %para.0 = f32[4,1,1,32,32]{3,4,0,1,2} parameter(0)
8496 %copy.0 = f32[4,1,1,32,32]{4,3,0,2,1} copy(%para.0)
8497 %bitcast.0 = f32[1024,4,1,1]{0,3,2,1} bitcast(%copy.0)
8498 ROOT %copy.1 = f32[1024,4,1,1]{0,1,3,2} copy(%bitcast.0)
8499 }
8500
8501 )";
8502 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
8503 AlgebraicSimplifierOptions options;
8504 options.set_is_layout_sensitive(true);
8505 AlgebraicSimplifier simplifier(options);
8506 SCOPED_TRACE(m->ToString());
8507 auto result = simplifier.Run(m.get()).ValueOrDie();
8508 SCOPED_TRACE(m->ToString());
8509 ASSERT_TRUE(result);
8510 EXPECT_THAT(m->entry_computation()->root_instruction(),
8511 GmockMatch(m::Bitcast(m::Bitcast(m::Copy(m::Parameter(0))))));
8512 }
8513
8514 } // namespace
8515 } // namespace xla
8516