1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
18
19 #include <optional>
20 #include <string>
21 #include <utility>
22
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/test.h"
26
27 namespace xla {
28 namespace testing {
29
30 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
31 public:
HloMatcher(HloOpcode opcode,std::vector<::testing::Matcher<const HloInstruction * >> operands)32 HloMatcher(HloOpcode opcode,
33 std::vector<::testing::Matcher<const HloInstruction*>> operands)
34 : opcode_(opcode), operands_(operands) {}
35
36 bool MatchAndExplain(const HloInstruction* instruction,
37 ::testing::MatchResultListener* listener) const override;
38
39 void DescribeTo(::std::ostream* os) const override;
40
41 private:
42 HloOpcode opcode_;
43 std::vector<::testing::Matcher<const HloInstruction*>> operands_;
44 };
45
46 // Custom matcher for parameters, which accepts a parameter number.
47 class HloParameterMatcher : public HloMatcher {
48 public:
HloParameterMatcher(int64_t parameter_number)49 explicit HloParameterMatcher(int64_t parameter_number)
50 : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
51 parameter_number_(parameter_number) {}
52
53 bool MatchAndExplain(const HloInstruction* instruction,
54 ::testing::MatchResultListener* listener) const override;
55
56 private:
57 int64_t parameter_number_;
58 };
59
60 // Custom matcher for comparisons, which accepts a comparison direction.
61 class HloComparisonMatcher : public HloMatcher {
62 public:
HloComparisonMatcher(ComparisonDirection direction,std::vector<::testing::Matcher<const HloInstruction * >> operands)63 explicit HloComparisonMatcher(
64 ComparisonDirection direction,
65 std::vector<::testing::Matcher<const HloInstruction*>> operands)
66 : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
67
68 bool MatchAndExplain(const HloInstruction* instruction,
69 ::testing::MatchResultListener* listener) const override;
70
71 private:
72 ComparisonDirection direction_;
73 };
74
75 // Custom matcher for get-tuple-element instructions, which accepts a tuple
76 // index to match.
77 class HloGetTupleElementMatcher : public HloMatcher {
78 public:
HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction * > operand,int64_t tuple_index)79 HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
80 int64_t tuple_index)
81 : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
82 tuple_index_(tuple_index) {}
83
84 bool MatchAndExplain(const HloInstruction* instruction,
85 ::testing::MatchResultListener* listener) const override;
86
87 private:
88 int64_t tuple_index_;
89 };
90
91 // Custom matcher for custom-call instructions, which accepts a matcher for its
92 // call target.
93 class HloCustomCallMatcher : public HloMatcher {
94 public:
HloCustomCallMatcher(::testing::Matcher<std::string> call_target_matcher,std::vector<::testing::Matcher<const HloInstruction * >> operands)95 HloCustomCallMatcher(
96 ::testing::Matcher<std::string> call_target_matcher,
97 std::vector<::testing::Matcher<const HloInstruction*>> operands)
98 : HloMatcher(HloOpcode::kCustomCall, operands),
99 call_target_matcher_(call_target_matcher) {}
100
101 bool MatchAndExplain(const HloInstruction* instruction,
102 ::testing::MatchResultListener* listener) const override;
103 void DescribeTo(std::ostream* os) const override;
104
105 private:
106 ::testing::Matcher<std::string> call_target_matcher_;
107 };
108
109 class HloShapeMatcher
110 : public ::testing::MatcherInterface<const HloInstruction*> {
111 public:
HloShapeMatcher(const Shape & shape)112 explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
113
114 bool MatchAndExplain(const HloInstruction* instruction,
115 ::testing::MatchResultListener* listener) const override;
116 void DescribeTo(std::ostream* os) const override;
117
118 private:
119 Shape shape_;
120 };
121
122 class HloShapeAndLayoutMatcher
123 : public ::testing::MatcherInterface<const HloInstruction*> {
124 public:
125 explicit HloShapeAndLayoutMatcher(const Shape& shape,
126 bool minor_to_major_only = false)
shape_(shape)127 : shape_(shape), minor_to_major_only_(minor_to_major_only) {}
128
129 bool MatchAndExplain(const HloInstruction* instruction,
130 ::testing::MatchResultListener* listener) const override;
131 void DescribeTo(std::ostream* os) const override;
132
133 private:
134 Shape shape_;
135 bool minor_to_major_only_;
136 };
137
138 // Verify the sharding of an instruction against the provided HloSharding. If a
139 // nullopt is provided for the expected sharding then it checks that no sharding
140 // is present for an instruction.
141 class HloShardingMatcher
142 : public ::testing::MatcherInterface<const HloInstruction*> {
143 public:
HloShardingMatcher(const std::optional<HloSharding> & sharding)144 explicit HloShardingMatcher(const std::optional<HloSharding>& sharding)
145 : sharding_(sharding) {}
146
147 bool MatchAndExplain(const HloInstruction* instruction,
148 ::testing::MatchResultListener* listener) const override;
149 void DescribeTo(std::ostream* os) const override;
150
151 private:
152 std::optional<HloSharding> sharding_;
153 };
154
155 // Matches a Dot HLO instruction with specific LHS and RHS contracting
156 // dimensions.
157 class HloDotWithContractingDimsMatcher : public HloMatcher {
158 public:
HloDotWithContractingDimsMatcher(::testing::Matcher<const HloInstruction * > lhs,::testing::Matcher<const HloInstruction * > rhs,int64_t lhs_contracting_dim,int64_t rhs_contracting_dim)159 explicit HloDotWithContractingDimsMatcher(
160 ::testing::Matcher<const HloInstruction*> lhs,
161 ::testing::Matcher<const HloInstruction*> rhs,
162 int64_t lhs_contracting_dim, int64_t rhs_contracting_dim)
163 : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
164 lhs_contracting_dim_(lhs_contracting_dim),
165 rhs_contracting_dim_(rhs_contracting_dim) {}
166
167 bool MatchAndExplain(const HloInstruction* instruction,
168 ::testing::MatchResultListener* listener) const override;
169 void DescribeTo(std::ostream* os) const override;
170
171 private:
172 int64_t lhs_contracting_dim_;
173 int64_t rhs_contracting_dim_;
174 };
175
176 // Custom matcher for asynchronous copy (CopyStart/CopyDone pair) with specified
177 // source and destination memory spaces.
178 class HloAsyncCopyMatcher : public HloMatcher {
179 public:
HloAsyncCopyMatcher(int64_t to_space,int64_t from_space,::testing::Matcher<const HloInstruction * > operand)180 HloAsyncCopyMatcher(int64_t to_space, int64_t from_space,
181 ::testing::Matcher<const HloInstruction*> operand)
182 : HloMatcher(HloOpcode::kCopyDone,
183 {::testing::MakeMatcher(
184 new HloMatcher(HloOpcode::kCopyStart, {operand}))}),
185 to_space_(to_space),
186 from_space_(from_space) {}
187
188 bool MatchAndExplain(const HloInstruction* instruction,
189 ::testing::MatchResultListener* listener) const override;
190 void DescribeTo(std::ostream* os) const override;
191
192 private:
193 int64_t to_space_;
194 int64_t from_space_;
195 };
196
197 class HloConstantMatcher : public HloMatcher {
198 public:
HloConstantMatcher(Literal literal)199 explicit HloConstantMatcher(Literal literal)
200 : HloMatcher(HloOpcode::kConstant, /*operands=*/{}),
201 literal_(std::move(literal)) {}
202 bool MatchAndExplain(const HloInstruction* instruction,
203 ::testing::MatchResultListener* listener) const override;
204 void DescribeTo(std::ostream* os) const override;
205
206 private:
207 Literal literal_;
208 };
209
210 class HloReplicaGroupsMatcher
211 : public ::testing::MatcherInterface<const HloInstruction*> {
212 public:
HloReplicaGroupsMatcher(std::vector<std::vector<int64_t>> replica_groups)213 explicit HloReplicaGroupsMatcher(
214 std::vector<std::vector<int64_t>> replica_groups)
215 : replica_groups_(std::move(replica_groups)) {}
216
217 bool MatchAndExplain(const HloInstruction* instruction,
218 ::testing::MatchResultListener* listener) const override;
219 void DescribeTo(std::ostream* os) const override;
220
221 private:
222 std::vector<std::vector<int64_t>> replica_groups_;
223 };
224
225 // HloInstruction* matchers for opcode and operands. Example:
226 // namespace op = xla::opcode_matchers;
227 // EXPECT_THAT(instruction,
228 // op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
229 namespace opcode_matchers {
230 #define HLO_MATCHER(opcode) \
231 template <typename... M> \
232 ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
233 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \
234 ::xla::HloOpcode::k##opcode, {operands...})); \
235 }
236 HLO_MATCHER(Abs);
237 HLO_MATCHER(Add);
238 HLO_MATCHER(AddDependency);
239 HLO_MATCHER(AfterAll);
240 HLO_MATCHER(AsyncStart);
241 HLO_MATCHER(AsyncUpdate);
242 HLO_MATCHER(AsyncDone);
243 HLO_MATCHER(AllGather);
244 HLO_MATCHER(AllReduce);
245 HLO_MATCHER(AllToAll);
246 HLO_MATCHER(And);
247 HLO_MATCHER(BatchNormGrad);
248 HLO_MATCHER(Bitcast);
249 HLO_MATCHER(BitcastConvert);
250 HLO_MATCHER(Broadcast);
251 HLO_MATCHER(Call);
252 HLO_MATCHER(Ceil);
253 HLO_MATCHER(Clamp);
254 HLO_MATCHER(CollectivePermute);
255 HLO_MATCHER(CollectivePermuteStart);
256 HLO_MATCHER(CollectivePermuteDone);
257 HLO_MATCHER(Compare);
258 HLO_MATCHER(Concatenate);
259 HLO_MATCHER(Conditional);
260 HLO_MATCHER(Convert);
261 HLO_MATCHER(Convolution);
262 HLO_MATCHER(Copy);
263 HLO_MATCHER(CopyDone);
264 HLO_MATCHER(CopyStart);
265 HLO_MATCHER(Divide);
266 HLO_MATCHER(Domain);
267 HLO_MATCHER(DynamicSlice);
268 HLO_MATCHER(DynamicUpdateSlice);
269 HLO_MATCHER(Exp);
270 HLO_MATCHER(Fft);
271 HLO_MATCHER(Floor);
272 HLO_MATCHER(Fusion);
273 HLO_MATCHER(Gather);
274 HLO_MATCHER(GetDimensionSize);
275 HLO_MATCHER(Infeed);
276 HLO_MATCHER(Iota);
277 HLO_MATCHER(IsFinite);
278 HLO_MATCHER(Log);
279 HLO_MATCHER(Map);
280 HLO_MATCHER(Maximum);
281 HLO_MATCHER(Minimum);
282 HLO_MATCHER(Multiply);
283 HLO_MATCHER(Negate);
284 HLO_MATCHER(Not);
285 HLO_MATCHER(Or);
286 HLO_MATCHER(Outfeed);
287 HLO_MATCHER(Pad);
288 HLO_MATCHER(PartitionId);
289 HLO_MATCHER(Power);
290 HLO_MATCHER(Recv);
291 HLO_MATCHER(RecvDone);
292 HLO_MATCHER(Reduce);
293 HLO_MATCHER(ReducePrecision);
294 HLO_MATCHER(ReduceScatter);
295 HLO_MATCHER(ReduceWindow);
296 HLO_MATCHER(Remainder);
297 HLO_MATCHER(ReplicaId);
298 HLO_MATCHER(Reshape);
299 HLO_MATCHER(Reverse);
300 HLO_MATCHER(Rng);
301 HLO_MATCHER(RngBitGenerator);
302 HLO_MATCHER(RngGetAndUpdateState);
303 HLO_MATCHER(Scatter);
304 HLO_MATCHER(Select);
305 HLO_MATCHER(SelectAndScatter);
306 HLO_MATCHER(Send);
307 HLO_MATCHER(SendDone);
308 HLO_MATCHER(SetDimensionSize);
309 HLO_MATCHER(ShiftLeft);
310 HLO_MATCHER(ShiftRightArithmetic);
311 HLO_MATCHER(ShiftRightLogical);
312 HLO_MATCHER(Sign);
313 HLO_MATCHER(Slice);
314 HLO_MATCHER(Sort);
315 HLO_MATCHER(Subtract);
316 HLO_MATCHER(Tanh);
317 HLO_MATCHER(Transpose);
318 HLO_MATCHER(Tuple);
319 HLO_MATCHER(While);
320 HLO_MATCHER(Xor);
321 HLO_MATCHER(OptimizationBarrier);
322
323 #define HLO_MATCHER_VECTOR_OPERANDS(opcode) \
324 template <> \
325 inline ::testing::Matcher<const ::xla::HloInstruction*> opcode( \
326 std::vector<::testing::Matcher<const HloInstruction*>> operands) { \
327 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \
328 ::xla::HloOpcode::k##opcode, operands)); \
329 }
330
331 HLO_MATCHER_VECTOR_OPERANDS(DynamicSlice);
332
333 // The special cases below let you check additional information about the
334 // HloInstruction, beyond just its opcode and operands. In all cases you can
335 // still use the generic matcher which doesn't check this info.
336 //
337 // Feel free to add additional custom matchers below.
338
339 // - Parameter(N) matches parameter number N.
340 // - Parameter() matches any parameter.
Parameter(int64_t parameter_number)341 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
342 int64_t parameter_number) {
343 return ::testing::MakeMatcher(
344 new ::xla::testing::HloParameterMatcher(parameter_number));
345 }
Parameter()346 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
347 return ::testing::MakeMatcher(
348 new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
349 }
350
351 // Comparison matchers below do not require any additional arguments.
352 template <typename... M>
Eq(M...operands)353 inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
354 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
355 ComparisonDirection::kEq, {operands...}));
356 }
357 template <typename... M>
Ne(M...operands)358 inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
359 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
360 ComparisonDirection::kNe, {operands...}));
361 }
362 template <typename... M>
Ge(M...operands)363 inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
364 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
365 ComparisonDirection::kGe, {operands...}));
366 }
367 template <typename... M>
Gt(M...operands)368 inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
369 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
370 ComparisonDirection::kGt, {operands...}));
371 }
372 template <typename... M>
Le(M...operands)373 inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
374 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
375 ComparisonDirection::kLe, {operands...}));
376 }
377 template <typename... M>
Lt(M...operands)378 inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
379 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
380 ComparisonDirection::kLt, {operands...}));
381 }
382
383 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
384 // tuple element of operand, while GetTupleElement(operand) matches any GTE
385 // operation on operand, and GetTupleElement() matches any GTE operation at all.
GetTupleElement(::testing::Matcher<const HloInstruction * > operand,int64_t tuple_index)386 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
387 ::testing::Matcher<const HloInstruction*> operand, int64_t tuple_index) {
388 return ::testing::MakeMatcher(
389 new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
390 }
GetTupleElement(::testing::Matcher<const HloInstruction * > operand)391 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
392 ::testing::Matcher<const HloInstruction*> operand) {
393 return ::testing::MakeMatcher(
394 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
395 }
GetTupleElement()396 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
397 return ::testing::MakeMatcher(
398 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
399 }
400
401 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
402 // target T and the given operands.
403 //
404 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
405 // given operands.
406 //
407 // - CustomCall() matches any CustomCall HLO at all.
408 template <typename... M>
CustomCall(::testing::Matcher<std::string> call_target_matcher,M...operands)409 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
410 ::testing::Matcher<std::string> call_target_matcher, M... operands) {
411 return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
412 call_target_matcher, {operands...}));
413 }
414 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
415 // ::testing::Matcher<std::string>. In that case, we want to prefer the
416 // overload above.
417 template <
418 typename FirstM, typename... M,
419 typename Dummy = typename std::enable_if<
420 !std::is_convertible<FirstM, ::testing::Matcher<std::string>>::value,
421 void>::type*>
CustomCall(FirstM operands_first,M...operands_rest)422 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
423 FirstM operands_first, M... operands_rest) {
424 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
425 HloOpcode::kCustomCall, {operands_first, operands_rest...}));
426 }
CustomCall()427 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
428 return ::testing::MakeMatcher(
429 new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
430 }
431
432 // Verifies the shape or the shape and the layout of an HLO instruction against
433 // the provided shape object.
Shape(const class Shape & shape)434 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
435 const class Shape& shape) {
436 return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
437 }
Shape(absl::string_view shape)438 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
439 absl::string_view shape) {
440 return ::testing::MakeMatcher(
441 new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
442 }
ShapeWithLayout(const class Shape & shape)443 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
444 const class Shape& shape) {
445 return ::testing::MakeMatcher(
446 new ::xla::testing::HloShapeAndLayoutMatcher(shape));
447 }
448 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
449 absl::string_view shape, bool minor_to_major_only = false) {
450 return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
451 ParseShape(shape).ValueOrDie(), minor_to_major_only));
452 }
453
454 // Verifies the value of the HloSharing against the provided sharding object.
Sharding(const HloSharding & sharding)455 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
456 const HloSharding& sharding) {
457 return ::testing::MakeMatcher(
458 new ::xla::testing::HloShardingMatcher(sharding));
459 }
460 // Matcher for Sharding from sharding string
Sharding(absl::string_view sharding)461 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
462 absl::string_view sharding) {
463 return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
464 ParseSharding(sharding).ValueOrDie()));
465 }
466 // Verifies that no HloSharding is set for an HLO instruction.
NoSharding()467 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
468 return ::testing::MakeMatcher(
469 new ::xla::testing::HloShardingMatcher(std::nullopt));
470 }
471
Dot()472 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot() {
473 return ::testing::MakeMatcher(
474 new ::xla::testing::HloMatcher(::xla::HloOpcode::kDot, {}));
475 }
476
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher)477 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
478 ::testing::Matcher<const HloInstruction*> lhs_matcher,
479 ::testing::Matcher<const HloInstruction*> rhs_matcher) {
480 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
481 ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
482 }
483
484 // Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
485 // equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
486 // equal to `rhs_contracting_dim`.
487 //
488 // Currently the HLO verifier rejects Dot operations with more than one
489 // contracting dimension (even though we can represent these in the
490 // DotDimensionNumbers proto) so there is no need to generalize this to support
491 // multiple contracting dimensions.
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher,int64_t lhs_contracting_dim,int64_t rhs_contracting_dim)492 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
493 ::testing::Matcher<const HloInstruction*> lhs_matcher,
494 ::testing::Matcher<const HloInstruction*> rhs_matcher,
495 int64_t lhs_contracting_dim, int64_t rhs_contracting_dim) {
496 return ::testing::MakeMatcher(
497 new ::xla::testing::HloDotWithContractingDimsMatcher(
498 lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
499 }
500
501 // Matcher for asynchronous copies from one memory space to another. Implies
502 // CopyDone(CopyStart(...)) where from_space and to_space is the source and
503 // destination memory spaces, respectively.
AsyncCopy(int64_t to_space,int64_t from_space,::testing::Matcher<const HloInstruction * > operand_matcher)504 inline ::testing::Matcher<const ::xla::HloInstruction*> AsyncCopy(
505 int64_t to_space, int64_t from_space,
506 ::testing::Matcher<const HloInstruction*> operand_matcher) {
507 return ::testing::MakeMatcher(new ::xla::testing::HloAsyncCopyMatcher(
508 to_space, from_space, operand_matcher));
509 }
510
511 // - Constant() matches any constant.
512 // - Constant(V) matches a constant with the given value.
Constant()513 inline ::testing::Matcher<const ::xla::HloInstruction*> Constant() {
514 return ::testing::MakeMatcher(
515 new ::xla::testing::HloMatcher(HloOpcode::kConstant, {}));
516 }
Constant(Literal value)517 inline ::testing::Matcher<const ::xla::HloInstruction*> Constant(
518 Literal value) {
519 return ::testing::MakeMatcher(
520 new ::xla::testing::HloConstantMatcher(std::move(value)));
521 }
522
ReplicaGroups(std::vector<std::vector<int64_t>> replica_groups)523 inline ::testing::Matcher<const ::xla::HloInstruction*> ReplicaGroups(
524 std::vector<std::vector<int64_t>> replica_groups) {
525 return ::testing::MakeMatcher(
526 new ::xla::testing::HloReplicaGroupsMatcher(std::move(replica_groups)));
527 }
528
529 #undef HLO_MATCHER
530 } // namespace opcode_matchers
531
532 // Helper to convert smart to raw pointers for matching.
533 template <typename Container>
Pointers(const Container & container)534 std::vector<const HloInstruction*> Pointers(const Container& container) {
535 std::vector<const HloInstruction*> result;
536 result.reserve(container.size());
537 for (const auto& entry : container) result.push_back(entry.get());
538 return result;
539 }
540
541 } // namespace testing
542
543 // Tell GMock to print HloInstruction* by value, so error messages are nice.
544 // Has to be in the same namespace as 'HloInstruction'.
545 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
546
547 } // namespace xla
548
549 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
550