xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/pattern_matcher_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace xla {
28 namespace {
29 
30 namespace m = match;
31 using PatternMatcherTest = HloTestBase;
32 
TEST_F(PatternMatcherTest,AddOp)33 TEST_F(PatternMatcherTest, AddOp) {
34   constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
35     ENTRY %two_plus_two_computation () -> f32[] {
36       %two = f32[] constant(2)
37       ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
38     }
39   )";
40   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
41                           ParseAndReturnVerifiedModule(kModuleStr));
42 
43   const HloInstruction* matched_inst;
44   HloInstruction* matched_operand;
45   Shape* matched_shape;
46 
47   ASSERT_TRUE(Match(
48       hlo_module->entry_computation()->root_instruction(),
49       match::Op(&matched_inst)
50           .WithName("two_plus_two")
51           .WithOpcode(HloOpcode::kAdd)
52           .WithShape(match::Shape(&matched_shape).IsDenseArray())
53           .WithOperand(
54               0,
55               match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
56   ASSERT_NE(matched_inst, nullptr);
57   EXPECT_EQ(matched_inst->name(), "two_plus_two");
58   EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
59 
60   EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
61                     match::Add(match::Constant(), match::Constant())));
62 
63   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
64                      match::Op().WithName("bad_name")));
65   matched_inst = nullptr;
66   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
67                      match::Multiply(&matched_inst, match::Op(), match::Op())));
68 }
69 
TEST_F(PatternMatcherTest,ScalarShape)70 TEST_F(PatternMatcherTest, ScalarShape) {
71   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
72   Shape* matched_shape;
73   EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
74   EXPECT_EQ(matched_shape, &scalar_shape);
75   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
76   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
77   EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
78   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
79   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
80   EXPECT_FALSE(Match(
81       &scalar_shape,
82       match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
83 }
84 
TEST_F(PatternMatcherTest,DenseArrayShape)85 TEST_F(PatternMatcherTest, DenseArrayShape) {
86   auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
87   Shape* matched_shape;
88   EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
89   EXPECT_EQ(matched_shape, &array_shape);
90   EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
91   EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
92   EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
93   EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
94   EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
95   EXPECT_FALSE(
96       Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
97   Layout* matched_layout;
98   EXPECT_TRUE(Match(&array_shape,
99                     match::Shape().WithLayout(match::Layout(&matched_layout))));
100   EXPECT_EQ(matched_layout, &array_shape.layout());
101   EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
102 }
103 
TEST_F(PatternMatcherTest,TupleShape)104 TEST_F(PatternMatcherTest, TupleShape) {
105   auto tuple_shape = ShapeUtil::MakeTupleShape({
106       ShapeUtil::MakeShape(F32, {1, 2, 3}),
107       ShapeUtil::MakeShape(S32, {4, 5}),
108   });
109   EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
110   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
111   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
112 
113   Shape* subshape;
114   ASSERT_TRUE(Match(
115       &tuple_shape,
116       match::Shape().WithSubshape(
117           {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
118   ASSERT_NE(subshape, nullptr);
119   EXPECT_TRUE(
120       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
121   EXPECT_TRUE(Match(&tuple_shape,
122                     match::Shape().WithSubshape(
123                         {0}, match::Shape().EqualTo(
124                                  &ShapeUtil::GetSubshape(tuple_shape, {0})))));
125   EXPECT_FALSE(Match(&tuple_shape,
126                      match::Shape().WithSubshape(
127                          {0}, match::Shape().EqualTo(
128                                   &ShapeUtil::GetSubshape(tuple_shape, {1})))));
129 
130   ASSERT_TRUE(Match(
131       &tuple_shape,
132       match::Shape().WithSubshape(
133           {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
134   ASSERT_NE(subshape, nullptr);
135   EXPECT_TRUE(
136       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
137   EXPECT_TRUE(Match(&tuple_shape,
138                     match::Shape().WithSubshape(
139                         {1}, match::Shape().EqualTo(
140                                  &ShapeUtil::GetSubshape(tuple_shape, {1})))));
141   EXPECT_FALSE(Match(&tuple_shape,
142                      match::Shape().WithSubshape(
143                          {1}, match::Shape().EqualTo(
144                                   &ShapeUtil::GetSubshape(tuple_shape, {0})))));
145 
146   EXPECT_FALSE(
147       Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
148   EXPECT_FALSE(
149       Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
150 }
151 
TEST_F(PatternMatcherTest,FusionKind)152 TEST_F(PatternMatcherTest, FusionKind) {
153   constexpr char kModuleStr[] = R"(
154     HloModule test_module
155 
156     fused_computation {
157       ROOT fp0 = f32[] parameter(0)
158     }
159 
160     ENTRY while.v11 {
161       p0 = f32[] parameter(0)
162       ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
163     })";
164   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
165                           ParseAndReturnVerifiedModule(kModuleStr));
166 
167   auto* root = hlo_module->entry_computation()->root_instruction();
168   EXPECT_TRUE(Match(
169       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
170   EXPECT_FALSE(Match(
171       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
172   EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
173                                            HloInstruction::FusionKind::kLoop)));
174 }
175 
TEST_F(PatternMatcherTest,GetTupleElement)176 TEST_F(PatternMatcherTest, GetTupleElement) {
177   constexpr char kModuleStr[] = R"(
178     HloModule test_module
179 
180     ENTRY while.v11 {
181       p0 = (f32[], f32[], f32[]) parameter(0)
182       ROOT gte = f32[] get-tuple-element(p0), index=1
183     })";
184   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
185                           ParseAndReturnVerifiedModule(kModuleStr));
186 
187   auto* root = hlo_module->entry_computation()->root_instruction();
188   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
189   EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
190   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
191   EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
192   EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
193 }
194 
TEST_F(PatternMatcherTest,AnyOf)195 TEST_F(PatternMatcherTest, AnyOf) {
196   constexpr char kModuleStr[] = R"(
197     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
198   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
199                           ParseAndReturnVerifiedModule(kModuleStr));
200   auto* root = hlo_module->entry_computation()->root_instruction();
201 
202   EXPECT_TRUE(
203       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
204                                                match::ConstantScalar(1))));
205   EXPECT_TRUE(
206       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
207                                                match::ConstantScalar(0))));
208   EXPECT_FALSE(
209       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
210                                                match::ConstantScalar(2))));
211 }
212 
TEST_F(PatternMatcherTest,ConstantScalar)213 TEST_F(PatternMatcherTest, ConstantScalar) {
214   using match::ConstantEffectiveScalar;
215   using match::ConstantScalar;
216   using match::Op;
217   using match::Tuple;
218 
219   constexpr char kModuleStr[] = R"(
220     HloModule test_module
221     ENTRY test {
222       a = s32[] constant(1)
223       b = s32[1,1] constant({{2}})
224       c = s32[1,2] constant({{2,2}})
225       d = f32[] constant(1)
226       e = f32[] constant(1.25)
227       ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
228     })";
229   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
230                           ParseAndReturnVerifiedModule(kModuleStr));
231   auto* root = hlo_module->entry_computation()->root_instruction();
232 
233   const HloInstruction* a = root->operand(0);
234   const HloInstruction* b = root->operand(1);
235   const HloInstruction* c = root->operand(2);
236   const HloInstruction* d = root->operand(3);
237   const HloInstruction* e = root->operand(4);
238   EXPECT_TRUE(Match(a, ConstantScalar()));
239   EXPECT_TRUE(Match(a, ConstantScalar(1)));
240   EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
241   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
242   EXPECT_FALSE(Match(a, ConstantScalar(2)));
243   EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
244   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
245   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
246 
247   EXPECT_FALSE(Match(b, ConstantScalar()));
248   EXPECT_FALSE(Match(b, ConstantScalar(2)));
249   EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
250   EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
251 
252   EXPECT_FALSE(Match(c, ConstantScalar()));
253   EXPECT_FALSE(Match(c, ConstantScalar(2)));
254   EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
255   EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
256 
257   EXPECT_TRUE(Match(d, ConstantScalar(1)));
258   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
259   EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
260   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
261 
262   EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
263   EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
264   EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
265   EXPECT_FALSE(Match(e, ConstantScalar(1)));
266   EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
267 
268   const HloInstruction* instr = nullptr;
269   EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
270   EXPECT_EQ(instr, a);
271 
272   instr = nullptr;
273   EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
274   EXPECT_EQ(instr, a);
275 
276   instr = nullptr;
277   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
278   EXPECT_EQ(instr, a);
279 
280   instr = nullptr;
281   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
282   EXPECT_EQ(instr, a);
283 }
284 
TEST_F(PatternMatcherTest,MultiplyAnyOrder)285 TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
286   using match::ConstantScalar;
287   using match::MultiplyAnyOrder;
288 
289   constexpr char kModuleStr[] = R"(
290     HloModule test_module
291     ENTRY test {
292       lhs = f16[] constant(42)
293       rhs = f16[] constant(52)
294       ROOT multiply = f16[] multiply(lhs, rhs)
295     })";
296   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
297                           ParseAndReturnVerifiedModule(kModuleStr));
298   auto* root = hlo_module->entry_computation()->root_instruction();
299   const HloInstruction* instr;
300 
301   EXPECT_TRUE(Match(
302       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
303   EXPECT_TRUE(Match(
304       root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
305 
306   // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
307   // e.g. IsNonConstant() on it.
308   EXPECT_TRUE(Match(
309       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
310                 .IsNonConstant()));
311   EXPECT_TRUE(
312       Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
313                       .IsNonConstant()));
314 }
315 
TEST_F(PatternMatcherTest,AnyOfShortCircuit)316 TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
317   using match::AnyOf;
318   using match::Multiply;
319   using match::Op;
320 
321   constexpr char kModuleStr[] = R"(
322     HloModule test_module
323     ENTRY test {
324       lhs = f16[] constant(42)
325       rhs = f16[] constant(52)
326       ROOT multiply = f16[] multiply(lhs, rhs)
327     })";
328   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
329                           ParseAndReturnVerifiedModule(kModuleStr));
330   auto* root = hlo_module->entry_computation()->root_instruction();
331 
332   {
333     const HloInstruction* mul = nullptr;
334     const HloInstruction* any = nullptr;
335 
336     ASSERT_TRUE(Match(
337         root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
338     EXPECT_NE(nullptr, mul);
339     EXPECT_EQ(nullptr, any);
340   }
341   {
342     const HloInstruction* mul = nullptr;
343     const HloInstruction* any = nullptr;
344 
345     ASSERT_TRUE(Match(
346         root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
347     EXPECT_NE(nullptr, any);
348     EXPECT_EQ(nullptr, mul);
349   }
350 }
351 
TEST_F(PatternMatcherTest,AllOf)352 TEST_F(PatternMatcherTest, AllOf) {
353   using match::AllOf;
354   using match::Broadcast;
355   using match::Constant;
356   using match::Op;
357 
358   constexpr char kModuleStr[] = R"(
359     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
360   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
361                           ParseAndReturnVerifiedModule(kModuleStr));
362   auto* root = hlo_module->entry_computation()->root_instruction();
363 
364   auto f16_scalar = ShapeUtil::MakeShape(F16, {});
365   auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
366   auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
367   auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
368   ASSERT_TRUE(Match(root, scalar_pattern));
369   ASSERT_TRUE(Match(root, f16_pattern));
370   ASSERT_TRUE(Match(root, f16_compatible_pattern));
371   EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
372                                                 f16_compatible_pattern)));
373   EXPECT_TRUE(
374       Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
375                                         scalar_pattern)));
376   EXPECT_FALSE(
377       Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
378   EXPECT_FALSE(Match(
379       root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
380   EXPECT_FALSE(
381       Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
382 }
383 
TEST_F(PatternMatcherTest,AllOfNoCaptureIfNotMatch)384 TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
385   using match::AllOf;
386   using match::Broadcast;
387   using match::Constant;
388   using match::Op;
389 
390   constexpr char kModuleStr[] = R"(
391     HloModule test_module
392     ENTRY test {
393       ROOT v = f16[] constant(42)
394     })";
395   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
396                           ParseAndReturnVerifiedModule(kModuleStr));
397   auto* root = hlo_module->entry_computation()->root_instruction();
398 
399   const HloInstruction* constant = nullptr;
400   ASSERT_FALSE(
401       Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
402   EXPECT_EQ(nullptr, constant);
403   ASSERT_TRUE(Match(root, Constant(&constant)));
404   EXPECT_NE(nullptr, constant);
405 }
406 
TEST_F(PatternMatcherTest,TestNoCapture)407 TEST_F(PatternMatcherTest, TestNoCapture) {
408   using match::Constant;
409 
410   constexpr char kModuleStr[] = R"(
411     HloModule test_module
412     ENTRY test {
413       ROOT v = f16[] constant(42)
414     })";
415   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
416                           ParseAndReturnVerifiedModule(kModuleStr));
417   auto* root = hlo_module->entry_computation()->root_instruction();
418 
419   const HloInstruction* constant = nullptr;
420   ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
421   EXPECT_EQ(nullptr, constant);
422 }
423 
TEST_F(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)424 TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
425   using match::Add;
426   using match::AddAnyOrder;
427   using match::AnyOf;
428   using match::Op;
429 
430   constexpr char kModuleStr[] = R"(
431     HloModule test_module
432     ENTRY test {
433       u = f16[] parameter(0)
434       v = f16[] parameter(1)
435       ROOT add = f16[] add(u, v)
436     })";
437   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
438                           ParseAndReturnVerifiedModule(kModuleStr));
439   auto* root = hlo_module->entry_computation()->root_instruction();
440 
441   const HloInstruction* addend0 = nullptr;
442   const HloInstruction* addend1 = nullptr;
443   const HloInstruction* addend2 = nullptr;
444   auto add2_pattern = Add(Op(&addend0), Op(&addend1));
445   auto add3_pattern = AnyOf<HloInstruction>(
446       AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
447 
448   ASSERT_TRUE(Match(root, add3_pattern));
449   EXPECT_NE(nullptr, addend0);
450   EXPECT_NE(nullptr, addend1);
451   EXPECT_EQ(nullptr, addend2);
452 }
453 
TEST_F(PatternMatcherTest,TestConcat)454 TEST_F(PatternMatcherTest, TestConcat) {
455   using match::Concatenate;
456   using match::ConstantScalar;
457   using match::Op;
458   using match::Reshape;
459 
460   constexpr char kModuleStr[] = R"(
461     HloModule test_module
462     ENTRY test {
463       c1 = u32[] constant(1)
464       c2 = u32[] constant(2)
465       c3 = u32[] constant(3)
466       c4 = u32[] constant(4)
467       r1 = u32[1] reshape(c1)
468       r2 = u32[1] reshape(c2)
469       r3 = u32[1] reshape(c3)
470       r4 = u32[1] reshape(c4)
471       ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
472     })";
473   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
474                           ParseAndReturnVerifiedModule(kModuleStr));
475   auto* root = hlo_module->entry_computation()->root_instruction();
476   ASSERT_TRUE(Match(
477       root,
478       Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
479                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
480   ASSERT_FALSE(Match(
481       root,
482       Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
483                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
484   ASSERT_FALSE(Match(
485       root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
486                         Reshape(ConstantScalar(3)))));
487   ASSERT_FALSE(Match(
488       root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
489                         Reshape(ConstantScalar(4)))));
490 }
491 
TEST_F(PatternMatcherTest,TestWithElementType)492 TEST_F(PatternMatcherTest, TestWithElementType) {
493   constexpr char kModuleStr[] = R"(
494     HloModule test_module
495     ENTRY test {
496       ROOT v = f16[] constant(42)
497     })";
498   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
499                           ParseAndReturnVerifiedModule(kModuleStr));
500   auto* root = hlo_module->entry_computation()->root_instruction();
501   EXPECT_TRUE(Match(root, m::Op().WithElementType(F16)));
502   EXPECT_FALSE(Match(root, m::Op().WithElementType(F32)));
503 }
504 
TEST_F(PatternMatcherTest,TestWithOperandIfPresent)505 TEST_F(PatternMatcherTest, TestWithOperandIfPresent) {
506   constexpr char kModuleStr[] = R"(
507     HloModule test_module
508     ENTRY test {
509       a = f16[] constant(42)
510       b = f16[] add(a, a)
511       ROOT root = tuple(a, b)
512     })";
513   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
514                           ParseAndReturnVerifiedModule(kModuleStr));
515   auto* root = hlo_module->entry_computation()->root_instruction();
516   auto* a = root->operand(0);
517   auto* b = root->operand(1);
518 
519   // No operand 0, but that's ok, still passes.
520   EXPECT_TRUE(Match(a, m::Op().WithOperandIfPresent(0, m::Iota())));
521 
522   EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(0, m::Constant())));
523   EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(1, m::Constant())));
524   EXPECT_FALSE(Match(b, m::Op().WithOperandIfPresent(0, m::Iota())));
525   // No operand 2/3, but that's ok, still passes.
526   EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(2, m::Iota())));
527   EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(3, m::Iota())));
528 }
529 
TEST_F(PatternMatcherTest,TestWithPredicate)530 TEST_F(PatternMatcherTest, TestWithPredicate) {
531   constexpr char kModuleStr[] = R"(
532     HloModule test_module
533     ENTRY test {
534       ROOT a = f16[] constant(42)
535     })";
536   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
537                           ParseAndReturnVerifiedModule(kModuleStr));
538   auto* root = hlo_module->entry_computation()->root_instruction();
539 
540   EXPECT_TRUE(
541       Match(root, m::Op().WithPredicate([&](const HloInstruction* instr) {
542         return instr == root;
543       })));
544   EXPECT_FALSE(
545       Match(root, m::Op().WithPredicate([&](const HloInstruction* instr) {
546         return instr != root;
547       })));
548 }
549 
550 template <typename Pattern>
Description(const Pattern & pattern)551 std::string Description(const Pattern& pattern) {
552   std::stringstream ss;
553   pattern.DescribeTo(&ss);
554   return ss.str();
555 }
556 
557 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)558 std::string Explanation(Elem* elem, const Pattern& pattern) {
559   std::stringstream ss;
560   MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
561   Match(elem, pattern, options);
562   return ss.str();
563 }
564 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)565 std::string Explanation(const std::unique_ptr<Elem>& elem,
566                         const Pattern& pattern) {
567   return Explanation(elem.get(), pattern);
568 }
569 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)570 std::string Explanation(const Elem& elem, const Pattern& pattern) {
571   return Explanation(&elem, pattern);
572 }
573 
574 // Helper macro for checking a pattern's description and the explanation printed
575 // when attempting to match (and presumably failing) on a given object.
576 //
577 // We use a macro rather than a function because we want good line numbers in
578 // errors.  We use this rather than writing a helper that returns a pair of
579 // (description, explanation) and doing something like
580 //
581 //   EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
582 //
583 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
584 // while EXPECT_THAT does not.  This unified diff makes the errors much easier
585 // to read.
586 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc,    \
587                                     expected_explanation)            \
588   do {                                                               \
589     EXPECT_EQ(Description(pattern), (expected_desc));                \
590     EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
591   } while (0)
592 
TEST_F(PatternMatcherTest,LayoutDescribeToAndExplain)593 TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
594   auto layout = LayoutUtil::MakeLayout({1, 2});
595   auto layout2 = LayoutUtil::MakeLayout({2, 2});
596 
597   EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
598                               "a layout", "Layout is null");
599   EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
600                               "a layout equal to {1,2}",
601                               "Layout {2,2} is not equal to expected {1,2}");
602 }
603 
TEST_F(PatternMatcherTest,CustomCallTargetMatcherDescribeAndExplain)604 TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
605   constexpr char kModuleStr[] = R"(
606     HloModule test_module
607 
608     ENTRY test {
609       ROOT out = f32[] custom-call(), custom_call_target="test_target"
610     }
611   )";
612 
613   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
614                           ParseAndReturnVerifiedModule(kModuleStr));
615 
616   auto* root = hlo_module->entry_computation()->root_instruction();
617   EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
618   EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
619 
620   EXPECT_DESC_AND_EXPLANATION(
621       root, match::Op().WithCustomCallTarget("other_target"),
622       "an HloInstruction custom call with target 'other_target'",
623       "HloInstruction is not a custom call with a target 'other_target'\nin "
624       "out = f32[] custom-call(), custom_call_target=\"test_target\"");
625 }
626 
TEST_F(PatternMatcherTest,ShapeDescribeToAndExplain)627 TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
628   auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
629   auto layout = shape.layout();
630 
631   EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
632                               "a shape", "Shape is null");
633   EXPECT_DESC_AND_EXPLANATION(
634       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
635       m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
636       "Shape not equal to f32[1,2]{0,1}\n"
637       "in f32[1,2]{1,0}");
638   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
639                               m::Shape().CompatibleTo(&shape),
640                               "a shape compatible with f32[1,2]",
641                               "Shape not compatible with f32[1,2]\n"
642                               "in f32[2,2]{1,0}");
643   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
644                               "a shape with element type F16",
645                               "Shape does not have element type F16\n"
646                               "in f32[1,2]{0,1}");
647   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
648                               "a shape that represents a scalar",
649                               "Shape is not a scalar\n"
650                               "in f32[1,2]{0,1}");
651   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
652                               "a shape that represents an array",
653                               "Shape is not an array\n"
654                               "in ()");
655   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
656                               "a shape that represents a tuple",
657                               "Shape is not a tuple\n"
658                               "in f32[1,2]{0,1}");
659   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
660                               "a shape that is an effective scalar",
661                               "Shape is not an effective scalar\n"
662                               "in f32[1,2]{0,1}");
663   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
664                               "a shape that has 42 dimensions",
665                               "Shape does not have rank 42\n"
666                               "in f32[1,2]{0,1}");
667   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
668                               "a shape that is a scalar",
669                               "Shape is not a scalar\n"
670                               "in f32[1,2]{0,1}");
671   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
672                               "a shape:\n"
673                               " * that has 1 dimension AND\n"
674                               " * that represents an array",
675                               "Shape does not have rank 1\n"
676                               "in f32[1,2]{0,1}");
677   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
678                               m::Shape().IsArray().WithRank(1),
679                               "a shape:\n"
680                               " * that represents an array AND\n"
681                               " * that has 1 dimension",
682                               "Shape is not an array\n"
683                               "in ()");
684   EXPECT_DESC_AND_EXPLANATION(
685       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
686       m::Shape().WithLayoutEqualTo(&layout),
687       "a shape with\n  a layout equal to {0,1}",
688       "Layout {1,0} is not equal to expected {0,1}\n"
689       "in f32[1,2]{1,0}");
690   EXPECT_DESC_AND_EXPLANATION(shape,
691                               m::Shape().WithSubshapeEqualTo({10}, &shape),
692                               "a shape with subshape at index {10} which is\n"
693                               "  a shape equal to f32[1,2]{0,1}",
694                               "No subshape at {10}\n"
695                               "in f32[1,2]{0,1}");
696   EXPECT_DESC_AND_EXPLANATION(
697       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
698       m::Shape().WithSubshapeEqualTo({0}, &shape),
699       "a shape with subshape at index {0} which is\n"
700       "  a shape equal to f32[1,2]{0,1}",
701       "Shape not equal to f32[1,2]{0,1}\n"
702       "in f32[2,2]{1,0}\n"
703       "in subshape at {0}\n"
704       "in (f32[2,2])");
705   EXPECT_DESC_AND_EXPLANATION(shape,
706                               m::Shape().WithSubshapeCompatibleTo({10}, &shape),
707                               "a shape with subshape at index {10} which is\n"
708                               "  a shape compatible with f32[1,2]",
709                               "No subshape at {10}\n"
710                               "in f32[1,2]{0,1}");
711   EXPECT_DESC_AND_EXPLANATION(
712       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
713       m::Shape().WithSubshapeCompatibleTo({0}, &shape),
714       "a shape with subshape at index {0} which is\n"
715       "  a shape compatible with f32[1,2]",
716       "Shape not compatible with f32[1,2]\n"
717       "in f32[2,2]{1,0}\n"
718       "in subshape at {0}\n"
719       "in (f32[2,2])");
720   EXPECT_DESC_AND_EXPLANATION(
721       ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
722       m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
723       "a shape with subshape at index {0,0} which is\n"
724       "  a shape that represents a scalar",
725       "Shape is not a scalar\n"
726       "in f32[1,2]{0,1}\n"
727       "in subshape at {0,0}\n"
728       "in ((f32[1,2]))");
729 }
730 
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)731 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
732                                         std::unique_ptr<HloInstruction> instr) {
733   instr->SetAndSanitizeName(name);
734   return instr;
735 }
736 
TEST_F(PatternMatcherTest,HloInstructionDescribeToAndExplain)737 TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
738   std::unique_ptr<HloInstruction> iota =
739       SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
740                                               /*iota_dimension=*/0));
741   std::unique_ptr<HloInstruction> constant =
742       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
743 
744   EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
745                               m::Op(), "an HloInstruction",
746                               "HloInstruction* is null");
747   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
748                               "an HloInstruction named \"foo\"",
749                               "HloInstruction not named \"foo\"\n"
750                               "in i = s32[42]{0} iota(), iota_dimension=0");
751   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
752                               "an HloInstruction with opcode add",
753                               "HloInstruction doesn't have opcode add\n"
754                               "in i = s32[42]{0} iota(), iota_dimension=0");
755   EXPECT_DESC_AND_EXPLANATION(
756       constant, m::Op().IsNonConstant(),
757       "an HloInstruction with any opcode other than constant",
758       "HloInstruction has opcode constant, expected anything else\n"
759       "in c = s32[] constant(0)");
760   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
761                               "an HloInstruction with 42 operands",
762                               "HloInstruction doesn't have 42 operands\n"
763                               "in i = s32[42]{0} iota(), iota_dimension=0");
764   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
765                               "an HloInstruction outputting\n"
766                               "  a shape that represents a tuple",
767                               "Shape is not a tuple\n"
768                               "in s32[42]{0}\n"
769                               "in output shape\n"
770                               "in i = s32[42]{0} iota(), iota_dimension=0");
771   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(F32, {42}),
772                               "an HloInstruction outputting\n"
773                               "  a shape:\n"
774                               "   * with element type F32 AND\n"
775                               "   * with dimensions [42]",
776                               "Shape does not have element type F32\n"
777                               "in s32[42]{0}\n"
778                               "in output shape\n"
779                               "in i = s32[42]{0} iota(), iota_dimension=0");
780   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(S32, {128}),
781                               "an HloInstruction outputting\n"
782                               "  a shape:\n"
783                               "   * with element type S32 AND\n"
784                               "   * with dimensions [128]",
785                               "Shape does not have dimensions [128]\n"
786                               "in s32[42]{0}\n"
787                               "in output shape\n"
788                               "in i = s32[42]{0} iota(), iota_dimension=0");
789   EXPECT_DESC_AND_EXPLANATION(
790       iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
791       "an HloInstruction with operand 2 which is:\n"
792       "  an HloInstruction with opcode add",
793       "desired operand index 2 is out of bounds\n"
794       "in i = s32[42]{0} iota(), iota_dimension=0");
795 
796   EXPECT_DESC_AND_EXPLANATION(
797       SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
798                                                 HloOpcode::kAdd, constant.get(),
799                                                 constant.get())),
800       m::Op().WithOperand(1, m::Op().IsNonConstant()),
801       "an HloInstruction with operand 1 which is:\n"
802       "  an HloInstruction with any opcode other than constant",
803       "HloInstruction has opcode constant, expected anything else\n"
804       "in c = s32[] constant(0)\n"
805       "in operand 1\n"
806       "in a = s32[] add(s32[] c, s32[] c)");
807   EXPECT_DESC_AND_EXPLANATION(
808       iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
809       "an HloInstruction with fusion kind kLoop",
810       "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
811       "in i = s32[42]{0} iota(), iota_dimension=0");
812   EXPECT_DESC_AND_EXPLANATION(
813       iota, m::Op().WithTupleIndex(42),
814       "an HloInstruction which is a GTE with index 42",
815       "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
816       "in i = s32[42]{0} iota(), iota_dimension=0");
817   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
818                               "an HloInstruction which is a constant scalar",
819                               "HloInstruction is not a constant\n"
820                               "in i = s32[42]{0} iota(), iota_dimension=0");
821   EXPECT_DESC_AND_EXPLANATION(
822       SetName("c", HloInstruction::CreateConstant(
823                        LiteralUtil::CreateR1<int>({1, 2}))),
824       m::Op().IsConstantEffectiveScalar(),
825       "an HloInstruction which is a constant effective scalar",
826       "HloInstruction is not an effective scalar\n"
827       "in c = s32[2]{0} constant({1, 2})");
828   EXPECT_DESC_AND_EXPLANATION(
829       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
830       m::Op().IsConstantScalar(42),
831       "an HloInstruction which is a constant scalar with value 42",
832       "HloInstruction's constant value 10 did not match expected value 42\n"
833       "in c = s32[] constant(10)");
834   EXPECT_DESC_AND_EXPLANATION(
835       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
836       m::Op().IsConstantEffectiveScalar(1.25),
837       "an HloInstruction which is a constant effective scalar with value 1.25",
838       "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
839       "in c = f64[] constant(2.25)");
840   EXPECT_DESC_AND_EXPLANATION(
841       constant, m::Op().Is(iota.get()),
842       absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
843                    " (i = s32[42]{0} iota(), iota_dimension=0)"),
844       absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
845                    absl::Hex(iota.get()),
846                    " (i = s32[42]{0} iota(), iota_dimension=0)\n"
847                    "in c = s32[] constant(0)"));
848 
849   EXPECT_DESC_AND_EXPLANATION(
850       SetName("a",
851               HloInstruction::CreateBinary(constant->shape(), HloOpcode::kAdd,
852                                            constant.get(), constant.get())),
853       m::Op().WithOperandIfPresent(0, m::Iota()),  //
854       "an HloInstruction either with fewer than 1 operand, or with an operand "
855       "0 which is:\n"
856       "  an HloInstruction with opcode iota",
857       "HloInstruction doesn't have opcode iota\n"
858       "in c = s32[] constant(0)\n"
859       "in operand 0\n"
860       "in a = s32[] add(s32[] c, s32[] c)");
861 
862   EXPECT_DESC_AND_EXPLANATION(
863       constant,
864       m::Op().WithPredicate([](const HloInstruction*) { return false; }),
865       "an HloInstruction which matches a user-specified predicate",
866       "HloInstruction does not match user-specified predicate\n"
867       "in c = s32[] constant(0)");
868 }
869 
TEST_F(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)870 TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
871   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
872   EXPECT_DESC_AND_EXPLANATION(
873       SetName("a", HloInstruction::CreateBinary(
874                        scalar_s32, HloOpcode::kAdd,
875                        SetName("b", HloInstruction::CreateConstant(
876                                         LiteralUtil::CreateR0(0)))
877                            .get(),
878                        SetName("c", HloInstruction::CreateConstant(
879                                         LiteralUtil::CreateR0(0)))
880                            .get())),
881       m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
882       "an HloInstruction:\n"
883       " * with opcode add AND\n"
884       " * with two operands in either order:\n"
885       "    - an HloInstruction named \"b\"\n"
886       "    - an HloInstruction named \"bar\"",
887       "HloInstruction's operands (ignoring order) did not match second "
888       "matcher.  Specifically,\n"
889       " - an HloInstruction named \"bar\"\n"
890       "does not match LHS:\n"
891       " - HloInstruction not named \"bar\"\n"
892       "   in b = s32[] constant(0)\n"
893       "does not match RHS:\n"
894       " - HloInstruction not named \"bar\"\n"
895       "   in c = s32[] constant(0)\n"
896       "in a = s32[] add(s32[] b, s32[] c)");
897 
898   EXPECT_DESC_AND_EXPLANATION(
899       SetName("a",
900               HloInstruction::CreateBinary(
901                   scalar_s32, HloOpcode::kAdd,
902                   HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
903                   SetName("c", HloInstruction::CreateConstant(
904                                    LiteralUtil::CreateR0(0)))
905                       .get())),
906       m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
907       "an HloInstruction:\n"
908       " * with opcode add AND\n"
909       " * with two operands in either order:\n"
910       "    - an HloInstruction which is a constant scalar\n"
911       "    - an HloInstruction with opcode constant",
912       "HloInstruction's LHS operand did not match either of the two matchers.  "
913       "Specifically,\n"
914       " - an HloInstruction which is a constant scalar\n"
915       "does not match LHS:\n"
916       " - HloInstruction is not a constant\n"
917       "   in p = s32[] parameter(0)\n"
918       "and\n"
919       " - an HloInstruction with opcode constant\n"
920       "does not match LHS:\n"
921       " - HloInstruction doesn't have opcode constant\n"
922       "   in p = s32[] parameter(0)\n"
923       "in a = s32[] add(s32[] p, s32[] c)");
924 }
925 
TEST_F(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)926 TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
927   EXPECT_DESC_AND_EXPLANATION(
928       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
929       m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
930                                m::Op().WithName("bar")),
931       "any of:\n"
932       " - an HloInstruction named \"foo\" OR\n"
933       " - an HloInstruction named \"bar\"",
934       "None of the following matchers succeeded:\n"
935       "Matcher #1\n"
936       " - an HloInstruction named \"foo\"\n"
937       "failed with\n"
938       " - HloInstruction not named \"foo\"\n"
939       "   in c = s32[] constant(0)\n"
940       "Matcher #2\n"
941       " - an HloInstruction named \"bar\"\n"
942       "failed with\n"
943       " - HloInstruction not named \"bar\"\n"
944       "   in c = s32[] constant(0)");
945 }
946 
TEST_F(PatternMatcherTest,Parameter)947 TEST_F(PatternMatcherTest, Parameter) {
948   auto param =
949       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
950   auto non_param =
951       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
952   EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
953   EXPECT_TRUE(Match(param.get(), m::Parameter()));
954   EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
955   EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
956   EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
957 
958   EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
959                               "an HloInstruction:\n"
960                               " * with opcode parameter AND\n"
961                               " * which is parameter 1",
962                               "HloInstruction doesn't have opcode parameter\n"
963                               "in c = s32[] constant(0)");
964   EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
965                             0, ShapeUtil::MakeShape(F32, {}), "p0"),
966                         m::Parameter(1)),
967             "HloInstruction is not parameter 1\n"
968             "in p0 = f32[] parameter(0)");
969 }
970 
TEST_F(PatternMatcherTest,OneUseAndOneUser)971 TEST_F(PatternMatcherTest, OneUseAndOneUser) {
972   auto param =
973       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
974 
975   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
976   EXPECT_DESC_AND_EXPLANATION(
977       param, m::Op().WithOneUse(),
978       "an HloInstruction which has exactly one use",
979       "HloInstruction has 0 users, but expected exactly one.\n"
980       "in p0 = f32[] parameter(0)");
981 
982   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
983   EXPECT_DESC_AND_EXPLANATION(
984       param, m::Op().WithOneUser(),
985       "an HloInstruction which has exactly one user (but possibly is used "
986       "multiple times by that instruction)",
987       "HloInstruction has 0 users, but expected exactly one.\n"
988       "in p0 = f32[] parameter(0)");
989 
990   {
991     auto reshape =
992         SetName("r", HloInstruction::CreateReshape(
993                          ShapeUtil::MakeShape(F32, {1}), param.get()));
994     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
995     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
996 
997     auto reshape1 =
998         SetName("r1", HloInstruction::CreateReshape(
999                           ShapeUtil::MakeShape(F32, {1}), param.get()));
1000     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
1001     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
1002 
1003     const char* kMultipleUserExplanation =
1004         "HloInstruction has 2 users, but expected exactly one.\n"
1005         "All users:\n"
1006         " - r = f32[1]{0} reshape(f32[] p0)\n"
1007         " - r1 = f32[1]{0} reshape(f32[] p0)\n"
1008         "in p0 = f32[] parameter(0)";
1009     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
1010               kMultipleUserExplanation);
1011     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
1012               kMultipleUserExplanation);
1013   }
1014 
1015   auto add = SetName("add", HloInstruction::CreateBinary(
1016                                 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
1017                                 param.get(), param.get()));
1018   EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
1019   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
1020   EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
1021             "HloInstruction is used 2 times by its user, but is expected to be "
1022             "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
1023             "in p0 = f32[] parameter(0)");
1024 }
1025 
TEST_F(PatternMatcherTest,Comparison)1026 TEST_F(PatternMatcherTest, Comparison) {
1027   auto shape = ShapeUtil::MakeShape(F32, {1});
1028   auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
1029   auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
1030   auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
1031                                           ComparisonDirection::kEq);
1032   auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
1033                                           ComparisonDirection::kNe);
1034   auto add =
1035       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
1036   auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
1037                                           ComparisonDirection::kLe);
1038 
1039   EXPECT_TRUE(Match(eq.get(), m::Compare()));
1040   EXPECT_TRUE(Match(eq.get(), m::Eq()));
1041   EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
1042   EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
1043   EXPECT_TRUE(Match(ne.get(), m::Compare()));
1044   EXPECT_TRUE(Match(ne.get(), m::Ne()));
1045   EXPECT_TRUE(Match(
1046       le.get(),
1047       m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
1048   EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
1049                                     m::Add(m::Parameter(0), m::Parameter(1)))));
1050 
1051   EXPECT_FALSE(Match(eq.get(), m::Add()));
1052   EXPECT_FALSE(Match(eq.get(), m::Ne()));
1053   EXPECT_FALSE(
1054       Match(le.get(),
1055             m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
1056   EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
1057   EXPECT_DESC_AND_EXPLANATION(
1058       eq, m::Ne().WithOneUser(),
1059       "an HloInstruction:\n"
1060       " * with opcode compare AND\n"
1061       " * which has comparison direction NE AND\n"
1062       " * which has exactly one user (but possibly is used "
1063       "multiple times by that instruction)",
1064       "HloInstruction is not comparison NE\n"
1065       "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
1066       "direction=EQ");
1067 }
1068 
TEST_F(PatternMatcherTest,CustomCallMatchers)1069 TEST_F(PatternMatcherTest, CustomCallMatchers) {
1070   constexpr char kModuleStr[] = R"(
1071     HloModule test_module
1072 
1073     ENTRY test {
1074       p0 = f32[] parameter(0)
1075       p1 = f32[] parameter(1)
1076       ROOT out = f32[] custom-call(p0, p1), custom_call_target="test_target"
1077     }
1078   )";
1079   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1080                           ParseAndReturnVerifiedModule(kModuleStr));
1081   auto* root = hlo_module->entry_computation()->root_instruction();
1082 
1083   EXPECT_TRUE(Match(root, m::CustomCall()));
1084   EXPECT_TRUE(Match(root, m::CustomCall("test_target")));
1085   EXPECT_TRUE(Match(
1086       root, m::CustomCall("test_target", m::Parameter(0), m::Parameter(1))));
1087 
1088   HloInstruction* instr;
1089   EXPECT_TRUE(Match(root, m::CustomCall(&instr)));
1090   EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target")));
1091   EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target", m::Parameter(0),
1092                                         m::Parameter(1))));
1093 
1094   const HloInstruction* const_instr;
1095   EXPECT_TRUE(Match(root, m::CustomCall(&const_instr)));
1096   EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target")));
1097   EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target",
1098                                         m::Parameter(0), m::Parameter(1))));
1099 
1100   EXPECT_FALSE(Match(root, m::CustomCall("other_target")));
1101   EXPECT_FALSE(Match(
1102       root, m::CustomCall("test_target", m::Parameter(1), m::Parameter(0))));
1103 }
1104 
1105 }  // namespace
1106 }  // namespace xla
1107