1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/ascii.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/status_matchers.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/platform/test.h"
39
40 namespace xla {
41 namespace {
42
43 namespace m = ::xla::match;
44
45 using ::absl::string_view;
46 using ::testing::ElementsAre;
47 using ::testing::HasSubstr;
48
49 struct TestData {
50 std::string test_name;
51 std::string module_string;
52 int64_t replica_count = 1;
53 bool enable_verification = true;
54 };
55
TestDataToString(const::testing::TestParamInfo<TestData> & data)56 std::string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
57 return data.param.test_name;
58 }
59
60 // Tests where the input module string doesn't match the output.
61 //
62 // In general we want to avoid these because we want HLO text to be
63 // round-trippable! But nested instructions, e.g. add(sqrt(x), y), cannot be
64 // round-triped without modification.
65 struct NonRoundtripTestData {
66 std::string test_name;
67 std::string input_module_string;
68 std::string output_module_string;
69 };
70
NonRoundtripTestDataToString(const::testing::TestParamInfo<NonRoundtripTestData> & data)71 std::string NonRoundtripTestDataToString(
72 const ::testing::TestParamInfo<NonRoundtripTestData>& data) {
73 return data.param.test_name;
74 }
75
76 // For each string below, we check that:
77 // - we parse it to an HloModule successfully, and
78 // - the stringification of the resulting HloModule is equal to our original
79 // string.
CreateTestCases()80 std::vector<TestData> CreateTestCases() {
81 // clang-format off
82 return std::vector<TestData>({
83 // ax + y
84 {
85 "AxpyParam",
86 R"(HloModule axpy_module, entry_computation_layout={(f32[],f32[2,4]{1,0},f32[2,4]{1,0})->f32[2,4]{1,0}}
87
88 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
89 %alpha = f32[] parameter(0)
90 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
91 %x = f32[2,4]{1,0} parameter(1)
92 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
93 %y = f32[2,4]{1,0} parameter(2)
94 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
95 }
96
97 )"
98 },
99 // parameter replication
100 {
101 "ParamReplication",
102 R"(HloModule param_replication_module, entry_computation_layout={(f32[],(f32[2,4]{1,0}, (f32[2,4]{1,0})))->(f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0})))}
103
104 ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
105 %a = f32[] parameter(0), parameter_replication={true}
106 %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
107 ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
108 }
109
110 )"
111 },
112 // pred constant
113 {
114 "ConstantPred",
115 R"(HloModule constant_pred_module, entry_computation_layout={()->pred[]}
116
117 ENTRY %constant_pred () -> pred[] {
118 ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
119 }
120
121 )"
122 },
123 // pred array constant
124 {
125 "ConstantPredArray",
126 R"(HloModule module, entry_computation_layout={()->pred[2,3]{1,0}}
127
128 ENTRY %constant_pred_array () -> pred[2,3] {
129 ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
130 }
131
132 )"
133 },
134
135 // s32 constant
136 {
137 "ConstantS32",
138 R"(HloModule constant_s32_module, entry_computation_layout={()->s32[]}
139
140 ENTRY %constant_s32 () -> s32[] {
141 ROOT %constant = s32[] constant(-42)
142 }
143
144 )"
145 },
146 // f32 constant, but the value is not a decimal and there is a backend
147 // configuration
148 {
149 "ConstantF32",
150 R"(HloModule ConstantF32_module, entry_computation_layout={()->f32[]}
151
152 ENTRY %ConstantF32.v4 () -> f32[] {
153 ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
154 }
155
156 )"
157 },
158 // f32 constant, rank 1 empty array.
159 {
160 "ConstantF32R1Empty",
161 R"(HloModule ConstantF32Empty_module, entry_computation_layout={()->f32[0]{0}}
162
163 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
164 ROOT %constant = f32[0]{0} constant({})
165 }
166
167 )"
168 },
169 // f32 constant, rank 4 empty array.
170 {
171 "ConstantF32R4Empty",
172 R"(HloModule ConstantF32R4Empty_module, entry_computation_layout={()->f32[2,0,4,3]{3,2,1,0}}
173
174 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
175 ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
176 }
177
178 )"
179 },
180 // constant 4D
181 {
182 "Constant4D",
183 R"(HloModule Small_3x2x1x1_module, entry_computation_layout={()->f32[3,2,1,1]{3,2,1,0}}
184
185 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
186 ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
187 }
188
189 )"
190 },
191 // non-finite constants: nan, inf, -inf
192 {
193 "ConstantNonFinite",
194 R"(HloModule IsFiniteR1F32s_module, entry_computation_layout={()->pred[6]{0}}
195
196 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
197 %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
198 ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
199 }
200
201 )"
202 },
203 // constant f16
204 {
205 "ConstantF16",
206 R"(HloModule ConstantF16_module, entry_computation_layout={()->f16[]}
207
208 ENTRY %ConstantF16.v4 () -> f16[] {
209 ROOT %constant = f16[] constant(500)
210 }
211
212 )"
213 },
214 // bf16
215 {
216 "BF16",
217 R"(HloModule BF16, entry_computation_layout={()->bf16[]}
218
219 ENTRY %BF16.v4 () -> bf16[] {
220 ROOT %constant = bf16[] constant(500)
221 }
222
223 )"
224 },
225 // constant + constant
226 {
227 "AddConstants",
228 R"(HloModule add_constants_module, entry_computation_layout={()->f32[]}
229
230 ENTRY %add_constants () -> f32[] {
231 %constant = f32[] constant(3.14)
232 ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
233 }
234
235 )"
236 },
237 // tuple constant
238 {
239 "TupleConstant",
240 R"(HloModule TupleConstant_module, entry_computation_layout={()->(f32[2,1]{1,0}, f32[2]{0})}
241
242 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
243 ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
244 }
245
246 )"
247 },
248 // v1 > v2 ? v1 : v2
249 {
250 "SelectR1F32",
251 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module, entry_computation_layout={(f32[4]{0},f32[4]{0})->f32[4]{0}}
252
253 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
254 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
255 %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
256 %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated}
257 ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
258 }
259
260 )"
261 },
262 // empty tuple
263 {
264 "EmptyTupleCreate",
265 R"(HloModule EmptyTupleCreate_module, entry_computation_layout={()->()}
266
267 ENTRY %EmptyTupleCreate.v1 () -> () {
268 ROOT %tuple = () tuple()
269 }
270
271 )"
272 },
273 // tuple
274 {
275 "TupleCreate",
276 R"(HloModule TupleCreate_module, entry_computation_layout={(f32[],f32[3]{0},f32[2,3]{1,0})->(f32[], f32[3]{0}, f32[2,3]{1,0})}
277
278 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
279 %v1 = f32[] parameter(0)
280 %v2 = f32[3]{0} parameter(1)
281 %v3 = f32[2,3]{1,0} parameter(2)
282 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
283 }
284
285 )"
286 },
287 // tuple
288 {
289 "LargeTupleRoundTrip",
290 R"(HloModule LargeTupleRoundTrip_module, entry_computation_layout={(f32[])->(f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[])}
291
292 ENTRY %TupleCreate.v4 (v: f32[]) -> (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[]) {
293 %v = f32[] parameter(0)
294 ROOT %tuple = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[]) tuple(f32[] %v, f32[] %v, f32[] %v, f32[] %v, f32[] %v, /*index=5*/f32[] %v)
295 }
296
297 )"
298 },
299 {
300 "ShardedTupleCreate",
301 R"(HloModule ShardedTupleCreate_module, entry_computation_layout={(f32[],f32[3]{0},f32[2,3]{1,0})->(f32[], f32[3]{0}, f32[2,3]{1,0})}
302
303 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
304 %v1 = f32[] parameter(0), sharding={manual}
305 %v2 = f32[3]{0} parameter(1)
306 %v3 = f32[2,3]{1,0} parameter(2)
307 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{manual}, {maximal device=0}, {replicated}}
308 }
309
310 )"
311 },
312 {
313 "DomainParsing",
314 R"(HloModule DomainParsing_module, entry_computation_layout={(f32[])->f32[]}
315
316 ENTRY %DomainParsing (v1: f32[]) -> f32[] {
317 %v1 = f32[] parameter(0)
318 ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
319 }
320
321 )"
322 },
323 // int32_t result = 0;
324 // while (result < 5) { result = result + 1; }
325 {
326 "WhileWithScalarS32Result",
327 R"(HloModule WhileWithScalarS32Result_module, entry_computation_layout={()->s32[]}
328
329 %body.v3 (prev.1: s32[]) -> s32[] {
330 %constant = s32[] constant(1)
331 %prev.1 = s32[] parameter(0)
332 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
333 }
334
335 %condition.v3 (prev.2: s32[]) -> pred[] {
336 %constant.1 = s32[] constant(5)
337 %prev.2 = s32[] parameter(0)
338 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
339 }
340
341 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
342 %constant.2 = s32[] constant(0)
343 ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
344 }
345
346 )"
347 },
348 // copy-start and copy-done
349 {
350 "CopyStartAndCopyDone",
351
352 R"(HloModule CopyStartAndCopyDone_module, entry_computation_layout={(f32[],f32[2,3]{1,0:S(1)})->(f32[], f32[2,3]{1,0:S(2)})}
353
354 ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
355 %v1 = f32[] parameter(0)
356 %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true
357 %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
358 %v2 = f32[2,3]{1,0:S(1)} parameter(1)
359 %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
360 %copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
361 ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
362 }
363
364 )"
365 },
366 // send and recv
367 {
368 "SendRecv",
369 R"(HloModule TwoSendRecvBothWayRecvFist_module, entry_computation_layout={()->(f32[], token[])}
370
371 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
372 %token0 = token[] after-all()
373 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
374 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
375 %constant = f32[] constant(2.1), sharding={maximal device=0}
376 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
377 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
378 }
379
380 )"
381 },
382 {
383 "SendRecvWithHostTransfer",
384 R"(HloModule HostTransferSendRecv_module, entry_computation_layout={()->(f32[], token[])}
385
386 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
387 %token0 = token[] after-all()
388 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
389 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
390 %constant = f32[] constant(2.1), sharding={maximal device=0}
391 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
392 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
393 }
394
395 )"
396 },
397 // get-tuple-element
398 {
399 "GetTupleElement",
400 R"(HloModule GetTupleElement_module, entry_computation_layout={()->s32[2,3]{1,0}}
401
402 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
403 %constant = f32[3]{0} constant({1, 2, 3})
404 %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
405 %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
406 ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
407 }
408
409 )"
410 },
411 // call
412 {
413 "Call",
414 R"(HloModule CallR0F32IdentityScalar_module, entry_computation_layout={()->f32[]}
415
416 %Identity.v1 (x: f32[]) -> f32[] {
417 ROOT %x = f32[] parameter(0)
418 }
419
420 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
421 %constant = f32[] constant(42)
422 ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
423 }
424
425 )"
426 },
427 // CustomCall with backend_config.
428 {
429 "CustomCallWithOpaque",
430 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
431
432 ENTRY %CustomCall () -> f32[1,2,3] {
433 %constant = f32[1]{0} constant({12345})
434 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
435 }
436
437 )"
438 },
439
440 // CustomCall with literal.
441 {
442 "CustomCallWithLiteral",
443 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
444
445 ENTRY %CustomCall () -> f32[1,2,3] {
446 %constant = f32[1]{0} constant({12345})
447 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=s32[2]{0} {1, 2}
448 }
449
450 )"
451 },
452
453 // CustomCall with literal tuple.
454 {
455 "CustomCallWithLiteralTuple",
456 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
457
458 ENTRY %CustomCall () -> f32[1,2,3] {
459 %constant = f32[1]{0} constant({12345})
460 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=( s32[4]{0} {4, 128, 128, 3}, pred[4]{0} {1, 0, 0, 0} )
461 }
462
463 )"
464 },
465
466 // CustomCall with literal R0.
467 {
468 "CustomCallWithLiteralR0",
469 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
470
471 ENTRY %CustomCall () -> f32[1,2,3] {
472 %constant = f32[1]{0} constant({12345})
473 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=f32[] 0.1
474 }
475
476 )"
477 },
478 // reduce window
479 {
480 "ReduceWindow",
481 R"(HloModule R4UnitWindow_module, entry_computation_layout={(f32[13,12,8,15]{0,3,2,1})->f32[13,3,8,15]{0,3,2,1}}
482
483 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
484 %lhs = f32[] parameter(0)
485 %rhs = f32[] parameter(1)
486 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
487 }
488
489 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
490 %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
491 %constant = f32[] constant(0)
492 ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
493 }
494
495 )"
496 },
497 // reduce window on scalar
498 {
499 "ReduceWindowScalar",
500 R"(HloModule reduce_window_scalar, entry_computation_layout={()->f32[]}
501
502 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
503 %lhs = f32[] parameter(0)
504 %rhs = f32[] parameter(1)
505 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
506 }
507
508 ENTRY %R4UnitWindowScalar () -> f32[] {
509 %constant = f32[] constant(42)
510 %constant.1 = f32[] constant(1)
511 ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
512 }
513
514 )"
515 },
516 // reduce window on scalar
517 {
518 "ReduceWindowVariadic",
519 R"(HloModule reduce_window_variadic, entry_computation_layout={()->(f32[], f32[])}
520
521 %add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) {
522 %lhs1 = f32[] parameter(0)
523 %rhs1 = f32[] parameter(2)
524 %add1 = f32[] add(f32[] %lhs1, f32[] %rhs1)
525 %lhs2 = f32[] parameter(1)
526 %rhs2 = f32[] parameter(3)
527 %add2 = f32[] add(f32[] %lhs2, f32[] %rhs2)
528 ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2)
529 }
530
531 ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) {
532 %constant = f32[] constant(42)
533 %constant.1 = f32[] constant(1)
534 ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3
535 }
536
537 )"
538 },
539 // convolution
540 {
541 "Convolution",
542 R"(HloModule Convolve1D1Window_0_module, entry_computation_layout={(f32[1,2,1]{2,1,0},f32[1,1,1]{2,1,0})->f32[1,2,1]{2,0,1}}
543
544 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
545 %input = f32[1,2,1]{2,1,0} parameter(0)
546 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
547 %filter = f32[1,1,1]{2,1,0} parameter(1)
548 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
549 }
550
551 )"
552 },
553 // convolution dynamic
554 {
555 "ConvolutionDynamic",
556 R"(HloModule Convolve1D1Window_0_module, entry_computation_layout={(f32[1,2,1]{2,1,0},f32[1,1,1]{2,1,0})->f32[1,2,1]{2,0,1}}
557
558 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
559 %input = f32[1,2,1]{2,1,0} parameter(0)
560 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
561 %filter = f32[1,1,1]{2,1,0} parameter(1)
562 ROOT %custom-call.52 = f32[1,2,1]{2,0,1} custom-call(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}, custom_call_target="DynamicConvolutionForward", metadata={op_type="Conv2D" op_name="conv1d"}
563 }
564
565 )"
566 },
567 // convolution rank 2
568 {
569 "ConvolutionR2",
570 R"(HloModule ConvolveR2_module, entry_computation_layout={(f32[1,2]{1,0},f32[2,2]{1,0})->f32[1,2]{0,1}}
571
572 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
573 %input = f32[1,2]{1,0} parameter(0)
574 %filter = f32[2,2]{1,0} parameter(1)
575 ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
576 }
577
578 )"
579 },
580 // convolution backward
581 {
582 "ConvolutionBackward",
583 R"(HloModule ConvolveBackward_module, entry_computation_layout={(f32[128,7,7,512]{0,3,2,1},f32[3,3,512,512]{3,2,1,0})->f32[128,14,14,512]{0,3,2,1}}
584
585 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
586 %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
587 %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
588 ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
589 }
590
591 )"
592 },
593 // reverse(constant)
594 {
595 "Reverse4D",
596 R"(HloModule Reverse4DFloatArrayOnDim01_module, entry_computation_layout={()->f32[4,3,2,1]{0,1,2,3}}
597
598 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
599 %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
600 ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
601 }
602
603 )"
604 },
605 // concat
606 {
607 "Concat",
608 R"(HloModule Concat2x3With2x5_module, entry_computation_layout={()->f32[2,8]{1,0}}
609
610 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
611 %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
612 %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
613 ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
614 }
615
616 )"
617 },
618 // select and scatter
619 {
620 "SelectAndScatter",
621 R"(HloModule R4F32OverlapSmall_module, entry_computation_layout={()->f32[4,5,1,1]{3,2,1,0}}
622
623 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
624 %lhs = f32[] parameter(0)
625 %rhs = f32[] parameter(1)
626 ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER
627 }
628
629 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
630 %lhs.1 = f32[] parameter(0)
631 %rhs.1 = f32[] parameter(1)
632 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
633 }
634
635 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
636 %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
637 %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
638 %constant.2 = f32[] constant(0)
639 ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
640 }
641
642 )"
643 },
644 // select and scatter on scalar
645 {
646 "SelectAndScatterScalar",
647 R"(HloModule select_and_scatter_scalar, entry_computation_layout={()->f32[]}
648
649 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
650 %lhs = f32[] parameter(0)
651 %rhs = f32[] parameter(1)
652 ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
653 }
654
655 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
656 %lhs.1 = f32[] parameter(0)
657 %rhs.1 = f32[] parameter(1)
658 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
659 }
660
661 ENTRY %SelectAndScatterScalar () -> f32[] {
662 %constant = f32[] constant(42)
663 %constant.1 = f32[] constant(1)
664 %constant.2 = f32[] constant(2)
665 ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
666 }
667
668 )"
669 },
670 // slice
671 {
672 "Slice",
673 R"(HloModule slice_module, entry_computation_layout={(f32[3,3,4,4]{3,2,1,0})->f32[3,3,2,4]{3,2,1,0}}
674
675 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
676 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
677 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
678 }
679
680 )"
681 },
682 // slice, no stride
683 {
684 "SliceNoStride",
685 R"(HloModule Slice3x3x3_To_1x3x3_F32_module, entry_computation_layout={()->f32[1,3,3]{2,1,0}}
686
687 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
688 %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
689 ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
690 }
691
692 )"
693 },
694 // slice R0
695 {
696 "SliceR0",
697 R"(HloModule SliceR0_module, entry_computation_layout={()->s32[]}
698
699 ENTRY %SliceR0.v2 () -> s32[] {
700 %constant = s32[] constant(1)
701 ROOT %slice = s32[] slice(s32[] %constant), slice={}
702 }
703
704 )"
705 },
706 // transpose
707 {
708 "Transpose",
709 R"(HloModule Transpose_module, entry_computation_layout={()->s32[1,2,3]{2,1,0}}
710
711 ENTRY %Transpose.v2 () -> s32[1,2,3] {
712 %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
713 ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
714 }
715
716 )"
717 },
718 {
719 "TransposeC128",
720 R"(HloModule TransposeC128_module, entry_computation_layout={(c128[1,2,3]{2,1,0})->c128[1,2,3]{2,1,0}}
721
722 ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
723 %input = c128[1,2,3]{2,1,0} parameter(0)
724 ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
725 }
726
727 )"
728 },
729 // Triangular solve
730 {
731 "TriangularSolve",
732 R"(HloModule TriangularSolve_module, entry_computation_layout={(f32[4,4]{1,0},f32[3,4]{1,0})->f32[3,4]{1,0}}
733
734 ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
735 %a.1 = f32[4,4]{1,0} parameter(0)
736 %b.2 = f32[3,4]{1,0} parameter(1)
737 ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
738 }
739
740 )"
741 },
742 // Dynamic slice
743 {
744 "DynamicSlice",
745 R"(HloModule DynamicSlice_module, entry_computation_layout={(s32[2,2,258]{2,1,0},s32[1]{0})->s32[2,2,258]{2,1,0}}
746
747 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
748 %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
749 %constant = s32[1]{0} constant({0})
750 %start_index = s32[1]{0} parameter(1)
751 %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
752 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
753 }
754
755 )"
756 },
757 // Dynamic slice with scalar indices
758 {
759 "DynamicSliceScalarIndices",
760 R"(HloModule DynamicSlice_module, entry_computation_layout={(s32[2,2,258]{2,1,0},s32[])->s32[2,2,258]{2,1,0}}
761
762 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
763 %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
764 %constant = s32[] constant(0)
765 %start_index = s32[] parameter(1)
766 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
767 }
768
769 )"
770 },
771 // Dynamic update slice
772 {
773 "DynamicUpdateSlice",
774 R"(HloModule DynamicSlice_module, entry_computation_layout={(s32[1,1,25,1]{3,2,1,0},s32[1,1,2,1]{3,2,1,0},s32[4]{0})->s32[1,1,25,1]{3,2,1,0}}
775
776 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
777 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
778 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
779 %start_indices = s32[4]{0} parameter(2)
780 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
781 }
782
783 )"
784 },
785 // Dynamic update slice with scalar indices
786 {
787 "DynamicUpdateSliceScalarIndex",
788 R"(HloModule DynamicUpdateSlice_module, entry_computation_layout={(s32[1,1,25,1]{3,2,1,0},s32[1,1,2,1]{3,2,1,0},s32[],s32[],s32[],s32[])->s32[1,1,25,1]{3,2,1,0}}
789
790 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
791 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
792 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
793 %start_index.0 = s32[] parameter(2)
794 %start_index.1 = s32[] parameter(3)
795 %start_index.2 = s32[] parameter(4)
796 %start_index.3 = s32[] parameter(5)
797 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, /*index=5*/s32[] %start_index.3)
798 }
799
800 )"
801 },
802 // batch norm training
803 {
804 "BatchNormTraining",
805 R"(HloModule BasicTraining_module, entry_computation_layout={()->(f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})}
806
807 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
808 %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
809 %constant.1 = f32[2]{0} constant({2, 3})
810 %constant.2 = f32[2]{0} constant({1, 2})
811 ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
812 }
813
814 )"
815 },
816 // batch norm inference
817 {
818 "BatchNormInference",
819 R"(HloModule BatchNormInference_module, entry_computation_layout={(f32[2,2,2,2]{3,2,1,0},f32[2]{0},f32[2]{0},f32[2]{0},f32[2]{0})->f32[2,2,2,2]{3,2,1,0}}
820
821 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
822 %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
823 %offset = f32[2]{0} parameter(1)
824 %scale = f32[2]{0} parameter(2)
825 %mean = f32[2]{0} parameter(3)
826 %variance = f32[2]{0} parameter(4)
827 ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
828 }
829
830 )"
831 },
832 // batch norm grad
833 {
834 "BatchNormGrad",
835 R"(HloModule BatchNormGrad_module, entry_computation_layout={(f32[2,2,2,2]{3,2,1,0},f32[2]{0},f32[2]{0},f32[2]{0},f32[2,2,2,2]{3,2,1,0})->(f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0})}
836
837 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
838 %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
839 %scale = f32[2]{0} parameter(1)
840 %mean = f32[2]{0} parameter(2)
841 %variance = f32[2]{0} parameter(3)
842 %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
843 ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
844 }
845
846 )"
847 },
848 // fft
849 {
850 "Fft",
851 R"(HloModule Fft_module, entry_computation_layout={(c64[8,32]{1,0})->c64[8,32]{1,0}}
852
853 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
854 %input = c64[8,32]{1,0} parameter(0)
855 ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
856 }
857
858 )"
859 },
860 // ifft
861 {
862 "Ifft2d",
863 R"(HloModule Ifft2d_module, entry_computation_layout={(c64[5,8,32]{2,1,0})->c64[5,8,32]{2,1,0}}
864
865 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
866 %input = c64[5,8,32]{2,1,0} parameter(0)
867 ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
868 }
869
870 )"
871 },
872 // rfft2d
873 {
874 "Rfft2d",
875 R"(HloModule Rfft2d_module, entry_computation_layout={(f32[5,64,32]{2,1,0})->c64[5,64,17]{2,1,0}}
876
877 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
878 %input = f32[5,64,32]{2,1,0} parameter(0)
879 ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
880 }
881
882 )"
883 },
884 // irfft3d
885 {
886 "Irfft3d",
887 R"(HloModule Irfft3d_module, entry_computation_layout={(c64[5,64,128,33]{3,2,1,0})->f32[5,64,128,64]{3,2,1,0}}
888
889 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
890 %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
891 ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
892 }
893
894 )"
895 },
896 // pad
897 {
898 "Pad",
899 R"(HloModule Pad1DS3Array_module, entry_computation_layout={()->f32[7]{0}}
900
901 ENTRY %Pad1DS3Array.v3 () -> f32[7] {
902 %constant = f32[3]{0} constant({1, 2, 3})
903 %constant.1 = f32[] constant(0.1)
904 ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
905 }
906
907 )"
908 },
909 // pad has interior
910 {
911 "PadHasInterior",
912 R"(HloModule PadHasInterior_module, entry_computation_layout={(f32[1,25,7,7]{3,2,1,0})->f32[1,25,17,11]{3,2,1,0}}
913
914 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
915 %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
916 %constant = f32[] constant(-5.123)
917 ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
918 }
919
920 )"
921 },
922 // round to nearest even
923 {
924 "RoundNearestEven",
925 R"(HloModule RoundNearestEven_module, entry_computation_layout={(f32[2,2]{1,0})->f32[2,2]{1,0}}
926
927 ENTRY %RoundNearestEven (input: f32[2,2]) -> f32[2,2] {
928 %input = f32[2,2]{1,0} parameter(0)
929 ROOT %round-nearest-even = f32[2,2]{1,0} round-nearest-even(f32[2,2]{1,0} %input)
930 }
931
932 )"
933 },
934 // Negative padding
935 {
936 "PadHasNegativePadding",
937 R"(HloModule PadHasNegativePadding_module, entry_computation_layout={(f32[1,25,7,7,10]{4,3,2,1,0})->f32[1,15,6,3,35]{4,3,2,1,0}}
938
939 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
940 %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
941 %constant = f32[] constant(-5.123)
942 ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
943 }
944
945 )"
946 },
947 // fusion
948 {
949 "Fusion",
950 R"(HloModule fusion_module, entry_computation_layout={()->f32[3,2,1,1]{3,2,1,0}}
951
952 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
953 %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
954 %constant.1.param_1 = f32[2]{0} parameter(1)
955 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
956 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
957 }
958
959 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
960 %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
961 %constant.1 = f32[2]{0} constant({3.14, 4.25})
962 ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
963 }
964
965 )"
966 },
967 {
968 "Gather",
969 R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}}
970
971 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
972 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
973 %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
974 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
975 }
976
977 )"
978 },
979 {
980 "SortedGather",
981 R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}}
982
983 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
984 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
985 %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
986 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true
987 }
988
989 )"
990 },
991 {
992 "Scatter",
993 R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
994
995 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
996 %lhs = f32[] parameter(0)
997 %rhs = f32[] parameter(1)
998 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
999 }
1000
1001 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
1002 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1003 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1004 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
1005 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
1006 }
1007
1008 )"
1009 },
1010 {
1011 "TupleScatter",
1012 R"(HloModule TupleScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},bf16[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0},bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->(f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0})}
1013
1014 %add_F32_mul_BF16 (lhs_0: f32[], lhs_1: bf16[], rhs_0: f32[], rhs_1: bf16[]) -> (f32[], bf16[]) {
1015 %lhs_0 = f32[] parameter(0)
1016 %rhs_0 = f32[] parameter(2)
1017 %add = f32[] add(f32[] %lhs_0, f32[] %rhs_0)
1018 %lhs_1 = bf16[] parameter(1)
1019 %rhs_1 = bf16[] parameter(3)
1020 %mul = bf16[] multiply(bf16[] %lhs_1, bf16[] %rhs_1)
1021 ROOT %tuple = (f32[], bf16[]) tuple(f32[] %add, bf16[] %mul)
1022 }
1023
1024 ENTRY %Scatter (input_0: f32[50,49,48,47,46], input_1: bf16[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates_0: f32[10,9,8,7,30,29,28,27,26], updates_1: bf16[10,9,8,7,30,29,28,27,26]) -> (f32[50,49,48,47,46], bf16[50,49,48,47,46]) {
1025 %input_0 = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1026 %input_1 = bf16[50,49,48,47,46]{4,3,2,1,0} parameter(1)
1027 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(2)
1028 %updates_0 = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(3)
1029 %updates_1 = bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(4)
1030 ROOT %scatter = (f32[50,49,48,47,46]{4,3,2,1,0}, bf16[50,49,48,47,46]{4,3,2,1,0}) scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_0, bf16[50,49,48,47,46]{4,3,2,1,0} %input_1, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates_0, bf16[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates_1), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32_mul_BF16
1031 }
1032
1033 )"
1034 },
1035 {
1036 "SortedScatter",
1037 R"(HloModule StringifySortedScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
1038
1039 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
1040 %lhs = f32[] parameter(0)
1041 %rhs = f32[] parameter(1)
1042 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
1043 }
1044
1045 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
1046 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1047 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1048 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
1049 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
1050 }
1051
1052 )"
1053 },
1054 {
1055 "UniqueIndicesScatter",
1056 R"(HloModule StringifyUniqueIndicesScatter, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0},f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46]{4,3,2,1,0}}
1057
1058 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
1059 %lhs = f32[] parameter(0)
1060 %rhs = f32[] parameter(1)
1061 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
1062 }
1063
1064 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
1065 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1066 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1067 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
1068 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
1069 }
1070
1071 )"
1072 },
1073 {
1074 "ConstantUnsignedNoUnderflow",
1075 R"(HloModule ConstantUnsignedNoUnderflow_module, entry_computation_layout={()->u64[]}
1076
1077 ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
1078 ROOT %constant = u64[] constant(1)
1079 }
1080
1081 )"
1082 },
1083
1084 {
1085 "ConstantUnsignedNoOverflow",
1086 R"(HloModule ConstantUnsignedNoOverflow_module, entry_computation_layout={()->u64[]}
1087
1088 ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
1089 ROOT %constant = u64[] constant(9223372036854775807)
1090 }
1091
1092 )"
1093 },
1094 // CustomCallWithLayoutConstraints
1095 {
1096 "CustomCallWithLayoutConstraints",
1097 R"(HloModule CustomCallWithLayoutConstraints, entry_computation_layout={(f32[42,2,3]{0,1,2},f32[123,4]{0,1})->f32[1,2,3]{0,2,1}}
1098
1099 ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
1100 %p0 = f32[42,2,3]{0,1,2} parameter(0)
1101 %p1 = f32[123,4]{0,1} parameter(1)
1102 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
1103 }
1104
1105 )"
1106 },
1107 // CustomCallWithLayoutConstraintsNoOperands
1108 {
1109 "CustomCallWithLayoutConstraintsNoOperands",
1110 R"(HloModule CustomCallWithLayoutConstraintsNoOperands, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1111
1112 ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
1113 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1114 }
1115
1116 )"
1117 },
1118 // CustomCallWithLayoutConstraintsTupleShapes
1119 {
1120 "CustomCallWithLayoutConstraintsTupleShapes",
1121 R"(HloModule CustomCallWithLayoutConstraintsTupleShapes, entry_computation_layout={((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}),f32[123,4]{0,1})->(f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0})}
1122
1123 ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
1124 %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
1125 %p1 = f32[123,4]{0,1} parameter(1)
1126 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
1127 }
1128
1129 )"
1130 },
1131 // CustomCallWithHasSideEffect
1132 {
1133 "CustomCallWithHasSideEffect",
1134 R"(HloModule CustomCallWithHasSideEffect, entry_computation_layout={((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}),f32[123,4]{0,1})->(f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0})}
1135
1136 ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
1137 %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
1138 %p1 = f32[123,4]{0,1} parameter(1)
1139 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
1140 }
1141
1142 )"
1143 },
1144 // CustomCallWithAliasing
1145 {
1146 "CustomCallWithAliasing",
1147 R"(HloModule CustomCallWithAliasing, entry_computation_layout={((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}),f32[123,4]{0,1})->(f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2})}
1148
1149 ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
1150 %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
1151 %p1 = f32[123,4]{0,1} parameter(1)
1152 ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
1153 }
1154
1155 )"
1156 },
1157 // CustomCall with schedule.
1158 {
1159 "CustomCallWithSchedule",
1160 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1161
1162 ENTRY %CustomCall () -> f32[1,2,3] {
1163 %constant = f32[1]{0} constant({12345})
1164 %custom-call.0 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", schedule=SCHEDULE_EARLIEST
1165 ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1,2,3]{0,2,1} %custom-call.0), custom_call_target="bar", schedule=SCHEDULE_LATEST
1166 }
1167
1168 )"
1169 },
1170 // CustomCall that returns a status.
1171 {
1172 "CustomCallWithStatusReturningVersion",
1173 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1174
1175 ENTRY %CustomCall () -> f32[1,2,3] {
1176 %constant = f32[1]{0} constant({12345})
1177 ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", api_version=API_VERSION_STATUS_RETURNING
1178 }
1179
1180 )"
1181 },
1182 // Parse c64 literal
1183 {
1184 "ParseC64Literal",
1185 R"(HloModule ParseC64Literal, entry_computation_layout={()->c64[2]{0}}
1186
1187 ENTRY %ParseC64Literal () -> c64[2] {
1188 ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
1189 }
1190
1191 )"
1192 },
1193 // Parse c128 literal
1194 {
1195 "ParseC128Literal",
1196 R"(HloModule ParseC128Literal, entry_computation_layout={()->c128[2]{0}}
1197
1198 ENTRY %ParseC128Literal () -> c128[2] {
1199 ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
1200 }
1201
1202 )"
1203 },
1204 // Indexed Conditional
1205 {
1206 "IndexedConditional",
1207 R"(HloModule indexed_conditional, entry_computation_layout={()->f32[]}
1208
1209 %Negate (x: f32[]) -> f32[] {
1210 %x = f32[] parameter(0)
1211 ROOT %negate = f32[] negate(f32[] %x)
1212 }
1213
1214 %Identity (y: f32[]) -> f32[] {
1215 %y = f32[] parameter(0)
1216 ROOT %copy = f32[] copy(f32[] %y)
1217 }
1218
1219 %Floor (z: f32[]) -> f32[] {
1220 %z = f32[] parameter(0)
1221 ROOT %floor = f32[] floor(f32[] %z)
1222 }
1223
1224 ENTRY %Parameters1.v4 () -> f32[] {
1225 %constant = s32[] constant(1)
1226 %constant.1 = f32[] constant(56)
1227 %constant.2 = f32[] constant(12)
1228 %constant.3 = f32[] constant(13)
1229 ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
1230 }
1231
1232 )"
1233 },
1234 // rng-get-and-update-state
1235 {
1236 "RngGetAndUpdateState",
1237 R"(HloModule rng_get_and_update_state, entry_computation_layout={()->u64[2]{0}}
1238
1239 ENTRY %RngGetAndUpdateState () -> u64[2] {
1240 ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=4096
1241 }
1242
1243 )"
1244 },
1245 {
1246 "RngBitGenerator",
1247 R"(HloModule gng_bit_generator, entry_computation_layout={(u64[2]{0})->(u64[2]{0}, u32[11,17]{1,0})}
1248
1249 ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[11,17]) {
1250 %p0 = u64[2]{0} parameter(0)
1251 ROOT %rand = (u64[2]{0}, u32[11,17]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry
1252 }
1253
1254 )"
1255 },
1256 // Async ops with syntax sugar.
1257 {
1258 "AsyncOpsWithSyntaxSugar",
1259 R"(HloModule AsyncOpsWithSyntaxSugar, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1260
1261 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1262 %p0 = f32[10]{0} parameter(0)
1263 %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), custom_call_target="foo"
1264 %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), custom_call_target="foo"
1265 ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), custom_call_target="foo"
1266 }
1267
1268 )"
1269 },
1270 {
1271 "AsyncOpsWithSyntaxSugarAndGroupId",
1272 R"(HloModule AsyncOpsWithSyntaxSugarAndGroupId, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1273
1274 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1275 %p0 = f32[10]{0} parameter(0)
1276 %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_group_id=3, custom_call_target="foo"
1277 %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_group_id=3, custom_call_target="foo"
1278 ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_group_id=3, custom_call_target="foo"
1279 }
1280
1281 )"
1282 },
1283 // Async ops with syntax sugar and async thread name.
1284 {
1285 "AsyncOpsWithSyntaxSugarAndThreadName",
1286 R"(HloModule AsyncOpsWithSyntaxSugarAndThreadName, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1287
1288 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1289 %p0 = f32[10]{0} parameter(0)
1290 %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo"
1291 %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
1292 ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo"
1293 }
1294
1295 )"
1296 },
1297 // HloComputation with thread name as attribute.
1298 {
1299 "HloComputationWithParallelThreadName",
1300 R"(HloModule HloComputationWithParallelThreadName, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1301
1302 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1303 %p0 = f32[10]{0} parameter(0)
1304 %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo"
1305 %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
1306 ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo"
1307 }, execution_thread="main_thread"
1308
1309 )"
1310 },
1311 });
1312 // clang-format on
1313 }
1314
1315 std::vector<TestData> CreateShortTestCases() {
1316 // clang-format off
1317 return std::vector<TestData>({
1318 // map
1319 {
1320 "Map",
1321 R"(HloModule MapBinaryAdder_module, entry_computation_layout={(f32[4]{0},f32[4]{0})->f32[4]{0}}
1322
1323 add_F32.v3 {
1324 lhs = f32[] parameter(0)
1325 rhs = f32[] parameter(1)
1326 ROOT add = f32[] add(lhs, rhs)
1327 }
1328
1329 ENTRY MapBinaryAdder.v3 {
1330 param0 = f32[4]{0} parameter(0)
1331 param1 = f32[4]{0} parameter(1)
1332 ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
1333 }
1334
1335 )"
1336 },
1337 // reduce
1338 {
1339 "Reduce",
1340 R"(HloModule ReduceR3ToR2_module, entry_computation_layout={(f32[8,16,256]{2,1,0})->f32[8,16]{1,0}}
1341
1342 add_F32.v3 {
1343 lhs = f32[] parameter(0)
1344 rhs = f32[] parameter(1)
1345 ROOT add = f32[] add(lhs, rhs)
1346 }
1347
1348 ENTRY ReduceR3ToR2.v3 {
1349 input = f32[8,16,256]{2,1,0} parameter(0)
1350 constant = f32[] constant(0)
1351 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
1352 }
1353
1354 )"
1355 },
1356 // tuple reduce
1357 {
1358 "TupleReduce",
1359 R"(HloModule TupleReduce, entry_computation_layout={(f32[1024]{0},s32[1024]{0})->(f32[], s32[])}
1360
1361 max_argmax {
1362 value = f32[] parameter(2)
1363 prev_max = f32[] parameter(0)
1364 is_next_larger = pred[] compare(value, prev_max), direction=GE
1365 max = f32[] select(is_next_larger, value, prev_max)
1366 index = s32[] parameter(3)
1367 prev_argmax = s32[] parameter(1)
1368 argmax = s32[] select(is_next_larger, index, prev_argmax)
1369 ROOT pair = (f32[], s32[]) tuple(max, argmax)
1370 }
1371
1372 ENTRY reduce_entry {
1373 values = f32[1024]{0} parameter(0)
1374 indices = s32[1024]{0} parameter(1)
1375 init_value = f32[] constant(-inf)
1376 init_index = s32[] constant(-1)
1377 ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
1378 }
1379
1380 )"
1381 },
1382 // infeed/outfeed
1383 {
1384 "InfeedOutfeed",
1385 R"(HloModule outfeed_module, entry_computation_layout={()->((u32[3]{0}, pred[]), token[])}
1386
1387 ENTRY InfeedToOutfeed {
1388 token0 = token[] after-all()
1389 infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1390 infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
1391 outfeed = token[] outfeed(infeed.data, token0), outfeed_shape=(u32[3]{0}, pred[])
1392 ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1393 infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
1394 infeed.1.token = token[] get-tuple-element(infeed.1), index=1
1395 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token), outfeed_shape=(u32[3]{0}, pred[])
1396 }
1397
1398 )"
1399 },
1400 // Rng
1401 {
1402 "Rng",
1403 R"(HloModule rng_module, entry_computation_layout={()->f32[8]{0}}
1404
1405 ENTRY Rng {
1406 constant = f32[] constant(0)
1407 constant.1 = f32[] constant(1)
1408 ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
1409 }
1410
1411 )"
1412 },
1413 // Reduce precision
1414 {
1415 "ReducePrecision",
1416 R"(HloModule reduce_precision, entry_computation_layout={()->f32[1]{0}}
1417
1418 ENTRY ReducePrecision {
1419 constant = f32[1]{0} constant({3.14159})
1420 ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
1421 }
1422
1423 )"
1424 },
1425 // Sort (Key)
1426 {
1427 "SortKey",
1428 R"(HloModule sort, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}}
1429
1430 compare {
1431 p.0.lhs = f32[] parameter(0)
1432 p.0.rhs = f32[] parameter(1)
1433 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1434 }
1435
1436 ENTRY Sort {
1437 x = f32[1024]{0} parameter(0)
1438 ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
1439 }
1440
1441 )"
1442 },
1443 // Sort (Key, Value)
1444 {
1445 "SortKeyValue",
1446 R"(HloModule sort, entry_computation_layout={(f32[1024]{0},s32[1024]{0})->(f32[1024]{0}, s32[1024]{0})}
1447
1448 compare {
1449 p.1.lhs = s32[] parameter(2)
1450 p.1.rhs = s32[] parameter(3)
1451 p.0.lhs = f32[] parameter(0)
1452 p.0.rhs = f32[] parameter(1)
1453 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1454 }
1455
1456 ENTRY Sort {
1457 keys = f32[1024]{0} parameter(0)
1458 values = s32[1024]{0} parameter(1)
1459 ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1460 }
1461
1462 )"
1463 },
1464 // R2 Sort (Key)
1465 {
1466 "SortKeyR2",
1467 R"(HloModule sort, entry_computation_layout={(f32[1024,16]{0,1})->f32[1024,16]{0,1}}
1468
1469 compare {
1470 p.0.lhs = f32[] parameter(0)
1471 p.0.rhs = f32[] parameter(1)
1472 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1473 }
1474
1475 ENTRY Sort {
1476 x = f32[1024,16]{0,1} parameter(0)
1477 ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
1478 }
1479
1480 )"
1481 },
1482 // R2 Sort (Key, Value)
1483 {
1484 "SortKeyValueR2",
1485 R"(HloModule sort, entry_computation_layout={(f32[1024,16]{0,1},s32[1024,16]{0,1})->(f32[1024,16]{0,1}, s32[1024,16]{0,1})}
1486
1487 compare {
1488 p.1.lhs = s32[] parameter(2)
1489 p.1.rhs = s32[] parameter(3)
1490 p.0.lhs = f32[] parameter(0)
1491 p.0.rhs = f32[] parameter(1)
1492 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1493 }
1494
1495 ENTRY Sort {
1496 keys = f32[1024,16]{0,1} parameter(0)
1497 values = s32[1024,16]{0,1} parameter(1)
1498 ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
1499 }
1500
1501 )"
1502 },
1503 // Sort (Key, Value, Value, Value)
1504 {
1505 "SortManyValues",
1506 R"(HloModule sort, entry_computation_layout={(f32[1024,16]{0,1},s32[1024,16]{0,1},u32[1024,16]{0,1},f32[1024,16]{0,1})->(f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1})}
1507
1508 compare {
1509 p.1.lhs = s32[] parameter(2)
1510 p.1.rhs = s32[] parameter(3)
1511 p.2.lhs = u32[] parameter(4)
1512 p.2.rhs = u32[] parameter(5)
1513 p.3.lhs = f32[] parameter(6)
1514 p.3.rhs = f32[] parameter(7)
1515 p.0.lhs = f32[] parameter(0)
1516 p.0.rhs = f32[] parameter(1)
1517 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1518 }
1519
1520 ENTRY Sort {
1521 keys = f32[1024,16]{0,1} parameter(0)
1522 values.0 = s32[1024,16]{0,1} parameter(1)
1523 values.1 = u32[1024,16]{0,1} parameter(2)
1524 values.2 = f32[1024,16]{0,1} parameter(3)
1525 ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
1526 }
1527
1528 )"
1529 },
1530 // Sort (Key) is_stable=true
1531 {
1532 "SortKeyStable",
1533 R"(HloModule sort, entry_computation_layout={(f32[1024]{0})->f32[1024]{0}}
1534
1535 compare {
1536 p.0.lhs = f32[] parameter(0)
1537 p.0.rhs = f32[] parameter(1)
1538 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1539 }
1540
1541 ENTRY Sort {
1542 x = f32[1024]{0} parameter(0)
1543 ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
1544 }
1545
1546 )"
1547 },
1548 // Indexed Conditional
1549 {
1550 "IndexedConditional",
1551 R"(HloModule indexed_conditional, entry_computation_layout={()->f32[]}
1552
1553 Negate {
1554 x = f32[] parameter(0)
1555 ROOT negate = f32[] negate(x)
1556 }
1557
1558 Identity {
1559 y = f32[] parameter(0)
1560 ROOT copy = f32[] copy(y)
1561 }
1562
1563 Floor {
1564 z = f32[] parameter(0)
1565 ROOT floor = f32[] floor(z)
1566 }
1567
1568 ENTRY Parameters1.v4 {
1569 constant = s32[] constant(1)
1570 constant.1 = f32[] constant(56)
1571 constant.2 = f32[] constant(12)
1572 constant.3 = f32[] constant(13)
1573 ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
1574 }
1575
1576 )"
1577 },
1578 // Predicated Conditional
1579 {
1580 "PredicatedConditional",
1581 R"(HloModule pred_conditional, entry_computation_layout={()->f32[]}
1582
1583 Negate {
1584 x = f32[] parameter(0)
1585 ROOT negate = f32[] negate(x)
1586 }
1587
1588 Identity {
1589 y = f32[] parameter(0)
1590 ROOT copy = f32[] copy(y)
1591 }
1592
1593 ENTRY Parameters1.v4 {
1594 constant = pred[] constant(true)
1595 constant.1 = f32[] constant(56)
1596 constant.2 = f32[] constant(12)
1597 ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
1598 }
1599
1600 )"
1601 },
1602 // CustomCall
1603 {
1604 "CustomCall",
1605 R"(HloModule custom_call, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1606
1607 ENTRY CustomCall {
1608 constant = f32[1]{0} constant({12345})
1609 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
1610 }
1611
1612 )"
1613 },
1614 // CustomCall with single computation.
1615 {
1616 "CustumCallSingleComp",
1617 R"(HloModule custom_call_with_comp, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1618
1619 max_F32 {
1620 lhs = f32[] parameter(0)
1621 rhs = f32[] parameter(1)
1622 ROOT maximum = f32[] maximum(lhs, rhs)
1623 }
1624
1625 ENTRY CustomCall {
1626 constant = f32[1]{0} constant({12345})
1627 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32}
1628 }
1629
1630 )"
1631 },
1632 // CustomCall with multiple computations.
1633 {
1634 "CustumCallMultipleComps",
1635 R"(HloModule custom_call_with_comps, entry_computation_layout={()->f32[1,2,3]{0,2,1}}
1636
1637 max_F32 {
1638 lhs = f32[] parameter(0)
1639 rhs = f32[] parameter(1)
1640 ROOT maximum = f32[] maximum(lhs, rhs)
1641 }
1642
1643 ENTRY CustomCall {
1644 constant = f32[1]{0} constant({12345})
1645 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32}
1646 }
1647
1648 )"
1649 },
1650 // Variables with non-default names
1651 {
1652 "NonDefaultNames",
1653 R"(HloModule add_constants_module, entry_computation_layout={()->f32[]}
1654
1655 ENTRY add_constants {
1656 foo = f32[] constant(3.14)
1657 ROOT bar = f32[] add(foo, foo)
1658 }
1659
1660 )"
1661 },
1662 {
1663 "Dot",
1664 R"(HloModule dot, entry_computation_layout={(f32[2,10]{1,0},f32[10,2]{1,0})->f32[2]{0}}
1665
1666 ENTRY dot {
1667 a = f32[2,10]{1,0} parameter(0)
1668 b = f32[10,2]{1,0} parameter(1)
1669 ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
1670 }
1671
1672 )"
1673 },
1674 {
1675 "gather",
1676 R"(HloModule gather, entry_computation_layout={(f32[50,49,48,47,46]{4,3,2,1,0},s64[10,9,8,7,5]{4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}}
1677
1678 ENTRY Gather {
1679 input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1680 start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1681 ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
1682 }
1683
1684 )"
1685 },
1686 // all-reduce
1687 {
1688 "AllReduce",
1689 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1690
1691 add {
1692 lhs = f32[] parameter(0)
1693 rhs = f32[] parameter(1)
1694 ROOT add = f32[] add(lhs, rhs)
1695 }
1696
1697 ENTRY CRS {
1698 input = f32[8]{0} parameter(0)
1699 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
1700 }
1701
1702 )"
1703 },
1704 // all-reduce with subgroups
1705 {
1706 "AllReduceWithSubgroups",
1707 R"(HloModule CRS_Subgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}
1708
1709 add {
1710 lhs = f32[] parameter(0)
1711 rhs = f32[] parameter(1)
1712 ROOT add = f32[] add(lhs, rhs)
1713 }
1714
1715 ENTRY AllReduceWithSubgroups {
1716 input = f32[128,32]{0,1} parameter(0)
1717 ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
1718 }
1719
1720 )",
1721 /*replica_count=*/4,
1722 },
1723 // all-reduce with constrained layout
1724 {
1725 "AllReduceWithLayout",
1726 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1727
1728 add {
1729 lhs = f32[] parameter(0)
1730 rhs = f32[] parameter(1)
1731 ROOT add = f32[] add(lhs, rhs)
1732 }
1733
1734 ENTRY CRS {
1735 input = f32[8]{0} parameter(0)
1736 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
1737 }
1738
1739 )"
1740 },
1741 // all-reduce with channel-id
1742 {
1743 "AllReduceAllReduce",
1744 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1745
1746 add {
1747 lhs = f32[] parameter(0)
1748 rhs = f32[] parameter(1)
1749 ROOT add = f32[] add(lhs, rhs)
1750 }
1751
1752 ENTRY CRS {
1753 input = f32[8]{0} parameter(0)
1754 crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1755 ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1756 }
1757
1758 )"
1759 },
1760 // all-reduce start and done
1761 {
1762 "AllReduceStartAndDone",
1763 R"(HloModule CRS, entry_computation_layout={(f32[8]{0})->f32[8]{0}}
1764
1765 add {
1766 lhs = f32[] parameter(0)
1767 rhs = f32[] parameter(1)
1768 ROOT add = f32[] add(lhs, rhs)
1769 }
1770
1771 ENTRY CRS {
1772 input = f32[8]{0} parameter(0)
1773 crs = f32[8]{0} all-reduce-start(input), replica_groups={}, to_apply=add
1774 ROOT done = f32[8]{0} all-reduce-done(crs)
1775 }
1776
1777 )"
1778 },
1779 // reduce-scatter
1780 {
1781 "ReduceScatter",
1782 R"(HloModule RS, entry_computation_layout={(f32[8]{0})->f32[4]{0}}
1783
1784 add {
1785 lhs = f32[] parameter(0)
1786 rhs = f32[] parameter(1)
1787 ROOT add = f32[] add(lhs, rhs)
1788 }
1789
1790 ENTRY CRS {
1791 input = f32[8]{0} parameter(0)
1792 ROOT ars = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}}, dimensions={0}, to_apply=add
1793 }
1794
1795 )"
1796 },
1797 // all-gather
1798 {
1799 "AllGather",
1800 R"(HloModule AllGather, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1801
1802 ENTRY AllGather {
1803 input = f32[128,32]{0,1} parameter(0)
1804 ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1}
1805 }
1806
1807 )"
1808 },
1809 // all-gather with constrained layout
1810 {
1811 "AllGatherWithLayout",
1812 R"(HloModule AllGather, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1813
1814 ENTRY AllGather {
1815 input = f32[128,32]{0,1} parameter(0)
1816 ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, constrain_layout=true, dimensions={1}
1817 }
1818
1819 )"
1820 },
1821 // all-gather with subgroups
1822 {
1823 "AllGatherWithSubgroups",
1824 R"(HloModule AllGatherWithSubgroups, entry_computation_layout={(f32[128,32]{0,1})->f32[128,64]{0,1}}
1825
1826 ENTRY AllGatherWithSubgroups {
1827 input = f32[128,32]{0,1} parameter(0)
1828 ROOT ag = f32[128,64]{0,1} all-gather(input), replica_groups={{0,1},{2,3}}, dimensions={1}
1829 }
1830
1831 )",
1832 /*replica_count=*/4,
1833 },
1834 // all-to-all
1835 {
1836 "AllToAll",
1837 R"(HloModule AllToAll, entry_computation_layout={(f32[128,32]{0,1})->(f32[128,32]{0,1})}
1838
1839 ENTRY AllToAll {
1840 input = f32[128,32]{0,1} parameter(0)
1841 ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
1842 }
1843
1844 )"
1845 },
1846 // all-to-all with subgroups
1847 {
1848 "AllToAllWithSubgroups",
1849 R"(HloModule AllToAllWithSubgroups, entry_computation_layout={(f32[128,32]{0,1},f32[128,32]{0,1})->(f32[128,32]{0,1}, f32[128,32]{0,1})}
1850
1851 ENTRY AllToAllWithSubgroups {
1852 p0 = f32[128,32]{0,1} parameter(0)
1853 p1 = f32[128,32]{0,1} parameter(1)
1854 ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
1855 }
1856
1857 )",
1858 /*replica_count=*/4,
1859 },
1860 // collective-permute
1861 {
1862 "CollectivePermute",
1863 R"(HloModule CollectivePermute, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}
1864
1865 ENTRY CollectivePermute {
1866 input = f32[128,32]{0,1} parameter(0)
1867 ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
1868 }
1869
1870 )",
1871 /*replica_count=*/4
1872 },
1873 // collective-permute with in-place updates
1874 {
1875 "CollectivePermuteInPlaceUpdate",
1876 R"(HloModule CollectivePermuteInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1877
1878 ENTRY CollectivePermuteInPlaceUpdate {
1879 input = f32[128,32]{0,1} parameter(0)
1880 constant = f32[] constant(1)
1881 output = f32[128,128]{0,1} broadcast(constant), dimensions={}
1882 constant.1 = s32[] constant(0)
1883 tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
1884 constant.2 = s32[] constant(64)
1885 tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
1886 ROOT root = f32[128,128]{0,1} collective-permute(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{128,32}}
1887 }
1888
1889 )",
1890 /*replica_count=*/4
1891 },
1892 // collective-permute with in-place updates with multiple targets per source
1893 {
1894 "CollectivePermuteInPlaceUpdateMultipleReadWrite",
1895 R"(HloModule CollectivePermuteInPlaceUpdateMultipleReadWrite, entry_computation_layout={(f32[8,8,128]{2,1,0})->f32[8,8,128]{2,1,0}}
1896
1897 ENTRY CollectivePermuteInPlaceUpdate {
1898 constant.3 = s32[] constant(2)
1899 constant.1 = s32[] constant(0)
1900 output_offset.3 = (s32[], s32[], s32[]) tuple(constant.3, constant.1, constant.1)
1901 constant.4 = s32[] constant(3)
1902 output_offset.4 = (s32[], s32[], s32[]) tuple(constant.4, constant.1, constant.1)
1903 input = f32[8,8,128]{2,1,0} parameter(0)
1904 constant = f32[] constant(1)
1905 output = f32[8,8,128]{2,1,0} broadcast(constant), dimensions={}
1906 input_offset.1 = (s32[], s32[], s32[]) tuple(constant.1, constant.1, constant.1)
1907 constant.2 = s32[] constant(1)
1908 input_offset.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.1, constant.1)
1909 input_offset = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(input_offset.1, input_offset.2)
1910 output_offset = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(input_offset.1, input_offset.2)
1911 ROOT root = f32[8,8,128]{2,1,0} collective-permute(input, output, input_offset, output_offset), source_target_pairs={{0,1},{1,2},{2,3},{0,3},{2,1},{3,2}}, slice_sizes={{1,8,128},{1,8,128}}
1912 }
1913
1914 )",
1915 /*replica_count=*/4
1916 },
1917 {
1918 "CollectivePermuteInPlaceUpdateTupleMultipleReadWrite",
1919 R"(HloModule hlo_runner_test_0.1, entry_computation_layout={()->(u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})}
1920
1921 ENTRY hlo_runner_test_0.1 {
1922 replica_id = u32[] replica-id()
1923 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(replica_id), dimensions={}
1924 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(broadcast.0, broadcast.0)
1925 constant.1 = u32[] constant(1000)
1926 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(constant.1), dimensions={}
1927 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(constant.1), dimensions={}
1928 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(broadcast.1, broadcast.2)
1929 constant.2 = s32[] constant(0)
1930 tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.2, constant.2)
1931 constant.3 = s32[] constant(1)
1932 tuple.3 = (s32[], s32[], s32[]) tuple(constant.3, constant.2, constant.2)
1933 tuple.4 = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(tuple.2, tuple.3)
1934 tuple.7 = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(tuple.2, tuple.2)
1935 tuple.8 = (((s32[], s32[], s32[]), (s32[], s32[], s32[])), ((s32[], s32[], s32[]), (s32[], s32[], s32[]))) tuple(tuple.4, tuple.7)
1936 constant.4 = s32[] constant(2)
1937 tuple.5 = (s32[], s32[], s32[]) tuple(constant.4, constant.2, constant.2)
1938 tuple.6 = ((s32[], s32[], s32[]), (s32[], s32[], s32[])) tuple(tuple.2, tuple.5)
1939 tuple.9 = (((s32[], s32[], s32[]), (s32[], s32[], s32[])), ((s32[], s32[], s32[]), (s32[], s32[], s32[]))) tuple(tuple.4, tuple.6)
1940 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(tuple.input, tuple.output, tuple.8, tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1941 }
1942
1943 )",
1944 /*replica_count=*/4
1945 },
1946
1947 // collective-permute tuple with in-place updates
1948 {
1949 "CollectivePermuteTupleInPlaceUpdate",
1950 R"(HloModule CollectivePermuteTupleInPlaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->(f32[128,128]{0,1}, f32[128,128]{0,1})}
1951
1952 ENTRY CollectivePermuteInPlaceUpdate {
1953 input = f32[128,32]{0,1} parameter(0)
1954 tuple.input = (f32[128,32]{0,1}, f32[128,32]{0,1}) tuple(input, input)
1955 constant = f32[] constant(1)
1956 output = f32[128,128]{0,1} broadcast(constant), dimensions={}
1957 tuple.output = (f32[128,128]{0,1}, f32[128,128]{0,1}) tuple(output, output)
1958 constant.1 = s32[] constant(0)
1959 tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
1960 constant.2 = s32[] constant(64)
1961 tuple.2 = (s32[], s32[]) tuple(constant.2, constant.1)
1962 tuple.3 = ((s32[], s32[]), (s32[], s32[])) tuple(tuple.1, tuple.2)
1963 tuple.4 = (s32[], s32[]) tuple(constant.1, constant.1)
1964 tuple.5 = (s32[], s32[]) tuple(constant.2, constant.2)
1965 tuple.6 = ((s32[], s32[]), (s32[], s32[])) tuple(tuple.4, tuple.5)
1966 ROOT root = (f32[128,128]{0,1}, f32[128,128]{0,1}) collective-permute(tuple.input, tuple.output, tuple.3, tuple.6), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{64,32},{64,32}}
1967 }
1968
1969 )",
1970 /*replica_count=*/4
1971 },
1972 // collective-permute-start and -done with inplace update
1973 {
1974 "CollectivePermuteStartAndDone",
1975 R"(HloModule CollectivePermuteStartAndDone, entry_computation_layout={(f32[128,32]{0,1})->f32[128,32]{0,1}}
1976
1977 ENTRY CollectivePermuteStartAndDone {
1978 input = f32[128,32]{0,1} parameter(0)
1979 collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
1980 ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
1981 }
1982
1983 )",
1984 /*replica_count=*/4
1985 },
1986 // collective-permute-start and -done
1987 {
1988 "CollectivePermuteStartAndDoneInplaceUpdate",
1989 R"(HloModule CollectivePermuteStartAndDoneInplaceUpdate, entry_computation_layout={(f32[128,32]{0,1})->f32[128,128]{0,1}}
1990
1991 ENTRY CollectivePermuteStartAndDoneInplaceUpdate {
1992 input = f32[128,32]{0,1} parameter(0)
1993 constant = f32[] constant(1)
1994 output = f32[128,128]{0,1} broadcast(constant), dimensions={}
1995 constant.1 = s32[] constant(0)
1996 tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
1997 constant.2 = s32[] constant(64)
1998 tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
1999 collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,128]{0,1}, u32[], u32[]) collective-permute-start(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{64,32}}
2000 ROOT collective-permute-done.1 = f32[128,128]{0,1} collective-permute-done(collective-permute-start.1)
2001 }
2002
2003 )",
2004 /*replica_count=*/4
2005 },
2006 // replica-id
2007 {
2008 "ReplicaId",
2009 R"(HloModule replica-id, entry_computation_layout={()->u32[]}
2010
2011 ENTRY Replica-id {
2012 ROOT replica-id = u32[] replica-id()
2013 }
2014
2015 )"
2016 },
2017 // partition-id
2018 {
2019 "PartitionId",
2020 R"(HloModule partition-id, entry_computation_layout={()->u32[]}
2021
2022 ENTRY PartitionId {
2023 ROOT id = u32[] partition-id()
2024 }
2025
2026 )"
2027 },
2028 // Iota
2029 {
2030 "Iota",
2031 R"(HloModule iota, entry_computation_layout={()->f32[100]{0}}
2032
2033 ENTRY Iota {
2034 ROOT iota = f32[100]{0} iota(), iota_dimension=0
2035 }
2036
2037 )"
2038 },
2039 // custom-call with window, dim_labels and feature_group_count
2040 {
2041 "CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
2042 R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount, entry_computation_layout={()->f32[100]{0}}
2043
2044 ENTRY Computation {
2045 ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
2046 }
2047
2048 )"
2049 },
2050 // custom-call with unknown dim labels.
2051 {
2052 "CustomCallWithUnknownDimLabels",
2053 R"(HloModule CustomCallWithUnknownDimLabels, entry_computation_layout={()->f32[100]{0}}
2054
2055 ENTRY Computation {
2056 ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=?b01f_0?1io->b01?f, custom_call_target="target"
2057 }
2058
2059 )"
2060 },
2061 // is_scheduled=true attribute
2062 {
2063 "ScheduledModule",
2064 R"(HloModule scheduled_module, is_scheduled=true, entry_computation_layout={(f32[1024]{0},s32[1024]{0})->(f32[1024]{0}, s32[1024]{0})}
2065
2066 compare {
2067 p.1.lhs = s32[] parameter(2)
2068 p.1.rhs = s32[] parameter(3)
2069 p.0.lhs = f32[] parameter(0)
2070 p.0.rhs = f32[] parameter(1)
2071 ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
2072 }
2073
2074 ENTRY Sort {
2075 keys = f32[1024]{0} parameter(0)
2076 values = s32[1024]{0} parameter(1)
2077 ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
2078 }
2079
2080 )"
2081 },
2082 // AfterAll with multiple operands
2083 {
2084 "AfterAllWithMultipleOperands",
2085 R"(HloModule AfterAllWithMultipleOperands, entry_computation_layout={(f32[])->token[]}
2086
2087 ENTRY AfterAllWithMultipleOperands {
2088 p0 = f32[] parameter(0)
2089 token0 = token[] after-all()
2090 token1 = token[] after-all()
2091 ROOT after-all = token[] after-all(p0, token0, token1)
2092 }
2093
2094 )"
2095 },
2096 // AddDependency
2097 // A dependency chain is created from 'neg' to 'exp' using tokens.
2098 {
2099 "AddDependency",
2100 R"(HloModule AddDependency, entry_computation_layout={(f32[])->f32[]}
2101
2102 ENTRY AddDependency {
2103 p = f32[] parameter(0)
2104 neg = f32[] negate(p)
2105 token0 = token[] after-all(neg)
2106 p_after_token = f32[] add-dependency(p, token0)
2107 exp = f32[] exponential(p_after_token)
2108 ROOT sum = f32[] add(neg, exp)
2109 }
2110
2111 )"
2112 },
2113 // A module containing constants equal to the min/max values of various data
2114 // types.
2115 {
2116 "MinMaxValues",
2117 R"(HloModule MinMaxValues, entry_computation_layout={()->c128[2]{0}}
2118
2119 ENTRY MinMaxValues {
2120 x.s8 = s8[2]{0} constant({-128, 127})
2121 x.s16 = s16[2]{0} constant({-32768, 32767})
2122 x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
2123 x.u8 = u8[2]{0} constant({0, 255})
2124 x.u16 = u16[2]{0} constant({0, 65535})
2125 x.u32 = u32[2]{0} constant({0, 4294967295})
2126 x.f16 = f16[2]{0} constant({-65504, 65504})
2127 x.bf16 = bf16[2]{0} constant({-3.39e+38, 3.39e+38})
2128 x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
2129 x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
2130 x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
2131 ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
2132 }
2133
2134 )"
2135 },
2136
2137 // Bitcast-convert usage
2138 {
2139 "BitcastConvert",
2140 R"(HloModule BitcastConvert, entry_computation_layout={(f32[100]{0})->u32[100]{0}}
2141
2142 ENTRY BitcastConvertUsage {
2143 p = f32[100]{0} parameter(0)
2144 ROOT out = u32[100]{0} bitcast-convert(p)
2145 }
2146
2147 )"
2148 },
2149 {
2150 "OuterDimensionPartitions",
2151 R"(HloModule OuterDimensionPartitions, entry_computation_layout={(f32[100]{0})->f32[100]{0}}
2152
2153 ENTRY Test {
2154 ROOT foo = f32[100]{0} parameter(0), outer_dimension_partitions={0,10,20}
2155 }
2156
2157 )"
2158 },
2159 });
2160 // clang-format on
2161 }
2162
2163 std::vector<NonRoundtripTestData> CreateNonRoundtripTestCases() {
2164 // clang-format off
2165 return std::vector<NonRoundtripTestData>({
2166 {
2167 "SimpleNesting",
2168 R"(HloModule test
2169
2170 ENTRY test {
2171 ROOT root = add(f32[10] parameter(0), multiply(f32[10] parameter(1), f32[10] parameter(2)))
2172 })",
2173 R"(HloModule test, entry_computation_layout={(f32[10]{0},f32[10]{0},f32[10]{0})->f32[10]{0}}
2174
2175 ENTRY test {
2176 parameter.anon = f32[10]{0} parameter(0)
2177 parameter.anon.1 = f32[10]{0} parameter(1)
2178 parameter.anon.2 = f32[10]{0} parameter(2)
2179 multiply.anon = f32[10]{0} multiply(parameter.anon.1, parameter.anon.2)
2180 ROOT root = f32[10]{0} add(parameter.anon, multiply.anon)
2181 })"
2182 },
2183
2184 {
2185 "AmbiguousNames",
2186 R"(HloModule test
2187 ENTRY test {
2188 add = add(f32[10] parameter(0), f32[10] parameter(1))
2189 ROOT add2 = add(add, add(add, add))
2190 })",
2191 R"(HloModule test, entry_computation_layout={(f32[10]{0},f32[10]{0})->f32[10]{0}}
2192
2193 ENTRY test {
2194 parameter.anon = f32[10]{0} parameter(0)
2195 parameter.anon.1 = f32[10]{0} parameter(1)
2196 add = f32[10]{0} add(parameter.anon, parameter.anon.1)
2197 add.anon = f32[10]{0} add(add, add)
2198 ROOT add2 = f32[10]{0} add(add, add.anon)
2199 })"
2200 },
2201
2202 {
2203 "TupleShapeInsideAnonymousInstr",
2204 R"(HloModule test
2205
2206 ENTRY test {
2207 ROOT root = get-tuple-element(
2208 (f32[10], f16[10]) tuple(f32[10] parameter(0), f16[10] parameter(1))
2209 ), index=0
2210 })",
2211 R"(HloModule test, entry_computation_layout={(f32[10]{0},f16[10]{0})->f32[10]{0}}
2212
2213 ENTRY test {
2214 parameter.anon = f32[10]{0} parameter(0)
2215 parameter.anon.1 = f16[10]{0} parameter(1)
2216 tuple.anon = (f32[10]{0}, f16[10]{0}) tuple(parameter.anon, parameter.anon.1)
2217 ROOT root = f32[10]{0} get-tuple-element(tuple.anon), index=0
2218 })"
2219 },
2220
2221 {
2222 "MixAnonAndNonAnonOperands",
2223 R"(HloModule test
2224
2225 ENTRY test {
2226 add = add(f32[10] parameter(0), f32[10] parameter(1))
2227 ROOT root = tuple(add, add(add, add), add)
2228 })",
2229 R"(HloModule test, entry_computation_layout={(f32[10]{0},f32[10]{0})->(f32[10]{0}, f32[10]{0}, f32[10]{0})}
2230
2231 ENTRY test {
2232 parameter.anon = f32[10]{0} parameter(0)
2233 parameter.anon.1 = f32[10]{0} parameter(1)
2234 add = f32[10]{0} add(parameter.anon, parameter.anon.1)
2235 add.anon = f32[10]{0} add(add, add)
2236 ROOT root = (f32[10]{0}, f32[10]{0}, f32[10]{0}) tuple(add, add.anon, add)
2237 })"
2238 },
2239
2240 {
2241 "BroadcastOfScalarDoesntNeedDimensionsAttr",
2242 R"(HloModule test
2243
2244 ENTRY test {
2245 ROOT root = sqrt(f32[10,10] broadcast(f32[] parameter(0)))
2246 })",
2247 R"(HloModule test, entry_computation_layout={(f32[])->f32[10,10]{1,0}}
2248
2249 ENTRY test {
2250 parameter.anon = f32[] parameter(0)
2251 broadcast.anon = f32[10,10]{1,0} broadcast(parameter.anon), dimensions={}
2252 ROOT root = f32[10,10]{1,0} sqrt(broadcast.anon)
2253 })"
2254 },
2255
2256 {
2257 "SparseShape",
2258 R"(HloModule test
2259
2260 ENTRY test {
2261 ROOT root = f32[10,10]{1,0:D(D,C)} parameter(0)
2262 })",
2263 R"(HloModule test, entry_computation_layout={(f32[10,10]{1,0:D(D,C)})->f32[10,10]{1,0:D(D,C)}}
2264
2265 ENTRY test {
2266 ROOT root = f32[10,10]{1,0:D(D,C)} parameter(0)
2267 })",
2268 }
2269 });
2270 // clang-format on
2271 }
2272
2273 // The test class for those tests defined above which round-trip through the
2274 // parser and ToString is templatized on two bool parameters:
2275 //
2276 // short_form : used for the "short" test cases which use the ShortParsable
2277 // output form.
2278 // proto_round_trip : whether the module should also be round-tripped through
2279 // HloProto form. This provides much better coverage for the proto
2280 // serialization/deserialization.
2281 //
2282 // The proto_round_trip=true case also technically covers the Parser->ToString
2283 // roundtrip as well, but separating out the Parser->ToString roundtrip as its
2284 // own test provides better isolation and could conceivably catch weirdo bugs
2285 // which are hidden by interaction between the textual and proto roundtripping.
2286 template <bool short_form, bool proto_round_trip>
2287 class HloParameterizedParserTest
2288 : public ::testing::Test,
2289 public ::testing::WithParamInterface<TestData> {
2290 protected:
2291 // Expects "ToString(ParseHloModule(std::string)) == string", that is, parses
2292 // the string, asserts that it succeeded, stringifies the parsed module, and
2293 // checks that it equals the original string.
2294 void ExpectEqual() {
2295 std::unique_ptr<HloModule> module;
2296 const std::string& original = GetParam().module_string;
2297 HloModuleConfig config;
2298 config.set_replica_count(GetParam().replica_count);
2299 if (GetParam().enable_verification) {
2300 auto verified_module = std::make_unique<VerifiedHloModule>(
2301 GetParam().test_name, config,
2302 /*verifier_layout_sensitive=*/false,
2303 /*allow_mixed_precision_in_hlo_verifier=*/true,
2304 ShapeUtil::ByteSizeOfElements);
2305 TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
2306 module = std::move(verified_module);
2307 } else {
2308 TF_ASSERT_OK_AND_ASSIGN(module,
2309 ParseAndReturnUnverifiedModule(original, config));
2310 }
2311 if (proto_round_trip) {
2312 TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
2313 module->ToProto(), module->config()));
2314 }
2315 if (short_form) {
2316 EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
2317 } else {
2318 EXPECT_EQ(
2319 original,
2320 module->ToString(HloPrintOptions().set_print_large_constants(true)));
2321 }
2322 }
2323 };
2324
2325 // These using shenanigans are required because the TEST_P macro doesn't like
2326 // template instantiations which contain commas.
2327 using HloParserTestLong = HloParameterizedParserTest<false, false>;
2328 using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
2329 using HloParserTestShort = HloParameterizedParserTest<true, false>;
2330 using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
2331
2332 TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
2333 TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
2334 TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
2335 TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
2336
2337 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
2338 ::testing::ValuesIn(CreateTestCases()),
2339 TestDataToString);
2340 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
2341 HloParserTestLongProto,
2342 ::testing::ValuesIn(CreateTestCases()),
2343 TestDataToString);
2344 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
2345 ::testing::ValuesIn(CreateShortTestCases()),
2346 TestDataToString);
2347 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
2348 HloParserTestShortProto,
2349 ::testing::ValuesIn(CreateShortTestCases()),
2350 TestDataToString);
2351
2352 class HloNonRoundtripParserTest
2353 : public ::testing::TestWithParam<NonRoundtripTestData> {};
2354 TEST_P(HloNonRoundtripParserTest, Run) {
2355 auto module = std::make_unique<VerifiedHloModule>(
2356 GetParam().test_name, HloModuleConfig{},
2357 /*verifier_layout_sensitive=*/false,
2358 /*allow_mixed_precision_in_hlo_verifier=*/true,
2359 ShapeUtil::ByteSizeOfElements);
2360 TF_ASSERT_OK(
2361 module->ParseHloStringAndVerifyModule(GetParam().input_module_string));
2362 EXPECT_EQ(absl::StripAsciiWhitespace(GetParam().output_module_string),
2363 absl::StripAsciiWhitespace(
2364 module->ToString(HloPrintOptions::ShortParsable())));
2365 }
2366
2367 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
2368 HloNonRoundtripParserTest,
2369 ::testing::ValuesIn(CreateNonRoundtripTestCases()),
2370 NonRoundtripTestDataToString);
2371
2372 class HloParserTest : public ::testing::Test {
2373 protected:
2374 static void ExpectHasSubstr(string_view s, string_view expected) {
2375 EXPECT_TRUE(absl::StrContains(s, expected))
2376 << "'" << s << "' does not contain '" << expected << "'";
2377 }
2378 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
2379 absl::string_view hlo_text) {
2380 auto module = std::make_unique<VerifiedHloModule>(
2381 ::testing::UnitTest::GetInstance()->current_test_info()->name(),
2382 HloModuleConfig(),
2383 /*verifier_layout_sensitive=*/false,
2384 /*allow_mixed_precision_in_hlo_verifier=*/true,
2385 ShapeUtil::ByteSizeOfElements);
2386 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
2387 return std::move(module);
2388 }
2389 };
2390
2391 TEST_F(HloParserTest, Empty) {
2392 const std::string original = "";
2393 auto result = ParseAndReturnUnverifiedModule(original);
2394 EXPECT_NE(OkStatus(), result.status());
2395 }
2396
2397 TEST_F(HloParserTest, Garbage) {
2398 const std::string original =
2399 "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
2400 auto result = ParseAndReturnUnverifiedModule(original);
2401 EXPECT_NE(OkStatus(), result.status());
2402 }
2403
2404 TEST_F(HloParserTest, WrongOpcode) {
2405 const std::string original = R"(HloModule wrong_opcode:
2406
2407 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
2408 %x = f32[]{} parameter(0)
2409 %y = f32[]{} parameter(1)
2410 %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
2411 }
2412
2413 )";
2414 auto result = ParseAndReturnUnverifiedModule(original);
2415 EXPECT_NE(OkStatus(), result.status());
2416 }
2417
2418 TEST_F(HloParserTest, MetadataWithCholesky) {
2419 const std::string original = R"(HloModule metadata_with_cholesky
2420 ENTRY %blabla (a: f32[1,291,291]) -> f32[1,291,291] {
2421 %a = f32[1,291,291] parameter(0)
2422 %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true, metadata={op_type="Cholesky" op_name="Cholesky" profile_type={1}}
2423 }
2424 )";
2425 auto result = ParseAndReturnVerifiedModule(original);
2426 EXPECT_EQ(OkStatus(), result.status());
2427 EXPECT_EQ("Cholesky", result.ValueOrDie()
2428 ->entry_computation()
2429 ->root_instruction()
2430 ->metadata()
2431 .op_name());
2432 EXPECT_EQ("Cholesky", result.ValueOrDie()
2433 ->entry_computation()
2434 ->root_instruction()
2435 ->metadata()
2436 .op_type());
2437 EXPECT_EQ(WINDOW, *result.ValueOrDie()
2438 ->entry_computation()
2439 ->root_instruction()
2440 ->metadata()
2441 .profile_type()
2442 .begin());
2443 }
2444
2445 TEST_F(HloParserTest, WrongShape) {
2446 const std::string original = R"(HloModule wrong_opcode:
2447
2448 ENTRY %blabla (x: g32[]) -> g32[] {
2449 %x = g32[]{} parameter(0)
2450 }
2451
2452 )";
2453 auto result = ParseAndReturnUnverifiedModule(original);
2454 EXPECT_NE(OkStatus(), result.status());
2455 }
2456
2457 TEST_F(HloParserTest, WrongOperandsSize) {
2458 const std::string original = R"(HloModule wrong_opcode:
2459
2460 ENTRY %blabla (x: f32[]) -> pred[] {
2461 %x = f32[]{} parameter(0)
2462 %eq = pred[]{} compare(f32[]{} %x), direction=EQ
2463 }
2464
2465 )";
2466 auto result = ParseAndReturnUnverifiedModule(original);
2467 EXPECT_NE(OkStatus(), result.status());
2468 }
2469
2470 TEST_F(HloParserTest, OperandNotFound) {
2471 const std::string original = R"(HloModule operand_not_found:
2472 ENTRY %blabla (x: f32[]) -> pred[] {
2473 %x = f32[]{} parameter(0)
2474 %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
2475 }
2476 )";
2477 auto result = ParseAndReturnUnverifiedModule(original);
2478 EXPECT_NE(OkStatus(), result.status());
2479 }
2480
2481 TEST_F(HloParserTest, MoreConstants) {
2482 const std::string original = R"(HloModule SelectScalarS32True_module
2483
2484 ENTRY %SelectScalarS32True.v4 () -> s32[] {
2485 %constant.2 = pred[] constant(true)
2486 %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
2487 %constant = s32[] constant(42)
2488 %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
2489 }
2490
2491 )";
2492 auto result = ParseAndReturnVerifiedModule(original);
2493 TF_EXPECT_OK(result.status());
2494 // Constant instructions have no name. The string will be parsed successfully
2495 // but the constant names will not be exactly the same.
2496 }
2497
2498 TEST_F(HloParserTest, ConfigurationField) {
2499 const std::string original = R"(HloModule AModule
2500 ENTRY %configuration_test() -> s32[] {
2501 %constant = s32[] constant(42), backend_config="foo bar"
2502 })";
2503 auto result = ParseAndReturnVerifiedModule(original);
2504 TF_ASSERT_OK(result.status());
2505 EXPECT_EQ("foo bar", result.ValueOrDie()
2506 ->entry_computation()
2507 ->root_instruction()
2508 ->raw_backend_config_string());
2509 }
2510
2511 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
2512 const std::string original = R"(HloModule some_2_module
2513
2514 ENTRY %some_2 () -> f32[2] {
2515 ROOT %constant = f32[2]{0} constant({1,{2}})
2516 }
2517
2518 )";
2519 auto result = ParseAndReturnUnverifiedModule(original);
2520 EXPECT_NE(OkStatus(), result.status());
2521 ExpectHasSubstr(result.status().error_message(),
2522 "expects nested array in rank 1, but sees larger");
2523 }
2524
2525 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
2526 const std::string original = R"(HloModule some_2x3_module
2527
2528 ENTRY %some_2x3 () -> f32[2,3] {
2529 ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
2530 }
2531
2532 )";
2533 auto result = ParseAndReturnUnverifiedModule(original);
2534 EXPECT_NE(OkStatus(), result.status());
2535 ExpectHasSubstr(result.status().error_message(),
2536 "expects nested array in rank 2, but sees 1");
2537 }
2538
2539 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
2540 const std::string original = R"(HloModule some_2x3x2_module
2541
2542 ENTRY %some_2x3x2 () -> f32[2,3,2] {
2543 ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
2544 }
2545
2546 )";
2547 auto result = ParseAndReturnUnverifiedModule(original);
2548 EXPECT_NE(OkStatus(), result.status());
2549 ExpectHasSubstr(result.status().error_message(),
2550 "expects 3 elements in the [0]th element");
2551 }
2552
2553 TEST_F(HloParserTest, ConstantF16Overflow) {
2554 const std::string original =
2555 R"(HloModule ConstantF16Overflow_module
2556
2557 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
2558 ROOT %constant = f16[] constant(-65520)
2559 }
2560
2561 )";
2562 auto result = ParseAndReturnUnverifiedModule(original);
2563 EXPECT_NE(OkStatus(), result.status());
2564 ExpectHasSubstr(result.status().error_message(),
2565 "is out of range for literal's primitive type F16");
2566 }
2567
2568 TEST_F(HloParserTest, ConstantBf16NoOverflow) {
2569 // 65505 is in range for bf16.
2570 const std::string original = R"(
2571 HloModule test_module
2572 ENTRY test {
2573 ROOT c = bf16[] constant(-65505)
2574 })";
2575 EXPECT_EQ(OkStatus(), ParseAndReturnVerifiedModule(original).status());
2576 }
2577
2578 TEST_F(HloParserTest, ConstantBf16Overflow) {
2579 // 1e100 is out of range for bf16.
2580 const std::string original = R"(
2581 HloModule test_module
2582 ENTRY test {
2583 ROOT c = bf16[] constant(1e100)
2584 })";
2585 ExpectHasSubstr(
2586 ParseAndReturnUnverifiedModule(original).status().error_message(),
2587 "out of range");
2588 }
2589
2590 TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
2591 const std::string original = R"(
2592 HloModule ConstantUnsignedUnderflow_module
2593 ENTRY %ConstantUnsignedUnderflow () -> u64[] {
2594 ROOT %constant = u64[] constant(-1)
2595 })";
2596 auto result = ParseAndReturnUnverifiedModule(original);
2597 EXPECT_EQ(OkStatus(), result.status());
2598 }
2599
2600 TEST_F(HloParserTest, ConstantUnsignedOverflow) {
2601 const std::string original = R"(
2602 HloModule ConstantUnsignedOverflow_module
2603 ENTRY %ConstantUnsignedOverflow () -> u32[] {
2604 ROOT %constant = u32[] constant(4294967296)
2605 })";
2606 auto result = ParseAndReturnUnverifiedModule(original);
2607 EXPECT_NE(OkStatus(), result.status());
2608 ExpectHasSubstr(result.status().error_message(),
2609 "is out of range for literal's primitive type U32");
2610 }
2611
2612 TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
2613 const std::string original = R"(
2614 HloModule ConstantUnsignedOverflow_module
2615 ENTRY %ConstantUnsignedOverflow () -> u64[] {
2616 ROOT %constant = u64[] constant(9223372036854775808)
2617 })";
2618 auto result = ParseAndReturnUnverifiedModule(original);
2619 EXPECT_EQ(OkStatus(), result.status());
2620 }
2621
2622 TEST_F(HloParserTest, ConstantC64Overflow) {
2623 const std::string original = R"(
2624 HloModule test_module
2625 ENTRY test () -> c64[] {
2626 ROOT c = c64[] constant((1e100, 0))
2627 })";
2628 auto result = ParseAndReturnUnverifiedModule(original);
2629 EXPECT_NE(OkStatus(), result.status());
2630 }
2631
2632 TEST_F(HloParserTest, ConstantC64Underflow) {
2633 const std::string original = R"(
2634 HloModule test_module
2635 ENTRY test () -> c64[] {
2636 ROOT c = c64[] constant((0, -1e100))
2637 })";
2638 auto result = ParseAndReturnUnverifiedModule(original);
2639 EXPECT_NE(OkStatus(), result.status());
2640 }
2641
2642 TEST_F(HloParserTest, ConstantF64Overflow) {
2643 const std::string original = R"(
2644 HloModule test_module
2645 ENTRY test {
2646 ROOT c = f64[] constant(1.8e308)
2647 })";
2648 auto result = ParseAndReturnUnverifiedModule(original);
2649 EXPECT_NE(OkStatus(), result.status());
2650 }
2651
2652 TEST_F(HloParserTest, ConstantF64Underflow) {
2653 const std::string original = R"(
2654 HloModule test_module
2655 ENTRY test {
2656 ROOT c = f64[] constant(-1.8e308)
2657 })";
2658 auto result = ParseAndReturnUnverifiedModule(original);
2659 EXPECT_NE(OkStatus(), result.status());
2660 }
2661
2662 TEST_F(HloParserTest, ConstantWithExp) {
2663 const std::string original = R"(HloModule ConstantWithExp_module
2664
2665 ENTRY %ConstantWithExp.v4 () -> f32[] {
2666 %constant.1 = f32[] constant(3e+2)
2667 }
2668
2669 )";
2670 auto result = ParseAndReturnVerifiedModule(original);
2671 TF_EXPECT_OK(result.status());
2672 // The string will be parsed successfully but the output strings are not
2673 // exactly the same, because "3e2" is parsed into value 300 and will be
2674 // printed as "300".
2675 }
2676
2677 TEST_F(HloParserTest, ShortConstant) {
2678 const std::string original =
2679 R"(HloModule ShortConstant_module, entry_computation_layout={()->f32[67,89]{1,0}}
2680
2681 ENTRY %ShortConstant.v4 () -> f32[67,89] {
2682 ROOT %constant.1 = f32[67,89]{1,0} constant({...})
2683 }
2684
2685 )";
2686 auto result = ParseAndReturnVerifiedModule(original);
2687 TF_EXPECT_OK(result.status());
2688 EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2689 }
2690
2691 TEST_F(HloParserTest, NegativeNan) {
2692 const std::string original =
2693 R"(HloModule NegativeNan_module, entry_computation_layout={()->bf16[2]{0}}
2694
2695 ENTRY %NegativeNan () -> bf16[2] {
2696 ROOT %constant = bf16[2]{0} constant({-nan, -nan})
2697 }
2698
2699 )";
2700 auto result = ParseAndReturnUnverifiedModule(original);
2701 EXPECT_EQ(OkStatus(), result.status());
2702 EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2703 }
2704
2705 TEST_F(HloParserTest, NanPayload) {
2706 const std::string original =
2707 R"(HloModule NanPayload_module, entry_computation_layout={()->bf16[2]{0}}
2708
2709 ENTRY %NanPayload () -> bf16[2] {
2710 ROOT %constant = bf16[2]{0} constant({-nan(0x7f), -nan(0x3f)})
2711 }
2712
2713 )";
2714 auto result = ParseAndReturnUnverifiedModule(original);
2715 EXPECT_EQ(OkStatus(), result.status());
2716 EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2717 }
2718
2719 TEST_F(HloParserTest, AttributesAnyOrder) {
2720 const std::string original = R"(HloModule any_order_module
2721
2722 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
2723 %input = f32[1,2,1]{2,1,0} parameter(0)
2724 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2725 %filter = f32[1,1,1]{2,1,0} parameter(1)
2726 ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
2727 }
2728
2729 )";
2730 TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2731 }
2732
2733 TEST_F(HloParserTest, InvalidDimLabels) {
2734 std::string prefix = R"(HloModule invalid_dim_labels_module
2735
2736 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2737 %input = f32[1,2,1]{2,1,0} parameter(0)
2738 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2739 %filter = f32[1,1,1]{2,1,0} parameter(1)
2740 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
2741 std::string suffix = R"(
2742 }
2743
2744 )";
2745
2746 ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2747 absl::StrCat(prefix, ",dim_labels=00_01->10", suffix))
2748 .status()
2749 .error_message(),
2750 "expects unique");
2751
2752 ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2753 absl::StrCat(prefix, ",dim_labels=012_0123->210", suffix))
2754 .status()
2755 .error_message(),
2756 "must have same number of spatial dimensions");
2757
2758 ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2759 absl::StrCat(prefix, ",dim_labels=013_0123->210", suffix))
2760 .status()
2761 .error_message(),
2762 "expects [0-2bf?]");
2763 }
2764
2765 TEST_F(HloParserTest, UnexpectedAttribute) {
2766 const std::string original = R"(HloModule unexpected_attr_module
2767
2768 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2769 %token0 = token[] after-all()
2770 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2771 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2772 ROOT %constant = f32[] constant(2.1)
2773 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
2774 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2775 }
2776
2777 )";
2778 ExpectHasSubstr(
2779 ParseAndReturnUnverifiedModule(original).status().error_message(),
2780 "unexpected attribute \"calls\"");
2781 }
2782
TEST_F(HloParserTest,MissingAttribute)2783 TEST_F(HloParserTest, MissingAttribute) {
2784 const std::string original = R"(HloModule missing_attr_module
2785
2786 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2787 %token0 = token[] after-all()
2788 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2789 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2790 ROOT %constant = f32[] constant(-2.1)
2791 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
2792 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2793 }
2794
2795 )";
2796 ExpectHasSubstr(
2797 ParseAndReturnUnverifiedModule(original).status().error_message(),
2798 "attribute channel_id is expected but not seen");
2799 }
2800
TEST_F(HloParserTest,PredecessorUndefined)2801 TEST_F(HloParserTest, PredecessorUndefined) {
2802 const std::string original = R"(HloModule pre_not_found_module
2803
2804 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2805 %token0 = token[] after-all()
2806 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2807 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2808 ROOT %constant = f32[] constant(2.1)
2809 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
2810 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2811 }
2812
2813 )";
2814 ExpectHasSubstr(
2815 ParseAndReturnUnverifiedModule(original).status().error_message(),
2816 "'done' is not defined");
2817 }
2818
TEST_F(HloParserTest,SliceAllowOmitStride1)2819 TEST_F(HloParserTest, SliceAllowOmitStride1) {
2820 const std::string original = R"(HloModule slice_module
2821
2822 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
2823 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
2824 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
2825 }
2826
2827 )";
2828 TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2829 }
2830
TEST_F(HloParserTest,PaddingConfigIsNotWindowPad)2831 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
2832 const std::string original = R"(HloModule window_pad_module
2833
2834 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2835 %input = f32[1,2,1]{2,1,0} parameter(0)
2836 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2837 %filter = f32[1,1,1]{2,1,0} parameter(1)
2838 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
2839 }
2840
2841 )";
2842 ExpectHasSubstr(
2843 ParseAndReturnUnverifiedModule(original).status().error_message(),
2844 "expects padding_low and padding_high separated by '_'");
2845 }
2846
TEST_F(HloParserTest,CommaBetweenSubAttributes)2847 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
2848 const std::string original = R"(HloModule test_comma_module
2849
2850 ENTRY %test_comma.v4 () -> f32[] {
2851 ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
2852 }
2853
2854 )";
2855 TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2856 }
2857
TEST_F(HloParserTest,ComputationShapeDoesNotMatchRootShape)2858 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
2859 const std::string original = R"(HloModule custom_call:
2860
2861 ENTRY %CustomCall () -> f32[1] {
2862 %constant = f32[1]{0} constant({12345})
2863 ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
2864 })";
2865 ExpectHasSubstr(
2866 ParseAndReturnUnverifiedModule(original).status().error_message(),
2867 "Shape of computation CustomCall, f32[1], is not compatible "
2868 "with that of its root instruction foo, f32[1,2,3]");
2869 }
2870
TEST_F(HloParserTest,EntryComputationWithLayout)2871 TEST_F(HloParserTest, EntryComputationWithLayout) {
2872 const std::string original = R"(HloModule layout:
2873 add_F32.v3 {
2874 lhs = f32[] parameter(0)
2875 rhs = f32[] parameter(1)
2876 ROOT add = f32[] add(lhs, rhs)
2877 }
2878
2879 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
2880 input = f32[8,16,256]{0,1,2} parameter(0)
2881 constant = f32[] constant(0)
2882 ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
2883 })";
2884
2885 auto module = ParseAndReturnVerifiedModule(original);
2886 TF_ASSERT_OK(module.status());
2887 auto program_layout = module.ValueOrDie()->entry_computation_layout();
2888 ASSERT_EQ(program_layout.parameter_count(), 1);
2889 auto param_layout = program_layout.parameter_layout(0).layout();
2890 auto result_layout = program_layout.result_layout().layout();
2891 EXPECT_TRUE(
2892 LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
2893 << "actual layout of parameter(0) is "
2894 << LayoutUtil::HumanString(param_layout);
2895 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
2896 << "actual layout of result is "
2897 << LayoutUtil::HumanString(result_layout);
2898 }
2899
TEST_F(HloParserTest,NoEntry)2900 TEST_F(HloParserTest, NoEntry) {
2901 const std::string original = R"(HloModule no_entry:
2902 c1 {
2903 const1 = f32[1]{0} constant({12345})
2904 }
2905 c2 {
2906 const2 = f32[1]{0} constant({67890})
2907 })";
2908 auto module = ParseAndReturnVerifiedModule(original);
2909 TF_ASSERT_OK(module.status());
2910 EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
2911 }
2912
TEST_F(HloParserTest,NoRoot)2913 TEST_F(HloParserTest, NoRoot) {
2914 const std::string original = R"(HloModule no_root:
2915 ENTRY consts {
2916 first = f32[1]{0} constant({12345})
2917 last = f32[1]{0} constant({67890})
2918 })";
2919 auto module = ParseAndReturnVerifiedModule(original);
2920 TF_ASSERT_OK(module.status());
2921 EXPECT_EQ(
2922 module.ValueOrDie()->entry_computation()->root_instruction()->name(),
2923 "last");
2924 }
2925
TEST_F(HloParserTest,Comments)2926 TEST_F(HloParserTest, Comments) {
2927 const std::string original = R"(/* module description. */
2928 HloModule comments:
2929
2930 ENTRY /*comment*/ c1 {
2931 /* blah */
2932 ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
2933 /* comment */
2934 }
2935
2936 /* something else */
2937
2938 )";
2939 auto module = ParseAndReturnVerifiedModule(original);
2940 TF_ASSERT_OK(module.status());
2941 }
2942
TEST_F(HloParserTest,MultilineComments)2943 TEST_F(HloParserTest, MultilineComments) {
2944 const std::string original = R"(HloModule multiline_comment:
2945 ENTRY c1 {
2946 /*
2947 ROOT foo = f32[1]{0} constant({12345})
2948 */
2949 ROOT const1 = f32[1]{0} constant({12345})
2950 /*
2951 a
2952 b
2953 c
2954 d
2955
2956 */
2957 })";
2958 auto module = ParseAndReturnVerifiedModule(original);
2959 TF_ASSERT_OK(module.status());
2960 }
2961
TEST_F(HloParserTest,UnterminatedComment)2962 TEST_F(HloParserTest, UnterminatedComment) {
2963 const std::string original = R"(HloModule unterminated_comment:
2964 ENTRY c1 {
2965 /* unterminated
2966 ROOT const1 = f32[1]{0} constant({12345})
2967 })";
2968 // Verify that the error message points to the beginning of the unterminated
2969 // comment.
2970 ExpectHasSubstr(
2971 ParseAndReturnUnverifiedModule(original).status().error_message(),
2972 "/* unterminated\n^");
2973 }
2974
2975 TEST_F(HloParserTest, SlashSlashComments) {
2976 const std::string original = R"(HloModule slash_slash_comment:
2977 // Garbage
2978 ENTRY c1 {
2979 // Foo bar
2980 ROOT const1 = f32[1]{0} constant({12345}) // Something else
2981 })";
2982 auto module = ParseAndReturnVerifiedModule(original);
2983 TF_ASSERT_OK(module.status());
2984 }
2985
2986 TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
2987 const std::string original =
2988 "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
2989 "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
2990 auto module = ParseAndReturnVerifiedModule(original);
2991 TF_ASSERT_OK(module.status());
2992 }
2993
2994 TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
2995 const std::string original =
2996 "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
2997 "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
2998 auto module = ParseAndReturnVerifiedModule(original);
2999 TF_ASSERT_OK(module.status());
3000 }
3001
3002 TEST_F(HloParserTest, MultipleEntries) {
3003 const std::string original = R"(HloModule multiple_entries:
3004 ENTRY c1 {
3005 const1 = f32[1]{0} constant({12345})
3006 }
3007 ENTRY c2 {
3008 const2 = f32[1]{0} constant({67890})
3009 })";
3010 ExpectHasSubstr(
3011 ParseAndReturnUnverifiedModule(original).status().error_message(),
3012 "expects only one ENTRY");
3013 }
3014
3015 TEST_F(HloParserTest, SimpleAliasing) {
3016 const std::string original = R"(
3017 HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) }
3018
3019 ENTRY entry {
3020 %p = (f32[], f32[]) parameter(0)
3021 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3022 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3023 ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3024 }
3025 )";
3026 auto module = ParseAndReturnVerifiedModule(original);
3027 TF_ASSERT_OK(module.status());
3028 std::unique_ptr<HloModule> parsed_module = std::move(module).value();
3029 EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
3030 ShapeIndex{0});
3031
3032 EXPECT_TRUE(
3033 parsed_module->input_output_alias_config().ParameterMustAlias(0, {0}));
3034 EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
3035 ShapeIndex{1});
3036 EXPECT_FALSE(
3037 parsed_module->input_output_alias_config().ParameterMustAlias(0, {1}));
3038 }
3039
3040 TEST_F(HloParserTest, NestedAliasing) {
3041 const std::string original = R"(
3042 HloModule Module, input_output_alias={ {0, 0}: (0, {0}), {1, 1}: (0, {1}) }
3043
3044 ENTRY entry {
3045 %p = (f32[], f32[]) parameter(0)
3046 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3047 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3048 %t0 = (f32[], f32[]) tuple(%p0, %p1)
3049 %t1 = (f32[], f32[]) tuple(%p0, %p1)
3050 ROOT %out = ((f32[], f32[]), (f32[], f32[])) tuple(%t0, %t1)
3051 }
3052 )";
3053 auto module = ParseAndReturnVerifiedModule(original);
3054 TF_ASSERT_OK(module.status());
3055 std::unique_ptr<HloModule> parsed_module = std::move(module).value();
3056 EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
3057 ShapeIndex({0, 0}));
3058 EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
3059 ShapeIndex({1, 1}));
3060 }
3061
3062 TEST_F(HloParserTest, AliasingWrongIndex) {
3063 const std::string original = R"(
3064 HloModule Module, input_output_alias={ {0 : (0, {0}), {1}: (0, {1}) }
3065
3066 ENTRY entry {
3067 %p = (f32[], f32[]) parameter(0)
3068 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3069 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3070 ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3071 }
3072 )";
3073 ExpectHasSubstr(
3074 ParseAndReturnUnverifiedModule(original).status().error_message(),
3075 "Expects '}' at the end of ShapeIndex");
3076 }
3077
3078 TEST_F(HloParserTest, AliasingShapeIndexNotNumerical) {
3079 const std::string original = R"(
3080 HloModule Module, input_output_alias={ {0, a}: (0, {0}), {1}: (0, {1}) }
3081
3082 ENTRY entry {
3083 %p = (f32[], f32[]) parameter(0)
3084 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3085 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3086 ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3087 }
3088 )";
3089 ExpectHasSubstr(
3090 ParseAndReturnUnverifiedModule(original).status().error_message(),
3091 "expects integer");
3092 }
3093
3094 TEST_F(HloParserTest, AliasingWrongFormatNoColon) {
3095 const std::string original = R"(
3096 HloModule Module, input_output_alias={ {0, 0}: (0, {0}), (0, {1}) }
3097
3098 ENTRY entry {
3099 %p = (f32[], f32[]) parameter(0)
3100 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3101 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3102 ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3103 }
3104 )";
3105 ExpectHasSubstr(
3106 ParseAndReturnUnverifiedModule(original).status().error_message(),
3107 "Expects '{' at the start of ShapeIndex");
3108 }
3109
3110 TEST_F(HloParserTest, AliasingWrongFormatTwoColons) {
3111 const std::string original = R"(
3112 HloModule Module, input_output_alias={ {0}: (0, {0}): {0, 1}, {1}: (0, {1}) }
3113
3114 ENTRY entry {
3115 %p = (f32[], f32[]) parameter(0)
3116 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3117 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3118 ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3119 }
3120 )";
3121 ExpectHasSubstr(
3122 ParseAndReturnUnverifiedModule(original).status().error_message(),
3123 "Expects '}' at the end of aliasing description");
3124 }
3125
3126 TEST_F(HloParserTest, AliasingWrongFormatAlphaParam) {
3127 const std::string original = R"(
3128 HloModule Module, input_output_alias={ {0, a}: (zero, {0}), {1}: (0, {1}) }
3129
3130 ENTRY entry {
3131 %p = (f32[], f32[]) parameter(0)
3132 %p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
3133 %p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
3134 ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
3135 }
3136 )";
3137 ExpectHasSubstr(
3138 ParseAndReturnUnverifiedModule(original).status().error_message(),
3139 "expects integer");
3140 }
3141
3142 TEST_F(HloParserTest, MultipleRoots) {
3143 const std::string original = R"(HloModule multiple_roots:
3144 ENTRY consts {
3145 ROOT const1 = f32[1]{0} constant({12345})
3146 ROOT const2 = f32[1]{0} constant({12345})
3147 })";
3148 ExpectHasSubstr(
3149 ParseAndReturnUnverifiedModule(original).status().error_message(),
3150 "one computation should have only one ROOT");
3151 }
3152
3153 TEST_F(HloParserTest, ComputationExists) {
3154 const std::string original = R"(HloModule comp_exists
3155 comp {
3156 const1 = f32[1]{0} constant({12345})
3157 }
3158 comp {
3159 const2 = f32[1]{0} constant({67890})
3160 })";
3161 ExpectHasSubstr(
3162 ParseAndReturnUnverifiedModule(original).status().error_message(),
3163 R"(was parsing 2:1: error: computation previously defined here
3164 comp {
3165 ^)");
3166 }
3167
3168 TEST_F(HloParserTest, CrossComputationLookup) {
3169 const std::string original = R"(HloModule cross_computation_lookup:
3170 tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
3171 ROOT aparam = (s32[], s32[]) parameter(0)
3172 }
3173
3174 tcallb (b: (s32[], s32[])) -> s32[] {
3175 rparam = (s32[], s32[]) parameter(0)
3176 ROOT gte0 = s32[] get-tuple-element(aparam), index=0
3177 }
3178
3179 ENTRY entry {
3180 param = (s32[], s32[]) parameter(0)
3181 call0 = (s32[], s32[]) call(param), to_apply=tcalla
3182 ROOT call1 = s32[] call(param), to_apply=tcallb
3183 })";
3184 ExpectHasSubstr(
3185 ParseAndReturnUnverifiedModule(original).status().error_message(),
3186 "was parsing 8:39: error: instruction does not exist: aparam");
3187 }
3188
3189 TEST_F(HloParserTest, SameNameDiffComputations) {
3190 const std::string original = R"(HloModule same_names:
3191 add {
3192 p0 = f32[] parameter(0)
3193 p1 = f32[] parameter(1)
3194 ROOT result = f32[] add(p0, p1)
3195 }
3196
3197 ENTRY ReduceR3ToR2 {
3198 p0 = f32[8,16,256]{2,1,0} parameter(0)
3199 p1 = f32[] constant(0)
3200 ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
3201 }
3202 )";
3203 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
3204 ASSERT_NE(module->entry_computation(), nullptr);
3205 EXPECT_THAT(module->entry_computation()->root_instruction(),
3206 GmockMatch(m::Reduce()));
3207 }
3208
3209 TEST_F(HloParserTest, ParseSharding) {
3210 const std::string original = "{maximal device=42}";
3211 TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
3212 EXPECT_EQ(sharding.ToString(), original);
3213 }
3214
3215 TEST_F(HloParserTest, ParseShardingPartialReplication) {
3216 const std::string original = "{devices=[2,2]0,1,2,3 last_tile_dim_replicate}";
3217 TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
3218 EXPECT_EQ(sharding.ToString(), original);
3219 Array<int64_t> group_tiling({2});
3220 group_tiling(0) = 0;
3221 group_tiling(1) = 1;
3222 std::vector<int64_t> group0_members({0, 1});
3223 std::vector<int64_t> group1_members({2, 3});
3224 EXPECT_EQ(
3225 HloSharding::PartialTile(group_tiling, {group0_members, group1_members})
3226 .ToString(),
3227 original);
3228 }
3229
3230 TEST_F(HloParserTest, ParseShardingSubGroup) {
3231 const std::string original =
3232 "{devices=[2,2,2,2]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 "
3233 "last_tile_dims={manual, replicated}}";
3234 TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
3235 EXPECT_EQ(sharding.ToString(), original);
3236 Array<int64_t> tile_assignment({2, 2, 2, 2});
3237 tile_assignment.FillIota(0);
3238 std::vector<OpSharding::Type> subgroup_types = {OpSharding::MANUAL,
3239 OpSharding::REPLICATED};
3240 EXPECT_EQ(HloSharding::Subgroup(tile_assignment, subgroup_types).ToString(),
3241 original);
3242 }
3243
3244 TEST_F(HloParserTest, ParseFrontendAttributes) {
3245 const std::string original =
3246 R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})";
3247 TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
3248 ParseFrontendAttributes(original));
3249 EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
3250 }
3251
3252 TEST_F(HloParserTest, ParseWindow) {
3253 Window original = window_util::MakeWindow({1, 2, 3});
3254 TF_ASSERT_OK_AND_ASSIGN(Window parsed,
3255 ParseWindow(window_util::ToString(original)))
3256 EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
3257 }
3258
3259 TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
3260 const std::string original = "b0f_0io->b0f";
3261 TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
3262 ParseConvolutionDimensionNumbers(original));
3263 EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
3264 }
3265
3266 TEST_F(HloParserTest, ParseConvolutionDimensionNumbersWithUnknownDims) {
3267 const std::string original = "b0?f_?0?io->?b?0?f";
3268 TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
3269 ParseConvolutionDimensionNumbers(original));
3270 EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
3271 }
3272
3273 TEST_F(HloParserTest, ParseReplicaGroups) {
3274 const std::string original = "{{0,1},{2,3}}";
3275 TF_ASSERT_OK_AND_ASSIGN(std::vector<ReplicaGroup> replica_groups,
3276 ParseReplicaGroupsOnly(original));
3277 EXPECT_EQ(original, ReplicaGroupsToString(replica_groups));
3278 }
3279
3280 TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
3281 const std::string original = "0_1x2_3";
3282 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
3283 EXPECT_EQ(original, PaddingConfigToString(dnums));
3284 }
3285
3286 TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
3287 const std::string original = "0_1_0x2_3_4";
3288 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
3289 EXPECT_EQ(original, PaddingConfigToString(dnums));
3290 }
3291
3292 TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
3293 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
3294 // The extra "_0" gets added to the canonical string because the other dim has
3295 // interior padding.
3296 EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
3297 }
3298
3299 TEST_F(HloParserTest, NontupleInfeed) {
3300 const std::string original = R"(HloModule nontuple_infeed:
3301 ENTRY nontuple_infeed {
3302 token0 = token[] after-all()
3303 ROOT infeed = pred[] infeed(token0)
3304 })";
3305 ExpectHasSubstr(
3306 ParseAndReturnUnverifiedModule(original).status().error_message(),
3307 "infeed must have a non-empty tuple shape");
3308 }
3309
3310 TEST(HloParserSingleOpTest, SingleOp) {
3311 const std::string text =
3312 "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
3313 "f32[2,4]{1,0} %x)";
3314 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3315 const HloComputation* computation = module->entry_computation();
3316 ASSERT_NE(computation, nullptr);
3317 EXPECT_THAT(computation->root_instruction(),
3318 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
3319 }
3320
3321 TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
3322 const std::string text =
3323 "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
3324 StatusOr<std::unique_ptr<HloModule>> module =
3325 ParseAndReturnUnverifiedModule(text);
3326 ASSERT_TRUE(!module.status().ok());
3327 LOG(INFO) << "Status: " << module.status();
3328 EXPECT_THAT(module.status().ToString(),
3329 HasSubstr("expects '=' in instruction"));
3330 }
3331
3332 TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
3333 const std::string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
3334 StatusOr<std::unique_ptr<HloModule>> module =
3335 ParseAndReturnUnverifiedModule(text);
3336 ASSERT_TRUE(!module.status().ok());
3337 LOG(INFO) << "Status: " << module.status();
3338 EXPECT_THAT(module.status().ToString(),
3339 HasSubstr("Operand had no shape in HLO text"));
3340 }
3341
3342 TEST(HloParserSingleOpTest, SingleOpNoNames) {
3343 const std::string text =
3344 "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
3345 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3346 const HloComputation* computation = module->entry_computation();
3347 ASSERT_NE(computation, nullptr);
3348 EXPECT_THAT(computation->root_instruction(),
3349 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
3350 }
3351
3352 TEST(HloParserSingleOpTest, CanonicalOp) {
3353 const std::string text =
3354 "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
3355 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3356 const HloComputation* computation = module->entry_computation();
3357 ASSERT_NE(computation, nullptr);
3358 EXPECT_THAT(computation->root_instruction(),
3359 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
3360 EXPECT_EQ(
3361 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
3362 text);
3363 }
3364
3365 TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
3366 const std::string text =
3367 R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
3368 {
3369 tmp_0 = f32[5,10]{1,0} parameter(0)
3370 tmp_1 = f32[20,10]{1,0} parameter(1)
3371 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
3372 {
3373 tmp_0 = f32[5,10]{1,0} parameter(0)
3374 tmp_1 = f32[20,10]{1,0} parameter(1)
3375 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
3376 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
3377 }
3378 }, body=
3379 {
3380 tmp_0 = f32[5,10]{1,0} parameter(0)
3381 tmp_1 = f32[20,10]{1,0} parameter(1)
3382 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
3383 {
3384 tmp_0 = f32[5,10]{1,0} parameter(0)
3385 tmp_1 = f32[20,10]{1,0} parameter(1)
3386 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
3387 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
3388 }
3389 })";
3390
3391 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3392 const HloComputation* computation = module->entry_computation();
3393 ASSERT_NE(computation, nullptr);
3394 EXPECT_EQ(
3395 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
3396 text);
3397 }
3398
3399 TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
3400 const std::string text =
3401 R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
3402 {
3403 tmp_0 = f32[5,10]{1,0} parameter(0)
3404 ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
3405 },
3406 {
3407 tmp_0 = f32[5,10]{1,0} parameter(0)
3408 ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
3409 },
3410 {
3411 tmp_0 = f32[5,10]{1,0} parameter(0)
3412 ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
3413 }
3414 })";
3415
3416 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3417 const HloComputation* computation = module->entry_computation();
3418 ASSERT_NE(computation, nullptr);
3419 EXPECT_EQ(
3420 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
3421 text);
3422 }
3423
3424 TEST(HloParserSingleOpTest, SingleOpWithNested) {
3425 const std::string text =
3426 R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
3427 {
3428 %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
3429 %param_1 = f32[2]{0} parameter(1)
3430 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
3431 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
3432 })";
3433
3434 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3435 const HloComputation* computation = module->entry_computation();
3436 ASSERT_NE(computation, nullptr);
3437 EXPECT_THAT(computation->root_instruction(),
3438 GmockMatch(m::Op()
3439 .WithOpcode(HloOpcode::kFusion)
3440 .WithNumOperands(2)
3441 .WithOperand(0, m::Parameter(0))
3442 .WithOperand(1, m::Parameter(1))));
3443 }
3444
3445 TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
3446 const std::string text =
3447 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
3448 {
3449 result = f32[] add(f32[] x, f32[] y)
3450 })";
3451 auto status = ParseAndReturnUnverifiedModule(text).status();
3452 ASSERT_FALSE(status.ok());
3453 EXPECT_THAT(status.error_message(), HasSubstr("does not exist: x"));
3454 }
3455
3456 TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
3457 const std::string text =
3458 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
3459 {
3460 f32[] add(f32[] x, f32[] y)
3461 })";
3462 auto status = ParseAndReturnUnverifiedModule(text).status();
3463 ASSERT_FALSE(status.ok());
3464 EXPECT_THAT(status.error_message(), HasSubstr("expects name"));
3465 }
3466
3467 TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
3468 const std::string text =
3469 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
3470 {
3471 result = f32[] add(f32[], f32[])
3472 })";
3473 auto status = ParseAndReturnUnverifiedModule(text).status();
3474 ASSERT_FALSE(status.ok());
3475 EXPECT_THAT(status.error_message(), HasSubstr("expects name"));
3476 }
3477
3478 TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
3479 const std::string text =
3480 R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
3481 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
3482 const HloComputation* computation = module->entry_computation();
3483 ASSERT_NE(computation, nullptr);
3484 EXPECT_THAT(computation->root_instruction(),
3485 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
3486 auto* convolution =
3487 Cast<HloConvolutionInstruction>(computation->root_instruction());
3488 EXPECT_EQ(convolution->feature_group_count(), 1);
3489 }
3490
3491 TEST(HloParserSingleOpTest, MultipleOpsProducesError) {
3492 const std::string text = R"(
3493 param = f32[2,5,1,3] parameter(0)
3494 transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
3495 )";
3496 auto status = ParseAndReturnUnverifiedModule(text).status();
3497 ASSERT_FALSE(status.ok());
3498 EXPECT_THAT(status.error_message(), HasSubstr("Expected eof"));
3499 }
3500
3501 TEST_F(HloParserTest, IsScheduledIsFalse) {
3502 const std::string text = R"(
3503 HloModule axpy_module, is_scheduled=false
3504
3505 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3506 %alpha = f32[] parameter(0)
3507 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3508 %x = f32[2,4]{1,0} parameter(1)
3509 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3510 %y = f32[2,4]{1,0} parameter(2)
3511 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3512 }
3513 )";
3514 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3515 ASSERT_FALSE(module->has_schedule());
3516 }
3517
3518 TEST_F(HloParserTest, IsScheduledNotPresent) {
3519 const std::string text = R"(
3520 HloModule axpy_module
3521
3522 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3523 %alpha = f32[] parameter(0)
3524 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3525 %x = f32[2,4]{1,0} parameter(1)
3526 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3527 %y = f32[2,4]{1,0} parameter(2)
3528 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3529 }
3530 )";
3531 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3532 ASSERT_FALSE(module->has_schedule());
3533 }
3534
3535 TEST_F(HloParserTest, IsScheduledIsTrue) {
3536 const std::string text = R"(
3537 HloModule axpy_module, is_scheduled=true
3538
3539 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3540 %alpha = f32[] parameter(0)
3541 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3542 %x = f32[2,4]{1,0} parameter(1)
3543 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3544 %y = f32[2,4]{1,0} parameter(2)
3545 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3546 }
3547 )";
3548 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3549 ASSERT_TRUE(module->has_schedule());
3550 TF_ASSERT_OK(module->schedule().Verify());
3551 EXPECT_EQ(module->schedule().sequences().size(), 1);
3552 ASSERT_TRUE(
3553 module->schedule().is_computation_scheduled(module->entry_computation()));
3554 EXPECT_THAT(
3555 module->schedule().sequence(module->entry_computation()).instructions(),
3556 ElementsAre(GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
3557 GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
3558 GmockMatch(m::Parameter()), GmockMatch(m::Add())));
3559 }
3560
3561 TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
3562 // As above but in with a different schedule order.
3563 const std::string text = R"(
3564 HloModule axpy_module, is_scheduled=true
3565
3566 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
3567 %alpha = f32[] parameter(0)
3568 %x = f32[2,4]{1,0} parameter(1)
3569 %y = f32[2,4]{1,0} parameter(2)
3570 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
3571 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
3572 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
3573 }
3574 )";
3575 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3576 ASSERT_TRUE(module->has_schedule());
3577 TF_ASSERT_OK(module->schedule().Verify());
3578 EXPECT_EQ(module->schedule().sequences().size(), 1);
3579 ASSERT_TRUE(
3580 module->schedule().is_computation_scheduled(module->entry_computation()));
3581 EXPECT_THAT(
3582 module->schedule().sequence(module->entry_computation()).instructions(),
3583 ElementsAre(GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
3584 GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
3585 GmockMatch(m::Multiply()), GmockMatch(m::Add())));
3586 }
3587
3588 TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
3589 const std::string original =
3590 R"(HloModule CustomCallWrongNumberofOperandConstraints
3591
3592 ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
3593 %p0 = f32[42,2,3]{0,1,2} parameter(0)
3594 %p1 = f32[123,4]{0,1} parameter(1)
3595 ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
3596 }
3597
3598 )";
3599 ExpectHasSubstr(
3600 ParseAndReturnUnverifiedModule(original).status().error_message(),
3601 "Expected 2 operand layout constraints, 1 given");
3602 }
3603
3604 TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
3605 const std::string original =
3606 R"(HloModule CustomCallIncompatibleOperandConstraints
3607
3608 ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
3609 %p0 = f32[42,2,3]{0,1,2} parameter(0)
3610 %p1 = f32[123,4]{0,1} parameter(1)
3611 ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
3612 }
3613
3614 )";
3615 ExpectHasSubstr(
3616 ParseAndReturnUnverifiedModule(original).status().error_message(),
3617 "operand 1 is not compatible with operand shape");
3618 }
3619
3620 TEST_F(HloParserTest, CustomCallWithNonexistentVersion) {
3621 const std::string original = R"(HloModule custom_call
3622
3623 ENTRY %CustomCall () -> f32[1,2,3] {
3624 %constant = f32[1]{0} constant({12345})
3625 ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", api_version=API_VERSION_THAT_DOESNT_EXIST
3626 }
3627
3628 )";
3629 ExpectHasSubstr(
3630 ParseAndReturnUnverifiedModule(original).status().error_message(),
3631 "Unknown API version");
3632 }
3633
3634 TEST_F(HloParserTest, CustomCallWithUnspecifiedVersion) {
3635 const std::string original = R"(HloModule custom_call
3636
3637 ENTRY %CustomCall () -> f32[1,2,3] {
3638 %constant = f32[1]{0} constant({12345})
3639 ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", api_version=API_VERSION_UNSPECIFIED
3640 }
3641
3642 )";
3643 ExpectHasSubstr(
3644 ParseAndReturnUnverifiedModule(original).status().error_message(),
3645 "Invalid API version");
3646 }
3647
3648 TEST_F(HloParserTest, AllowShapeWhitespace) {
3649 const std::string text = R"(
3650 HloModule module
3651
3652 ENTRY entry {
3653 ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
3654 }
3655 )";
3656 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3657 }
3658
3659 TEST_F(HloParserTest, ShapeMismatchInOperand) {
3660 const std::string text = R"(
3661 HloModule foobar
3662
3663 ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
3664 %p = f32[2,2] parameter(0)
3665 %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
3666 ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
3667 }
3668 )";
3669
3670 ExpectHasSubstr(ParseAndReturnUnverifiedModule(text).status().error_message(),
3671 "The declared operand shape f32[2,5]{1,0} is not compatible"
3672 " with the shape of the operand instruction f32[2,2]{1,0}.");
3673 }
3674
3675 TEST_F(HloParserTest, ParseShapeStringR2F32) {
3676 std::string shape_string = "f32[123,456]";
3677 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3678 Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
3679 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3680 << "expected: " << ShapeUtil::HumanString(expected)
3681 << "actual: " << ShapeUtil::HumanString(actual);
3682 }
3683
3684 TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
3685 std::string shape_string = "(f32[1572864],s8[5120,1024])";
3686 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3687 Shape expected =
3688 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
3689 ShapeUtil::MakeShape(S8, {5120, 1024})});
3690 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3691 << "expected: " << ShapeUtil::HumanString(expected)
3692 << "actual: " << ShapeUtil::HumanString(actual);
3693 }
3694
3695 TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
3696 std::string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
3697 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3698 Shape expected = ShapeUtil::MakeTupleShape({
3699 ShapeUtil::MakeShape(F32, {1}),
3700 ShapeUtil::MakeTupleShape(
3701 {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
3702 ShapeUtil::MakeOpaqueShape(),
3703 ShapeUtil::MakeShape(F32, {3}),
3704 });
3705 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3706 << "expected: " << ShapeUtil::HumanString(expected)
3707 << "actual: " << ShapeUtil::HumanString(actual);
3708 }
3709
3710 TEST_F(HloParserTest, ParseShapeStringWithLayout) {
3711 std::string shape_string = "f32[123,456]{0,1}";
3712 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3713 Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
3714 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3715 << "expected: " << ShapeUtil::HumanString(expected)
3716 << "actual: " << ShapeUtil::HumanString(actual);
3717 }
3718
3719 TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
3720 // One tile.
3721 std::string shape_string = "f32[123,456]{0,1:T(2,128)}";
3722 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3723 Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {},
3724 {Tile({2, 128})});
3725 EXPECT_EQ(expected, actual)
3726 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3727 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3728
3729 // Tile with negative dimension size for combining dimensions.
3730 shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
3731 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3732 expected =
3733 ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2}, {},
3734 {Tile({2, Tile::kCombineDimension, 128})});
3735 EXPECT_EQ(expected, actual)
3736 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3737 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3738
3739 // Two tiles.
3740 shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
3741 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3742 expected = ShapeUtil::MakeShapeWithLayout(
3743 BF16, {123, 456, 789}, {2, 1, 0}, {},
3744 {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
3745 EXPECT_EQ(expected, actual)
3746 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3747 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3748
3749 // Tile with element size in bits.
3750 shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
3751 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3752 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {},
3753 {Tile({2, 128})}, 1);
3754 EXPECT_EQ(expected, actual)
3755 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3756 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3757
3758 // Element size in bits without tile.
3759 shape_string = "pred[123,456]{1,0:E(1)}";
3760 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3761 expected =
3762 ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, {}, 1);
3763 EXPECT_EQ(expected, actual)
3764 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3765 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3766
3767 // Wrong minor_to_major.
3768 shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
3769 auto result = ParseShape(shape_string);
3770 ExpectHasSubstr(result.status().error_message(),
3771 "Dimensions size is 3, but minor to major size is 1.");
3772 }
3773
3774 TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
3775 // Tile, element size, and memory space.
3776 std::string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
3777 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3778 Shape expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {},
3779 {Tile({2, 128})}, 1, 3);
3780 EXPECT_EQ(expected, actual)
3781 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3782 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3783
3784 // Element size and memory space.
3785 shape_string = "pred[123,456]{1,0:E(1)S(3)}";
3786 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3787 expected =
3788 ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, {}, 1, 3);
3789 EXPECT_EQ(expected, actual)
3790 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3791 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3792
3793 // Memory space only.
3794 shape_string = "pred[123,456]{1,0:S(3)}";
3795 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
3796 expected =
3797 ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, {}, 0, 3);
3798 EXPECT_EQ(expected, actual)
3799 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
3800 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
3801 }
3802
3803 TEST_F(HloParserTest, ParseOpaqueType) {
3804 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
3805 Shape expected = ShapeUtil::MakeOpaqueShape();
3806 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3807 << "expected: " << ShapeUtil::HumanString(expected)
3808 << "actual: " << ShapeUtil::HumanString(actual);
3809 }
3810
3811 TEST_F(HloParserTest, ParseTokenType) {
3812 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
3813 Shape expected = ShapeUtil::MakeTokenShape();
3814 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3815 << "expected: " << ShapeUtil::HumanString(expected)
3816 << "actual: " << ShapeUtil::HumanString(actual);
3817 }
3818
3819 TEST_F(HloParserTest, ParseInvalidShapeString) {
3820 std::string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
3821 "f32[123,456]dense{foo}"};
3822 for (const std::string& shape_string : shape_strings) {
3823 StatusOr<Shape> result = ParseShape(shape_string);
3824 ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
3825 }
3826 }
3827
3828 TEST_F(HloParserTest, ParseDynamicArray) {
3829 std::string shape_string = "f32[123,<=456]";
3830 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3831 Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
3832 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3833 << "expected: " << ShapeUtil::HumanString(expected)
3834 << "actual: " << ShapeUtil::HumanString(actual);
3835 }
3836
3837 TEST_F(HloParserTest, ParseDynamicTuple) {
3838 std::string shape_string = "(f32[42], u32[<=123,<=456])";
3839 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
3840 Shape expected = ShapeUtil::MakeTupleShape(
3841 {ShapeUtil::MakeShape(F32, {42}),
3842 ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
3843 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
3844 << "expected: " << ShapeUtil::HumanString(expected)
3845 << "actual: " << ShapeUtil::HumanString(actual);
3846 }
3847
3848 TEST_F(HloParserTest, NegativeParameterNumber) {
3849 const std::string hlo_string = "par0 = f32[3,5] parameter(-1)";
3850 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3851 ASSERT_FALSE(result.status().ok());
3852 EXPECT_THAT(result.status().error_message(),
3853 HasSubstr("parameter number must be >= 0"));
3854 }
3855
3856 TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
3857 const std::string hlo_string =
3858 "par0 = (f32[3,5], f32[]) parameter(0), "
3859 "parameter_replication={true,false,true}";
3860 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3861 ASSERT_FALSE(result.status().ok());
3862 EXPECT_THAT(result.status().error_message(),
3863 HasSubstr("parameter has 2 leaf buffers, but "
3864 "parameter_replication has 3 elements"));
3865 }
3866
3867 TEST_F(HloParserTest, CheckIndexedConditionalDimension) {
3868 const char* const hlo_string = R"(
3869 HloModule Module
3870
3871 branch0 {
3872 tparam = f32[4] parameter(0)
3873 ROOT tgte1 = f32[4] ceil(tparam)
3874 }
3875
3876 branch1 {
3877 fparam = f32[4] parameter(0)
3878 ROOT fgte1 = f32[4] floor(fparam)
3879 }
3880
3881 ENTRY entry {
3882 p0 = f32[4] parameter(0)
3883 b0 = s32[2] parameter(1)
3884 ROOT conditional = f32[4] conditional(b0, p0, p0),
3885 branch_computations={branch0, branch1}
3886 }
3887 )";
3888 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3889 EXPECT_NE(OkStatus(), result.status());
3890 EXPECT_THAT(result.status().error_message(),
3891 HasSubstr("The first operand must be a scalar"));
3892 }
3893
3894 TEST_F(HloParserTest, CheckIndexedConditionalElementType) {
3895 const char* const hlo_string = R"(
3896 HloModule Module
3897
3898 branch0 {
3899 tparam = f32[4] parameter(0)
3900 ROOT tgte1 = f32[4] ceil(tparam)
3901 }
3902
3903 branch1 {
3904 fparam = f32[4] parameter(0)
3905 ROOT fgte1 = f32[4] floor(fparam)
3906 }
3907
3908 ENTRY entry {
3909 p0 = f32[4] parameter(0)
3910 b0 = f32[] parameter(1)
3911 ROOT conditional = f32[4] conditional(b0, p0, p0),
3912 branch_computations={branch0, branch1}
3913 }
3914 )";
3915 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3916 EXPECT_NE(OkStatus(), result.status());
3917 EXPECT_THAT(result.status().error_message(),
3918 HasSubstr("The first operand must be a scalar of PRED or S32"));
3919 }
3920
3921 TEST_F(HloParserTest,
3922 CheckPredicatedConditionalRequiresTrueAndFalseComputation) {
3923 const char* const hlo_string = R"(
3924 HloModule Module
3925
3926 branch0 {
3927 tparam = f32[4] parameter(0)
3928 ROOT tgte1 = f32[4] ceil(tparam)
3929 }
3930
3931 branch1 {
3932 fparam = f32[4] parameter(0)
3933 ROOT fgte1 = f32[4] floor(fparam)
3934 }
3935
3936 ENTRY entry {
3937 p0 = f32[4] parameter(0)
3938 b0 = pred[] parameter(1)
3939 ROOT conditional = f32[4] conditional(b0, p0, p0),
3940 branch_computations={branch0, branch1}
3941 }
3942 )";
3943 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3944 EXPECT_NE(OkStatus(), result.status());
3945 EXPECT_THAT(result.status().error_message(),
3946 HasSubstr("unexpected attribute \"branch_computations\""));
3947 }
3948
3949 // Result shape inference tests cases.
3950 TEST_F(HloParserTest, InferUnaryShape) {
3951 constexpr char text[] = R"(HloModule InferUnaryShapeTest
3952 ENTRY InferUnaryShape {
3953 a = f32[2,10]{1,0} parameter(0)
3954 ROOT v = abs(a)
3955 }
3956 )";
3957 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3958 }
3959
3960 TEST_F(HloParserTest, InferBinaryShape) {
3961 constexpr char text[] = R"(HloModule InferBinaryShapeTest
3962 ENTRY InferBinaryShape {
3963 a = f32[2,10]{1,0} parameter(0)
3964 b = f32[2,10]{1,0} parameter(1)
3965 ROOT sum = add(a, b)
3966 }
3967 )";
3968 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3969 EXPECT_TRUE(ShapeUtil::Equal(
3970 module->entry_computation()->ComputeProgramShape().result(),
3971 ShapeUtil::MakeShapeWithLayout(F32, {2, 10}, {1, 0})));
3972 }
3973
3974 TEST_F(HloParserTest, InferTernaryShape) {
3975 constexpr char text[] = R"(HloModule InferTernaryShapeTest
3976 ENTRY InferTernaryShape {
3977 p = pred[] constant(true)
3978 f = s32[] constant(-42)
3979 t = s32[] constant(42)
3980 ROOT select = select(p, f, t)
3981 }
3982 )";
3983 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3984 EXPECT_TRUE(ShapeUtil::Equal(
3985 module->entry_computation()->ComputeProgramShape().result(),
3986 ShapeUtil::MakeScalarShape(S32)));
3987 }
3988
3989 TEST_F(HloParserTest, InferDotShape) {
3990 constexpr char text[] = R"(HloModule InferDotShapeTest
3991 ENTRY InferDotShape {
3992 a = f32[2,10]{1,0} parameter(0)
3993 b = f32[10,2]{1,0} parameter(1)
3994 ROOT dot = dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
3995 }
3996 )";
3997 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
3998 EXPECT_TRUE(ShapeUtil::Equal(
3999 module->entry_computation()->ComputeProgramShape().result(),
4000 ShapeUtil::MakeShape(F32, {2}, {0})));
4001 }
4002
4003 TEST_F(HloParserTest, InferTupleShape) {
4004 constexpr char text[] = R"(HloModule InferTupleShapeTest
4005 ENTRY InferTupleShape () -> s32[2,3] {
4006 c0 = f32[3]{0} constant({1, 2, 3})
4007 c1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
4008 tuple = tuple(c0, c1)
4009 ROOT get = get-tuple-element(tuple), index=1, sharding={maximal device=0}
4010 }
4011 )";
4012 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
4013 EXPECT_TRUE(ShapeUtil::Equal(
4014 module->entry_computation()->ComputeProgramShape().result(),
4015 ShapeUtil::MakeShapeWithLayout(S32, {2, 3}, {1, 0})));
4016 }
4017
4018 TEST_F(HloParserTest, InferShapeMixedExplicitShape) {
4019 constexpr char text[] = R"(HloModule InferUnaryShapeTest
4020 Negate {
4021 x = f32[] parameter(0)
4022 ROOT negate = negate(x)
4023 }
4024
4025 Identity {
4026 y = f32[] parameter(0)
4027 ROOT copy = copy(y)
4028 }
4029
4030 ENTRY InferUnaryShape {
4031 a = f32[] parameter(0)
4032 b = f32[] parameter(1)
4033 p = pred[] parameter(2)
4034 c = f32[] add(a, b)
4035 ROOT conditional = conditional(p, a, c), true_computation=Negate, false_computation=Identity
4036 }
4037 )";
4038 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
4039 EXPECT_TRUE(ShapeUtil::Equal(
4040 module->entry_computation()->ComputeProgramShape().result(),
4041 ShapeUtil::MakeScalarShape(F32)));
4042 }
4043
4044 TEST_F(HloParserTest, CheckAliasPassthroughParams) {
4045 const char* const hlo_string = R"(
4046 HloModule TestModule, alias_passthrough_params=true
4047
4048 ENTRY TestComputation {
4049 p0 = f16[2048,1024] parameter(0)
4050 p1 = f16[2048,1024] parameter(1)
4051 ROOT root = (f16[2048,1024], f16[2048,1024]) tuple(p0, p1)
4052 }
4053 )";
4054 auto result = ParseAndReturnVerifiedModule(hlo_string);
4055 TF_EXPECT_OK(result.status());
4056 EXPECT_TRUE(result.ValueOrDie()->config().alias_passthrough_params());
4057 }
4058
4059 TEST_F(HloParserTest, NestedBroadcastWithoutDimensionsAttribute) {
4060 const char* const hlo_string = R"(
4061 HloModule test
4062 ENTRY test {
4063 ROOT root = sqrt(f32[10,10] broadcast(f32[10] parameter(0)))
4064 }
4065 )";
4066 auto result = ParseAndReturnVerifiedModule(hlo_string);
4067 EXPECT_NE(OkStatus(), result.status());
4068 EXPECT_THAT(result.status().error_message(), HasSubstr("dimensions"));
4069 }
4070
4071 TEST_F(HloParserTest, InvalidDimLevelType) {
4072 const std::string original = R"(HloModule test
4073
4074 ENTRY test {
4075 ROOT root = f32[10,10]{1,0:D(X,C)} parameter(0)
4076 })";
4077 EXPECT_THAT(ParseAndReturnUnverifiedModule(original).status(),
4078 tensorflow::testing::StatusIs(
4079 tensorflow::error::INVALID_ARGUMENT,
4080 HasSubstr("expected a DimLevelType abbreviation")));
4081 }
4082
4083 TEST_F(HloParserTest, InvalidDimLevelTypeCount) {
4084 const std::string original = R"(HloModule test
4085
4086 ENTRY test {
4087 ROOT root = f32[10,10]{1,0:D(C)} parameter(0)
4088 })";
4089 EXPECT_THAT(
4090 ParseAndReturnUnverifiedModule(original).status(),
4091 tensorflow::testing::StatusIs(
4092 tensorflow::error::INVALID_ARGUMENT,
4093 HasSubstr("Dimensions size is 2, but dim level types size is 1")));
4094 }
4095
4096 TEST_F(HloParserTest, RejectSparseTiles) {
4097 const std::string original = R"(HloModule test
4098
4099 ENTRY test {
4100 ROOT root = f32[10,10]{1,0:D(D,C)T(128,8)} parameter(0)
4101 })";
4102 EXPECT_THAT(ParseAndReturnUnverifiedModule(original).status(),
4103 tensorflow::testing::StatusIs(
4104 tensorflow::error::INVALID_ARGUMENT,
4105 HasSubstr("Layout has tiles, but is for a sparse array")));
4106 }
4107
4108 } // namespace
4109 } // namespace xla
4110