1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace xla {
28 namespace {
29
30 namespace m = match;
31 using PatternMatcherTest = HloTestBase;
32
TEST_F(PatternMatcherTest,AddOp)33 TEST_F(PatternMatcherTest, AddOp) {
34 constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
35 ENTRY %two_plus_two_computation () -> f32[] {
36 %two = f32[] constant(2)
37 ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
38 }
39 )";
40 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
41 ParseAndReturnVerifiedModule(kModuleStr));
42
43 const HloInstruction* matched_inst;
44 HloInstruction* matched_operand;
45 Shape* matched_shape;
46
47 ASSERT_TRUE(Match(
48 hlo_module->entry_computation()->root_instruction(),
49 match::Op(&matched_inst)
50 .WithName("two_plus_two")
51 .WithOpcode(HloOpcode::kAdd)
52 .WithShape(match::Shape(&matched_shape).IsDenseArray())
53 .WithOperand(
54 0,
55 match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
56 ASSERT_NE(matched_inst, nullptr);
57 EXPECT_EQ(matched_inst->name(), "two_plus_two");
58 EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
59
60 EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
61 match::Add(match::Constant(), match::Constant())));
62
63 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
64 match::Op().WithName("bad_name")));
65 matched_inst = nullptr;
66 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
67 match::Multiply(&matched_inst, match::Op(), match::Op())));
68 }
69
TEST_F(PatternMatcherTest,ScalarShape)70 TEST_F(PatternMatcherTest, ScalarShape) {
71 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
72 Shape* matched_shape;
73 EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
74 EXPECT_EQ(matched_shape, &scalar_shape);
75 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
76 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
77 EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
78 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
79 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
80 EXPECT_FALSE(Match(
81 &scalar_shape,
82 match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
83 }
84
TEST_F(PatternMatcherTest,DenseArrayShape)85 TEST_F(PatternMatcherTest, DenseArrayShape) {
86 auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
87 Shape* matched_shape;
88 EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
89 EXPECT_EQ(matched_shape, &array_shape);
90 EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
91 EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
92 EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
93 EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
94 EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
95 EXPECT_FALSE(
96 Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
97 Layout* matched_layout;
98 EXPECT_TRUE(Match(&array_shape,
99 match::Shape().WithLayout(match::Layout(&matched_layout))));
100 EXPECT_EQ(matched_layout, &array_shape.layout());
101 EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
102 }
103
TEST_F(PatternMatcherTest,TupleShape)104 TEST_F(PatternMatcherTest, TupleShape) {
105 auto tuple_shape = ShapeUtil::MakeTupleShape({
106 ShapeUtil::MakeShape(F32, {1, 2, 3}),
107 ShapeUtil::MakeShape(S32, {4, 5}),
108 });
109 EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
110 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
111 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
112
113 Shape* subshape;
114 ASSERT_TRUE(Match(
115 &tuple_shape,
116 match::Shape().WithSubshape(
117 {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
118 ASSERT_NE(subshape, nullptr);
119 EXPECT_TRUE(
120 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
121 EXPECT_TRUE(Match(&tuple_shape,
122 match::Shape().WithSubshape(
123 {0}, match::Shape().EqualTo(
124 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
125 EXPECT_FALSE(Match(&tuple_shape,
126 match::Shape().WithSubshape(
127 {0}, match::Shape().EqualTo(
128 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
129
130 ASSERT_TRUE(Match(
131 &tuple_shape,
132 match::Shape().WithSubshape(
133 {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
134 ASSERT_NE(subshape, nullptr);
135 EXPECT_TRUE(
136 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
137 EXPECT_TRUE(Match(&tuple_shape,
138 match::Shape().WithSubshape(
139 {1}, match::Shape().EqualTo(
140 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
141 EXPECT_FALSE(Match(&tuple_shape,
142 match::Shape().WithSubshape(
143 {1}, match::Shape().EqualTo(
144 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
145
146 EXPECT_FALSE(
147 Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
148 EXPECT_FALSE(
149 Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
150 }
151
TEST_F(PatternMatcherTest,FusionKind)152 TEST_F(PatternMatcherTest, FusionKind) {
153 constexpr char kModuleStr[] = R"(
154 HloModule test_module
155
156 fused_computation {
157 ROOT fp0 = f32[] parameter(0)
158 }
159
160 ENTRY while.v11 {
161 p0 = f32[] parameter(0)
162 ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
163 })";
164 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
165 ParseAndReturnVerifiedModule(kModuleStr));
166
167 auto* root = hlo_module->entry_computation()->root_instruction();
168 EXPECT_TRUE(Match(
169 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
170 EXPECT_FALSE(Match(
171 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
172 EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
173 HloInstruction::FusionKind::kLoop)));
174 }
175
TEST_F(PatternMatcherTest,GetTupleElement)176 TEST_F(PatternMatcherTest, GetTupleElement) {
177 constexpr char kModuleStr[] = R"(
178 HloModule test_module
179
180 ENTRY while.v11 {
181 p0 = (f32[], f32[], f32[]) parameter(0)
182 ROOT gte = f32[] get-tuple-element(p0), index=1
183 })";
184 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
185 ParseAndReturnVerifiedModule(kModuleStr));
186
187 auto* root = hlo_module->entry_computation()->root_instruction();
188 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
189 EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
190 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
191 EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
192 EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
193 }
194
TEST_F(PatternMatcherTest,AnyOf)195 TEST_F(PatternMatcherTest, AnyOf) {
196 constexpr char kModuleStr[] = R"(
197 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
198 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
199 ParseAndReturnVerifiedModule(kModuleStr));
200 auto* root = hlo_module->entry_computation()->root_instruction();
201
202 EXPECT_TRUE(
203 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
204 match::ConstantScalar(1))));
205 EXPECT_TRUE(
206 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
207 match::ConstantScalar(0))));
208 EXPECT_FALSE(
209 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
210 match::ConstantScalar(2))));
211 }
212
TEST_F(PatternMatcherTest,ConstantScalar)213 TEST_F(PatternMatcherTest, ConstantScalar) {
214 using match::ConstantEffectiveScalar;
215 using match::ConstantScalar;
216 using match::Op;
217 using match::Tuple;
218
219 constexpr char kModuleStr[] = R"(
220 HloModule test_module
221 ENTRY test {
222 a = s32[] constant(1)
223 b = s32[1,1] constant({{2}})
224 c = s32[1,2] constant({{2,2}})
225 d = f32[] constant(1)
226 e = f32[] constant(1.25)
227 ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
228 })";
229 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
230 ParseAndReturnVerifiedModule(kModuleStr));
231 auto* root = hlo_module->entry_computation()->root_instruction();
232
233 const HloInstruction* a = root->operand(0);
234 const HloInstruction* b = root->operand(1);
235 const HloInstruction* c = root->operand(2);
236 const HloInstruction* d = root->operand(3);
237 const HloInstruction* e = root->operand(4);
238 EXPECT_TRUE(Match(a, ConstantScalar()));
239 EXPECT_TRUE(Match(a, ConstantScalar(1)));
240 EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
241 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
242 EXPECT_FALSE(Match(a, ConstantScalar(2)));
243 EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
244 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
245 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
246
247 EXPECT_FALSE(Match(b, ConstantScalar()));
248 EXPECT_FALSE(Match(b, ConstantScalar(2)));
249 EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
250 EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
251
252 EXPECT_FALSE(Match(c, ConstantScalar()));
253 EXPECT_FALSE(Match(c, ConstantScalar(2)));
254 EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
255 EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
256
257 EXPECT_TRUE(Match(d, ConstantScalar(1)));
258 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
259 EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
260 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
261
262 EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
263 EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
264 EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
265 EXPECT_FALSE(Match(e, ConstantScalar(1)));
266 EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
267
268 const HloInstruction* instr = nullptr;
269 EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
270 EXPECT_EQ(instr, a);
271
272 instr = nullptr;
273 EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
274 EXPECT_EQ(instr, a);
275
276 instr = nullptr;
277 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
278 EXPECT_EQ(instr, a);
279
280 instr = nullptr;
281 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
282 EXPECT_EQ(instr, a);
283 }
284
TEST_F(PatternMatcherTest,MultiplyAnyOrder)285 TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
286 using match::ConstantScalar;
287 using match::MultiplyAnyOrder;
288
289 constexpr char kModuleStr[] = R"(
290 HloModule test_module
291 ENTRY test {
292 lhs = f16[] constant(42)
293 rhs = f16[] constant(52)
294 ROOT multiply = f16[] multiply(lhs, rhs)
295 })";
296 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
297 ParseAndReturnVerifiedModule(kModuleStr));
298 auto* root = hlo_module->entry_computation()->root_instruction();
299 const HloInstruction* instr;
300
301 EXPECT_TRUE(Match(
302 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
303 EXPECT_TRUE(Match(
304 root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
305
306 // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
307 // e.g. IsNonConstant() on it.
308 EXPECT_TRUE(Match(
309 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
310 .IsNonConstant()));
311 EXPECT_TRUE(
312 Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
313 .IsNonConstant()));
314 }
315
TEST_F(PatternMatcherTest,AnyOfShortCircuit)316 TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
317 using match::AnyOf;
318 using match::Multiply;
319 using match::Op;
320
321 constexpr char kModuleStr[] = R"(
322 HloModule test_module
323 ENTRY test {
324 lhs = f16[] constant(42)
325 rhs = f16[] constant(52)
326 ROOT multiply = f16[] multiply(lhs, rhs)
327 })";
328 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
329 ParseAndReturnVerifiedModule(kModuleStr));
330 auto* root = hlo_module->entry_computation()->root_instruction();
331
332 {
333 const HloInstruction* mul = nullptr;
334 const HloInstruction* any = nullptr;
335
336 ASSERT_TRUE(Match(
337 root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
338 EXPECT_NE(nullptr, mul);
339 EXPECT_EQ(nullptr, any);
340 }
341 {
342 const HloInstruction* mul = nullptr;
343 const HloInstruction* any = nullptr;
344
345 ASSERT_TRUE(Match(
346 root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
347 EXPECT_NE(nullptr, any);
348 EXPECT_EQ(nullptr, mul);
349 }
350 }
351
TEST_F(PatternMatcherTest,AllOf)352 TEST_F(PatternMatcherTest, AllOf) {
353 using match::AllOf;
354 using match::Broadcast;
355 using match::Constant;
356 using match::Op;
357
358 constexpr char kModuleStr[] = R"(
359 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
360 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
361 ParseAndReturnVerifiedModule(kModuleStr));
362 auto* root = hlo_module->entry_computation()->root_instruction();
363
364 auto f16_scalar = ShapeUtil::MakeShape(F16, {});
365 auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
366 auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
367 auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
368 ASSERT_TRUE(Match(root, scalar_pattern));
369 ASSERT_TRUE(Match(root, f16_pattern));
370 ASSERT_TRUE(Match(root, f16_compatible_pattern));
371 EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
372 f16_compatible_pattern)));
373 EXPECT_TRUE(
374 Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
375 scalar_pattern)));
376 EXPECT_FALSE(
377 Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
378 EXPECT_FALSE(Match(
379 root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
380 EXPECT_FALSE(
381 Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
382 }
383
TEST_F(PatternMatcherTest,AllOfNoCaptureIfNotMatch)384 TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
385 using match::AllOf;
386 using match::Broadcast;
387 using match::Constant;
388 using match::Op;
389
390 constexpr char kModuleStr[] = R"(
391 HloModule test_module
392 ENTRY test {
393 ROOT v = f16[] constant(42)
394 })";
395 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
396 ParseAndReturnVerifiedModule(kModuleStr));
397 auto* root = hlo_module->entry_computation()->root_instruction();
398
399 const HloInstruction* constant = nullptr;
400 ASSERT_FALSE(
401 Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
402 EXPECT_EQ(nullptr, constant);
403 ASSERT_TRUE(Match(root, Constant(&constant)));
404 EXPECT_NE(nullptr, constant);
405 }
406
TEST_F(PatternMatcherTest,TestNoCapture)407 TEST_F(PatternMatcherTest, TestNoCapture) {
408 using match::Constant;
409
410 constexpr char kModuleStr[] = R"(
411 HloModule test_module
412 ENTRY test {
413 ROOT v = f16[] constant(42)
414 })";
415 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
416 ParseAndReturnVerifiedModule(kModuleStr));
417 auto* root = hlo_module->entry_computation()->root_instruction();
418
419 const HloInstruction* constant = nullptr;
420 ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
421 EXPECT_EQ(nullptr, constant);
422 }
423
TEST_F(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)424 TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
425 using match::Add;
426 using match::AddAnyOrder;
427 using match::AnyOf;
428 using match::Op;
429
430 constexpr char kModuleStr[] = R"(
431 HloModule test_module
432 ENTRY test {
433 u = f16[] parameter(0)
434 v = f16[] parameter(1)
435 ROOT add = f16[] add(u, v)
436 })";
437 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
438 ParseAndReturnVerifiedModule(kModuleStr));
439 auto* root = hlo_module->entry_computation()->root_instruction();
440
441 const HloInstruction* addend0 = nullptr;
442 const HloInstruction* addend1 = nullptr;
443 const HloInstruction* addend2 = nullptr;
444 auto add2_pattern = Add(Op(&addend0), Op(&addend1));
445 auto add3_pattern = AnyOf<HloInstruction>(
446 AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
447
448 ASSERT_TRUE(Match(root, add3_pattern));
449 EXPECT_NE(nullptr, addend0);
450 EXPECT_NE(nullptr, addend1);
451 EXPECT_EQ(nullptr, addend2);
452 }
453
TEST_F(PatternMatcherTest,TestConcat)454 TEST_F(PatternMatcherTest, TestConcat) {
455 using match::Concatenate;
456 using match::ConstantScalar;
457 using match::Op;
458 using match::Reshape;
459
460 constexpr char kModuleStr[] = R"(
461 HloModule test_module
462 ENTRY test {
463 c1 = u32[] constant(1)
464 c2 = u32[] constant(2)
465 c3 = u32[] constant(3)
466 c4 = u32[] constant(4)
467 r1 = u32[1] reshape(c1)
468 r2 = u32[1] reshape(c2)
469 r3 = u32[1] reshape(c3)
470 r4 = u32[1] reshape(c4)
471 ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
472 })";
473 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
474 ParseAndReturnVerifiedModule(kModuleStr));
475 auto* root = hlo_module->entry_computation()->root_instruction();
476 ASSERT_TRUE(Match(
477 root,
478 Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
479 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
480 ASSERT_FALSE(Match(
481 root,
482 Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
483 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
484 ASSERT_FALSE(Match(
485 root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
486 Reshape(ConstantScalar(3)))));
487 ASSERT_FALSE(Match(
488 root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
489 Reshape(ConstantScalar(4)))));
490 }
491
TEST_F(PatternMatcherTest,TestWithElementType)492 TEST_F(PatternMatcherTest, TestWithElementType) {
493 constexpr char kModuleStr[] = R"(
494 HloModule test_module
495 ENTRY test {
496 ROOT v = f16[] constant(42)
497 })";
498 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
499 ParseAndReturnVerifiedModule(kModuleStr));
500 auto* root = hlo_module->entry_computation()->root_instruction();
501 EXPECT_TRUE(Match(root, m::Op().WithElementType(F16)));
502 EXPECT_FALSE(Match(root, m::Op().WithElementType(F32)));
503 }
504
TEST_F(PatternMatcherTest,TestWithOperandIfPresent)505 TEST_F(PatternMatcherTest, TestWithOperandIfPresent) {
506 constexpr char kModuleStr[] = R"(
507 HloModule test_module
508 ENTRY test {
509 a = f16[] constant(42)
510 b = f16[] add(a, a)
511 ROOT root = tuple(a, b)
512 })";
513 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
514 ParseAndReturnVerifiedModule(kModuleStr));
515 auto* root = hlo_module->entry_computation()->root_instruction();
516 auto* a = root->operand(0);
517 auto* b = root->operand(1);
518
519 // No operand 0, but that's ok, still passes.
520 EXPECT_TRUE(Match(a, m::Op().WithOperandIfPresent(0, m::Iota())));
521
522 EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(0, m::Constant())));
523 EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(1, m::Constant())));
524 EXPECT_FALSE(Match(b, m::Op().WithOperandIfPresent(0, m::Iota())));
525 // No operand 2/3, but that's ok, still passes.
526 EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(2, m::Iota())));
527 EXPECT_TRUE(Match(b, m::Op().WithOperandIfPresent(3, m::Iota())));
528 }
529
TEST_F(PatternMatcherTest,TestWithPredicate)530 TEST_F(PatternMatcherTest, TestWithPredicate) {
531 constexpr char kModuleStr[] = R"(
532 HloModule test_module
533 ENTRY test {
534 ROOT a = f16[] constant(42)
535 })";
536 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
537 ParseAndReturnVerifiedModule(kModuleStr));
538 auto* root = hlo_module->entry_computation()->root_instruction();
539
540 EXPECT_TRUE(
541 Match(root, m::Op().WithPredicate([&](const HloInstruction* instr) {
542 return instr == root;
543 })));
544 EXPECT_FALSE(
545 Match(root, m::Op().WithPredicate([&](const HloInstruction* instr) {
546 return instr != root;
547 })));
548 }
549
550 template <typename Pattern>
Description(const Pattern & pattern)551 std::string Description(const Pattern& pattern) {
552 std::stringstream ss;
553 pattern.DescribeTo(&ss);
554 return ss.str();
555 }
556
557 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)558 std::string Explanation(Elem* elem, const Pattern& pattern) {
559 std::stringstream ss;
560 MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
561 Match(elem, pattern, options);
562 return ss.str();
563 }
564 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)565 std::string Explanation(const std::unique_ptr<Elem>& elem,
566 const Pattern& pattern) {
567 return Explanation(elem.get(), pattern);
568 }
569 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)570 std::string Explanation(const Elem& elem, const Pattern& pattern) {
571 return Explanation(&elem, pattern);
572 }
573
574 // Helper macro for checking a pattern's description and the explanation printed
575 // when attempting to match (and presumably failing) on a given object.
576 //
577 // We use a macro rather than a function because we want good line numbers in
578 // errors. We use this rather than writing a helper that returns a pair of
579 // (description, explanation) and doing something like
580 //
581 // EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
582 //
583 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
584 // while EXPECT_THAT does not. This unified diff makes the errors much easier
585 // to read.
586 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
587 expected_explanation) \
588 do { \
589 EXPECT_EQ(Description(pattern), (expected_desc)); \
590 EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
591 } while (0)
592
TEST_F(PatternMatcherTest,LayoutDescribeToAndExplain)593 TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
594 auto layout = LayoutUtil::MakeLayout({1, 2});
595 auto layout2 = LayoutUtil::MakeLayout({2, 2});
596
597 EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
598 "a layout", "Layout is null");
599 EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
600 "a layout equal to {1,2}",
601 "Layout {2,2} is not equal to expected {1,2}");
602 }
603
TEST_F(PatternMatcherTest,CustomCallTargetMatcherDescribeAndExplain)604 TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
605 constexpr char kModuleStr[] = R"(
606 HloModule test_module
607
608 ENTRY test {
609 ROOT out = f32[] custom-call(), custom_call_target="test_target"
610 }
611 )";
612
613 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
614 ParseAndReturnVerifiedModule(kModuleStr));
615
616 auto* root = hlo_module->entry_computation()->root_instruction();
617 EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
618 EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
619
620 EXPECT_DESC_AND_EXPLANATION(
621 root, match::Op().WithCustomCallTarget("other_target"),
622 "an HloInstruction custom call with target 'other_target'",
623 "HloInstruction is not a custom call with a target 'other_target'\nin "
624 "out = f32[] custom-call(), custom_call_target=\"test_target\"");
625 }
626
TEST_F(PatternMatcherTest,ShapeDescribeToAndExplain)627 TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
628 auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
629 auto layout = shape.layout();
630
631 EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
632 "a shape", "Shape is null");
633 EXPECT_DESC_AND_EXPLANATION(
634 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
635 m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
636 "Shape not equal to f32[1,2]{0,1}\n"
637 "in f32[1,2]{1,0}");
638 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
639 m::Shape().CompatibleTo(&shape),
640 "a shape compatible with f32[1,2]",
641 "Shape not compatible with f32[1,2]\n"
642 "in f32[2,2]{1,0}");
643 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
644 "a shape with element type F16",
645 "Shape does not have element type F16\n"
646 "in f32[1,2]{0,1}");
647 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
648 "a shape that represents a scalar",
649 "Shape is not a scalar\n"
650 "in f32[1,2]{0,1}");
651 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
652 "a shape that represents an array",
653 "Shape is not an array\n"
654 "in ()");
655 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
656 "a shape that represents a tuple",
657 "Shape is not a tuple\n"
658 "in f32[1,2]{0,1}");
659 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
660 "a shape that is an effective scalar",
661 "Shape is not an effective scalar\n"
662 "in f32[1,2]{0,1}");
663 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
664 "a shape that has 42 dimensions",
665 "Shape does not have rank 42\n"
666 "in f32[1,2]{0,1}");
667 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
668 "a shape that is a scalar",
669 "Shape is not a scalar\n"
670 "in f32[1,2]{0,1}");
671 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
672 "a shape:\n"
673 " * that has 1 dimension AND\n"
674 " * that represents an array",
675 "Shape does not have rank 1\n"
676 "in f32[1,2]{0,1}");
677 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
678 m::Shape().IsArray().WithRank(1),
679 "a shape:\n"
680 " * that represents an array AND\n"
681 " * that has 1 dimension",
682 "Shape is not an array\n"
683 "in ()");
684 EXPECT_DESC_AND_EXPLANATION(
685 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
686 m::Shape().WithLayoutEqualTo(&layout),
687 "a shape with\n a layout equal to {0,1}",
688 "Layout {1,0} is not equal to expected {0,1}\n"
689 "in f32[1,2]{1,0}");
690 EXPECT_DESC_AND_EXPLANATION(shape,
691 m::Shape().WithSubshapeEqualTo({10}, &shape),
692 "a shape with subshape at index {10} which is\n"
693 " a shape equal to f32[1,2]{0,1}",
694 "No subshape at {10}\n"
695 "in f32[1,2]{0,1}");
696 EXPECT_DESC_AND_EXPLANATION(
697 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
698 m::Shape().WithSubshapeEqualTo({0}, &shape),
699 "a shape with subshape at index {0} which is\n"
700 " a shape equal to f32[1,2]{0,1}",
701 "Shape not equal to f32[1,2]{0,1}\n"
702 "in f32[2,2]{1,0}\n"
703 "in subshape at {0}\n"
704 "in (f32[2,2])");
705 EXPECT_DESC_AND_EXPLANATION(shape,
706 m::Shape().WithSubshapeCompatibleTo({10}, &shape),
707 "a shape with subshape at index {10} which is\n"
708 " a shape compatible with f32[1,2]",
709 "No subshape at {10}\n"
710 "in f32[1,2]{0,1}");
711 EXPECT_DESC_AND_EXPLANATION(
712 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
713 m::Shape().WithSubshapeCompatibleTo({0}, &shape),
714 "a shape with subshape at index {0} which is\n"
715 " a shape compatible with f32[1,2]",
716 "Shape not compatible with f32[1,2]\n"
717 "in f32[2,2]{1,0}\n"
718 "in subshape at {0}\n"
719 "in (f32[2,2])");
720 EXPECT_DESC_AND_EXPLANATION(
721 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
722 m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
723 "a shape with subshape at index {0,0} which is\n"
724 " a shape that represents a scalar",
725 "Shape is not a scalar\n"
726 "in f32[1,2]{0,1}\n"
727 "in subshape at {0,0}\n"
728 "in ((f32[1,2]))");
729 }
730
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)731 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
732 std::unique_ptr<HloInstruction> instr) {
733 instr->SetAndSanitizeName(name);
734 return instr;
735 }
736
TEST_F(PatternMatcherTest,HloInstructionDescribeToAndExplain)737 TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
738 std::unique_ptr<HloInstruction> iota =
739 SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
740 /*iota_dimension=*/0));
741 std::unique_ptr<HloInstruction> constant =
742 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
743
744 EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
745 m::Op(), "an HloInstruction",
746 "HloInstruction* is null");
747 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
748 "an HloInstruction named \"foo\"",
749 "HloInstruction not named \"foo\"\n"
750 "in i = s32[42]{0} iota(), iota_dimension=0");
751 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
752 "an HloInstruction with opcode add",
753 "HloInstruction doesn't have opcode add\n"
754 "in i = s32[42]{0} iota(), iota_dimension=0");
755 EXPECT_DESC_AND_EXPLANATION(
756 constant, m::Op().IsNonConstant(),
757 "an HloInstruction with any opcode other than constant",
758 "HloInstruction has opcode constant, expected anything else\n"
759 "in c = s32[] constant(0)");
760 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
761 "an HloInstruction with 42 operands",
762 "HloInstruction doesn't have 42 operands\n"
763 "in i = s32[42]{0} iota(), iota_dimension=0");
764 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
765 "an HloInstruction outputting\n"
766 " a shape that represents a tuple",
767 "Shape is not a tuple\n"
768 "in s32[42]{0}\n"
769 "in output shape\n"
770 "in i = s32[42]{0} iota(), iota_dimension=0");
771 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(F32, {42}),
772 "an HloInstruction outputting\n"
773 " a shape:\n"
774 " * with element type F32 AND\n"
775 " * with dimensions [42]",
776 "Shape does not have element type F32\n"
777 "in s32[42]{0}\n"
778 "in output shape\n"
779 "in i = s32[42]{0} iota(), iota_dimension=0");
780 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(S32, {128}),
781 "an HloInstruction outputting\n"
782 " a shape:\n"
783 " * with element type S32 AND\n"
784 " * with dimensions [128]",
785 "Shape does not have dimensions [128]\n"
786 "in s32[42]{0}\n"
787 "in output shape\n"
788 "in i = s32[42]{0} iota(), iota_dimension=0");
789 EXPECT_DESC_AND_EXPLANATION(
790 iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
791 "an HloInstruction with operand 2 which is:\n"
792 " an HloInstruction with opcode add",
793 "desired operand index 2 is out of bounds\n"
794 "in i = s32[42]{0} iota(), iota_dimension=0");
795
796 EXPECT_DESC_AND_EXPLANATION(
797 SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
798 HloOpcode::kAdd, constant.get(),
799 constant.get())),
800 m::Op().WithOperand(1, m::Op().IsNonConstant()),
801 "an HloInstruction with operand 1 which is:\n"
802 " an HloInstruction with any opcode other than constant",
803 "HloInstruction has opcode constant, expected anything else\n"
804 "in c = s32[] constant(0)\n"
805 "in operand 1\n"
806 "in a = s32[] add(s32[] c, s32[] c)");
807 EXPECT_DESC_AND_EXPLANATION(
808 iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
809 "an HloInstruction with fusion kind kLoop",
810 "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
811 "in i = s32[42]{0} iota(), iota_dimension=0");
812 EXPECT_DESC_AND_EXPLANATION(
813 iota, m::Op().WithTupleIndex(42),
814 "an HloInstruction which is a GTE with index 42",
815 "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
816 "in i = s32[42]{0} iota(), iota_dimension=0");
817 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
818 "an HloInstruction which is a constant scalar",
819 "HloInstruction is not a constant\n"
820 "in i = s32[42]{0} iota(), iota_dimension=0");
821 EXPECT_DESC_AND_EXPLANATION(
822 SetName("c", HloInstruction::CreateConstant(
823 LiteralUtil::CreateR1<int>({1, 2}))),
824 m::Op().IsConstantEffectiveScalar(),
825 "an HloInstruction which is a constant effective scalar",
826 "HloInstruction is not an effective scalar\n"
827 "in c = s32[2]{0} constant({1, 2})");
828 EXPECT_DESC_AND_EXPLANATION(
829 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
830 m::Op().IsConstantScalar(42),
831 "an HloInstruction which is a constant scalar with value 42",
832 "HloInstruction's constant value 10 did not match expected value 42\n"
833 "in c = s32[] constant(10)");
834 EXPECT_DESC_AND_EXPLANATION(
835 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
836 m::Op().IsConstantEffectiveScalar(1.25),
837 "an HloInstruction which is a constant effective scalar with value 1.25",
838 "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
839 "in c = f64[] constant(2.25)");
840 EXPECT_DESC_AND_EXPLANATION(
841 constant, m::Op().Is(iota.get()),
842 absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
843 " (i = s32[42]{0} iota(), iota_dimension=0)"),
844 absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
845 absl::Hex(iota.get()),
846 " (i = s32[42]{0} iota(), iota_dimension=0)\n"
847 "in c = s32[] constant(0)"));
848
849 EXPECT_DESC_AND_EXPLANATION(
850 SetName("a",
851 HloInstruction::CreateBinary(constant->shape(), HloOpcode::kAdd,
852 constant.get(), constant.get())),
853 m::Op().WithOperandIfPresent(0, m::Iota()), //
854 "an HloInstruction either with fewer than 1 operand, or with an operand "
855 "0 which is:\n"
856 " an HloInstruction with opcode iota",
857 "HloInstruction doesn't have opcode iota\n"
858 "in c = s32[] constant(0)\n"
859 "in operand 0\n"
860 "in a = s32[] add(s32[] c, s32[] c)");
861
862 EXPECT_DESC_AND_EXPLANATION(
863 constant,
864 m::Op().WithPredicate([](const HloInstruction*) { return false; }),
865 "an HloInstruction which matches a user-specified predicate",
866 "HloInstruction does not match user-specified predicate\n"
867 "in c = s32[] constant(0)");
868 }
869
TEST_F(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)870 TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
871 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
872 EXPECT_DESC_AND_EXPLANATION(
873 SetName("a", HloInstruction::CreateBinary(
874 scalar_s32, HloOpcode::kAdd,
875 SetName("b", HloInstruction::CreateConstant(
876 LiteralUtil::CreateR0(0)))
877 .get(),
878 SetName("c", HloInstruction::CreateConstant(
879 LiteralUtil::CreateR0(0)))
880 .get())),
881 m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
882 "an HloInstruction:\n"
883 " * with opcode add AND\n"
884 " * with two operands in either order:\n"
885 " - an HloInstruction named \"b\"\n"
886 " - an HloInstruction named \"bar\"",
887 "HloInstruction's operands (ignoring order) did not match second "
888 "matcher. Specifically,\n"
889 " - an HloInstruction named \"bar\"\n"
890 "does not match LHS:\n"
891 " - HloInstruction not named \"bar\"\n"
892 " in b = s32[] constant(0)\n"
893 "does not match RHS:\n"
894 " - HloInstruction not named \"bar\"\n"
895 " in c = s32[] constant(0)\n"
896 "in a = s32[] add(s32[] b, s32[] c)");
897
898 EXPECT_DESC_AND_EXPLANATION(
899 SetName("a",
900 HloInstruction::CreateBinary(
901 scalar_s32, HloOpcode::kAdd,
902 HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
903 SetName("c", HloInstruction::CreateConstant(
904 LiteralUtil::CreateR0(0)))
905 .get())),
906 m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
907 "an HloInstruction:\n"
908 " * with opcode add AND\n"
909 " * with two operands in either order:\n"
910 " - an HloInstruction which is a constant scalar\n"
911 " - an HloInstruction with opcode constant",
912 "HloInstruction's LHS operand did not match either of the two matchers. "
913 "Specifically,\n"
914 " - an HloInstruction which is a constant scalar\n"
915 "does not match LHS:\n"
916 " - HloInstruction is not a constant\n"
917 " in p = s32[] parameter(0)\n"
918 "and\n"
919 " - an HloInstruction with opcode constant\n"
920 "does not match LHS:\n"
921 " - HloInstruction doesn't have opcode constant\n"
922 " in p = s32[] parameter(0)\n"
923 "in a = s32[] add(s32[] p, s32[] c)");
924 }
925
TEST_F(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)926 TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
927 EXPECT_DESC_AND_EXPLANATION(
928 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
929 m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
930 m::Op().WithName("bar")),
931 "any of:\n"
932 " - an HloInstruction named \"foo\" OR\n"
933 " - an HloInstruction named \"bar\"",
934 "None of the following matchers succeeded:\n"
935 "Matcher #1\n"
936 " - an HloInstruction named \"foo\"\n"
937 "failed with\n"
938 " - HloInstruction not named \"foo\"\n"
939 " in c = s32[] constant(0)\n"
940 "Matcher #2\n"
941 " - an HloInstruction named \"bar\"\n"
942 "failed with\n"
943 " - HloInstruction not named \"bar\"\n"
944 " in c = s32[] constant(0)");
945 }
946
TEST_F(PatternMatcherTest,Parameter)947 TEST_F(PatternMatcherTest, Parameter) {
948 auto param =
949 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
950 auto non_param =
951 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
952 EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
953 EXPECT_TRUE(Match(param.get(), m::Parameter()));
954 EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
955 EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
956 EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
957
958 EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
959 "an HloInstruction:\n"
960 " * with opcode parameter AND\n"
961 " * which is parameter 1",
962 "HloInstruction doesn't have opcode parameter\n"
963 "in c = s32[] constant(0)");
964 EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
965 0, ShapeUtil::MakeShape(F32, {}), "p0"),
966 m::Parameter(1)),
967 "HloInstruction is not parameter 1\n"
968 "in p0 = f32[] parameter(0)");
969 }
970
TEST_F(PatternMatcherTest,OneUseAndOneUser)971 TEST_F(PatternMatcherTest, OneUseAndOneUser) {
972 auto param =
973 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
974
975 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
976 EXPECT_DESC_AND_EXPLANATION(
977 param, m::Op().WithOneUse(),
978 "an HloInstruction which has exactly one use",
979 "HloInstruction has 0 users, but expected exactly one.\n"
980 "in p0 = f32[] parameter(0)");
981
982 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
983 EXPECT_DESC_AND_EXPLANATION(
984 param, m::Op().WithOneUser(),
985 "an HloInstruction which has exactly one user (but possibly is used "
986 "multiple times by that instruction)",
987 "HloInstruction has 0 users, but expected exactly one.\n"
988 "in p0 = f32[] parameter(0)");
989
990 {
991 auto reshape =
992 SetName("r", HloInstruction::CreateReshape(
993 ShapeUtil::MakeShape(F32, {1}), param.get()));
994 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
995 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
996
997 auto reshape1 =
998 SetName("r1", HloInstruction::CreateReshape(
999 ShapeUtil::MakeShape(F32, {1}), param.get()));
1000 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
1001 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
1002
1003 const char* kMultipleUserExplanation =
1004 "HloInstruction has 2 users, but expected exactly one.\n"
1005 "All users:\n"
1006 " - r = f32[1]{0} reshape(f32[] p0)\n"
1007 " - r1 = f32[1]{0} reshape(f32[] p0)\n"
1008 "in p0 = f32[] parameter(0)";
1009 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
1010 kMultipleUserExplanation);
1011 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
1012 kMultipleUserExplanation);
1013 }
1014
1015 auto add = SetName("add", HloInstruction::CreateBinary(
1016 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
1017 param.get(), param.get()));
1018 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
1019 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
1020 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
1021 "HloInstruction is used 2 times by its user, but is expected to be "
1022 "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
1023 "in p0 = f32[] parameter(0)");
1024 }
1025
TEST_F(PatternMatcherTest,Comparison)1026 TEST_F(PatternMatcherTest, Comparison) {
1027 auto shape = ShapeUtil::MakeShape(F32, {1});
1028 auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
1029 auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
1030 auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
1031 ComparisonDirection::kEq);
1032 auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
1033 ComparisonDirection::kNe);
1034 auto add =
1035 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
1036 auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
1037 ComparisonDirection::kLe);
1038
1039 EXPECT_TRUE(Match(eq.get(), m::Compare()));
1040 EXPECT_TRUE(Match(eq.get(), m::Eq()));
1041 EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
1042 EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
1043 EXPECT_TRUE(Match(ne.get(), m::Compare()));
1044 EXPECT_TRUE(Match(ne.get(), m::Ne()));
1045 EXPECT_TRUE(Match(
1046 le.get(),
1047 m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
1048 EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
1049 m::Add(m::Parameter(0), m::Parameter(1)))));
1050
1051 EXPECT_FALSE(Match(eq.get(), m::Add()));
1052 EXPECT_FALSE(Match(eq.get(), m::Ne()));
1053 EXPECT_FALSE(
1054 Match(le.get(),
1055 m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
1056 EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
1057 EXPECT_DESC_AND_EXPLANATION(
1058 eq, m::Ne().WithOneUser(),
1059 "an HloInstruction:\n"
1060 " * with opcode compare AND\n"
1061 " * which has comparison direction NE AND\n"
1062 " * which has exactly one user (but possibly is used "
1063 "multiple times by that instruction)",
1064 "HloInstruction is not comparison NE\n"
1065 "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
1066 "direction=EQ");
1067 }
1068
TEST_F(PatternMatcherTest,CustomCallMatchers)1069 TEST_F(PatternMatcherTest, CustomCallMatchers) {
1070 constexpr char kModuleStr[] = R"(
1071 HloModule test_module
1072
1073 ENTRY test {
1074 p0 = f32[] parameter(0)
1075 p1 = f32[] parameter(1)
1076 ROOT out = f32[] custom-call(p0, p1), custom_call_target="test_target"
1077 }
1078 )";
1079 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1080 ParseAndReturnVerifiedModule(kModuleStr));
1081 auto* root = hlo_module->entry_computation()->root_instruction();
1082
1083 EXPECT_TRUE(Match(root, m::CustomCall()));
1084 EXPECT_TRUE(Match(root, m::CustomCall("test_target")));
1085 EXPECT_TRUE(Match(
1086 root, m::CustomCall("test_target", m::Parameter(0), m::Parameter(1))));
1087
1088 HloInstruction* instr;
1089 EXPECT_TRUE(Match(root, m::CustomCall(&instr)));
1090 EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target")));
1091 EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target", m::Parameter(0),
1092 m::Parameter(1))));
1093
1094 const HloInstruction* const_instr;
1095 EXPECT_TRUE(Match(root, m::CustomCall(&const_instr)));
1096 EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target")));
1097 EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target",
1098 m::Parameter(0), m::Parameter(1))));
1099
1100 EXPECT_FALSE(Match(root, m::CustomCall("other_target")));
1101 EXPECT_FALSE(Match(
1102 root, m::CustomCall("test_target", m::Parameter(1), m::Parameter(0))));
1103 }
1104
1105 } // namespace
1106 } // namespace xla
1107