xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_parser_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/status_matchers.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/platform/test.h"
39 
40 namespace xla {
41 namespace {
42 
43 namespace m = ::xla::match;
44 
45 using ::absl::string_view;
46 using ::testing::ElementsAre;
47 using ::testing::HasSubstr;
48 
49 struct TestData {
50   std::string test_name;
51   std::string module_string;
52   int64_t replica_count = 1;
53   bool enable_verification = true;
54 };
55 
TestDataToString(const::testing::TestParamInfo<TestData> & data)56 std::string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
57   return data.param.test_name;
58 }
59 
60 // Tests where the input module string doesn't match the output.
61 //
62 // In general we want to avoid these because we want HLO text to be
63 // round-trippable!  But nested instructions, e.g. add(sqrt(x), y), cannot be
64 // round-triped without modification.
65 struct NonRoundtripTestData {
66   std::string test_name;
67   std::string input_module_string;
68   std::string output_module_string;
69 };
70 
NonRoundtripTestDataToString(const::testing::TestParamInfo<NonRoundtripTestData> & data)71 std::string NonRoundtripTestDataToString(
72     const ::testing::TestParamInfo<NonRoundtripTestData>& data) {
73   return data.param.test_name;
74 }
75 
76 // For each string below, we check that:
77 //  - we parse it to an HloModule successfully, and
78 //  - the stringification of the resulting HloModule is equal to our original
79 //    string.
CreateTestCases()80 std::vector<TestData> CreateTestCases() {
81   // clang-format off
82   return std::vector<TestData>({
83 // ax + y
84 {
85 "AxpyParam",
86 R"(HloModule axpy_module, entry_computation_layout={(f32[],f32[2,4]{1,0},f32[2,4]{1,0})->f32[2,4]{1,0}}
87 
88 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
89   %alpha = f32[] parameter(0)
90   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
91   %x = f32[2,4]{1,0} parameter(1)
92   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
93   %y = f32[2,4]{1,0} parameter(2)
94   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
95 }
96 
97 )"
98 },
99 // parameter replication
100 {
101 "ParamReplication",
102 R"(HloModule param_replication_module, entry_computation_layout={(f32[],(f32[2,4]{1,0}, (f32[2,4]{1,0})))->(f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0})))}
103 
104 ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
105   %a = f32[] parameter(0), parameter_replication={true}
106   %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
107   ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
108 }
109 
110 )"
111 },
112 // pred constant
113 {
114 "ConstantPred",
115 R"(HloModule constant_pred_module, entry_computation_layout={()->pred[]}
116 
117 ENTRY %constant_pred () -> pred[] {
118   ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
119 }
120 
121 )"
122 },
123 // pred array constant
124 {
125 "ConstantPredArray",
126 R"(HloModule module, entry_computation_layout={()->pred[2,3]{1,0}}
127 
128 ENTRY %constant_pred_array () -> pred[2,3] {
129   ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
130 }
131 
132 )"
133 },
134 
135 // s32 constant
136 {
137 "ConstantS32",
138 R"(HloModule constant_s32_module, entry_computation_layout={()->s32[]}
139 
140 ENTRY %constant_s32 () -> s32[] {
141   ROOT %constant = s32[] constant(-42)
142 }
143 
144 )"
145 },
146 // f32 constant, but the value is not a decimal and there is a backend
147 // configuration
148 {
149 "ConstantF32",
150 R"(HloModule ConstantF32_module, entry_computation_layout={()->f32[]}
151 
152 ENTRY %ConstantF32.v4 () -> f32[] {
153   ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
154 }
155 
156 )"
157 },
158 // f32 constant, rank 1 empty array.
159 {
160 "ConstantF32R1Empty",
161 R"(HloModule ConstantF32Empty_module, entry_computation_layout={()->f32[0]{0}}
162 
163 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
164   ROOT %constant = f32[0]{0} constant({})
165 }
166 
167 )"
168 },
169 // f32 constant, rank 4 empty array.
170 {
171 "ConstantF32R4Empty",
172 R"(HloModule ConstantF32R4Empty_module, entry_computation_layout={()->f32[2,0,4,3]{3,2,1,0}}
173 
174 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
175   ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
176 }
177 
178 )"
179 },
180 // constant 4D
181 {
182 "Constant4D",
183 R"(HloModule Small_3x2x1x1_module, entry_computation_layout={()->f32[3,2,1,1]{3,2,1,0}}
184 
185 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
186   ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
187 }
188 
189 )"
190 },
191 // non-finite constants: nan, inf, -inf
192 {
193 "ConstantNonFinite",
194 R"(HloModule IsFiniteR1F32s_module, entry_computation_layout={()->pred[6]{0}}
195 
196 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
197   %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
198   ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
199 }
200 
201 )"
202 },
203 // constant f16
204 {
205 "ConstantF16",
206 R"(HloModule ConstantF16_module, entry_computation_layout={()->f16[]}
207 
208 ENTRY %ConstantF16.v4 () -> f16[] {
209   ROOT %constant = f16[] constant(500)
210 }
211 
212 )"
213 },
214 // bf16
215 {
216 "BF16",
217 R"(HloModule BF16, entry_computation_layout={()->bf16[]}
218 
219 ENTRY %BF16.v4 () -> bf16[] {
220   ROOT %constant = bf16[] constant(500)
221 }
222 
223 )"
224 },
225 // constant + constant
226 {
227 "AddConstants",
228 R"(HloModule add_constants_module, entry_computation_layout={()->f32[]}
229 
230 ENTRY %add_constants () -> f32[] {
231   %constant = f32[] constant(3.14)
232   ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
233 }
234 
235 )"
236 },
237 // tuple constant
238 {
239 "TupleConstant",
240 R"(HloModule TupleConstant_module, entry_computation_layout={()->(f32[2,1]{1,0}, f32[2]{0})}
241 
242 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
243   ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
244 }
245 
246 )"
247 },
248 // v1 > v2 ? v1 : v2
249 {
250 "SelectR1F32",
251 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module, entry_computation_layout={(f32[4]{0},f32[4]{0})->f32[4]{0}}
252 
253 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
254   %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
255   %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
256   %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated}
257   ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
258 }
259 
260 )"
261 },
262 // empty tuple
263 {
264 "EmptyTupleCreate",
265 R"(HloModule EmptyTupleCreate_module, entry_computation_layout={()->()}
266 
267 ENTRY %EmptyTupleCreate.v1 () -> () {
268   ROOT %tuple = () tuple()
269 }
270 
271 )"
272 },
273 // tuple
274 {
275 "TupleCreate",
276 R"(HloModule TupleCreate_module, entry_computation_layout={(f32[],f32[3]{0},f32[2,3]{1,0})->(f32[], f32[3]{0}, f32[2,3]{1,0})}
277 
278 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
279   %v1 = f32[] parameter(0)
280   %v2 = f32[3]{0} parameter(1)
281   %v3 = f32[2,3]{1,0} parameter(2)
282   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
283 }
284 
285 )"
286 },
287 // tuple
288 {
289 "LargeTupleRoundTrip",
290 R"(HloModule LargeTupleRoundTrip_module, entry_computation_layout={(f32[])->(f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[])}
291 
292 ENTRY %TupleCreate.v4 (v: f32[]) -> (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[]) {
293   %v = f32[] parameter(0)
294   ROOT %tuple = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[]) tuple(f32[] %v, f32[] %v, f32[] %v, f32[] %v, f32[] %v, /*index=5*/f32[] %v)
295 }
296 
297 )"
298 },
299 {
300 "ShardedTupleCreate",
301 R"(HloModule ShardedTupleCreate_module, entry_computation_layout={(f32[],f32[3]{0},f32[2,3]{1,0})->(f32[], f32[3]{0}, f32[2,3]{1,0})}
302 
303 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
304   %v1 = f32[] parameter(0), sharding={manual}
305   %v2 = f32[3]{0} parameter(1)
306   %v3 = f32[2,3]{1,0} parameter(2)
307   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{manual}, {maximal device=0}, {replicated}}
308 }
309 
310 )"
311 },
312 {
313 "DomainParsing",
314 R"(HloModule DomainParsing_module, entry_computation_layout={(f32[])->f32[]}
315 
316 ENTRY %DomainParsing (v1: f32[]) -> f32[] {
317   %v1 = f32[] parameter(0)
318   ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
319 }
320 
321 )"
322 },
323 // int32_t result = 0;
324 // while (result < 5) { result = result + 1; }
325 {
326 "WhileWithScalarS32Result",
327 R"(HloModule WhileWithScalarS32Result_module, entry_computation_layout={()->s32[]}
328 
329 %body.v3 (prev.1: s32[]) -> s32[] {
330   %constant = s32[] constant(1)
331   %prev.1 = s32[] parameter(0)
332   ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
333 }
334 
335 %condition.v3 (prev.2: s32[]) -> pred[] {
336   %constant.1 = s32[] constant(5)
337   %prev.2 = s32[] parameter(0)
338   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
339 }
340 
341 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
342   %constant.2 = s32[] constant(0)
343   ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
344 }
345 
346 )"
347 },
348 // copy-start and copy-done
349 {
350 "CopyStartAndCopyDone",
351 
352 R"(HloModule CopyStartAndCopyDone_module, entry_computation_layout={(f32[],f32[2,3]{1,0:S(1)})->(f32[], f32[2,3]{1,0:S(2)})}
353 
354 ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
355   %v1 = f32[] parameter(0)
356   %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true
357   %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
358   %v2 = f32[2,3]{1,0:S(1)} parameter(1)
359   %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
360   %copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
361   ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
362 }
363 
364 )"
365 },
366 // send and recv
367 {
368 "SendRecv",
369 R"(HloModule TwoSendRecvBothWayRecvFist_module, entry_computation_layout={()->(f32[], token[])}
370 
371 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
372   %token0 = token[] after-all()
373   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
374   ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
375   %constant = f32[] constant(2.1), sharding={maximal device=0}
376   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
377   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
378 }
379 
380 )"
381 },
382 {
383 "SendRecvWithHostTransfer",
384 R"(HloModule HostTransferSendRecv_module, entry_computation_layout={()->(f32[], token[])}
385 
386 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
387   %token0 = token[] after-all()
388   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
389   ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
390   %constant = f32[] constant(2.1), sharding={maximal device=0}
391   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
392   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
393 }
394 
395 )"
396 },
397 // get-tuple-element
398 {
399 "GetTupleElement",
400 R"(HloModule GetTupleElement_module, entry_computation_layout={()->s32[2,3]{1,0}}
401 
402 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
403   %constant = f32[3]{0} constant({1, 2, 3})
404   %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
405   %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
406   ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
407 }
408 
409 )"
410 },
411 // call
412 {
413 "Call",
414 R"(HloModule CallR0F32IdentityScalar_module, entry_computation_layout={()->f32[]}
415 
416 %Identity.v1 (x: f32[]) -> f32[] {
417   ROOT %x = f32[] parameter(0)
418 }
419 
420 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
421   %constant = f32[] constant(42)
422   ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
423 }
424 
425 )"
426 },
427 // CustomCall with backend_config.
428 {
429 "CustomCallWithOpaque",
430 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
431 
432 ENTRY %CustomCall () -> f32[1,2,3] {
433   %constant = f32[1]{0} constant({12345})
434   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
435 }
436 
437 )"
438 },
439 
440 // CustomCall with literal.
441 {
442 "CustomCallWithLiteral",
443 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
444 
445 ENTRY %CustomCall () -> f32[1,2,3] {
446   %constant = f32[1]{0} constant({12345})
447   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=s32[2]{0} {1, 2}
448 }
449 
450 )"
451 },
452 
453 // CustomCall with literal tuple.
454 {
455 "CustomCallWithLiteralTuple",
456 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
457 
458 ENTRY %CustomCall () -> f32[1,2,3] {
459   %constant = f32[1]{0} constant({12345})
460   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=( s32[4]{0} {4, 128, 128, 3}, pred[4]{0} {1, 0, 0, 0} )
461 }
462 
463 )"
464 },
465 
466 // CustomCall with literal R0.
467 {
468 "CustomCallWithLiteralR0",
469 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
470 
471 ENTRY %CustomCall () -> f32[1,2,3] {
472   %constant = f32[1]{0} constant({12345})
473   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=f32[] 0.1
474 }
475 
476 )"
477 },
478 // reduce window
479 {
480 "ReduceWindow",
481 R"(HloModule R4UnitWindow_module, entry_computation_layout={(f32[13,12,8,15]{0,3,2,1})->f32[13,3,8,15]{0,3,2,1}}
482 
483 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
484   %lhs = f32[] parameter(0)
485   %rhs = f32[] parameter(1)
486   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
487 }
488 
489 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
490   %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
491   %constant = f32[] constant(0)
492   ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
493 }
494 
495 )"
496 },
497 // reduce window on scalar
498 {
499 "ReduceWindowScalar",
500 R"(HloModule reduce_window_scalar, entry_computation_layout={()->f32[]}
501 
502 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
503   %lhs = f32[] parameter(0)
504   %rhs = f32[] parameter(1)
505   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
506 }
507 
508 ENTRY %R4UnitWindowScalar () -> f32[] {
509   %constant = f32[] constant(42)
510   %constant.1 = f32[] constant(1)
511   ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
512 }
513 
514 )"
515 },
516 // reduce window on scalar
517 {
518 "ReduceWindowVariadic",
519 R"(HloModule reduce_window_variadic, entry_computation_layout={()->(f32[], f32[])}
520 
521 %add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) {
522   %lhs1 = f32[] parameter(0)
523   %rhs1 = f32[] parameter(2)
524   %add1 = f32[] add(f32[] %lhs1, f32[] %rhs1)
525   %lhs2 = f32[] parameter(1)
526   %rhs2 = f32[] parameter(3)
527   %add2 = f32[] add(f32[] %lhs2, f32[] %rhs2)
528   ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2)
529 }
530 
531 ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) {
532   %constant = f32[] constant(42)
533   %constant.1 = f32[] constant(1)
534   ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3
535 }
536 
537 )"
538 },
539 // convolution
540 {
541 "Convolution",
542 R"(HloModule Convolve1D1Window_0_module, entry_computation_layout={(f32[1,2,1]{2,1,0},f32[1,1,1]{2,1,0})->f32[1,2,1]{2,0,1}}
543 
544 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
545   %input = f32[1,2,1]{2,1,0} parameter(0)
546   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
547   %filter = f32[1,1,1]{2,1,0} parameter(1)
548   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
549 }
550 
551 )"
552 },
553 // convolution dynamic
554 {
555 "ConvolutionDynamic",
556 R"(HloModule Convolve1D1Window_0_module, entry_computation_layout={(f32[1,2,1]{2,1,0},f32[1,1,1]{2,1,0})->f32[1,2,1]{2,0,1}}
557 
558 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
559   %input = f32[1,2,1]{2,1,0} parameter(0)
560   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
561   %filter = f32[1,1,1]{2,1,0} parameter(1)
562   ROOT %custom-call.52 = f32[1,2,1]{2,0,1} custom-call(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}, custom_call_target="DynamicConvolutionForward", metadata={op_type="Conv2D" op_name="conv1d"}
563 }
564 
565 )"
566 },
567 // convolution rank 2
568 {
569 "ConvolutionR2",
570 R"(HloModule ConvolveR2_module, entry_computation_layout={(f32[1,2]{1,0},f32[2,2]{1,0})->f32[1,2]{0,1}}
571 
572 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
573   %input = f32[1,2]{1,0} parameter(0)
574   %filter = f32[2,2]{1,0} parameter(1)
575   ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
576 }
577 
578 )"
579 },
580 // convolution backward
581 {
582 "ConvolutionBackward",
583 R"(HloModule ConvolveBackward_module, entry_computation_layout={(f32[128,7,7,512]{0,3,2,1},f32[3,3,512,512]{3,2,1,0})->f32[128,14,14,512]{0,3,2,1}}
584 
585 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
586   %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
587   %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
588   ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
589 }
590 
591 )"
592 },
593 // reverse(constant)
594 {
595 "Reverse4D",
596 R"(HloModule Reverse4DFloatArrayOnDim01_module, entry_computation_layout={()->f32[4,3,2,1]{0,1,2,3}}
597 
598 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
599   %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
600   ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
601 }
602 
603 )"
604 },
605 // concat
606 {
607 "Concat",
608 R"(HloModule Concat2x3With2x5_module, entry_computation_layout={()->f32[2,8]{1,0}}
609 
610 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
611   %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
612   %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
613   ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
614 }
615 
616 )"
617 },
618 // select and scatter
619 {
620 "SelectAndScatter",
621 R"(HloModule R4F32OverlapSmall_module, entry_computation_layout={()->f32[4,5,1,1]{3,2,1,0}}
622 
623 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
624   %lhs = f32[] parameter(0)
625   %rhs = f32[] parameter(1)
626   ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER
627 }
628 
629 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
630   %lhs.1 = f32[] parameter(0)
631   %rhs.1 = f32[] parameter(1)
632   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
633 }
634 
635 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
636   %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
637   %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
638   %constant.2 = f32[] constant(0)
639   ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
640 }
641 
642 )"
643 },
644 // select and scatter on scalar
645 {
646 "SelectAndScatterScalar",
647 R"(HloModule select_and_scatter_scalar, entry_computation_layout={()->f32[]}
648 
649 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
650   %lhs = f32[] parameter(0)
651   %rhs = f32[] parameter(1)
652   ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
653 }
654 
655 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
656   %lhs.1 = f32[] parameter(0)
657   %rhs.1 = f32[] parameter(1)
658   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
659 }
660 
661 ENTRY %SelectAndScatterScalar () -> f32[] {
662   %constant = f32[] constant(42)
663   %constant.1 = f32[] constant(1)
664   %constant.2 = f32[] constant(2)
665   ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
666 }
667 
668 )"
669 },
670 // slice
671 {
672 "Slice",
673 R"(HloModule slice_module, entry_computation_layout={(f32[3,3,4,4]{3,2,1,0})->f32[3,3,2,4]{3,2,1,0}}
674 
675 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
676   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
677   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
678 }
679 
680 )"
681 },
682 // slice, no stride
683 {
684 "SliceNoStride",
685 R"(HloModule Slice3x3x3_To_1x3x3_F32_module, entry_computation_layout={()->f32[1,3,3]{2,1,0}}
686 
687 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
688   %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
689   ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
690 }
691 
692 )"
693 },
694 // slice R0
695 {
696 "SliceR0",
697 R"(HloModule SliceR0_module, entry_computation_layout={()->s32[]}
698 
699 ENTRY %SliceR0.v2 () -> s32[] {
700   %constant = s32[] constant(1)
701   ROOT %slice = s32[] slice(s32[] %constant), slice={}
702 }
703 
704 )"
705 },
706 // transpose
707 {
708 "Transpose",
709 R"(HloModule Transpose_module, entry_computation_layout={()->s32[1,2,3]{2,1,0}}
710 
711 ENTRY %Transpose.v2 () -> s32[1,2,3] {
712   %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
713   ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
714 }
715 
716 )"
717 },
718 {
719 "TransposeC128",
720 R"(HloModule TransposeC128_module, entry_computation_layout={(c128[1,2,3]{2,1,0})->c128[1,2,3]{2,1,0}}
721 
722 ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
723   %input = c128[1,2,3]{2,1,0} parameter(0)
724   ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
725 }
726 
727 )"
728 },
729 // Triangular solve
730 {
731 "TriangularSolve",
732 R"(HloModule TriangularSolve_module, entry_computation_layout={(f32[4,4]{1,0},f32[3,4]{1,0})->f32[3,4]{1,0}}
733 
734 ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
735   %a.1 = f32[4,4]{1,0} parameter(0)
736   %b.2 = f32[3,4]{1,0} parameter(1)
737   ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
738 }
739 
740 )"
741 },
742 // Dynamic slice
743 {
744 "DynamicSlice",
745 R"(HloModule DynamicSlice_module, entry_computation_layout={(s32[2,2,258]{2,1,0},s32[1]{0})->s32[2,2,258]{2,1,0}}
746 
747 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
748   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
749   %constant = s32[1]{0} constant({0})
750   %start_index = s32[1]{0} parameter(1)
751   %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
752   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
753 }
754 
755 )"
756 },
757 // Dynamic slice with scalar indices
758 {
759 "DynamicSliceScalarIndices",
760 R"(HloModule DynamicSlice_module, entry_computation_layout={(s32[2,2,258]{2,1,0},s32[])->s32[2,2,258]{2,1,0}}
761 
762 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
763   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
764   %constant = s32[] constant(0)
765   %start_index = s32[] parameter(1)
766   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
767 }
768 
769 )"
770 },
771 // Dynamic update slice
772 {
773 "DynamicUpdateSlice",
774 R"(HloModule DynamicSlice_module, entry_computation_layout={(s32[1,1,25,1]{3,2,1,0},s32[1,1,2,1]{3,2,1,0},s32[4]{0})->s32[1,1,25,1]{3,2,1,0}}
775 
776 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
777   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
778   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
779   %start_indices = s32[4]{0} parameter(2)
780   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
781 }
782 
783 )"
784 },
785 // Dynamic update slice with scalar indices
786 {
787 "DynamicUpdateSliceScalarIndex",
788 R"(HloModule DynamicUpdateSlice_module, entry_computation_layout={(s32[1,1,25,1]{3,2,1,0},s32[1,1,2,1]{3,2,1,0},s32[],s32[],s32[],s32[])->s32[1,1,25,1]{3,2,1,0}}
789 
790 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
791   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
792   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
793   %start_index.0 = s32[] parameter(2)
794   %start_index.1 = s32[] parameter(3)
795   %start_index.2 = s32[] parameter(4)
796   %start_index.3 = s32[] parameter(5)
797   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, /*index=5*/s32[] %start_index.3)
798 }
799 
800 )"
801 },
802 // batch norm training
803 {
804 "BatchNormTraining",
805 R"(HloModule BasicTraining_module, entry_computation_layout={()->(f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})}
806 
807 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
808   %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
809   %constant.1 = f32[2]{0} constant({2, 3})
810   %constant.2 = f32[2]{0} constant({1, 2})
811   ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
812 }
813 
814 )"
815 },
816 // batch norm inference
817 {
818 "BatchNormInference",
819 R"(HloModule BatchNormInference_module, entry_computation_layout={(f32[2,2,2,2]{3,2,1,0},f32[2]{0},f32[2]{0},f32[2]{0},f32[2]{0})->f32[2,2,2,2]{3,2,1,0}}
820 
821 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
822   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
823   %offset = f32[2]{0} parameter(1)
824   %scale = f32[2]{0} parameter(2)
825   %mean = f32[2]{0} parameter(3)
826   %variance = f32[2]{0} parameter(4)
827   ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
828 }
829 
830 )"
831 },
832 // batch norm grad
833 {
834 "BatchNormGrad",
835 R"(HloModule BatchNormGrad_module, entry_computation_layout={(f32[2,2,2,2]{3,2,1,0},f32[2]{0},f32[2]{0},f32[2]{0},f32[2,2,2,2]{3,2,1,0})->(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})}
836 
837 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
838   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
839   %scale = f32[2]{0} parameter(1)
840   %mean = f32[2]{0} parameter(2)
841   %variance = f32[2]{0} parameter(3)
842   %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
843   ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
844 }
845 
846 )"
847 },
848 // fft
849 {
850 "Fft",
851 R"(HloModule Fft_module, entry_computation_layout={(c64[8,32]{1,0})->c64[8,32]{1,0}}
852 
853 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
854   %input = c64[8,32]{1,0} parameter(0)
855   ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
856 }
857 
858 )"
859 },
860 // ifft
861 {
862 "Ifft2d",
863 R"(HloModule Ifft2d_module, entry_computation_layout={(c64[5,8,32]{2,1,0})->c64[5,8,32]{2,1,0}}
864 
865 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
866   %input = c64[5,8,32]{2,1,0} parameter(0)
867   ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
868 }
869 
870 )"
871 },
872 // rfft2d
873 {
874 "Rfft2d",
875 R"(HloModule Rfft2d_module, entry_computation_layout={(f32[5,64,32]{2,1,0})->c64[5,64,17]{2,1,0}}
876 
877 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
878   %input = f32[5,64,32]{2,1,0} parameter(0)
879   ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
880 }
881 
882 )"
883 },
884 // irfft3d
885 {
886 "Irfft3d",
887 R"(HloModule Irfft3d_module, entry_computation_layout={(c64[5,64,128,33]{3,2,1,0})->f32[5,64,128,64]{3,2,1,0}}
888 
889 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
890   %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
891   ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
892 }
893 
894 )"
895 },
896 // pad
897 {
898 "Pad",
899 R"(HloModule Pad1DS3Array_module, entry_computation_layout={()->f32[7]{0}}
900 
901 ENTRY %Pad1DS3Array.v3 () -> f32[7] {
902   %constant = f32[3]{0} constant({1, 2, 3})
903   %constant.1 = f32[] constant(0.1)
904   ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
905 }
906 
907 )"
908 },
909 // pad has interior
910 {
911 "PadHasInterior",
912 R"(HloModule PadHasInterior_module, entry_computation_layout={(f32[1,25,7,7]{3,2,1,0})->f32[1,25,17,11]{3,2,1,0}}
913 
914 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
915   %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
916   %constant = f32[] constant(-5.123)
917   ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
918 }
919 
920 )"
921 },
922 // round to nearest even
923 {
924 "RoundNearestEven",
925 R"(HloModule RoundNearestEven_module, entry_computation_layout={(f32[2,2]{1,0})->f32[2,2]{1,0}}
926 
927 ENTRY %RoundNearestEven (input: f32[2,2]) -> f32[2,2] {
928   %input = f32[2,2]{1,0} parameter(0)
929   ROOT %round-nearest-even = f32[2,2]{1,0} round-nearest-even(f32[2,2]{1,0} %input)
930 }
931 
932 )"
933 },
934 // Negative padding
935 {
936 "PadHasNegativePadding",
937 R"(HloModule PadHasNegativePadding_module, entry_computation_layout={(f32[1,25,7,7,10]{4,3,2,1,0})->f32[1,15,6,3,35]{4,3,2,1,0}}
938 
939 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
940   %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
941   %constant = f32[] constant(-5.123)
942   ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
943 }
944 
945 )"
946 },
947 // fusion
948 {
949 "Fusion",
950 R"(HloModule fusion_module, entry_computation_layout={()->f32[3,2,1,1]{3,2,1,0}}
951 
952 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
953   %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
954   %constant.1.param_1 = f32[2]{0} parameter(1)
955   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
956   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
957 }
958 
959 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
960   %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
961   %constant.1 = f32[2]{0} constant({3.14, 4.25})
962   ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
963 }
964 
965 )"
966 },
967 {
968 "Gather",
969 R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}}
970 
971 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
972   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
973   %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
974   ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
975 }
976 
977 )"
978 },
979 {
980 "SortedGather",
981 R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}}
982 
983 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
984   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
985   %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
986   ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true
987 }
988 
989 )"
990 },
991 {
992 "Scatter",
993 R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
994 
995 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
996   %lhs = f32[] parameter(0)
997   %rhs = f32[] parameter(1)
998   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
999 }
1000 
1001 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
1002   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1003   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1004   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
1005   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
1006 }
1007 
1008 )"
1009 },
1010 {
1011 "TupleScatter",
1012 R"(HloModule TupleScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},bf16[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0},bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->(f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0})}
1013 
1014 %add_F32_mul_BF16 (lhs_0: f32[], lhs_1: bf16[], rhs_0: f32[], rhs_1: bf16[]) -> (f32[], bf16[]) {
1015   %lhs_0 = f32[] parameter(0)
1016   %rhs_0 = f32[] parameter(2)
1017   %add = f32[] add(f32[] %lhs_0, f32[] %rhs_0)
1018   %lhs_1 = bf16[] parameter(1)
1019   %rhs_1 = bf16[] parameter(3)
1020   %mul = bf16[] multiply(bf16[] %lhs_1, bf16[] %rhs_1)
1021   ROOT %tuple = (f32[], bf16[]) tuple(f32[] %add, bf16[] %mul)
1022 }
1023 
1024 ENTRY %Scatter (input_0: f32[50,49,48,47,46], input_1: bf16[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates_0: f32[10,9,8,7,30,29,28,27,26], updates_1: bf16[10,9,8,7,30,29,28,27,26]) -> (f32[50,49,48,47,46], bf16[50,49,48,47,46]) {
1025   %input_0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1026   %input_1 = bf16[50,49,48,47,46]{4,3,2,1,0} parameter(1)
1027   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(2)
1028   %updates_0 = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(3)
1029   %updates_1 = bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(4)
1030   ROOT %scatter = (f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0}) scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_0, bf16[50,49,48,47,46]{4,3,2,1,0} %input_1, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates_0, bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates_1), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32_mul_BF16
1031 }
1032 
1033 )"
1034 },
1035 {
1036 "SortedScatter",
1037 R"(HloModule StringifySortedScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
1038 
1039 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
1040   %lhs = f32[] parameter(0)
1041   %rhs = f32[] parameter(1)
1042   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
1043 }
1044 
1045 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
1046   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1047   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1048   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
1049   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
1050 }
1051 
1052 )"
1053 },
1054 {
1055 "UniqueIndicesScatter",
1056 R"(HloModule StringifyUniqueIndicesScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
1057 
1058 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
1059   %lhs = f32[] parameter(0)
1060   %rhs = f32[] parameter(1)
1061   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
1062 }
1063 
1064 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
1065   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1066   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1067   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
1068   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
1069 }
1070 
1071 )"
1072 },
1073 {
1074   "ConstantUnsignedNoUnderflow",
1075   R"(HloModule ConstantUnsignedNoUnderflow_module, entry_computation_layout={()->u64[]}
1076 
1077 ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
1078   ROOT %constant = u64[] constant(1)
1079 }
1080 
1081 )"
1082 },
1083 
1084 {
1085   "ConstantUnsignedNoOverflow",
1086   R"(HloModule ConstantUnsignedNoOverflow_module, entry_computation_layout={()->u64[]}
1087 
1088 ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
1089   ROOT %constant = u64[] constant(9223372036854775807)
1090 }
1091 
1092 )"
1093 },
1094 // CustomCallWithLayoutConstraints
1095 {
1096 "CustomCallWithLayoutConstraints",
1097 R"(HloModule CustomCallWithLayoutConstraints, entry_computation_layout={(f32[42,2,3]{0,1,2},f32[123,4]{0,1})->f32[1,2,3]{0,2,1}}
1098 
1099 ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
1100   %p0 = f32[42,2,3]{0,1,2} parameter(0)
1101   %p1 = f32[123,4]{0,1} parameter(1)
1102   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
1103 }
1104 
1105 )"
1106 },
1107 // CustomCallWithLayoutConstraintsNoOperands
1108 {
1109 "CustomCallWithLayoutConstraintsNoOperands",
1110 R"(HloModule CustomCallWithLayoutConstraintsNoOperands, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1111 
1112 ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
1113   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1114 }
1115 
1116 )"
1117 },
1118 // CustomCallWithLayoutConstraintsTupleShapes
1119 {
1120 "CustomCallWithLayoutConstraintsTupleShapes",
1121 R"(HloModule CustomCallWithLayoutConstraintsTupleShapes, entry_computation_layout={((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}),f32[123,4]{0,1})->(f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0})}
1122 
1123 ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
1124   %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
1125   %p1 = f32[123,4]{0,1} parameter(1)
1126   ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
1127 }
1128 
1129 )"
1130 },
1131 // CustomCallWithHasSideEffect
1132 {
1133 "CustomCallWithHasSideEffect",
1134 R"(HloModule CustomCallWithHasSideEffect, entry_computation_layout={((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}),f32[123,4]{0,1})->(f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0})}
1135 
1136 ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
1137   %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
1138   %p1 = f32[123,4]{0,1} parameter(1)
1139   ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
1140 }
1141 
1142 )"
1143 },
1144 // CustomCallWithAliasing
1145 {
1146 "CustomCallWithAliasing",
1147 R"(HloModule CustomCallWithAliasing, entry_computation_layout={((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}),f32[123,4]{0,1})->(f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2})}
1148 
1149 ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
1150   %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
1151   %p1 = f32[123,4]{0,1} parameter(1)
1152   ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
1153 }
1154 
1155 )"
1156 },
1157 // CustomCall with schedule.
1158 {
1159 "CustomCallWithSchedule",
1160 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1161 
1162 ENTRY %CustomCall () -> f32[1,2,3] {
1163   %constant = f32[1]{0} constant({12345})
1164   %custom-call.0 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", schedule=SCHEDULE_EARLIEST
1165   ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1,2,3]{0,2,1} %custom-call.0), custom_call_target="bar", schedule=SCHEDULE_LATEST
1166 }
1167 
1168 )"
1169 },
1170 // CustomCall that returns a status.
1171 {
1172 "CustomCallWithStatusReturningVersion",
1173 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1174 
1175 ENTRY %CustomCall () -> f32[1,2,3] {
1176   %constant = f32[1]{0} constant({12345})
1177   ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", api_version=API_VERSION_STATUS_RETURNING
1178 }
1179 
1180 )"
1181 },
1182 // Parse c64 literal
1183 {
1184 "ParseC64Literal",
1185 R"(HloModule ParseC64Literal, entry_computation_layout={()->c64[2]{0}}
1186 
1187 ENTRY %ParseC64Literal () -> c64[2] {
1188   ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
1189 }
1190 
1191 )"
1192 },
1193 // Parse c128 literal
1194 {
1195 "ParseC128Literal",
1196 R"(HloModule ParseC128Literal, entry_computation_layout={()->c128[2]{0}}
1197 
1198 ENTRY %ParseC128Literal () -> c128[2] {
1199   ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
1200 }
1201 
1202 )"
1203 },
1204 // Indexed Conditional
1205 {
1206 "IndexedConditional",
1207 R"(HloModule indexed_conditional, entry_computation_layout={()->f32[]}
1208 
1209 %Negate (x: f32[]) -> f32[] {
1210   %x = f32[] parameter(0)
1211   ROOT %negate = f32[] negate(f32[] %x)
1212 }
1213 
1214 %Identity (y: f32[]) -> f32[] {
1215   %y = f32[] parameter(0)
1216   ROOT %copy = f32[] copy(f32[] %y)
1217 }
1218 
1219 %Floor (z: f32[]) -> f32[] {
1220   %z = f32[] parameter(0)
1221   ROOT %floor = f32[] floor(f32[] %z)
1222 }
1223 
1224 ENTRY %Parameters1.v4 () -> f32[] {
1225   %constant = s32[] constant(1)
1226   %constant.1 = f32[] constant(56)
1227   %constant.2 = f32[] constant(12)
1228   %constant.3 = f32[] constant(13)
1229   ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
1230 }
1231 
1232 )"
1233 },
1234 // rng-get-and-update-state
1235 {
1236 "RngGetAndUpdateState",
1237 R"(HloModule rng_get_and_update_state, entry_computation_layout={()->u64[2]{0}}
1238 
1239 ENTRY %RngGetAndUpdateState () -> u64[2] {
1240   ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=4096
1241 }
1242 
1243 )"
1244 },
1245 {
1246 "RngBitGenerator",
1247 R"(HloModule gng_bit_generator, entry_computation_layout={(u64[2]{0})->(u64[2]{0}, u32[11,17]{1,0})}
1248 
1249 ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[11,17]) {
1250   %p0 = u64[2]{0} parameter(0)
1251   ROOT %rand = (u64[2]{0}, u32[11,17]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry
1252 }
1253 
1254 )"
1255 },
1256 // Async ops with syntax sugar.
1257 {
1258 "AsyncOpsWithSyntaxSugar",
1259 R"(HloModule AsyncOpsWithSyntaxSugar, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1260 
1261 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1262   %p0 = f32[10]{0} parameter(0)
1263   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo"
1264   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), custom_call_target="foo"
1265   ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), custom_call_target="foo"
1266 }
1267 
1268 )"
1269 },
1270 {
1271 "AsyncOpsWithSyntaxSugarAndGroupId",
1272 R"(HloModule AsyncOpsWithSyntaxSugarAndGroupId, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1273 
1274 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1275   %p0 = f32[10]{0} parameter(0)
1276   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_group_id=3, custom_call_target="foo"
1277   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_group_id=3, custom_call_target="foo"
1278   ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_group_id=3, custom_call_target="foo"
1279 }
1280 
1281 )"
1282 },
1283 // Async ops with syntax sugar and async thread name.
1284 {
1285 "AsyncOpsWithSyntaxSugarAndThreadName",
1286 R"(HloModule AsyncOpsWithSyntaxSugarAndThreadName, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1287 
1288 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1289   %p0 = f32[10]{0} parameter(0)
1290   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo"
1291   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
1292   ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo"
1293 }
1294 
1295 )"
1296 },
1297 // HloComputation with thread name as attribute.
1298 {
1299 "HloComputationWithParallelThreadName",
1300 R"(HloModule HloComputationWithParallelThreadName, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1301 
1302 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1303   %p0 = f32[10]{0} parameter(0)
1304   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo"
1305   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
1306   ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo"
1307 }, execution_thread="main_thread"
1308 
1309 )"
1310   },
1311 });
1312   // clang-format on
1313 }
1314 
1315 std::vector<TestData> CreateShortTestCases() {
1316   // clang-format off
1317   return std::vector<TestData>({
1318 // map
1319 {
1320 "Map",
1321 R"(HloModule MapBinaryAdder_module, entry_computation_layout={(f32[4]{0},f32[4]{0})->f32[4]{0}}
1322 
1323 add_F32.v3 {
1324   lhs = f32[] parameter(0)
1325   rhs = f32[] parameter(1)
1326   ROOT add = f32[] add(lhs, rhs)
1327 }
1328 
1329 ENTRY MapBinaryAdder.v3 {
1330   param0 = f32[4]{0} parameter(0)
1331   param1 = f32[4]{0} parameter(1)
1332   ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
1333 }
1334 
1335 )"
1336 },
1337 // reduce
1338 {
1339 "Reduce",
1340 R"(HloModule ReduceR3ToR2_module, entry_computation_layout={(f32[8,16,256]{2,1,0})->f32[8,16]{1,0}}
1341 
1342 add_F32.v3 {
1343   lhs = f32[] parameter(0)
1344   rhs = f32[] parameter(1)
1345   ROOT add = f32[] add(lhs, rhs)
1346 }
1347 
1348 ENTRY ReduceR3ToR2.v3 {
1349   input = f32[8,16,256]{2,1,0} parameter(0)
1350   constant = f32[] constant(0)
1351   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
1352 }
1353 
1354 )"
1355 },
1356 // tuple reduce
1357 {
1358 "TupleReduce",
1359 R"(HloModule TupleReduce, entry_computation_layout={(f32[1024]{0},s32[1024]{0})->(f32[], s32[])}
1360 
1361 max_argmax {
1362   value = f32[] parameter(2)
1363   prev_max = f32[] parameter(0)
1364   is_next_larger = pred[] compare(value, prev_max), direction=GE
1365   max = f32[] select(is_next_larger, value, prev_max)
1366   index = s32[] parameter(3)
1367   prev_argmax = s32[] parameter(1)
1368   argmax = s32[] select(is_next_larger, index, prev_argmax)
1369   ROOT pair = (f32[], s32[]) tuple(max, argmax)
1370 }
1371 
1372 ENTRY reduce_entry {
1373   values = f32[1024]{0} parameter(0)
1374   indices = s32[1024]{0} parameter(1)
1375   init_value = f32[] constant(-inf)
1376   init_index = s32[] constant(-1)
1377   ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
1378 }
1379 
1380 )"
1381 },
1382 // infeed/outfeed
1383 {
1384 "InfeedOutfeed",
1385 R"(HloModule outfeed_module, entry_computation_layout={()->((u32[3]{0}, pred[]), token[])}
1386 
1387 ENTRY InfeedToOutfeed {
1388   token0 = token[] after-all()
1389   infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1390   infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
1391   outfeed = token[] outfeed(infeed.data, token0), outfeed_shape=(u32[3]{0}, pred[])
1392   ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1393   infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
1394   infeed.1.token = token[] get-tuple-element(infeed.1), index=1
1395   outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token), outfeed_shape=(u32[3]{0}, pred[])
1396 }
1397 
1398 )"
1399 },
1400 // Rng
1401 {
1402 "Rng",
1403 R"(HloModule rng_module, entry_computation_layout={()->f32[8]{0}}
1404 
1405 ENTRY Rng {
1406   constant = f32[] constant(0)
1407   constant.1 = f32[] constant(1)
1408   ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
1409 }
1410 
1411 )"
1412 },
1413 // Reduce precision
1414 {
1415 "ReducePrecision",
1416 R"(HloModule reduce_precision, entry_computation_layout={()->f32[1]{0}}
1417 
1418 ENTRY ReducePrecision {
1419   constant = f32[1]{0} constant({3.14159})
1420   ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
1421 }
1422 
1423 )"
1424 },
1425 // Sort (Key)
1426 {
1427 "SortKey",
1428 R"(HloModule sort, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}}
1429 
1430 compare {
1431   p.0.lhs = f32[] parameter(0)
1432   p.0.rhs = f32[] parameter(1)
1433   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1434 }
1435 
1436 ENTRY Sort {
1437   x = f32[1024]{0} parameter(0)
1438   ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
1439 }
1440 
1441 )"
1442 },
1443 // Sort (Key, Value)
1444 {
1445 "SortKeyValue",
1446 R"(HloModule sort, entry_computation_layout={(f32[1024]{0},s32[1024]{0})->(f32[1024]{0}, s32[1024]{0})}
1447 
1448 compare {
1449   p.1.lhs = s32[] parameter(2)
1450   p.1.rhs = s32[] parameter(3)
1451   p.0.lhs = f32[] parameter(0)
1452   p.0.rhs = f32[] parameter(1)
1453   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1454 }
1455 
1456 ENTRY Sort {
1457   keys = f32[1024]{0} parameter(0)
1458   values = s32[1024]{0} parameter(1)
1459   ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1460 }
1461 
1462 )"
1463 },
1464 // R2 Sort (Key)
1465 {
1466 "SortKeyR2",
1467 R"(HloModule sort, entry_computation_layout={(f32[1024,16]{0,1})->f32[1024,16]{0,1}}
1468 
1469 compare {
1470   p.0.lhs = f32[] parameter(0)
1471   p.0.rhs = f32[] parameter(1)
1472   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1473 }
1474 
1475 ENTRY Sort {
1476   x = f32[1024,16]{0,1} parameter(0)
1477   ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
1478 }
1479 
1480 )"
1481 },
1482 // R2 Sort (Key, Value)
1483 {
1484 "SortKeyValueR2",
1485 R"(HloModule sort, entry_computation_layout={(f32[1024,16]{0,1},s32[1024,16]{0,1})->(f32[1024,16]{0,1}, s32[1024,16]{0,1})}
1486 
1487 compare {
1488   p.1.lhs = s32[] parameter(2)
1489   p.1.rhs = s32[] parameter(3)
1490   p.0.lhs = f32[] parameter(0)
1491   p.0.rhs = f32[] parameter(1)
1492   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1493 }
1494 
1495 ENTRY Sort {
1496   keys = f32[1024,16]{0,1} parameter(0)
1497   values = s32[1024,16]{0,1} parameter(1)
1498   ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
1499 }
1500 
1501 )"
1502 },
1503 // Sort (Key, Value, Value, Value)
1504 {
1505 "SortManyValues",
1506 R"(HloModule sort, entry_computation_layout={(f32[1024,16]{0,1},s32[1024,16]{0,1},u32[1024,16]{0,1},f32[1024,16]{0,1})->(f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1})}
1507 
1508 compare {
1509   p.1.lhs = s32[] parameter(2)
1510   p.1.rhs = s32[] parameter(3)
1511   p.2.lhs = u32[] parameter(4)
1512   p.2.rhs = u32[] parameter(5)
1513   p.3.lhs = f32[] parameter(6)
1514   p.3.rhs = f32[] parameter(7)
1515   p.0.lhs = f32[] parameter(0)
1516   p.0.rhs = f32[] parameter(1)
1517   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1518 }
1519 
1520 ENTRY Sort {
1521   keys = f32[1024,16]{0,1} parameter(0)
1522   values.0 = s32[1024,16]{0,1} parameter(1)
1523   values.1 = u32[1024,16]{0,1} parameter(2)
1524   values.2 = f32[1024,16]{0,1} parameter(3)
1525   ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
1526 }
1527 
1528 )"
1529 },
1530 // Sort (Key) is_stable=true
1531 {
1532 "SortKeyStable",
1533 R"(HloModule sort, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}}
1534 
1535 compare {
1536   p.0.lhs = f32[] parameter(0)
1537   p.0.rhs = f32[] parameter(1)
1538   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1539 }
1540 
1541 ENTRY Sort {
1542   x = f32[1024]{0} parameter(0)
1543   ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
1544 }
1545 
1546 )"
1547 },
1548 // Indexed Conditional
1549 {
1550 "IndexedConditional",
1551 R"(HloModule indexed_conditional, entry_computation_layout={()->f32[]}
1552 
1553 Negate {
1554   x = f32[] parameter(0)
1555   ROOT negate = f32[] negate(x)
1556 }
1557 
1558 Identity {
1559   y = f32[] parameter(0)
1560   ROOT copy = f32[] copy(y)
1561 }
1562 
1563 Floor {
1564   z = f32[] parameter(0)
1565   ROOT floor = f32[] floor(z)
1566 }
1567 
1568 ENTRY Parameters1.v4 {
1569   constant = s32[] constant(1)
1570   constant.1 = f32[] constant(56)
1571   constant.2 = f32[] constant(12)
1572   constant.3 = f32[] constant(13)
1573   ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
1574 }
1575 
1576 )"
1577 },
1578 // Predicated Conditional
1579 {
1580 "PredicatedConditional",
1581 R"(HloModule pred_conditional, entry_computation_layout={()->f32[]}
1582 
1583 Negate {
1584   x = f32[] parameter(0)
1585   ROOT negate = f32[] negate(x)
1586 }
1587 
1588 Identity {
1589   y = f32[] parameter(0)
1590   ROOT copy = f32[] copy(y)
1591 }
1592 
1593 ENTRY Parameters1.v4 {
1594   constant = pred[] constant(true)
1595   constant.1 = f32[] constant(56)
1596   constant.2 = f32[] constant(12)
1597   ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
1598 }
1599 
1600 )"
1601 },
1602 // CustomCall
1603 {
1604 "CustomCall",
1605 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1606 
1607 ENTRY CustomCall {
1608   constant = f32[1]{0} constant({12345})
1609   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
1610 }
1611 
1612 )"
1613 },
1614 // CustomCall with single computation.
1615 {
1616 "CustumCallSingleComp",
1617 R"(HloModule custom_call_with_comp, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1618 
1619 max_F32 {
1620   lhs = f32[] parameter(0)
1621   rhs = f32[] parameter(1)
1622   ROOT maximum = f32[] maximum(lhs, rhs)
1623 }
1624 
1625 ENTRY CustomCall {
1626   constant = f32[1]{0} constant({12345})
1627   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32}
1628 }
1629 
1630 )"
1631 },
1632 // CustomCall with multiple computations.
1633 {
1634 "CustumCallMultipleComps",
1635 R"(HloModule custom_call_with_comps, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1636 
1637 max_F32 {
1638   lhs = f32[] parameter(0)
1639   rhs = f32[] parameter(1)
1640   ROOT maximum = f32[] maximum(lhs, rhs)
1641 }
1642 
1643 ENTRY CustomCall {
1644   constant = f32[1]{0} constant({12345})
1645   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32}
1646 }
1647 
1648 )"
1649 },
1650 // Variables with non-default names
1651 {
1652 "NonDefaultNames",
1653 R"(HloModule add_constants_module, entry_computation_layout={()->f32[]}
1654 
1655 ENTRY add_constants {
1656   foo = f32[] constant(3.14)
1657   ROOT bar = f32[] add(foo, foo)
1658 }
1659 
1660 )"
1661 },
1662 {
1663 "Dot",
1664 R"(HloModule dot, entry_computation_layout={(f32[2,10]{1,0},f32[10,2]{1,0})->f32[2]{0}}
1665 
1666 ENTRY dot {
1667   a = f32[2,10]{1,0} parameter(0)
1668   b = f32[10,2]{1,0} parameter(1)
1669   ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
1670 }
1671 
1672 )"
1673 },
1674 {
1675 "gather",
1676 R"(HloModule gather, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}}
1677 
1678 ENTRY Gather {
1679   input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1680   start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1681   ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
1682 }
1683 
1684 )"
1685 },
1686 // all-reduce
1687 {
1688 "AllReduce",
1689 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1690 
1691 add {
1692   lhs = f32[] parameter(0)
1693   rhs = f32[] parameter(1)
1694   ROOT add = f32[] add(lhs, rhs)
1695 }
1696 
1697 ENTRY CRS {
1698   input = f32[8]{0} parameter(0)
1699   ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
1700 }
1701 
1702 )"
1703 },
1704 // all-reduce with subgroups
1705 {
1706 "AllReduceWithSubgroups",
1707 R"(HloModule CRS_Subgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}
1708 
1709 add {
1710   lhs = f32[] parameter(0)
1711   rhs = f32[] parameter(1)
1712   ROOT add = f32[] add(lhs, rhs)
1713 }
1714 
1715 ENTRY AllReduceWithSubgroups {
1716   input = f32[128,32]{0,1} parameter(0)
1717   ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
1718 }
1719 
1720 )",
1721 /*replica_count=*/4,
1722 },
1723 // all-reduce with constrained layout
1724 {
1725 "AllReduceWithLayout",
1726 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1727 
1728 add {
1729   lhs = f32[] parameter(0)
1730   rhs = f32[] parameter(1)
1731   ROOT add = f32[] add(lhs, rhs)
1732 }
1733 
1734 ENTRY CRS {
1735   input = f32[8]{0} parameter(0)
1736   ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
1737 }
1738 
1739 )"
1740 },
1741 // all-reduce with channel-id
1742 {
1743 "AllReduceAllReduce",
1744 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1745 
1746 add {
1747   lhs = f32[] parameter(0)
1748   rhs = f32[] parameter(1)
1749   ROOT add = f32[] add(lhs, rhs)
1750 }
1751 
1752 ENTRY CRS {
1753   input = f32[8]{0} parameter(0)
1754   crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1755   ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1756 }
1757 
1758 )"
1759 },
1760 // all-reduce start and done
1761 {
1762 "AllReduceStartAndDone",
1763 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1764 
1765 add {
1766   lhs = f32[] parameter(0)
1767   rhs = f32[] parameter(1)
1768   ROOT add = f32[] add(lhs, rhs)
1769 }
1770 
1771 ENTRY CRS {
1772   input = f32[8]{0} parameter(0)
1773   crs = f32[8]{0} all-reduce-start(input), replica_groups={}, to_apply=add
1774   ROOT done = f32[8]{0} all-reduce-done(crs)
1775 }
1776 
1777 )"
1778 },
1779 // reduce-scatter
1780 {
1781 "ReduceScatter",
1782 R"(HloModule RS, entry_computation_layout={(f32[8]{0})->f32[4]{0}}
1783 
1784 add {
1785   lhs = f32[] parameter(0)
1786   rhs = f32[] parameter(1)
1787   ROOT add = f32[] add(lhs, rhs)
1788 }
1789 
1790 ENTRY CRS {
1791   input = f32[8]{0} parameter(0)
1792   ROOT ars = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}}, dimensions={0}, to_apply=add
1793 }
1794 
1795 )"
1796 },
1797 // all-gather
1798 {
1799 "AllGather",
1800 R"(HloModule AllGather, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1801 
1802 ENTRY AllGather {
1803   input = f32[128,32]{0,1} parameter(0)
1804   ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1}
1805 }
1806 
1807 )"
1808 },
1809 // all-gather with constrained layout
1810 {
1811 "AllGatherWithLayout",
1812 R"(HloModule AllGather, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1813 
1814 ENTRY AllGather {
1815   input = f32[128,32]{0,1} parameter(0)
1816   ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, constrain_layout=true, dimensions={1}
1817 }
1818 
1819 )"
1820 },
1821 // all-gather with subgroups
1822 {
1823 "AllGatherWithSubgroups",
1824 R"(HloModule AllGatherWithSubgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,64]{0,1}}
1825 
1826 ENTRY AllGatherWithSubgroups {
1827   input = f32[128,32]{0,1} parameter(0)
1828   ROOT ag = f32[128,64]{0,1} all-gather(input), replica_groups={{0,1},{2,3}}, dimensions={1}
1829 }
1830 
1831 )",
1832 /*replica_count=*/4,
1833 },
1834 // all-to-all
1835 {
1836 "AllToAll",
1837 R"(HloModule AllToAll, entry_computation_layout={(f32[128,32]{0,1})->(f32[128,32]{0,1})}
1838 
1839 ENTRY AllToAll {
1840   input = f32[128,32]{0,1} parameter(0)
1841   ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
1842 }
1843 
1844 )"
1845 },
1846 // all-to-all with subgroups
1847 {
1848 "AllToAllWithSubgroups",
1849 R"(HloModule AllToAllWithSubgroups, entry_computation_layout={(f32[128,32]{0,1},f32[128,32]{0,1})->(f32[128,32]{0,1}, f32[128,32]{0,1})}
1850 
1851 ENTRY AllToAllWithSubgroups {
1852   p0 = f32[128,32]{0,1} parameter(0)
1853   p1 = f32[128,32]{0,1} parameter(1)
1854   ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
1855 }
1856 
1857 )",
1858 /*replica_count=*/4,
1859 },
1860 // collective-permute
1861 {
1862 "CollectivePermute",
1863 R"(HloModule CollectivePermute, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}
1864 
1865 ENTRY CollectivePermute {
1866   input = f32[128,32]{0,1} parameter(0)
1867   ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
1868 }
1869 
1870 )",
1871 /*replica_count=*/4
1872 },
1873 // collective-permute with in-place updates
1874 {
1875 "CollectivePermuteInPlaceUpdate",
1876 R"(HloModule CollectivePermuteInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1877 
1878 ENTRY CollectivePermuteInPlaceUpdate {
1879   input = f32[128,32]{0,1} parameter(0)
1880   constant = f32[] constant(1)
1881   output = f32[128,128]{0,1} broadcast(constant), dimensions={}
1882   constant.1 = s32[] constant(0)
1883   tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
1884   constant.2 = s32[] constant(64)
1885   tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
1886   ROOT root = f32[128,128]{0,1} collective-permute(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{128,32}}
1887 }
1888 
1889 )",
1890 /*replica_count=*/4
1891 },
1892 // collective-permute with in-place updates with multiple targets per source
1893 {
1894 "CollectivePermuteInPlaceUpdateMultipleReadWrite",
1895 R"(HloModule CollectivePermuteInPlaceUpdateMultipleReadWrite, entry_computation_layout={(f32[8,8,128]{2,1,0})->f32[8,8,128]{2,1,0}}
1896 
1897 ENTRY CollectivePermuteInPlaceUpdate {
1898   constant.3 = s32[] constant(2)
1899   constant.1 = s32[] constant(0)
1900   output_offset.3 = (s32[], s32[], s32[]) tuple(constant.3, constant.1, constant.1)
1901   constant.4 = s32[] constant(3)
1902   output_offset.4 = (s32[], s32[], s32[]) tuple(constant.4, constant.1, constant.1)
1903   input = f32[8,8,128]{2,1,0} parameter(0)
1904   constant = f32[] constant(1)
1905   output = f32[8,8,128]{2,1,0} broadcast(constant), dimensions={}
1906   input_offset.1 = (s32[], s32[], s32[]) tuple(constant.1, constant.1, constant.1)
1907   constant.2 = s32[] constant(1)
1908   input_offset.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.1, constant.1)
1909   input_offset = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(input_offset.1, input_offset.2)
1910   output_offset = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(input_offset.1, input_offset.2)
1911   ROOT root = f32[8,8,128]{2,1,0} collective-permute(input, output, input_offset, output_offset), source_target_pairs={{0,1},{1,2},{2,3},{0,3},{2,1},{3,2}}, slice_sizes={{1,8,128},{1,8,128}}
1912 }
1913 
1914 )",
1915 /*replica_count=*/4
1916 },
1917 {
1918 "CollectivePermuteInPlaceUpdateTupleMultipleReadWrite",
1919 R"(HloModule hlo_runner_test_0.1, entry_computation_layout={()->(u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})}
1920 
1921 ENTRY hlo_runner_test_0.1 {
1922   replica_id = u32[] replica-id()
1923   broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(replica_id), dimensions={}
1924   tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(broadcast.0, broadcast.0)
1925   constant.1 = u32[] constant(1000)
1926   broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(constant.1), dimensions={}
1927   broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(constant.1), dimensions={}
1928   tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(broadcast.1, broadcast.2)
1929   constant.2 = s32[] constant(0)
1930   tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.2, constant.2)
1931   constant.3 = s32[] constant(1)
1932   tuple.3 = (s32[], s32[], s32[]) tuple(constant.3, constant.2, constant.2)
1933   tuple.4 = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(tuple.2, tuple.3)
1934   tuple.7 = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(tuple.2, tuple.2)
1935   tuple.8 = (((s32[], s32[], s32[]), (s32[], s32[], s32[])), ((s32[], s32[], s32[]), (s32[], s32[], s32[]))) tuple(tuple.4, tuple.7)
1936   constant.4 = s32[] constant(2)
1937   tuple.5 = (s32[], s32[], s32[]) tuple(constant.4, constant.2, constant.2)
1938   tuple.6 = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(tuple.2, tuple.5)
1939   tuple.9 = (((s32[], s32[], s32[]), (s32[], s32[], s32[])), ((s32[], s32[], s32[]), (s32[], s32[], s32[]))) tuple(tuple.4, tuple.6)
1940   ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(tuple.input, tuple.output, tuple.8, tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1941 }
1942 
1943 )",
1944 /*replica_count=*/4
1945 },
1946 
1947 // collective-permute tuple with in-place updates
1948 {
1949 "CollectivePermuteTupleInPlaceUpdate",
1950 R"(HloModule CollectivePermuteTupleInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->(f32[128,128]{0,1}, f32[128,128]{0,1})}
1951 
1952 ENTRY CollectivePermuteInPlaceUpdate {
1953   input = f32[128,32]{0,1} parameter(0)
1954   tuple.input = (f32[128,32]{0,1}, f32[128,32]{0,1}) tuple(input, input)
1955   constant = f32[] constant(1)
1956   output = f32[128,128]{0,1} broadcast(constant), dimensions={}
1957   tuple.output = (f32[128,128]{0,1}, f32[128,128]{0,1}) tuple(output, output)
1958   constant.1 = s32[] constant(0)
1959   tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
1960   constant.2 = s32[] constant(64)
1961   tuple.2 = (s32[], s32[]) tuple(constant.2, constant.1)
1962   tuple.3 = ((s32[], s32[]), (s32[], s32[])) tuple(tuple.1, tuple.2)
1963   tuple.4 = (s32[], s32[]) tuple(constant.1, constant.1)
1964   tuple.5 = (s32[], s32[]) tuple(constant.2, constant.2)
1965   tuple.6 = ((s32[], s32[]), (s32[], s32[])) tuple(tuple.4, tuple.5)
1966   ROOT root = (f32[128,128]{0,1}, f32[128,128]{0,1}) collective-permute(tuple.input, tuple.output, tuple.3, tuple.6), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{64,32},{64,32}}
1967 }
1968 
1969 )",
1970 /*replica_count=*/4
1971 },
1972 // collective-permute-start and -done with inplace update
1973 {
1974 "CollectivePermuteStartAndDone",
1975 R"(HloModule CollectivePermuteStartAndDone, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}
1976 
1977 ENTRY CollectivePermuteStartAndDone {
1978   input = f32[128,32]{0,1} parameter(0)
1979   collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
1980   ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
1981 }
1982 
1983 )",
1984 /*replica_count=*/4
1985 },
1986 // collective-permute-start and -done
1987 {
1988 "CollectivePermuteStartAndDoneInplaceUpdate",
1989 R"(HloModule CollectivePermuteStartAndDoneInplaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1990 
1991 ENTRY CollectivePermuteStartAndDoneInplaceUpdate {
1992   input = f32[128,32]{0,1} parameter(0)
1993   constant = f32[] constant(1)
1994   output = f32[128,128]{0,1} broadcast(constant), dimensions={}
1995   constant.1 = s32[] constant(0)
1996   tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
1997   constant.2 = s32[] constant(64)
1998   tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
1999   collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,128]{0,1}, u32[], u32[]) collective-permute-start(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{64,32}}
2000   ROOT collective-permute-done.1 = f32[128,128]{0,1} collective-permute-done(collective-permute-start.1)
2001 }
2002 
2003 )",
2004 /*replica_count=*/4
2005 },
2006 // replica-id
2007 {
2008 "ReplicaId",
2009 R"(HloModule replica-id, entry_computation_layout={()->u32[]}
2010 
2011 ENTRY Replica-id {
2012   ROOT replica-id = u32[] replica-id()
2013 }
2014 
2015 )"
2016 },
2017 // partition-id
2018 {
2019 "PartitionId",
2020 R"(HloModule partition-id, entry_computation_layout={()->u32[]}
2021 
2022 ENTRY PartitionId {
2023   ROOT id = u32[] partition-id()
2024 }
2025 
2026 )"
2027 },
2028 // Iota
2029 {
2030 "Iota",
2031 R"(HloModule iota, entry_computation_layout={()->f32[100]{0}}
2032 
2033 ENTRY Iota {
2034   ROOT iota = f32[100]{0} iota(), iota_dimension=0
2035 }
2036 
2037 )"
2038 },
2039 // custom-call with window, dim_labels and feature_group_count
2040 {
2041 "CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
2042 R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount, entry_computation_layout={()->f32[100]{0}}
2043 
2044 ENTRY Computation {
2045   ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
2046 }
2047 
2048 )"
2049     },
2050 // custom-call with unknown dim labels.
2051 {
2052 "CustomCallWithUnknownDimLabels",
2053 R"(HloModule CustomCallWithUnknownDimLabels, entry_computation_layout={()->f32[100]{0}}
2054 
2055 ENTRY Computation {
2056   ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=?b01f_0?1io->b01?f, custom_call_target="target"
2057 }
2058 
2059 )"
2060     },
2061 // is_scheduled=true attribute
2062 {
2063 "ScheduledModule",
2064 R"(HloModule scheduled_module, is_scheduled=true, entry_computation_layout={(f32[1024]{0},s32[1024]{0})->(f32[1024]{0}, s32[1024]{0})}
2065 
2066 compare {
2067   p.1.lhs = s32[] parameter(2)
2068   p.1.rhs = s32[] parameter(3)
2069   p.0.lhs = f32[] parameter(0)
2070   p.0.rhs = f32[] parameter(1)
2071   ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
2072 }
2073 
2074 ENTRY Sort {
2075   keys = f32[1024]{0} parameter(0)
2076   values = s32[1024]{0} parameter(1)
2077   ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
2078 }
2079 
2080 )"
2081     },
2082 // AfterAll with multiple operands
2083 {
2084 "AfterAllWithMultipleOperands",
2085 R"(HloModule AfterAllWithMultipleOperands, entry_computation_layout={(f32[])->token[]}
2086 
2087 ENTRY AfterAllWithMultipleOperands {
2088   p0 = f32[] parameter(0)
2089   token0 = token[] after-all()
2090   token1 = token[] after-all()
2091   ROOT after-all = token[] after-all(p0, token0, token1)
2092 }
2093 
2094 )"
2095 },
2096 // AddDependency
2097 // A dependency chain is created from 'neg' to 'exp' using tokens.
2098 {
2099 "AddDependency",
2100 R"(HloModule AddDependency, entry_computation_layout={(f32[])->f32[]}
2101 
2102 ENTRY AddDependency {
2103   p = f32[] parameter(0)
2104   neg = f32[] negate(p)
2105   token0 = token[] after-all(neg)
2106   p_after_token = f32[] add-dependency(p, token0)
2107   exp = f32[] exponential(p_after_token)
2108   ROOT sum = f32[] add(neg, exp)
2109 }
2110 
2111 )"
2112 },
2113 // A module containing constants equal to the min/max values of various data
2114 // types.
2115 {
2116 "MinMaxValues",
2117 R"(HloModule MinMaxValues, entry_computation_layout={()->c128[2]{0}}
2118 
2119 ENTRY MinMaxValues {
2120   x.s8 = s8[2]{0} constant({-128, 127})
2121   x.s16 = s16[2]{0} constant({-32768, 32767})
2122   x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
2123   x.u8 = u8[2]{0} constant({0, 255})
2124   x.u16 = u16[2]{0} constant({0, 65535})
2125   x.u32 = u32[2]{0} constant({0, 4294967295})
2126   x.f16 = f16[2]{0} constant({-65504, 65504})
2127   x.bf16 = bf16[2]{0} constant({-3.39e+38, 3.39e+38})
2128   x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
2129   x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
2130   x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
2131   ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
2132 }
2133 
2134 )"
2135 },
2136 
2137 // Bitcast-convert usage
2138 {
2139 "BitcastConvert",
2140 R"(HloModule BitcastConvert, entry_computation_layout={(f32[100]{0})->u32[100]{0}}
2141 
2142 ENTRY BitcastConvertUsage {
2143   p = f32[100]{0} parameter(0)
2144   ROOT out = u32[100]{0} bitcast-convert(p)
2145 }
2146 
2147 )"
2148 },
2149 {
2150 "OuterDimensionPartitions",
2151 R"(HloModule OuterDimensionPartitions, entry_computation_layout={(f32[100]{0})->f32[100]{0}}
2152 
2153 ENTRY Test {
2154   ROOT foo = f32[100]{0} parameter(0), outer_dimension_partitions={0,10,20}
2155 }
2156 
2157 )"
2158 },
2159 });
2160   // clang-format on
2161 }
2162 
2163 std::vector<NonRoundtripTestData> CreateNonRoundtripTestCases() {
2164   // clang-format off
2165 return std::vector<NonRoundtripTestData>({
2166 {
2167 "SimpleNesting",
2168 R"(HloModule test
2169 
2170 ENTRY test {
2171     ROOT root = add(f32[10] parameter(0), multiply(f32[10] parameter(1), f32[10] parameter(2)))
2172 })",
2173 R"(HloModule test, entry_computation_layout={(f32[10]{0},f32[10]{0},f32[10]{0})->f32[10]{0}}
2174 
2175 ENTRY test {
2176   parameter.anon = f32[10]{0} parameter(0)
2177   parameter.anon.1 = f32[10]{0} parameter(1)
2178   parameter.anon.2 = f32[10]{0} parameter(2)
2179   multiply.anon = f32[10]{0} multiply(parameter.anon.1, parameter.anon.2)
2180   ROOT root = f32[10]{0} add(parameter.anon, multiply.anon)
2181 })"
2182 },
2183 
2184 {
2185 "AmbiguousNames",
2186 R"(HloModule test
2187 ENTRY test {
2188   add = add(f32[10] parameter(0), f32[10] parameter(1))
2189   ROOT add2 = add(add, add(add, add))
2190 })",
2191 R"(HloModule test, entry_computation_layout={(f32[10]{0},f32[10]{0})->f32[10]{0}}
2192 
2193 ENTRY test {
2194   parameter.anon = f32[10]{0} parameter(0)
2195   parameter.anon.1 = f32[10]{0} parameter(1)
2196   add = f32[10]{0} add(parameter.anon, parameter.anon.1)
2197   add.anon = f32[10]{0} add(add, add)
2198   ROOT add2 = f32[10]{0} add(add, add.anon)
2199 })"
2200 },
2201 
2202 {
2203 "TupleShapeInsideAnonymousInstr",
2204 R"(HloModule test
2205 
2206 ENTRY test {
2207   ROOT root = get-tuple-element(
2208     (f32[10], f16[10]) tuple(f32[10] parameter(0), f16[10] parameter(1))
2209   ), index=0
2210 })",
2211 R"(HloModule test, entry_computation_layout={(f32[10]{0},f16[10]{0})->f32[10]{0}}
2212 
2213 ENTRY test {
2214   parameter.anon = f32[10]{0} parameter(0)
2215   parameter.anon.1 = f16[10]{0} parameter(1)
2216   tuple.anon = (f32[10]{0}, f16[10]{0}) tuple(parameter.anon, parameter.anon.1)
2217   ROOT root = f32[10]{0} get-tuple-element(tuple.anon), index=0
2218 })"
2219 },
2220 
2221 {
2222 "MixAnonAndNonAnonOperands",
2223 R"(HloModule test
2224 
2225 ENTRY test {
2226   add = add(f32[10] parameter(0), f32[10] parameter(1))
2227   ROOT root = tuple(add, add(add, add), add)
2228 })",
2229 R"(HloModule test, entry_computation_layout={(f32[10]{0},f32[10]{0})->(f32[10]{0}, f32[10]{0}, f32[10]{0})}
2230 
2231 ENTRY test {
2232   parameter.anon = f32[10]{0} parameter(0)
2233   parameter.anon.1 = f32[10]{0} parameter(1)
2234   add = f32[10]{0} add(parameter.anon, parameter.anon.1)
2235   add.anon = f32[10]{0} add(add, add)
2236   ROOT root = (f32[10]{0}, f32[10]{0}, f32[10]{0}) tuple(add, add.anon, add)
2237 })"
2238 },
2239 
2240 {
2241 "BroadcastOfScalarDoesntNeedDimensionsAttr",
2242 R"(HloModule test
2243 
2244 ENTRY test {
2245   ROOT root = sqrt(f32[10,10] broadcast(f32[] parameter(0)))
2246 })",
2247 R"(HloModule test, entry_computation_layout={(f32[])->f32[10,10]{1,0}}
2248 
2249 ENTRY test {
2250   parameter.anon = f32[] parameter(0)
2251   broadcast.anon = f32[10,10]{1,0} broadcast(parameter.anon), dimensions={}
2252   ROOT root = f32[10,10]{1,0} sqrt(broadcast.anon)
2253 })"
2254 },
2255 
2256 {
2257 "SparseShape",
2258 R"(HloModule test
2259 
2260 ENTRY test {
2261   ROOT root = f32[10,10]{1,0:D(D,C)} parameter(0)
2262 })",
2263 R"(HloModule test, entry_computation_layout={(f32[10,10]{1,0:D(D,C)})->f32[10,10]{1,0:D(D,C)}}
2264 
2265 ENTRY test {
2266   ROOT root = f32[10,10]{1,0:D(D,C)} parameter(0)
2267 })",
2268 }
2269 });
2270   // clang-format on
2271 }
2272 
2273 // The test class for those tests defined above which round-trip through the
2274 // parser and ToString is templatized on two bool parameters:
2275 //
2276 //  short_form : used for the "short" test cases which use the ShortParsable
2277 //    output form.
2278 //  proto_round_trip : whether the module should also be round-tripped through
2279 //    HloProto form. This provides much better coverage for the proto
2280 //    serialization/deserialization.
2281 //
2282 // The proto_round_trip=true case also technically covers the Parser->ToString
2283 // roundtrip as well, but separating out the Parser->ToString roundtrip as its
2284 // own test provides better isolation and could conceivably catch weirdo bugs
2285 // which are hidden by interaction between the textual and proto roundtripping.
2286 template <bool short_form, bool proto_round_trip>
2287 class HloParameterizedParserTest
2288     : public ::testing::Test,
2289       public ::testing::WithParamInterface<TestData> {
2290  protected:
2291   // Expects "ToString(ParseHloModule(std::string)) == string", that is, parses
2292   // the string, asserts that it succeeded, stringifies the parsed module, and
2293   // checks that it equals the original string.
2294   void ExpectEqual() {
2295     std::unique_ptr<HloModule> module;
2296     const std::string& original = GetParam().module_string;
2297     HloModuleConfig config;
2298     config.set_replica_count(GetParam().replica_count);
2299     if (GetParam().enable_verification) {
2300       auto verified_module = std::make_unique<VerifiedHloModule>(
2301           GetParam().test_name, config,
2302           /*verifier_layout_sensitive=*/false,
2303           /*allow_mixed_precision_in_hlo_verifier=*/true,
2304           ShapeUtil::ByteSizeOfElements);
2305       TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
2306       module = std::move(verified_module);
2307     } else {
2308       TF_ASSERT_OK_AND_ASSIGN(module,
2309                               ParseAndReturnUnverifiedModule(original, config));
2310     }
2311     if (proto_round_trip) {
2312       TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
2313                                           module->ToProto(), module->config()));
2314     }
2315     if (short_form) {
2316       EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
2317     } else {
2318       EXPECT_EQ(
2319           original,
2320           module->ToString(HloPrintOptions().set_print_large_constants(true)));
2321     }
2322   }
2323 };
2324 
2325 // These using shenanigans are required because the TEST_P macro doesn't like
2326 // template instantiations which contain commas.
2327 using HloParserTestLong = HloParameterizedParserTest<false, false>;
2328 using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
2329 using HloParserTestShort = HloParameterizedParserTest<true, false>;
2330 using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
2331 
2332 TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
2333 TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
2334 TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
2335 TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
2336 
2337 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
2338                          ::testing::ValuesIn(CreateTestCases()),
2339                          TestDataToString);
2340 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
2341                          HloParserTestLongProto,
2342                          ::testing::ValuesIn(CreateTestCases()),
2343                          TestDataToString);
2344 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
2345                          ::testing::ValuesIn(CreateShortTestCases()),
2346                          TestDataToString);
2347 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
2348                          HloParserTestShortProto,
2349                          ::testing::ValuesIn(CreateShortTestCases()),
2350                          TestDataToString);
2351 
2352 class HloNonRoundtripParserTest
2353     : public ::testing::TestWithParam<NonRoundtripTestData> {};
2354 TEST_P(HloNonRoundtripParserTest, Run) {
2355   auto module = std::make_unique<VerifiedHloModule>(
2356       GetParam().test_name, HloModuleConfig{},
2357       /*verifier_layout_sensitive=*/false,
2358       /*allow_mixed_precision_in_hlo_verifier=*/true,
2359       ShapeUtil::ByteSizeOfElements);
2360   TF_ASSERT_OK(
2361       module->ParseHloStringAndVerifyModule(GetParam().input_module_string));
2362   EXPECT_EQ(absl::StripAsciiWhitespace(GetParam().output_module_string),
2363             absl::StripAsciiWhitespace(
2364                 module->ToString(HloPrintOptions::ShortParsable())));
2365 }
2366 
2367 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
2368                          HloNonRoundtripParserTest,
2369                          ::testing::ValuesIn(CreateNonRoundtripTestCases()),
2370                          NonRoundtripTestDataToString);
2371 
2372 class HloParserTest : public ::testing::Test {
2373  protected:
2374   static void ExpectHasSubstr(string_view s, string_view expected) {
2375     EXPECT_TRUE(absl::StrContains(s, expected))
2376         << "'" << s << "' does not contain '" << expected << "'";
2377   }
2378   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
2379       absl::string_view hlo_text) {
2380     auto module = std::make_unique<VerifiedHloModule>(
2381         ::testing::UnitTest::GetInstance()->current_test_info()->name(),
2382         HloModuleConfig(),
2383         /*verifier_layout_sensitive=*/false,
2384         /*allow_mixed_precision_in_hlo_verifier=*/true,
2385         ShapeUtil::ByteSizeOfElements);
2386     TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
2387     return std::move(module);
2388   }
2389 };
2390 
2391 TEST_F(HloParserTest, Empty) {
2392   const std::string original = "";
2393   auto result = ParseAndReturnUnverifiedModule(original);
2394   EXPECT_NE(OkStatus(), result.status());
2395 }
2396 
2397 TEST_F(HloParserTest, Garbage) {
2398   const std::string original =
2399       "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
2400   auto result = ParseAndReturnUnverifiedModule(original);
2401   EXPECT_NE(OkStatus(), result.status());
2402 }
2403 
2404 TEST_F(HloParserTest, WrongOpcode) {
2405   const std::string original = R"(HloModule wrong_opcode:
2406 
2407 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
2408   %x = f32[]{} parameter(0)
2409   %y = f32[]{} parameter(1)
2410   %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
2411 }
2412 
2413 )";
2414   auto result = ParseAndReturnUnverifiedModule(original);
2415   EXPECT_NE(OkStatus(), result.status());
2416 }
2417 
2418 TEST_F(HloParserTest, MetadataWithCholesky) {
2419   const std::string original = R"(HloModule metadata_with_cholesky
2420 ENTRY %blabla (a: f32[1,291,291]) -> f32[1,291,291] {
2421   %a = f32[1,291,291] parameter(0)
2422   %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true, metadata={op_type="Cholesky" op_name="Cholesky" profile_type={1}}
2423 }
2424 )";
2425   auto result = ParseAndReturnVerifiedModule(original);
2426   EXPECT_EQ(OkStatus(), result.status());
2427   EXPECT_EQ("Cholesky", result.ValueOrDie()
2428                             ->entry_computation()
2429                             ->root_instruction()
2430                             ->metadata()
2431                             .op_name());
2432   EXPECT_EQ("Cholesky", result.ValueOrDie()
2433                             ->entry_computation()
2434                             ->root_instruction()
2435                             ->metadata()
2436                             .op_type());
2437   EXPECT_EQ(WINDOW, *result.ValueOrDie()
2438                          ->entry_computation()
2439                          ->root_instruction()
2440                          ->metadata()
2441                          .profile_type()
2442                          .begin());
2443 }
2444 
2445 TEST_F(HloParserTest, WrongShape) {
2446   const std::string original = R"(HloModule wrong_opcode:
2447 
2448 ENTRY %blabla (x: g32[]) -> g32[] {
2449   %x = g32[]{} parameter(0)
2450 }
2451 
2452 )";
2453   auto result = ParseAndReturnUnverifiedModule(original);
2454   EXPECT_NE(OkStatus(), result.status());
2455 }
2456 
2457 TEST_F(HloParserTest, WrongOperandsSize) {
2458   const std::string original = R"(HloModule wrong_opcode:
2459 
2460 ENTRY %blabla (x: f32[]) -> pred[] {
2461   %x = f32[]{} parameter(0)
2462   %eq = pred[]{} compare(f32[]{} %x), direction=EQ
2463 }
2464 
2465 )";
2466   auto result = ParseAndReturnUnverifiedModule(original);
2467   EXPECT_NE(OkStatus(), result.status());
2468 }
2469 
2470 TEST_F(HloParserTest, OperandNotFound) {
2471   const std::string original = R"(HloModule operand_not_found:
2472 ENTRY %blabla (x: f32[]) -> pred[] {
2473   %x = f32[]{} parameter(0)
2474   %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
2475 }
2476 )";
2477   auto result = ParseAndReturnUnverifiedModule(original);
2478   EXPECT_NE(OkStatus(), result.status());
2479 }
2480 
2481 TEST_F(HloParserTest, MoreConstants) {
2482   const std::string original = R"(HloModule SelectScalarS32True_module
2483 
2484 ENTRY %SelectScalarS32True.v4 () -> s32[] {
2485   %constant.2 = pred[] constant(true)
2486   %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
2487   %constant = s32[] constant(42)
2488   %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
2489 }
2490 
2491 )";
2492   auto result = ParseAndReturnVerifiedModule(original);
2493   TF_EXPECT_OK(result.status());
2494   // Constant instructions have no name. The string will be parsed successfully
2495   // but the constant names will not be exactly the same.
2496 }
2497 
2498 TEST_F(HloParserTest, ConfigurationField) {
2499   const std::string original = R"(HloModule AModule
2500 ENTRY %configuration_test() -> s32[] {
2501   %constant = s32[] constant(42), backend_config="foo bar"
2502 })";
2503   auto result = ParseAndReturnVerifiedModule(original);
2504   TF_ASSERT_OK(result.status());
2505   EXPECT_EQ("foo bar", result.ValueOrDie()
2506                            ->entry_computation()
2507                            ->root_instruction()
2508                            ->raw_backend_config_string());
2509 }
2510 
2511 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
2512   const std::string original = R"(HloModule some_2_module
2513 
2514 ENTRY %some_2 () -> f32[2] {
2515   ROOT %constant = f32[2]{0} constant({1,{2}})
2516 }
2517 
2518 )";
2519   auto result = ParseAndReturnUnverifiedModule(original);
2520   EXPECT_NE(OkStatus(), result.status());
2521   ExpectHasSubstr(result.status().error_message(),
2522                   "expects nested array in rank 1, but sees larger");
2523 }
2524 
2525 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
2526   const std::string original = R"(HloModule some_2x3_module
2527 
2528 ENTRY %some_2x3 () -> f32[2,3] {
2529   ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
2530 }
2531 
2532 )";
2533   auto result = ParseAndReturnUnverifiedModule(original);
2534   EXPECT_NE(OkStatus(), result.status());
2535   ExpectHasSubstr(result.status().error_message(),
2536                   "expects nested array in rank 2, but sees 1");
2537 }
2538 
2539 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
2540   const std::string original = R"(HloModule some_2x3x2_module
2541 
2542 ENTRY %some_2x3x2 () -> f32[2,3,2] {
2543   ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
2544 }
2545 
2546 )";
2547   auto result = ParseAndReturnUnverifiedModule(original);
2548   EXPECT_NE(OkStatus(), result.status());
2549   ExpectHasSubstr(result.status().error_message(),
2550                   "expects 3 elements in the [0]th element");
2551 }
2552 
2553 TEST_F(HloParserTest, ConstantF16Overflow) {
2554   const std::string original =
2555       R"(HloModule ConstantF16Overflow_module
2556 
2557 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
2558   ROOT %constant = f16[] constant(-65520)
2559 }
2560 
2561 )";
2562   auto result = ParseAndReturnUnverifiedModule(original);
2563   EXPECT_NE(OkStatus(), result.status());
2564   ExpectHasSubstr(result.status().error_message(),
2565                   "is out of range for literal's primitive type F16");
2566 }
2567 
2568 TEST_F(HloParserTest, ConstantBf16NoOverflow) {
2569   // 65505 is in range for bf16.
2570   const std::string original = R"(
2571   HloModule test_module
2572   ENTRY test {
2573     ROOT c = bf16[] constant(-65505)
2574   })";
2575   EXPECT_EQ(OkStatus(), ParseAndReturnVerifiedModule(original).status());
2576 }
2577 
2578 TEST_F(HloParserTest, ConstantBf16Overflow) {
2579   // 1e100 is out of range for bf16.
2580   const std::string original = R"(
2581   HloModule test_module
2582   ENTRY test {
2583     ROOT c = bf16[] constant(1e100)
2584   })";
2585   ExpectHasSubstr(
2586       ParseAndReturnUnverifiedModule(original).status().error_message(),
2587       "out of range");
2588 }
2589 
2590 TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
2591   const std::string original = R"(
2592       HloModule ConstantUnsignedUnderflow_module
2593       ENTRY %ConstantUnsignedUnderflow () -> u64[] {
2594         ROOT %constant = u64[] constant(-1)
2595       })";
2596   auto result = ParseAndReturnUnverifiedModule(original);
2597   EXPECT_EQ(OkStatus(), result.status());
2598 }
2599 
2600 TEST_F(HloParserTest, ConstantUnsignedOverflow) {
2601   const std::string original = R"(
2602       HloModule ConstantUnsignedOverflow_module
2603       ENTRY %ConstantUnsignedOverflow () -> u32[] {
2604         ROOT %constant = u32[] constant(4294967296)
2605       })";
2606   auto result = ParseAndReturnUnverifiedModule(original);
2607   EXPECT_NE(OkStatus(), result.status());
2608   ExpectHasSubstr(result.status().error_message(),
2609                   "is out of range for literal's primitive type U32");
2610 }
2611 
2612 TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
2613   const std::string original = R"(
2614       HloModule ConstantUnsignedOverflow_module
2615       ENTRY %ConstantUnsignedOverflow () -> u64[] {
2616         ROOT %constant = u64[] constant(9223372036854775808)
2617       })";
2618   auto result = ParseAndReturnUnverifiedModule(original);
2619   EXPECT_EQ(OkStatus(), result.status());
2620 }
2621 
2622 TEST_F(HloParserTest, ConstantC64Overflow) {
2623   const std::string original = R"(
2624       HloModule test_module
2625       ENTRY test () -> c64[] {
2626         ROOT c = c64[] constant((1e100, 0))
2627       })";
2628   auto result = ParseAndReturnUnverifiedModule(original);
2629   EXPECT_NE(OkStatus(), result.status());
2630 }
2631 
2632 TEST_F(HloParserTest, ConstantC64Underflow) {
2633   const std::string original = R"(
2634       HloModule test_module
2635       ENTRY test () -> c64[] {
2636         ROOT c = c64[] constant((0, -1e100))
2637       })";
2638   auto result = ParseAndReturnUnverifiedModule(original);
2639   EXPECT_NE(OkStatus(), result.status());
2640 }
2641 
2642 TEST_F(HloParserTest, ConstantF64Overflow) {
2643   const std::string original = R"(
2644       HloModule test_module
2645       ENTRY test {
2646         ROOT c = f64[] constant(1.8e308)
2647       })";
2648   auto result = ParseAndReturnUnverifiedModule(original);
2649   EXPECT_NE(OkStatus(), result.status());
2650 }
2651 
2652 TEST_F(HloParserTest, ConstantF64Underflow) {
2653   const std::string original = R"(
2654       HloModule test_module
2655       ENTRY test {
2656         ROOT c = f64[] constant(-1.8e308)
2657       })";
2658   auto result = ParseAndReturnUnverifiedModule(original);
2659   EXPECT_NE(OkStatus(), result.status());
2660 }
2661 
2662 TEST_F(HloParserTest, ConstantWithExp) {
2663   const std::string original = R"(HloModule ConstantWithExp_module
2664 
2665 ENTRY %ConstantWithExp.v4 () -> f32[] {
2666   %constant.1 = f32[] constant(3e+2)
2667 }
2668 
2669 )";
2670   auto result = ParseAndReturnVerifiedModule(original);
2671   TF_EXPECT_OK(result.status());
2672   // The string will be parsed successfully but the output strings are not
2673   // exactly the same, because "3e2" is parsed into value 300 and will be
2674   // printed as "300".
2675 }
2676 
2677 TEST_F(HloParserTest, ShortConstant) {
2678   const std::string original =
2679       R"(HloModule ShortConstant_module, entry_computation_layout={()->f32[67,89]{1,0}}
2680 
2681 ENTRY %ShortConstant.v4 () -> f32[67,89] {
2682   ROOT %constant.1 = f32[67,89]{1,0} constant({...})
2683 }
2684 
2685 )";
2686   auto result = ParseAndReturnVerifiedModule(original);
2687   TF_EXPECT_OK(result.status());
2688   EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2689 }
2690 
2691 TEST_F(HloParserTest, NegativeNan) {
2692   const std::string original =
2693       R"(HloModule NegativeNan_module, entry_computation_layout={()->bf16[2]{0}}
2694 
2695 ENTRY %NegativeNan () -> bf16[2] {
2696   ROOT %constant = bf16[2]{0} constant({-nan, -nan})
2697 }
2698 
2699 )";
2700   auto result = ParseAndReturnUnverifiedModule(original);
2701   EXPECT_EQ(OkStatus(), result.status());
2702   EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2703 }
2704 
2705 TEST_F(HloParserTest, NanPayload) {
2706   const std::string original =
2707       R"(HloModule NanPayload_module, entry_computation_layout={()->bf16[2]{0}}
2708 
2709 ENTRY %NanPayload () -> bf16[2] {
2710   ROOT %constant = bf16[2]{0} constant({-nan(0x7f), -nan(0x3f)})
2711 }
2712 
2713 )";
2714   auto result = ParseAndReturnUnverifiedModule(original);
2715   EXPECT_EQ(OkStatus(), result.status());
2716   EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2717 }
2718 
2719 TEST_F(HloParserTest, AttributesAnyOrder) {
2720   const std::string original = R"(HloModule any_order_module
2721 
2722 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
2723   %input = f32[1,2,1]{2,1,0} parameter(0)
2724   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2725   %filter = f32[1,1,1]{2,1,0} parameter(1)
2726   ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
2727 }
2728 
2729 )";
2730   TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2731 }
2732 
2733 TEST_F(HloParserTest, InvalidDimLabels) {
2734   std::string prefix = R"(HloModule invalid_dim_labels_module
2735 
2736 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2737   %input = f32[1,2,1]{2,1,0} parameter(0)
2738   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2739   %filter = f32[1,1,1]{2,1,0} parameter(1)
2740   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
2741   std::string suffix = R"(
2742 }
2743 
2744 )";
2745 
2746   ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2747                       absl::StrCat(prefix, ",dim_labels=00_01->10", suffix))
2748                       .status()
2749                       .error_message(),
2750                   "expects unique");
2751 
2752   ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2753                       absl::StrCat(prefix, ",dim_labels=012_0123->210", suffix))
2754                       .status()
2755                       .error_message(),
2756                   "must have same number of spatial dimensions");
2757 
2758   ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2759                       absl::StrCat(prefix, ",dim_labels=013_0123->210", suffix))
2760                       .status()
2761                       .error_message(),
2762                   "expects [0-2bf?]");
2763 }
2764 
2765 TEST_F(HloParserTest, UnexpectedAttribute) {
2766   const std::string original = R"(HloModule unexpected_attr_module
2767 
2768 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2769   %token0 = token[] after-all()
2770   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2771   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2772   ROOT %constant = f32[] constant(2.1)
2773   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
2774   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2775 }
2776 
2777 )";
2778   ExpectHasSubstr(
2779       ParseAndReturnUnverifiedModule(original).status().error_message(),
2780       "unexpected attribute \"calls\"");
2781 }
2782 
TEST_F(HloParserTest,MissingAttribute)2783 TEST_F(HloParserTest, MissingAttribute) {
2784   const std::string original = R"(HloModule missing_attr_module
2785 
2786 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2787   %token0 = token[] after-all()
2788   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2789   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2790   ROOT %constant = f32[] constant(-2.1)
2791   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
2792   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2793 }
2794 
2795 )";
2796   ExpectHasSubstr(
2797       ParseAndReturnUnverifiedModule(original).status().error_message(),
2798       "attribute channel_id is expected but not seen");
2799 }
2800 
TEST_F(HloParserTest,PredecessorUndefined)2801 TEST_F(HloParserTest, PredecessorUndefined) {
2802   const std::string original = R"(HloModule pre_not_found_module
2803 
2804 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2805   %token0 = token[] after-all()
2806   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2807   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2808   ROOT %constant = f32[] constant(2.1)
2809   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
2810   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2811 }
2812 
2813 )";
2814   ExpectHasSubstr(
2815       ParseAndReturnUnverifiedModule(original).status().error_message(),
2816       "'done' is not defined");
2817 }
2818 
TEST_F(HloParserTest,SliceAllowOmitStride1)2819 TEST_F(HloParserTest, SliceAllowOmitStride1) {
2820   const std::string original = R"(HloModule slice_module
2821 
2822 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
2823   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
2824   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
2825 }
2826 
2827 )";
2828   TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2829 }
2830 
TEST_F(HloParserTest,PaddingConfigIsNotWindowPad)2831 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
2832   const std::string original = R"(HloModule window_pad_module
2833 
2834 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2835   %input = f32[1,2,1]{2,1,0} parameter(0)
2836   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2837   %filter = f32[1,1,1]{2,1,0} parameter(1)
2838   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
2839 }
2840 
2841 )";
2842   ExpectHasSubstr(
2843       ParseAndReturnUnverifiedModule(original).status().error_message(),
2844       "expects padding_low and padding_high separated by '_'");
2845 }
2846 
TEST_F(HloParserTest,CommaBetweenSubAttributes)2847 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
2848   const std::string original = R"(HloModule test_comma_module
2849 
2850 ENTRY %test_comma.v4 () -> f32[] {
2851   ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
2852 }
2853 
2854 )";
2855   TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2856 }
2857 
TEST_F(HloParserTest,ComputationShapeDoesNotMatchRootShape)2858 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
2859   const std::string original = R"(HloModule custom_call:
2860 
2861 ENTRY %CustomCall () -> f32[1] {
2862   %constant = f32[1]{0} constant({12345})
2863   ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
2864 })";
2865   ExpectHasSubstr(
2866       ParseAndReturnUnverifiedModule(original).status().error_message(),
2867       "Shape of computation CustomCall, f32[1], is not compatible "
2868       "with that of its root instruction foo, f32[1,2,3]");
2869 }
2870 
TEST_F(HloParserTest,EntryComputationWithLayout)2871 TEST_F(HloParserTest, EntryComputationWithLayout) {
2872   const std::string original = R"(HloModule layout:
2873 add_F32.v3 {
2874   lhs = f32[] parameter(0)
2875   rhs = f32[] parameter(1)
2876   ROOT add = f32[] add(lhs, rhs)
2877 }
2878 
2879 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
2880   input = f32[8,16,256]{0,1,2} parameter(0)
2881   constant = f32[] constant(0)
2882   ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
2883 })";
2884 
2885   auto module = ParseAndReturnVerifiedModule(original);
2886   TF_ASSERT_OK(module.status());
2887   auto program_layout = module.ValueOrDie()->entry_computation_layout();
2888   ASSERT_EQ(program_layout.parameter_count(), 1);
2889   auto param_layout = program_layout.parameter_layout(0).layout();
2890   auto result_layout = program_layout.result_layout().layout();
2891   EXPECT_TRUE(
2892       LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
2893       << "actual layout of parameter(0) is "
2894       << LayoutUtil::HumanString(param_layout);
2895   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
2896       << "actual layout of result is "
2897       << LayoutUtil::HumanString(result_layout);
2898 }
2899 
TEST_F(HloParserTest,NoEntry)2900 TEST_F(HloParserTest, NoEntry) {
2901   const std::string original = R"(HloModule no_entry:
2902 c1 {
2903   const1 = f32[1]{0} constant({12345})
2904 }
2905 c2 {
2906   const2 = f32[1]{0} constant({67890})
2907 })";
2908   auto module = ParseAndReturnVerifiedModule(original);
2909   TF_ASSERT_OK(module.status());
2910   EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
2911 }
2912 
TEST_F(HloParserTest,NoRoot)2913 TEST_F(HloParserTest, NoRoot) {
2914   const std::string original = R"(HloModule no_root:
2915 ENTRY consts {
2916   first = f32[1]{0} constant({12345})
2917   last = f32[1]{0} constant({67890})
2918 })";
2919   auto module = ParseAndReturnVerifiedModule(original);
2920   TF_ASSERT_OK(module.status());
2921   EXPECT_EQ(
2922       module.ValueOrDie()->entry_computation()->root_instruction()->name(),
2923       "last");
2924 }
2925 
TEST_F(HloParserTest,Comments)2926 TEST_F(HloParserTest, Comments) {
2927   const std::string original = R"(/* module description. */
2928 HloModule comments:
2929 
2930 ENTRY /*comment*/ c1 {
2931   /* blah */
2932   ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
2933   /* comment */
2934 }
2935 
2936 /* something else */
2937 
2938 )";
2939   auto module = ParseAndReturnVerifiedModule(original);
2940   TF_ASSERT_OK(module.status());
2941 }
2942 
TEST_F(HloParserTest,MultilineComments)2943 TEST_F(HloParserTest, MultilineComments) {
2944   const std::string original = R"(HloModule multiline_comment:
2945 ENTRY c1 {
2946   /*
2947      ROOT foo = f32[1]{0} constant({12345})
2948   */
2949   ROOT const1 = f32[1]{0} constant({12345})
2950 /*
2951 a
2952 b
2953 c
2954 d
2955 
2956 */
2957 })";
2958   auto module = ParseAndReturnVerifiedModule(original);
2959   TF_ASSERT_OK(module.status());
2960 }
2961 
TEST_F(HloParserTest,UnterminatedComment)2962 TEST_F(HloParserTest, UnterminatedComment) {
2963   const std::string original = R"(HloModule unterminated_comment:
2964 ENTRY c1 {
2965 /* unterminated
2966   ROOT const1 = f32[1]{0} constant({12345})
2967 })";
2968   // Verify that the error message points to the beginning of the unterminated
2969   // comment.
2970   ExpectHasSubstr(
2971       ParseAndReturnUnverifiedModule(original).status().error_message(),
2972       "/* unterminated\n^");
2973 }
2974 
2975 TEST_F(HloParserTest, SlashSlashComments) {
2976   const std::string original = R"(HloModule slash_slash_comment:
2977 // Garbage
2978 ENTRY c1 {
2979   // Foo bar
2980   ROOT const1 = f32[1]{0} constant({12345}) // Something else
2981 })";
2982   auto module = ParseAndReturnVerifiedModule(original);
2983   TF_ASSERT_OK(module.status());
2984 }
2985 
2986 TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
2987   const std::string original =
2988       "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
2989       "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
2990   auto module = ParseAndReturnVerifiedModule(original);
2991   TF_ASSERT_OK(module.status());
2992 }
2993 
2994 TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
2995   const std::string original =
2996       "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
2997       "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
2998   auto module = ParseAndReturnVerifiedModule(original);
2999   TF_ASSERT_OK(module.status());
3000 }
3001 
3002 TEST_F(HloParserTest, MultipleEntries) {
3003   const std::string original = R"(HloModule multiple_entries:
3004 ENTRY c1 {
3005   const1 = f32[1]{0} constant({12345})
3006 }
3007 ENTRY c2 {
3008   const2 = f32[1]{0} constant({67890})
3009 })";
3010   ExpectHasSubstr(
3011       ParseAndReturnUnverifiedModule(original).status().error_message(),
3012       "expects only one ENTRY");
3013 }
3014 
3015 TEST_F(HloParserTest, SimpleAliasing) {
3016   const std::string original = R"(
3017 HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) }
3018 
3019 ENTRY entry {
3020   %p = (f32[], f32[]) parameter(0)
3021   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3022   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3023   ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3024 }
3025   )";
3026   auto module = ParseAndReturnVerifiedModule(original);
3027   TF_ASSERT_OK(module.status());
3028   std::unique_ptr<HloModule> parsed_module = std::move(module).value();
3029   EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
3030             ShapeIndex{0});
3031 
3032   EXPECT_TRUE(
3033       parsed_module->input_output_alias_config().ParameterMustAlias(0, {0}));
3034   EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
3035             ShapeIndex{1});
3036   EXPECT_FALSE(
3037       parsed_module->input_output_alias_config().ParameterMustAlias(0, {1}));
3038 }
3039 
3040 TEST_F(HloParserTest, NestedAliasing) {
3041   const std::string original = R"(
3042 HloModule Module, input_output_alias={ {0, 0}: (0, {0}), {1, 1}: (0, {1}) }
3043 
3044 ENTRY entry {
3045   %p = (f32[], f32[]) parameter(0)
3046   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3047   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3048   %t0 = (f32[], f32[]) tuple(%p0, %p1)
3049   %t1 = (f32[], f32[]) tuple(%p0, %p1)
3050   ROOT %out = ((f32[], f32[]), (f32[], f32[])) tuple(%t0, %t1)
3051 }
3052   )";
3053   auto module = ParseAndReturnVerifiedModule(original);
3054   TF_ASSERT_OK(module.status());
3055   std::unique_ptr<HloModule> parsed_module = std::move(module).value();
3056   EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
3057             ShapeIndex({0, 0}));
3058   EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
3059             ShapeIndex({1, 1}));
3060 }
3061 
3062 TEST_F(HloParserTest, AliasingWrongIndex) {
3063   const std::string original = R"(
3064 HloModule Module, input_output_alias={ {0 : (0, {0}), {1}: (0, {1}) }
3065 
3066 ENTRY entry {
3067   %p = (f32[], f32[]) parameter(0)
3068   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3069   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3070   ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3071 }
3072   )";
3073   ExpectHasSubstr(
3074       ParseAndReturnUnverifiedModule(original).status().error_message(),
3075       "Expects '}' at the end of ShapeIndex");
3076 }
3077 
3078 TEST_F(HloParserTest, AliasingShapeIndexNotNumerical) {
3079   const std::string original = R"(
3080 HloModule Module, input_output_alias={ {0, a}: (0, {0}), {1}: (0, {1}) }
3081 
3082 ENTRY entry {
3083   %p = (f32[], f32[]) parameter(0)
3084   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3085   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3086   ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3087 }
3088   )";
3089   ExpectHasSubstr(
3090       ParseAndReturnUnverifiedModule(original).status().error_message(),
3091       "expects integer");
3092 }
3093 
3094 TEST_F(HloParserTest, AliasingWrongFormatNoColon) {
3095   const std::string original = R"(
3096 HloModule Module, input_output_alias={ {0, 0}: (0, {0}), (0, {1}) }
3097 
3098 ENTRY entry {
3099   %p = (f32[], f32[]) parameter(0)
3100   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3101   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3102   ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3103 }
3104   )";
3105   ExpectHasSubstr(
3106       ParseAndReturnUnverifiedModule(original).status().error_message(),
3107       "Expects '{' at the start of ShapeIndex");
3108 }
3109 
3110 TEST_F(HloParserTest, AliasingWrongFormatTwoColons) {
3111   const std::string original = R"(
3112 HloModule Module, input_output_alias={ {0}: (0, {0}): {0, 1}, {1}: (0, {1}) }
3113 
3114 ENTRY entry {
3115   %p = (f32[], f32[]) parameter(0)
3116   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3117   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3118   ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3119 }
3120   )";
3121   ExpectHasSubstr(
3122       ParseAndReturnUnverifiedModule(original).status().error_message(),
3123       "Expects '}' at the end of aliasing description");
3124 }
3125 
3126 TEST_F(HloParserTest, AliasingWrongFormatAlphaParam) {
3127   const std::string original = R"(
3128 HloModule Module, input_output_alias={ {0, a}: (zero, {0}), {1}: (0, {1}) }
3129 
3130 ENTRY entry {
3131   %p = (f32[], f32[]) parameter(0)
3132   %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3133   %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3134   ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3135 }
3136   )";
3137   ExpectHasSubstr(
3138       ParseAndReturnUnverifiedModule(original).status().error_message(),
3139       "expects integer");
3140 }
3141 
3142 TEST_F(HloParserTest, MultipleRoots) {
3143   const std::string original = R"(HloModule multiple_roots:
3144 ENTRY consts {
3145   ROOT const1 = f32[1]{0} constant({12345})
3146   ROOT const2 = f32[1]{0} constant({12345})
3147 })";
3148   ExpectHasSubstr(
3149       ParseAndReturnUnverifiedModule(original).status().error_message(),
3150       "one computation should have only one ROOT");
3151 }
3152 
3153 TEST_F(HloParserTest, ComputationExists) {
3154   const std::string original = R"(HloModule comp_exists
3155 comp {
3156   const1 = f32[1]{0} constant({12345})
3157 }
3158 comp {
3159   const2 = f32[1]{0} constant({67890})
3160 })";
3161   ExpectHasSubstr(
3162       ParseAndReturnUnverifiedModule(original).status().error_message(),
3163       R"(was parsing 2:1: error: computation previously defined here
3164 comp {
3165 ^)");
3166 }
3167 
3168 TEST_F(HloParserTest, CrossComputationLookup) {
3169   const std::string original = R"(HloModule cross_computation_lookup:
3170 tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
3171   ROOT aparam = (s32[], s32[]) parameter(0)
3172 }
3173 
3174 tcallb (b: (s32[], s32[])) -> s32[] {
3175   rparam = (s32[], s32[]) parameter(0)
3176   ROOT gte0 = s32[] get-tuple-element(aparam), index=0
3177 }
3178 
3179 ENTRY entry {
3180   param = (s32[], s32[]) parameter(0)
3181   call0 = (s32[], s32[]) call(param), to_apply=tcalla
3182   ROOT call1 = s32[] call(param), to_apply=tcallb
3183 })";
3184   ExpectHasSubstr(
3185       ParseAndReturnUnverifiedModule(original).status().error_message(),
3186       "was parsing 8:39: error: instruction does not exist: aparam");
3187 }
3188 
3189 TEST_F(HloParserTest, SameNameDiffComputations) {
3190   const std::string original = R"(HloModule same_names:
3191 add {
3192   p0 = f32[] parameter(0)
3193   p1 = f32[] parameter(1)
3194   ROOT result = f32[] add(p0, p1)
3195 }
3196 
3197 ENTRY ReduceR3ToR2 {
3198   p0 = f32[8,16,256]{2,1,0} parameter(0)
3199   p1 = f32[] constant(0)
3200   ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
3201 }
3202 )";
3203   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
3204   ASSERT_NE(module->entry_computation(), nullptr);
3205   EXPECT_THAT(module->entry_computation()->root_instruction(),
3206               GmockMatch(m::Reduce()));
3207 }
3208 
3209 TEST_F(HloParserTest, ParseSharding) {
3210   const std::string original = "{maximal device=42}";
3211   TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
3212   EXPECT_EQ(sharding.ToString(), original);
3213 }
3214 
3215 TEST_F(HloParserTest, ParseShardingPartialReplication) {
3216   const std::string original = "{devices=[2,2]0,1,2,3 last_tile_dim_replicate}";
3217   TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
3218   EXPECT_EQ(sharding.ToString(), original);
3219   Array<int64_t> group_tiling({2});
3220   group_tiling(0) = 0;
3221   group_tiling(1) = 1;
3222   std::vector<int64_t> group0_members({0, 1});
3223   std::vector<int64_t> group1_members({2, 3});
3224   EXPECT_EQ(
3225       HloSharding::PartialTile(group_tiling, {group0_members, group1_members})
3226           .ToString(),
3227       original);
3228 }
3229 
3230 TEST_F(HloParserTest, ParseShardingSubGroup) {
3231   const std::string original =
3232       "{devices=[2,2,2,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 "
3233       "last_tile_dims={manual, replicated}}";
3234   TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
3235   EXPECT_EQ(sharding.ToString(), original);
3236   Array<int64_t> tile_assignment({2, 2, 2, 2});
3237   tile_assignment.FillIota(0);
3238   std::vector<OpSharding::Type> subgroup_types = {OpSharding::MANUAL,
3239                                                   OpSharding::REPLICATED};
3240   EXPECT_EQ(HloSharding::Subgroup(tile_assignment, subgroup_types).ToString(),
3241             original);
3242 }
3243 
3244 TEST_F(HloParserTest, ParseFrontendAttributes) {
3245   const std::string original =
3246       R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})";
3247   TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
3248                           ParseFrontendAttributes(original));
3249   EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
3250 }
3251 
3252 TEST_F(HloParserTest, ParseWindow) {
3253   Window original = window_util::MakeWindow({1, 2, 3});
3254   TF_ASSERT_OK_AND_ASSIGN(Window parsed,
3255                           ParseWindow(window_util::ToString(original)))
3256   EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
3257 }
3258 
3259 TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
3260   const std::string original = "b0f_0io->b0f";
3261   TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
3262                           ParseConvolutionDimensionNumbers(original));
3263   EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
3264 }
3265 
3266 TEST_F(HloParserTest, ParseConvolutionDimensionNumbersWithUnknownDims) {
3267   const std::string original = "b0?f_?0?io->?b?0?f";
3268   TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
3269                           ParseConvolutionDimensionNumbers(original));
3270   EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
3271 }
3272 
3273 TEST_F(HloParserTest, ParseReplicaGroups) {
3274   const std::string original = "{{0,1},{2,3}}";
3275   TF_ASSERT_OK_AND_ASSIGN(std::vector<ReplicaGroup> replica_groups,
3276                           ParseReplicaGroupsOnly(original));
3277   EXPECT_EQ(original, ReplicaGroupsToString(replica_groups));
3278 }
3279 
3280 TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
3281   const std::string original = "0_1x2_3";
3282   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
3283   EXPECT_EQ(original, PaddingConfigToString(dnums));
3284 }
3285 
3286 TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
3287   const std::string original = "0_1_0x2_3_4";
3288   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
3289   EXPECT_EQ(original, PaddingConfigToString(dnums));
3290 }
3291 
3292 TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
3293   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
3294   // The extra "_0" gets added to the canonical string because the other dim has
3295   // interior padding.
3296   EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
3297 }
3298 
3299 TEST_F(HloParserTest, NontupleInfeed) {
3300   const std::string original = R"(HloModule nontuple_infeed:
3301 ENTRY nontuple_infeed {
3302   token0 = token[] after-all()
3303   ROOT infeed = pred[] infeed(token0)
3304 })";
3305   ExpectHasSubstr(
3306       ParseAndReturnUnverifiedModule(original).status().error_message(),
3307       "infeed must have a non-empty tuple shape");
3308 }
3309 
3310 TEST(HloParserSingleOpTest, SingleOp) {
3311   const std::string text =
3312       "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
3313       "f32[2,4]{1,0} %x)";
3314   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3315   const HloComputation* computation = module->entry_computation();
3316   ASSERT_NE(computation, nullptr);
3317   EXPECT_THAT(computation->root_instruction(),
3318               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
3319 }
3320 
3321 TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
3322   const std::string text =
3323       "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
3324   StatusOr<std::unique_ptr<HloModule>> module =
3325       ParseAndReturnUnverifiedModule(text);
3326   ASSERT_TRUE(!module.status().ok());
3327   LOG(INFO) << "Status: " << module.status();
3328   EXPECT_THAT(module.status().ToString(),
3329               HasSubstr("expects '=' in instruction"));
3330 }
3331 
3332 TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
3333   const std::string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
3334   StatusOr<std::unique_ptr<HloModule>> module =
3335       ParseAndReturnUnverifiedModule(text);
3336   ASSERT_TRUE(!module.status().ok());
3337   LOG(INFO) << "Status: " << module.status();
3338   EXPECT_THAT(module.status().ToString(),
3339               HasSubstr("Operand had no shape in HLO text"));
3340 }
3341 
3342 TEST(HloParserSingleOpTest, SingleOpNoNames) {
3343   const std::string text =
3344       "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
3345   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3346   const HloComputation* computation = module->entry_computation();
3347   ASSERT_NE(computation, nullptr);
3348   EXPECT_THAT(computation->root_instruction(),
3349               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
3350 }
3351 
3352 TEST(HloParserSingleOpTest, CanonicalOp) {
3353   const std::string text =
3354       "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
3355   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3356   const HloComputation* computation = module->entry_computation();
3357   ASSERT_NE(computation, nullptr);
3358   EXPECT_THAT(computation->root_instruction(),
3359               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
3360   EXPECT_EQ(
3361       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
3362       text);
3363 }
3364 
3365 TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
3366   const std::string text =
3367       R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
3368 {
3369   tmp_0 = f32[5,10]{1,0} parameter(0)
3370   tmp_1 = f32[20,10]{1,0} parameter(1)
3371   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
3372   {
3373     tmp_0 = f32[5,10]{1,0} parameter(0)
3374     tmp_1 = f32[20,10]{1,0} parameter(1)
3375     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
3376     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
3377   }
3378 }, body=
3379 {
3380   tmp_0 = f32[5,10]{1,0} parameter(0)
3381   tmp_1 = f32[20,10]{1,0} parameter(1)
3382   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
3383   {
3384     tmp_0 = f32[5,10]{1,0} parameter(0)
3385     tmp_1 = f32[20,10]{1,0} parameter(1)
3386     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
3387     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
3388   }
3389 })";
3390 
3391   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3392   const HloComputation* computation = module->entry_computation();
3393   ASSERT_NE(computation, nullptr);
3394   EXPECT_EQ(
3395       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
3396       text);
3397 }
3398 
3399 TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
3400   const std::string text =
3401       R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
3402 {
3403   tmp_0 = f32[5,10]{1,0} parameter(0)
3404   ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
3405 },
3406 {
3407   tmp_0 = f32[5,10]{1,0} parameter(0)
3408   ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
3409 },
3410 {
3411   tmp_0 = f32[5,10]{1,0} parameter(0)
3412   ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
3413 }
3414 })";
3415 
3416   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3417   const HloComputation* computation = module->entry_computation();
3418   ASSERT_NE(computation, nullptr);
3419   EXPECT_EQ(
3420       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
3421       text);
3422 }
3423 
3424 TEST(HloParserSingleOpTest, SingleOpWithNested) {
3425   const std::string text =
3426       R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
3427 {
3428   %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
3429   %param_1 = f32[2]{0} parameter(1)
3430   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
3431   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
3432 })";
3433 
3434   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3435   const HloComputation* computation = module->entry_computation();
3436   ASSERT_NE(computation, nullptr);
3437   EXPECT_THAT(computation->root_instruction(),
3438               GmockMatch(m::Op()
3439                              .WithOpcode(HloOpcode::kFusion)
3440                              .WithNumOperands(2)
3441                              .WithOperand(0, m::Parameter(0))
3442                              .WithOperand(1, m::Parameter(1))));
3443 }
3444 
3445 TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
3446   const std::string text =
3447       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
3448 {
3449   result = f32[] add(f32[] x, f32[] y)
3450 })";
3451   auto status = ParseAndReturnUnverifiedModule(text).status();
3452   ASSERT_FALSE(status.ok());
3453   EXPECT_THAT(status.error_message(), HasSubstr("does not exist: x"));
3454 }
3455 
3456 TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
3457   const std::string text =
3458       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
3459 {
3460   f32[] add(f32[] x, f32[] y)
3461 })";
3462   auto status = ParseAndReturnUnverifiedModule(text).status();
3463   ASSERT_FALSE(status.ok());
3464   EXPECT_THAT(status.error_message(), HasSubstr("expects name"));
3465 }
3466 
3467 TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
3468   const std::string text =
3469       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
3470 {
3471   result = f32[] add(f32[], f32[])
3472 })";
3473   auto status = ParseAndReturnUnverifiedModule(text).status();
3474   ASSERT_FALSE(status.ok());
3475   EXPECT_THAT(status.error_message(), HasSubstr("expects name"));
3476 }
3477 
3478 TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
3479   const std::string text =
3480       R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
3481   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3482   const HloComputation* computation = module->entry_computation();
3483   ASSERT_NE(computation, nullptr);
3484   EXPECT_THAT(computation->root_instruction(),
3485               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
3486   auto* convolution =
3487       Cast<HloConvolutionInstruction>(computation->root_instruction());
3488   EXPECT_EQ(convolution->feature_group_count(), 1);
3489 }
3490 
3491 TEST(HloParserSingleOpTest, MultipleOpsProducesError) {
3492   const std::string text = R"(
3493     param = f32[2,5,1,3] parameter(0)
3494     transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
3495   )";
3496   auto status = ParseAndReturnUnverifiedModule(text).status();
3497   ASSERT_FALSE(status.ok());
3498   EXPECT_THAT(status.error_message(), HasSubstr("Expected eof"));
3499 }
3500 
3501 TEST_F(HloParserTest, IsScheduledIsFalse) {
3502   const std::string text = R"(
3503 HloModule axpy_module, is_scheduled=false
3504 
3505 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3506   %alpha = f32[] parameter(0)
3507   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3508   %x = f32[2,4]{1,0} parameter(1)
3509   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3510   %y = f32[2,4]{1,0} parameter(2)
3511   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3512 }
3513 )";
3514   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3515   ASSERT_FALSE(module->has_schedule());
3516 }
3517 
3518 TEST_F(HloParserTest, IsScheduledNotPresent) {
3519   const std::string text = R"(
3520 HloModule axpy_module
3521 
3522 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3523   %alpha = f32[] parameter(0)
3524   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3525   %x = f32[2,4]{1,0} parameter(1)
3526   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3527   %y = f32[2,4]{1,0} parameter(2)
3528   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3529 }
3530 )";
3531   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3532   ASSERT_FALSE(module->has_schedule());
3533 }
3534 
3535 TEST_F(HloParserTest, IsScheduledIsTrue) {
3536   const std::string text = R"(
3537 HloModule axpy_module, is_scheduled=true
3538 
3539 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3540   %alpha = f32[] parameter(0)
3541   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3542   %x = f32[2,4]{1,0} parameter(1)
3543   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3544   %y = f32[2,4]{1,0} parameter(2)
3545   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3546 }
3547 )";
3548   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3549   ASSERT_TRUE(module->has_schedule());
3550   TF_ASSERT_OK(module->schedule().Verify());
3551   EXPECT_EQ(module->schedule().sequences().size(), 1);
3552   ASSERT_TRUE(
3553       module->schedule().is_computation_scheduled(module->entry_computation()));
3554   EXPECT_THAT(
3555       module->schedule().sequence(module->entry_computation()).instructions(),
3556       ElementsAre(GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
3557                   GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
3558                   GmockMatch(m::Parameter()), GmockMatch(m::Add())));
3559 }
3560 
3561 TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
3562   // As above but in with a different schedule order.
3563   const std::string text = R"(
3564 HloModule axpy_module, is_scheduled=true
3565 
3566 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3567   %alpha = f32[] parameter(0)
3568   %x = f32[2,4]{1,0} parameter(1)
3569   %y = f32[2,4]{1,0} parameter(2)
3570   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3571   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3572   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3573 }
3574 )";
3575   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3576   ASSERT_TRUE(module->has_schedule());
3577   TF_ASSERT_OK(module->schedule().Verify());
3578   EXPECT_EQ(module->schedule().sequences().size(), 1);
3579   ASSERT_TRUE(
3580       module->schedule().is_computation_scheduled(module->entry_computation()));
3581   EXPECT_THAT(
3582       module->schedule().sequence(module->entry_computation()).instructions(),
3583       ElementsAre(GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
3584                   GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
3585                   GmockMatch(m::Multiply()), GmockMatch(m::Add())));
3586 }
3587 
3588 TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
3589   const std::string original =
3590       R"(HloModule CustomCallWrongNumberofOperandConstraints
3591 
3592 ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
3593   %p0 = f32[42,2,3]{0,1,2} parameter(0)
3594   %p1 = f32[123,4]{0,1} parameter(1)
3595   ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
3596 }
3597 
3598 )";
3599   ExpectHasSubstr(
3600       ParseAndReturnUnverifiedModule(original).status().error_message(),
3601       "Expected 2 operand layout constraints, 1 given");
3602 }
3603 
3604 TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
3605   const std::string original =
3606       R"(HloModule CustomCallIncompatibleOperandConstraints
3607 
3608 ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
3609   %p0 = f32[42,2,3]{0,1,2} parameter(0)
3610   %p1 = f32[123,4]{0,1} parameter(1)
3611   ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
3612 }
3613 
3614 )";
3615   ExpectHasSubstr(
3616       ParseAndReturnUnverifiedModule(original).status().error_message(),
3617       "operand 1 is not compatible with operand shape");
3618 }
3619 
3620 TEST_F(HloParserTest, CustomCallWithNonexistentVersion) {
3621   const std::string original = R"(HloModule custom_call
3622 
3623 ENTRY %CustomCall () -> f32[1,2,3] {
3624   %constant = f32[1]{0} constant({12345})
3625   ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", api_version=API_VERSION_THAT_DOESNT_EXIST
3626 }
3627 
3628 )";
3629   ExpectHasSubstr(
3630       ParseAndReturnUnverifiedModule(original).status().error_message(),
3631       "Unknown API version");
3632 }
3633 
3634 TEST_F(HloParserTest, CustomCallWithUnspecifiedVersion) {
3635   const std::string original = R"(HloModule custom_call
3636 
3637 ENTRY %CustomCall () -> f32[1,2,3] {
3638   %constant = f32[1]{0} constant({12345})
3639   ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", api_version=API_VERSION_UNSPECIFIED
3640 }
3641 
3642 )";
3643   ExpectHasSubstr(
3644       ParseAndReturnUnverifiedModule(original).status().error_message(),
3645       "Invalid API version");
3646 }
3647 
3648 TEST_F(HloParserTest, AllowShapeWhitespace) {
3649   const std::string text = R"(
3650 HloModule module
3651 
3652 ENTRY entry {
3653   ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
3654 }
3655 )";
3656   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3657 }
3658 
3659 TEST_F(HloParserTest, ShapeMismatchInOperand) {
3660   const std::string text = R"(
3661 HloModule foobar
3662 
3663 ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
3664   %p = f32[2,2] parameter(0)
3665   %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
3666   ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
3667 }
3668 )";
3669 
3670   ExpectHasSubstr(ParseAndReturnUnverifiedModule(text).status().error_message(),
3671                   "The declared operand shape f32[2,5]{1,0} is not compatible"
3672                   " with the shape of the operand instruction f32[2,2]{1,0}.");
3673 }
3674 
3675 TEST_F(HloParserTest, ParseShapeStringR2F32) {
3676   std::string shape_string = "f32[123,456]";
3677   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3678   Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
3679   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3680       << "expected: " << ShapeUtil::HumanString(expected)
3681       << "actual:   " << ShapeUtil::HumanString(actual);
3682 }
3683 
3684 TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
3685   std::string shape_string = "(f32[1572864],s8[5120,1024])";
3686   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3687   Shape expected =
3688       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
3689                                  ShapeUtil::MakeShape(S8, {5120, 1024})});
3690   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3691       << "expected: " << ShapeUtil::HumanString(expected)
3692       << "actual:   " << ShapeUtil::HumanString(actual);
3693 }
3694 
3695 TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
3696   std::string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
3697   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3698   Shape expected = ShapeUtil::MakeTupleShape({
3699       ShapeUtil::MakeShape(F32, {1}),
3700       ShapeUtil::MakeTupleShape(
3701           {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
3702       ShapeUtil::MakeOpaqueShape(),
3703       ShapeUtil::MakeShape(F32, {3}),
3704   });
3705   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3706       << "expected: " << ShapeUtil::HumanString(expected)
3707       << "actual:   " << ShapeUtil::HumanString(actual);
3708 }
3709 
3710 TEST_F(HloParserTest, ParseShapeStringWithLayout) {
3711   std::string shape_string = "f32[123,456]{0,1}";
3712   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3713   Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
3714   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3715       << "expected: " << ShapeUtil::HumanString(expected)
3716       << "actual:   " << ShapeUtil::HumanString(actual);
3717 }
3718 
3719 TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
3720   // One tile.
3721   std::string shape_string = "f32[123,456]{0,1:T(2,128)}";
3722   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3723   Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {},
3724                                                   {Tile({2, 128})});
3725   EXPECT_EQ(expected, actual)
3726       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3727       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3728 
3729   // Tile with negative dimension size for combining dimensions.
3730   shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
3731   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3732   expected =
3733       ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2}, {},
3734                                      {Tile({2, Tile::kCombineDimension, 128})});
3735   EXPECT_EQ(expected, actual)
3736       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3737       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3738 
3739   // Two tiles.
3740   shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
3741   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3742   expected = ShapeUtil::MakeShapeWithLayout(
3743       BF16, {123, 456, 789}, {2, 1, 0}, {},
3744       {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
3745   EXPECT_EQ(expected, actual)
3746       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3747       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3748 
3749   // Tile with element size in bits.
3750   shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
3751   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3752   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {},
3753                                             {Tile({2, 128})}, 1);
3754   EXPECT_EQ(expected, actual)
3755       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3756       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3757 
3758   // Element size in bits without tile.
3759   shape_string = "pred[123,456]{1,0:E(1)}";
3760   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3761   expected =
3762       ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, {}, 1);
3763   EXPECT_EQ(expected, actual)
3764       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3765       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3766 
3767   // Wrong minor_to_major.
3768   shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
3769   auto result = ParseShape(shape_string);
3770   ExpectHasSubstr(result.status().error_message(),
3771                   "Dimensions size is 3, but minor to major size is 1.");
3772 }
3773 
3774 TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
3775   // Tile, element size, and memory space.
3776   std::string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
3777   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3778   Shape expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {},
3779                                                   {Tile({2, 128})}, 1, 3);
3780   EXPECT_EQ(expected, actual)
3781       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3782       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3783 
3784   // Element size and memory space.
3785   shape_string = "pred[123,456]{1,0:E(1)S(3)}";
3786   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3787   expected =
3788       ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, {}, 1, 3);
3789   EXPECT_EQ(expected, actual)
3790       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3791       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3792 
3793   // Memory space only.
3794   shape_string = "pred[123,456]{1,0:S(3)}";
3795   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3796   expected =
3797       ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, {}, 0, 3);
3798   EXPECT_EQ(expected, actual)
3799       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3800       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
3801 }
3802 
3803 TEST_F(HloParserTest, ParseOpaqueType) {
3804   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
3805   Shape expected = ShapeUtil::MakeOpaqueShape();
3806   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3807       << "expected: " << ShapeUtil::HumanString(expected)
3808       << "actual:   " << ShapeUtil::HumanString(actual);
3809 }
3810 
3811 TEST_F(HloParserTest, ParseTokenType) {
3812   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
3813   Shape expected = ShapeUtil::MakeTokenShape();
3814   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3815       << "expected: " << ShapeUtil::HumanString(expected)
3816       << "actual:   " << ShapeUtil::HumanString(actual);
3817 }
3818 
3819 TEST_F(HloParserTest, ParseInvalidShapeString) {
3820   std::string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
3821                                  "f32[123,456]dense{foo}"};
3822   for (const std::string& shape_string : shape_strings) {
3823     StatusOr<Shape> result = ParseShape(shape_string);
3824     ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
3825   }
3826 }
3827 
3828 TEST_F(HloParserTest, ParseDynamicArray) {
3829   std::string shape_string = "f32[123,<=456]";
3830   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3831   Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
3832   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3833       << "expected: " << ShapeUtil::HumanString(expected)
3834       << "actual:   " << ShapeUtil::HumanString(actual);
3835 }
3836 
3837 TEST_F(HloParserTest, ParseDynamicTuple) {
3838   std::string shape_string = "(f32[42], u32[<=123,<=456])";
3839   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3840   Shape expected = ShapeUtil::MakeTupleShape(
3841       {ShapeUtil::MakeShape(F32, {42}),
3842        ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
3843   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3844       << "expected: " << ShapeUtil::HumanString(expected)
3845       << "actual:   " << ShapeUtil::HumanString(actual);
3846 }
3847 
3848 TEST_F(HloParserTest, NegativeParameterNumber) {
3849   const std::string hlo_string = "par0 = f32[3,5] parameter(-1)";
3850   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3851   ASSERT_FALSE(result.status().ok());
3852   EXPECT_THAT(result.status().error_message(),
3853               HasSubstr("parameter number must be >= 0"));
3854 }
3855 
3856 TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
3857   const std::string hlo_string =
3858       "par0 = (f32[3,5], f32[]) parameter(0), "
3859       "parameter_replication={true,false,true}";
3860   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3861   ASSERT_FALSE(result.status().ok());
3862   EXPECT_THAT(result.status().error_message(),
3863               HasSubstr("parameter has 2 leaf buffers, but "
3864                         "parameter_replication has 3 elements"));
3865 }
3866 
3867 TEST_F(HloParserTest, CheckIndexedConditionalDimension) {
3868   const char* const hlo_string = R"(
3869   HloModule Module
3870 
3871   branch0 {
3872     tparam = f32[4] parameter(0)
3873     ROOT tgte1 = f32[4] ceil(tparam)
3874   }
3875 
3876   branch1 {
3877     fparam = f32[4] parameter(0)
3878     ROOT fgte1 = f32[4] floor(fparam)
3879   }
3880 
3881   ENTRY entry {
3882     p0 = f32[4] parameter(0)
3883     b0 = s32[2] parameter(1)
3884     ROOT conditional = f32[4] conditional(b0, p0, p0),
3885       branch_computations={branch0, branch1}
3886   }
3887   )";
3888   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3889   EXPECT_NE(OkStatus(), result.status());
3890   EXPECT_THAT(result.status().error_message(),
3891               HasSubstr("The first operand must be a scalar"));
3892 }
3893 
3894 TEST_F(HloParserTest, CheckIndexedConditionalElementType) {
3895   const char* const hlo_string = R"(
3896   HloModule Module
3897 
3898   branch0 {
3899     tparam = f32[4] parameter(0)
3900     ROOT tgte1 = f32[4] ceil(tparam)
3901   }
3902 
3903   branch1 {
3904     fparam = f32[4] parameter(0)
3905     ROOT fgte1 = f32[4] floor(fparam)
3906   }
3907 
3908   ENTRY entry {
3909     p0 = f32[4] parameter(0)
3910     b0 = f32[] parameter(1)
3911     ROOT conditional = f32[4] conditional(b0, p0, p0),
3912       branch_computations={branch0, branch1}
3913   }
3914   )";
3915   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3916   EXPECT_NE(OkStatus(), result.status());
3917   EXPECT_THAT(result.status().error_message(),
3918               HasSubstr("The first operand must be a scalar of PRED or S32"));
3919 }
3920 
3921 TEST_F(HloParserTest,
3922        CheckPredicatedConditionalRequiresTrueAndFalseComputation) {
3923   const char* const hlo_string = R"(
3924   HloModule Module
3925 
3926   branch0 {
3927     tparam = f32[4] parameter(0)
3928     ROOT tgte1 = f32[4] ceil(tparam)
3929   }
3930 
3931   branch1 {
3932     fparam = f32[4] parameter(0)
3933     ROOT fgte1 = f32[4] floor(fparam)
3934   }
3935 
3936   ENTRY entry {
3937     p0 = f32[4] parameter(0)
3938     b0 = pred[] parameter(1)
3939     ROOT conditional = f32[4] conditional(b0, p0, p0),
3940       branch_computations={branch0, branch1}
3941   }
3942   )";
3943   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3944   EXPECT_NE(OkStatus(), result.status());
3945   EXPECT_THAT(result.status().error_message(),
3946               HasSubstr("unexpected attribute \"branch_computations\""));
3947 }
3948 
3949 // Result shape inference tests cases.
3950 TEST_F(HloParserTest, InferUnaryShape) {
3951   constexpr char text[] = R"(HloModule InferUnaryShapeTest
3952 ENTRY InferUnaryShape {
3953   a = f32[2,10]{1,0} parameter(0)
3954   ROOT v = abs(a)
3955 }
3956 )";
3957   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3958 }
3959 
3960 TEST_F(HloParserTest, InferBinaryShape) {
3961   constexpr char text[] = R"(HloModule InferBinaryShapeTest
3962 ENTRY InferBinaryShape {
3963   a = f32[2,10]{1,0} parameter(0)
3964   b = f32[2,10]{1,0} parameter(1)
3965   ROOT sum = add(a, b)
3966 }
3967 )";
3968   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3969   EXPECT_TRUE(ShapeUtil::Equal(
3970       module->entry_computation()->ComputeProgramShape().result(),
3971       ShapeUtil::MakeShapeWithLayout(F32, {2, 10}, {1, 0})));
3972 }
3973 
3974 TEST_F(HloParserTest, InferTernaryShape) {
3975   constexpr char text[] = R"(HloModule InferTernaryShapeTest
3976 ENTRY InferTernaryShape {
3977   p = pred[] constant(true)
3978   f = s32[] constant(-42)
3979   t = s32[] constant(42)
3980   ROOT select = select(p, f, t)
3981 }
3982 )";
3983   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3984   EXPECT_TRUE(ShapeUtil::Equal(
3985       module->entry_computation()->ComputeProgramShape().result(),
3986       ShapeUtil::MakeScalarShape(S32)));
3987 }
3988 
3989 TEST_F(HloParserTest, InferDotShape) {
3990   constexpr char text[] = R"(HloModule InferDotShapeTest
3991 ENTRY InferDotShape {
3992   a = f32[2,10]{1,0} parameter(0)
3993   b = f32[10,2]{1,0} parameter(1)
3994   ROOT dot = dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
3995 }
3996 )";
3997   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3998   EXPECT_TRUE(ShapeUtil::Equal(
3999       module->entry_computation()->ComputeProgramShape().result(),
4000       ShapeUtil::MakeShape(F32, {2}, {0})));
4001 }
4002 
4003 TEST_F(HloParserTest, InferTupleShape) {
4004   constexpr char text[] = R"(HloModule InferTupleShapeTest
4005 ENTRY InferTupleShape () -> s32[2,3] {
4006   c0 = f32[3]{0} constant({1, 2, 3})
4007   c1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
4008   tuple = tuple(c0, c1)
4009   ROOT get = get-tuple-element(tuple), index=1, sharding={maximal device=0}
4010 }
4011 )";
4012   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
4013   EXPECT_TRUE(ShapeUtil::Equal(
4014       module->entry_computation()->ComputeProgramShape().result(),
4015       ShapeUtil::MakeShapeWithLayout(S32, {2, 3}, {1, 0})));
4016 }
4017 
4018 TEST_F(HloParserTest, InferShapeMixedExplicitShape) {
4019   constexpr char text[] = R"(HloModule InferUnaryShapeTest
4020 Negate {
4021   x = f32[] parameter(0)
4022   ROOT negate = negate(x)
4023 }
4024 
4025 Identity {
4026   y = f32[] parameter(0)
4027   ROOT copy = copy(y)
4028 }
4029 
4030 ENTRY InferUnaryShape {
4031   a = f32[] parameter(0)
4032   b = f32[] parameter(1)
4033   p = pred[] parameter(2)
4034   c = f32[] add(a, b)
4035   ROOT conditional = conditional(p, a, c), true_computation=Negate, false_computation=Identity
4036 }
4037 )";
4038   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
4039   EXPECT_TRUE(ShapeUtil::Equal(
4040       module->entry_computation()->ComputeProgramShape().result(),
4041       ShapeUtil::MakeScalarShape(F32)));
4042 }
4043 
4044 TEST_F(HloParserTest, CheckAliasPassthroughParams) {
4045   const char* const hlo_string = R"(
4046 HloModule TestModule, alias_passthrough_params=true
4047 
4048 ENTRY TestComputation {
4049     p0 = f16[2048,1024] parameter(0)
4050     p1 = f16[2048,1024] parameter(1)
4051     ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1)
4052 }
4053 )";
4054   auto result = ParseAndReturnVerifiedModule(hlo_string);
4055   TF_EXPECT_OK(result.status());
4056   EXPECT_TRUE(result.ValueOrDie()->config().alias_passthrough_params());
4057 }
4058 
4059 TEST_F(HloParserTest, NestedBroadcastWithoutDimensionsAttribute) {
4060   const char* const hlo_string = R"(
4061 HloModule test
4062 ENTRY test {
4063     ROOT root = sqrt(f32[10,10] broadcast(f32[10] parameter(0)))
4064 }
4065 )";
4066   auto result = ParseAndReturnVerifiedModule(hlo_string);
4067   EXPECT_NE(OkStatus(), result.status());
4068   EXPECT_THAT(result.status().error_message(), HasSubstr("dimensions"));
4069 }
4070 
4071 TEST_F(HloParserTest, InvalidDimLevelType) {
4072   const std::string original = R"(HloModule test
4073 
4074 ENTRY test {
4075   ROOT root = f32[10,10]{1,0:D(X,C)} parameter(0)
4076 })";
4077   EXPECT_THAT(ParseAndReturnUnverifiedModule(original).status(),
4078               tensorflow::testing::StatusIs(
4079                   tensorflow::error::INVALID_ARGUMENT,
4080                   HasSubstr("expected a DimLevelType abbreviation")));
4081 }
4082 
4083 TEST_F(HloParserTest, InvalidDimLevelTypeCount) {
4084   const std::string original = R"(HloModule test
4085 
4086 ENTRY test {
4087   ROOT root = f32[10,10]{1,0:D(C)} parameter(0)
4088 })";
4089   EXPECT_THAT(
4090       ParseAndReturnUnverifiedModule(original).status(),
4091       tensorflow::testing::StatusIs(
4092           tensorflow::error::INVALID_ARGUMENT,
4093           HasSubstr("Dimensions size is 2, but dim level types size is 1")));
4094 }
4095 
4096 TEST_F(HloParserTest, RejectSparseTiles) {
4097   const std::string original = R"(HloModule test
4098 
4099 ENTRY test {
4100   ROOT root = f32[10,10]{1,0:D(D,C)T(128,8)} parameter(0)
4101 })";
4102   EXPECT_THAT(ParseAndReturnUnverifiedModule(original).status(),
4103               tensorflow::testing::StatusIs(
4104                   tensorflow::error::INVALID_ARGUMENT,
4105                   HasSubstr("Layout has tiles, but is for a sparse array")));
4106 }
4107 
4108 }  // namespace
4109 }  // namespace xla
4110