1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18
19 #include <functional>
20 #include <sstream>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/utility/utility.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35
36 namespace xla {
37
38 // A pattern matcher for HloInstructions, Shapes, and Layouts.
39 //
40 // The Match function's first argument must be HloInstruction*, Shape*, or
41 // Layout*. The second argument is a pattern that will be matched against the
42 // first argument, as described below.
43 //
44 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
45 // functions. By default, the returned patterns will match any HloInstruction,
46 // Shape, or Layout, respectively. However the match can be made more specific
47 // by using the pattern's modifier methods, for example:
48 //
49 // match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
50 // 0, match::Op().WithOpcode(HloOpcode::kConstant))
51 //
52 // This pattern will match Add instructions whose first operand is a constant.
53 //
54 // Each pattern type has the following modifiers, which are described where
55 // nontrivial.
56 //
57 // Op():
58 // - Is: is the given HloInstruction* (i.e. pointer equality)
59 // - WithName
60 // - WithOpcode
61 // - WithoutOpcode: anything other than the given opcode
62 // - WithShape: instr's shape matches the given pattern
63 // - WithShapeEqualTo: instr's shape is equal to the given Shape
64 // - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
65 // - WithElementType: instr.shape().element_type() matches the given type
66 // - WithNumOperands
67 // - WithOperand: operand at the given index matches the given pattern
68 // - WithOperandIfPresent: instr has fewer than i operands or the i'th one
69 // matches the given pattern
70 // - IsConstant
71 // - IsNonConstant
72 // - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
73 // e.g. IsConstantScalar() or IsConstantScalar(42).
74 // - WithFusionKind
75 // - WithTupleIndex: get-tuple-element operations with the given tuple index
76 // - WithOneUse: Instruction is used as an operand exactly once.
77 // - WithOneUser: Instruction is used by exactly one other instruction, but
78 // is possibly used more than once as an operand (e.g. multiply(x,x)).
79 // - WithComparisonDirection: instr has the given direction
80 // - WithPredicate: Instruction matches an arbitrary function you pass.
81 // Function must have signature `bool(const HloInstruction*)`.
82 //
83 // Shape():
84 // - EqualTo
85 // - CompatibleTo
86 // - IsScalar/IsEffectiveScalar/IsArray/IsTuple
87 // - IsDenseArray
88 // - WithLayout: layout shape's layout matches the given pattern (e.g.
89 // - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
90 // Layout, but not the result of Layout().foo())
91 // - WithSubshape: shape is a tuple whose subshape matches the given pattern
92 // (e.g. Shape().IsScalar()).
93 // - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
94 // (i.e. another Shape, but not the result of Shape().foo())
95 // - WithElementType: shape is an array/scalar with the given elem type
96 // - WithRank: shape is an array/scalar with the given rank
97 //
98 // Layout():
99 // - EqualTo
100 //
101 // Op(), Shape(), and Layout() may be passed an argument of type
102 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
103 // these pointers. If the pattern is matched, the address of the matched value
104 // will be "captured" and stored at this location.
105 //
106 // For example:
107 // HloInstruction* foo = ...;
108 // HloInstruction* matched_operand;
109 // CHECK(Match(foo,
110 // match::Op().WithOperand(0, match::Op(&matched_operand))));
111 //
112 // Helpers are provided for most HLO instructions. These helpers can be called
113 // with no arguments, in which case they will match any instruction matching the
114 // opcode. They may also be called with matches for the operands and with an
115 // optional capture. (The capture must be the first argument.) Some examples of
116 // these helpers and their equivalents are provided below.
117
118 // Example nullary instruction:
119 // Parameter() == Op().WithOpcode(HloOpcode::kParameter)
120 // Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter)
121 //
122 // Example unary instruction:
123 // Abs() == Op().WithOpcode(HloOpcode::kAbs)
124 // Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs)
125 // .WithOperand(0, Op(&a)))
126 // Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs)
127 // .WithOperand(0, Op(&b))
128 //
129 // Commutative binary instructions have a special form that accepts either order
130 // of args, e.g.:
131 //
132 // AddAnyOrder(Parameter(1), Abs()) ==
133 // Op().WithOpcode(HloOpcode::kAdd)
134 // .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
135 //
136 // MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`.
137 //
138 // The following additional helpers are provided. In all cases, `&a` is
139 // optional.
140 //
141 // ConstantScalar(&a) == Op(&a).IsConstantScalar();
142 // ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v);
143 // ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar();
144 // ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v)
145 // NonConstant(&a) == Op(&a).IsNonConstant()
146 // GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index)
147 // .WithOperand(0, b);
148 // Parameter(&a, n) == Op(&a).WithParameterNum(n);
149
150 struct MatchOption {
151 // If true, actually capture matched item into the user pointer.
152 bool capture;
153
154 // An explanation for why we failed to match is streamed here, if not-null.
155 std::ostream* explain_os;
156 };
157
158 template <typename Value, typename Pattern>
159 bool Match(Value* value, const Pattern& pattern,
160 MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
161 if (option.capture) {
162 auto new_option = option;
163 new_option.capture = false;
164 if (!pattern.Match(value, new_option)) {
165 return false;
166 }
167 }
168 return pattern.Match(value, option);
169 }
170
171 // If `enable_logging` is false, this is identical to Match(instr, pattern).
172 //
173 // If `enable_logging` is true and the match fails, we try to
174 // Match(instr, filter_pattern). If this is true, then we log an explanation for
175 // why the original Match(instr, pattern) failed.
176 //
177 // This function can be used aid in debugging passes with complex matchers.
178 // For example, in the following snippet we're trying to match
179 // m::Slice(m::Reshape(m::Pad())). Every time we encounter a slice that
180 // doesn't match the larger pattern, we will log an explanation for why it
181 // didn't match the larger pattern.
182 //
183 // if (MatchAndLogIfFailed(instr, "slice of reshape of pad",
184 // m::Slice(m::Reshape(m::Pad())),
185 // VLOG_IS_ON(3), m::Slice())
186 //
187 // TODO(jlebar): Log the caller's absl::SourceLocation once that's in OSS.
188 template <typename FilterPattern, typename Pattern>
MatchAndLogIfFailed(HloInstruction * instr,absl::string_view desc,const Pattern & pattern,bool enable_logging,const FilterPattern & filter_pattern)189 bool MatchAndLogIfFailed(HloInstruction* instr, absl::string_view desc,
190 const Pattern& pattern, bool enable_logging,
191 const FilterPattern& filter_pattern) {
192 bool matched = Match(instr, pattern);
193 if (matched || !enable_logging || !Match(instr, filter_pattern)) {
194 return matched;
195 }
196 std::stringstream os;
197 CHECK(!Match(instr, pattern, {/*capture=*/false, /*explain_os=*/&os}));
198 LOG(ERROR) << "Failed to match " << desc << ":\n" << os.str();
199 return false;
200 }
201
202 namespace match {
203
204 namespace detail {
205
206 // Macro for streaming to option.explain_os if it's not null.
207 //
208 // EXPLAIN << "value of foo(): " << foo()
209 //
210 #pragma push_macro("EXPLAIN")
211 #define EXPLAIN \
212 if (option.explain_os) *option.explain_os
213
214 // kIndentInc is the additional number of spaces that we indent by when we
215 // increase the indent "by one".
216 enum {
217 kIndentInc = 2,
218 };
219
220 // Writes a newline and then `indent` spaces.
221 //
222 // We follow an unintuitive convention in this file's pretty-printers: Indents
223 // are performed by the caller, not the callee. For example, if you want to
224 // print
225 //
226 // foo:
227 // - bar
228 //
229 // you'd do:
230 //
231 // Foo::DescribeTo(std::ostream* os, int64_t indent) {
232 // *os << "foo:";
233 // Indent(os, indent) // Create a newline at the *current* indent level.
234 // *os << " - ";
235 // bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3.
236 // }
237 //
238 // Bar::DescribeTo(std::ostream* os, int64_t indent) { *os << "bar"; }
239 //
240 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
241 // performed by Foo. This convention allows the caller to decide whether a
242 // matcher is preceded by a newline, which is important e.g. for the AllOf
243 // matcher.
244 //
245 // (Incidentally, indenting in Match's explanations is handled differently.
246 // Indents are a common case in DescribeTo [we're printing a whole tree], but
247 // they're a special case in Match [we're printing only a path through the tree
248 // that encounters a failing node]. Indents in Match only appear when we
249 // encounter a failing disjunction, so we just handle them as a special case
250 // there.)
Indent(std::ostream * os,int64_t indent)251 inline void Indent(std::ostream* os, int64_t indent) {
252 *os << "\n";
253 for (int64_t i = 0; i < indent; ++i) {
254 *os << " ";
255 }
256 }
257
258 // SFINAE template that determines whether T declares a static member
259 // kIsTrivialMatcher.
260 //
261 // Trivial matchers get special treatment. For example, when printing
262 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
263 // yields e.g.
264 // "a shape compatible with f32[1,2]"
265 // rather than
266 // "a shape AND compatible with f32[1,2]"
267 template <typename T, typename Dummy = void>
268 struct IsTrivialMatcher {
269 static constexpr bool value = false;
270 };
271 template <typename T>
272 struct IsTrivialMatcher<T,
273 typename std::enable_if<T::kIsTrivialMatcher>::type> {
274 static constexpr bool value = true;
275 };
276
277 template <typename Item, typename... Patterns>
278 class AllOfPattern {
279 public:
280 explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
281
282 bool Match(const Item* item, MatchOption option) const {
283 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
284 // This invariant is guaranteed by the top-level Match and AnyOf.
285 DCHECK(matched || !option.capture);
286 return matched;
287 }
288
289 bool Match(Item* item, MatchOption option) const {
290 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
291 // This invariant is guaranteed by the top-level Match and AnyOf.
292 DCHECK(matched || !option.capture);
293 return matched;
294 }
295
296 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
297 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
298 }
299
300 // Accessor for patterns_. Please don't use this outside of this file.
301 const std::tuple<Patterns...>& patterns() const { return patterns_; }
302
303 private:
304 template <typename ItemType, size_t index>
305 bool MatchImpl(ItemType* item, MatchOption option,
306 std::integral_constant<size_t, index>) const {
307 // We don't need to do any EXPLAINing here; it's all correctly handled by
308 // our sub-matchers (if any fail).
309 return std::get<index>(patterns_).Match(item, option) &&
310 MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
311 }
312
313 template <typename ItemType>
314 bool MatchImpl(ItemType* item, MatchOption option,
315 std::integral_constant<size_t, sizeof...(Patterns)>) const {
316 return true;
317 }
318
319 // Pretty-printing a conjunction has some special cases to make it easy to
320 // read in the simple (common) case.
321 //
322 // If sizeof...(Patterns) == 1, prints as e.g.
323 //
324 // a shape
325 //
326 // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
327 // shape") prints as
328 //
329 // a shape compatible with f32[1,2]
330 //
331 // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
332 //
333 // a shape:
334 // * compatible with f32[1,2] AND
335 // * that represents a scalar
336 //
337 // Otherwise prints as:
338 //
339 // all of:
340 // * foo AND
341 // * bar
342 //
343 template <size_t index>
344 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
345 int64_t indent) const {
346 constexpr bool first_is_trivial =
347 IsTrivialMatcher<typename std::remove_reference<decltype(std::get<0>(
348 patterns_))>::type>::value;
349 constexpr bool is_last = index == sizeof...(Patterns) - 1;
350 const auto& submatcher = std::get<index>(patterns_);
351
352 auto print_bulleted_item = [&] {
353 *os << " * ";
354 submatcher.DescribeTo(os, indent + 3);
355 if (!is_last) {
356 *os << " AND";
357 Indent(os, indent);
358 }
359 };
360
361 if (index == 0) {
362 if (first_is_trivial || is_last) {
363 submatcher.DescribeTo(os, indent + kIndentInc);
364 if (sizeof...(Patterns) > 2) {
365 *os << ":";
366 Indent(os, indent);
367 }
368 } else {
369 *os << "all of:";
370 Indent(os, indent);
371 print_bulleted_item();
372 }
373 } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
374 *os << " ";
375 submatcher.DescribeTo(os, indent);
376 } else {
377 print_bulleted_item();
378 }
379 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
380 }
381
382 void DescribeToImpl(std::ostream* os,
383 std::integral_constant<size_t, sizeof...(Patterns)>,
384 int64_t indent) const {}
385
386 std::tuple<Patterns...> patterns_;
387 };
388
389 } // namespace detail
390
391 // Returns a pattern that represents the conjunction of all input patterns. All
392 // patterns need to match in order to have the AllOf pattern match.
393 template <typename Item, typename... Patterns>
394 auto AllOf(const Patterns&... patterns) {
395 return detail::AllOfPattern<typename std::remove_const<Item>::type,
396 Patterns...>(patterns...);
397 }
398
399 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
400 //
401 // This transformation is necessary for good pretty-printing.
402 template <typename Item, typename... InnerPs, typename... OuterPs>
403 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
404 const OuterPs&... outer_ps) {
405 // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
406 auto make_all_of = [](const InnerPs&... inner_ps,
407 const OuterPs&... outer_ps) {
408 return detail::AllOfPattern<typename std::remove_const<Item>::type,
409 InnerPs..., OuterPs...>(inner_ps...,
410 outer_ps...);
411 };
412 return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
413 std::make_tuple(outer_ps...)));
414 }
415
416 namespace detail {
417
418 template <typename LayoutType, typename Impl>
419 class LayoutPattern;
420
421 // The base LayoutPattern implementation. Matches only if the layout is not
422 // nullptr.
423 class LayoutPatternBaseImpl {
424 public:
425 bool Match(const ::xla::Layout* layout, MatchOption option) const {
426 if (layout == nullptr) {
427 EXPLAIN << "Layout is null";
428 return false;
429 }
430 return true;
431 }
432
433 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
434 *os << "a layout";
435 }
436
437 static constexpr bool kIsTrivialMatcher = true;
438 };
439
440 // A LayoutPattern implementation that matches only if the layout equals a
441 // Layout proto.
442 class LayoutPatternEqualImpl {
443 public:
444 explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
445 : layout_(layout) {}
446
447 bool Match(const ::xla::Layout* layout, MatchOption option) const {
448 if (!LayoutUtil::Equal(*layout_, *layout)) {
449 EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
450 << " is not equal to expected "
451 << LayoutUtil::HumanString(*layout_);
452 return false;
453 }
454 return true;
455 }
456
457 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
458 *os << "equal to " << LayoutUtil::HumanString(*layout_);
459 }
460
461 private:
462 const ::xla::Layout* layout_;
463 };
464
465 // A pattern that matches Layouts.
466 template <typename LayoutType, typename Impl>
467 class LayoutPattern {
468 private:
469 template <typename NewImpl>
470 auto AppendImpl(NewImpl new_impl) const {
471 auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
472 return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
473 matched_layout_);
474 }
475
476 public:
477 explicit constexpr LayoutPattern(const Impl& impl,
478 LayoutType** matched_layout)
479 : impl_(impl), matched_layout_(matched_layout) {}
480
481 // Returns true and captures the layout iff it matches the pattern.
482 bool Match(const ::xla::Layout* layout, MatchOption option) const {
483 if (impl_.Match(layout, option)) {
484 if (option.capture && matched_layout_) {
485 *matched_layout_ = layout;
486 }
487 return true;
488 }
489 return false;
490 }
491
492 // Returns true and captures the layout iff it matches the pattern.
493 bool Match(::xla::Layout* layout, MatchOption option) const {
494 if (impl_.Match(layout, option)) {
495 if (option.capture && matched_layout_) {
496 *matched_layout_ = layout;
497 }
498 return true;
499 }
500 return false;
501 }
502
503 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
504 impl_.DescribeTo(os, indent);
505 }
506
507 // Modifies the pattern to match only if the layout equals the given proto.
508 // The layout must outlive the returned pattern.
509 constexpr auto EqualTo(const ::xla::Layout* layout) const {
510 return AppendImpl(LayoutPatternEqualImpl(layout));
511 }
512
513 private:
514 Impl impl_;
515 LayoutType** matched_layout_;
516 };
517
518 template <typename Item, typename... Patterns>
519 class AnyOfPattern {
520 public:
521 explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
522
523 bool Match(const Item* item, MatchOption option) const {
524 return MatchImpl(item, option);
525 }
526
527 bool Match(Item* item, MatchOption option) const {
528 return MatchImpl(item, option);
529 }
530
531 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
532 *os << "any of:";
533 Indent(os, indent);
534 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
535 }
536
537 private:
538 template <typename ItemType>
539 bool MatchImpl(ItemType* item, MatchOption option) const {
540 // If we're generating an explanation, buffer it until we know we failed.
541 std::optional<std::stringstream> explanation;
542 MatchOption new_option = option;
543 if (option.explain_os) {
544 new_option.explain_os = &explanation.emplace();
545 }
546 bool rv = MatchRecursiveImpl(item, new_option,
547 std::integral_constant<size_t, 0>());
548 if (!rv && option.explain_os) {
549 EXPLAIN << "None of the following matchers succeeded:";
550 EXPLAIN << explanation->str();
551 }
552 return rv;
553 }
554
555 template <typename ItemType, size_t index>
556 bool MatchRecursiveImpl(ItemType* item, MatchOption option,
557 std::integral_constant<size_t, index>) const {
558 auto new_option = option;
559 new_option.capture = false;
560
561 std::optional<std::stringstream> explanation;
562 if (option.explain_os) {
563 new_option.explain_os = &explanation.emplace();
564 }
565
566 // Try to match the sub-pattern without capturing behavior.
567 if (std::get<index>(patterns_).Match(item, new_option)) {
568 // Capture the branch.
569 if (option.capture) {
570 // TODO(timshen): Currently the behavior can be exponential. Optimize it
571 // with memoization or recording the matched sub-pattern index, if it
572 // takes too long to run.
573 //
574 // Specifically, the "memoization" approach is to create an empty
575 // container with the key (pattern, instruction), and value as whether
576 // matched or not.
577 //
578 // Alternatively, we may run the pattern matching with captures off, but
579 // instead record a "trace" somewhere, indicating how exactly the
580 // pattern matches the input. For example, the trace information for
581 // AnyOf will be a runtime number indicate which sub-pattern is matched.
582 // Then we run another pass to do captures only with the help of the
583 // trace.
584 bool matched = std::get<index>(patterns_).Match(item, option);
585 DCHECK(matched);
586 }
587 return true;
588 }
589 if (option.explain_os) {
590 EXPLAIN << "\nMatcher #" << index + 1;
591 EXPLAIN << "\n - ";
592 std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
593 EXPLAIN << "\nfailed with";
594 EXPLAIN << "\n - ";
595 EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}});
596 }
597 return MatchRecursiveImpl(item, option,
598 std::integral_constant<size_t, index + 1>());
599 }
600
601 template <typename ItemType>
602 bool MatchRecursiveImpl(
603 ItemType* item, MatchOption option,
604 std::integral_constant<size_t, sizeof...(Patterns)>) const {
605 return false;
606 }
607
608 template <size_t index>
609 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
610 int64_t indent) const {
611 *os << " - ";
612 std::get<index>(patterns_).DescribeTo(os, indent + 3);
613 if (index != sizeof...(Patterns) - 1) {
614 *os << " OR";
615 Indent(os, indent);
616 }
617 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
618 }
619
620 void DescribeToImpl(std::ostream* os,
621 std::integral_constant<size_t, sizeof...(Patterns)>,
622 int64_t indent) const {}
623
624 std::tuple<Patterns...> patterns_;
625 };
626
627 } // namespace detail
628
629 // Returns a pattern that represents the logical disjunction of the input
630 // patterns. The returned pattern matches from left to right, and stops on the
631 // first match.
632 template <typename Item, typename... Patterns>
633 auto AnyOf(const Patterns&... patterns) {
634 return detail::AnyOfPattern<typename std::remove_const<Item>::type,
635 Patterns...>(patterns...);
636 }
637
638 // Creates a layout pattern that will capture the matched layout in the
639 // argument.
640 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
641 return detail::LayoutPattern<const ::xla::Layout,
642 detail::LayoutPatternBaseImpl>(
643 detail::LayoutPatternBaseImpl(), matched_layout);
644 }
645
646 // Creates a layout pattern that will capture the matched layout in the
647 // argument.
648 inline constexpr auto Layout(::xla::Layout** matched_layout) {
649 return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
650 detail::LayoutPatternBaseImpl(), matched_layout);
651 }
652
653 namespace detail {
654
655 template <typename ShapeType, typename Impl>
656 class ShapePattern;
657
658 // The base ShapePattern implementation. Matches only if the shape is not
659 // nullptr.
660 class ShapePatternBaseImpl {
661 public:
662 bool Match(const ::xla::Shape* shape, MatchOption option) const {
663 if (shape == nullptr) {
664 EXPLAIN << "Shape is null";
665 }
666 return shape != nullptr;
667 }
668
669 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
670 *os << "a shape";
671 }
672
673 static constexpr bool kIsTrivialMatcher = true;
674 };
675
676 // A ShapePattern implementation that matches only if the shape equals a Shape
677 // proto.
678 class ShapePatternEqualImpl {
679 public:
680 explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
681 : shape_(shape) {}
682
683 bool Match(const ::xla::Shape* shape, MatchOption option) const {
684 if (!ShapeUtil::Equal(*shape_, *shape)) {
685 EXPLAIN << "Shape not equal to "
686 << ShapeUtil::HumanStringWithLayout(*shape_);
687 return false;
688 }
689 return true;
690 }
691
692 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
693 *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
694 }
695
696 private:
697 const ::xla::Shape* shape_;
698 };
699
700 // A ShapePattern implementation that matches only if the shape is compatible to
701 // a Shape proto.
702 class ShapePatternCompatibleImpl {
703 public:
704 explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
705 : shape_(shape) {}
706
707 bool Match(const ::xla::Shape* shape, MatchOption option) const {
708 if (!ShapeUtil::Compatible(*shape_, *shape)) {
709 EXPLAIN << "Shape not compatible with "
710 << ShapeUtil::HumanString(*shape_);
711 return false;
712 }
713 return true;
714 }
715
716 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
717 *os << "compatible with " << ShapeUtil::HumanString(*shape_);
718 }
719
720 private:
721 const ::xla::Shape* shape_;
722 };
723
724 // A ShapePattern implementation that matches only if the shape has a given
725 // element type.
726 class ShapePatternElementTypeImpl {
727 public:
728 explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
729 : element_type_(element_type) {}
730
731 bool Match(const ::xla::Shape* shape, MatchOption option) const {
732 if (shape->element_type() != element_type_) {
733 EXPLAIN << "Shape does not have element type "
734 << PrimitiveType_Name(element_type_);
735 return false;
736 }
737 return true;
738 }
739
740 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
741 *os << "with element type " << PrimitiveType_Name(element_type_);
742 }
743
744 private:
745 PrimitiveType element_type_;
746 };
747
748 // A ShapePattern implementation that matches only if the shape has a given
749 // list of dimensions.
750 class ShapePatternDimsImpl {
751 public:
752 explicit ShapePatternDimsImpl(absl::Span<const int64_t> dims)
753 : dims_(dims.begin(), dims.end()) {}
754
755 bool Match(const ::xla::Shape* shape, MatchOption option) const {
756 if (shape->dimensions() != dims_) {
757 EXPLAIN << "Shape does not have dimensions [" << absl::StrJoin(dims_, ",")
758 << "]";
759 return false;
760 }
761 return true;
762 }
763
764 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
765 *os << "with dimensions [" << absl::StrJoin(dims_, ",") << "]";
766 }
767
768 private:
769 absl::InlinedVector<int64_t, 8> dims_;
770 };
771
772 // A ShapePattern implementation that matches only if the shape is scalar.
773 class ShapePatternIsScalarImpl {
774 public:
775 explicit constexpr ShapePatternIsScalarImpl() {}
776
777 bool Match(const ::xla::Shape* shape, MatchOption option) const {
778 if (!ShapeUtil::IsScalar(*shape)) {
779 EXPLAIN << "Shape is not a scalar";
780 return false;
781 }
782 return true;
783 }
784
785 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
786 *os << "that represents a scalar";
787 }
788 };
789
790 // A ShapePattern implementation that matches only if the shape is an array
791 class ShapePatternIsArrayImpl {
792 public:
793 explicit constexpr ShapePatternIsArrayImpl() {}
794
795 bool Match(const ::xla::Shape* shape, MatchOption option) const {
796 if (!shape->IsArray()) {
797 EXPLAIN << "Shape is not an array";
798 return false;
799 }
800 return true;
801 }
802
803 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
804 *os << "that represents an array";
805 }
806 };
807
808 // A ShapePattern implementation that matches only if the shape is an array
809 class ShapePatternIsDenseArrayImpl {
810 public:
811 explicit constexpr ShapePatternIsDenseArrayImpl() {}
812
813 bool Match(const ::xla::Shape* shape, MatchOption option) const {
814 if (!LayoutUtil::IsDenseArray(*shape)) {
815 EXPLAIN << "Shape is not a dense array";
816 return false;
817 }
818 return true;
819 }
820
821 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
822 *os << "that represents a dense array";
823 }
824 };
825
826 // A ShapePattern implementation that matches only if the shape is a tuple.
827 class ShapePatternIsTupleImpl {
828 public:
829 explicit constexpr ShapePatternIsTupleImpl() {}
830
831 bool Match(const ::xla::Shape* shape, MatchOption option) const {
832 if (!shape->IsTuple()) {
833 EXPLAIN << "Shape is not a tuple";
834 return false;
835 }
836 return true;
837 }
838
839 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
840 *os << "that represents a tuple";
841 }
842 };
843
844 // A ShapePattern implementation that matches only if the shape is an effective
845 // scalar.
846 class ShapePatternEffectiveScalarImpl {
847 public:
848 explicit constexpr ShapePatternEffectiveScalarImpl() {}
849
850 bool Match(const ::xla::Shape* shape, MatchOption option) const {
851 if (!ShapeUtil::IsEffectiveScalar(*shape)) {
852 EXPLAIN << "Shape is not an effective scalar";
853 return false;
854 }
855 return true;
856 }
857
858 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
859 *os << "that is an effective scalar";
860 }
861 };
862
863 // A ShapePattern implementation that matches only if the shape has a given
864 // rank.
865 class ShapePatternRankImpl {
866 public:
867 explicit constexpr ShapePatternRankImpl(int64_t rank) : rank_(rank) {}
868
869 bool Match(const ::xla::Shape* shape, MatchOption option) const {
870 if (shape->rank() != rank_) {
871 if (rank_ == 0) {
872 EXPLAIN << "Shape is not a scalar";
873 } else {
874 EXPLAIN << "Shape does not have rank " << rank_;
875 }
876 return false;
877 }
878 return true;
879 }
880
881 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
882 if (rank_ == 0) {
883 *os << "that is a scalar";
884 } else {
885 *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
886 }
887 }
888
889 private:
890 int64_t rank_;
891 };
892
893 // A ShapePattern implementation that matches only if the shape has a layout
894 // that matches a given pattern.
895 template <typename LayoutType, typename LayoutImpl>
896 class ShapePatternLayoutImpl {
897 public:
898 explicit constexpr ShapePatternLayoutImpl(
899 const LayoutPattern<LayoutType, LayoutImpl>& layout)
900 : layout_(layout) {}
901
902 bool Match(const ::xla::Shape* shape, MatchOption option) const {
903 return LayoutUtil::HasLayout(*shape) &&
904 layout_.Match(&shape->layout(), option);
905 }
906
907 bool Match(::xla::Shape* shape, MatchOption option) const {
908 if (!LayoutUtil::HasLayout(*shape)) {
909 EXPLAIN << "Shape does not have a layout";
910 return false;
911 }
912 if (!layout_.Match(shape->mutable_layout(), option)) {
913 EXPLAIN << "\nin layout";
914 return false;
915 }
916 return true;
917 }
918
919 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
920 *os << "with";
921 Indent(os, indent + kIndentInc);
922 layout_.DescribeTo(os, indent + kIndentInc);
923 }
924
925 private:
926 LayoutPattern<LayoutType, LayoutImpl> layout_;
927 };
928
929 // A ShapePattern implementation that matches only if the shape has a subshape
930 // that matches a given pattern.
931 template <typename SubshapeType, typename SubshapeImpl>
932 class ShapePatternSubshapeImpl {
933 public:
934 explicit ShapePatternSubshapeImpl(
935 ShapeIndexView index,
936 const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
937 : index_(index), subshape_(subshape) {}
938
939 bool Match(const ::xla::Shape* shape, MatchOption option) const {
940 return MatchImpl(shape, option);
941 }
942
943 bool Match(::xla::Shape* shape, MatchOption option) const {
944 return MatchImpl(shape, option);
945 }
946
947 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
948 *os << "with subshape at index " << ShapeIndex(index_) << " which is";
949 Indent(os, indent + kIndentInc);
950 subshape_.DescribeTo(os, indent + kIndentInc);
951 }
952
953 private:
954 ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
955 return ShapeUtil::GetMutableSubshape(shape, index_);
956 }
957 const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
958 return &ShapeUtil::GetSubshape(*shape, index_);
959 }
960
961 template <typename ShapeType>
962 bool MatchImpl(ShapeType* shape, MatchOption option) const {
963 if (!ShapeUtil::IndexIsValid(*shape, index_)) {
964 EXPLAIN << "No subshape at " << ShapeIndex(index_);
965 return false;
966 }
967 if (!subshape_.Match(GetSubshape(shape), option)) {
968 EXPLAIN << "\nin subshape at " << ShapeIndex(index_);
969 return false;
970 }
971 return true;
972 }
973
974 ShapeIndexView index_;
975 ShapePattern<SubshapeType, SubshapeImpl> subshape_;
976 };
977
978 // A pattern that matches Shapes.
979 template <typename ShapeType, typename Impl>
980 class ShapePattern {
981 private:
982 template <typename NewImpl>
983 auto AppendImpl(NewImpl new_impl) const {
984 auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
985 return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
986 matched_shape_);
987 }
988
989 public:
990 explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
991 : impl_(impl), matched_shape_(matched_shape) {}
992
993 // Returns true and captures the shape iff it matches the pattern.
994 bool Match(const ::xla::Shape* shape, MatchOption option) const {
995 if (impl_.Match(shape, option)) {
996 if (option.capture && matched_shape_) {
997 *matched_shape_ = shape;
998 }
999 return true;
1000 }
1001 if (shape) {
1002 EXPLAIN << "\nin "
1003 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
1004 : ShapeUtil::HumanString(*shape));
1005 }
1006 return false;
1007 }
1008
1009 // Returns true and captures the shape iff it matches the pattern.
1010 bool Match(::xla::Shape* shape, MatchOption option) const {
1011 if (impl_.Match(shape, option)) {
1012 if (option.capture && matched_shape_) {
1013 *matched_shape_ = shape;
1014 }
1015 return true;
1016 }
1017 EXPLAIN << "\nin "
1018 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
1019 : ShapeUtil::HumanString(*shape));
1020 return false;
1021 }
1022
1023 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1024 return impl_.DescribeTo(os, indent);
1025 }
1026
1027 // Modifies the pattern to match only if the shape equals the given proto.
1028 // The layout must outlive the returned pattern.
1029 constexpr auto EqualTo(const ::xla::Shape* shape) const {
1030 return AppendImpl(ShapePatternEqualImpl(shape));
1031 }
1032
1033 // Modifies the pattern to match only if the shape is compatible to the given
1034 // proto. The layout must outlive the returned pattern.
1035 constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
1036 return AppendImpl(ShapePatternCompatibleImpl(shape));
1037 }
1038
1039 // Modifies the pattern to match only if the shape has the given element type.
1040 constexpr auto WithElementType(PrimitiveType element_type) const {
1041 return AppendImpl(ShapePatternElementTypeImpl(element_type));
1042 }
1043
1044 constexpr auto WithDims(absl::Span<const int64_t> dims) const {
1045 return AppendImpl(ShapePatternDimsImpl(dims));
1046 }
1047
1048 // Modifies the pattern to match only if the shape is scalar.
1049 constexpr auto IsScalar() const {
1050 return AppendImpl(ShapePatternIsScalarImpl());
1051 }
1052
1053 // Modifies the pattern to match only if the shape is an array.
1054 constexpr auto IsArray() const {
1055 return AppendImpl(ShapePatternIsArrayImpl());
1056 }
1057
1058 // Modifies the pattern to match only if the shape is a tuple.
1059 constexpr auto IsTuple() const {
1060 return AppendImpl(ShapePatternIsTupleImpl());
1061 }
1062
1063 constexpr auto IsEffectiveScalar() const {
1064 return AppendImpl(ShapePatternEffectiveScalarImpl());
1065 }
1066
1067 // Modifies the pattern to match only if the shape has the given rank.
1068 constexpr auto WithRank(int64_t rank) const {
1069 return AppendImpl(ShapePatternRankImpl(rank));
1070 }
1071
1072 // Modifies the pattern to match only if the shape has a layout that matches
1073 // the given pattern.
1074 template <typename LayoutType, typename LayoutImpl>
1075 auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1076 return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1077 }
1078
1079 constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1080 return WithLayout(Layout().EqualTo(layout));
1081 }
1082
1083 // Modifies the pattern to match only if the shape is a dense array.
1084 constexpr auto IsDenseArray() const {
1085 return AppendImpl(ShapePatternIsDenseArrayImpl());
1086 }
1087
1088 // Modifies the pattern to match only if the shape has a subshape that matches
1089 // the given pattern.
1090 template <typename SubshapeType, typename SubshapeImpl>
1091 auto WithSubshape(
1092 ShapeIndexView index,
1093 const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1094 return AppendImpl(
1095 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1096 }
1097
1098 ShapePattern<ShapeType,
1099 AllOfPattern<::xla::Shape, Impl,
1100 ShapePatternSubshapeImpl<
1101 const ::xla::Shape,
1102 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1103 ShapePatternEqualImpl>>>>
1104 WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1105 return WithSubshape(index,
1106 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1107 ShapePatternBaseImpl(), nullptr)
1108 .EqualTo(shape));
1109 }
1110
1111 ShapePattern<ShapeType,
1112 AllOfPattern<::xla::Shape, Impl,
1113 ShapePatternSubshapeImpl<
1114 const ::xla::Shape,
1115 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1116 ShapePatternCompatibleImpl>>>>
1117 WithSubshapeCompatibleTo(ShapeIndexView index,
1118 const ::xla::Shape* shape) const {
1119 return WithSubshape(index,
1120 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1121 ShapePatternBaseImpl(), nullptr)
1122 .CompatibleTo(shape));
1123 }
1124
1125 private:
1126 Impl impl_;
1127 ShapeType** matched_shape_;
1128 };
1129
1130 } // namespace detail
1131
1132 // Creates a shape pattern that will capture the matched layout in the argument.
1133 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1134 return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1135 detail::ShapePatternBaseImpl(), matched_shape);
1136 }
1137
1138 // Creates a shape pattern that will capture the matched layout in the argument.
1139 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1140 return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1141 detail::ShapePatternBaseImpl(), matched_shape);
1142 }
1143
1144 namespace detail {
1145
1146 // Overloads to get a const or non-const operand out of an instruction.
1147 inline HloInstruction* HloOperand(HloInstruction* instr, int64_t idx) {
1148 return instr->mutable_operand(idx);
1149 }
1150 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1151 int64_t idx) {
1152 return instr->operand(idx);
1153 }
1154
1155 // Pretty-printer for HloInstruction. Sort of like ToShortString, but with
1156 // fewer %s and more shapes.
1157 inline std::string InstToString(const HloInstruction* inst) {
1158 return inst->ToString(
1159 HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1160 }
1161
1162 template <typename HloInstructionType, typename Impl>
1163 class HloInstructionPattern;
1164
1165 // The base HloInstructionPattern implementation. Matches only if the
1166 // instruction is not nullptr.
1167 class HloInstructionPatternBaseImpl {
1168 public:
1169 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1170 if (inst == nullptr) {
1171 EXPLAIN << "HloInstruction* is null";
1172 return false;
1173 }
1174 return true;
1175 }
1176
1177 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1178 *os << "an HloInstruction";
1179 }
1180
1181 static constexpr bool kIsTrivialMatcher = true;
1182 };
1183
1184 // An HloInstructionPattern implementation that matches only if the instruction
1185 // has a given name.
1186 class HloInstructionPatternNameImpl {
1187 public:
1188 explicit HloInstructionPatternNameImpl(absl::string_view name)
1189 : name_(name) {}
1190
1191 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1192 if (inst->name() != name_) {
1193 EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1194 return false;
1195 }
1196 return true;
1197 }
1198
1199 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1200 *os << "named \"" << name_ << "\"";
1201 }
1202
1203 private:
1204 absl::string_view name_;
1205 };
1206
1207 // An HloInstructionPattern implementation that matches only if the instruction
1208 // equals a particular pointer.
1209 class HloInstructionIsImpl {
1210 public:
1211 explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1212
1213 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1214 if (inst != inst_) {
1215 EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1216 << std::showbase << reinterpret_cast<uint64_t>(inst) << " is not "
1217 << reinterpret_cast<uint64_t>(inst_) << " ("
1218 << InstToString(inst_) << ")";
1219 return false;
1220 }
1221 return true;
1222 }
1223
1224 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1225 *os << "which is " << std::hex << std::nouppercase << std::showbase
1226 << reinterpret_cast<uint64_t>(inst_) << " (" << InstToString(inst_)
1227 << ")";
1228 }
1229
1230 private:
1231 const HloInstruction* inst_;
1232 };
1233
1234 // An HloInstructionPattern implementation that matches only if the instruction
1235 // has a given opcode.
1236 class HloInstructionPatternOpcodeImpl {
1237 public:
1238 explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1239 bool invert)
1240 : opcode_(opcode), invert_(invert) {}
1241
1242 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1243 if (invert_ && inst->opcode() == opcode_) {
1244 EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1245 << ", expected anything else";
1246 return false;
1247 }
1248 if (!invert_ && inst->opcode() != opcode_) {
1249 EXPLAIN << "HloInstruction doesn't have opcode "
1250 << HloOpcodeString(opcode_);
1251 return false;
1252 }
1253 return true;
1254 }
1255
1256 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1257 if (!invert_) {
1258 *os << "with opcode " << HloOpcodeString(opcode_);
1259 } else {
1260 *os << "with any opcode other than " << HloOpcodeString(opcode_);
1261 }
1262 }
1263
1264 private:
1265 HloOpcode opcode_;
1266 bool invert_;
1267 };
1268
1269 // An HloInstructionPattern implementation that matches only if the instruction
1270 // has one of a given list of custom call targets.
1271 class HloInstructionCustomCallTargetImpl {
1272 public:
1273 explicit HloInstructionCustomCallTargetImpl(
1274 absl::Span<const absl::string_view> custom_call_targets)
1275 : custom_call_targets_(custom_call_targets.begin(),
1276 custom_call_targets.end()) {}
1277
1278 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1279 if (inst->opcode() != HloOpcode::kCustomCall ||
1280 !absl::c_linear_search(custom_call_targets_,
1281 inst->custom_call_target())) {
1282 if (custom_call_targets_.size() == 1) {
1283 EXPLAIN << "HloInstruction is not a custom call with a target '"
1284 << custom_call_targets_.front() << "'";
1285 } else {
1286 EXPLAIN << "HloInstruction is not a custom call with a target in {"
1287 << absl::StrJoin(custom_call_targets_, ", ") << "}";
1288 }
1289 return false;
1290 }
1291 return true;
1292 }
1293
1294 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1295 if (custom_call_targets_.size() == 1) {
1296 *os << "custom call with target '" << custom_call_targets_.front() << "'";
1297 } else {
1298 *os << "custom call with target in {"
1299 << absl::StrJoin(custom_call_targets_, ", ") << "}";
1300 }
1301 }
1302
1303 private:
1304 absl::InlinedVector<std::string, 1> custom_call_targets_;
1305 };
1306
1307 // An HloInstructionPattern implementation that matches only if the instruction
1308 // has the given number of operands.
1309 class HloInstructionPatternNumOperandsImpl {
1310 public:
1311 explicit constexpr HloInstructionPatternNumOperandsImpl(int64_t num_operands)
1312 : num_operands_(num_operands) {}
1313
1314 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1315 if (inst->operand_count() != num_operands_) {
1316 EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1317 return false;
1318 }
1319 return true;
1320 }
1321
1322 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1323 *os << "with " << num_operands_ << " operand"
1324 << (num_operands_ != 1 ? "s" : "");
1325 }
1326
1327 private:
1328 int64_t num_operands_;
1329 };
1330
1331 // An HloInstructionPattern implementation that matches only if the instruction
1332 // has a shape that matches a given pattern.
1333 template <typename ShapeType, typename ShapeImpl>
1334 class HloInstructionPatternShapeImpl {
1335 public:
1336 explicit constexpr HloInstructionPatternShapeImpl(
1337 const ShapePattern<ShapeType, ShapeImpl>& shape)
1338 : shape_(shape) {}
1339
1340 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1341 if (!shape_.Match(&inst->shape(), option)) {
1342 EXPLAIN << "\nin output shape";
1343 return false;
1344 }
1345 return true;
1346 }
1347
1348 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1349 if (!shape_.Match(inst->mutable_shape(), option)) {
1350 EXPLAIN << "\nin output shape";
1351 return false;
1352 }
1353 return true;
1354 }
1355
1356 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1357 *os << "outputting";
1358 Indent(os, indent + kIndentInc);
1359 shape_.DescribeTo(os, indent + kIndentInc);
1360 }
1361
1362 private:
1363 ShapePattern<ShapeType, ShapeImpl> shape_;
1364 };
1365
1366 // An HloInstructionPattern implementation that matches only if the instruction
1367 // has an operand that matches a given pattern.
1368 template <typename OperandType, typename OperandImpl>
1369 class HloInstructionPatternOperandImpl {
1370 public:
1371 explicit constexpr HloInstructionPatternOperandImpl(
1372 int64_t operand_index,
1373 const HloInstructionPattern<OperandType, OperandImpl>& operand)
1374 : operand_index_(operand_index), operand_(operand) {}
1375
1376 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1377 return MatchImpl(inst, option);
1378 }
1379
1380 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1381 return MatchImpl(inst, option);
1382 }
1383
1384 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1385 *os << "with operand " << operand_index_ << " which is:";
1386 Indent(os, indent + kIndentInc);
1387 operand_.DescribeTo(os, indent + kIndentInc);
1388 }
1389
1390 private:
1391 template <typename HloInstructionType>
1392 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1393 if (operand_index_ >= inst->operand_count()) {
1394 EXPLAIN << "desired operand index " << operand_index_
1395 << " is out of bounds";
1396 return false;
1397 }
1398 if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1399 EXPLAIN << "\nin operand " << operand_index_;
1400 return false;
1401 }
1402 return true;
1403 }
1404
1405 int64_t operand_index_;
1406 HloInstructionPattern<OperandType, OperandImpl> operand_;
1407 };
1408
1409 // An HloInstructionPattern implementation that matches if the instruction has
1410 // fewer than i+1 operands, or if the i'th operand matches a given pattern.
1411 template <typename OperandType, typename OperandImpl>
1412 class HloInstructionPatternOperandIfPresentImpl {
1413 public:
1414 explicit constexpr HloInstructionPatternOperandIfPresentImpl(
1415 int64_t operand_index,
1416 const HloInstructionPattern<OperandType, OperandImpl>& operand)
1417 : operand_index_(operand_index), operand_(operand) {}
1418
1419 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1420 return MatchImpl(inst, option);
1421 }
1422
1423 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1424 return MatchImpl(inst, option);
1425 }
1426
1427 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1428 *os << "either with fewer than " << operand_index_ + 1 << " operand"
1429 << (operand_index_ + 1 != 1 ? "s" : "") << ", or with an operand "
1430 << operand_index_ << " which is:";
1431 Indent(os, indent + kIndentInc);
1432 operand_.DescribeTo(os, indent + kIndentInc);
1433 }
1434
1435 private:
1436 template <typename HloInstructionType>
1437 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1438 if (operand_index_ >= inst->operand_count()) {
1439 return true;
1440 }
1441 if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1442 EXPLAIN << "\nin operand " << operand_index_;
1443 return false;
1444 }
1445 return true;
1446 }
1447
1448 int64_t operand_index_;
1449 HloInstructionPattern<OperandType, OperandImpl> operand_;
1450 };
1451
1452 // Matches a binary instruction whose operands come in any order.
1453 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1454 typename OperandImpl2>
1455 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1456 public:
1457 explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1458 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1459 const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1460 : op1_(op1), op2_(op2) {}
1461
1462 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1463 return MatchImpl(inst, option);
1464 }
1465
1466 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1467 return MatchImpl(inst, option);
1468 }
1469
1470 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1471 *os << "with two operands in either order:";
1472 Indent(os, indent);
1473 *os << " - ";
1474 op1_.DescribeTo(os, indent + 3);
1475 Indent(os, indent);
1476 *os << " - ";
1477 op2_.DescribeTo(os, indent + 3);
1478 }
1479
1480 private:
1481 HloInstruction* operand(HloInstruction* inst, int64_t idx) const {
1482 return inst->mutable_operand(idx);
1483 }
1484 const HloInstruction* operand(const HloInstruction* inst, int64_t idx) const {
1485 return inst->operand(idx);
1486 }
1487
1488 template <typename HloInstructionType>
1489 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1490 // We could implement this using AnyOf and AllOf matchers, but the templates
1491 // get pretty difficult to debug, since any compile error herein becomes
1492 // not-an-error via SFINAE. Also this way lets us give better messages on
1493 // failure.
1494 if (inst->operand_count() != 2) {
1495 EXPLAIN << "HloInstruction did not have two operands";
1496 return false;
1497 }
1498
1499 // If we're not generating explanations, this is pretty simple.
1500 if (!option.explain_os) {
1501 auto try_match = [&](int64_t idx1, int64_t idx2) {
1502 MatchOption new_option = option;
1503 new_option.capture = false;
1504 if (op1_.Match(operand(inst, idx1), new_option) &&
1505 op2_.Match(operand(inst, idx2), new_option)) {
1506 if (option.capture) {
1507 bool matched = op1_.Match(operand(inst, idx1), option) &&
1508 op2_.Match(operand(inst, idx2), option);
1509 DCHECK(matched);
1510 }
1511 return true;
1512 }
1513 return false;
1514 };
1515 return try_match(0, 1) || try_match(1, 0);
1516 }
1517
1518 // If we are generating explanations, we have some work to do in order to
1519 // generate a helpful error.
1520 //
1521 // First, try all four operand/matcher combinations, recording the
1522 // failure explanations separately from option.explain_os. matches[i][j]
1523 // tells us if matcher_i matches operand j.
1524 bool matches[/*matcher*/ 2][/*operand*/ 2];
1525 std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1526 for (int i = 0; i < 2; ++i) {
1527 for (int j = 0; j < 2; ++j) {
1528 MatchOption new_option = option;
1529 new_option.capture = false;
1530 new_option.explain_os = &explanations[i][j];
1531 matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1532 : op2_.Match(operand(inst, j), new_option);
1533 }
1534 }
1535
1536 // Check if the match succeeded.
1537 for (int i = 0; i < 2; ++i) {
1538 if (matches[0][i] && matches[1][(i + 1) % 2]) {
1539 // Rerun the matches with capture enabled if necessary.
1540 if (option.capture) {
1541 auto* operand1 = operand(inst, i);
1542 auto* operand2 = operand(inst, (i + 1) % 2);
1543 bool matched =
1544 op1_.Match(operand1, option) && op2_.Match(operand2, option);
1545 DCHECK(matched);
1546 }
1547 return true;
1548 }
1549 }
1550
1551 auto describe_matcher = [&](int matcher_idx) {
1552 EXPLAIN << "\n - ";
1553 if (matcher_idx == 0) {
1554 op1_.DescribeTo(option.explain_os, /*indent=*/3);
1555 } else {
1556 CHECK_EQ(matcher_idx, 1);
1557 op2_.DescribeTo(option.explain_os, /*indent=*/3);
1558 }
1559 for (int i = 0; i < 2; ++i) {
1560 if (matches[matcher_idx][/*operand*/ i]) {
1561 continue;
1562 }
1563 EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1564 EXPLAIN << " - ";
1565 EXPLAIN << absl::StrReplaceAll(
1566 explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}});
1567 }
1568 };
1569
1570 // If we failed to match, one of the following is true:
1571 // 1. op1 (op2) matches neither LHS nor RHS, or
1572 // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1573 // We print different explanations depending on which case we're in.
1574
1575 // Case 1.
1576 bool wrote_explanation = false;
1577 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1578 if (!matches[i][0] && !matches[i][1]) {
1579 EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1580 << (i == 0 ? "first" : "second") << " matcher. Specifically,";
1581 describe_matcher(i);
1582 wrote_explanation = true;
1583 }
1584 }
1585
1586 // Case 2.
1587 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1588 if (matches[/*matcher*/ 0][/*operand*/ i] &&
1589 matches[/*matcher*/ 1][/*operand*/ i]) {
1590 CHECK(!matches[0][(i + 1) % 2]);
1591 CHECK(!matches[1][(i + 1) % 2]);
1592 CHECK(!wrote_explanation);
1593 EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1594 << " operand did not match either of the two matchers. "
1595 "Specifically,";
1596 describe_matcher(0);
1597 EXPLAIN << "\nand";
1598 describe_matcher(1);
1599 wrote_explanation = true;
1600 }
1601 }
1602
1603 CHECK(wrote_explanation);
1604 return false;
1605 }
1606
1607 HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1608 HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1609 };
1610
1611 // An HloInstructionPattern implementation that matches only if the instruction
1612 // is a fusion node with a particular kind.
1613 class HloInstructionPatternFusionKindImpl {
1614 public:
1615 explicit constexpr HloInstructionPatternFusionKindImpl(
1616 ::xla::HloInstruction::FusionKind kind)
1617 : kind_(kind) {}
1618
1619 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1620 return MatchImpl(inst, option);
1621 }
1622
1623 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1624 return MatchImpl(inst, option);
1625 }
1626
1627 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1628 *os << "with fusion kind " << ToString(kind_);
1629 }
1630
1631 private:
1632 template <typename HloInstructionType>
1633 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1634 if (inst->opcode() != HloOpcode::kFusion) {
1635 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1636 << "; it's not a fusion";
1637 return false;
1638 }
1639 if (inst->fusion_kind() != kind_) {
1640 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1641 return false;
1642 }
1643 return true;
1644 }
1645
1646 ::xla::HloInstruction::FusionKind kind_;
1647 };
1648
1649 // An HloInstructionPattern implementation that matches only if the instruction
1650 // is a kGetTupleElement with a particular tuple index.
1651 class HloInstructionPatternTupleIndexImpl {
1652 public:
1653 explicit constexpr HloInstructionPatternTupleIndexImpl(int64_t tuple_index)
1654 : tuple_index_(tuple_index) {}
1655
1656 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1657 return MatchImpl(inst, option);
1658 }
1659
1660 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1661 return MatchImpl(inst, option);
1662 }
1663
1664 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1665 *os << "which is a GTE with index " << tuple_index_;
1666 }
1667
1668 private:
1669 template <typename HloInstructionType>
1670 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1671 if (inst->opcode() != HloOpcode::kGetTupleElement) {
1672 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1673 << "; it's not a GTE at all";
1674 return false;
1675 }
1676 if (inst->tuple_index() != tuple_index_) {
1677 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1678 return false;
1679 }
1680 return true;
1681 }
1682
1683 int64_t tuple_index_;
1684 };
1685
1686 class HloInstructionPatternParameterNumImpl {
1687 public:
1688 explicit constexpr HloInstructionPatternParameterNumImpl(
1689 int64_t parameter_num)
1690 : parameter_num_(parameter_num) {}
1691
1692 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1693 return MatchImpl(inst, option);
1694 }
1695
1696 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1697 return MatchImpl(inst, option);
1698 }
1699
1700 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1701 *os << "which is parameter " << parameter_num_;
1702 }
1703
1704 private:
1705 template <typename HloInstructionType>
1706 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1707 if (inst->opcode() != HloOpcode::kParameter ||
1708 inst->parameter_number() != parameter_num_) {
1709 EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1710 return false;
1711 }
1712 return true;
1713 }
1714
1715 int64_t parameter_num_;
1716 };
1717
1718 // Superclass that contains common code used by Op::WithOneUse() and
1719 // Op::WithOneUser().
1720 class HloInstructionPatternOneUseOrUserImpl {
1721 protected:
1722 bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1723 if (inst->user_count() != 1) {
1724 EXPLAIN << "HloInstruction has " << inst->user_count()
1725 << " users, but expected exactly one.";
1726 if (inst->user_count() > 1) {
1727 EXPLAIN << "\nAll users:";
1728 for (const HloInstruction* user : inst->users()) {
1729 EXPLAIN << "\n - " << InstToString(user);
1730 }
1731 }
1732 return false;
1733 }
1734 return true;
1735 }
1736 };
1737
1738 class HloInstructionPatternOneUseImpl
1739 : public HloInstructionPatternOneUseOrUserImpl {
1740 public:
1741 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1742 if (!MatchOneUser(inst, option)) {
1743 return false;
1744 }
1745
1746 int64_t use_count = absl::c_count_if(
1747 inst->users()[0]->operands(),
1748 [&](const HloInstruction* operand) { return operand == inst; });
1749 if (use_count != 1) {
1750 EXPLAIN << "HloInstruction is used " << use_count
1751 << " times by its user, but is expected to be used just once: "
1752 << InstToString(inst->users()[0]);
1753 return false;
1754 }
1755 return true;
1756 }
1757
1758 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1759 *os << "which has exactly one use";
1760 }
1761 };
1762
1763 class HloInstructionPatternOneUserImpl
1764 : public HloInstructionPatternOneUseOrUserImpl {
1765 public:
1766 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1767 return MatchOneUser(inst, option);
1768 }
1769
1770 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1771 *os << "which has exactly one user (but possibly is used multiple times by "
1772 "that instruction)";
1773 }
1774 };
1775
1776 class HloInstructionPatternComparisonDirectionImpl {
1777 public:
1778 explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1779 ComparisonDirection direction)
1780 : direction_(direction) {}
1781
1782 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1783 return MatchImpl(inst, option);
1784 }
1785
1786 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1787 return MatchImpl(inst, option);
1788 }
1789
1790 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1791 *os << "which has comparison direction "
1792 << ComparisonDirectionToString(direction_);
1793 }
1794
1795 private:
1796 template <typename HloInstructionType>
1797 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1798 if (inst->opcode() != HloOpcode::kCompare ||
1799 inst->comparison_direction() != direction_) {
1800 EXPLAIN << "HloInstruction is not comparison "
1801 << ComparisonDirectionToString(direction_);
1802 return false;
1803 }
1804 return true;
1805 }
1806
1807 ComparisonDirection direction_;
1808 };
1809
1810 class HloInstructionPredicateImpl {
1811 public:
1812 explicit HloInstructionPredicateImpl(HloPredicate fn) : fn_(std::move(fn)) {}
1813
1814 bool Match(const HloInstruction* inst, MatchOption option) const {
1815 bool match = fn_(inst);
1816 if (!match) {
1817 EXPLAIN << "HloInstruction does not match user-specified predicate";
1818 }
1819 return match;
1820 }
1821
1822 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1823 *os << "which matches a user-specified predicate";
1824 }
1825
1826 private:
1827 HloPredicate fn_;
1828 };
1829
1830 // Matches a constant scalar or effective scalar, optionally with a given value.
1831 template <typename ScalarTy>
1832 class HloConstantScalarImpl {
1833 public:
1834 explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1835 : val_(std::nullopt), match_effective_scalar_(match_effective_scalar) {}
1836
1837 constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1838 : val_(val), match_effective_scalar_(match_effective_scalar) {}
1839
1840 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1841 return MatchImpl(inst, option);
1842 }
1843
1844 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1845 return MatchImpl(inst, option);
1846 }
1847
1848 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1849 *os << "which is a constant "
1850 << (match_effective_scalar_ ? "effective " : "") << "scalar";
1851 if (val_.has_value()) {
1852 *os << " with value " << *val_;
1853 }
1854 }
1855
1856 private:
1857 template <typename InstTy>
1858 bool MatchImpl(InstTy* inst, MatchOption option) const {
1859 const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1860 if (!const_inst) {
1861 EXPLAIN << "HloInstruction is not a constant";
1862 return false;
1863 }
1864 if (match_effective_scalar_ &&
1865 !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1866 EXPLAIN << "HloInstruction is not an effective scalar";
1867 return false;
1868 }
1869 if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1870 EXPLAIN << "HloInstruction is not a scalar";
1871 return false;
1872 }
1873 if (!val_.has_value()) {
1874 return true;
1875 }
1876
1877 auto const_inst_scalar_or = const_inst->literal().Reshape({});
1878 if (!const_inst_scalar_or.ok()) {
1879 EXPLAIN << "could not convert matched literal to effective scalar";
1880 return false;
1881 }
1882 Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1883 if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1884 EXPLAIN << "HloInstruction's constant value "
1885 << const_inst_scalar.ToStringWithoutShape()
1886 << " did not match expected value " << *val_;
1887 return false;
1888 }
1889 return true;
1890 }
1891
1892 std::optional<ScalarTy> val_;
1893 bool match_effective_scalar_;
1894 };
1895
1896 // A pattern that matches HloInstructions.
1897 template <typename HloInstructionType, typename Impl>
1898 class HloInstructionPattern {
1899 private:
1900 template <typename NewImpl>
1901 auto AppendImpl(NewImpl new_impl) const {
1902 auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1903 return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1904 std::move(new_allof), matched_inst_);
1905 }
1906
1907 public:
1908 explicit constexpr HloInstructionPattern(const Impl& impl,
1909 HloInstructionType** matched_inst)
1910 : impl_(impl), matched_inst_(matched_inst) {}
1911
1912 // Returns true and captures the instruction iff it matches the pattern.
1913 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1914 if (impl_.Match(inst, option)) {
1915 if (option.capture && matched_inst_) {
1916 *matched_inst_ = inst;
1917 }
1918 return true;
1919 }
1920 if (inst != nullptr) {
1921 EXPLAIN << "\nin " << InstToString(inst);
1922 }
1923 return false;
1924 }
1925
1926 // Returns true and captures the instruction iff it matches the pattern.
1927 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1928 if (impl_.Match(inst, option)) {
1929 if (option.capture && matched_inst_) {
1930 *matched_inst_ = inst;
1931 }
1932 return true;
1933 }
1934 EXPLAIN << "\nin " << InstToString(inst);
1935 return false;
1936 }
1937
1938 // Modifies the pattern to match only if the instruction has the given name.
1939 auto WithName(absl::string_view name) const {
1940 return AppendImpl(HloInstructionPatternNameImpl(name));
1941 }
1942
1943 // Modifies the pattern to match only if the instruction has the given opcode.
1944 auto WithOpcode(HloOpcode opcode) const {
1945 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1946 }
1947
1948 // Modifies the pattern to match only the custom call with a given target.
1949 auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1950 return AppendImpl(HloInstructionCustomCallTargetImpl({custom_call_target}));
1951 }
1952
1953 // Modifies the pattern to match a custom call with one of the given targets.
1954 auto WithCustomCallTarget(
1955 absl::Span<const absl::string_view> custom_call_targets) const {
1956 return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_targets));
1957 }
1958
1959 auto WithNumOperands(int64_t num_operands) const {
1960 return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1961 }
1962
1963 // Modifies the pattern to match only if the instruction does not have the
1964 // given opcode.
1965 auto WithoutOpcode(HloOpcode opcode) const {
1966 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1967 }
1968
1969 constexpr auto Is(const HloInstruction* instr) const {
1970 return AppendImpl(HloInstructionIsImpl(instr));
1971 }
1972
1973 // Modifies the pattern to match only if the instruction is a constant.
1974 constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1975
1976 constexpr auto IsConstantScalar() const {
1977 return AppendImpl(
1978 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1979 }
1980
1981 // This does not check that T has the same type as the instruction, so e.g.
1982 // IsConstantScalar(1.0) may match a constant of shape int32_t[].
1983 template <typename ScalarTy>
1984 constexpr auto IsConstantScalar(const ScalarTy& val) const {
1985 return AppendImpl(
1986 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1987 }
1988
1989 constexpr auto IsConstantEffectiveScalar() const {
1990 return AppendImpl(
1991 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1992 }
1993
1994 template <typename ScalarTy>
1995 constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1996 return AppendImpl(
1997 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1998 }
1999
2000 // Modifies the pattern to match only if the instruction is not a constant.
2001 constexpr auto IsNonConstant() const {
2002 return WithoutOpcode(HloOpcode::kConstant);
2003 }
2004
2005 // Modifies the pattern to match only if the instruction has a shape that
2006 // matches the given pattern.
2007 template <typename ShapeType, typename ShapeImpl>
2008 constexpr auto WithShape(
2009 const ShapePattern<ShapeType, ShapeImpl>& shape) const {
2010 return AppendImpl(
2011 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
2012 }
2013
2014 // Because we only specify the shape's element type and dims, this is
2015 // effectivley checking shape-compatible-to, not shape-equal-to. Perhaps this
2016 // function should be called WithShapeCompatibleTo, but the short name is
2017 // nice, and there's no ambiguity because there's no layout in the args!
2018 constexpr auto WithShape(PrimitiveType ty, absl::Span<const int64_t> dims) {
2019 return WithShape(Shape().WithElementType(ty).WithDims(dims));
2020 }
2021
2022 // Make this a templated function to work around gcc 4.9.4 template infinite
2023 // recursion bug.
2024 template <typename Dummy = void>
2025 constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
2026 return WithShape(Shape().EqualTo(shape));
2027 }
2028
2029 // Make this a templated function to work around gcc 4.9.4 template infinite
2030 // recursion bug.
2031 template <typename Dummy = void>
2032 constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
2033 return WithShape(Shape().CompatibleTo(shape));
2034 }
2035
2036 // Modifies the pattern to match only if the instruction's shape's element
2037 // type matches the given pattern.
2038 constexpr auto WithElementType(PrimitiveType ty) {
2039 return WithShape(Shape().WithElementType(ty));
2040 }
2041
2042 // Modifies the pattern to match only if the instruction has an operand that
2043 // matches the given pattern.
2044 template <typename OperandType, typename OperandImpl>
2045 constexpr auto WithOperand(
2046 int64_t operand_index,
2047 const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
2048 return AppendImpl(
2049 HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
2050 operand_index, operand));
2051 }
2052
2053 // Modifies the pattern to match only if
2054 // - the instruction has fewer than i+1 operands, or
2055 // - the i'th operand matches the given pattern.
2056 template <typename OperandType, typename OperandImpl>
2057 constexpr auto WithOperandIfPresent(
2058 int64_t operand_index,
2059 const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
2060 return AppendImpl(
2061 HloInstructionPatternOperandIfPresentImpl<OperandType, OperandImpl>(
2062 operand_index, operand));
2063 }
2064
2065 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
2066 typename OperandImpl2>
2067 constexpr auto WithBinaryOperandsAnyOrder(
2068 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
2069 const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
2070 return AppendImpl(
2071 HloInstructionPatternBinaryOperandsAnyOrderImpl<
2072 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
2073 }
2074
2075 // Modifies the pattern to match only if the instruction is a fusion node with
2076 // the given kind.
2077 constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
2078 return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
2079 }
2080
2081 // Modifies the pattern to match only if the instruction is a
2082 // get-tuple-element with the given tuple index.
2083 constexpr auto WithTupleIndex(int64_t tuple_index) const {
2084 return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
2085 }
2086
2087 // Modifies the pattern to match only if the instruction is a parameter
2088 // with the given parameter number.
2089 constexpr auto WithParameterNum(int64_t parameter_num) const {
2090 return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
2091 }
2092
2093 // Modifies the pattern to match if the instruction is used exactly once.
2094 // Does not match if the instruction is used twice by the same user (e.g.
2095 // multiply(x,x)).
2096 constexpr auto WithOneUse() const {
2097 return AppendImpl(HloInstructionPatternOneUseImpl());
2098 }
2099
2100 // Modifies the pattern to match if the instruction is used by exactly one
2101 // other instruction. Will match if the instruction is used twice, so long as
2102 // it's by the same user (e.g. multiply(x,x)).
2103 constexpr auto WithOneUser() const {
2104 return AppendImpl(HloInstructionPatternOneUserImpl());
2105 }
2106
2107 // Modifies the pattern to match only if the instruction has the given
2108 // comparison direction.
2109 auto WithComparisonDirection(ComparisonDirection direction) const {
2110 return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
2111 }
2112
2113 auto WithPredicate(HloPredicate fn) const {
2114 return AppendImpl(HloInstructionPredicateImpl(std::move(fn)));
2115 }
2116
2117 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
2118 impl_.DescribeTo(os, indent);
2119 }
2120
2121 private:
2122 Impl impl_;
2123 HloInstructionType** matched_inst_;
2124 };
2125
2126 } // namespace detail
2127
2128 // Creates an instruction pattern that will capture the matched instruction in
2129 // the argument.
2130 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
2131 return detail::HloInstructionPattern<const ::xla::HloInstruction,
2132 detail::HloInstructionPatternBaseImpl>(
2133 detail::HloInstructionPatternBaseImpl(), matched_inst);
2134 }
2135
2136 // Creates an instruction pattern that will capture the matched instruction in
2137 // the argument.
2138 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
2139 return detail::HloInstructionPattern<::xla::HloInstruction,
2140 detail::HloInstructionPatternBaseImpl>(
2141 detail::HloInstructionPatternBaseImpl(), matched_inst);
2142 }
2143
2144 // Helpers for nullary instructions.
2145 #define XLA_NULLOP_PATTERN(NAME) \
2146 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2147 \
2148 template <typename HloInstructionType> \
2149 inline auto NAME(HloInstructionType** matched_inst) { \
2150 return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
2151 }
2152 XLA_NULLOP_PATTERN(Constant)
2153 XLA_NULLOP_PATTERN(Parameter)
2154 XLA_NULLOP_PATTERN(Iota)
2155 XLA_NULLOP_PATTERN(Rng)
2156 XLA_NULLOP_PATTERN(PartitionId)
2157 XLA_NULLOP_PATTERN(ReplicaId)
2158 #undef XLA_NULLOP_PATTERN
2159
2160 // Helpers for unary instructions.
2161 #define XLA_UNOP_PATTERN(NAME) \
2162 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2163 \
2164 template <typename Arg> \
2165 inline auto NAME(Arg&& arg) { \
2166 return Op() \
2167 .WithOpcode(HloOpcode::k##NAME) \
2168 .WithOperand(0, std::forward<Arg>(arg)); \
2169 } \
2170 \
2171 template <typename HloInstructionType, typename Arg> \
2172 inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \
2173 return Op(matched_inst) \
2174 .WithOpcode(HloOpcode::k##NAME) \
2175 .WithOperand(0, std::forward<Arg>(arg)); \
2176 }
2177 XLA_UNOP_PATTERN(Abs)
2178 XLA_UNOP_PATTERN(RoundNearestAfz)
2179 XLA_UNOP_PATTERN(Bitcast)
2180 XLA_UNOP_PATTERN(BitcastConvert)
2181 XLA_UNOP_PATTERN(Broadcast)
2182 XLA_UNOP_PATTERN(Ceil)
2183 XLA_UNOP_PATTERN(Convert)
2184 XLA_UNOP_PATTERN(Copy)
2185 XLA_UNOP_PATTERN(Cos)
2186 XLA_UNOP_PATTERN(AllReduce)
2187 XLA_UNOP_PATTERN(Exp)
2188 XLA_UNOP_PATTERN(Fft)
2189 XLA_UNOP_PATTERN(Floor)
2190 XLA_UNOP_PATTERN(GetTupleElement)
2191 XLA_UNOP_PATTERN(Imag)
2192 XLA_UNOP_PATTERN(Infeed)
2193 XLA_UNOP_PATTERN(IsFinite)
2194 XLA_UNOP_PATTERN(Log)
2195 XLA_UNOP_PATTERN(Not)
2196 XLA_UNOP_PATTERN(Negate)
2197 XLA_UNOP_PATTERN(Real)
2198 XLA_UNOP_PATTERN(Recv)
2199 XLA_UNOP_PATTERN(RecvDone)
2200 XLA_UNOP_PATTERN(ReducePrecision)
2201 XLA_UNOP_PATTERN(Reshape)
2202 XLA_UNOP_PATTERN(Reverse)
2203 XLA_UNOP_PATTERN(Rsqrt)
2204 XLA_UNOP_PATTERN(SendDone)
2205 XLA_UNOP_PATTERN(Sign)
2206 XLA_UNOP_PATTERN(Sin)
2207 XLA_UNOP_PATTERN(Slice)
2208 XLA_UNOP_PATTERN(Sqrt)
2209 XLA_UNOP_PATTERN(Tanh)
2210 XLA_UNOP_PATTERN(Transpose)
2211 #undef XLA_UNOP_PATTERN
2212
2213 // Helpers for binary instructions.
2214 #define XLA_BINOP_PATTERN(NAME) \
2215 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2216 \
2217 template <typename Lhs, typename Rhs> \
2218 inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
2219 return Op() \
2220 .WithOpcode(HloOpcode::k##NAME) \
2221 .WithOperand(0, std::forward<Lhs>(lhs)) \
2222 .WithOperand(1, std::forward<Rhs>(rhs)); \
2223 } \
2224 \
2225 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2226 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2227 return Op(matched_inst) \
2228 .WithOpcode(HloOpcode::k##NAME) \
2229 .WithOperand(0, std::forward<Lhs>(lhs)) \
2230 .WithOperand(1, std::forward<Rhs>(rhs)); \
2231 }
2232
2233 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
2234 XLA_BINOP_PATTERN(NAME) \
2235 \
2236 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2237 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2238 Rhs&& rhs) { \
2239 return Op(matched_inst) \
2240 .WithOpcode(HloOpcode::k##NAME) \
2241 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2242 std::forward<Rhs>(rhs)); \
2243 } \
2244 template <typename Lhs, typename Rhs> \
2245 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
2246 return NAME##AnyOrder<const HloInstruction>( \
2247 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2248 }
2249 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2250 XLA_BINOP_PATTERN(Atan2)
2251 XLA_BINOP_PATTERN(Divide)
2252 XLA_BINOP_PATTERN(Complex)
2253 XLA_BINOP_PATTERN(Compare)
2254 XLA_BINOP_PATTERN(Convolution)
2255 XLA_BINOP_PATTERN(Dot)
2256 XLA_BINOP_PATTERN(Gather)
2257 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2258 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2259 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2260 XLA_BINOP_PATTERN(Outfeed)
2261 XLA_BINOP_PATTERN(Pad)
2262 XLA_BINOP_PATTERN(Power)
2263 XLA_BINOP_PATTERN(Remainder)
2264 XLA_BINOP_PATTERN(Send)
2265 XLA_BINOP_PATTERN(Subtract)
2266 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2267 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2268 XLA_BINOP_PATTERN(ShiftLeft)
2269 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2270 XLA_BINOP_PATTERN(ShiftRightLogical)
2271 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2272 #undef XLA_BINOP_PATTERN
2273
2274 // Helpers for ternary instructions.
2275 #define XLA_TERNOP_PATTERN(NAME) \
2276 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2277 \
2278 template <typename Arg0, typename Arg1, typename Arg2> \
2279 inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \
2280 return Op() \
2281 .WithOpcode(HloOpcode::k##NAME) \
2282 .WithOperand(0, std::forward<Arg0>(arg0)) \
2283 .WithOperand(1, std::forward<Arg1>(arg1)) \
2284 .WithOperand(2, std::forward<Arg2>(arg2)); \
2285 } \
2286 \
2287 template <typename HloInstructionType, typename Arg0, typename Arg1, \
2288 typename Arg2> \
2289 inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \
2290 Arg1&& arg1, Arg2&& arg2) { \
2291 return Op(matched_inst) \
2292 .WithOpcode(HloOpcode::k##NAME) \
2293 .WithOperand(0, std::forward<Arg0>(arg0)) \
2294 .WithOperand(1, std::forward<Arg1>(arg1)) \
2295 .WithOperand(2, std::forward<Arg2>(arg2)); \
2296 }
2297 XLA_TERNOP_PATTERN(Clamp);
2298 XLA_TERNOP_PATTERN(Select);
2299 XLA_TERNOP_PATTERN(SelectAndScatter);
2300 #undef XLA_TERNOP_PATTERN
2301
2302 namespace detail {
2303 template <typename Matcher, typename FirstArg>
2304 inline auto WithOperands(Matcher&& m, int64_t operand_num,
2305 FirstArg&& first_arg) {
2306 return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2307 }
2308
2309 template <typename Matcher, typename FirstArg, typename... Args>
2310 inline auto WithOperands(Matcher&& m, int64_t operand_num, FirstArg&& first_arg,
2311 Args&&... args) {
2312 return WithOperands(
2313 m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2314 operand_num + 1, std::forward<Args>(args)...);
2315 }
2316 } // namespace detail
2317
2318 #define XLA_VARIADIC_OP_PATTERN(NAME) \
2319 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2320 \
2321 template <typename... Args> \
2322 inline auto NAME(Args&&... args) { \
2323 return detail::WithOperands( \
2324 Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2325 /*operand_num=*/0, std::forward<Args>(args)...); \
2326 } \
2327 \
2328 template <typename HloInstructionType, typename... Args> \
2329 inline auto NAME(HloInstructionType** matched_inst, Args&&... args) { \
2330 return detail::WithOperands(Op(matched_inst) \
2331 .WithOpcode(HloOpcode::k##NAME) \
2332 .WithNumOperands(sizeof...(Args)), \
2333 /*operand_num=*/0, \
2334 std::forward<Args>(args)...); \
2335 }
2336
2337 // We could implement all ops as "variadic" ops, but it would make the
2338 // already-bad compile errors even worse.
2339 XLA_VARIADIC_OP_PATTERN(AfterAll);
2340 XLA_VARIADIC_OP_PATTERN(Concatenate);
2341 XLA_VARIADIC_OP_PATTERN(Conditional);
2342 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2343 XLA_VARIADIC_OP_PATTERN(DynamicUpdateSlice)
2344 XLA_VARIADIC_OP_PATTERN(Fusion);
2345 XLA_VARIADIC_OP_PATTERN(Map)
2346 XLA_VARIADIC_OP_PATTERN(Reduce);
2347 XLA_VARIADIC_OP_PATTERN(ReduceWindow)
2348 XLA_VARIADIC_OP_PATTERN(Scatter);
2349 XLA_VARIADIC_OP_PATTERN(Sort);
2350 XLA_VARIADIC_OP_PATTERN(Tuple);
2351 XLA_VARIADIC_OP_PATTERN(Call);
2352
2353 // CustomCall doesn't use the XLA_VARIADIC_OP_PATTERN macro so that you can
2354 // optionally pass a string_view for the custom_call_target before the other
2355 // operands.
2356 inline auto CustomCall() { return Op().WithOpcode(HloOpcode::kCustomCall); }
2357
2358 template <typename HloInstructionType>
2359 auto CustomCall(HloInstructionType** matched_inst) {
2360 return Op(matched_inst).WithOpcode(HloOpcode::kCustomCall);
2361 }
2362
2363 template <
2364 typename Arg0, typename... Args,
2365 typename std::enable_if<
2366 !std::is_convertible<Arg0, absl::string_view>::value &&
2367 !std::is_convertible<Arg0, HloInstruction**>::value &&
2368 !std::is_convertible<Arg0, const HloInstruction**>::value>::type* =
2369 nullptr>
2370 auto CustomCall(Arg0&& arg0, Args&&... args) {
2371 return detail::WithOperands(CustomCall().WithNumOperands(sizeof...(Args) + 1),
2372 /*operand_num=*/0, std::forward<Arg0>(arg0),
2373 std::forward<Args>(args)...);
2374 }
2375 template <typename... Args>
2376 auto CustomCall(absl::string_view custom_call_target, Args&&... args) {
2377 return CustomCall(std::forward<Args>(args)...)
2378 .WithCustomCallTarget(custom_call_target);
2379 }
2380
2381 template <typename HloInstructionType, typename Arg0, typename... Args,
2382 typename std::enable_if<!std::is_convertible<
2383 Arg0, absl::string_view>::value>::type* = nullptr>
2384 auto CustomCall(HloInstructionType** matched_inst, Arg0&& arg0,
2385 Args&&... args) {
2386 return detail::WithOperands(
2387 CustomCall(matched_inst).WithNumOperands(sizeof...(Args) + 1),
2388 /*operand_num=*/0, std::forward<Arg0>(arg0), std::forward<Args>(args)...);
2389 }
2390 template <typename HloInstructionType, typename... Args>
2391 auto CustomCall(HloInstructionType** matched_inst,
2392 absl::string_view custom_call_target, Args&&... args) {
2393 return CustomCall(matched_inst, std::forward<Args>(args)...)
2394 .WithCustomCallTarget(custom_call_target);
2395 }
2396
2397 // Helpers for comparison instructions.
2398 #define XLA_COMPARE_PATTERN(NAME) \
2399 inline auto NAME() { \
2400 return Op() \
2401 .WithOpcode(HloOpcode::kCompare) \
2402 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2403 } \
2404 \
2405 template <typename Lhs, typename Rhs> \
2406 inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
2407 return Op() \
2408 .WithOpcode(HloOpcode::kCompare) \
2409 .WithOperand(0, std::forward<Lhs>(lhs)) \
2410 .WithOperand(1, std::forward<Rhs>(rhs)) \
2411 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2412 } \
2413 \
2414 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2415 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2416 return Op(matched_inst) \
2417 .WithOpcode(HloOpcode::kCompare) \
2418 .WithOperand(0, std::forward<Lhs>(lhs)) \
2419 .WithOperand(1, std::forward<Rhs>(rhs)) \
2420 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2421 }
2422
2423 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
2424 XLA_COMPARE_PATTERN(NAME) \
2425 \
2426 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2427 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2428 Rhs&& rhs) { \
2429 return Op(matched_inst) \
2430 .WithOpcode(HloOpcode::kCompare) \
2431 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2432 std::forward<Rhs>(rhs)); \
2433 } \
2434 template <typename Lhs, typename Rhs> \
2435 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
2436 return NAME##AnyOrder<const HloInstruction>( \
2437 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2438 }
2439
2440 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2441 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2442 XLA_COMPARE_PATTERN(Ge);
2443 XLA_COMPARE_PATTERN(Gt);
2444 XLA_COMPARE_PATTERN(Le);
2445 XLA_COMPARE_PATTERN(Lt);
2446
2447 // Helpers for matching non-constant instructions.
2448 inline auto NonConstant() { return Op().IsNonConstant(); }
2449
2450 template <typename HloInstructionType>
2451 inline auto NonConstant(HloInstructionType** matched_inst) {
2452 return Op(matched_inst).IsNonConstant();
2453 }
2454
2455 // Add overloads for GetTupleElement which take a int64_t specifying which tuple
2456 // element is selected.
2457 template <typename Arg>
2458 inline auto GetTupleElement(Arg&& arg, int64_t tuple_index) {
2459 return Op()
2460 .WithOpcode(HloOpcode::kGetTupleElement)
2461 .WithOperand(0, std::forward<Arg>(arg))
2462 .WithTupleIndex(tuple_index);
2463 }
2464
2465 template <typename HloInstructionType, typename Arg>
2466 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2467 int64_t tuple_index) {
2468 return Op(matched_inst)
2469 .WithOpcode(HloOpcode::kGetTupleElement)
2470 .WithOperand(0, std::forward<Arg>(arg))
2471 .WithTupleIndex(tuple_index);
2472 }
2473
2474 // Add overloads for Parameter which take an int64_t specifying the parameter
2475 // number.
2476 inline auto Parameter(int64_t parameter_num) {
2477 return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2478 }
2479 template <typename HloInstructionType>
2480 inline auto Parameter(HloInstructionType** matched_inst,
2481 int64_t parameter_num) {
2482 return Op(matched_inst)
2483 .WithOpcode(HloOpcode::kParameter)
2484 .WithParameterNum(parameter_num);
2485 }
2486
2487 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2488
2489 template <typename HloInstructionType>
2490 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2491 return Op(matched_inst).IsConstantScalar();
2492 }
2493
2494 template <typename ScalarTy>
2495 inline auto ConstantScalar(ScalarTy val) {
2496 return Op().IsConstantScalar(val);
2497 }
2498
2499 template <typename HloInstructionType, typename ScalarTy>
2500 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2501 return Op(matched_inst).IsConstantScalar(val);
2502 }
2503
2504 inline auto ConstantEffectiveScalar() {
2505 return Op().IsConstantEffectiveScalar();
2506 }
2507
2508 template <typename HloInstructionType>
2509 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2510 return Op(matched_inst).IsConstantEffectiveScalar();
2511 }
2512
2513 template <typename ScalarTy>
2514 inline auto ConstantEffectiveScalar(ScalarTy val) {
2515 return Op().IsConstantEffectiveScalar(val);
2516 }
2517
2518 template <typename HloInstructionType, typename ScalarTy>
2519 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2520 ScalarTy val) {
2521 return Op(matched_inst).IsConstantEffectiveScalar(val);
2522 }
2523
2524 } // namespace match
2525
2526 } // namespace xla
2527
2528 #undef EXPLAIN
2529 #pragma pop_macro("EXPLAIN")
2530 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2531