xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_matchers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
18 
19 #include <optional>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/test.h"
26 
27 namespace xla {
28 namespace testing {
29 
30 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
31  public:
HloMatcher(HloOpcode opcode,std::vector<::testing::Matcher<const HloInstruction * >> operands)32   HloMatcher(HloOpcode opcode,
33              std::vector<::testing::Matcher<const HloInstruction*>> operands)
34       : opcode_(opcode), operands_(operands) {}
35 
36   bool MatchAndExplain(const HloInstruction* instruction,
37                        ::testing::MatchResultListener* listener) const override;
38 
39   void DescribeTo(::std::ostream* os) const override;
40 
41  private:
42   HloOpcode opcode_;
43   std::vector<::testing::Matcher<const HloInstruction*>> operands_;
44 };
45 
46 // Custom matcher for parameters, which accepts a parameter number.
47 class HloParameterMatcher : public HloMatcher {
48  public:
HloParameterMatcher(int64_t parameter_number)49   explicit HloParameterMatcher(int64_t parameter_number)
50       : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
51         parameter_number_(parameter_number) {}
52 
53   bool MatchAndExplain(const HloInstruction* instruction,
54                        ::testing::MatchResultListener* listener) const override;
55 
56  private:
57   int64_t parameter_number_;
58 };
59 
60 // Custom matcher for comparisons, which accepts a comparison direction.
61 class HloComparisonMatcher : public HloMatcher {
62  public:
HloComparisonMatcher(ComparisonDirection direction,std::vector<::testing::Matcher<const HloInstruction * >> operands)63   explicit HloComparisonMatcher(
64       ComparisonDirection direction,
65       std::vector<::testing::Matcher<const HloInstruction*>> operands)
66       : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
67 
68   bool MatchAndExplain(const HloInstruction* instruction,
69                        ::testing::MatchResultListener* listener) const override;
70 
71  private:
72   ComparisonDirection direction_;
73 };
74 
75 // Custom matcher for get-tuple-element instructions, which accepts a tuple
76 // index to match.
77 class HloGetTupleElementMatcher : public HloMatcher {
78  public:
HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction * > operand,int64_t tuple_index)79   HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
80                             int64_t tuple_index)
81       : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
82         tuple_index_(tuple_index) {}
83 
84   bool MatchAndExplain(const HloInstruction* instruction,
85                        ::testing::MatchResultListener* listener) const override;
86 
87  private:
88   int64_t tuple_index_;
89 };
90 
91 // Custom matcher for custom-call instructions, which accepts a matcher for its
92 // call target.
93 class HloCustomCallMatcher : public HloMatcher {
94  public:
HloCustomCallMatcher(::testing::Matcher<std::string> call_target_matcher,std::vector<::testing::Matcher<const HloInstruction * >> operands)95   HloCustomCallMatcher(
96       ::testing::Matcher<std::string> call_target_matcher,
97       std::vector<::testing::Matcher<const HloInstruction*>> operands)
98       : HloMatcher(HloOpcode::kCustomCall, operands),
99         call_target_matcher_(call_target_matcher) {}
100 
101   bool MatchAndExplain(const HloInstruction* instruction,
102                        ::testing::MatchResultListener* listener) const override;
103   void DescribeTo(std::ostream* os) const override;
104 
105  private:
106   ::testing::Matcher<std::string> call_target_matcher_;
107 };
108 
109 class HloShapeMatcher
110     : public ::testing::MatcherInterface<const HloInstruction*> {
111  public:
HloShapeMatcher(const Shape & shape)112   explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
113 
114   bool MatchAndExplain(const HloInstruction* instruction,
115                        ::testing::MatchResultListener* listener) const override;
116   void DescribeTo(std::ostream* os) const override;
117 
118  private:
119   Shape shape_;
120 };
121 
122 class HloShapeAndLayoutMatcher
123     : public ::testing::MatcherInterface<const HloInstruction*> {
124  public:
125   explicit HloShapeAndLayoutMatcher(const Shape& shape,
126                                     bool minor_to_major_only = false)
shape_(shape)127       : shape_(shape), minor_to_major_only_(minor_to_major_only) {}
128 
129   bool MatchAndExplain(const HloInstruction* instruction,
130                        ::testing::MatchResultListener* listener) const override;
131   void DescribeTo(std::ostream* os) const override;
132 
133  private:
134   Shape shape_;
135   bool minor_to_major_only_;
136 };
137 
138 // Verify the sharding of an instruction against the provided HloSharding. If a
139 // nullopt is provided for the expected sharding then it checks that no sharding
140 // is present for an instruction.
141 class HloShardingMatcher
142     : public ::testing::MatcherInterface<const HloInstruction*> {
143  public:
HloShardingMatcher(const std::optional<HloSharding> & sharding)144   explicit HloShardingMatcher(const std::optional<HloSharding>& sharding)
145       : sharding_(sharding) {}
146 
147   bool MatchAndExplain(const HloInstruction* instruction,
148                        ::testing::MatchResultListener* listener) const override;
149   void DescribeTo(std::ostream* os) const override;
150 
151  private:
152   std::optional<HloSharding> sharding_;
153 };
154 
155 // Matches a Dot HLO instruction with specific LHS and RHS contracting
156 // dimensions.
157 class HloDotWithContractingDimsMatcher : public HloMatcher {
158  public:
HloDotWithContractingDimsMatcher(::testing::Matcher<const HloInstruction * > lhs,::testing::Matcher<const HloInstruction * > rhs,int64_t lhs_contracting_dim,int64_t rhs_contracting_dim)159   explicit HloDotWithContractingDimsMatcher(
160       ::testing::Matcher<const HloInstruction*> lhs,
161       ::testing::Matcher<const HloInstruction*> rhs,
162       int64_t lhs_contracting_dim, int64_t rhs_contracting_dim)
163       : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
164         lhs_contracting_dim_(lhs_contracting_dim),
165         rhs_contracting_dim_(rhs_contracting_dim) {}
166 
167   bool MatchAndExplain(const HloInstruction* instruction,
168                        ::testing::MatchResultListener* listener) const override;
169   void DescribeTo(std::ostream* os) const override;
170 
171  private:
172   int64_t lhs_contracting_dim_;
173   int64_t rhs_contracting_dim_;
174 };
175 
176 // Custom matcher for asynchronous copy (CopyStart/CopyDone pair) with specified
177 // source and destination memory spaces.
178 class HloAsyncCopyMatcher : public HloMatcher {
179  public:
HloAsyncCopyMatcher(int64_t to_space,int64_t from_space,::testing::Matcher<const HloInstruction * > operand)180   HloAsyncCopyMatcher(int64_t to_space, int64_t from_space,
181                       ::testing::Matcher<const HloInstruction*> operand)
182       : HloMatcher(HloOpcode::kCopyDone,
183                    {::testing::MakeMatcher(
184                        new HloMatcher(HloOpcode::kCopyStart, {operand}))}),
185         to_space_(to_space),
186         from_space_(from_space) {}
187 
188   bool MatchAndExplain(const HloInstruction* instruction,
189                        ::testing::MatchResultListener* listener) const override;
190   void DescribeTo(std::ostream* os) const override;
191 
192  private:
193   int64_t to_space_;
194   int64_t from_space_;
195 };
196 
197 class HloConstantMatcher : public HloMatcher {
198  public:
HloConstantMatcher(Literal literal)199   explicit HloConstantMatcher(Literal literal)
200       : HloMatcher(HloOpcode::kConstant, /*operands=*/{}),
201         literal_(std::move(literal)) {}
202   bool MatchAndExplain(const HloInstruction* instruction,
203                        ::testing::MatchResultListener* listener) const override;
204   void DescribeTo(std::ostream* os) const override;
205 
206  private:
207   Literal literal_;
208 };
209 
210 class HloReplicaGroupsMatcher
211     : public ::testing::MatcherInterface<const HloInstruction*> {
212  public:
HloReplicaGroupsMatcher(std::vector<std::vector<int64_t>> replica_groups)213   explicit HloReplicaGroupsMatcher(
214       std::vector<std::vector<int64_t>> replica_groups)
215       : replica_groups_(std::move(replica_groups)) {}
216 
217   bool MatchAndExplain(const HloInstruction* instruction,
218                        ::testing::MatchResultListener* listener) const override;
219   void DescribeTo(std::ostream* os) const override;
220 
221  private:
222   std::vector<std::vector<int64_t>> replica_groups_;
223 };
224 
225 // HloInstruction* matchers for opcode and operands. Example:
226 //   namespace op = xla::opcode_matchers;
227 //   EXPECT_THAT(instruction,
228 //               op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
229 namespace opcode_matchers {
230 #define HLO_MATCHER(opcode)                                                \
231   template <typename... M>                                                 \
232   ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
233     return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(          \
234         ::xla::HloOpcode::k##opcode, {operands...}));                      \
235   }
236 HLO_MATCHER(Abs);
237 HLO_MATCHER(Add);
238 HLO_MATCHER(AddDependency);
239 HLO_MATCHER(AfterAll);
240 HLO_MATCHER(AsyncStart);
241 HLO_MATCHER(AsyncUpdate);
242 HLO_MATCHER(AsyncDone);
243 HLO_MATCHER(AllGather);
244 HLO_MATCHER(AllReduce);
245 HLO_MATCHER(AllToAll);
246 HLO_MATCHER(And);
247 HLO_MATCHER(BatchNormGrad);
248 HLO_MATCHER(Bitcast);
249 HLO_MATCHER(BitcastConvert);
250 HLO_MATCHER(Broadcast);
251 HLO_MATCHER(Call);
252 HLO_MATCHER(Ceil);
253 HLO_MATCHER(Clamp);
254 HLO_MATCHER(CollectivePermute);
255 HLO_MATCHER(CollectivePermuteStart);
256 HLO_MATCHER(CollectivePermuteDone);
257 HLO_MATCHER(Compare);
258 HLO_MATCHER(Concatenate);
259 HLO_MATCHER(Conditional);
260 HLO_MATCHER(Convert);
261 HLO_MATCHER(Convolution);
262 HLO_MATCHER(Copy);
263 HLO_MATCHER(CopyDone);
264 HLO_MATCHER(CopyStart);
265 HLO_MATCHER(Divide);
266 HLO_MATCHER(Domain);
267 HLO_MATCHER(DynamicSlice);
268 HLO_MATCHER(DynamicUpdateSlice);
269 HLO_MATCHER(Exp);
270 HLO_MATCHER(Fft);
271 HLO_MATCHER(Floor);
272 HLO_MATCHER(Fusion);
273 HLO_MATCHER(Gather);
274 HLO_MATCHER(GetDimensionSize);
275 HLO_MATCHER(Infeed);
276 HLO_MATCHER(Iota);
277 HLO_MATCHER(IsFinite);
278 HLO_MATCHER(Log);
279 HLO_MATCHER(Map);
280 HLO_MATCHER(Maximum);
281 HLO_MATCHER(Minimum);
282 HLO_MATCHER(Multiply);
283 HLO_MATCHER(Negate);
284 HLO_MATCHER(Not);
285 HLO_MATCHER(Or);
286 HLO_MATCHER(Outfeed);
287 HLO_MATCHER(Pad);
288 HLO_MATCHER(PartitionId);
289 HLO_MATCHER(Power);
290 HLO_MATCHER(Recv);
291 HLO_MATCHER(RecvDone);
292 HLO_MATCHER(Reduce);
293 HLO_MATCHER(ReducePrecision);
294 HLO_MATCHER(ReduceScatter);
295 HLO_MATCHER(ReduceWindow);
296 HLO_MATCHER(Remainder);
297 HLO_MATCHER(ReplicaId);
298 HLO_MATCHER(Reshape);
299 HLO_MATCHER(Reverse);
300 HLO_MATCHER(Rng);
301 HLO_MATCHER(RngBitGenerator);
302 HLO_MATCHER(RngGetAndUpdateState);
303 HLO_MATCHER(Scatter);
304 HLO_MATCHER(Select);
305 HLO_MATCHER(SelectAndScatter);
306 HLO_MATCHER(Send);
307 HLO_MATCHER(SendDone);
308 HLO_MATCHER(SetDimensionSize);
309 HLO_MATCHER(ShiftLeft);
310 HLO_MATCHER(ShiftRightArithmetic);
311 HLO_MATCHER(ShiftRightLogical);
312 HLO_MATCHER(Sign);
313 HLO_MATCHER(Slice);
314 HLO_MATCHER(Sort);
315 HLO_MATCHER(Subtract);
316 HLO_MATCHER(Tanh);
317 HLO_MATCHER(Transpose);
318 HLO_MATCHER(Tuple);
319 HLO_MATCHER(While);
320 HLO_MATCHER(Xor);
321 HLO_MATCHER(OptimizationBarrier);
322 
323 #define HLO_MATCHER_VECTOR_OPERANDS(opcode)                              \
324   template <>                                                            \
325   inline ::testing::Matcher<const ::xla::HloInstruction*> opcode(        \
326       std::vector<::testing::Matcher<const HloInstruction*>> operands) { \
327     return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(        \
328         ::xla::HloOpcode::k##opcode, operands));                         \
329   }
330 
331 HLO_MATCHER_VECTOR_OPERANDS(DynamicSlice);
332 
333 // The special cases below let you check additional information about the
334 // HloInstruction, beyond just its opcode and operands.  In all cases you can
335 // still use the generic matcher which doesn't check this info.
336 //
337 // Feel free to add additional custom matchers below.
338 
339 //  - Parameter(N) matches parameter number N.
340 //  - Parameter() matches any parameter.
Parameter(int64_t parameter_number)341 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
342     int64_t parameter_number) {
343   return ::testing::MakeMatcher(
344       new ::xla::testing::HloParameterMatcher(parameter_number));
345 }
Parameter()346 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
347   return ::testing::MakeMatcher(
348       new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
349 }
350 
351 // Comparison matchers below do not require any additional arguments.
352 template <typename... M>
Eq(M...operands)353 inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
354   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
355       ComparisonDirection::kEq, {operands...}));
356 }
357 template <typename... M>
Ne(M...operands)358 inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
359   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
360       ComparisonDirection::kNe, {operands...}));
361 }
362 template <typename... M>
Ge(M...operands)363 inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
364   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
365       ComparisonDirection::kGe, {operands...}));
366 }
367 template <typename... M>
Gt(M...operands)368 inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
369   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
370       ComparisonDirection::kGt, {operands...}));
371 }
372 template <typename... M>
Le(M...operands)373 inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
374   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
375       ComparisonDirection::kLe, {operands...}));
376 }
377 template <typename... M>
Lt(M...operands)378 inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
379   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
380       ComparisonDirection::kLt, {operands...}));
381 }
382 
383 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
384 // tuple element of operand, while GetTupleElement(operand) matches any GTE
385 // operation on operand, and GetTupleElement() matches any GTE operation at all.
GetTupleElement(::testing::Matcher<const HloInstruction * > operand,int64_t tuple_index)386 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
387     ::testing::Matcher<const HloInstruction*> operand, int64_t tuple_index) {
388   return ::testing::MakeMatcher(
389       new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
390 }
GetTupleElement(::testing::Matcher<const HloInstruction * > operand)391 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
392     ::testing::Matcher<const HloInstruction*> operand) {
393   return ::testing::MakeMatcher(
394       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
395 }
GetTupleElement()396 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
397   return ::testing::MakeMatcher(
398       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
399 }
400 
401 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
402 //   target T and the given operands.
403 //
404 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
405 //   given operands.
406 //
407 // - CustomCall() matches any CustomCall HLO at all.
408 template <typename... M>
CustomCall(::testing::Matcher<std::string> call_target_matcher,M...operands)409 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
410     ::testing::Matcher<std::string> call_target_matcher, M... operands) {
411   return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
412       call_target_matcher, {operands...}));
413 }
414 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
415 // ::testing::Matcher<std::string>.  In that case, we want to prefer the
416 // overload above.
417 template <
418     typename FirstM, typename... M,
419     typename Dummy = typename std::enable_if<
420         !std::is_convertible<FirstM, ::testing::Matcher<std::string>>::value,
421         void>::type*>
CustomCall(FirstM operands_first,M...operands_rest)422 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
423     FirstM operands_first, M... operands_rest) {
424   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
425       HloOpcode::kCustomCall, {operands_first, operands_rest...}));
426 }
CustomCall()427 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
428   return ::testing::MakeMatcher(
429       new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
430 }
431 
432 // Verifies the shape or the shape and the layout of an HLO instruction against
433 // the provided shape object.
Shape(const class Shape & shape)434 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
435     const class Shape& shape) {
436   return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
437 }
Shape(absl::string_view shape)438 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
439     absl::string_view shape) {
440   return ::testing::MakeMatcher(
441       new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
442 }
ShapeWithLayout(const class Shape & shape)443 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
444     const class Shape& shape) {
445   return ::testing::MakeMatcher(
446       new ::xla::testing::HloShapeAndLayoutMatcher(shape));
447 }
448 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
449     absl::string_view shape, bool minor_to_major_only = false) {
450   return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
451       ParseShape(shape).ValueOrDie(), minor_to_major_only));
452 }
453 
454 // Verifies the value of the HloSharing against the provided sharding object.
Sharding(const HloSharding & sharding)455 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
456     const HloSharding& sharding) {
457   return ::testing::MakeMatcher(
458       new ::xla::testing::HloShardingMatcher(sharding));
459 }
460 // Matcher for Sharding from sharding string
Sharding(absl::string_view sharding)461 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
462     absl::string_view sharding) {
463   return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
464       ParseSharding(sharding).ValueOrDie()));
465 }
466 // Verifies that no HloSharding is set for an HLO instruction.
NoSharding()467 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
468   return ::testing::MakeMatcher(
469       new ::xla::testing::HloShardingMatcher(std::nullopt));
470 }
471 
Dot()472 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot() {
473   return ::testing::MakeMatcher(
474       new ::xla::testing::HloMatcher(::xla::HloOpcode::kDot, {}));
475 }
476 
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher)477 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
478     ::testing::Matcher<const HloInstruction*> lhs_matcher,
479     ::testing::Matcher<const HloInstruction*> rhs_matcher) {
480   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
481       ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
482 }
483 
484 // Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
485 // equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
486 // equal to `rhs_contracting_dim`.
487 //
488 // Currently the HLO verifier rejects Dot operations with more than one
489 // contracting dimension (even though we can represent these in the
490 // DotDimensionNumbers proto) so there is no need to generalize this to support
491 // multiple contracting dimensions.
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher,int64_t lhs_contracting_dim,int64_t rhs_contracting_dim)492 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
493     ::testing::Matcher<const HloInstruction*> lhs_matcher,
494     ::testing::Matcher<const HloInstruction*> rhs_matcher,
495     int64_t lhs_contracting_dim, int64_t rhs_contracting_dim) {
496   return ::testing::MakeMatcher(
497       new ::xla::testing::HloDotWithContractingDimsMatcher(
498           lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
499 }
500 
501 // Matcher for asynchronous copies from one memory space to another. Implies
502 // CopyDone(CopyStart(...)) where from_space and to_space is the source and
503 // destination memory spaces, respectively.
AsyncCopy(int64_t to_space,int64_t from_space,::testing::Matcher<const HloInstruction * > operand_matcher)504 inline ::testing::Matcher<const ::xla::HloInstruction*> AsyncCopy(
505     int64_t to_space, int64_t from_space,
506     ::testing::Matcher<const HloInstruction*> operand_matcher) {
507   return ::testing::MakeMatcher(new ::xla::testing::HloAsyncCopyMatcher(
508       to_space, from_space, operand_matcher));
509 }
510 
511 //  - Constant() matches any constant.
512 //  - Constant(V) matches a constant with the given value.
Constant()513 inline ::testing::Matcher<const ::xla::HloInstruction*> Constant() {
514   return ::testing::MakeMatcher(
515       new ::xla::testing::HloMatcher(HloOpcode::kConstant, {}));
516 }
Constant(Literal value)517 inline ::testing::Matcher<const ::xla::HloInstruction*> Constant(
518     Literal value) {
519   return ::testing::MakeMatcher(
520       new ::xla::testing::HloConstantMatcher(std::move(value)));
521 }
522 
ReplicaGroups(std::vector<std::vector<int64_t>> replica_groups)523 inline ::testing::Matcher<const ::xla::HloInstruction*> ReplicaGroups(
524     std::vector<std::vector<int64_t>> replica_groups) {
525   return ::testing::MakeMatcher(
526       new ::xla::testing::HloReplicaGroupsMatcher(std::move(replica_groups)));
527 }
528 
529 #undef HLO_MATCHER
530 }  // namespace opcode_matchers
531 
532 // Helper to convert smart to raw pointers for matching.
533 template <typename Container>
Pointers(const Container & container)534 std::vector<const HloInstruction*> Pointers(const Container& container) {
535   std::vector<const HloInstruction*> result;
536   result.reserve(container.size());
537   for (const auto& entry : container) result.push_back(entry.get());
538   return result;
539 }
540 
541 }  // namespace testing
542 
543 // Tell GMock to print HloInstruction* by value, so error messages are nice.
544 // Has to be in the same namespace as 'HloInstruction'.
545 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
546 
547 }  // namespace xla
548 
549 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
550