xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/pattern_matcher.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18 
19 #include <functional>
20 #include <sstream>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/utility/utility.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 
36 namespace xla {
37 
38 // A pattern matcher for HloInstructions, Shapes, and Layouts.
39 //
40 // The Match function's first argument must be HloInstruction*, Shape*, or
41 // Layout*. The second argument is a pattern that will be matched against the
42 // first argument, as described below.
43 //
44 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
45 // functions. By default, the returned patterns will match any HloInstruction,
46 // Shape, or Layout, respectively. However the match can be made more specific
47 // by using the pattern's modifier methods, for example:
48 //
49 //   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
50 //     0, match::Op().WithOpcode(HloOpcode::kConstant))
51 //
52 // This pattern will match Add instructions whose first operand is a constant.
53 //
54 // Each pattern type has the following modifiers, which are described where
55 // nontrivial.
56 //
57 //   Op():
58 //     - Is: is the given HloInstruction* (i.e. pointer equality)
59 //     - WithName
60 //     - WithOpcode
61 //     - WithoutOpcode: anything other than the given opcode
62 //     - WithShape: instr's shape matches the given pattern
63 //     - WithShapeEqualTo: instr's shape is equal to the given Shape
64 //     - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
65 //     - WithElementType: instr.shape().element_type() matches the given type
66 //     - WithNumOperands
67 //     - WithOperand: operand at the given index matches the given pattern
68 //     - WithOperandIfPresent: instr has fewer than i operands or the i'th one
69 //       matches the given pattern
70 //     - IsConstant
71 //     - IsNonConstant
72 //     - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
73 //       e.g. IsConstantScalar() or IsConstantScalar(42).
74 //     - WithFusionKind
75 //     - WithTupleIndex: get-tuple-element operations with the given tuple index
76 //     - WithOneUse: Instruction is used as an operand exactly once.
77 //     - WithOneUser: Instruction is used by exactly one other instruction, but
78 //       is possibly used more than once as an operand (e.g. multiply(x,x)).
79 //     - WithComparisonDirection: instr has the given direction
80 //     - WithPredicate: Instruction matches an arbitrary function you pass.
81 //       Function must have signature `bool(const HloInstruction*)`.
82 //
83 //   Shape():
84 //     - EqualTo
85 //     - CompatibleTo
86 //     - IsScalar/IsEffectiveScalar/IsArray/IsTuple
87 //     - IsDenseArray
88 //     - WithLayout: layout shape's layout matches the given pattern (e.g.
89 //     - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
90 //       Layout, but not the result of Layout().foo())
91 //     - WithSubshape: shape is a tuple whose subshape matches the given pattern
92 //       (e.g. Shape().IsScalar()).
93 //     - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
94 //       (i.e. another Shape, but not the result of Shape().foo())
95 //     - WithElementType: shape is an array/scalar with the given elem type
96 //     - WithRank: shape is an array/scalar with the given rank
97 //
98 //  Layout():
99 //     - EqualTo
100 //
101 // Op(), Shape(), and Layout() may be passed an argument of type
102 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
103 // these pointers. If the pattern is matched, the address of the matched value
104 // will be "captured" and stored at this location.
105 //
106 // For example:
107 //   HloInstruction* foo = ...;
108 //   HloInstruction* matched_operand;
109 //   CHECK(Match(foo,
110 //               match::Op().WithOperand(0, match::Op(&matched_operand))));
111 //
112 // Helpers are provided for most HLO instructions. These helpers can be called
113 // with no arguments, in which case they will match any instruction matching the
114 // opcode. They may also be called with matches for the operands and with an
115 // optional capture. (The capture must be the first argument.) Some examples of
116 // these helpers and their equivalents are provided below.
117 
118 // Example nullary instruction:
119 //   Parameter()                    == Op().WithOpcode(HloOpcode::kParameter)
120 //   Parameter(&a)                  == Op(&a).WithOpcode(HloOpcode::kParameter)
121 //
122 // Example unary instruction:
123 //   Abs()                          == Op().WithOpcode(HloOpcode::kAbs)
124 //   Abs(Op(&a))                    == Op().WithOpcode(HloOpcode::kAbs)
125 //                                         .WithOperand(0, Op(&a)))
126 //   Abs(&a, Op(&b))                == Op(&a).WithOpcode(HloOpcode::kAbs)
127 //                                           .WithOperand(0, Op(&b))
128 //
129 // Commutative binary instructions have a special form that accepts either order
130 // of args, e.g.:
131 //
132 //   AddAnyOrder(Parameter(1), Abs()) ==
133 //     Op().WithOpcode(HloOpcode::kAdd)
134 //         .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
135 //
136 //   MultiplyAnyOrder(&a, Parameter(), Abs())  // Captures the mul in `a`.
137 //
138 // The following additional helpers are provided.  In all cases, `&a` is
139 // optional.
140 //
141 //   ConstantScalar(&a)               == Op(&a).IsConstantScalar();
142 //   ConstantScalar(&a, v)            == Op(&a).IsConstantScalar(v);
143 //   ConstantEffectiveScalar(&a)      == Op(&a).IsConstantEffectiveScalar();
144 //   ConstantEffectiveScalar(&a, v)   == Op(&a).IsConstantEffectiveScalar(&a, v)
145 //   NonConstant(&a)                  == Op(&a).IsNonConstant()
146 //   GetTupleElement(&a, b, index)    == Op(&a).WithTupleIndex(index)
147 //                                             .WithOperand(0, b);
148 //   Parameter(&a, n)                 == Op(&a).WithParameterNum(n);
149 
150 struct MatchOption {
151   // If true, actually capture matched item into the user pointer.
152   bool capture;
153 
154   // An explanation for why we failed to match is streamed here, if not-null.
155   std::ostream* explain_os;
156 };
157 
158 template <typename Value, typename Pattern>
159 bool Match(Value* value, const Pattern& pattern,
160            MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
161   if (option.capture) {
162     auto new_option = option;
163     new_option.capture = false;
164     if (!pattern.Match(value, new_option)) {
165       return false;
166     }
167   }
168   return pattern.Match(value, option);
169 }
170 
171 // If `enable_logging` is false, this is identical to Match(instr, pattern).
172 //
173 // If `enable_logging` is true and the match fails, we try to
174 // Match(instr, filter_pattern). If this is true, then we log an explanation for
175 // why the original Match(instr, pattern) failed.
176 //
177 // This function can be used aid in debugging passes with complex matchers.
178 // For example, in the following snippet we're trying to match
179 // m::Slice(m::Reshape(m::Pad())). Every time we encounter a slice that
180 // doesn't match the larger pattern, we will log an explanation for why it
181 // didn't match the larger pattern.
182 //
183 // if (MatchAndLogIfFailed(instr, "slice of reshape of pad",
184 //                         m::Slice(m::Reshape(m::Pad())),
185 //                         VLOG_IS_ON(3), m::Slice())
186 //
187 // TODO(jlebar): Log the caller's absl::SourceLocation once that's in OSS.
188 template <typename FilterPattern, typename Pattern>
MatchAndLogIfFailed(HloInstruction * instr,absl::string_view desc,const Pattern & pattern,bool enable_logging,const FilterPattern & filter_pattern)189 bool MatchAndLogIfFailed(HloInstruction* instr, absl::string_view desc,
190                          const Pattern& pattern, bool enable_logging,
191                          const FilterPattern& filter_pattern) {
192   bool matched = Match(instr, pattern);
193   if (matched || !enable_logging || !Match(instr, filter_pattern)) {
194     return matched;
195   }
196   std::stringstream os;
197   CHECK(!Match(instr, pattern, {/*capture=*/false, /*explain_os=*/&os}));
198   LOG(ERROR) << "Failed to match " << desc << ":\n" << os.str();
199   return false;
200 }
201 
202 namespace match {
203 
204 namespace detail {
205 
206 // Macro for streaming to option.explain_os if it's not null.
207 //
208 //   EXPLAIN << "value of foo(): " << foo()
209 //
210 #pragma push_macro("EXPLAIN")
211 #define EXPLAIN \
212   if (option.explain_os) *option.explain_os
213 
214 // kIndentInc is the additional number of spaces that we indent by when we
215 // increase the indent "by one".
216 enum {
217   kIndentInc = 2,
218 };
219 
220 // Writes a newline and then `indent` spaces.
221 //
222 // We follow an unintuitive convention in this file's pretty-printers: Indents
223 // are performed by the caller, not the callee.  For example, if you want to
224 // print
225 //
226 //   foo:
227 //    - bar
228 //
229 // you'd do:
230 //
231 //  Foo::DescribeTo(std::ostream* os, int64_t indent) {
232 //    *os << "foo:";
233 //    Indent(os, indent)  // Create a newline at the *current* indent level.
234 //    *os << " - ";
235 //    bar.DescribeTo(os, indent + 3);  // + 3 because strlen(" * ") == 3.
236 //  }
237 //
238 //  Bar::DescribeTo(std::ostream* os, int64_t indent) { *os << "bar"; }
239 //
240 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
241 // performed by Foo.  This convention allows the caller to decide whether a
242 // matcher is preceded by a newline, which is important e.g. for the AllOf
243 // matcher.
244 //
245 // (Incidentally, indenting in Match's explanations is handled differently.
246 // Indents are a common case in DescribeTo [we're printing a whole tree], but
247 // they're a special case in Match [we're printing only a path through the tree
248 // that encounters a failing node]. Indents in Match only appear when we
249 // encounter a failing disjunction, so we just handle them as a special case
250 // there.)
Indent(std::ostream * os,int64_t indent)251 inline void Indent(std::ostream* os, int64_t indent) {
252   *os << "\n";
253   for (int64_t i = 0; i < indent; ++i) {
254     *os << " ";
255   }
256 }
257 
258 // SFINAE template that determines whether T declares a static member
259 // kIsTrivialMatcher.
260 //
261 // Trivial matchers get special treatment.  For example, when printing
262 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
263 // yields e.g.
264 //    "a shape compatible with f32[1,2]"
265 // rather than
266 //    "a shape AND compatible with f32[1,2]"
267 template <typename T, typename Dummy = void>
268 struct IsTrivialMatcher {
269   static constexpr bool value = false;
270 };
271 template <typename T>
272 struct IsTrivialMatcher<T,
273                         typename std::enable_if<T::kIsTrivialMatcher>::type> {
274   static constexpr bool value = true;
275 };
276 
277 template <typename Item, typename... Patterns>
278 class AllOfPattern {
279  public:
280   explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
281 
282   bool Match(const Item* item, MatchOption option) const {
283     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
284     // This invariant is guaranteed by the top-level Match and AnyOf.
285     DCHECK(matched || !option.capture);
286     return matched;
287   }
288 
289   bool Match(Item* item, MatchOption option) const {
290     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
291     // This invariant is guaranteed by the top-level Match and AnyOf.
292     DCHECK(matched || !option.capture);
293     return matched;
294   }
295 
296   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
297     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
298   }
299 
300   // Accessor for patterns_.  Please don't use this outside of this file.
301   const std::tuple<Patterns...>& patterns() const { return patterns_; }
302 
303  private:
304   template <typename ItemType, size_t index>
305   bool MatchImpl(ItemType* item, MatchOption option,
306                  std::integral_constant<size_t, index>) const {
307     // We don't need to do any EXPLAINing here; it's all correctly handled by
308     // our sub-matchers (if any fail).
309     return std::get<index>(patterns_).Match(item, option) &&
310            MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
311   }
312 
313   template <typename ItemType>
314   bool MatchImpl(ItemType* item, MatchOption option,
315                  std::integral_constant<size_t, sizeof...(Patterns)>) const {
316     return true;
317   }
318 
319   // Pretty-printing a conjunction has some special cases to make it easy to
320   // read in the simple (common) case.
321   //
322   // If sizeof...(Patterns) == 1, prints as e.g.
323   //
324   //   a shape
325   //
326   // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
327   // shape") prints as
328   //
329   //   a shape compatible with f32[1,2]
330   //
331   // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
332   //
333   //   a shape:
334   //    * compatible with f32[1,2] AND
335   //    * that represents a scalar
336   //
337   // Otherwise prints as:
338   //
339   //   all of:
340   //    * foo AND
341   //    * bar
342   //
343   template <size_t index>
344   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
345                       int64_t indent) const {
346     constexpr bool first_is_trivial =
347         IsTrivialMatcher<typename std::remove_reference<decltype(std::get<0>(
348             patterns_))>::type>::value;
349     constexpr bool is_last = index == sizeof...(Patterns) - 1;
350     const auto& submatcher = std::get<index>(patterns_);
351 
352     auto print_bulleted_item = [&] {
353       *os << " * ";
354       submatcher.DescribeTo(os, indent + 3);
355       if (!is_last) {
356         *os << " AND";
357         Indent(os, indent);
358       }
359     };
360 
361     if (index == 0) {
362       if (first_is_trivial || is_last) {
363         submatcher.DescribeTo(os, indent + kIndentInc);
364         if (sizeof...(Patterns) > 2) {
365           *os << ":";
366           Indent(os, indent);
367         }
368       } else {
369         *os << "all of:";
370         Indent(os, indent);
371         print_bulleted_item();
372       }
373     } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
374       *os << " ";
375       submatcher.DescribeTo(os, indent);
376     } else {
377       print_bulleted_item();
378     }
379     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
380   }
381 
382   void DescribeToImpl(std::ostream* os,
383                       std::integral_constant<size_t, sizeof...(Patterns)>,
384                       int64_t indent) const {}
385 
386   std::tuple<Patterns...> patterns_;
387 };
388 
389 }  // namespace detail
390 
391 // Returns a pattern that represents the conjunction of all input patterns. All
392 // patterns need to match in order to have the AllOf pattern match.
393 template <typename Item, typename... Patterns>
394 auto AllOf(const Patterns&... patterns) {
395   return detail::AllOfPattern<typename std::remove_const<Item>::type,
396                               Patterns...>(patterns...);
397 }
398 
399 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
400 //
401 // This transformation is necessary for good pretty-printing.
402 template <typename Item, typename... InnerPs, typename... OuterPs>
403 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
404            const OuterPs&... outer_ps) {
405   // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
406   auto make_all_of = [](const InnerPs&... inner_ps,
407                         const OuterPs&... outer_ps) {
408     return detail::AllOfPattern<typename std::remove_const<Item>::type,
409                                 InnerPs..., OuterPs...>(inner_ps...,
410                                                         outer_ps...);
411   };
412   return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
413                                                  std::make_tuple(outer_ps...)));
414 }
415 
416 namespace detail {
417 
418 template <typename LayoutType, typename Impl>
419 class LayoutPattern;
420 
421 // The base LayoutPattern implementation. Matches only if the layout is not
422 // nullptr.
423 class LayoutPatternBaseImpl {
424  public:
425   bool Match(const ::xla::Layout* layout, MatchOption option) const {
426     if (layout == nullptr) {
427       EXPLAIN << "Layout is null";
428       return false;
429     }
430     return true;
431   }
432 
433   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
434     *os << "a layout";
435   }
436 
437   static constexpr bool kIsTrivialMatcher = true;
438 };
439 
440 // A LayoutPattern implementation that matches only if the layout equals a
441 // Layout proto.
442 class LayoutPatternEqualImpl {
443  public:
444   explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
445       : layout_(layout) {}
446 
447   bool Match(const ::xla::Layout* layout, MatchOption option) const {
448     if (!LayoutUtil::Equal(*layout_, *layout)) {
449       EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
450               << " is not equal to expected "
451               << LayoutUtil::HumanString(*layout_);
452       return false;
453     }
454     return true;
455   }
456 
457   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
458     *os << "equal to " << LayoutUtil::HumanString(*layout_);
459   }
460 
461  private:
462   const ::xla::Layout* layout_;
463 };
464 
465 // A pattern that matches Layouts.
466 template <typename LayoutType, typename Impl>
467 class LayoutPattern {
468  private:
469   template <typename NewImpl>
470   auto AppendImpl(NewImpl new_impl) const {
471     auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
472     return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
473                                                           matched_layout_);
474   }
475 
476  public:
477   explicit constexpr LayoutPattern(const Impl& impl,
478                                    LayoutType** matched_layout)
479       : impl_(impl), matched_layout_(matched_layout) {}
480 
481   // Returns true and captures the layout iff it matches the pattern.
482   bool Match(const ::xla::Layout* layout, MatchOption option) const {
483     if (impl_.Match(layout, option)) {
484       if (option.capture && matched_layout_) {
485         *matched_layout_ = layout;
486       }
487       return true;
488     }
489     return false;
490   }
491 
492   // Returns true and captures the layout iff it matches the pattern.
493   bool Match(::xla::Layout* layout, MatchOption option) const {
494     if (impl_.Match(layout, option)) {
495       if (option.capture && matched_layout_) {
496         *matched_layout_ = layout;
497       }
498       return true;
499     }
500     return false;
501   }
502 
503   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
504     impl_.DescribeTo(os, indent);
505   }
506 
507   // Modifies the pattern to match only if the layout equals the given proto.
508   // The layout must outlive the returned pattern.
509   constexpr auto EqualTo(const ::xla::Layout* layout) const {
510     return AppendImpl(LayoutPatternEqualImpl(layout));
511   }
512 
513  private:
514   Impl impl_;
515   LayoutType** matched_layout_;
516 };
517 
518 template <typename Item, typename... Patterns>
519 class AnyOfPattern {
520  public:
521   explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
522 
523   bool Match(const Item* item, MatchOption option) const {
524     return MatchImpl(item, option);
525   }
526 
527   bool Match(Item* item, MatchOption option) const {
528     return MatchImpl(item, option);
529   }
530 
531   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
532     *os << "any of:";
533     Indent(os, indent);
534     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
535   }
536 
537  private:
538   template <typename ItemType>
539   bool MatchImpl(ItemType* item, MatchOption option) const {
540     // If we're generating an explanation, buffer it until we know we failed.
541     std::optional<std::stringstream> explanation;
542     MatchOption new_option = option;
543     if (option.explain_os) {
544       new_option.explain_os = &explanation.emplace();
545     }
546     bool rv = MatchRecursiveImpl(item, new_option,
547                                  std::integral_constant<size_t, 0>());
548     if (!rv && option.explain_os) {
549       EXPLAIN << "None of the following matchers succeeded:";
550       EXPLAIN << explanation->str();
551     }
552     return rv;
553   }
554 
555   template <typename ItemType, size_t index>
556   bool MatchRecursiveImpl(ItemType* item, MatchOption option,
557                           std::integral_constant<size_t, index>) const {
558     auto new_option = option;
559     new_option.capture = false;
560 
561     std::optional<std::stringstream> explanation;
562     if (option.explain_os) {
563       new_option.explain_os = &explanation.emplace();
564     }
565 
566     // Try to match the sub-pattern without capturing behavior.
567     if (std::get<index>(patterns_).Match(item, new_option)) {
568       // Capture the branch.
569       if (option.capture) {
570         // TODO(timshen): Currently the behavior can be exponential. Optimize it
571         // with memoization or recording the matched sub-pattern index, if it
572         // takes too long to run.
573         //
574         // Specifically, the "memoization" approach is to create an empty
575         // container with the key (pattern, instruction), and value as whether
576         // matched or not.
577         //
578         // Alternatively, we may run the pattern matching with captures off, but
579         // instead record a "trace" somewhere, indicating how exactly the
580         // pattern matches the input. For example, the trace information for
581         // AnyOf will be a runtime number indicate which sub-pattern is matched.
582         // Then we run another pass to do captures only with the help of the
583         // trace.
584         bool matched = std::get<index>(patterns_).Match(item, option);
585         DCHECK(matched);
586       }
587       return true;
588     }
589     if (option.explain_os) {
590       EXPLAIN << "\nMatcher #" << index + 1;
591       EXPLAIN << "\n - ";
592       std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
593       EXPLAIN << "\nfailed with";
594       EXPLAIN << "\n - ";
595       EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n   "}});
596     }
597     return MatchRecursiveImpl(item, option,
598                               std::integral_constant<size_t, index + 1>());
599   }
600 
601   template <typename ItemType>
602   bool MatchRecursiveImpl(
603       ItemType* item, MatchOption option,
604       std::integral_constant<size_t, sizeof...(Patterns)>) const {
605     return false;
606   }
607 
608   template <size_t index>
609   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
610                       int64_t indent) const {
611     *os << " - ";
612     std::get<index>(patterns_).DescribeTo(os, indent + 3);
613     if (index != sizeof...(Patterns) - 1) {
614       *os << " OR";
615       Indent(os, indent);
616     }
617     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
618   }
619 
620   void DescribeToImpl(std::ostream* os,
621                       std::integral_constant<size_t, sizeof...(Patterns)>,
622                       int64_t indent) const {}
623 
624   std::tuple<Patterns...> patterns_;
625 };
626 
627 }  // namespace detail
628 
629 // Returns a pattern that represents the logical disjunction of the input
630 // patterns. The returned pattern matches from left to right, and stops on the
631 // first match.
632 template <typename Item, typename... Patterns>
633 auto AnyOf(const Patterns&... patterns) {
634   return detail::AnyOfPattern<typename std::remove_const<Item>::type,
635                               Patterns...>(patterns...);
636 }
637 
638 // Creates a layout pattern that will capture the matched layout in the
639 // argument.
640 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
641   return detail::LayoutPattern<const ::xla::Layout,
642                                detail::LayoutPatternBaseImpl>(
643       detail::LayoutPatternBaseImpl(), matched_layout);
644 }
645 
646 // Creates a layout pattern that will capture the matched layout in the
647 // argument.
648 inline constexpr auto Layout(::xla::Layout** matched_layout) {
649   return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
650       detail::LayoutPatternBaseImpl(), matched_layout);
651 }
652 
653 namespace detail {
654 
655 template <typename ShapeType, typename Impl>
656 class ShapePattern;
657 
658 // The base ShapePattern implementation. Matches only if the shape is not
659 // nullptr.
660 class ShapePatternBaseImpl {
661  public:
662   bool Match(const ::xla::Shape* shape, MatchOption option) const {
663     if (shape == nullptr) {
664       EXPLAIN << "Shape is null";
665     }
666     return shape != nullptr;
667   }
668 
669   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
670     *os << "a shape";
671   }
672 
673   static constexpr bool kIsTrivialMatcher = true;
674 };
675 
676 // A ShapePattern implementation that matches only if the shape equals a Shape
677 // proto.
678 class ShapePatternEqualImpl {
679  public:
680   explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
681       : shape_(shape) {}
682 
683   bool Match(const ::xla::Shape* shape, MatchOption option) const {
684     if (!ShapeUtil::Equal(*shape_, *shape)) {
685       EXPLAIN << "Shape not equal to "
686               << ShapeUtil::HumanStringWithLayout(*shape_);
687       return false;
688     }
689     return true;
690   }
691 
692   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
693     *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
694   }
695 
696  private:
697   const ::xla::Shape* shape_;
698 };
699 
700 // A ShapePattern implementation that matches only if the shape is compatible to
701 // a Shape proto.
702 class ShapePatternCompatibleImpl {
703  public:
704   explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
705       : shape_(shape) {}
706 
707   bool Match(const ::xla::Shape* shape, MatchOption option) const {
708     if (!ShapeUtil::Compatible(*shape_, *shape)) {
709       EXPLAIN << "Shape not compatible with "
710               << ShapeUtil::HumanString(*shape_);
711       return false;
712     }
713     return true;
714   }
715 
716   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
717     *os << "compatible with " << ShapeUtil::HumanString(*shape_);
718   }
719 
720  private:
721   const ::xla::Shape* shape_;
722 };
723 
724 // A ShapePattern implementation that matches only if the shape has a given
725 // element type.
726 class ShapePatternElementTypeImpl {
727  public:
728   explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
729       : element_type_(element_type) {}
730 
731   bool Match(const ::xla::Shape* shape, MatchOption option) const {
732     if (shape->element_type() != element_type_) {
733       EXPLAIN << "Shape does not have element type "
734               << PrimitiveType_Name(element_type_);
735       return false;
736     }
737     return true;
738   }
739 
740   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
741     *os << "with element type " << PrimitiveType_Name(element_type_);
742   }
743 
744  private:
745   PrimitiveType element_type_;
746 };
747 
748 // A ShapePattern implementation that matches only if the shape has a given
749 // list of dimensions.
750 class ShapePatternDimsImpl {
751  public:
752   explicit ShapePatternDimsImpl(absl::Span<const int64_t> dims)
753       : dims_(dims.begin(), dims.end()) {}
754 
755   bool Match(const ::xla::Shape* shape, MatchOption option) const {
756     if (shape->dimensions() != dims_) {
757       EXPLAIN << "Shape does not have dimensions [" << absl::StrJoin(dims_, ",")
758               << "]";
759       return false;
760     }
761     return true;
762   }
763 
764   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
765     *os << "with dimensions [" << absl::StrJoin(dims_, ",") << "]";
766   }
767 
768  private:
769   absl::InlinedVector<int64_t, 8> dims_;
770 };
771 
772 // A ShapePattern implementation that matches only if the shape is scalar.
773 class ShapePatternIsScalarImpl {
774  public:
775   explicit constexpr ShapePatternIsScalarImpl() {}
776 
777   bool Match(const ::xla::Shape* shape, MatchOption option) const {
778     if (!ShapeUtil::IsScalar(*shape)) {
779       EXPLAIN << "Shape is not a scalar";
780       return false;
781     }
782     return true;
783   }
784 
785   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
786     *os << "that represents a scalar";
787   }
788 };
789 
790 // A ShapePattern implementation that matches only if the shape is an array
791 class ShapePatternIsArrayImpl {
792  public:
793   explicit constexpr ShapePatternIsArrayImpl() {}
794 
795   bool Match(const ::xla::Shape* shape, MatchOption option) const {
796     if (!shape->IsArray()) {
797       EXPLAIN << "Shape is not an array";
798       return false;
799     }
800     return true;
801   }
802 
803   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
804     *os << "that represents an array";
805   }
806 };
807 
808 // A ShapePattern implementation that matches only if the shape is an array
809 class ShapePatternIsDenseArrayImpl {
810  public:
811   explicit constexpr ShapePatternIsDenseArrayImpl() {}
812 
813   bool Match(const ::xla::Shape* shape, MatchOption option) const {
814     if (!LayoutUtil::IsDenseArray(*shape)) {
815       EXPLAIN << "Shape is not a dense array";
816       return false;
817     }
818     return true;
819   }
820 
821   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
822     *os << "that represents a dense array";
823   }
824 };
825 
826 // A ShapePattern implementation that matches only if the shape is a tuple.
827 class ShapePatternIsTupleImpl {
828  public:
829   explicit constexpr ShapePatternIsTupleImpl() {}
830 
831   bool Match(const ::xla::Shape* shape, MatchOption option) const {
832     if (!shape->IsTuple()) {
833       EXPLAIN << "Shape is not a tuple";
834       return false;
835     }
836     return true;
837   }
838 
839   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
840     *os << "that represents a tuple";
841   }
842 };
843 
844 // A ShapePattern implementation that matches only if the shape is an effective
845 // scalar.
846 class ShapePatternEffectiveScalarImpl {
847  public:
848   explicit constexpr ShapePatternEffectiveScalarImpl() {}
849 
850   bool Match(const ::xla::Shape* shape, MatchOption option) const {
851     if (!ShapeUtil::IsEffectiveScalar(*shape)) {
852       EXPLAIN << "Shape is not an effective scalar";
853       return false;
854     }
855     return true;
856   }
857 
858   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
859     *os << "that is an effective scalar";
860   }
861 };
862 
863 // A ShapePattern implementation that matches only if the shape has a given
864 // rank.
865 class ShapePatternRankImpl {
866  public:
867   explicit constexpr ShapePatternRankImpl(int64_t rank) : rank_(rank) {}
868 
869   bool Match(const ::xla::Shape* shape, MatchOption option) const {
870     if (shape->rank() != rank_) {
871       if (rank_ == 0) {
872         EXPLAIN << "Shape is not a scalar";
873       } else {
874         EXPLAIN << "Shape does not have rank " << rank_;
875       }
876       return false;
877     }
878     return true;
879   }
880 
881   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
882     if (rank_ == 0) {
883       *os << "that is a scalar";
884     } else {
885       *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
886     }
887   }
888 
889  private:
890   int64_t rank_;
891 };
892 
893 // A ShapePattern implementation that matches only if the shape has a layout
894 // that matches a given pattern.
895 template <typename LayoutType, typename LayoutImpl>
896 class ShapePatternLayoutImpl {
897  public:
898   explicit constexpr ShapePatternLayoutImpl(
899       const LayoutPattern<LayoutType, LayoutImpl>& layout)
900       : layout_(layout) {}
901 
902   bool Match(const ::xla::Shape* shape, MatchOption option) const {
903     return LayoutUtil::HasLayout(*shape) &&
904            layout_.Match(&shape->layout(), option);
905   }
906 
907   bool Match(::xla::Shape* shape, MatchOption option) const {
908     if (!LayoutUtil::HasLayout(*shape)) {
909       EXPLAIN << "Shape does not have a layout";
910       return false;
911     }
912     if (!layout_.Match(shape->mutable_layout(), option)) {
913       EXPLAIN << "\nin layout";
914       return false;
915     }
916     return true;
917   }
918 
919   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
920     *os << "with";
921     Indent(os, indent + kIndentInc);
922     layout_.DescribeTo(os, indent + kIndentInc);
923   }
924 
925  private:
926   LayoutPattern<LayoutType, LayoutImpl> layout_;
927 };
928 
929 // A ShapePattern implementation that matches only if the shape has a subshape
930 // that matches a given pattern.
931 template <typename SubshapeType, typename SubshapeImpl>
932 class ShapePatternSubshapeImpl {
933  public:
934   explicit ShapePatternSubshapeImpl(
935       ShapeIndexView index,
936       const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
937       : index_(index), subshape_(subshape) {}
938 
939   bool Match(const ::xla::Shape* shape, MatchOption option) const {
940     return MatchImpl(shape, option);
941   }
942 
943   bool Match(::xla::Shape* shape, MatchOption option) const {
944     return MatchImpl(shape, option);
945   }
946 
947   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
948     *os << "with subshape at index " << ShapeIndex(index_) << " which is";
949     Indent(os, indent + kIndentInc);
950     subshape_.DescribeTo(os, indent + kIndentInc);
951   }
952 
953  private:
954   ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
955     return ShapeUtil::GetMutableSubshape(shape, index_);
956   }
957   const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
958     return &ShapeUtil::GetSubshape(*shape, index_);
959   }
960 
961   template <typename ShapeType>
962   bool MatchImpl(ShapeType* shape, MatchOption option) const {
963     if (!ShapeUtil::IndexIsValid(*shape, index_)) {
964       EXPLAIN << "No subshape at " << ShapeIndex(index_);
965       return false;
966     }
967     if (!subshape_.Match(GetSubshape(shape), option)) {
968       EXPLAIN << "\nin subshape at " << ShapeIndex(index_);
969       return false;
970     }
971     return true;
972   }
973 
974   ShapeIndexView index_;
975   ShapePattern<SubshapeType, SubshapeImpl> subshape_;
976 };
977 
978 // A pattern that matches Shapes.
979 template <typename ShapeType, typename Impl>
980 class ShapePattern {
981  private:
982   template <typename NewImpl>
983   auto AppendImpl(NewImpl new_impl) const {
984     auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
985     return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
986                                                          matched_shape_);
987   }
988 
989  public:
990   explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
991       : impl_(impl), matched_shape_(matched_shape) {}
992 
993   // Returns true and captures the shape iff it matches the pattern.
994   bool Match(const ::xla::Shape* shape, MatchOption option) const {
995     if (impl_.Match(shape, option)) {
996       if (option.capture && matched_shape_) {
997         *matched_shape_ = shape;
998       }
999       return true;
1000     }
1001     if (shape) {
1002       EXPLAIN << "\nin "
1003               << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
1004                                       : ShapeUtil::HumanString(*shape));
1005     }
1006     return false;
1007   }
1008 
1009   // Returns true and captures the shape iff it matches the pattern.
1010   bool Match(::xla::Shape* shape, MatchOption option) const {
1011     if (impl_.Match(shape, option)) {
1012       if (option.capture && matched_shape_) {
1013         *matched_shape_ = shape;
1014       }
1015       return true;
1016     }
1017     EXPLAIN << "\nin "
1018             << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
1019                                     : ShapeUtil::HumanString(*shape));
1020     return false;
1021   }
1022 
1023   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1024     return impl_.DescribeTo(os, indent);
1025   }
1026 
1027   // Modifies the pattern to match only if the shape equals the given proto.
1028   // The layout must outlive the returned pattern.
1029   constexpr auto EqualTo(const ::xla::Shape* shape) const {
1030     return AppendImpl(ShapePatternEqualImpl(shape));
1031   }
1032 
1033   // Modifies the pattern to match only if the shape is compatible to the given
1034   // proto. The layout must outlive the returned pattern.
1035   constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
1036     return AppendImpl(ShapePatternCompatibleImpl(shape));
1037   }
1038 
1039   // Modifies the pattern to match only if the shape has the given element type.
1040   constexpr auto WithElementType(PrimitiveType element_type) const {
1041     return AppendImpl(ShapePatternElementTypeImpl(element_type));
1042   }
1043 
1044   constexpr auto WithDims(absl::Span<const int64_t> dims) const {
1045     return AppendImpl(ShapePatternDimsImpl(dims));
1046   }
1047 
1048   // Modifies the pattern to match only if the shape is scalar.
1049   constexpr auto IsScalar() const {
1050     return AppendImpl(ShapePatternIsScalarImpl());
1051   }
1052 
1053   // Modifies the pattern to match only if the shape is an array.
1054   constexpr auto IsArray() const {
1055     return AppendImpl(ShapePatternIsArrayImpl());
1056   }
1057 
1058   // Modifies the pattern to match only if the shape is a tuple.
1059   constexpr auto IsTuple() const {
1060     return AppendImpl(ShapePatternIsTupleImpl());
1061   }
1062 
1063   constexpr auto IsEffectiveScalar() const {
1064     return AppendImpl(ShapePatternEffectiveScalarImpl());
1065   }
1066 
1067   // Modifies the pattern to match only if the shape has the given rank.
1068   constexpr auto WithRank(int64_t rank) const {
1069     return AppendImpl(ShapePatternRankImpl(rank));
1070   }
1071 
1072   // Modifies the pattern to match only if the shape has a layout that matches
1073   // the given pattern.
1074   template <typename LayoutType, typename LayoutImpl>
1075   auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1076     return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1077   }
1078 
1079   constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1080     return WithLayout(Layout().EqualTo(layout));
1081   }
1082 
1083   // Modifies the pattern to match only if the shape is a dense array.
1084   constexpr auto IsDenseArray() const {
1085     return AppendImpl(ShapePatternIsDenseArrayImpl());
1086   }
1087 
1088   // Modifies the pattern to match only if the shape has a subshape that matches
1089   // the given pattern.
1090   template <typename SubshapeType, typename SubshapeImpl>
1091   auto WithSubshape(
1092       ShapeIndexView index,
1093       const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1094     return AppendImpl(
1095         ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1096   }
1097 
1098   ShapePattern<ShapeType,
1099                AllOfPattern<::xla::Shape, Impl,
1100                             ShapePatternSubshapeImpl<
1101                                 const ::xla::Shape,
1102                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1103                                              ShapePatternEqualImpl>>>>
1104   WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1105     return WithSubshape(index,
1106                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1107                             ShapePatternBaseImpl(), nullptr)
1108                             .EqualTo(shape));
1109   }
1110 
1111   ShapePattern<ShapeType,
1112                AllOfPattern<::xla::Shape, Impl,
1113                             ShapePatternSubshapeImpl<
1114                                 const ::xla::Shape,
1115                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1116                                              ShapePatternCompatibleImpl>>>>
1117   WithSubshapeCompatibleTo(ShapeIndexView index,
1118                            const ::xla::Shape* shape) const {
1119     return WithSubshape(index,
1120                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1121                             ShapePatternBaseImpl(), nullptr)
1122                             .CompatibleTo(shape));
1123   }
1124 
1125  private:
1126   Impl impl_;
1127   ShapeType** matched_shape_;
1128 };
1129 
1130 }  // namespace detail
1131 
1132 // Creates a shape pattern that will capture the matched layout in the argument.
1133 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1134   return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1135       detail::ShapePatternBaseImpl(), matched_shape);
1136 }
1137 
1138 // Creates a shape pattern that will capture the matched layout in the argument.
1139 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1140   return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1141       detail::ShapePatternBaseImpl(), matched_shape);
1142 }
1143 
1144 namespace detail {
1145 
1146 // Overloads to get a const or non-const operand out of an instruction.
1147 inline HloInstruction* HloOperand(HloInstruction* instr, int64_t idx) {
1148   return instr->mutable_operand(idx);
1149 }
1150 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1151                                         int64_t idx) {
1152   return instr->operand(idx);
1153 }
1154 
1155 // Pretty-printer for HloInstruction.  Sort of like ToShortString, but with
1156 // fewer %s and more shapes.
1157 inline std::string InstToString(const HloInstruction* inst) {
1158   return inst->ToString(
1159       HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1160 }
1161 
1162 template <typename HloInstructionType, typename Impl>
1163 class HloInstructionPattern;
1164 
1165 // The base HloInstructionPattern implementation. Matches only if the
1166 // instruction is not nullptr.
1167 class HloInstructionPatternBaseImpl {
1168  public:
1169   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1170     if (inst == nullptr) {
1171       EXPLAIN << "HloInstruction* is null";
1172       return false;
1173     }
1174     return true;
1175   }
1176 
1177   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1178     *os << "an HloInstruction";
1179   }
1180 
1181   static constexpr bool kIsTrivialMatcher = true;
1182 };
1183 
1184 // An HloInstructionPattern implementation that matches only if the instruction
1185 // has a given name.
1186 class HloInstructionPatternNameImpl {
1187  public:
1188   explicit HloInstructionPatternNameImpl(absl::string_view name)
1189       : name_(name) {}
1190 
1191   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1192     if (inst->name() != name_) {
1193       EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1194       return false;
1195     }
1196     return true;
1197   }
1198 
1199   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1200     *os << "named \"" << name_ << "\"";
1201   }
1202 
1203  private:
1204   absl::string_view name_;
1205 };
1206 
1207 // An HloInstructionPattern implementation that matches only if the instruction
1208 // equals a particular pointer.
1209 class HloInstructionIsImpl {
1210  public:
1211   explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1212 
1213   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1214     if (inst != inst_) {
1215       EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1216               << std::showbase << reinterpret_cast<uint64_t>(inst) << " is not "
1217               << reinterpret_cast<uint64_t>(inst_) << " ("
1218               << InstToString(inst_) << ")";
1219       return false;
1220     }
1221     return true;
1222   }
1223 
1224   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1225     *os << "which is " << std::hex << std::nouppercase << std::showbase
1226         << reinterpret_cast<uint64_t>(inst_) << " (" << InstToString(inst_)
1227         << ")";
1228   }
1229 
1230  private:
1231   const HloInstruction* inst_;
1232 };
1233 
1234 // An HloInstructionPattern implementation that matches only if the instruction
1235 // has a given opcode.
1236 class HloInstructionPatternOpcodeImpl {
1237  public:
1238   explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1239                                                      bool invert)
1240       : opcode_(opcode), invert_(invert) {}
1241 
1242   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1243     if (invert_ && inst->opcode() == opcode_) {
1244       EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1245               << ", expected anything else";
1246       return false;
1247     }
1248     if (!invert_ && inst->opcode() != opcode_) {
1249       EXPLAIN << "HloInstruction doesn't have opcode "
1250               << HloOpcodeString(opcode_);
1251       return false;
1252     }
1253     return true;
1254   }
1255 
1256   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1257     if (!invert_) {
1258       *os << "with opcode " << HloOpcodeString(opcode_);
1259     } else {
1260       *os << "with any opcode other than " << HloOpcodeString(opcode_);
1261     }
1262   }
1263 
1264  private:
1265   HloOpcode opcode_;
1266   bool invert_;
1267 };
1268 
1269 // An HloInstructionPattern implementation that matches only if the instruction
1270 // has one of a given list of custom call targets.
1271 class HloInstructionCustomCallTargetImpl {
1272  public:
1273   explicit HloInstructionCustomCallTargetImpl(
1274       absl::Span<const absl::string_view> custom_call_targets)
1275       : custom_call_targets_(custom_call_targets.begin(),
1276                              custom_call_targets.end()) {}
1277 
1278   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1279     if (inst->opcode() != HloOpcode::kCustomCall ||
1280         !absl::c_linear_search(custom_call_targets_,
1281                                inst->custom_call_target())) {
1282       if (custom_call_targets_.size() == 1) {
1283         EXPLAIN << "HloInstruction is not a custom call with a target '"
1284                 << custom_call_targets_.front() << "'";
1285       } else {
1286         EXPLAIN << "HloInstruction is not a custom call with a target in {"
1287                 << absl::StrJoin(custom_call_targets_, ", ") << "}";
1288       }
1289       return false;
1290     }
1291     return true;
1292   }
1293 
1294   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1295     if (custom_call_targets_.size() == 1) {
1296       *os << "custom call with target '" << custom_call_targets_.front() << "'";
1297     } else {
1298       *os << "custom call with target in {"
1299           << absl::StrJoin(custom_call_targets_, ", ") << "}";
1300     }
1301   }
1302 
1303  private:
1304   absl::InlinedVector<std::string, 1> custom_call_targets_;
1305 };
1306 
1307 // An HloInstructionPattern implementation that matches only if the instruction
1308 // has the given number of operands.
1309 class HloInstructionPatternNumOperandsImpl {
1310  public:
1311   explicit constexpr HloInstructionPatternNumOperandsImpl(int64_t num_operands)
1312       : num_operands_(num_operands) {}
1313 
1314   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1315     if (inst->operand_count() != num_operands_) {
1316       EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1317       return false;
1318     }
1319     return true;
1320   }
1321 
1322   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1323     *os << "with " << num_operands_ << " operand"
1324         << (num_operands_ != 1 ? "s" : "");
1325   }
1326 
1327  private:
1328   int64_t num_operands_;
1329 };
1330 
1331 // An HloInstructionPattern implementation that matches only if the instruction
1332 // has a shape that matches a given pattern.
1333 template <typename ShapeType, typename ShapeImpl>
1334 class HloInstructionPatternShapeImpl {
1335  public:
1336   explicit constexpr HloInstructionPatternShapeImpl(
1337       const ShapePattern<ShapeType, ShapeImpl>& shape)
1338       : shape_(shape) {}
1339 
1340   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1341     if (!shape_.Match(&inst->shape(), option)) {
1342       EXPLAIN << "\nin output shape";
1343       return false;
1344     }
1345     return true;
1346   }
1347 
1348   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1349     if (!shape_.Match(inst->mutable_shape(), option)) {
1350       EXPLAIN << "\nin output shape";
1351       return false;
1352     }
1353     return true;
1354   }
1355 
1356   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1357     *os << "outputting";
1358     Indent(os, indent + kIndentInc);
1359     shape_.DescribeTo(os, indent + kIndentInc);
1360   }
1361 
1362  private:
1363   ShapePattern<ShapeType, ShapeImpl> shape_;
1364 };
1365 
1366 // An HloInstructionPattern implementation that matches only if the instruction
1367 // has an operand that matches a given pattern.
1368 template <typename OperandType, typename OperandImpl>
1369 class HloInstructionPatternOperandImpl {
1370  public:
1371   explicit constexpr HloInstructionPatternOperandImpl(
1372       int64_t operand_index,
1373       const HloInstructionPattern<OperandType, OperandImpl>& operand)
1374       : operand_index_(operand_index), operand_(operand) {}
1375 
1376   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1377     return MatchImpl(inst, option);
1378   }
1379 
1380   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1381     return MatchImpl(inst, option);
1382   }
1383 
1384   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1385     *os << "with operand " << operand_index_ << " which is:";
1386     Indent(os, indent + kIndentInc);
1387     operand_.DescribeTo(os, indent + kIndentInc);
1388   }
1389 
1390  private:
1391   template <typename HloInstructionType>
1392   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1393     if (operand_index_ >= inst->operand_count()) {
1394       EXPLAIN << "desired operand index " << operand_index_
1395               << " is out of bounds";
1396       return false;
1397     }
1398     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1399       EXPLAIN << "\nin operand " << operand_index_;
1400       return false;
1401     }
1402     return true;
1403   }
1404 
1405   int64_t operand_index_;
1406   HloInstructionPattern<OperandType, OperandImpl> operand_;
1407 };
1408 
1409 // An HloInstructionPattern implementation that matches if the instruction has
1410 // fewer than i+1 operands, or if the i'th operand matches a given pattern.
1411 template <typename OperandType, typename OperandImpl>
1412 class HloInstructionPatternOperandIfPresentImpl {
1413  public:
1414   explicit constexpr HloInstructionPatternOperandIfPresentImpl(
1415       int64_t operand_index,
1416       const HloInstructionPattern<OperandType, OperandImpl>& operand)
1417       : operand_index_(operand_index), operand_(operand) {}
1418 
1419   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1420     return MatchImpl(inst, option);
1421   }
1422 
1423   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1424     return MatchImpl(inst, option);
1425   }
1426 
1427   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1428     *os << "either with fewer than " << operand_index_ + 1 << " operand"
1429         << (operand_index_ + 1 != 1 ? "s" : "") << ", or with an operand "
1430         << operand_index_ << " which is:";
1431     Indent(os, indent + kIndentInc);
1432     operand_.DescribeTo(os, indent + kIndentInc);
1433   }
1434 
1435  private:
1436   template <typename HloInstructionType>
1437   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1438     if (operand_index_ >= inst->operand_count()) {
1439       return true;
1440     }
1441     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1442       EXPLAIN << "\nin operand " << operand_index_;
1443       return false;
1444     }
1445     return true;
1446   }
1447 
1448   int64_t operand_index_;
1449   HloInstructionPattern<OperandType, OperandImpl> operand_;
1450 };
1451 
1452 // Matches a binary instruction whose operands come in any order.
1453 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1454           typename OperandImpl2>
1455 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1456  public:
1457   explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1458       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1459       const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1460       : op1_(op1), op2_(op2) {}
1461 
1462   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1463     return MatchImpl(inst, option);
1464   }
1465 
1466   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1467     return MatchImpl(inst, option);
1468   }
1469 
1470   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1471     *os << "with two operands in either order:";
1472     Indent(os, indent);
1473     *os << " - ";
1474     op1_.DescribeTo(os, indent + 3);
1475     Indent(os, indent);
1476     *os << " - ";
1477     op2_.DescribeTo(os, indent + 3);
1478   }
1479 
1480  private:
1481   HloInstruction* operand(HloInstruction* inst, int64_t idx) const {
1482     return inst->mutable_operand(idx);
1483   }
1484   const HloInstruction* operand(const HloInstruction* inst, int64_t idx) const {
1485     return inst->operand(idx);
1486   }
1487 
1488   template <typename HloInstructionType>
1489   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1490     // We could implement this using AnyOf and AllOf matchers, but the templates
1491     // get pretty difficult to debug, since any compile error herein becomes
1492     // not-an-error via SFINAE.  Also this way lets us give better messages on
1493     // failure.
1494     if (inst->operand_count() != 2) {
1495       EXPLAIN << "HloInstruction did not have two operands";
1496       return false;
1497     }
1498 
1499     // If we're not generating explanations, this is pretty simple.
1500     if (!option.explain_os) {
1501       auto try_match = [&](int64_t idx1, int64_t idx2) {
1502         MatchOption new_option = option;
1503         new_option.capture = false;
1504         if (op1_.Match(operand(inst, idx1), new_option) &&
1505             op2_.Match(operand(inst, idx2), new_option)) {
1506           if (option.capture) {
1507             bool matched = op1_.Match(operand(inst, idx1), option) &&
1508                            op2_.Match(operand(inst, idx2), option);
1509             DCHECK(matched);
1510           }
1511           return true;
1512         }
1513         return false;
1514       };
1515       return try_match(0, 1) || try_match(1, 0);
1516     }
1517 
1518     // If we are generating explanations, we have some work to do in order to
1519     // generate a helpful error.
1520     //
1521     // First, try all four operand/matcher combinations, recording the
1522     // failure explanations separately from option.explain_os. matches[i][j]
1523     // tells us if matcher_i matches operand j.
1524     bool matches[/*matcher*/ 2][/*operand*/ 2];
1525     std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1526     for (int i = 0; i < 2; ++i) {
1527       for (int j = 0; j < 2; ++j) {
1528         MatchOption new_option = option;
1529         new_option.capture = false;
1530         new_option.explain_os = &explanations[i][j];
1531         matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1532                                : op2_.Match(operand(inst, j), new_option);
1533       }
1534     }
1535 
1536     // Check if the match succeeded.
1537     for (int i = 0; i < 2; ++i) {
1538       if (matches[0][i] && matches[1][(i + 1) % 2]) {
1539         // Rerun the matches with capture enabled if necessary.
1540         if (option.capture) {
1541           auto* operand1 = operand(inst, i);
1542           auto* operand2 = operand(inst, (i + 1) % 2);
1543           bool matched =
1544               op1_.Match(operand1, option) && op2_.Match(operand2, option);
1545           DCHECK(matched);
1546         }
1547         return true;
1548       }
1549     }
1550 
1551     auto describe_matcher = [&](int matcher_idx) {
1552       EXPLAIN << "\n - ";
1553       if (matcher_idx == 0) {
1554         op1_.DescribeTo(option.explain_os, /*indent=*/3);
1555       } else {
1556         CHECK_EQ(matcher_idx, 1);
1557         op2_.DescribeTo(option.explain_os, /*indent=*/3);
1558       }
1559       for (int i = 0; i < 2; ++i) {
1560         if (matches[matcher_idx][/*operand*/ i]) {
1561           continue;
1562         }
1563         EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1564         EXPLAIN << " - ";
1565         EXPLAIN << absl::StrReplaceAll(
1566             explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n   "}});
1567       }
1568     };
1569 
1570     // If we failed to match, one of the following is true:
1571     //  1. op1 (op2) matches neither LHS nor RHS, or
1572     //  2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1573     // We print different explanations depending on which case we're in.
1574 
1575     // Case 1.
1576     bool wrote_explanation = false;
1577     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1578       if (!matches[i][0] && !matches[i][1]) {
1579         EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1580                 << (i == 0 ? "first" : "second") << " matcher.  Specifically,";
1581         describe_matcher(i);
1582         wrote_explanation = true;
1583       }
1584     }
1585 
1586     // Case 2.
1587     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1588       if (matches[/*matcher*/ 0][/*operand*/ i] &&
1589           matches[/*matcher*/ 1][/*operand*/ i]) {
1590         CHECK(!matches[0][(i + 1) % 2]);
1591         CHECK(!matches[1][(i + 1) % 2]);
1592         CHECK(!wrote_explanation);
1593         EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1594                 << " operand did not match either of the two matchers.  "
1595                    "Specifically,";
1596         describe_matcher(0);
1597         EXPLAIN << "\nand";
1598         describe_matcher(1);
1599         wrote_explanation = true;
1600       }
1601     }
1602 
1603     CHECK(wrote_explanation);
1604     return false;
1605   }
1606 
1607   HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1608   HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1609 };
1610 
1611 // An HloInstructionPattern implementation that matches only if the instruction
1612 // is a fusion node with a particular kind.
1613 class HloInstructionPatternFusionKindImpl {
1614  public:
1615   explicit constexpr HloInstructionPatternFusionKindImpl(
1616       ::xla::HloInstruction::FusionKind kind)
1617       : kind_(kind) {}
1618 
1619   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1620     return MatchImpl(inst, option);
1621   }
1622 
1623   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1624     return MatchImpl(inst, option);
1625   }
1626 
1627   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1628     *os << "with fusion kind " << ToString(kind_);
1629   }
1630 
1631  private:
1632   template <typename HloInstructionType>
1633   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1634     if (inst->opcode() != HloOpcode::kFusion) {
1635       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1636               << "; it's not a fusion";
1637       return false;
1638     }
1639     if (inst->fusion_kind() != kind_) {
1640       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1641       return false;
1642     }
1643     return true;
1644   }
1645 
1646   ::xla::HloInstruction::FusionKind kind_;
1647 };
1648 
1649 // An HloInstructionPattern implementation that matches only if the instruction
1650 // is a kGetTupleElement with a particular tuple index.
1651 class HloInstructionPatternTupleIndexImpl {
1652  public:
1653   explicit constexpr HloInstructionPatternTupleIndexImpl(int64_t tuple_index)
1654       : tuple_index_(tuple_index) {}
1655 
1656   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1657     return MatchImpl(inst, option);
1658   }
1659 
1660   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1661     return MatchImpl(inst, option);
1662   }
1663 
1664   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1665     *os << "which is a GTE with index " << tuple_index_;
1666   }
1667 
1668  private:
1669   template <typename HloInstructionType>
1670   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1671     if (inst->opcode() != HloOpcode::kGetTupleElement) {
1672       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1673               << "; it's not a GTE at all";
1674       return false;
1675     }
1676     if (inst->tuple_index() != tuple_index_) {
1677       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1678       return false;
1679     }
1680     return true;
1681   }
1682 
1683   int64_t tuple_index_;
1684 };
1685 
1686 class HloInstructionPatternParameterNumImpl {
1687  public:
1688   explicit constexpr HloInstructionPatternParameterNumImpl(
1689       int64_t parameter_num)
1690       : parameter_num_(parameter_num) {}
1691 
1692   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1693     return MatchImpl(inst, option);
1694   }
1695 
1696   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1697     return MatchImpl(inst, option);
1698   }
1699 
1700   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1701     *os << "which is parameter " << parameter_num_;
1702   }
1703 
1704  private:
1705   template <typename HloInstructionType>
1706   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1707     if (inst->opcode() != HloOpcode::kParameter ||
1708         inst->parameter_number() != parameter_num_) {
1709       EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1710       return false;
1711     }
1712     return true;
1713   }
1714 
1715   int64_t parameter_num_;
1716 };
1717 
1718 // Superclass that contains common code used by Op::WithOneUse() and
1719 // Op::WithOneUser().
1720 class HloInstructionPatternOneUseOrUserImpl {
1721  protected:
1722   bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1723     if (inst->user_count() != 1) {
1724       EXPLAIN << "HloInstruction has " << inst->user_count()
1725               << " users, but expected exactly one.";
1726       if (inst->user_count() > 1) {
1727         EXPLAIN << "\nAll users:";
1728         for (const HloInstruction* user : inst->users()) {
1729           EXPLAIN << "\n - " << InstToString(user);
1730         }
1731       }
1732       return false;
1733     }
1734     return true;
1735   }
1736 };
1737 
1738 class HloInstructionPatternOneUseImpl
1739     : public HloInstructionPatternOneUseOrUserImpl {
1740  public:
1741   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1742     if (!MatchOneUser(inst, option)) {
1743       return false;
1744     }
1745 
1746     int64_t use_count = absl::c_count_if(
1747         inst->users()[0]->operands(),
1748         [&](const HloInstruction* operand) { return operand == inst; });
1749     if (use_count != 1) {
1750       EXPLAIN << "HloInstruction is used " << use_count
1751               << " times by its user, but is expected to be used just once: "
1752               << InstToString(inst->users()[0]);
1753       return false;
1754     }
1755     return true;
1756   }
1757 
1758   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1759     *os << "which has exactly one use";
1760   }
1761 };
1762 
1763 class HloInstructionPatternOneUserImpl
1764     : public HloInstructionPatternOneUseOrUserImpl {
1765  public:
1766   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1767     return MatchOneUser(inst, option);
1768   }
1769 
1770   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1771     *os << "which has exactly one user (but possibly is used multiple times by "
1772            "that instruction)";
1773   }
1774 };
1775 
1776 class HloInstructionPatternComparisonDirectionImpl {
1777  public:
1778   explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1779       ComparisonDirection direction)
1780       : direction_(direction) {}
1781 
1782   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1783     return MatchImpl(inst, option);
1784   }
1785 
1786   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1787     return MatchImpl(inst, option);
1788   }
1789 
1790   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1791     *os << "which has comparison direction "
1792         << ComparisonDirectionToString(direction_);
1793   }
1794 
1795  private:
1796   template <typename HloInstructionType>
1797   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1798     if (inst->opcode() != HloOpcode::kCompare ||
1799         inst->comparison_direction() != direction_) {
1800       EXPLAIN << "HloInstruction is not comparison "
1801               << ComparisonDirectionToString(direction_);
1802       return false;
1803     }
1804     return true;
1805   }
1806 
1807   ComparisonDirection direction_;
1808 };
1809 
1810 class HloInstructionPredicateImpl {
1811  public:
1812   explicit HloInstructionPredicateImpl(HloPredicate fn) : fn_(std::move(fn)) {}
1813 
1814   bool Match(const HloInstruction* inst, MatchOption option) const {
1815     bool match = fn_(inst);
1816     if (!match) {
1817       EXPLAIN << "HloInstruction does not match user-specified predicate";
1818     }
1819     return match;
1820   }
1821 
1822   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1823     *os << "which matches a user-specified predicate";
1824   }
1825 
1826  private:
1827   HloPredicate fn_;
1828 };
1829 
1830 // Matches a constant scalar or effective scalar, optionally with a given value.
1831 template <typename ScalarTy>
1832 class HloConstantScalarImpl {
1833  public:
1834   explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1835       : val_(std::nullopt), match_effective_scalar_(match_effective_scalar) {}
1836 
1837   constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1838       : val_(val), match_effective_scalar_(match_effective_scalar) {}
1839 
1840   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1841     return MatchImpl(inst, option);
1842   }
1843 
1844   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1845     return MatchImpl(inst, option);
1846   }
1847 
1848   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1849     *os << "which is a constant "
1850         << (match_effective_scalar_ ? "effective " : "") << "scalar";
1851     if (val_.has_value()) {
1852       *os << " with value " << *val_;
1853     }
1854   }
1855 
1856  private:
1857   template <typename InstTy>
1858   bool MatchImpl(InstTy* inst, MatchOption option) const {
1859     const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1860     if (!const_inst) {
1861       EXPLAIN << "HloInstruction is not a constant";
1862       return false;
1863     }
1864     if (match_effective_scalar_ &&
1865         !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1866       EXPLAIN << "HloInstruction is not an effective scalar";
1867       return false;
1868     }
1869     if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1870       EXPLAIN << "HloInstruction is not a scalar";
1871       return false;
1872     }
1873     if (!val_.has_value()) {
1874       return true;
1875     }
1876 
1877     auto const_inst_scalar_or = const_inst->literal().Reshape({});
1878     if (!const_inst_scalar_or.ok()) {
1879       EXPLAIN << "could not convert matched literal to effective scalar";
1880       return false;
1881     }
1882     Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1883     if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1884       EXPLAIN << "HloInstruction's constant value "
1885               << const_inst_scalar.ToStringWithoutShape()
1886               << " did not match expected value " << *val_;
1887       return false;
1888     }
1889     return true;
1890   }
1891 
1892   std::optional<ScalarTy> val_;
1893   bool match_effective_scalar_;
1894 };
1895 
1896 // A pattern that matches HloInstructions.
1897 template <typename HloInstructionType, typename Impl>
1898 class HloInstructionPattern {
1899  private:
1900   template <typename NewImpl>
1901   auto AppendImpl(NewImpl new_impl) const {
1902     auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1903     return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1904         std::move(new_allof), matched_inst_);
1905   }
1906 
1907  public:
1908   explicit constexpr HloInstructionPattern(const Impl& impl,
1909                                            HloInstructionType** matched_inst)
1910       : impl_(impl), matched_inst_(matched_inst) {}
1911 
1912   // Returns true and captures the instruction iff it matches the pattern.
1913   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1914     if (impl_.Match(inst, option)) {
1915       if (option.capture && matched_inst_) {
1916         *matched_inst_ = inst;
1917       }
1918       return true;
1919     }
1920     if (inst != nullptr) {
1921       EXPLAIN << "\nin " << InstToString(inst);
1922     }
1923     return false;
1924   }
1925 
1926   // Returns true and captures the instruction iff it matches the pattern.
1927   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1928     if (impl_.Match(inst, option)) {
1929       if (option.capture && matched_inst_) {
1930         *matched_inst_ = inst;
1931       }
1932       return true;
1933     }
1934     EXPLAIN << "\nin " << InstToString(inst);
1935     return false;
1936   }
1937 
1938   // Modifies the pattern to match only if the instruction has the given name.
1939   auto WithName(absl::string_view name) const {
1940     return AppendImpl(HloInstructionPatternNameImpl(name));
1941   }
1942 
1943   // Modifies the pattern to match only if the instruction has the given opcode.
1944   auto WithOpcode(HloOpcode opcode) const {
1945     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1946   }
1947 
1948   // Modifies the pattern to match only the custom call with a given target.
1949   auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1950     return AppendImpl(HloInstructionCustomCallTargetImpl({custom_call_target}));
1951   }
1952 
1953   // Modifies the pattern to match a custom call with one of the given targets.
1954   auto WithCustomCallTarget(
1955       absl::Span<const absl::string_view> custom_call_targets) const {
1956     return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_targets));
1957   }
1958 
1959   auto WithNumOperands(int64_t num_operands) const {
1960     return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1961   }
1962 
1963   // Modifies the pattern to match only if the instruction does not have the
1964   // given opcode.
1965   auto WithoutOpcode(HloOpcode opcode) const {
1966     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1967   }
1968 
1969   constexpr auto Is(const HloInstruction* instr) const {
1970     return AppendImpl(HloInstructionIsImpl(instr));
1971   }
1972 
1973   // Modifies the pattern to match only if the instruction is a constant.
1974   constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1975 
1976   constexpr auto IsConstantScalar() const {
1977     return AppendImpl(
1978         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1979   }
1980 
1981   // This does not check that T has the same type as the instruction, so e.g.
1982   // IsConstantScalar(1.0) may match a constant of shape int32_t[].
1983   template <typename ScalarTy>
1984   constexpr auto IsConstantScalar(const ScalarTy& val) const {
1985     return AppendImpl(
1986         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1987   }
1988 
1989   constexpr auto IsConstantEffectiveScalar() const {
1990     return AppendImpl(
1991         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1992   }
1993 
1994   template <typename ScalarTy>
1995   constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1996     return AppendImpl(
1997         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1998   }
1999 
2000   // Modifies the pattern to match only if the instruction is not a constant.
2001   constexpr auto IsNonConstant() const {
2002     return WithoutOpcode(HloOpcode::kConstant);
2003   }
2004 
2005   // Modifies the pattern to match only if the instruction has a shape that
2006   // matches the given pattern.
2007   template <typename ShapeType, typename ShapeImpl>
2008   constexpr auto WithShape(
2009       const ShapePattern<ShapeType, ShapeImpl>& shape) const {
2010     return AppendImpl(
2011         HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
2012   }
2013 
2014   // Because we only specify the shape's element type and dims, this is
2015   // effectivley checking shape-compatible-to, not shape-equal-to.  Perhaps this
2016   // function should be called WithShapeCompatibleTo, but the short name is
2017   // nice, and there's no ambiguity because there's no layout in the args!
2018   constexpr auto WithShape(PrimitiveType ty, absl::Span<const int64_t> dims) {
2019     return WithShape(Shape().WithElementType(ty).WithDims(dims));
2020   }
2021 
2022   // Make this a templated function to work around gcc 4.9.4 template infinite
2023   // recursion bug.
2024   template <typename Dummy = void>
2025   constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
2026     return WithShape(Shape().EqualTo(shape));
2027   }
2028 
2029   // Make this a templated function to work around gcc 4.9.4 template infinite
2030   // recursion bug.
2031   template <typename Dummy = void>
2032   constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
2033     return WithShape(Shape().CompatibleTo(shape));
2034   }
2035 
2036   // Modifies the pattern to match only if the instruction's shape's element
2037   // type matches the given pattern.
2038   constexpr auto WithElementType(PrimitiveType ty) {
2039     return WithShape(Shape().WithElementType(ty));
2040   }
2041 
2042   // Modifies the pattern to match only if the instruction has an operand that
2043   // matches the given pattern.
2044   template <typename OperandType, typename OperandImpl>
2045   constexpr auto WithOperand(
2046       int64_t operand_index,
2047       const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
2048     return AppendImpl(
2049         HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
2050             operand_index, operand));
2051   }
2052 
2053   // Modifies the pattern to match only if
2054   //  - the instruction has fewer than i+1 operands, or
2055   //  - the i'th operand matches the given pattern.
2056   template <typename OperandType, typename OperandImpl>
2057   constexpr auto WithOperandIfPresent(
2058       int64_t operand_index,
2059       const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
2060     return AppendImpl(
2061         HloInstructionPatternOperandIfPresentImpl<OperandType, OperandImpl>(
2062             operand_index, operand));
2063   }
2064 
2065   template <typename OperandType1, typename OperandImpl1, typename OperandType2,
2066             typename OperandImpl2>
2067   constexpr auto WithBinaryOperandsAnyOrder(
2068       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
2069       const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
2070     return AppendImpl(
2071         HloInstructionPatternBinaryOperandsAnyOrderImpl<
2072             OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
2073   }
2074 
2075   // Modifies the pattern to match only if the instruction is a fusion node with
2076   // the given kind.
2077   constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
2078     return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
2079   }
2080 
2081   // Modifies the pattern to match only if the instruction is a
2082   // get-tuple-element with the given tuple index.
2083   constexpr auto WithTupleIndex(int64_t tuple_index) const {
2084     return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
2085   }
2086 
2087   // Modifies the pattern to match only if the instruction is a parameter
2088   // with the given parameter number.
2089   constexpr auto WithParameterNum(int64_t parameter_num) const {
2090     return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
2091   }
2092 
2093   // Modifies the pattern to match if the instruction is used exactly once.
2094   // Does not match if the instruction is used twice by the same user (e.g.
2095   // multiply(x,x)).
2096   constexpr auto WithOneUse() const {
2097     return AppendImpl(HloInstructionPatternOneUseImpl());
2098   }
2099 
2100   // Modifies the pattern to match if the instruction is used by exactly one
2101   // other instruction.  Will match if the instruction is used twice, so long as
2102   // it's by the same user (e.g.  multiply(x,x)).
2103   constexpr auto WithOneUser() const {
2104     return AppendImpl(HloInstructionPatternOneUserImpl());
2105   }
2106 
2107   // Modifies the pattern to match only if the instruction has the given
2108   // comparison direction.
2109   auto WithComparisonDirection(ComparisonDirection direction) const {
2110     return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
2111   }
2112 
2113   auto WithPredicate(HloPredicate fn) const {
2114     return AppendImpl(HloInstructionPredicateImpl(std::move(fn)));
2115   }
2116 
2117   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
2118     impl_.DescribeTo(os, indent);
2119   }
2120 
2121  private:
2122   Impl impl_;
2123   HloInstructionType** matched_inst_;
2124 };
2125 
2126 }  // namespace detail
2127 
2128 // Creates an instruction pattern that will capture the matched instruction in
2129 // the argument.
2130 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
2131   return detail::HloInstructionPattern<const ::xla::HloInstruction,
2132                                        detail::HloInstructionPatternBaseImpl>(
2133       detail::HloInstructionPatternBaseImpl(), matched_inst);
2134 }
2135 
2136 // Creates an instruction pattern that will capture the matched instruction in
2137 // the argument.
2138 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
2139   return detail::HloInstructionPattern<::xla::HloInstruction,
2140                                        detail::HloInstructionPatternBaseImpl>(
2141       detail::HloInstructionPatternBaseImpl(), matched_inst);
2142 }
2143 
2144 // Helpers for nullary instructions.
2145 #define XLA_NULLOP_PATTERN(NAME)                                     \
2146   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2147                                                                      \
2148   template <typename HloInstructionType>                             \
2149   inline auto NAME(HloInstructionType** matched_inst) {              \
2150     return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);          \
2151   }
2152 XLA_NULLOP_PATTERN(Constant)
2153 XLA_NULLOP_PATTERN(Parameter)
2154 XLA_NULLOP_PATTERN(Iota)
2155 XLA_NULLOP_PATTERN(Rng)
2156 XLA_NULLOP_PATTERN(PartitionId)
2157 XLA_NULLOP_PATTERN(ReplicaId)
2158 #undef XLA_NULLOP_PATTERN
2159 
2160 // Helpers for unary instructions.
2161 #define XLA_UNOP_PATTERN(NAME)                                       \
2162   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2163                                                                      \
2164   template <typename Arg>                                            \
2165   inline auto NAME(Arg&& arg) {                                      \
2166     return Op()                                                      \
2167         .WithOpcode(HloOpcode::k##NAME)                              \
2168         .WithOperand(0, std::forward<Arg>(arg));                     \
2169   }                                                                  \
2170                                                                      \
2171   template <typename HloInstructionType, typename Arg>               \
2172   inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) {   \
2173     return Op(matched_inst)                                          \
2174         .WithOpcode(HloOpcode::k##NAME)                              \
2175         .WithOperand(0, std::forward<Arg>(arg));                     \
2176   }
2177 XLA_UNOP_PATTERN(Abs)
2178 XLA_UNOP_PATTERN(RoundNearestAfz)
2179 XLA_UNOP_PATTERN(Bitcast)
2180 XLA_UNOP_PATTERN(BitcastConvert)
2181 XLA_UNOP_PATTERN(Broadcast)
2182 XLA_UNOP_PATTERN(Ceil)
2183 XLA_UNOP_PATTERN(Convert)
2184 XLA_UNOP_PATTERN(Copy)
2185 XLA_UNOP_PATTERN(Cos)
2186 XLA_UNOP_PATTERN(AllReduce)
2187 XLA_UNOP_PATTERN(Exp)
2188 XLA_UNOP_PATTERN(Fft)
2189 XLA_UNOP_PATTERN(Floor)
2190 XLA_UNOP_PATTERN(GetTupleElement)
2191 XLA_UNOP_PATTERN(Imag)
2192 XLA_UNOP_PATTERN(Infeed)
2193 XLA_UNOP_PATTERN(IsFinite)
2194 XLA_UNOP_PATTERN(Log)
2195 XLA_UNOP_PATTERN(Not)
2196 XLA_UNOP_PATTERN(Negate)
2197 XLA_UNOP_PATTERN(Real)
2198 XLA_UNOP_PATTERN(Recv)
2199 XLA_UNOP_PATTERN(RecvDone)
2200 XLA_UNOP_PATTERN(ReducePrecision)
2201 XLA_UNOP_PATTERN(Reshape)
2202 XLA_UNOP_PATTERN(Reverse)
2203 XLA_UNOP_PATTERN(Rsqrt)
2204 XLA_UNOP_PATTERN(SendDone)
2205 XLA_UNOP_PATTERN(Sign)
2206 XLA_UNOP_PATTERN(Sin)
2207 XLA_UNOP_PATTERN(Slice)
2208 XLA_UNOP_PATTERN(Sqrt)
2209 XLA_UNOP_PATTERN(Tanh)
2210 XLA_UNOP_PATTERN(Transpose)
2211 #undef XLA_UNOP_PATTERN
2212 
2213 // Helpers for binary instructions.
2214 #define XLA_BINOP_PATTERN(NAME)                                               \
2215   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2216                                                                               \
2217   template <typename Lhs, typename Rhs>                                       \
2218   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2219     return Op()                                                               \
2220         .WithOpcode(HloOpcode::k##NAME)                                       \
2221         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2222         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2223   }                                                                           \
2224                                                                               \
2225   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2226   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2227     return Op(matched_inst)                                                   \
2228         .WithOpcode(HloOpcode::k##NAME)                                       \
2229         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2230         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2231   }
2232 
2233 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                \
2234   XLA_BINOP_PATTERN(NAME)                                                  \
2235                                                                            \
2236   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2237   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2238                              Rhs&& rhs) {                                  \
2239     return Op(matched_inst)                                                \
2240         .WithOpcode(HloOpcode::k##NAME)                                    \
2241         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2242                                     std::forward<Rhs>(rhs));               \
2243   }                                                                        \
2244   template <typename Lhs, typename Rhs>                                    \
2245   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2246     return NAME##AnyOrder<const HloInstruction>(                           \
2247         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2248   }
2249 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2250 XLA_BINOP_PATTERN(Atan2)
2251 XLA_BINOP_PATTERN(Divide)
2252 XLA_BINOP_PATTERN(Complex)
2253 XLA_BINOP_PATTERN(Compare)
2254 XLA_BINOP_PATTERN(Convolution)
2255 XLA_BINOP_PATTERN(Dot)
2256 XLA_BINOP_PATTERN(Gather)
2257 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2258 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2259 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2260 XLA_BINOP_PATTERN(Outfeed)
2261 XLA_BINOP_PATTERN(Pad)
2262 XLA_BINOP_PATTERN(Power)
2263 XLA_BINOP_PATTERN(Remainder)
2264 XLA_BINOP_PATTERN(Send)
2265 XLA_BINOP_PATTERN(Subtract)
2266 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2267 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2268 XLA_BINOP_PATTERN(ShiftLeft)
2269 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2270 XLA_BINOP_PATTERN(ShiftRightLogical)
2271 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2272 #undef XLA_BINOP_PATTERN
2273 
2274 // Helpers for ternary instructions.
2275 #define XLA_TERNOP_PATTERN(NAME)                                       \
2276   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }   \
2277                                                                        \
2278   template <typename Arg0, typename Arg1, typename Arg2>               \
2279   inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {            \
2280     return Op()                                                        \
2281         .WithOpcode(HloOpcode::k##NAME)                                \
2282         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2283         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2284         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2285   }                                                                    \
2286                                                                        \
2287   template <typename HloInstructionType, typename Arg0, typename Arg1, \
2288             typename Arg2>                                             \
2289   inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
2290                    Arg1&& arg1, Arg2&& arg2) {                         \
2291     return Op(matched_inst)                                            \
2292         .WithOpcode(HloOpcode::k##NAME)                                \
2293         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2294         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2295         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2296   }
2297 XLA_TERNOP_PATTERN(Clamp);
2298 XLA_TERNOP_PATTERN(Select);
2299 XLA_TERNOP_PATTERN(SelectAndScatter);
2300 #undef XLA_TERNOP_PATTERN
2301 
2302 namespace detail {
2303 template <typename Matcher, typename FirstArg>
2304 inline auto WithOperands(Matcher&& m, int64_t operand_num,
2305                          FirstArg&& first_arg) {
2306   return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2307 }
2308 
2309 template <typename Matcher, typename FirstArg, typename... Args>
2310 inline auto WithOperands(Matcher&& m, int64_t operand_num, FirstArg&& first_arg,
2311                          Args&&... args) {
2312   return WithOperands(
2313       m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2314       operand_num + 1, std::forward<Args>(args)...);
2315 }
2316 }  // namespace detail
2317 
2318 #define XLA_VARIADIC_OP_PATTERN(NAME)                                         \
2319   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2320                                                                               \
2321   template <typename... Args>                                                 \
2322   inline auto NAME(Args&&... args) {                                          \
2323     return detail::WithOperands(                                              \
2324         Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2325         /*operand_num=*/0, std::forward<Args>(args)...);                      \
2326   }                                                                           \
2327                                                                               \
2328   template <typename HloInstructionType, typename... Args>                    \
2329   inline auto NAME(HloInstructionType** matched_inst, Args&&... args) {       \
2330     return detail::WithOperands(Op(matched_inst)                              \
2331                                     .WithOpcode(HloOpcode::k##NAME)           \
2332                                     .WithNumOperands(sizeof...(Args)),        \
2333                                 /*operand_num=*/0,                            \
2334                                 std::forward<Args>(args)...);                 \
2335   }
2336 
2337 // We could implement all ops as "variadic" ops, but it would make the
2338 // already-bad compile errors even worse.
2339 XLA_VARIADIC_OP_PATTERN(AfterAll);
2340 XLA_VARIADIC_OP_PATTERN(Concatenate);
2341 XLA_VARIADIC_OP_PATTERN(Conditional);
2342 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2343 XLA_VARIADIC_OP_PATTERN(DynamicUpdateSlice)
2344 XLA_VARIADIC_OP_PATTERN(Fusion);
2345 XLA_VARIADIC_OP_PATTERN(Map)
2346 XLA_VARIADIC_OP_PATTERN(Reduce);
2347 XLA_VARIADIC_OP_PATTERN(ReduceWindow)
2348 XLA_VARIADIC_OP_PATTERN(Scatter);
2349 XLA_VARIADIC_OP_PATTERN(Sort);
2350 XLA_VARIADIC_OP_PATTERN(Tuple);
2351 XLA_VARIADIC_OP_PATTERN(Call);
2352 
2353 // CustomCall doesn't use the XLA_VARIADIC_OP_PATTERN macro so that you can
2354 // optionally pass a string_view for the custom_call_target before the other
2355 // operands.
2356 inline auto CustomCall() { return Op().WithOpcode(HloOpcode::kCustomCall); }
2357 
2358 template <typename HloInstructionType>
2359 auto CustomCall(HloInstructionType** matched_inst) {
2360   return Op(matched_inst).WithOpcode(HloOpcode::kCustomCall);
2361 }
2362 
2363 template <
2364     typename Arg0, typename... Args,
2365     typename std::enable_if<
2366         !std::is_convertible<Arg0, absl::string_view>::value &&
2367         !std::is_convertible<Arg0, HloInstruction**>::value &&
2368         !std::is_convertible<Arg0, const HloInstruction**>::value>::type* =
2369         nullptr>
2370 auto CustomCall(Arg0&& arg0, Args&&... args) {
2371   return detail::WithOperands(CustomCall().WithNumOperands(sizeof...(Args) + 1),
2372                               /*operand_num=*/0, std::forward<Arg0>(arg0),
2373                               std::forward<Args>(args)...);
2374 }
2375 template <typename... Args>
2376 auto CustomCall(absl::string_view custom_call_target, Args&&... args) {
2377   return CustomCall(std::forward<Args>(args)...)
2378       .WithCustomCallTarget(custom_call_target);
2379 }
2380 
2381 template <typename HloInstructionType, typename Arg0, typename... Args,
2382           typename std::enable_if<!std::is_convertible<
2383               Arg0, absl::string_view>::value>::type* = nullptr>
2384 auto CustomCall(HloInstructionType** matched_inst, Arg0&& arg0,
2385                 Args&&... args) {
2386   return detail::WithOperands(
2387       CustomCall(matched_inst).WithNumOperands(sizeof...(Args) + 1),
2388       /*operand_num=*/0, std::forward<Arg0>(arg0), std::forward<Args>(args)...);
2389 }
2390 template <typename HloInstructionType, typename... Args>
2391 auto CustomCall(HloInstructionType** matched_inst,
2392                 absl::string_view custom_call_target, Args&&... args) {
2393   return CustomCall(matched_inst, std::forward<Args>(args)...)
2394       .WithCustomCallTarget(custom_call_target);
2395 }
2396 
2397 // Helpers for comparison instructions.
2398 #define XLA_COMPARE_PATTERN(NAME)                                             \
2399   inline auto NAME() {                                                        \
2400     return Op()                                                               \
2401         .WithOpcode(HloOpcode::kCompare)                                      \
2402         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2403   }                                                                           \
2404                                                                               \
2405   template <typename Lhs, typename Rhs>                                       \
2406   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2407     return Op()                                                               \
2408         .WithOpcode(HloOpcode::kCompare)                                      \
2409         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2410         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2411         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2412   }                                                                           \
2413                                                                               \
2414   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2415   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2416     return Op(matched_inst)                                                   \
2417         .WithOpcode(HloOpcode::kCompare)                                      \
2418         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2419         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2420         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2421   }
2422 
2423 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME)                              \
2424   XLA_COMPARE_PATTERN(NAME)                                                \
2425                                                                            \
2426   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2427   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2428                              Rhs&& rhs) {                                  \
2429     return Op(matched_inst)                                                \
2430         .WithOpcode(HloOpcode::kCompare)                                   \
2431         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2432                                     std::forward<Rhs>(rhs));               \
2433   }                                                                        \
2434   template <typename Lhs, typename Rhs>                                    \
2435   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2436     return NAME##AnyOrder<const HloInstruction>(                           \
2437         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2438   }
2439 
2440 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2441 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2442 XLA_COMPARE_PATTERN(Ge);
2443 XLA_COMPARE_PATTERN(Gt);
2444 XLA_COMPARE_PATTERN(Le);
2445 XLA_COMPARE_PATTERN(Lt);
2446 
2447 // Helpers for matching non-constant instructions.
2448 inline auto NonConstant() { return Op().IsNonConstant(); }
2449 
2450 template <typename HloInstructionType>
2451 inline auto NonConstant(HloInstructionType** matched_inst) {
2452   return Op(matched_inst).IsNonConstant();
2453 }
2454 
2455 // Add overloads for GetTupleElement which take a int64_t specifying which tuple
2456 // element is selected.
2457 template <typename Arg>
2458 inline auto GetTupleElement(Arg&& arg, int64_t tuple_index) {
2459   return Op()
2460       .WithOpcode(HloOpcode::kGetTupleElement)
2461       .WithOperand(0, std::forward<Arg>(arg))
2462       .WithTupleIndex(tuple_index);
2463 }
2464 
2465 template <typename HloInstructionType, typename Arg>
2466 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2467                             int64_t tuple_index) {
2468   return Op(matched_inst)
2469       .WithOpcode(HloOpcode::kGetTupleElement)
2470       .WithOperand(0, std::forward<Arg>(arg))
2471       .WithTupleIndex(tuple_index);
2472 }
2473 
2474 // Add overloads for Parameter which take an int64_t specifying the parameter
2475 // number.
2476 inline auto Parameter(int64_t parameter_num) {
2477   return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2478 }
2479 template <typename HloInstructionType>
2480 inline auto Parameter(HloInstructionType** matched_inst,
2481                       int64_t parameter_num) {
2482   return Op(matched_inst)
2483       .WithOpcode(HloOpcode::kParameter)
2484       .WithParameterNum(parameter_num);
2485 }
2486 
2487 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2488 
2489 template <typename HloInstructionType>
2490 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2491   return Op(matched_inst).IsConstantScalar();
2492 }
2493 
2494 template <typename ScalarTy>
2495 inline auto ConstantScalar(ScalarTy val) {
2496   return Op().IsConstantScalar(val);
2497 }
2498 
2499 template <typename HloInstructionType, typename ScalarTy>
2500 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2501   return Op(matched_inst).IsConstantScalar(val);
2502 }
2503 
2504 inline auto ConstantEffectiveScalar() {
2505   return Op().IsConstantEffectiveScalar();
2506 }
2507 
2508 template <typename HloInstructionType>
2509 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2510   return Op(matched_inst).IsConstantEffectiveScalar();
2511 }
2512 
2513 template <typename ScalarTy>
2514 inline auto ConstantEffectiveScalar(ScalarTy val) {
2515   return Op().IsConstantEffectiveScalar(val);
2516 }
2517 
2518 template <typename HloInstructionType, typename ScalarTy>
2519 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2520                                     ScalarTy val) {
2521   return Op(matched_inst).IsConstantEffectiveScalar(val);
2522 }
2523 
2524 }  // namespace match
2525 
2526 }  // namespace xla
2527 
2528 #undef EXPLAIN
2529 #pragma pop_macro("EXPLAIN")
2530 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2531