1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17
18 #include <functional>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/base/casts.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "absl/types/variant.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/literal_util.h"
41 #include "tensorflow/compiler/xla/primitive_util.h"
42 #include "tensorflow/compiler/xla/service/computation_layout.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
45 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
46 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
47 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
48 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
49 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
50 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
51 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
52 #include "tensorflow/compiler/xla/service/shape_inference.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 #include "tensorflow/core/lib/gtl/map_util.h"
57 #include "tensorflow/core/platform/protobuf.h"
58
59 namespace xla {
60
61 namespace {
62
63 using absl::StrAppend;
64 using absl::StrCat;
65 using absl::StrFormat;
66 using absl::StrJoin;
67 using std::nullopt;
68 using std::optional;
69
70 // Creates and returns a schedule created using the order of the instructions in
71 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)72 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
73 HloSchedule schedule(module);
74 for (HloComputation* computation : module->computations()) {
75 if (!computation->IsFusionComputation()) {
76 for (HloInstruction* instruction : computation->instructions()) {
77 schedule.GetOrCreateSequence(computation).push_back(instruction);
78 }
79 }
80 }
81 return schedule;
82 }
83
CanInferShape(HloOpcode code)84 bool CanInferShape(HloOpcode code) {
85 switch (code) {
86 case HloOpcode::kAbs:
87 case HloOpcode::kAdd:
88 case HloOpcode::kAddDependency:
89 case HloOpcode::kAfterAll:
90 case HloOpcode::kAtan2:
91 case HloOpcode::kBatchNormGrad:
92 case HloOpcode::kBatchNormInference:
93 case HloOpcode::kBatchNormTraining:
94 case HloOpcode::kBroadcast:
95 case HloOpcode::kCall:
96 case HloOpcode::kCeil:
97 case HloOpcode::kCholesky:
98 case HloOpcode::kClamp:
99 case HloOpcode::kClz:
100 case HloOpcode::kCompare:
101 case HloOpcode::kComplex:
102 case HloOpcode::kConcatenate:
103 case HloOpcode::kConditional:
104 case HloOpcode::kConvolution:
105 case HloOpcode::kCopy:
106 case HloOpcode::kCos:
107 case HloOpcode::kOptimizationBarrier:
108 case HloOpcode::kDivide:
109 case HloOpcode::kDomain:
110 case HloOpcode::kDot:
111 case HloOpcode::kExp:
112 case HloOpcode::kExpm1:
113 case HloOpcode::kFft:
114 case HloOpcode::kFloor:
115 case HloOpcode::kGather:
116 case HloOpcode::kGetDimensionSize:
117 case HloOpcode::kSetDimensionSize:
118 case HloOpcode::kGetTupleElement:
119 case HloOpcode::kImag:
120 case HloOpcode::kIsFinite:
121 case HloOpcode::kLog:
122 case HloOpcode::kLog1p:
123 case HloOpcode::kLogistic:
124 case HloOpcode::kAnd:
125 case HloOpcode::kNot:
126 case HloOpcode::kOr:
127 case HloOpcode::kXor:
128 case HloOpcode::kMap:
129 case HloOpcode::kMaximum:
130 case HloOpcode::kMinimum:
131 case HloOpcode::kMultiply:
132 case HloOpcode::kNegate:
133 case HloOpcode::kPad:
134 case HloOpcode::kPartitionId:
135 case HloOpcode::kPopulationCount:
136 case HloOpcode::kPower:
137 case HloOpcode::kReal:
138 case HloOpcode::kReduce:
139 case HloOpcode::kRemainder:
140 case HloOpcode::kReplicaId:
141 case HloOpcode::kReverse:
142 case HloOpcode::kRoundNearestAfz:
143 case HloOpcode::kRoundNearestEven:
144 case HloOpcode::kRsqrt:
145 case HloOpcode::kScatter:
146 case HloOpcode::kSelect:
147 case HloOpcode::kShiftLeft:
148 case HloOpcode::kShiftRightArithmetic:
149 case HloOpcode::kShiftRightLogical:
150 case HloOpcode::kSign:
151 case HloOpcode::kSin:
152 case HloOpcode::kSqrt:
153 case HloOpcode::kCbrt:
154 case HloOpcode::kReduceWindow:
155 case HloOpcode::kSelectAndScatter:
156 case HloOpcode::kSort:
157 case HloOpcode::kSubtract:
158 case HloOpcode::kTanh:
159 case HloOpcode::kTranspose:
160 case HloOpcode::kTriangularSolve:
161 case HloOpcode::kTuple:
162 case HloOpcode::kWhile:
163 return true;
164 // Technically the following ops do not require an explicit result shape,
165 // but we made it so that we always write the shapes explicitly.
166 case HloOpcode::kAsyncStart:
167 case HloOpcode::kAsyncUpdate:
168 case HloOpcode::kAsyncDone:
169 case HloOpcode::kAllGather:
170 case HloOpcode::kAllGatherStart:
171 case HloOpcode::kAllGatherDone:
172 case HloOpcode::kAllReduce:
173 case HloOpcode::kAllReduceStart:
174 case HloOpcode::kAllReduceDone:
175 case HloOpcode::kAllToAll:
176 case HloOpcode::kCollectivePermute:
177 case HloOpcode::kCollectivePermuteStart:
178 case HloOpcode::kCollectivePermuteDone:
179 case HloOpcode::kCopyDone:
180 case HloOpcode::kCopyStart:
181 case HloOpcode::kDynamicReshape:
182 case HloOpcode::kDynamicSlice:
183 case HloOpcode::kDynamicUpdateSlice:
184 case HloOpcode::kRecv:
185 case HloOpcode::kRecvDone:
186 case HloOpcode::kReduceScatter:
187 case HloOpcode::kSend:
188 case HloOpcode::kSendDone:
189 case HloOpcode::kSlice:
190 // The following ops require an explicit result shape.
191 case HloOpcode::kBitcast:
192 case HloOpcode::kBitcastConvert:
193 case HloOpcode::kConstant:
194 case HloOpcode::kConvert:
195 case HloOpcode::kCustomCall:
196 case HloOpcode::kFusion:
197 case HloOpcode::kInfeed:
198 case HloOpcode::kIota:
199 case HloOpcode::kOutfeed:
200 case HloOpcode::kParameter:
201 case HloOpcode::kReducePrecision:
202 case HloOpcode::kReshape:
203 case HloOpcode::kRng:
204 case HloOpcode::kRngBitGenerator:
205 case HloOpcode::kRngGetAndUpdateState:
206 return false;
207 }
208 }
209
210 // Parser for the HloModule::ToString() format text.
211 class HloParserImpl : public HloParser {
212 public:
213 using LocTy = HloLexer::LocTy;
214
HloParserImpl(absl::string_view str)215 explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
216
217 // Runs the parser and constructs the resulting HLO in the given (empty)
218 // HloModule. Returns the error status in case an error occurred.
219 Status Run(HloModule* module) override;
220
221 // Returns the error information.
GetError() const222 std::string GetError() const { return StrJoin(error_, "\n"); }
223
224 // Stand alone parsing utils for various aggregate data types.
225 StatusOr<Shape> ParseShapeOnly();
226 StatusOr<HloSharding> ParseShardingOnly();
227 StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
228 StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
229 StatusOr<Window> ParseWindowOnly();
230 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
231 StatusOr<PaddingConfig> ParsePaddingConfigOnly();
232 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
233
234 private:
235 // Types of attributes.
236 enum class AttrTy {
237 kBool,
238 kInt64,
239 kInt32,
240 kFloat,
241 kString,
242 kLiteral,
243 kBracedInt64List,
244 kBracedInt64ListList,
245 kHloComputation,
246 kBracedHloComputationList,
247 kFftType,
248 kPaddingType,
249 kComparisonDirection,
250 kComparisonType,
251 kWindow,
252 kConvolutionDimensionNumbers,
253 kSharding,
254 kFrontendAttributes,
255 kParameterReplication,
256 kInstructionList,
257 kSliceRanges,
258 kPaddingConfig,
259 kMetadata,
260 kFusionKind,
261 kDistribution,
262 kDomain,
263 kPrecisionList,
264 kShape,
265 kShapeList,
266 kEnum,
267 kRandomAlgorithm,
268 kAliasing,
269 kComputationLayout,
270 kInstructionAliasing,
271 kCustomCallSchedule,
272 kCustomCallApiVersion,
273 };
274
275 struct AttrConfig {
276 bool required; // whether it's required or optional
277 AttrTy attr_type; // what type it is
278 void* result; // where to store the parsed result.
279 };
280
281 using InstrNameTable =
282 absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
283
284 // Returns the map from the instruction name to the instruction itself and its
285 // location in the current scope.
current_name_table()286 InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
287
288 // Locates an instruction with the given name in the current_name_table() or
289 // returns nullptr.
290 //
291 // When the name is not found or name is empty, if create_missing_instruction_
292 // hook is registered and a "shape" is provided, the hook will be called to
293 // create an instruction. This is useful when we reify parameters as they're
294 // resolved; i.e. for ParseSingleInstruction.
295 std::pair<HloInstruction*, LocTy>* FindInstruction(
296 const std::string& name, const optional<Shape>& shape = nullopt);
297
298 // Parse a single instruction worth of text.
299 bool ParseSingleInstruction(HloModule* module);
300
301 // Parses a module, returning false if an error occurred.
302 bool ParseHloModule(HloModule* module);
303
304 bool ParseComputations(HloModule* module);
305 bool ParseComputation(HloComputation** entry_computation);
306 bool ParseInstructionList(HloComputation** computation,
307 const std::string& computation_name);
308 bool ParseInstruction(HloComputation::Builder* builder,
309 std::string* root_name);
310 bool ParseInstructionRhs(HloComputation::Builder* builder, std::string name,
311 LocTy name_loc, bool allow_attributes = true);
312 bool ParseControlPredecessors(HloInstruction* instruction);
313 bool ParseLiteral(Literal* literal);
314 bool ParseLiteral(Literal* literal, const Shape& shape);
315 bool ParseTupleLiteral(Literal* literal, const Shape& shape);
316 bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
317 bool ParseDenseLiteral(Literal* literal, const Shape& shape);
318
319 // Parses and creates instruction given name, shape, opcode etc. This is
320 // refactored out from ParseInstructionRhs to allow recursion of wrapped
321 // async instructions to allow parsing for wrapped-op-specific attributes.
322 HloInstruction* CreateInstruction(
323 HloComputation::Builder* builder, absl::string_view name,
324 std::optional<Shape> shape, HloOpcode opcode,
325 std::optional<HloOpcode> async_wrapped_opcode,
326 absl::flat_hash_map<std::string, AttrConfig>& attrs,
327 bool allow_attributes,
328 std::vector<HloInstruction*>* preset_operands = nullptr);
329
330 // Sets the sub-value of literal at the given linear index to the
331 // given value. If the literal is dense, it must have the default layout.
332 //
333 // `loc` should be the source location of the value.
334 bool SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
335 Literal* literal);
336 bool SetValueInLiteral(LocTy loc, double value, int64_t index,
337 Literal* literal);
338 bool SetValueInLiteral(LocTy loc, bool value, int64_t index,
339 Literal* literal);
340 bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64_t index,
341 Literal* literal);
342 // `loc` should be the source location of the value.
343 template <typename LiteralNativeT, typename ParsedElemT>
344 bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64_t index,
345 Literal* literal);
346
347 // Checks whether the given value is within the range of LiteralNativeT.
348 // `loc` should be the source location of the value.
349 template <typename LiteralNativeT, typename ParsedElemT>
350 bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
351 template <typename LiteralNativeT>
352 bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
353
354 bool ParseOperands(std::vector<HloInstruction*>* operands,
355 HloComputation::Builder* builder);
356 // Fills parsed operands into 'operands' and expects a certain number of
357 // operands.
358 bool ParseOperands(std::vector<HloInstruction*>* operands,
359 HloComputation::Builder* builder, const int expected_size);
360
361 // Describes the start, limit, and stride on every dimension of the operand
362 // being sliced.
363 struct SliceRanges {
364 std::vector<int64_t> starts;
365 std::vector<int64_t> limits;
366 std::vector<int64_t> strides;
367 };
368
369 // The data parsed for the kDomain instruction.
370 struct DomainData {
371 std::unique_ptr<DomainMetadata> entry_metadata;
372 std::unique_ptr<DomainMetadata> exit_metadata;
373 };
374
375 // attributes ::= (',' attribute)*
376 //
377 // Parses attributes given names and configs of the attributes. Each parsed
378 // result is passed back through the result pointer in corresponding
379 // AttrConfig. Note that the result pointer must point to a optional<T> typed
380 // variable which outlives this function. Returns false on error. You should
381 // not use the any of the results if this function failed.
382 //
383 // If allow_attributes is false, returns an error if any attributes are
384 // present. This is used for contexts in which attributes are not allowed but
385 // e.g. we *also* want to raise an error if any required attributes are
386 // missing.
387 //
388 // Example usage:
389 //
390 // absl::flat_hash_map<std::string, AttrConfig> attrs;
391 // optional<int64_t> foo;
392 // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
393 // optional<Window> bar;
394 // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
395 // if (!ParseAttributes(attrs)) {
396 // return false; // Do not use 'foo' 'bar' if failed.
397 // }
398 // // Do something with 'bar'.
399 // if (foo) { // If attr foo is seen, do something with 'foo'. }
400 //
401 bool ParseAttributes(
402 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
403 bool allow_attributes = true);
404
405 // sub_attributes ::= '{' (','? attribute)* '}'
406 //
407 // Usage is the same as ParseAttributes. See immediately above.
408 bool ParseSubAttributes(
409 const absl::flat_hash_map<std::string, AttrConfig>& attrs);
410
411 // Parses one attribute. If it has already been seen, return error. Returns
412 // true and adds to seen_attrs on success.
413 //
414 // Do not call this except in ParseAttributes or ParseSubAttributes.
415 bool ParseAttributeHelper(
416 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
417 absl::flat_hash_set<std::string>* seen_attrs);
418
419 // Copy attributes from `attrs` to `message`, unless the attribute name is in
420 // `non_proto_attrs`.
421 bool CopyAttributeToProtoMessage(
422 absl::flat_hash_set<std::string> non_proto_attrs,
423 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
424 tensorflow::protobuf::Message* message);
425
426 // Parses an attribute string into a protocol buffer `message`.
427 // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
428 // set of mandatory attributes.
429 // `non_proto_attrs` specifies attributes that are not written to the proto,
430 // but added to the HloInstruction.
431 bool ParseAttributesAsProtoMessage(
432 const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
433 tensorflow::protobuf::Message* message);
434
435 // Parses a name and finds the corresponding hlo computation.
436 bool ParseComputationName(HloComputation** value);
437 // Parses a list of names and finds the corresponding hlo instructions.
438 bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
439 // Pass expect_outer_curlies == true when parsing a Window in the context of a
440 // larger computation. Pass false when parsing a stand-alone Window string.
441 bool ParseWindow(Window* window, bool expect_outer_curlies);
442 bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
443 bool ParsePaddingConfig(PaddingConfig* padding);
444 bool ParseMetadata(OpMetadata* metadata);
445 bool ParseSingleOrListMetadata(
446 tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
447 bool ParseOpShardingType(OpSharding::Type* type);
448 bool ParseListShardingType(std::vector<OpSharding::Type>* types);
449 bool ParseSharding(OpSharding* sharding);
450 bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
451 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
452 bool ParseParameterReplication(ParameterReplication* parameter_replication);
453 bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
454
455 // Parses the metadata behind a kDOmain instruction.
456 bool ParseDomain(DomainData* domain);
457
458 // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
459 bool ParseDxD(const std::string& name, std::vector<int64_t>* result);
460 // Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
461 bool ParseWindowPad(std::vector<std::vector<int64_t>>* pad);
462
463 bool ParseSliceRanges(SliceRanges* result);
464 bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
465 bool ParseHloComputation(HloComputation** result);
466 bool ParseHloComputationList(std::vector<HloComputation*>* result);
467 bool ParseShapeList(std::vector<Shape>* result);
468 bool ParseInt64List(const TokKind start, const TokKind end,
469 const TokKind delim, std::vector<int64_t>* result);
470 bool ParseInt64ListList(const TokKind start, const TokKind end,
471 const TokKind delim,
472 std::vector<std::vector<int64_t>>* result);
473 // 'parse_and_add_item' is an lambda to parse an element in the list and add
474 // the parsed element to the result. It's supposed to capture the result.
475 bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
476 const std::function<bool()>& parse_and_add_item);
477
478 bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
479 bool ParseParamList();
480 bool ParseName(std::string* result);
481 bool ParseAttributeName(std::string* result);
482 bool ParseString(std::string* result);
483 bool ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
484 std::vector<bool>* dynamic_dimensions);
485 bool ParseShape(Shape* result);
486 bool ParseLayout(Layout* layout);
487 bool ParseLayoutIntAttribute(int64_t* attr_value,
488 absl::string_view attr_description);
489 bool ParseDimLevelTypes(std::vector<DimLevelType>* dim_level_types);
490 bool ParseTiles(std::vector<Tile>* tiles);
491 bool ParseOpcode(HloOpcode* opcode,
492 std::optional<HloOpcode>* async_wrapped_opcode);
493 bool ParseFftType(FftType* result);
494 bool ParsePaddingType(PaddingType* result);
495 bool ParseComparisonDirection(ComparisonDirection* result);
496 bool ParseComparisonType(Comparison::Type* result);
497 bool ParseFusionKind(HloInstruction::FusionKind* result);
498 bool ParseRandomDistribution(RandomDistribution* result);
499 bool ParseRandomAlgorithm(RandomAlgorithm* result);
500 bool ParsePrecision(PrecisionConfig::Precision* result);
501 bool ParseInt64(int64_t* result);
502 bool ParseDouble(double* result);
503 bool ParseComplex(std::complex<double>* result);
504 bool ParseBool(bool* result);
505 bool ParseToken(TokKind kind, const std::string& msg);
506
507 using AliasingData =
508 absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
509
510 // Parses the aliasing information from string `s`, returns `false` if it
511 // fails.
512 bool ParseAliasing(AliasingData* data);
513
514 // Parses the entry computation layout.
515 bool ParseComputationLayout(ComputationLayout* computation_layout);
516
517 // Parses the per-instruction aliasing information from string `s`, returns
518 // `false` if it fails.
519 bool ParseInstructionOutputOperandAliasing(
520 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
521 aliasing_output_operand_pairs);
522
523 bool ParseCustomCallSchedule(CustomCallSchedule* result);
524 bool ParseCustomCallApiVersion(CustomCallApiVersion* result);
525 bool ParseShapeIndex(ShapeIndex* out);
526
527 // Returns true if the current token is the beginning of a shape.
528 bool CanBeShape();
529 // Returns true if the current token is the beginning of a
530 // param_list_to_shape.
531 bool CanBeParamListToShape();
532
533 // Logs the current parsing line and the given message. Always returns false.
534 bool TokenError(absl::string_view msg);
535 bool Error(LocTy loc, absl::string_view msg);
536
537 // If the current token is 'kind', eats it (i.e. lexes the next token) and
538 // returns true.
539 bool EatIfPresent(TokKind kind);
540
541 // Adds the instruction to the pool. Returns false and emits an error if the
542 // instruction already exists.
543 bool AddInstruction(const std::string& name, HloInstruction* instruction,
544 LocTy name_loc);
545 // Adds the computation to the pool. Returns false and emits an error if the
546 // computation already exists.
547 bool AddComputation(const std::string& name, HloComputation* computation,
548 LocTy name_loc);
549
550 HloLexer lexer_;
551
552 // A stack for the instruction names. The top of the stack stores the
553 // instruction name table for the current scope.
554 //
555 // A instruction's name is unique among its scope (i.e. its parent
556 // computation), but it's not necessarily unique among all computations in the
557 // module. When there are multiple levels of nested computations, the same
558 // name could appear in both an outer computation and an inner computation. So
559 // we need a stack to make sure a name is only visible within its scope,
560 std::vector<InstrNameTable> scoped_name_tables_;
561
562 // A helper class which pushes and pops to an InstrNameTable stack via RAII.
563 class Scope {
564 public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)565 explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
566 : scoped_name_tables_(scoped_name_tables) {
567 scoped_name_tables_->emplace_back();
568 }
~Scope()569 ~Scope() { scoped_name_tables_->pop_back(); }
570
571 private:
572 std::vector<InstrNameTable>* scoped_name_tables_;
573 };
574
575 // Map from the computation name to the computation itself and its location.
576 absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
577 computation_pool_;
578
579 std::vector<std::unique_ptr<HloComputation>> computations_;
580 std::vector<std::string> error_;
581
582 // When an operand name cannot be resolved, this function is called to create
583 // a parameter instruction with the given name and shape. It registers the
584 // name, instruction, and a placeholder location in the name table. It returns
585 // the newly-created instruction and the placeholder location. If `name` is
586 // empty, this should create the parameter with a generated name. This is
587 // supposed to be set and used only in ParseSingleInstruction.
588 std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
589 const Shape& shape)>
590 create_missing_instruction_;
591
592 // Used to generate names for anonymous instructions.
593 NameUniquer name_uniquer_{/*separator=*/"."};
594 };
595
SplitToInt64s(absl::string_view s,char delim,std::vector<int64_t> * out)596 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64_t>* out) {
597 for (const auto& split : absl::StrSplit(s, delim)) {
598 int64_t val;
599 if (!absl::SimpleAtoi(split, &val)) {
600 return false;
601 }
602 out->push_back(val);
603 }
604 return true;
605 }
606
607 // Creates replica groups from the provided nested array. groups[i] represents
608 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64_t>> groups)609 std::vector<ReplicaGroup> CreateReplicaGroups(
610 absl::Span<const std::vector<int64_t>> groups) {
611 std::vector<ReplicaGroup> replica_groups;
612 absl::c_transform(groups, std::back_inserter(replica_groups),
613 [](const std::vector<int64_t>& ids) {
614 ReplicaGroup group;
615 *group.mutable_replica_ids() = {ids.begin(), ids.end()};
616 return group;
617 });
618 return replica_groups;
619 }
620
Error(LocTy loc,absl::string_view msg)621 bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
622 auto line_col = lexer_.GetLineAndColumn(loc);
623 const unsigned line = line_col.first;
624 const unsigned col = line_col.second;
625 std::vector<std::string> error_lines;
626 error_lines.push_back(
627 StrCat("was parsing ", line, ":", col, ": error: ", msg));
628 error_lines.emplace_back(lexer_.GetLine(loc));
629 error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
630
631 error_.push_back(StrJoin(error_lines, "\n"));
632 VLOG(1) << "Error: " << error_.back();
633 return false;
634 }
635
TokenError(absl::string_view msg)636 bool HloParserImpl::TokenError(absl::string_view msg) {
637 return Error(lexer_.GetLoc(), msg);
638 }
639
Run(HloModule * module)640 Status HloParserImpl::Run(HloModule* module) {
641 lexer_.Lex();
642 if (lexer_.GetKind() == TokKind::kw_HloModule) {
643 // This means that the text contains a full HLO module.
644 if (!ParseHloModule(module)) {
645 return InvalidArgument(
646 "Syntax error when trying to parse the text as a HloModule:\n%s",
647 GetError());
648 }
649 return OkStatus();
650 }
651 // This means that the text is a single HLO instruction.
652 if (!ParseSingleInstruction(module)) {
653 return InvalidArgument(
654 "Syntax error when trying to parse the text as a single "
655 "HloInstruction:\n%s",
656 GetError());
657 }
658 return OkStatus();
659 }
660
661 std::pair<HloInstruction*, HloParserImpl::LocTy>*
FindInstruction(const std::string & name,const optional<Shape> & shape)662 HloParserImpl::FindInstruction(const std::string& name,
663 const optional<Shape>& shape) {
664 std::pair<HloInstruction*, LocTy>* instr = nullptr;
665 if (!name.empty()) {
666 instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
667 }
668
669 // Potentially call the missing instruction hook.
670 if (instr == nullptr && create_missing_instruction_ != nullptr &&
671 scoped_name_tables_.size() == 1) {
672 if (!shape.has_value()) {
673 Error(lexer_.GetLoc(),
674 "Operand had no shape in HLO text; cannot create parameter for "
675 "single-instruction module.");
676 return nullptr;
677 }
678 return create_missing_instruction_(name, *shape);
679 }
680
681 if (instr != nullptr && shape.has_value() &&
682 !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
683 Error(
684 lexer_.GetLoc(),
685 StrCat("The declared operand shape ",
686 ShapeUtil::HumanStringWithLayout(shape.value()),
687 " is not compatible with the shape of the operand instruction ",
688 ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
689 return nullptr;
690 }
691
692 return instr;
693 }
694
ParseShapeIndex(ShapeIndex * out)695 bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
696 if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
697 return false;
698 }
699
700 std::vector<int64_t> idxs;
701 while (lexer_.GetKind() != TokKind::kRbrace) {
702 int64_t idx;
703 if (!ParseInt64(&idx)) {
704 return false;
705 }
706 idxs.push_back(idx);
707 if (!EatIfPresent(TokKind::kComma)) {
708 break;
709 }
710 }
711 if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
712 return false;
713 }
714 *out = ShapeIndex(idxs.begin(), idxs.end());
715 return true;
716 }
717
ParseAliasing(AliasingData * data)718 bool HloParserImpl::ParseAliasing(AliasingData* data) {
719 if (!ParseToken(TokKind::kLbrace,
720 "Expects '{' at the start of aliasing description")) {
721 return false;
722 }
723
724 while (lexer_.GetKind() != TokKind::kRbrace) {
725 ShapeIndex out;
726 if (!ParseShapeIndex(&out)) {
727 return false;
728 }
729 std::string errmsg =
730 "Expected format: <output_shape_index>: (<input_param>, "
731 "<input_param_shape_index>) OR <output_shape_index>: <input_param>";
732 if (!ParseToken(TokKind::kColon, errmsg)) {
733 return false;
734 }
735
736 if (!ParseToken(TokKind::kLparen, errmsg)) {
737 return false;
738 }
739 int64_t param_num;
740 ParseInt64(¶m_num);
741 if (!ParseToken(TokKind::kComma, errmsg)) {
742 return false;
743 }
744 ShapeIndex param_idx;
745 if (!ParseShapeIndex(¶m_idx)) {
746 return false;
747 }
748
749 HloInputOutputAliasConfig::AliasKind alias_kind =
750 HloInputOutputAliasConfig::kMayAlias;
751 if (EatIfPresent(TokKind::kComma)) {
752 std::string type;
753 ParseName(&type);
754 if (type == "must-alias") {
755 alias_kind = HloInputOutputAliasConfig::kMustAlias;
756 } else if (type == "may-alias") {
757 alias_kind = HloInputOutputAliasConfig::kMayAlias;
758 } else {
759 return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
760 }
761 }
762
763 data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
764 std::forward_as_tuple(param_num, param_idx, alias_kind));
765 if (!ParseToken(TokKind::kRparen, errmsg)) {
766 return false;
767 }
768
769 if (!EatIfPresent(TokKind::kComma)) {
770 break;
771 }
772 }
773 if (!ParseToken(TokKind::kRbrace,
774 "Expects '}' at the end of aliasing description")) {
775 return false;
776 }
777 return true;
778 }
779
ParseComputationLayout(ComputationLayout * computation_layout)780 bool HloParserImpl::ParseComputationLayout(
781 ComputationLayout* computation_layout) {
782 if (!ParseToken(TokKind::kLbrace,
783 "Expects '{' at the start of aliasing description")) {
784 return false;
785 }
786 if (!ParseToken(TokKind::kLparen, "Expects ( before parameter shape list")) {
787 return false;
788 }
789 while (lexer_.GetKind() != TokKind::kRparen) {
790 Shape param;
791 if (!ParseShape(¶m)) {
792 return false;
793 }
794 computation_layout->add_parameter_layout(ShapeLayout(param));
795 if (lexer_.GetKind() == TokKind::kRparen) {
796 break;
797 }
798 if (!ParseToken(TokKind::kComma, "Expects , between parameter shapes")) {
799 return false;
800 }
801 }
802
803 if (!ParseToken(TokKind::kRparen,
804 "Expects ) at end of parameter shape list")) {
805 return false;
806 }
807 if (!ParseToken(TokKind::kArrow, "Expects -> before result shape")) {
808 return false;
809 }
810 Shape result;
811 if (!ParseShape(&result)) {
812 return false;
813 }
814 *computation_layout->mutable_result_layout() = ShapeLayout(result);
815 if (!ParseToken(TokKind::kRbrace,
816 "Expects '}' at the end of computation layouts")) {
817 return false;
818 }
819 return true;
820 }
821
ParseInstructionOutputOperandAliasing(std::vector<std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> * aliasing_output_operand_pairs)822 bool HloParserImpl::ParseInstructionOutputOperandAliasing(
823 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
824 aliasing_output_operand_pairs) {
825 if (!ParseToken(
826 TokKind::kLbrace,
827 "Expects '{' at the start of instruction aliasing description")) {
828 return false;
829 }
830
831 while (lexer_.GetKind() != TokKind::kRbrace) {
832 ShapeIndex out;
833 if (!ParseShapeIndex(&out)) {
834 return false;
835 }
836 std::string errmsg =
837 "Expected format: <output_shape_index>: (<operand_index>, "
838 "<operand_shape_index>)";
839 if (!ParseToken(TokKind::kColon, errmsg)) {
840 return false;
841 }
842
843 if (!ParseToken(TokKind::kLparen, errmsg)) {
844 return false;
845 }
846 int64_t operand_index;
847 ParseInt64(&operand_index);
848 if (!ParseToken(TokKind::kComma, errmsg)) {
849 return false;
850 }
851 ShapeIndex operand_shape_index;
852 if (!ParseShapeIndex(&operand_shape_index)) {
853 return false;
854 }
855
856 aliasing_output_operand_pairs->emplace_back(
857 out,
858 std::pair<int64_t, ShapeIndex>{operand_index, operand_shape_index});
859 if (!ParseToken(TokKind::kRparen, errmsg)) {
860 return false;
861 }
862
863 if (!EatIfPresent(TokKind::kComma)) {
864 break;
865 }
866 }
867 if (!ParseToken(
868 TokKind::kRbrace,
869 "Expects '}' at the end of instruction aliasing description")) {
870 return false;
871 }
872 return true;
873 }
874
ParseCustomCallSchedule(CustomCallSchedule * result)875 bool HloParserImpl::ParseCustomCallSchedule(CustomCallSchedule* result) {
876 VLOG(3) << "ParseCustomCallSchedule";
877 if (lexer_.GetKind() != TokKind::kIdent) {
878 return TokenError("expects custom-call schedule");
879 }
880 std::string val = lexer_.GetStrVal();
881 auto status_or_result = StringToCustomCallSchedule(val);
882 if (!status_or_result.ok()) {
883 return TokenError(
884 StrFormat("expects custom-call schedule but sees: %s, error: %s", val,
885 status_or_result.status().error_message()));
886 }
887 *result = status_or_result.ValueOrDie();
888 lexer_.Lex();
889 return true;
890 }
891
ParseCustomCallApiVersion(CustomCallApiVersion * result)892 bool HloParserImpl::ParseCustomCallApiVersion(CustomCallApiVersion* result) {
893 VLOG(3) << "ParseCustomCallApiVersion";
894 if (lexer_.GetKind() != TokKind::kIdent) {
895 return TokenError("expects custom-call API version");
896 }
897 std::string val = lexer_.GetStrVal();
898 auto status_or_result = StringToCustomCallApiVersion(val);
899 if (!status_or_result.ok()) {
900 return TokenError(
901 StrFormat("expects custom-call API version but sees: %s, error: %s",
902 val, status_or_result.status().error_message()));
903 }
904 *result = status_or_result.ValueOrDie();
905 lexer_.Lex();
906 return true;
907 }
908
909 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)910 bool HloParserImpl::ParseHloModule(HloModule* module) {
911 if (lexer_.GetKind() != TokKind::kw_HloModule) {
912 return TokenError("expects HloModule");
913 }
914 // Eat 'HloModule'
915 lexer_.Lex();
916
917 std::string name;
918 if (!ParseName(&name)) {
919 return false;
920 }
921
922 std::optional<bool> is_scheduled;
923 std::optional<AliasingData> aliasing_data;
924 std::optional<bool> alias_passthrough_params;
925 absl::flat_hash_map<std::string, AttrConfig> attrs;
926 std::optional<ComputationLayout> entry_computation_layout;
927
928 attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
929 attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
930 &aliasing_data};
931 attrs["alias_passthrough_params"] = {/*required=*/false, AttrTy::kBool,
932 &alias_passthrough_params};
933 attrs["entry_computation_layout"] = {/*required=*/false,
934 AttrTy::kComputationLayout,
935 &entry_computation_layout};
936 if (!ParseAttributes(attrs)) {
937 return false;
938 }
939 module->set_name(name);
940 if (!ParseComputations(module)) {
941 return false;
942 }
943
944 if (is_scheduled.has_value() && *is_scheduled) {
945 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
946 }
947 HloModuleConfig config = module->config();
948 bool default_config = true;
949 if (alias_passthrough_params.has_value() && *alias_passthrough_params) {
950 config.set_alias_passthrough_params(true);
951 default_config = false;
952 }
953 if (entry_computation_layout.has_value()) {
954 *config.mutable_entry_computation_layout() = *entry_computation_layout;
955 default_config = false;
956 }
957 if (!default_config) {
958 module->set_config(config);
959 }
960 if (aliasing_data) {
961 HloInputOutputAliasConfig alias_config(module->result_shape());
962 for (auto& p : *aliasing_data) {
963 Status st =
964 alias_config.SetUpAlias(p.first, p.second.parameter_number,
965 p.second.parameter_index, p.second.kind);
966 if (!st.ok()) {
967 return TokenError(st.error_message());
968 }
969 }
970 module->input_output_alias_config() = alias_config;
971 }
972
973 return true;
974 }
975
976 // computations ::= (computation)+
ParseComputations(HloModule * module)977 bool HloParserImpl::ParseComputations(HloModule* module) {
978 HloComputation* entry_computation = nullptr;
979 do {
980 if (!ParseComputation(&entry_computation)) {
981 return false;
982 }
983 } while (lexer_.GetKind() != TokKind::kEof);
984
985 for (int i = 0; i < computations_.size(); i++) {
986 // If entry_computation is not nullptr, it means the computation it pointed
987 // to is marked with "ENTRY"; otherwise, no computation is marked with
988 // "ENTRY", and we use the last computation as the entry computation. We
989 // add the non-entry computations as embedded computations to the module.
990 if ((entry_computation != nullptr &&
991 computations_[i].get() != entry_computation) ||
992 (entry_computation == nullptr && i != computations_.size() - 1)) {
993 module->AddEmbeddedComputation(std::move(computations_[i]));
994 continue;
995 }
996 auto computation = module->AddEntryComputation(std::move(computations_[i]));
997 // The parameters and result layouts were set to default layout. Here we
998 // set the layouts to what the hlo text says.
999 for (int p = 0; p < computation->num_parameters(); p++) {
1000 const Shape& param_shape = computation->parameter_instruction(p)->shape();
1001 TF_CHECK_OK(module->mutable_entry_computation_layout()
1002 ->mutable_parameter_layout(p)
1003 ->CopyLayoutFromShape(param_shape));
1004 }
1005 const Shape& result_shape = computation->root_instruction()->shape();
1006 TF_CHECK_OK(module->mutable_entry_computation_layout()
1007 ->mutable_result_layout()
1008 ->CopyLayoutFromShape(result_shape));
1009 }
1010 return true;
1011 }
1012
1013 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list(,
1014 // 'execution_thread='execution_thread)?
ParseComputation(HloComputation ** entry_computation)1015 bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
1016 LocTy maybe_entry_loc = lexer_.GetLoc();
1017 const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
1018
1019 std::string name;
1020 LocTy name_loc = lexer_.GetLoc();
1021 if (!ParseName(&name)) {
1022 return false;
1023 }
1024
1025 LocTy shape_loc = nullptr;
1026 Shape shape;
1027 if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
1028 return false;
1029 }
1030
1031 HloComputation* computation = nullptr;
1032 if (!ParseInstructionList(&computation, name)) {
1033 return false;
1034 }
1035
1036 // If param_list_to_shape was present, check compatibility.
1037 if (shape_loc != nullptr &&
1038 !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
1039 return Error(
1040 shape_loc,
1041 StrCat(
1042 "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
1043 ", is not compatible with that of its root instruction ",
1044 computation->root_instruction()->name(), ", ",
1045 ShapeUtil::HumanString(computation->root_instruction()->shape())));
1046 }
1047 absl::flat_hash_map<std::string, AttrConfig> attrs;
1048 optional<std::string> execution_thread = HloInstruction::kMainExecutionThread;
1049 attrs["execution_thread"] = {/*required=*/false, AttrTy::kString,
1050 &execution_thread};
1051 if (!ParseAttributes(attrs)) {
1052 return false;
1053 }
1054 computation->SetExecutionThread(*execution_thread);
1055 if (is_entry_computation) {
1056 if (*entry_computation != nullptr) {
1057 return Error(maybe_entry_loc, "expects only one ENTRY");
1058 }
1059 *entry_computation = computation;
1060 }
1061
1062 return AddComputation(name, computation, name_loc);
1063 }
1064
1065 // instruction_list ::= '{' instruction_list1 '}'
1066 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const std::string & computation_name)1067 bool HloParserImpl::ParseInstructionList(HloComputation** computation,
1068 const std::string& computation_name) {
1069 Scope scope(&scoped_name_tables_);
1070 HloComputation::Builder builder(computation_name);
1071 if (!ParseToken(TokKind::kLbrace,
1072 "expects '{' at the beginning of instruction list.")) {
1073 return false;
1074 }
1075 std::string root_name;
1076 do {
1077 if (!ParseInstruction(&builder, &root_name)) {
1078 return false;
1079 }
1080 } while (lexer_.GetKind() != TokKind::kRbrace);
1081 if (!ParseToken(TokKind::kRbrace,
1082 "expects '}' at the end of instruction list.")) {
1083 return false;
1084 }
1085 HloInstruction* root = nullptr;
1086 if (!root_name.empty()) {
1087 std::pair<HloInstruction*, LocTy>* root_node =
1088 tensorflow::gtl::FindOrNull(current_name_table(), root_name);
1089
1090 // This means some instruction was marked as ROOT but we didn't find it in
1091 // the pool, which should not happen.
1092 if (root_node == nullptr) {
1093 // LOG(FATAL) crashes the program by calling abort().
1094 LOG(FATAL) << "instruction " << root_name
1095 << " was marked as ROOT but the parser has not seen it before";
1096 }
1097 root = root_node->first;
1098 }
1099
1100 // Now root can be either an existing instruction or a nullptr. If it's a
1101 // nullptr, the implementation of Builder will set the last instruction as
1102 // the root instruction.
1103 computations_.emplace_back(builder.Build(root));
1104 *computation = computations_.back().get();
1105 return true;
1106 }
1107
1108 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,std::string * root_name)1109 bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
1110 std::string* root_name) {
1111 std::string name;
1112 LocTy maybe_root_loc = lexer_.GetLoc();
1113 bool is_root = EatIfPresent(TokKind::kw_ROOT);
1114
1115 const LocTy name_loc = lexer_.GetLoc();
1116 if (!ParseName(&name) ||
1117 !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
1118 return false;
1119 }
1120
1121 if (is_root) {
1122 if (!root_name->empty()) {
1123 return Error(maybe_root_loc, "one computation should have only one ROOT");
1124 }
1125 *root_name = name;
1126 }
1127
1128 return ParseInstructionRhs(builder, name, name_loc);
1129 }
1130
ParseInstructionRhs(HloComputation::Builder * builder,std::string name,LocTy name_loc,bool allow_attributes)1131 bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
1132 std::string name, LocTy name_loc,
1133 bool allow_attributes) {
1134 Shape shape;
1135 HloOpcode opcode;
1136 std::optional<HloOpcode> async_wrapped_opcode;
1137 std::vector<HloInstruction*> operands;
1138
1139 const bool parse_shape = CanBeShape();
1140 if ((parse_shape && !ParseShape(&shape)) ||
1141 !ParseOpcode(&opcode, &async_wrapped_opcode)) {
1142 return false;
1143 }
1144 if (!parse_shape && !CanInferShape(opcode)) {
1145 return TokenError(StrFormat("cannot infer shape for opcode: %s",
1146 HloOpcodeString(opcode)));
1147 }
1148
1149 // Add optional attributes. These are added to any HloInstruction type if
1150 // present.
1151 absl::flat_hash_map<std::string, AttrConfig> attrs;
1152 optional<OpSharding> sharding;
1153 optional<FrontendAttributes> frontend_attributes;
1154 attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
1155 attrs["frontend_attributes"] = {
1156 /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
1157 optional<ParameterReplication> parameter_replication;
1158 attrs["parameter_replication"] = {/*required=*/false,
1159 AttrTy::kParameterReplication,
1160 ¶meter_replication};
1161 optional<std::vector<HloInstruction*>> predecessors;
1162 attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
1163 &predecessors};
1164 optional<OpMetadata> metadata;
1165 attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
1166
1167 optional<std::string> backend_config;
1168 attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
1169 &backend_config};
1170 optional<std::vector<int64_t>> outer_dimension_partitions;
1171 attrs["outer_dimension_partitions"] = {/*required=*/false,
1172 AttrTy::kBracedInt64List,
1173 &outer_dimension_partitions};
1174
1175 std::optional<Shape> maybe_shape;
1176 if (parse_shape) {
1177 maybe_shape = shape;
1178 }
1179 HloInstruction* instruction =
1180 CreateInstruction(builder, name, maybe_shape, opcode,
1181 async_wrapped_opcode, attrs, allow_attributes);
1182 if (instruction == nullptr) {
1183 return false;
1184 }
1185
1186 // Generate a unique name if the name is empty. This is used for nested
1187 // instructions (e.g. the `max` in add(max(x, y), z)).
1188 //
1189 // Otherwise, register the given name with the name uniquer.
1190 if (name.empty()) {
1191 name = name_uniquer_.GetUniqueName(
1192 absl::StrCat(HloOpcodeString(instruction->opcode()), ".anon"));
1193 } else {
1194 name_uniquer_.GetUniqueName(name);
1195 }
1196
1197 instruction->SetAndSanitizeName(name);
1198 if (instruction->name() != name) {
1199 return Error(name_loc,
1200 StrCat("illegal instruction name: ", name,
1201 "; suggest renaming to: ", instruction->name()));
1202 }
1203
1204 // Add shared attributes like metadata to the instruction, if they were seen.
1205 if (sharding) {
1206 instruction->set_sharding(
1207 HloSharding::FromProto(sharding.value()).ValueOrDie());
1208 }
1209 if (parameter_replication) {
1210 int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
1211 const auto& replicated =
1212 parameter_replication->replicated_at_leaf_buffers();
1213 if (leaf_count != replicated.size()) {
1214 return Error(lexer_.GetLoc(),
1215 StrCat("parameter has ", leaf_count,
1216 " leaf buffers, but parameter_replication has ",
1217 replicated.size(), " elements."));
1218 }
1219 instruction->set_parameter_replicated_at_leaf_buffers(replicated);
1220 }
1221 if (predecessors) {
1222 for (auto* pre : *predecessors) {
1223 Status status = pre->AddControlDependencyTo(instruction);
1224 if (!status.ok()) {
1225 return Error(name_loc, StrCat("error adding control dependency for: ",
1226 name, " status: ", status.ToString()));
1227 }
1228 }
1229 }
1230 if (metadata) {
1231 instruction->set_metadata(*metadata);
1232 }
1233 if (backend_config) {
1234 instruction->set_raw_backend_config_string(std::move(*backend_config));
1235 }
1236 if (outer_dimension_partitions) {
1237 instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
1238 }
1239 if (frontend_attributes) {
1240 instruction->set_frontend_attributes(*frontend_attributes);
1241 }
1242 return AddInstruction(name, instruction, name_loc);
1243 }
1244
CreateInstruction(HloComputation::Builder * builder,absl::string_view name,std::optional<Shape> shape,HloOpcode opcode,std::optional<HloOpcode> async_wrapped_opcode,absl::flat_hash_map<std::string,AttrConfig> & attrs,bool allow_attributes,std::vector<HloInstruction * > * preset_operands)1245 HloInstruction* HloParserImpl::CreateInstruction( // NOLINT
1246 HloComputation::Builder* builder, absl::string_view name,
1247 std::optional<Shape> shape, HloOpcode opcode,
1248 std::optional<HloOpcode> async_wrapped_opcode,
1249 absl::flat_hash_map<std::string, AttrConfig>& attrs, bool allow_attributes,
1250 std::vector<HloInstruction*>* preset_operands) {
1251 std::vector<HloInstruction*> operands;
1252 if (preset_operands) {
1253 operands = *preset_operands;
1254 }
1255 const auto maybe_infer_shape =
1256 [&](const std::function<StatusOr<Shape>()>& infer) {
1257 if (shape.has_value()) {
1258 return true;
1259 }
1260 auto inferred = infer();
1261 if (!inferred.ok()) {
1262 return TokenError(StrFormat(
1263 "failed to infer shape for opcode: %s, error: %s",
1264 HloOpcodeString(opcode), inferred.status().error_message()));
1265 }
1266 shape = std::move(inferred).ValueOrDie();
1267 return true;
1268 };
1269
1270 switch (opcode) {
1271 case HloOpcode::kParameter: {
1272 int64_t parameter_number;
1273 if (!ParseToken(TokKind::kLparen,
1274 "expects '(' before parameter number") ||
1275 !ParseInt64(¶meter_number)) {
1276 return nullptr;
1277 }
1278 if (parameter_number < 0) {
1279 Error(lexer_.GetLoc(), "parameter number must be >= 0");
1280 return nullptr;
1281 }
1282 if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
1283 !ParseAttributes(attrs, allow_attributes)) {
1284 return nullptr;
1285 }
1286 std::string param_name(name);
1287 return builder->AddInstruction(HloInstruction::CreateParameter(
1288 parameter_number, *shape, param_name));
1289 }
1290 case HloOpcode::kConstant: {
1291 Literal literal;
1292 if (!ParseToken(TokKind::kLparen,
1293 "expects '(' before constant literal") ||
1294 !ParseLiteral(&literal, *shape) ||
1295 !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
1296 !ParseAttributes(attrs, allow_attributes)) {
1297 return nullptr;
1298 }
1299 return builder->AddInstruction(
1300 HloInstruction::CreateConstant(std::move(literal)));
1301 }
1302 case HloOpcode::kIota: {
1303 optional<int64_t> iota_dimension;
1304 attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
1305 &iota_dimension};
1306 if ((!preset_operands &&
1307 !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
1308 !ParseAttributes(attrs, allow_attributes)) {
1309 return nullptr;
1310 }
1311 return builder->AddInstruction(
1312 HloInstruction::CreateIota(*shape, *iota_dimension));
1313 }
1314 // Unary ops.
1315 case HloOpcode::kAbs:
1316 case HloOpcode::kAllGatherDone:
1317 case HloOpcode::kAllReduceDone:
1318 case HloOpcode::kRoundNearestAfz:
1319 case HloOpcode::kRoundNearestEven:
1320 case HloOpcode::kBitcast:
1321 case HloOpcode::kCeil:
1322 case HloOpcode::kClz:
1323 case HloOpcode::kCollectivePermuteDone:
1324 case HloOpcode::kCopy:
1325 case HloOpcode::kCopyDone:
1326 case HloOpcode::kCos:
1327 case HloOpcode::kOptimizationBarrier:
1328 case HloOpcode::kExp:
1329 case HloOpcode::kExpm1:
1330 case HloOpcode::kImag:
1331 case HloOpcode::kIsFinite:
1332 case HloOpcode::kFloor:
1333 case HloOpcode::kLog:
1334 case HloOpcode::kLog1p:
1335 case HloOpcode::kLogistic:
1336 case HloOpcode::kNot:
1337 case HloOpcode::kNegate:
1338 case HloOpcode::kPopulationCount:
1339 case HloOpcode::kReal:
1340 case HloOpcode::kRsqrt:
1341 case HloOpcode::kSign:
1342 case HloOpcode::kSin:
1343 case HloOpcode::kSqrt:
1344 case HloOpcode::kCbrt:
1345 case HloOpcode::kTanh: {
1346 if ((!preset_operands &&
1347 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1348 !ParseAttributes(attrs, allow_attributes)) {
1349 return nullptr;
1350 }
1351 if (!maybe_infer_shape([&] {
1352 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
1353 })) {
1354 return nullptr;
1355 }
1356 return builder->AddInstruction(
1357 HloInstruction::CreateUnary(*shape, opcode, operands[0]));
1358 }
1359 // Binary ops.
1360 case HloOpcode::kAdd:
1361 case HloOpcode::kDivide:
1362 case HloOpcode::kMultiply:
1363 case HloOpcode::kSubtract:
1364 case HloOpcode::kAtan2:
1365 case HloOpcode::kComplex:
1366 case HloOpcode::kMaximum:
1367 case HloOpcode::kMinimum:
1368 case HloOpcode::kPower:
1369 case HloOpcode::kRemainder:
1370 case HloOpcode::kAnd:
1371 case HloOpcode::kOr:
1372 case HloOpcode::kXor:
1373 case HloOpcode::kShiftLeft:
1374 case HloOpcode::kShiftRightArithmetic:
1375 case HloOpcode::kShiftRightLogical: {
1376 if ((!preset_operands &&
1377 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1378 !ParseAttributes(attrs, allow_attributes)) {
1379 return nullptr;
1380 }
1381 if (!maybe_infer_shape([&] {
1382 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1383 operands[1]);
1384 })) {
1385 return nullptr;
1386 }
1387 return builder->AddInstruction(HloInstruction::CreateBinary(
1388 *shape, opcode, operands[0], operands[1]));
1389 }
1390 // Ternary ops.
1391 case HloOpcode::kClamp:
1392 case HloOpcode::kSelect: {
1393 if ((!preset_operands &&
1394 !ParseOperands(&operands, builder, /*expected_size=*/3)) ||
1395 !ParseAttributes(attrs, allow_attributes)) {
1396 return nullptr;
1397 }
1398 if (!maybe_infer_shape([&] {
1399 return ShapeInference::InferTernaryOpShape(
1400 opcode, operands[0], operands[1], operands[2]);
1401 })) {
1402 return nullptr;
1403 }
1404 return builder->AddInstruction(HloInstruction::CreateTernary(
1405 *shape, opcode, operands[0], operands[1], operands[2]));
1406 }
1407 // Other supported ops.
1408 case HloOpcode::kConvert: {
1409 if ((!preset_operands &&
1410 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1411 !ParseAttributes(attrs, allow_attributes)) {
1412 return nullptr;
1413 }
1414 return builder->AddInstruction(
1415 HloInstruction::CreateConvert(*shape, operands[0]));
1416 }
1417 case HloOpcode::kBitcastConvert: {
1418 if ((!preset_operands &&
1419 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1420 !ParseAttributes(attrs, allow_attributes)) {
1421 return nullptr;
1422 }
1423 return builder->AddInstruction(
1424 HloInstruction::CreateBitcastConvert(*shape, operands[0]));
1425 }
1426 case HloOpcode::kAllGather:
1427 case HloOpcode::kAllGatherStart: {
1428 optional<std::vector<std::vector<int64_t>>> tmp_groups;
1429 optional<std::vector<int64_t>> replica_group_ids;
1430 optional<int64_t> channel_id;
1431 optional<std::vector<int64_t>> dimensions;
1432 optional<bool> constrain_layout;
1433 optional<bool> use_global_device_ids;
1434 attrs["replica_groups"] = {/*required=*/false,
1435 AttrTy::kBracedInt64ListList, &tmp_groups};
1436 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1437 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1438 &dimensions};
1439 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1440 &constrain_layout};
1441 attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1442 &use_global_device_ids};
1443 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1444 !ParseAttributes(attrs, allow_attributes)) {
1445 return nullptr;
1446 }
1447 std::vector<ReplicaGroup> replica_groups;
1448 if (tmp_groups) {
1449 replica_groups = CreateReplicaGroups(*tmp_groups);
1450 }
1451 if (opcode == HloOpcode::kAllGather) {
1452 return builder->AddInstruction(HloInstruction::CreateAllGather(
1453 *shape, operands, dimensions->at(0), replica_groups,
1454 constrain_layout ? *constrain_layout : false, channel_id,
1455 use_global_device_ids ? *use_global_device_ids : false));
1456 }
1457 return builder->AddInstruction(HloInstruction::CreateAllGatherStart(
1458 *shape, operands, dimensions->at(0), replica_groups,
1459 constrain_layout ? *constrain_layout : false, channel_id,
1460 use_global_device_ids ? *use_global_device_ids : false));
1461 }
1462 case HloOpcode::kAllReduce:
1463 case HloOpcode::kAllReduceStart:
1464 case HloOpcode::kReduceScatter: {
1465 optional<std::vector<std::vector<int64_t>>> tmp_groups;
1466 optional<HloComputation*> to_apply;
1467 optional<std::vector<int64_t>> replica_group_ids;
1468 optional<int64_t> channel_id;
1469 optional<bool> constrain_layout;
1470 optional<bool> use_global_device_ids;
1471 optional<std::vector<int64_t>> dimensions;
1472 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1473 &to_apply};
1474 attrs["replica_groups"] = {/*required=*/false,
1475 AttrTy::kBracedInt64ListList, &tmp_groups};
1476 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1477 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1478 &constrain_layout};
1479 attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1480 &use_global_device_ids};
1481 if (opcode == HloOpcode::kReduceScatter) {
1482 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1483 &dimensions};
1484 }
1485 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1486 !ParseAttributes(attrs, allow_attributes)) {
1487 return nullptr;
1488 }
1489 std::vector<ReplicaGroup> replica_groups;
1490 if (tmp_groups) {
1491 replica_groups = CreateReplicaGroups(*tmp_groups);
1492 }
1493 if (opcode == HloOpcode::kAllReduce) {
1494 return builder->AddInstruction(HloInstruction::CreateAllReduce(
1495 *shape, operands, *to_apply, replica_groups,
1496 constrain_layout ? *constrain_layout : false, channel_id,
1497 use_global_device_ids ? *use_global_device_ids : false));
1498 } else if (opcode == HloOpcode::kReduceScatter) {
1499 return builder->AddInstruction(HloInstruction::CreateReduceScatter(
1500 *shape, operands, *to_apply, replica_groups,
1501 constrain_layout ? *constrain_layout : false, channel_id,
1502 use_global_device_ids ? *use_global_device_ids : false,
1503 dimensions->at(0)));
1504 }
1505 return builder->AddInstruction(HloInstruction::CreateAllReduceStart(
1506 *shape, operands, *to_apply, replica_groups,
1507 constrain_layout ? *constrain_layout : false, channel_id,
1508 use_global_device_ids ? *use_global_device_ids : false));
1509 }
1510 case HloOpcode::kAllToAll: {
1511 optional<std::vector<std::vector<int64_t>>> tmp_groups;
1512 attrs["replica_groups"] = {/*required=*/false,
1513 AttrTy::kBracedInt64ListList, &tmp_groups};
1514 optional<int64_t> channel_id;
1515 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1516 optional<std::vector<int64_t>> dimensions;
1517 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1518 &dimensions};
1519 optional<bool> constrain_layout;
1520 attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1521 &constrain_layout};
1522 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1523 !ParseAttributes(attrs, allow_attributes) ||
1524 (dimensions && dimensions->size() != 1)) {
1525 return nullptr;
1526 }
1527 std::vector<ReplicaGroup> replica_groups;
1528 if (tmp_groups) {
1529 replica_groups = CreateReplicaGroups(*tmp_groups);
1530 }
1531 optional<int64_t> split_dimension;
1532 if (dimensions) {
1533 split_dimension = dimensions->at(0);
1534 }
1535 return builder->AddInstruction(HloInstruction::CreateAllToAll(
1536 *shape, operands, replica_groups,
1537 constrain_layout ? *constrain_layout : false, channel_id,
1538 split_dimension));
1539 }
1540 case HloOpcode::kCollectivePermute:
1541 case HloOpcode::kCollectivePermuteStart: {
1542 optional<std::vector<std::vector<int64_t>>> source_targets;
1543 attrs["source_target_pairs"] = {
1544 /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
1545 optional<int64_t> channel_id;
1546 attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1547 optional<std::vector<std::vector<int64_t>>> slice_sizes;
1548 attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList,
1549 &slice_sizes};
1550 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1551 !ParseAttributes(attrs, allow_attributes)) {
1552 return nullptr;
1553 }
1554 std::vector<std::pair<int64_t, int64_t>> pairs(source_targets->size());
1555 for (int i = 0; i < pairs.size(); i++) {
1556 if ((*source_targets)[i].size() != 2) {
1557 TokenError("expects 'source_target_pairs=' to be a list of pairs");
1558 return nullptr;
1559 }
1560 pairs[i].first = (*source_targets)[i][0];
1561 pairs[i].second = (*source_targets)[i][1];
1562 }
1563 if (!slice_sizes.has_value()) {
1564 if (operands.size() != 1) {
1565 TokenError(
1566 "CollectivePermute and CollectivePermuteStart must "
1567 "have exactly one operand (input buffer) unless "
1568 "it performs dynamic-slice and in-place update.");
1569 return nullptr;
1570 }
1571 if (opcode == HloOpcode::kCollectivePermute) {
1572 return builder->AddInstruction(
1573 HloInstruction::CreateCollectivePermute(*shape, operands[0],
1574 pairs, channel_id));
1575 }
1576 if (opcode == HloOpcode::kCollectivePermuteStart) {
1577 return builder->AddInstruction(
1578 HloInstruction::CreateCollectivePermuteStart(*shape, operands[0],
1579 pairs, channel_id));
1580 }
1581 LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1582 "CollectivePermuteStart, but got "
1583 << HloOpcodeString(opcode);
1584 }
1585 if (operands.size() != 4) {
1586 TokenError(
1587 "CollectivePermute and CollectivePermuteStart must "
1588 "have exactly four operands for dynamic-slice and "
1589 "in-place update.");
1590 return nullptr;
1591 }
1592 if (opcode == HloOpcode::kCollectivePermute) {
1593 return builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1594 *shape, operands[0], operands[1], operands[2], operands[3], pairs,
1595 *slice_sizes, channel_id));
1596 }
1597 if (opcode == HloOpcode::kCollectivePermuteStart) {
1598 return builder->AddInstruction(
1599 HloInstruction::CreateCollectivePermuteStart(
1600 *shape, operands[0], operands[1], operands[2], operands[3],
1601 pairs, *slice_sizes, channel_id));
1602 }
1603 LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1604 "CollectivePermuteStart, but got "
1605 << HloOpcodeString(opcode);
1606 }
1607 case HloOpcode::kAsyncStart:
1608 case HloOpcode::kAsyncUpdate:
1609 case HloOpcode::kAsyncDone: {
1610 std::optional<HloComputation*> async_computation;
1611 if (!preset_operands && !ParseOperands(&operands, builder)) {
1612 return nullptr;
1613 }
1614 auto is_async_shape_correct = [](const Shape& shape) {
1615 return shape.IsTuple() && shape.tuple_shapes_size() >= 2 &&
1616 shape.tuple_shapes(0).IsTuple();
1617 };
1618 optional<int64_t> async_group_id;
1619 attrs["async_group_id"] = {/*required=*/false, AttrTy::kInt64,
1620 &async_group_id};
1621 optional<std::string> async_execution_thread =
1622 HloInstruction::kMainExecutionThread;
1623 attrs["async_execution_thread"] = {/*required=*/false, AttrTy::kString,
1624 &async_execution_thread};
1625 if (async_wrapped_opcode) {
1626 std::vector<HloInstruction*> async_wrapped_operands;
1627 std::vector<Shape> async_wrapped_operand_shapes;
1628 Shape async_wrapped_root_shape;
1629 if (opcode == HloOpcode::kAsyncStart) {
1630 for (const HloInstruction* operand : operands) {
1631 async_wrapped_operand_shapes.push_back(operand->shape());
1632 }
1633 } else {
1634 if (operands.size() != 1 ||
1635 !is_async_shape_correct(operands[0]->shape())) {
1636 TokenError(
1637 "AsyncUpdate and AsyncDone expect a single operand in the form "
1638 "of "
1639 "((async-operands), async-outputs, state).");
1640 return nullptr;
1641 }
1642 async_wrapped_operand_shapes =
1643 operands[0]->shape().tuple_shapes(0).tuple_shapes();
1644 }
1645
1646 if (opcode == HloOpcode::kAsyncDone) {
1647 async_wrapped_root_shape = *shape;
1648 } else {
1649 if (!is_async_shape_correct(*shape)) {
1650 TokenError(
1651 "AsyncStart and AsyncUpdate expect the op shape to be in the "
1652 "form of "
1653 "((async-operands), async-outputs, state).");
1654 return nullptr;
1655 }
1656 async_wrapped_root_shape = shape->tuple_shapes(1);
1657 }
1658 HloComputation::Builder async_wrapped_builder("async_wrapped");
1659 async_wrapped_operands.reserve(async_wrapped_operand_shapes.size());
1660 for (int i = 0; i < async_wrapped_operand_shapes.size(); ++i) {
1661 async_wrapped_operands.push_back(async_wrapped_builder.AddInstruction(
1662 HloInstruction::CreateParameter(
1663 i, async_wrapped_operand_shapes.at(i), "async_param")));
1664 }
1665 HloInstruction* root =
1666 CreateInstruction(&async_wrapped_builder, "async_op",
1667 async_wrapped_root_shape, *async_wrapped_opcode,
1668 /*async_wrapped_opcode=*/std::nullopt, attrs,
1669 allow_attributes, &async_wrapped_operands);
1670 if (!root) {
1671 return nullptr;
1672 }
1673 computations_.emplace_back(async_wrapped_builder.Build(root));
1674 async_computation = computations_.back().get();
1675
1676 } else {
1677 attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
1678 &async_computation};
1679 if (!ParseAttributes(attrs, allow_attributes)) {
1680 return nullptr;
1681 }
1682 }
1683 if (opcode == HloOpcode::kAsyncStart) {
1684 return builder->AddInstruction(HloInstruction::CreateAsyncStart(
1685 *shape, operands, *async_computation, async_group_id,
1686 *async_execution_thread));
1687 }
1688 if (opcode == HloOpcode::kAsyncUpdate) {
1689 return builder->AddInstruction(HloInstruction::CreateAsyncUpdate(
1690 *shape, operands[0], *async_computation, async_group_id,
1691 *async_execution_thread));
1692 }
1693 return builder->AddInstruction(HloInstruction::CreateAsyncDone(
1694 *shape, operands[0], *async_computation, async_group_id,
1695 *async_execution_thread));
1696 }
1697 case HloOpcode::kCopyStart: {
1698 // If the is_cross_program_prefetch attribute is not present then default
1699 // to false.
1700 optional<bool> is_cross_program_prefetch = false;
1701 attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
1702 &is_cross_program_prefetch};
1703 if ((!preset_operands &&
1704 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1705 !ParseAttributes(attrs, allow_attributes)) {
1706 return nullptr;
1707 }
1708 return builder->AddInstruction(HloInstruction::CreateCopyStart(
1709 *shape, operands[0], *is_cross_program_prefetch));
1710 }
1711 case HloOpcode::kReplicaId: {
1712 if ((!preset_operands &&
1713 !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
1714 !ParseAttributes(attrs, allow_attributes)) {
1715 return nullptr;
1716 }
1717 return builder->AddInstruction(HloInstruction::CreateReplicaId());
1718 }
1719 case HloOpcode::kPartitionId: {
1720 if ((!preset_operands &&
1721 !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
1722 !ParseAttributes(attrs, allow_attributes)) {
1723 return nullptr;
1724 }
1725 return builder->AddInstruction(HloInstruction::CreatePartitionId());
1726 }
1727 case HloOpcode::kDynamicReshape: {
1728 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1729 !ParseAttributes(attrs, allow_attributes)) {
1730 return nullptr;
1731 }
1732 return builder->AddInstruction(HloInstruction::CreateDynamicReshape(
1733 *shape, operands[0],
1734 absl::Span<HloInstruction* const>(operands).subspan(1)));
1735 }
1736 case HloOpcode::kReshape: {
1737 optional<int64_t> inferred_dimension;
1738 attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
1739 &inferred_dimension};
1740 if ((!preset_operands &&
1741 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1742 !ParseAttributes(attrs, allow_attributes)) {
1743 return nullptr;
1744 }
1745 return builder->AddInstruction(HloInstruction::CreateReshape(
1746 *shape, operands[0], inferred_dimension.value_or(-1)));
1747 }
1748 case HloOpcode::kAfterAll: {
1749 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1750 !ParseAttributes(attrs, allow_attributes)) {
1751 return nullptr;
1752 }
1753 if (operands.empty()) {
1754 return builder->AddInstruction(HloInstruction::CreateToken());
1755 }
1756 return builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
1757 }
1758 case HloOpcode::kAddDependency: {
1759 if ((!preset_operands &&
1760 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1761 !ParseAttributes(attrs, allow_attributes)) {
1762 return nullptr;
1763 }
1764 return builder->AddInstruction(
1765 HloInstruction::CreateAddDependency(operands[0], operands[1]));
1766 }
1767 case HloOpcode::kSort: {
1768 optional<std::vector<int64_t>> dimensions;
1769 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1770 &dimensions};
1771 optional<bool> is_stable = false;
1772 attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
1773 optional<HloComputation*> to_apply;
1774 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1775 &to_apply};
1776 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1777 !ParseAttributes(attrs, allow_attributes) ||
1778 dimensions->size() != 1) {
1779 return nullptr;
1780 }
1781 if (!maybe_infer_shape([&] {
1782 absl::InlinedVector<const Shape*, 2> arg_shapes;
1783 arg_shapes.reserve(operands.size());
1784 for (auto* operand : operands) {
1785 arg_shapes.push_back(&operand->shape());
1786 }
1787 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1788 })) {
1789 return nullptr;
1790 }
1791 return builder->AddInstruction(
1792 HloInstruction::CreateSort(*shape, dimensions->at(0), operands,
1793 to_apply.value(), is_stable.value()));
1794 }
1795 case HloOpcode::kTuple: {
1796 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1797 !ParseAttributes(attrs, allow_attributes)) {
1798 return nullptr;
1799 }
1800 if (!maybe_infer_shape([&] {
1801 absl::InlinedVector<const Shape*, 2> arg_shapes;
1802 arg_shapes.reserve(operands.size());
1803 for (auto* operand : operands) {
1804 arg_shapes.push_back(&operand->shape());
1805 }
1806 return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1807 })) {
1808 return nullptr;
1809 }
1810 // HloInstruction::CreateTuple() infers the shape of the tuple from
1811 // operands and should not be used here.
1812 return builder->AddInstruction(
1813 HloInstruction::CreateVariadic(*shape, HloOpcode::kTuple, operands));
1814 }
1815 case HloOpcode::kWhile: {
1816 optional<HloComputation*> condition;
1817 optional<HloComputation*> body;
1818 attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
1819 &condition};
1820 attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
1821 if ((!preset_operands &&
1822 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1823 !ParseAttributes(attrs, allow_attributes)) {
1824 return nullptr;
1825 }
1826 if (!maybe_infer_shape([&] {
1827 return ShapeInference::InferWhileShape(
1828 condition.value()->ComputeProgramShape(),
1829 body.value()->ComputeProgramShape(), operands[0]->shape());
1830 })) {
1831 return nullptr;
1832 }
1833 return builder->AddInstruction(HloInstruction::CreateWhile(
1834 *shape, *condition, *body, /*init=*/operands[0]));
1835 }
1836 case HloOpcode::kRecv: {
1837 optional<int64_t> channel_id;
1838 // If the is_host_transfer attribute is not present then default to false.
1839 optional<bool> is_host_transfer = false;
1840 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1841 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1842 &is_host_transfer};
1843 if ((!preset_operands &&
1844 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1845 !ParseAttributes(attrs, allow_attributes)) {
1846 return nullptr;
1847 }
1848 // If the is_host_transfer attribute is not present then default to false.
1849 return builder->AddInstruction(HloInstruction::CreateRecv(
1850 shape->tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
1851 }
1852 case HloOpcode::kRecvDone: {
1853 optional<int64_t> channel_id;
1854 // If the is_host_transfer attribute is not present then default to false.
1855 optional<bool> is_host_transfer = false;
1856 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1857 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1858 &is_host_transfer};
1859 if ((!preset_operands &&
1860 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1861 !ParseAttributes(attrs, allow_attributes)) {
1862 return nullptr;
1863 }
1864 if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1865 return nullptr;
1866 }
1867 if (channel_id != operands[0]->channel_id()) {
1868 return nullptr;
1869 }
1870 return builder->AddInstruction(
1871 HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
1872 }
1873 case HloOpcode::kSend: {
1874 optional<int64_t> channel_id;
1875 // If the is_host_transfer attribute is not present then default to false.
1876 optional<bool> is_host_transfer = false;
1877 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1878 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1879 &is_host_transfer};
1880 if ((!preset_operands &&
1881 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1882 !ParseAttributes(attrs, allow_attributes)) {
1883 return nullptr;
1884 }
1885 return builder->AddInstruction(HloInstruction::CreateSend(
1886 operands[0], operands[1], *channel_id, *is_host_transfer));
1887 }
1888 case HloOpcode::kSendDone: {
1889 optional<int64_t> channel_id;
1890 // If the is_host_transfer attribute is not present then default to false.
1891 optional<bool> is_host_transfer = false;
1892 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1893 attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1894 &is_host_transfer};
1895 if ((!preset_operands &&
1896 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1897 !ParseAttributes(attrs, allow_attributes)) {
1898 return nullptr;
1899 }
1900 if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1901 return nullptr;
1902 }
1903 if (channel_id != operands[0]->channel_id()) {
1904 return nullptr;
1905 }
1906 return builder->AddInstruction(
1907 HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1908 }
1909 case HloOpcode::kGetTupleElement: {
1910 optional<int64_t> index;
1911 attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1912 if ((!preset_operands &&
1913 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1914 !ParseAttributes(attrs, allow_attributes)) {
1915 return nullptr;
1916 }
1917 if (!maybe_infer_shape([&] {
1918 return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
1919 *index);
1920 })) {
1921 return nullptr;
1922 }
1923 return builder->AddInstruction(
1924 HloInstruction::CreateGetTupleElement(*shape, operands[0], *index));
1925 }
1926 case HloOpcode::kCall: {
1927 optional<HloComputation*> to_apply;
1928 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1929 &to_apply};
1930 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1931 !ParseAttributes(attrs, allow_attributes)) {
1932 return nullptr;
1933 }
1934 if (!maybe_infer_shape([&] {
1935 absl::InlinedVector<const Shape*, 2> arg_shapes;
1936 arg_shapes.reserve(operands.size());
1937 for (auto* operand : operands) {
1938 arg_shapes.push_back(&operand->shape());
1939 }
1940 return ShapeInference::InferCallShape(
1941 arg_shapes, to_apply.value()->ComputeProgramShape());
1942 })) {
1943 return nullptr;
1944 }
1945 return builder->AddInstruction(
1946 HloInstruction::CreateCall(*shape, operands, *to_apply));
1947 }
1948 case HloOpcode::kReduceWindow: {
1949 optional<HloComputation*> reduce_computation;
1950 optional<Window> window;
1951 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1952 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1953 &reduce_computation};
1954 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1955 !ParseAttributes(attrs, allow_attributes)) {
1956 return nullptr;
1957 }
1958 if (!window) {
1959 window.emplace();
1960 }
1961 if (operands.size() % 2) {
1962 TokenError(StrCat("expects an even number of operands, but has ",
1963 operands.size(), " operands"));
1964 return nullptr;
1965 }
1966 if (!maybe_infer_shape([&] {
1967 return ShapeInference::InferReduceWindowShape(
1968 operands[0]->shape(), operands[1]->shape(), *window,
1969 reduce_computation.value()->ComputeProgramShape());
1970 })) {
1971 return nullptr;
1972 }
1973 return builder->AddInstruction(HloInstruction::CreateReduceWindow(
1974 *shape, /*operands=*/
1975 absl::Span<HloInstruction* const>(operands).subspan(
1976 0, operands.size() / 2),
1977 /*init_values=*/
1978 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1979 2),
1980 *window, *reduce_computation));
1981 }
1982 case HloOpcode::kConvolution: {
1983 optional<Window> window;
1984 optional<ConvolutionDimensionNumbers> dnums;
1985 optional<int64_t> feature_group_count;
1986 optional<int64_t> batch_group_count;
1987 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1988 attrs["dim_labels"] = {/*required=*/true,
1989 AttrTy::kConvolutionDimensionNumbers, &dnums};
1990 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1991 &feature_group_count};
1992 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1993 &batch_group_count};
1994 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1995 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1996 &operand_precision};
1997 if ((!preset_operands &&
1998 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1999 !ParseAttributes(attrs, allow_attributes)) {
2000 return nullptr;
2001 }
2002 if (!window) {
2003 window.emplace();
2004 }
2005 if (!feature_group_count) {
2006 feature_group_count = 1;
2007 }
2008 if (!batch_group_count) {
2009 batch_group_count = 1;
2010 }
2011 PrecisionConfig precision_config;
2012 if (operand_precision) {
2013 *precision_config.mutable_operand_precision() = {
2014 operand_precision->begin(), operand_precision->end()};
2015 } else {
2016 precision_config.mutable_operand_precision()->Resize(
2017 operands.size(), PrecisionConfig::DEFAULT);
2018 }
2019 if (!maybe_infer_shape([&] {
2020 return ShapeInference::InferConvolveShape(
2021 operands[0]->shape(), operands[1]->shape(),
2022 *feature_group_count, *batch_group_count, *window, *dnums,
2023 /*preferred_element_type=*/std::nullopt);
2024 })) {
2025 return nullptr;
2026 }
2027 return builder->AddInstruction(HloInstruction::CreateConvolve(
2028 *shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
2029 feature_group_count.value(), batch_group_count.value(), *window,
2030 *dnums, precision_config));
2031 }
2032 case HloOpcode::kFft: {
2033 optional<FftType> fft_type;
2034 optional<std::vector<int64_t>> fft_length;
2035 attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
2036 attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
2037 &fft_length};
2038 if ((!preset_operands &&
2039 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2040 !ParseAttributes(attrs, allow_attributes)) {
2041 return nullptr;
2042 }
2043 if (!maybe_infer_shape([&] {
2044 return ShapeInference::InferFftShape(operands[0]->shape(),
2045 *fft_type, *fft_length);
2046 })) {
2047 return nullptr;
2048 }
2049 return builder->AddInstruction(HloInstruction::CreateFft(
2050 *shape, operands[0], *fft_type, *fft_length));
2051 }
2052 case HloOpcode::kTriangularSolve: {
2053 TriangularSolveOptions options;
2054 if ((!preset_operands &&
2055 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2056 (allow_attributes && !ParseAttributesAsProtoMessage(
2057 /*non_proto_attrs=*/attrs, &options))) {
2058 return nullptr;
2059 }
2060 if (!maybe_infer_shape([&] {
2061 return ShapeInference::InferTriangularSolveShape(
2062 operands[0]->shape(), operands[1]->shape(), options);
2063 })) {
2064 return nullptr;
2065 }
2066 return builder->AddInstruction(HloInstruction::CreateTriangularSolve(
2067 *shape, operands[0], operands[1], options));
2068 }
2069 case HloOpcode::kCompare: {
2070 optional<ComparisonDirection> direction;
2071 optional<Comparison::Type> type;
2072 attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
2073 &direction};
2074 attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
2075 if ((!preset_operands &&
2076 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2077 !ParseAttributes(attrs, allow_attributes)) {
2078 return nullptr;
2079 }
2080 if (!maybe_infer_shape([&] {
2081 return ShapeInference::InferBinaryOpShape(opcode, operands[0],
2082 operands[1]);
2083 })) {
2084 return nullptr;
2085 }
2086 return builder->AddInstruction(HloInstruction::CreateCompare(
2087 *shape, operands[0], operands[1], *direction, type));
2088 }
2089 case HloOpcode::kCholesky: {
2090 CholeskyOptions options;
2091 if ((!preset_operands &&
2092 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2093 (allow_attributes && !ParseAttributesAsProtoMessage(
2094 /*non_proto_attrs=*/attrs, &options))) {
2095 return nullptr;
2096 }
2097 if (!maybe_infer_shape([&] {
2098 return ShapeInference::InferCholeskyShape(operands[0]->shape());
2099 })) {
2100 return nullptr;
2101 }
2102 return builder->AddInstruction(
2103 HloInstruction::CreateCholesky(*shape, operands[0], options));
2104 }
2105 case HloOpcode::kBroadcast: {
2106 if (!preset_operands &&
2107 !ParseOperands(&operands, builder, /*expected_size=*/1)) {
2108 return nullptr;
2109 }
2110
2111 // The `dimensions` attr is optional if the broadcasted operand is a
2112 // scalar; in that case we can infer it to be {}.
2113 bool operand_is_scalar = ShapeUtil::IsScalar(operands[0]->shape());
2114 optional<std::vector<int64_t>> broadcast_dimensions;
2115 attrs["dimensions"] = {/*required=*/!operand_is_scalar,
2116 AttrTy::kBracedInt64List, &broadcast_dimensions};
2117 if (!ParseAttributes(attrs, allow_attributes)) {
2118 return nullptr;
2119 }
2120 if (operand_is_scalar && !broadcast_dimensions.has_value()) {
2121 broadcast_dimensions.emplace();
2122 }
2123
2124 if (!maybe_infer_shape([&] {
2125 return ShapeInference::InferBroadcastShape(operands[0]->shape(),
2126 *broadcast_dimensions);
2127 })) {
2128 return nullptr;
2129 }
2130 return builder->AddInstruction(HloInstruction::CreateBroadcast(
2131 *shape, operands[0], *broadcast_dimensions));
2132 }
2133 case HloOpcode::kConcatenate: {
2134 optional<std::vector<int64_t>> dimensions;
2135 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2136 &dimensions};
2137 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2138 !ParseAttributes(attrs, allow_attributes) ||
2139 dimensions->size() != 1) {
2140 return nullptr;
2141 }
2142 if (!maybe_infer_shape([&] {
2143 absl::InlinedVector<const Shape*, 2> arg_shapes;
2144 arg_shapes.reserve(operands.size());
2145 for (auto* operand : operands) {
2146 arg_shapes.push_back(&operand->shape());
2147 }
2148 return ShapeInference::InferConcatOpShape(arg_shapes,
2149 dimensions->at(0));
2150 })) {
2151 return nullptr;
2152 }
2153 return builder->AddInstruction(HloInstruction::CreateConcatenate(
2154 *shape, operands, dimensions->at(0)));
2155 }
2156 case HloOpcode::kMap: {
2157 optional<HloComputation*> to_apply;
2158 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2159 &to_apply};
2160 optional<std::vector<int64_t>> dimensions;
2161 attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
2162 &dimensions};
2163 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2164 !ParseAttributes(attrs, allow_attributes)) {
2165 return nullptr;
2166 }
2167 if (!maybe_infer_shape([&] {
2168 absl::InlinedVector<const Shape*, 2> arg_shapes;
2169 arg_shapes.reserve(operands.size());
2170 for (auto* operand : operands) {
2171 arg_shapes.push_back(&operand->shape());
2172 }
2173 return ShapeInference::InferMapShape(
2174 arg_shapes, to_apply.value()->ComputeProgramShape(),
2175 *dimensions);
2176 })) {
2177 return nullptr;
2178 }
2179 return builder->AddInstruction(
2180 HloInstruction::CreateMap(*shape, operands, *to_apply));
2181 }
2182 case HloOpcode::kReduce: {
2183 optional<HloComputation*> reduce_computation;
2184 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2185 &reduce_computation};
2186 optional<std::vector<int64_t>> dimensions_to_reduce;
2187 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2188 &dimensions_to_reduce};
2189 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2190 !ParseAttributes(attrs, allow_attributes)) {
2191 return nullptr;
2192 }
2193 if (operands.size() % 2) {
2194 TokenError(StrCat("expects an even number of operands, but has ",
2195 operands.size(), " operands"));
2196 return nullptr;
2197 }
2198 if (!maybe_infer_shape([&] {
2199 absl::InlinedVector<const Shape*, 2> arg_shapes;
2200 arg_shapes.reserve(operands.size());
2201 for (auto* operand : operands) {
2202 arg_shapes.push_back(&operand->shape());
2203 }
2204 return ShapeInference::InferReduceShape(
2205 arg_shapes, *dimensions_to_reduce,
2206 reduce_computation.value()->ComputeProgramShape());
2207 })) {
2208 return nullptr;
2209 }
2210 return builder->AddInstruction(HloInstruction::CreateReduce(
2211 *shape, /*operands=*/
2212 absl::Span<HloInstruction* const>(operands).subspan(
2213 0, operands.size() / 2),
2214 /*init_values=*/
2215 absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2216 2),
2217 *dimensions_to_reduce, *reduce_computation));
2218 }
2219 case HloOpcode::kReverse: {
2220 optional<std::vector<int64_t>> dimensions;
2221 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2222 &dimensions};
2223 if ((!preset_operands &&
2224 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2225 !ParseAttributes(attrs, allow_attributes)) {
2226 return nullptr;
2227 }
2228 if (!maybe_infer_shape([&] {
2229 return ShapeInference::InferReverseShape(operands[0]->shape(),
2230 *dimensions);
2231 })) {
2232 return nullptr;
2233 }
2234 return builder->AddInstruction(
2235 HloInstruction::CreateReverse(*shape, operands[0], *dimensions));
2236 }
2237 case HloOpcode::kSelectAndScatter: {
2238 optional<HloComputation*> select;
2239 attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
2240 optional<HloComputation*> scatter;
2241 attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
2242 optional<Window> window;
2243 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2244 if ((!preset_operands &&
2245 !ParseOperands(&operands, builder, /*expected_size=*/3)) ||
2246 !ParseAttributes(attrs, allow_attributes)) {
2247 return nullptr;
2248 }
2249 if (!window) {
2250 window.emplace();
2251 }
2252 if (!maybe_infer_shape([&] {
2253 return ShapeInference::InferSelectAndScatterShape(
2254 operands[0]->shape(), select.value()->ComputeProgramShape(),
2255 *window, operands[1]->shape(), operands[2]->shape(),
2256 scatter.value()->ComputeProgramShape());
2257 })) {
2258 return nullptr;
2259 }
2260 return builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
2261 *shape, /*operand=*/operands[0], *select, *window,
2262 /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
2263 }
2264 case HloOpcode::kSlice: {
2265 optional<SliceRanges> slice_ranges;
2266 attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
2267 if ((!preset_operands &&
2268 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2269 !ParseAttributes(attrs, allow_attributes)) {
2270 return nullptr;
2271 }
2272 return builder->AddInstruction(HloInstruction::CreateSlice(
2273 *shape, operands[0], slice_ranges->starts, slice_ranges->limits,
2274 slice_ranges->strides));
2275 }
2276 case HloOpcode::kDynamicSlice: {
2277 optional<std::vector<int64_t>> dynamic_slice_sizes;
2278 attrs["dynamic_slice_sizes"] = {
2279 /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
2280 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2281 !ParseAttributes(attrs, allow_attributes)) {
2282 return nullptr;
2283 }
2284 if (operands.empty()) {
2285 TokenError("Expected at least one operand.");
2286 return nullptr;
2287 }
2288 if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
2289 operands.size() != 1 + operands[0]->shape().rank()) {
2290 TokenError("Wrong number of operands.");
2291 return nullptr;
2292 }
2293 return builder->AddInstruction(HloInstruction::CreateDynamicSlice(
2294 *shape, /*operand=*/operands[0],
2295 /*start_indices=*/absl::MakeSpan(operands).subspan(1),
2296 *dynamic_slice_sizes));
2297 }
2298 case HloOpcode::kDynamicUpdateSlice: {
2299 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2300 !ParseAttributes(attrs, allow_attributes)) {
2301 return nullptr;
2302 }
2303 if (operands.size() < 2) {
2304 TokenError("Expected at least two operands.");
2305 return nullptr;
2306 }
2307 if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
2308 operands.size() != 2 + operands[0]->shape().rank()) {
2309 TokenError("Wrong number of operands.");
2310 return nullptr;
2311 }
2312 return builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2313 *shape, /*operand=*/operands[0], /*update=*/operands[1],
2314 /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
2315 }
2316 case HloOpcode::kTranspose: {
2317 optional<std::vector<int64_t>> dimensions;
2318 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2319 &dimensions};
2320 if ((!preset_operands &&
2321 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2322 !ParseAttributes(attrs, allow_attributes)) {
2323 return nullptr;
2324 }
2325 if (!maybe_infer_shape([&] {
2326 return ShapeInference::InferTransposeShape(operands[0]->shape(),
2327 *dimensions);
2328 })) {
2329 return nullptr;
2330 }
2331 return builder->AddInstruction(
2332 HloInstruction::CreateTranspose(*shape, operands[0], *dimensions));
2333 }
2334 case HloOpcode::kBatchNormTraining: {
2335 optional<float> epsilon;
2336 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2337 optional<int64_t> feature_index;
2338 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2339 &feature_index};
2340 if ((!preset_operands &&
2341 !ParseOperands(&operands, builder, /*expected_size=*/3)) ||
2342 !ParseAttributes(attrs, allow_attributes)) {
2343 return nullptr;
2344 }
2345 if (!maybe_infer_shape([&] {
2346 return ShapeInference::InferBatchNormTrainingShape(
2347 operands[0]->shape(), operands[1]->shape(),
2348 operands[2]->shape(), *feature_index);
2349 })) {
2350 return nullptr;
2351 }
2352 return builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
2353 *shape, /*operand=*/operands[0], /*scale=*/operands[1],
2354 /*offset=*/operands[2], *epsilon, *feature_index));
2355 }
2356 case HloOpcode::kBatchNormInference: {
2357 optional<float> epsilon;
2358 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2359 optional<int64_t> feature_index;
2360 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2361 &feature_index};
2362 if ((!preset_operands &&
2363 !ParseOperands(&operands, builder, /*expected_size=*/5)) ||
2364 !ParseAttributes(attrs, allow_attributes)) {
2365 return nullptr;
2366 }
2367 if (!maybe_infer_shape([&] {
2368 return ShapeInference::InferBatchNormInferenceShape(
2369 operands[0]->shape(), operands[1]->shape(),
2370 operands[2]->shape(), operands[3]->shape(),
2371 operands[4]->shape(), *feature_index);
2372 })) {
2373 return nullptr;
2374 }
2375 return builder->AddInstruction(HloInstruction::CreateBatchNormInference(
2376 *shape, /*operand=*/operands[0], /*scale=*/operands[1],
2377 /*offset=*/operands[2], /*mean=*/operands[3],
2378 /*variance=*/operands[4], *epsilon, *feature_index));
2379 }
2380 case HloOpcode::kBatchNormGrad: {
2381 optional<float> epsilon;
2382 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2383 optional<int64_t> feature_index;
2384 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2385 &feature_index};
2386 if ((!preset_operands &&
2387 !ParseOperands(&operands, builder, /*expected_size=*/5)) ||
2388 !ParseAttributes(attrs, allow_attributes)) {
2389 return nullptr;
2390 }
2391 if (!maybe_infer_shape([&] {
2392 return ShapeInference::InferBatchNormGradShape(
2393 operands[0]->shape(), operands[1]->shape(),
2394 operands[2]->shape(), operands[3]->shape(),
2395 operands[4]->shape(), *feature_index);
2396 })) {
2397 return nullptr;
2398 }
2399 return builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
2400 *shape, /*operand=*/operands[0], /*scale=*/operands[1],
2401 /*mean=*/operands[2], /*variance=*/operands[3],
2402 /*grad_output=*/operands[4], *epsilon, *feature_index));
2403 }
2404 case HloOpcode::kPad: {
2405 optional<PaddingConfig> padding;
2406 attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
2407 if ((!preset_operands &&
2408 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2409 !ParseAttributes(attrs, allow_attributes)) {
2410 return nullptr;
2411 }
2412 if (!maybe_infer_shape([&] {
2413 return ShapeInference::InferPadShape(
2414 operands[0]->shape(), operands[1]->shape(), *padding);
2415 })) {
2416 return nullptr;
2417 }
2418 return builder->AddInstruction(HloInstruction::CreatePad(
2419 *shape, operands[0], /*padding_value=*/operands[1], *padding));
2420 }
2421 case HloOpcode::kFusion: {
2422 optional<HloComputation*> fusion_computation;
2423 attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
2424 &fusion_computation};
2425 optional<HloInstruction::FusionKind> fusion_kind;
2426 attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
2427 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2428 !ParseAttributes(attrs, allow_attributes)) {
2429 return nullptr;
2430 }
2431 return builder->AddInstruction(HloInstruction::CreateFusion(
2432 *shape, *fusion_kind, operands, *fusion_computation));
2433 }
2434 case HloOpcode::kInfeed: {
2435 optional<std::string> config;
2436 attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2437 if ((!preset_operands &&
2438 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2439 !ParseAttributes(attrs, allow_attributes)) {
2440 return nullptr;
2441 }
2442 // We need to know the infeed data shape to construct the infeed
2443 // instruction. This is the zero-th element of the tuple-shaped output of
2444 // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
2445 // if the shape is not a non-empty tuple, so add guard so an error message
2446 // can be emitted instead of a check fail
2447 if (!shape->IsTuple() && !ShapeUtil::IsEmptyTuple(*shape)) {
2448 TokenError("infeed must have a non-empty tuple shape");
2449 return nullptr;
2450 }
2451 return builder->AddInstruction(HloInstruction::CreateInfeed(
2452 ShapeUtil::GetTupleElementShape(*shape, 0), operands[0],
2453 config ? *config : ""));
2454 }
2455 case HloOpcode::kOutfeed: {
2456 optional<std::string> config;
2457 optional<Shape> outfeed_shape;
2458 attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2459 attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
2460 &outfeed_shape};
2461 if ((!preset_operands &&
2462 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2463 !ParseAttributes(attrs, allow_attributes)) {
2464 return nullptr;
2465 }
2466 HloInstruction* const outfeed_input = operands[0];
2467 HloInstruction* const outfeed_token = operands[1];
2468 const Shape shape =
2469 outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
2470 return builder->AddInstruction(HloInstruction::CreateOutfeed(
2471 shape, outfeed_input, outfeed_token, config ? *config : ""));
2472 }
2473 case HloOpcode::kRng: {
2474 optional<RandomDistribution> distribution;
2475 attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
2476 &distribution};
2477 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2478 !ParseAttributes(attrs, allow_attributes)) {
2479 return nullptr;
2480 }
2481 return builder->AddInstruction(
2482 HloInstruction::CreateRng(*shape, *distribution, operands));
2483 }
2484 case HloOpcode::kRngGetAndUpdateState: {
2485 optional<int64_t> delta;
2486 attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
2487 if ((!preset_operands &&
2488 !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
2489 !ParseAttributes(attrs, allow_attributes)) {
2490 return nullptr;
2491 }
2492 return builder->AddInstruction(
2493 HloInstruction::CreateRngGetAndUpdateState(*shape, *delta));
2494 }
2495 case HloOpcode::kRngBitGenerator: {
2496 optional<RandomAlgorithm> algorithm;
2497 attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
2498 &algorithm};
2499 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2500 !ParseAttributes(attrs, allow_attributes)) {
2501 return nullptr;
2502 }
2503 return builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
2504 *shape, operands[0], *algorithm));
2505 }
2506 case HloOpcode::kReducePrecision: {
2507 optional<int64_t> exponent_bits;
2508 optional<int64_t> mantissa_bits;
2509 attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
2510 &exponent_bits};
2511 attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
2512 &mantissa_bits};
2513 if ((!preset_operands &&
2514 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2515 !ParseAttributes(attrs, allow_attributes)) {
2516 return nullptr;
2517 }
2518 return builder->AddInstruction(HloInstruction::CreateReducePrecision(
2519 *shape, operands[0], static_cast<int>(*exponent_bits),
2520 static_cast<int>(*mantissa_bits)));
2521 }
2522 case HloOpcode::kConditional: {
2523 optional<HloComputation*> true_computation;
2524 optional<HloComputation*> false_computation;
2525 optional<std::vector<HloComputation*>> branch_computations;
2526 if (!preset_operands && !ParseOperands(&operands, builder)) {
2527 return nullptr;
2528 }
2529 if (!ShapeUtil::IsScalar(operands[0]->shape())) {
2530 TokenError("The first operand must be a scalar");
2531 return nullptr;
2532 }
2533 const bool branch_index_is_bool =
2534 operands[0]->shape().element_type() == PRED;
2535 if (branch_index_is_bool) {
2536 attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
2537 &true_computation};
2538 attrs["false_computation"] = {
2539 /*required=*/true, AttrTy::kHloComputation, &false_computation};
2540 } else {
2541 if (operands[0]->shape().element_type() != S32) {
2542 TokenError("The first operand must be a scalar of PRED or S32");
2543 return nullptr;
2544 }
2545 attrs["branch_computations"] = {/*required=*/true,
2546 AttrTy::kBracedHloComputationList,
2547 &branch_computations};
2548 }
2549 if (!ParseAttributes(attrs, allow_attributes)) {
2550 return nullptr;
2551 }
2552 if (branch_index_is_bool) {
2553 branch_computations.emplace({*true_computation, *false_computation});
2554 }
2555 if (branch_computations->empty() ||
2556 operands.size() != branch_computations->size() + 1) {
2557 return nullptr;
2558 }
2559 if (!maybe_infer_shape([&] {
2560 absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
2561 branch_computation_shapes.reserve(branch_computations->size());
2562 for (auto* computation : *branch_computations) {
2563 branch_computation_shapes.push_back(
2564 computation->ComputeProgramShape());
2565 }
2566 absl::InlinedVector<Shape, 2> branch_operand_shapes;
2567 branch_operand_shapes.reserve(operands.size() - 1);
2568 for (int i = 1; i < operands.size(); ++i) {
2569 branch_operand_shapes.push_back(operands[i]->shape());
2570 }
2571 return ShapeInference::InferConditionalShape(
2572 operands[0]->shape(), branch_computation_shapes,
2573 branch_operand_shapes);
2574 })) {
2575 return nullptr;
2576 }
2577 return builder->AddInstruction(HloInstruction::CreateConditional(
2578 *shape, /*branch_index=*/operands[0],
2579 absl::MakeSpan(*branch_computations),
2580 absl::MakeSpan(operands).subspan(1)));
2581 }
2582 case HloOpcode::kCustomCall: {
2583 optional<std::string> custom_call_target;
2584 optional<Window> window;
2585 optional<ConvolutionDimensionNumbers> dnums;
2586 optional<int64_t> feature_group_count;
2587 optional<int64_t> batch_group_count;
2588 optional<std::vector<Shape>> operand_layout_constraints;
2589 optional<bool> custom_call_has_side_effect;
2590 optional<HloComputation*> to_apply;
2591 optional<
2592 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>>
2593 output_to_operand_aliasing;
2594 optional<PaddingType> padding_type;
2595 optional<std::vector<HloComputation*>> called_computations;
2596 optional<CustomCallSchedule> custom_call_schedule;
2597 optional<CustomCallApiVersion> api_version;
2598 attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
2599 &custom_call_target};
2600 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2601 attrs["dim_labels"] = {/*required=*/false,
2602 AttrTy::kConvolutionDimensionNumbers, &dnums};
2603 attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
2604 &feature_group_count};
2605 attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
2606 &batch_group_count};
2607 attrs["operand_layout_constraints"] = {
2608 /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
2609 attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
2610 &custom_call_has_side_effect};
2611 attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
2612 &to_apply};
2613 attrs["called_computations"] = {/*required=*/false,
2614 AttrTy::kBracedHloComputationList,
2615 &called_computations};
2616 attrs["output_to_operand_aliasing"] = {/*required=*/false,
2617 AttrTy::kInstructionAliasing,
2618 &output_to_operand_aliasing};
2619
2620 attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
2621 &padding_type};
2622
2623 optional<Literal> literal;
2624 attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
2625 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2626 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2627 &operand_precision};
2628 HloInstruction* instruction;
2629 if (called_computations.has_value() && to_apply.has_value()) {
2630 TokenError(
2631 "A single instruction can't have both to_apply and "
2632 "calls field");
2633 return nullptr;
2634 }
2635 attrs["schedule"] = {/*required=*/false, AttrTy::kCustomCallSchedule,
2636 &custom_call_schedule};
2637 attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion,
2638 &api_version};
2639 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2640 !ParseAttributes(attrs, allow_attributes)) {
2641 return nullptr;
2642 }
2643
2644 if (api_version.has_value() &&
2645 *api_version == CustomCallApiVersion::API_VERSION_UNSPECIFIED) {
2646 TokenError(StrCat("Invalid API version: ",
2647 CustomCallApiVersion_Name(*api_version)));
2648 return nullptr;
2649 }
2650 if (operand_layout_constraints.has_value()) {
2651 if (!LayoutUtil::HasLayout(*shape)) {
2652 TokenError("Layout must be set on layout-constrained custom call");
2653 return nullptr;
2654 }
2655 if (operands.size() != operand_layout_constraints->size()) {
2656 TokenError(StrCat("Expected ", operands.size(),
2657 " operand layout constraints, ",
2658 operand_layout_constraints->size(), " given"));
2659 return nullptr;
2660 }
2661 for (int64_t i = 0; i < operands.size(); ++i) {
2662 const Shape& operand_shape_with_layout =
2663 (*operand_layout_constraints)[i];
2664 if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
2665 TokenError(StrCat(
2666 "Operand layout constraint shape ",
2667 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2668 " for operand ", i, " does not have a layout"));
2669 return nullptr;
2670 }
2671 if (!ShapeUtil::Compatible(operand_shape_with_layout,
2672 operands[i]->shape())) {
2673 TokenError(StrCat(
2674 "Operand layout constraint shape ",
2675 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2676 " for operand ", i, " is not compatible with operand shape ",
2677 ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
2678 return nullptr;
2679 }
2680 }
2681 instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
2682 *shape, operands, *custom_call_target, *operand_layout_constraints,
2683 ""));
2684 } else {
2685 if (to_apply.has_value()) {
2686 instruction =
2687 builder->AddInstruction(HloInstruction::CreateCustomCall(
2688 *shape, operands, *to_apply, *custom_call_target, ""));
2689 } else if (called_computations.has_value()) {
2690 instruction =
2691 builder->AddInstruction(HloInstruction::CreateCustomCall(
2692 *shape, operands, *called_computations, *custom_call_target,
2693 ""));
2694 } else {
2695 instruction =
2696 builder->AddInstruction(HloInstruction::CreateCustomCall(
2697 *shape, operands, *custom_call_target, ""));
2698 }
2699 }
2700 auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
2701 if (window.has_value()) {
2702 custom_call_instr->set_window(*window);
2703 }
2704 if (dnums.has_value()) {
2705 custom_call_instr->set_convolution_dimension_numbers(*dnums);
2706 }
2707 if (feature_group_count.has_value()) {
2708 custom_call_instr->set_feature_group_count(*feature_group_count);
2709 }
2710 if (batch_group_count.has_value()) {
2711 custom_call_instr->set_batch_group_count(*batch_group_count);
2712 }
2713 if (padding_type.has_value()) {
2714 custom_call_instr->set_padding_type(*padding_type);
2715 }
2716 if (custom_call_has_side_effect.has_value()) {
2717 custom_call_instr->set_custom_call_has_side_effect(
2718 *custom_call_has_side_effect);
2719 }
2720 if (custom_call_schedule.has_value()) {
2721 custom_call_instr->set_custom_call_schedule(*custom_call_schedule);
2722 }
2723 if (api_version.has_value()) {
2724 custom_call_instr->set_api_version(*api_version);
2725 }
2726 if (output_to_operand_aliasing.has_value()) {
2727 custom_call_instr->set_output_to_operand_aliasing(
2728 std::move(*output_to_operand_aliasing));
2729 }
2730 if (literal.has_value()) {
2731 custom_call_instr->set_literal(std::move(*literal));
2732 }
2733 PrecisionConfig precision_config;
2734 if (operand_precision) {
2735 *precision_config.mutable_operand_precision() = {
2736 operand_precision->begin(), operand_precision->end()};
2737 } else {
2738 precision_config.mutable_operand_precision()->Resize(
2739 operands.size(), PrecisionConfig::DEFAULT);
2740 }
2741 *custom_call_instr->mutable_precision_config() = precision_config;
2742 return instruction;
2743 }
2744 case HloOpcode::kDot: {
2745 optional<std::vector<int64_t>> lhs_contracting_dims;
2746 attrs["lhs_contracting_dims"] = {
2747 /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
2748 optional<std::vector<int64_t>> rhs_contracting_dims;
2749 attrs["rhs_contracting_dims"] = {
2750 /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
2751 optional<std::vector<int64_t>> lhs_batch_dims;
2752 attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2753 &lhs_batch_dims};
2754 optional<std::vector<int64_t>> rhs_batch_dims;
2755 attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2756 &rhs_batch_dims};
2757 optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2758 attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2759 &operand_precision};
2760
2761 if ((!preset_operands &&
2762 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2763 !ParseAttributes(attrs, allow_attributes)) {
2764 return nullptr;
2765 }
2766
2767 DotDimensionNumbers dnum;
2768 if (lhs_contracting_dims) {
2769 *dnum.mutable_lhs_contracting_dimensions() = {
2770 lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
2771 }
2772 if (rhs_contracting_dims) {
2773 *dnum.mutable_rhs_contracting_dimensions() = {
2774 rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
2775 }
2776 if (lhs_batch_dims) {
2777 *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
2778 lhs_batch_dims->end()};
2779 }
2780 if (rhs_batch_dims) {
2781 *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
2782 rhs_batch_dims->end()};
2783 }
2784
2785 PrecisionConfig precision_config;
2786 if (operand_precision) {
2787 *precision_config.mutable_operand_precision() = {
2788 operand_precision->begin(), operand_precision->end()};
2789 } else {
2790 precision_config.mutable_operand_precision()->Resize(
2791 operands.size(), PrecisionConfig::DEFAULT);
2792 }
2793 if (!maybe_infer_shape([&] {
2794 return ShapeInference::InferDotOpShape(
2795 operands[0]->shape(), operands[1]->shape(), dnum,
2796 /*preferred_element_type=*/std::nullopt);
2797 })) {
2798 return nullptr;
2799 }
2800 return builder->AddInstruction(HloInstruction::CreateDot(
2801 *shape, operands[0], operands[1], dnum, precision_config));
2802 }
2803 case HloOpcode::kGather: {
2804 optional<std::vector<int64_t>> offset_dims;
2805 attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
2806 &offset_dims};
2807 optional<std::vector<int64_t>> collapsed_slice_dims;
2808 attrs["collapsed_slice_dims"] = {
2809 /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
2810 optional<std::vector<int64_t>> start_index_map;
2811 attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
2812 &start_index_map};
2813 optional<int64_t> index_vector_dim;
2814 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2815 &index_vector_dim};
2816 optional<std::vector<int64_t>> slice_sizes;
2817 attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
2818 &slice_sizes};
2819 optional<bool> indices_are_sorted = false;
2820 attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2821 &indices_are_sorted};
2822
2823 if ((!preset_operands &&
2824 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2825 !ParseAttributes(attrs, allow_attributes)) {
2826 return nullptr;
2827 }
2828
2829 GatherDimensionNumbers dim_numbers =
2830 HloGatherInstruction::MakeGatherDimNumbers(
2831 /*offset_dims=*/*offset_dims,
2832 /*collapsed_slice_dims=*/*collapsed_slice_dims,
2833 /*start_index_map=*/*start_index_map,
2834 /*index_vector_dim=*/*index_vector_dim);
2835 if (!maybe_infer_shape([&] {
2836 return ShapeInference::InferGatherShape(operands[0]->shape(),
2837 operands[1]->shape(),
2838 dim_numbers, *slice_sizes);
2839 })) {
2840 return nullptr;
2841 }
2842 return builder->AddInstruction(HloInstruction::CreateGather(
2843 *shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
2844 dim_numbers, *slice_sizes, indices_are_sorted.value()));
2845 }
2846 case HloOpcode::kScatter: {
2847 optional<std::vector<int64_t>> update_window_dims;
2848 attrs["update_window_dims"] = {
2849 /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
2850 optional<std::vector<int64_t>> inserted_window_dims;
2851 attrs["inserted_window_dims"] = {
2852 /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
2853 optional<std::vector<int64_t>> scatter_dims_to_operand_dims;
2854 attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
2855 AttrTy::kBracedInt64List,
2856 &scatter_dims_to_operand_dims};
2857 optional<int64_t> index_vector_dim;
2858 attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2859 &index_vector_dim};
2860
2861 optional<HloComputation*> update_computation;
2862 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2863 &update_computation};
2864 optional<bool> indices_are_sorted = false;
2865 attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2866 &indices_are_sorted};
2867 optional<bool> unique_indices = false;
2868 attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
2869 &unique_indices};
2870
2871 if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2872 !ParseAttributes(attrs, allow_attributes)) {
2873 return nullptr;
2874 }
2875
2876 if (operands.size() % 2 == 0) {
2877 TokenError(StrCat("expects an odd number of operands, but has ",
2878 operands.size(), " operands"));
2879 return nullptr;
2880 }
2881
2882 ScatterDimensionNumbers dim_numbers =
2883 HloScatterInstruction::MakeScatterDimNumbers(
2884 /*update_window_dims=*/*update_window_dims,
2885 /*inserted_window_dims=*/*inserted_window_dims,
2886 /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
2887 /*index_vector_dim=*/*index_vector_dim);
2888
2889 if (!maybe_infer_shape([&] {
2890 absl::InlinedVector<const Shape*, 3> arg_shapes;
2891 arg_shapes.reserve(operands.size());
2892 for (auto* operand : operands) {
2893 arg_shapes.push_back(&operand->shape());
2894 }
2895 return ShapeInference::InferScatterShape(
2896 arg_shapes, update_computation.value()->ComputeProgramShape(),
2897 dim_numbers);
2898 })) {
2899 return nullptr;
2900 }
2901 auto input_count = operands.size() / 2;
2902 auto operand_span = absl::MakeConstSpan(operands);
2903 return builder->AddInstruction(HloInstruction::CreateScatter(
2904 *shape, operand_span.first(input_count), operands[input_count],
2905 operand_span.last(input_count), *update_computation, dim_numbers,
2906 indices_are_sorted.value(), unique_indices.value()));
2907 }
2908 case HloOpcode::kDomain: {
2909 DomainData domain;
2910 attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
2911 if ((!preset_operands &&
2912 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2913 !ParseAttributes(attrs, allow_attributes)) {
2914 return nullptr;
2915 }
2916 if (!maybe_infer_shape([&] {
2917 return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
2918 })) {
2919 return nullptr;
2920 }
2921 return builder->AddInstruction(HloInstruction::CreateDomain(
2922 *shape, operands[0], std::move(domain.exit_metadata),
2923 std::move(domain.entry_metadata)));
2924 }
2925 case HloOpcode::kGetDimensionSize: {
2926 optional<std::vector<int64_t>> dimensions;
2927 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2928 &dimensions};
2929 if ((!preset_operands &&
2930 !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2931 !ParseAttributes(attrs, allow_attributes)) {
2932 return nullptr;
2933 }
2934 if (!maybe_infer_shape([&] {
2935 return ShapeInference::InferGetDimensionSizeShape(
2936 operands[0]->shape(), dimensions->at(0));
2937 })) {
2938 return nullptr;
2939 }
2940 return builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
2941 *shape, operands[0], (*dimensions)[0]));
2942 }
2943 case HloOpcode::kSetDimensionSize: {
2944 optional<std::vector<int64_t>> dimensions;
2945 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2946 &dimensions};
2947 if ((!preset_operands &&
2948 !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2949 !ParseAttributes(attrs, allow_attributes)) {
2950 return nullptr;
2951 }
2952 if (!maybe_infer_shape([&] {
2953 return ShapeInference::InferSetDimensionSizeShape(
2954 operands[0]->shape(), operands[1]->shape(), dimensions->at(0));
2955 })) {
2956 return nullptr;
2957 }
2958 return builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
2959 *shape, operands[0], operands[1], (*dimensions)[0]));
2960 }
2961 }
2962 return nullptr;
2963 } // NOLINT(readability/fn_size)
2964
2965 // ::= '{' (single_sharding | tuple_sharding) '}'
2966 //
2967 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)2968 bool HloParserImpl::ParseSharding(OpSharding* sharding) {
2969 // A single sharding starts with '{' and is not followed by '{'.
2970 // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
2971 // an empty tuple.
2972 if (!ParseToken(TokKind::kLbrace,
2973 "expected '{' to start sharding attribute")) {
2974 return false;
2975 }
2976
2977 if (lexer_.GetKind() != TokKind::kLbrace &&
2978 lexer_.GetKind() != TokKind::kRbrace) {
2979 return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
2980 }
2981
2982 // Tuple sharding.
2983 // Allow empty tuple shardings.
2984 if (lexer_.GetKind() != TokKind::kRbrace) {
2985 do {
2986 if (!ParseSingleSharding(sharding->add_tuple_shardings(),
2987 /*lbrace_pre_lexed=*/false)) {
2988 return false;
2989 }
2990 } while (EatIfPresent(TokKind::kComma));
2991 }
2992 sharding->set_type(OpSharding::TUPLE);
2993
2994 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
2995 }
2996
2997 // frontend_attributes ::= '{' attributes '}'
2998 // attributes
2999 // ::= /*empty*/
3000 // ::= attribute '=' value (',' attribute '=' value)*
ParseFrontendAttributes(FrontendAttributes * frontend_attributes)3001 bool HloParserImpl::ParseFrontendAttributes(
3002 FrontendAttributes* frontend_attributes) {
3003 CHECK(frontend_attributes != nullptr);
3004 if (!ParseToken(TokKind::kLbrace,
3005 "expected '{' to start frontend attributes")) {
3006 return false;
3007 }
3008 if (lexer_.GetKind() == TokKind::kRbrace) {
3009 // empty
3010 } else {
3011 do {
3012 std::string attribute;
3013 if (!ParseAttributeName(&attribute)) {
3014 return false;
3015 }
3016 if (lexer_.GetKind() != TokKind::kString) {
3017 return false;
3018 }
3019 (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
3020 lexer_.Lex();
3021 } while (EatIfPresent(TokKind::kComma));
3022 }
3023 return ParseToken(TokKind::kRbrace,
3024 "expects '}' at the end of frontend attributes");
3025 }
3026
3027 // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
3028 // ('devices=' ('[' dims ']')* device_list)?
3029 // ('metadata=' metadata)* '}'
3030 //
3031 // dims ::= int_list device_list ::= int_list
3032 // metadata ::= single_metadata |
3033 // ('{' [single_metadata (',' single_metadata)*] '}')
3034 // last_tile_dims ::= sharding_type_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)3035 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
3036 bool lbrace_pre_lexed) {
3037 if (!lbrace_pre_lexed &&
3038 !ParseToken(TokKind::kLbrace,
3039 "expected '{' to start sharding attribute")) {
3040 return false;
3041 }
3042
3043 LocTy loc = lexer_.GetLoc();
3044 bool maximal = false;
3045 bool replicated = false;
3046 bool manual = false;
3047 bool last_tile_dim_replicate = false;
3048 bool last_tile_dims = false;
3049 std::vector<int64_t> devices;
3050 std::vector<int64_t> tile_assignment_dimensions;
3051 std::vector<OpSharding::Type> subgroup_types;
3052 while (lexer_.GetKind() != TokKind::kRbrace) {
3053 switch (lexer_.GetKind()) {
3054 case TokKind::kw_maximal:
3055 maximal = true;
3056 lexer_.Lex();
3057 break;
3058 case TokKind::kw_replicated:
3059 replicated = true;
3060 lexer_.Lex();
3061 break;
3062 case TokKind::kw_manual:
3063 manual = true;
3064 lexer_.Lex();
3065 break;
3066 case TokKind::kAttributeName: {
3067 if (lexer_.GetStrVal() == "device") {
3068 if (lexer_.Lex() != TokKind::kInt) {
3069 return TokenError("device= attribute must be an integer");
3070 }
3071 devices = {lexer_.GetInt64Val()};
3072 lexer_.Lex();
3073 } else if (lexer_.GetStrVal() == "devices") {
3074 lexer_.Lex();
3075 if (!ParseToken(TokKind::kLsquare,
3076 "expected '[' to start sharding devices shape")) {
3077 return false;
3078 }
3079
3080 do {
3081 int64_t dim;
3082 if (!ParseInt64(&dim)) {
3083 return false;
3084 }
3085 tile_assignment_dimensions.push_back(dim);
3086 } while (EatIfPresent(TokKind::kComma));
3087
3088 if (!ParseToken(TokKind::kRsquare,
3089 "expected ']' to start sharding devices shape")) {
3090 return false;
3091 }
3092 do {
3093 int64_t device;
3094 if (!ParseInt64(&device)) {
3095 return false;
3096 }
3097 devices.push_back(device);
3098 } while (EatIfPresent(TokKind::kComma));
3099 } else if (lexer_.GetStrVal() == "metadata") {
3100 lexer_.Lex();
3101 if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
3102 return false;
3103 }
3104 } else if (lexer_.GetStrVal() == "last_tile_dims") {
3105 last_tile_dims = true;
3106 lexer_.Lex();
3107 if (!ParseListShardingType(&subgroup_types)) {
3108 return false;
3109 }
3110 } else {
3111 return TokenError(
3112 "unknown attribute in sharding: expected device=, devices= "
3113 "metadata= or last_tile_dims= ");
3114 }
3115 break;
3116 }
3117 case TokKind::kw_last_tile_dim_replicate:
3118 last_tile_dim_replicate = true;
3119 lexer_.Lex();
3120 break;
3121 case TokKind::kRbrace:
3122 break;
3123 default:
3124 return TokenError("unexpected token");
3125 }
3126 }
3127
3128 if (replicated) {
3129 if (!devices.empty()) {
3130 return Error(loc,
3131 "replicated shardings should not have any devices assigned");
3132 }
3133 sharding->set_type(OpSharding::REPLICATED);
3134 } else if (maximal) {
3135 if (devices.size() != 1) {
3136 return Error(loc,
3137 "maximal shardings should have exactly one device assigned");
3138 }
3139 sharding->set_type(OpSharding::MAXIMAL);
3140 sharding->add_tile_assignment_devices(devices[0]);
3141 } else if (manual) {
3142 if (!devices.empty()) {
3143 return Error(loc,
3144 "manual shardings should not have any devices assigned");
3145 }
3146 sharding->set_type(OpSharding::MANUAL);
3147 } else {
3148 if (devices.size() <= 1) {
3149 return Error(
3150 loc, "non-maximal shardings must have more than one device assigned");
3151 }
3152 if (tile_assignment_dimensions.empty()) {
3153 return Error(
3154 loc,
3155 "non-maximal shardings must have a tile assignment list including "
3156 "dimensions");
3157 }
3158 sharding->set_type(OpSharding::OTHER);
3159 for (int64_t dim : tile_assignment_dimensions) {
3160 sharding->add_tile_assignment_dimensions(dim);
3161 }
3162 for (int64_t device : devices) {
3163 sharding->add_tile_assignment_devices(device);
3164 }
3165
3166 if (last_tile_dims) {
3167 for (OpSharding::Type type : subgroup_types) {
3168 sharding->add_last_tile_dims(type);
3169 }
3170 } else {
3171 sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
3172 }
3173 }
3174
3175 lexer_.Lex();
3176 return true;
3177 }
3178
3179 // parameter_replication ::=
3180 // '{' ('true' | 'false')* (',' ('true' | 'false'))* '}'
ParseParameterReplication(ParameterReplication * parameter_replication)3181 bool HloParserImpl::ParseParameterReplication(
3182 ParameterReplication* parameter_replication) {
3183 if (!ParseToken(TokKind::kLbrace,
3184 "expected '{' to start parameter_replication attribute")) {
3185 return false;
3186 }
3187
3188 if (lexer_.GetKind() != TokKind::kRbrace) {
3189 do {
3190 if (lexer_.GetKind() == TokKind::kw_true) {
3191 parameter_replication->add_replicated_at_leaf_buffers(true);
3192 } else if (lexer_.GetKind() == TokKind::kw_false) {
3193 parameter_replication->add_replicated_at_leaf_buffers(false);
3194 } else {
3195 return false;
3196 }
3197 lexer_.Lex();
3198 } while (EatIfPresent(TokKind::kComma));
3199 }
3200
3201 return ParseToken(TokKind::kRbrace,
3202 "expected '}' to end parameter_replication attribute");
3203 }
3204
3205 // replica_groups ::='{' int64_tlist_elements '}'
3206 // int64_tlist_elements
3207 // ::= /*empty*/
3208 // ::= int64_tlist (',' int64_tlist)*
3209 // int64_tlist ::= '{' int64_elements '}'
3210 // int64_elements
3211 // ::= /*empty*/
3212 // ::= int64_val (',' int64_val)*
ParseReplicaGroupsOnly(std::vector<ReplicaGroup> * replica_groups)3213 bool HloParserImpl::ParseReplicaGroupsOnly(
3214 std::vector<ReplicaGroup>* replica_groups) {
3215 std::vector<std::vector<int64_t>> result;
3216 if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3217 &result)) {
3218 return false;
3219 }
3220 *replica_groups = CreateReplicaGroups(result);
3221 return true;
3222 }
3223
3224 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
3225 // 'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)3226 bool HloParserImpl::ParseDomain(DomainData* domain) {
3227 absl::flat_hash_map<std::string, AttrConfig> attrs;
3228 optional<std::string> kind;
3229 optional<OpSharding> entry_sharding;
3230 optional<OpSharding> exit_sharding;
3231 attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
3232 attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
3233 attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
3234 if (!ParseSubAttributes(attrs)) {
3235 return false;
3236 }
3237 if (*kind == ShardingMetadata::KindName()) {
3238 auto entry_sharding_ptr = std::make_unique<HloSharding>(
3239 HloSharding::FromProto(*entry_sharding).ValueOrDie());
3240 auto exit_sharding_ptr = std::make_unique<HloSharding>(
3241 HloSharding::FromProto(*exit_sharding).ValueOrDie());
3242 domain->entry_metadata =
3243 std::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
3244 domain->exit_metadata =
3245 std::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
3246 } else {
3247 return TokenError(StrCat("unsupported domain kind: ", *kind));
3248 }
3249 return true;
3250 }
3251
3252 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)3253 bool HloParserImpl::ParseInstructionNames(
3254 std::vector<HloInstruction*>* instructions) {
3255 if (!ParseToken(TokKind::kLbrace,
3256 "expects '{' at the beginning of instruction name list")) {
3257 return false;
3258 }
3259 LocTy loc = lexer_.GetLoc();
3260 do {
3261 std::string name;
3262 if (!ParseName(&name)) {
3263 return Error(loc, "expects a instruction name");
3264 }
3265 std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
3266 if (!instr) {
3267 return TokenError(StrFormat("instruction '%s' is not defined", name));
3268 }
3269 instructions->push_back(instr->first);
3270 } while (EatIfPresent(TokKind::kComma));
3271
3272 return ParseToken(TokKind::kRbrace,
3273 "expects '}' at the end of instruction name list");
3274 }
3275
SetValueInLiteral(LocTy loc,int64_t value,int64_t index,Literal * literal)3276 bool HloParserImpl::SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
3277 Literal* literal) {
3278 const Shape& shape = literal->shape();
3279 switch (shape.element_type()) {
3280 case S8:
3281 return SetValueInLiteralHelper<int8_t>(loc, value, index, literal);
3282 case S16:
3283 return SetValueInLiteralHelper<int16_t>(loc, value, index, literal);
3284 case S32:
3285 return SetValueInLiteralHelper<int32_t>(loc, value, index, literal);
3286 case S64:
3287 return SetValueInLiteralHelper<int64_t>(loc, value, index, literal);
3288 case U8:
3289 return SetValueInLiteralHelper<uint8_t>(loc, value, index, literal);
3290 case U16:
3291 return SetValueInLiteralHelper<uint16_t>(loc, value, index, literal);
3292 case U32:
3293 return SetValueInLiteralHelper<uint32_t>(loc, value, index, literal);
3294 case U64:
3295 return SetValueInLiteralHelper<uint64_t>(loc, value, index, literal);
3296 case PRED:
3297 // Bool type literals with rank >= 1 are printed in 0s and 1s.
3298 return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
3299 literal);
3300 default:
3301 LOG(FATAL) << "unknown integral primitive type "
3302 << PrimitiveType_Name(shape.element_type());
3303 }
3304 }
3305
SetValueInLiteral(LocTy loc,double value,int64_t index,Literal * literal)3306 bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64_t index,
3307 Literal* literal) {
3308 const Shape& shape = literal->shape();
3309 switch (shape.element_type()) {
3310 case F16:
3311 return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
3312 case BF16:
3313 return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
3314 literal);
3315 case F32:
3316 return SetValueInLiteralHelper<float>(loc, value, index, literal);
3317 case F64:
3318 return SetValueInLiteralHelper<double>(loc, value, index, literal);
3319 default:
3320 LOG(FATAL) << "unknown floating point primitive type "
3321 << PrimitiveType_Name(shape.element_type());
3322 }
3323 }
3324
SetValueInLiteral(LocTy loc,bool value,int64_t index,Literal * literal)3325 bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64_t index,
3326 Literal* literal) {
3327 const Shape& shape = literal->shape();
3328 switch (shape.element_type()) {
3329 case PRED:
3330 return SetValueInLiteralHelper<bool>(loc, value, index, literal);
3331 default:
3332 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3333 << " is not PRED type";
3334 }
3335 }
3336
SetValueInLiteral(LocTy loc,std::complex<double> value,int64_t index,Literal * literal)3337 bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
3338 int64_t index, Literal* literal) {
3339 const Shape& shape = literal->shape();
3340 switch (shape.element_type()) {
3341 case C64:
3342 return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
3343 literal);
3344 case C128:
3345 return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
3346 literal);
3347 default:
3348 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3349 << " is not a complex type";
3350 }
3351 }
3352
3353 template <typename T>
StringifyValue(T val)3354 std::string StringifyValue(T val) {
3355 return StrCat(val);
3356 }
3357 template <>
StringifyValue(std::complex<double> val)3358 std::string StringifyValue(std::complex<double> val) {
3359 return StrFormat("(%f, %f)", std::real(val), std::imag(val));
3360 }
3361
3362 // Evaluates to V when T == U.
3363 template <typename T, typename U, typename V>
3364 using EnableIfSameWithType = std::enable_if_t<std::is_same<T, U>::value, V>;
3365
3366 template <class T, EnableIfSameWithType<T, bool, bool> = false>
GetNanPayload(T val)3367 uint64_t GetNanPayload(T val) {
3368 return 0;
3369 }
3370
3371 template <class T, EnableIfSameWithType<T, int64_t, bool> = false>
GetNanPayload(T val)3372 uint64_t GetNanPayload(T val) {
3373 return 0;
3374 }
3375
3376 template <class T, EnableIfSameWithType<T, double, bool> = false>
GetNanPayload(T val)3377 uint64_t GetNanPayload(T val) {
3378 auto rep = absl::bit_cast<uint64_t>(val);
3379 if (auto payload = rep & NanPayloadBitMask<double>()) {
3380 return payload;
3381 }
3382 return QuietNanWithoutPayload<double>();
3383 }
3384
3385 template <typename LiteralNativeT, typename LiteralComponentT>
3386 EnableIfSameWithType<LiteralNativeT, LiteralComponentT, LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3387 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3388 return real;
3389 }
3390
3391 template <typename LiteralNativeT, typename LiteralComponentT>
3392 EnableIfSameWithType<LiteralNativeT, std::complex<LiteralComponentT>,
3393 LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3394 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3395 return LiteralNativeT(real, imag);
3396 }
3397
3398 template <typename T>
3399 struct ComponentType {
3400 using Type = T;
3401 };
3402
3403 template <typename T>
3404 struct ComponentType<std::complex<T>> {
3405 using Type = T;
3406 };
3407
3408 template <typename T>
GetReal(T value)3409 T GetReal(T value) {
3410 return value;
3411 }
3412
3413 template <typename T>
GetReal(std::complex<T> value)3414 T GetReal(std::complex<T> value) {
3415 return value.real();
3416 }
3417
3418 template <typename T>
GetImag(T value)3419 T GetImag(T value) {
3420 return 0;
3421 }
3422
3423 template <typename T>
GetImag(std::complex<T> value)3424 T GetImag(std::complex<T> value) {
3425 return value.imag();
3426 }
3427
3428 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,int64_t index,Literal * literal)3429 bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
3430 int64_t index, Literal* literal) {
3431 if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
3432 return false;
3433 }
3434
3435 // Check that the index is in range and assign into the literal
3436 if (index >= ShapeUtil::ElementsIn(literal->shape())) {
3437 return Error(loc, StrCat("tries to set value ", StringifyValue(value),
3438 " to a literal in shape ",
3439 ShapeUtil::HumanString(literal->shape()),
3440 " at linear index ", index,
3441 ", but the index is out of range"));
3442 }
3443 using ParsedElemComponentT = typename ComponentType<ParsedElemT>::Type;
3444 using LiteralNativeComponentT = typename ComponentType<LiteralNativeT>::Type;
3445 const auto handle_nan = [this, literal, index, loc](
3446 ParsedElemComponentT parsed_value_component,
3447 LiteralNativeComponentT*
3448 literal_value_component) {
3449 if (!std::isnan(static_cast<double>(parsed_value_component))) {
3450 return true;
3451 }
3452 auto nan_payload = GetNanPayload(parsed_value_component);
3453 if (nan_payload == QuietNanWithoutPayload<double>()) {
3454 nan_payload = QuietNanWithoutPayload<LiteralNativeComponentT>();
3455 }
3456 const auto kLargestPayload = NanPayloadBitMask<LiteralNativeComponentT>();
3457 if (nan_payload > kLargestPayload) {
3458 return Error(
3459 loc,
3460 StrCat("tries to set NaN payload 0x", absl::Hex(nan_payload),
3461 " to a literal in shape ",
3462 ShapeUtil::HumanString(literal->shape()), " at linear index ",
3463 index, ", but the NaN payload is out of range (0x",
3464 absl::Hex(kLargestPayload), ")"));
3465 }
3466 *literal_value_component = NanWithSignAndPayload<LiteralNativeComponentT>(
3467 /*sign=*/std::signbit(static_cast<double>(parsed_value_component)),
3468 /*nan_payload=*/nan_payload);
3469 return true;
3470 };
3471 const ParsedElemComponentT parsed_real_value = GetReal(value);
3472 auto literal_real_value =
3473 static_cast<LiteralNativeComponentT>(parsed_real_value);
3474 if (std::is_floating_point<ParsedElemT>::value ||
3475 std::is_same<ParsedElemT, std::complex<double>>::value) {
3476 if (!handle_nan(parsed_real_value, &literal_real_value)) {
3477 return false;
3478 }
3479 }
3480 const ParsedElemComponentT parsed_imag_value = GetImag(value);
3481 auto literal_imag_value =
3482 static_cast<LiteralNativeComponentT>(parsed_imag_value);
3483 if (std::is_same<ParsedElemT, std::complex<double>>::value) {
3484 if (!handle_nan(parsed_real_value, &literal_imag_value)) {
3485 return false;
3486 }
3487 }
3488 literal->data<LiteralNativeT>().at(index) =
3489 LiteralNativeFromRealImag<LiteralNativeT>(literal_real_value,
3490 literal_imag_value);
3491 return true;
3492 }
3493
3494 // Similar to ParseLiteral(Literal* literal, const Shape& shape), but parse the
3495 // shape instead of accepting one as argument.
ParseLiteral(Literal * literal)3496 bool HloParserImpl::ParseLiteral(Literal* literal) {
3497 if (lexer_.GetKind() == TokKind::kLparen) {
3498 // Consume Lparen
3499 lexer_.Lex();
3500 std::vector<Literal> elements;
3501 while (lexer_.GetKind() != TokKind::kRparen) {
3502 Literal element;
3503 if (!ParseLiteral(&element)) {
3504 return TokenError("Fails when parsing tuple element");
3505 }
3506 elements.emplace_back(std::move(element));
3507 if (lexer_.GetKind() != TokKind::kRparen) {
3508 ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3509 }
3510 }
3511
3512 *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3513 // Consume Rparen
3514 return ParseToken(TokKind::kRparen, "expects ')' to close a tuple literal");
3515 }
3516 Shape literal_shape;
3517 if (!ParseShape(&literal_shape)) {
3518 return false;
3519 }
3520 return ParseLiteral(literal, literal_shape);
3521 }
3522
3523 // literal
3524 // ::= tuple
3525 // ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)3526 bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
3527 return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
3528 : ParseNonTupleLiteral(literal, shape);
3529 }
3530
3531 // tuple
3532 // ::= shape '(' literal_list ')'
3533 // literal_list
3534 // ::= /*empty*/
3535 // ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)3536 bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
3537 if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
3538 return false;
3539 }
3540 std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
3541
3542 if (lexer_.GetKind() == TokKind::kRparen) {
3543 // empty
3544 } else {
3545 // literal, (',' literal)*
3546 for (int i = 0; i < elements.size(); i++) {
3547 if (i > 0) {
3548 ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3549 }
3550 if (!ParseLiteral(&elements[i],
3551 ShapeUtil::GetTupleElementShape(shape, i))) {
3552 return TokenError(StrCat("expects the ", i, "th element"));
3553 }
3554 }
3555 }
3556 *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3557 return ParseToken(TokKind::kRparen,
3558 StrCat("expects ')' at the end of the tuple with ",
3559 ShapeUtil::TupleElementCount(shape), "elements"));
3560 }
3561
3562 // non_tuple
3563 // ::= rank01
3564 // ::= rank2345
3565 // rank2345 ::= shape nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)3566 bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
3567 CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
3568 return ParseDenseLiteral(literal, shape);
3569 }
3570
ParseDenseLiteral(Literal * literal,const Shape & shape)3571 bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
3572 // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
3573 // `rank` is an int64_t, that's an implicit narrowing conversion, which is
3574 // implementation-defined behavior.
3575 const int rank = static_cast<int>(shape.rank());
3576
3577 // Create a literal with the given shape in default layout.
3578 *literal = LiteralUtil::CreateFromDimensions(shape.element_type(),
3579 shape.dimensions());
3580 int64_t nest_level = 0;
3581 int64_t linear_index = 0;
3582 // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
3583 // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
3584 // when we are parsing the 2nd '{' (right before '1'), we are seeing a
3585 // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
3586 // the first '}' (right after '3'), it means the sub-array ends, and the
3587 // sub-array is supposed to contain exactly 3 elements, so check if
3588 // elems_seen_per_dim[1] is 3.
3589 std::vector<int64_t> elems_seen_per_dim(rank);
3590 auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
3591 std::vector<int64_t> elems_seen_until_dim(elems_seen_per_dim.begin(),
3592 elems_seen_per_dim.begin() + dim);
3593 return StrCat("[",
3594 StrJoin(elems_seen_until_dim, ",",
3595 [](std::string* out, const int64_t num_elems) {
3596 StrAppend(out, num_elems - 1);
3597 }),
3598 "]");
3599 };
3600
3601 auto add_one_elem_seen = [&] {
3602 if (rank > 0) {
3603 if (nest_level != rank) {
3604 return TokenError(absl::StrFormat(
3605 "expects nested array in rank %d, but sees %d", rank, nest_level));
3606 }
3607 elems_seen_per_dim[rank - 1]++;
3608 if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
3609 return TokenError(absl::StrFormat(
3610 "expects %d elements on the minor-most dimension, but "
3611 "sees more",
3612 shape.dimensions(rank - 1)));
3613 }
3614 }
3615 return true;
3616 };
3617
3618 do {
3619 switch (lexer_.GetKind()) {
3620 default:
3621 return TokenError("unexpected token type in a literal");
3622 case TokKind::kLbrace: {
3623 nest_level++;
3624 if (nest_level > rank) {
3625 return TokenError(absl::StrFormat(
3626 "expects nested array in rank %d, but sees larger", rank));
3627 }
3628 if (nest_level > 1) {
3629 elems_seen_per_dim[nest_level - 2]++;
3630 if (elems_seen_per_dim[nest_level - 2] >
3631 shape.dimensions(nest_level - 2)) {
3632 return TokenError(absl::StrFormat(
3633 "expects %d elements in the %sth element, but sees more",
3634 shape.dimensions(nest_level - 2),
3635 get_index_str(nest_level - 2)));
3636 }
3637 }
3638 lexer_.Lex();
3639 break;
3640 }
3641 case TokKind::kRbrace: {
3642 nest_level--;
3643 if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
3644 return TokenError(absl::StrFormat(
3645 "expects %d elements in the %sth element, but sees %d",
3646 shape.dimensions(nest_level), get_index_str(nest_level),
3647 elems_seen_per_dim[nest_level]));
3648 }
3649 elems_seen_per_dim[nest_level] = 0;
3650 lexer_.Lex();
3651 break;
3652 }
3653 case TokKind::kLparen: {
3654 if (!primitive_util::IsComplexType(shape.element_type())) {
3655 return TokenError(
3656 absl::StrFormat("unexpected '(' in literal. Parens are only "
3657 "valid for complex literals"));
3658 }
3659
3660 std::complex<double> value;
3661 LocTy loc = lexer_.GetLoc();
3662 if (!add_one_elem_seen() || !ParseComplex(&value) ||
3663 !SetValueInLiteral(loc, value, linear_index++, literal)) {
3664 return false;
3665 }
3666 break;
3667 }
3668 case TokKind::kDots: {
3669 if (nest_level != 1) {
3670 return TokenError(absl::StrFormat(
3671 "expects `...` at nest level 1, but sees it at nest level %d",
3672 nest_level));
3673 }
3674 elems_seen_per_dim[0] = shape.dimensions(0);
3675 lexer_.Lex();
3676 // Fill data with deterministic (garbage) values. Use static to avoid
3677 // creating identical constants which could potentially got CSE'ed
3678 // away. This is a best-effort approach to make sure replaying a HLO
3679 // gives us same optimized HLO graph.
3680 static uint32_t data = 0;
3681 uint32_t* raw_data = static_cast<uint32_t*>(literal->untyped_data());
3682 for (int64_t i = 0; i < literal->size_bytes() / 4; ++i) {
3683 raw_data[i] = data++;
3684 }
3685 uint8_t* raw_data_int8 = static_cast<uint8_t*>(literal->untyped_data());
3686 static uint8_t data_int8 = 0;
3687 for (int64_t i = 0; i < literal->size_bytes() % 4; ++i) {
3688 raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
3689 }
3690 break;
3691 }
3692 case TokKind::kComma:
3693 // Skip.
3694 lexer_.Lex();
3695 break;
3696 case TokKind::kw_true:
3697 case TokKind::kw_false:
3698 case TokKind::kInt:
3699 case TokKind::kDecimal:
3700 case TokKind::kw_inf:
3701 case TokKind::kNegInf: {
3702 add_one_elem_seen();
3703 if (lexer_.GetKind() == TokKind::kw_true ||
3704 lexer_.GetKind() == TokKind::kw_false) {
3705 if (!SetValueInLiteral(lexer_.GetLoc(),
3706 lexer_.GetKind() == TokKind::kw_true,
3707 linear_index++, literal)) {
3708 return false;
3709 }
3710 lexer_.Lex();
3711 } else if (primitive_util::IsIntegralType(shape.element_type()) ||
3712 shape.element_type() == PRED) {
3713 LocTy loc = lexer_.GetLoc();
3714 int64_t value;
3715 if (!ParseInt64(&value)) {
3716 return Error(loc, StrCat("expects integer for primitive type: ",
3717 PrimitiveType_Name(shape.element_type())));
3718 }
3719 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3720 return false;
3721 }
3722 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
3723 LocTy loc = lexer_.GetLoc();
3724 double value;
3725 if (!ParseDouble(&value)) {
3726 return Error(
3727 loc, StrCat("expect floating point value for primitive type: ",
3728 PrimitiveType_Name(shape.element_type())));
3729 }
3730 if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3731 return false;
3732 }
3733 } else {
3734 return TokenError(StrCat("unsupported primitive type ",
3735 PrimitiveType_Name(shape.element_type())));
3736 }
3737 break;
3738 }
3739 } // end of switch
3740 } while (nest_level > 0);
3741
3742 *literal = literal->Relayout(shape.layout());
3743 return true;
3744 }
3745
3746 // MaxFiniteValue is a type-traits helper used by
3747 // HloParserImpl::CheckParsedValueIsInRange.
3748 template <typename T>
3749 struct MinMaxFiniteValue {
maxxla::__anon2255e8ac0111::MinMaxFiniteValue3750 static T max() { return std::numeric_limits<T>::max(); }
minxla::__anon2255e8ac0111::MinMaxFiniteValue3751 static T min() { return std::numeric_limits<T>::lowest(); }
3752 };
3753
3754 template <>
3755 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anon2255e8ac0111::MinMaxFiniteValue3756 static double max() {
3757 // Sadly this is not constexpr, so this forces `value` to be a method.
3758 return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
3759 }
minxla::__anon2255e8ac0111::MinMaxFiniteValue3760 static double min() { return -max(); }
3761 };
3762
3763 template <>
3764 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anon2255e8ac0111::MinMaxFiniteValue3765 static double max() {
3766 return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
3767 }
minxla::__anon2255e8ac0111::MinMaxFiniteValue3768 static double min() { return -max(); }
3769 };
3770
3771 // MSVC's standard C++ library does not define isnan/isfinite for integer types.
3772 // To work around that we will need to provide our own.
3773 template <typename T>
IsFinite(T val)3774 std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
3775 return std::isfinite(val);
3776 }
3777 template <typename T>
IsNaN(T val)3778 std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
3779 return std::isnan(val);
3780 }
3781 template <typename T>
IsFinite(T val)3782 std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
3783 return std::isfinite(static_cast<double>(val));
3784 }
3785 template <typename T>
IsNaN(T val)3786 std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
3787 return std::isnan(static_cast<double>(val));
3788 }
3789
3790 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)3791 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
3792 if (std::is_floating_point<ParsedElemT>::value) {
3793 auto value_as_native_t = static_cast<LiteralNativeT>(value);
3794 auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
3795 if (!IsFinite(value) || IsFinite(value_double_converted)) {
3796 value = value_double_converted;
3797 }
3798 }
3799 PrimitiveType literal_ty =
3800 primitive_util::NativeToPrimitiveType<LiteralNativeT>();
3801 if (IsNaN(value) ||
3802 (std::numeric_limits<ParsedElemT>::has_infinity &&
3803 (std::numeric_limits<ParsedElemT>::infinity() == value ||
3804 -std::numeric_limits<ParsedElemT>::infinity() == value))) {
3805 // Skip range checking for non-finite value.
3806 } else if (std::is_unsigned<LiteralNativeT>::value) {
3807 CHECK((std::is_same<ParsedElemT, int64_t>::value ||
3808 std::is_same<ParsedElemT, bool>::value))
3809 << "Unimplemented checking for ParsedElemT";
3810
3811 const uint64_t unsigned_value = value;
3812 const uint64_t upper_bound =
3813 static_cast<uint64_t>(std::numeric_limits<LiteralNativeT>::max());
3814 if (unsigned_value > upper_bound) {
3815 // Value is out of range for LiteralNativeT.
3816 return Error(loc, StrCat("value ", value,
3817 " is out of range for literal's primitive type ",
3818 PrimitiveType_Name(literal_ty), " namely [0, ",
3819 upper_bound, "]."));
3820 }
3821 } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
3822 value < MinMaxFiniteValue<LiteralNativeT>::min()) {
3823 // Value is out of range for LiteralNativeT.
3824 return Error(loc, StrCat("value ", value,
3825 " is out of range for literal's primitive type ",
3826 PrimitiveType_Name(literal_ty), " namely [",
3827 MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
3828 MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
3829 }
3830 return true;
3831 }
3832
3833 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)3834 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
3835 std::complex<double> value) {
3836 // e.g. `float` for std::complex<float>
3837 using LiteralComplexComponentT =
3838 decltype(std::real(std::declval<LiteralNativeT>()));
3839
3840 // We could do simply
3841 //
3842 // return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
3843 // CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
3844 //
3845 // but this would give bad error messages on failure.
3846
3847 auto check_component = [&](absl::string_view name, double v) {
3848 if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
3849 v == -std::numeric_limits<double>::infinity()) {
3850 // Skip range-checking for non-finite values.
3851 return true;
3852 }
3853
3854 double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
3855 double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
3856 if (v < min || v > max) {
3857 // Value is out of range for LitearlComplexComponentT.
3858 return Error(
3859 loc,
3860 StrCat(name, " part ", v,
3861 " is out of range for literal's primitive type ",
3862 PrimitiveType_Name(
3863 primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
3864 ", namely [", min, ", ", max, "]."));
3865 }
3866 return true;
3867 };
3868 return check_component("real", std::real(value)) &&
3869 check_component("imaginary", std::imag(value));
3870 }
3871
3872 // operands ::= '(' operands1 ')'
3873 // operands1
3874 // ::= /*empty*/
3875 // ::= operand (, operand)*
3876 // operand ::= (shape)? name
3877 // ::= (shape)? opcode operands
ParseOperands(std::vector<HloInstruction * > * operands,HloComputation::Builder * builder)3878 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3879 HloComputation::Builder* builder) {
3880 CHECK(operands != nullptr);
3881 if (!ParseToken(TokKind::kLparen,
3882 "expects '(' at the beginning of operands")) {
3883 return false;
3884 }
3885 if (lexer_.GetKind() == TokKind::kRparen) {
3886 // empty
3887 } else {
3888 do {
3889 // Try to parse the operand as a name with an optional shape. If that
3890 // doesn't work, try again parsing it as a nested instruction.
3891 //
3892 // (Trying nested instructions second is important here: If you have a
3893 // giant HLO dump, it likely doesn't have any nested instructions, but
3894 // likely has tons of non-nested operands. Generating an error is slow --
3895 // O(n) as of writing -- so we only want to hit the error branch in the
3896 // uncommon case.)
3897 HloLexer lexer_copy = lexer_;
3898 std::vector<std::string> saved_errors;
3899 std::swap(saved_errors, error_);
3900 bool is_normal_operand = [&] {
3901 LocTy loc = lexer_.GetLoc();
3902 std::string name;
3903 optional<Shape> shape;
3904 if (CanBeShape()) {
3905 shape.emplace();
3906 if (!ParseShape(&shape.value())) {
3907 return false;
3908 }
3909 }
3910 if (!ParseName(&name)) {
3911 // When parsing a single instruction (as opposed to a whole module),
3912 // an HLO may have one or more operands with a shape but no name:
3913 //
3914 // foo = add(f32[10], f32[10])
3915 //
3916 // create_missing_instruction_ is always non-null when parsing a
3917 // single instruction, and is responsible for creating kParameter
3918 // instructions for these operands.
3919 if (shape.has_value() && create_missing_instruction_ != nullptr &&
3920 scoped_name_tables_.size() == 1) {
3921 name = "";
3922 } else {
3923 return false;
3924 }
3925 }
3926 std::pair<HloInstruction*, LocTy>* instruction =
3927 FindInstruction(name, shape);
3928 if (instruction == nullptr) {
3929 return Error(loc, StrCat("instruction does not exist: ", name));
3930 }
3931
3932 // If this is a regular named operand, it must be followed by a comma or
3933 // a close-paren. If not, it has to be a named instruction. Don't
3934 // output an error here -- if it fails to parse as a named instruction
3935 // too, we'll just use that set of errors.
3936 auto next = lexer_.GetKind();
3937 if (next != TokKind::kComma && next != TokKind::kRparen) {
3938 return false;
3939 }
3940
3941 operands->push_back(instruction->first);
3942 return true;
3943 }();
3944
3945 if (is_normal_operand) {
3946 error_ = std::move(saved_errors);
3947 continue;
3948 }
3949
3950 // If parsing as a normal operand failed, try parsing as a nested
3951 // instruction.
3952 std::vector<std::string> normal_operand_errors;
3953 std::swap(error_, normal_operand_errors);
3954 lexer_ = lexer_copy;
3955
3956 // Nested instructions can't have attributes because it's ambiguous
3957 // whether the comma separates an instruction from its attribute, or
3958 // whether the comma separates two instructions.
3959 LocTy loc = lexer_.GetLoc();
3960 bool is_nested_instruction = ParseInstructionRhs(
3961 builder, /*name=*/"", loc, /*allow_attributes=*/false);
3962 if (is_nested_instruction) {
3963 operands->push_back(builder->last_added_instruction());
3964 error_ = std::move(saved_errors);
3965 continue;
3966 }
3967
3968 // If neither parsing as a normal operand nor parsing as a nested
3969 // instruction worked, fail. Return both sets of errors.
3970 std::vector<std::string> nested_instruction_errors;
3971 std::swap(error_, nested_instruction_errors);
3972 error_ = std::move(saved_errors);
3973 Error(loc,
3974 "cannot parse as an instruction name or as a nested instruction:");
3975 error_.insert(error_.end(),
3976 std::make_move_iterator(normal_operand_errors.begin()),
3977 std::make_move_iterator(normal_operand_errors.end()));
3978 error_.insert(error_.end(),
3979 std::make_move_iterator(nested_instruction_errors.begin()),
3980 std::make_move_iterator(nested_instruction_errors.end()));
3981 } while (EatIfPresent(TokKind::kComma));
3982 }
3983 return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
3984 }
3985
ParseOperands(std::vector<HloInstruction * > * operands,HloComputation::Builder * builder,const int expected_size)3986 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3987 HloComputation::Builder* builder,
3988 const int expected_size) {
3989 CHECK(operands != nullptr);
3990 LocTy loc = lexer_.GetLoc();
3991 if (!ParseOperands(operands, builder)) {
3992 return false;
3993 }
3994 if (expected_size != operands->size()) {
3995 return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
3996 operands->size(), " operands"));
3997 }
3998 return true;
3999 }
4000
4001 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)4002 bool HloParserImpl::ParseSubAttributes(
4003 const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
4004 LocTy loc = lexer_.GetLoc();
4005 if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
4006 return false;
4007 }
4008 absl::flat_hash_set<std::string> seen_attrs;
4009 if (lexer_.GetKind() == TokKind::kRbrace) {
4010 // empty
4011 } else {
4012 do {
4013 EatIfPresent(TokKind::kComma);
4014 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
4015 return false;
4016 }
4017 } while (lexer_.GetKind() != TokKind::kRbrace);
4018 }
4019 // Check that all required attrs were seen.
4020 for (const auto& attr_it : attrs) {
4021 if (attr_it.second.required &&
4022 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
4023 return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
4024 attr_it.first));
4025 }
4026 }
4027 return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
4028 }
4029
4030 // attributes ::= (',' attribute)*
ParseAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs,bool allow_attributes)4031 bool HloParserImpl::ParseAttributes(
4032 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4033 bool allow_attributes) {
4034 LocTy loc = lexer_.GetLoc();
4035 absl::flat_hash_set<std::string> seen_attrs;
4036 if (allow_attributes) {
4037 while (EatIfPresent(TokKind::kComma)) {
4038 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
4039 return false;
4040 }
4041 }
4042 }
4043
4044 // Check that all required attrs were seen.
4045 for (const auto& attr_it : attrs) {
4046 if (attr_it.second.required &&
4047 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
4048 return Error(loc, StrFormat("attribute %s is expected but not seen",
4049 attr_it.first));
4050 }
4051 }
4052 return true;
4053 }
4054
ParseAttributeHelper(const absl::flat_hash_map<std::string,AttrConfig> & attrs,absl::flat_hash_set<std::string> * seen_attrs)4055 bool HloParserImpl::ParseAttributeHelper(
4056 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4057 absl::flat_hash_set<std::string>* seen_attrs) {
4058 LocTy loc = lexer_.GetLoc();
4059 std::string name;
4060 if (!ParseAttributeName(&name)) {
4061 return Error(loc, "error parsing attributes");
4062 }
4063 VLOG(3) << "Parsing attribute " << name;
4064 if (!seen_attrs->insert(name).second) {
4065 return Error(loc, StrFormat("attribute %s already exists", name));
4066 }
4067 auto attr_it = attrs.find(name);
4068 if (attr_it == attrs.end()) {
4069 std::string allowed_attrs;
4070 if (attrs.empty()) {
4071 allowed_attrs = "No attributes are allowed here.";
4072 } else {
4073 allowed_attrs =
4074 StrCat("Allowed attributes: ",
4075 StrJoin(attrs, ", ",
4076 [&](std::string* out,
4077 const std::pair<std::string, AttrConfig>& kv) {
4078 StrAppend(out, kv.first);
4079 }));
4080 }
4081 return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
4082 allowed_attrs));
4083 }
4084 AttrTy attr_type = attr_it->second.attr_type;
4085 void* attr_out_ptr = attr_it->second.result;
4086 bool success = [&] {
4087 LocTy attr_loc = lexer_.GetLoc();
4088 switch (attr_type) {
4089 case AttrTy::kBool: {
4090 bool result;
4091 if (!ParseBool(&result)) {
4092 return false;
4093 }
4094 static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
4095 return true;
4096 }
4097 case AttrTy::kInt64: {
4098 int64_t result;
4099 if (!ParseInt64(&result)) {
4100 return false;
4101 }
4102 static_cast<optional<int64_t>*>(attr_out_ptr)->emplace(result);
4103 return true;
4104 }
4105 case AttrTy::kInt32: {
4106 int64_t result;
4107 if (!ParseInt64(&result)) {
4108 return false;
4109 }
4110 if (result != static_cast<int32_t>(result)) {
4111 return Error(attr_loc, "value out of range for int32_t");
4112 }
4113 static_cast<optional<int32_t>*>(attr_out_ptr)
4114 ->emplace(static_cast<int32_t>(result));
4115 return true;
4116 }
4117 case AttrTy::kFloat: {
4118 double result;
4119 if (!ParseDouble(&result)) {
4120 return false;
4121 }
4122 if (result > std::numeric_limits<float>::max() ||
4123 result < std::numeric_limits<float>::lowest()) {
4124 return Error(attr_loc, "value out of range for float");
4125 }
4126 static_cast<optional<float>*>(attr_out_ptr)
4127 ->emplace(static_cast<float>(result));
4128 return true;
4129 }
4130 case AttrTy::kHloComputation: {
4131 HloComputation* result = nullptr;
4132 if (!ParseHloComputation(&result)) {
4133 return false;
4134 }
4135 static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
4136 return true;
4137 }
4138 case AttrTy::kBracedHloComputationList: {
4139 std::vector<HloComputation*> result;
4140 if (!ParseHloComputationList(&result)) {
4141 return false;
4142 }
4143 static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
4144 ->emplace(result);
4145 return true;
4146 }
4147 case AttrTy::kFftType: {
4148 FftType result;
4149 if (!ParseFftType(&result)) {
4150 return false;
4151 }
4152 static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
4153 return true;
4154 }
4155 case AttrTy::kPaddingType: {
4156 PaddingType result;
4157 if (!ParsePaddingType(&result)) {
4158 return false;
4159 }
4160 static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
4161 return true;
4162 }
4163 case AttrTy::kComparisonDirection: {
4164 ComparisonDirection result;
4165 if (!ParseComparisonDirection(&result)) {
4166 return false;
4167 }
4168 static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
4169 ->emplace(result);
4170 return true;
4171 }
4172 case AttrTy::kComparisonType: {
4173 Comparison::Type result;
4174 if (!ParseComparisonType(&result)) {
4175 return false;
4176 }
4177 static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
4178 return true;
4179 }
4180 case AttrTy::kEnum: {
4181 if (lexer_.GetKind() != TokKind::kIdent) {
4182 return TokenError("expects an enumeration value");
4183 }
4184 std::string result = lexer_.GetStrVal();
4185 lexer_.Lex();
4186 static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
4187 return true;
4188 }
4189 case AttrTy::kWindow: {
4190 Window result;
4191 if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
4192 return false;
4193 }
4194 static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
4195 return true;
4196 }
4197 case AttrTy::kConvolutionDimensionNumbers: {
4198 ConvolutionDimensionNumbers result;
4199 if (!ParseConvolutionDimensionNumbers(&result)) {
4200 return false;
4201 }
4202 static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
4203 ->emplace(result);
4204 return true;
4205 }
4206 case AttrTy::kSharding: {
4207 OpSharding sharding;
4208 if (!ParseSharding(&sharding)) {
4209 return false;
4210 }
4211 static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
4212 return true;
4213 }
4214 case AttrTy::kFrontendAttributes: {
4215 FrontendAttributes frontend_attributes;
4216 if (!ParseFrontendAttributes(&frontend_attributes)) {
4217 return false;
4218 }
4219 static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
4220 ->emplace(frontend_attributes);
4221 return true;
4222 }
4223 case AttrTy::kParameterReplication: {
4224 ParameterReplication parameter_replication;
4225 if (!ParseParameterReplication(¶meter_replication)) {
4226 return false;
4227 }
4228 static_cast<optional<ParameterReplication>*>(attr_out_ptr)
4229 ->emplace(parameter_replication);
4230 return true;
4231 }
4232 case AttrTy::kInstructionList: {
4233 std::vector<HloInstruction*> result;
4234 if (!ParseInstructionNames(&result)) {
4235 return false;
4236 }
4237 static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
4238 ->emplace(result);
4239 return true;
4240 }
4241 case AttrTy::kFusionKind: {
4242 HloInstruction::FusionKind result;
4243 if (!ParseFusionKind(&result)) {
4244 return false;
4245 }
4246 static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
4247 ->emplace(result);
4248 return true;
4249 }
4250 case AttrTy::kBracedInt64List: {
4251 std::vector<int64_t> result;
4252 if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4253 &result)) {
4254 return false;
4255 }
4256 static_cast<optional<std::vector<int64_t>>*>(attr_out_ptr)
4257 ->emplace(result);
4258 return true;
4259 }
4260 case AttrTy::kBracedInt64ListList: {
4261 std::vector<std::vector<int64_t>> result;
4262 if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
4263 TokKind::kComma, &result)) {
4264 return false;
4265 }
4266 static_cast<optional<std::vector<std::vector<int64_t>>>*>(attr_out_ptr)
4267 ->emplace(result);
4268 return true;
4269 }
4270 case AttrTy::kSliceRanges: {
4271 SliceRanges result;
4272 if (!ParseSliceRanges(&result)) {
4273 return false;
4274 }
4275 static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
4276 return true;
4277 }
4278 case AttrTy::kPaddingConfig: {
4279 PaddingConfig result;
4280 if (!ParsePaddingConfig(&result)) {
4281 return false;
4282 }
4283 static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
4284 return true;
4285 }
4286 case AttrTy::kString: {
4287 std::string result;
4288 if (!ParseString(&result)) {
4289 return false;
4290 }
4291 static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
4292 return true;
4293 }
4294 case AttrTy::kMetadata: {
4295 OpMetadata result;
4296 if (!ParseMetadata(&result)) {
4297 return false;
4298 }
4299 static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
4300 return true;
4301 }
4302 case AttrTy::kDistribution: {
4303 RandomDistribution result;
4304 if (!ParseRandomDistribution(&result)) {
4305 return false;
4306 }
4307 static_cast<optional<RandomDistribution>*>(attr_out_ptr)
4308 ->emplace(result);
4309 return true;
4310 }
4311 case AttrTy::kDomain: {
4312 return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
4313 }
4314 case AttrTy::kPrecisionList: {
4315 std::vector<PrecisionConfig::Precision> result;
4316 if (!ParsePrecisionList(&result)) {
4317 return false;
4318 }
4319 static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
4320 attr_out_ptr)
4321 ->emplace(result);
4322 return true;
4323 }
4324 case AttrTy::kShape: {
4325 Shape result;
4326 if (!ParseShape(&result)) {
4327 return false;
4328 }
4329 static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
4330 return true;
4331 }
4332 case AttrTy::kShapeList: {
4333 std::vector<Shape> result;
4334 if (!ParseShapeList(&result)) {
4335 return false;
4336 }
4337 static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
4338 ->emplace(result);
4339 return true;
4340 }
4341 case AttrTy::kRandomAlgorithm: {
4342 RandomAlgorithm result;
4343 if (!ParseRandomAlgorithm(&result)) {
4344 return false;
4345 }
4346 static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
4347 return true;
4348 }
4349 case AttrTy::kAliasing: {
4350 AliasingData aliasing_data;
4351 if (!ParseAliasing(&aliasing_data)) {
4352 return false;
4353 }
4354 static_cast<optional<AliasingData>*>(attr_out_ptr)
4355 ->emplace(aliasing_data);
4356 return true;
4357 }
4358 case AttrTy::kComputationLayout: {
4359 ComputationLayout computation_layout(ShapeLayout(Shape{}));
4360 if (!ParseComputationLayout(&computation_layout)) {
4361 return false;
4362 }
4363 static_cast<optional<ComputationLayout>*>(attr_out_ptr)
4364 ->emplace(computation_layout);
4365 return true;
4366 }
4367 case AttrTy::kInstructionAliasing: {
4368 std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4369 aliasing_output_operand_pairs;
4370 if (!ParseInstructionOutputOperandAliasing(
4371 &aliasing_output_operand_pairs)) {
4372 return false;
4373 }
4374 static_cast<optional<std::vector<
4375 std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>>*>(
4376 attr_out_ptr)
4377 ->emplace(std::move(aliasing_output_operand_pairs));
4378 return true;
4379 }
4380 case AttrTy::kLiteral: {
4381 Literal result;
4382 if (!ParseLiteral(&result)) {
4383 return false;
4384 }
4385 static_cast<optional<Literal>*>(attr_out_ptr)
4386 ->emplace(std::move(result));
4387 return true;
4388 }
4389 case AttrTy::kCustomCallSchedule: {
4390 CustomCallSchedule result;
4391 if (!ParseCustomCallSchedule(&result)) {
4392 return false;
4393 }
4394 static_cast<optional<CustomCallSchedule>*>(attr_out_ptr)
4395 ->emplace(result);
4396 return true;
4397 }
4398 case AttrTy::kCustomCallApiVersion: {
4399 CustomCallApiVersion result;
4400 if (!ParseCustomCallApiVersion(&result)) {
4401 return false;
4402 }
4403 static_cast<optional<CustomCallApiVersion>*>(attr_out_ptr)
4404 ->emplace(result);
4405 return true;
4406 }
4407 }
4408 }();
4409 if (!success) {
4410 return Error(loc, StrFormat("error parsing attribute %s", name));
4411 }
4412 return true;
4413 }
4414
CopyAttributeToProtoMessage(absl::flat_hash_set<std::string> non_proto_attrs,const absl::flat_hash_map<std::string,AttrConfig> & attrs,tensorflow::protobuf::Message * message)4415 bool HloParserImpl::CopyAttributeToProtoMessage(
4416 absl::flat_hash_set<std::string> non_proto_attrs,
4417 const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4418 tensorflow::protobuf::Message* message) {
4419 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4420 const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
4421
4422 for (const auto& p : attrs) {
4423 const std::string& name = p.first;
4424 if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
4425 continue;
4426 }
4427 const tensorflow::protobuf::FieldDescriptor* fd =
4428 descriptor->FindFieldByName(name);
4429 if (!fd) {
4430 std::string allowed_attrs = "Allowed attributes: ";
4431
4432 for (int i = 0; i < descriptor->field_count(); ++i) {
4433 if (i == 0) {
4434 absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
4435 } else {
4436 absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
4437 }
4438 }
4439 return TokenError(
4440 StrFormat("unexpected attribute \"%s\". %s", name, allowed_attrs));
4441 }
4442
4443 CHECK(!fd->is_repeated()); // Repeated fields not implemented.
4444 bool success = [&] {
4445 switch (fd->type()) {
4446 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4447 auto attr_value = static_cast<optional<bool>*>(p.second.result);
4448 if (attr_value->has_value()) {
4449 reflection->SetBool(message, fd, **attr_value);
4450 }
4451 return true;
4452 }
4453 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4454 auto attr_value =
4455 static_cast<optional<std::string>*>(p.second.result);
4456 if (attr_value->has_value()) {
4457 const tensorflow::protobuf::EnumValueDescriptor* evd =
4458 fd->enum_type()->FindValueByName(**attr_value);
4459 reflection->SetEnum(message, fd, evd);
4460 }
4461 return true;
4462 }
4463 default:
4464 return false;
4465 }
4466 }();
4467
4468 if (!success) {
4469 return TokenError(StrFormat("error parsing attribute %s", name));
4470 }
4471 }
4472
4473 return true;
4474 }
4475
4476 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const absl::flat_hash_map<std::string,AttrConfig> & non_proto_attrs,tensorflow::protobuf::Message * message)4477 bool HloParserImpl::ParseAttributesAsProtoMessage(
4478 const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
4479 tensorflow::protobuf::Message* message) {
4480 const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4481 absl::flat_hash_map<std::string, AttrConfig> attrs;
4482
4483 // Storage for attributes.
4484 std::vector<optional<bool>> bool_params;
4485 std::vector<optional<std::string>> string_params;
4486 // Reserve enough capacity to make sure that the vector is not growing, so we
4487 // can rely on the pointers to stay valid.
4488 bool_params.reserve(descriptor->field_count());
4489 string_params.reserve(descriptor->field_count());
4490
4491 // Populate the storage of expected attributes from the protobuf description.
4492 for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
4493 const tensorflow::protobuf::FieldDescriptor* fd =
4494 descriptor->field(field_idx);
4495 const std::string& field_name = fd->name();
4496 switch (fd->type()) {
4497 case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4498 bool_params.emplace_back(std::nullopt);
4499 attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
4500 &bool_params.back()};
4501 break;
4502 }
4503 case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4504 string_params.emplace_back(std::nullopt);
4505 attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
4506 &string_params.back()};
4507 break;
4508 }
4509 default:
4510 return TokenError(absl::StrFormat(
4511 "Unexpected protocol buffer type: %s ", fd->DebugString()));
4512 }
4513 }
4514
4515 absl::flat_hash_set<std::string> non_proto_attrs_names;
4516 non_proto_attrs_names.reserve(non_proto_attrs.size());
4517 for (const auto& p : non_proto_attrs) {
4518 const std::string& attr_name = p.first;
4519 // If an attribute is both specified within 'non_proto_attrs' and an
4520 // attribute of the proto message, we prefer the attribute of the proto
4521 // message.
4522 if (attrs.find(attr_name) == attrs.end()) {
4523 non_proto_attrs_names.insert(attr_name);
4524 attrs[attr_name] = p.second;
4525 }
4526 }
4527
4528 if (!ParseAttributes(attrs)) {
4529 return false;
4530 }
4531
4532 return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
4533 }
4534
ParseComputationName(HloComputation ** value)4535 bool HloParserImpl::ParseComputationName(HloComputation** value) {
4536 std::string name;
4537 LocTy loc = lexer_.GetLoc();
4538 if (!ParseName(&name)) {
4539 return Error(loc, "expects computation name");
4540 }
4541 std::pair<HloComputation*, LocTy>* computation =
4542 tensorflow::gtl::FindOrNull(computation_pool_, name);
4543 if (computation == nullptr) {
4544 return Error(loc, StrCat("computation does not exist: ", name));
4545 }
4546 *value = computation->first;
4547 return true;
4548 }
4549
4550 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
4551 // The subattributes can appear in any order. 'size=' is required, others are
4552 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)4553 bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
4554 LocTy loc = lexer_.GetLoc();
4555 if (expect_outer_curlies &&
4556 !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
4557 return false;
4558 }
4559
4560 std::vector<int64_t> size;
4561 std::vector<int64_t> stride;
4562 std::vector<std::vector<int64_t>> pad;
4563 std::vector<int64_t> lhs_dilate;
4564 std::vector<int64_t> rhs_dilate;
4565 std::vector<int64_t> rhs_reversal;
4566 const auto end_token =
4567 expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
4568 while (lexer_.GetKind() != end_token) {
4569 LocTy attr_loc = lexer_.GetLoc();
4570 std::string field_name;
4571 if (!ParseAttributeName(&field_name)) {
4572 return Error(attr_loc, "expects sub-attributes in window");
4573 }
4574 bool ok = [&] {
4575 if (field_name == "size") {
4576 return ParseDxD("size", &size);
4577 }
4578 if (field_name == "stride") {
4579 return ParseDxD("stride", &stride);
4580 }
4581 if (field_name == "lhs_dilate") {
4582 return ParseDxD("lhs_dilate", &lhs_dilate);
4583 }
4584 if (field_name == "rhs_dilate") {
4585 return ParseDxD("rls_dilate", &rhs_dilate);
4586 }
4587 if (field_name == "pad") {
4588 return ParseWindowPad(&pad);
4589 }
4590 if (field_name == "rhs_reversal") {
4591 return ParseDxD("rhs_reversal", &rhs_reversal);
4592 }
4593 return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
4594 }();
4595 if (!ok) {
4596 return false;
4597 }
4598 }
4599
4600 if (!stride.empty() && stride.size() != size.size()) {
4601 return Error(loc, "expects 'stride=' has the same size as 'size='");
4602 }
4603 if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
4604 return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
4605 }
4606 if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
4607 return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
4608 }
4609 if (!pad.empty() && pad.size() != size.size()) {
4610 return Error(loc, "expects 'pad=' has the same size as 'size='");
4611 }
4612
4613 for (int i = 0; i < size.size(); i++) {
4614 window->add_dimensions()->set_size(size[i]);
4615 if (!pad.empty()) {
4616 window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
4617 window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
4618 }
4619 // If some field is not present, it has the default value.
4620 window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
4621 window->mutable_dimensions(i)->set_base_dilation(
4622 lhs_dilate.empty() ? 1 : lhs_dilate[i]);
4623 window->mutable_dimensions(i)->set_window_dilation(
4624 rhs_dilate.empty() ? 1 : rhs_dilate[i]);
4625 window->mutable_dimensions(i)->set_window_reversal(
4626 rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
4627 }
4628 return !expect_outer_curlies ||
4629 ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
4630 }
4631
4632 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
4633 // The string looks like "dim_labels=0bf_0io->0bf".
4634 //
4635 // '?' dims don't appear in ConvolutionDimensionNumbers. There can be more than
4636 // one '?' dim.
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)4637 bool HloParserImpl::ParseConvolutionDimensionNumbers(
4638 ConvolutionDimensionNumbers* dnums) {
4639 if (lexer_.GetKind() != TokKind::kDimLabels) {
4640 return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
4641 }
4642 std::string str = lexer_.GetStrVal();
4643
4644 // The str is expected to have 3 items, lhs, rhs, out, and it must look like
4645 // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
4646 std::vector<std::string> split1 = absl::StrSplit(str, '_');
4647 if (split1.size() != 2) {
4648 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4649 << str;
4650 }
4651 std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
4652 if (split2.size() != 2) {
4653 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4654 << str;
4655 }
4656 absl::string_view lhs = split1[0];
4657 absl::string_view rhs = split2[0];
4658 absl::string_view out = split2[1];
4659
4660 auto is_unique = [](absl::string_view str) -> bool {
4661 absl::flat_hash_set<char> chars;
4662 for (char c : str) {
4663 // '?' dims are skipped.
4664 if (c == '?') {
4665 continue;
4666 }
4667 if (!chars.insert(c).second) {
4668 return false;
4669 }
4670 }
4671 return true;
4672 };
4673
4674 // lhs
4675 {
4676 if (!is_unique(lhs)) {
4677 return TokenError(
4678 StrCat("expects unique lhs dimension numbers, but sees ", lhs));
4679 }
4680 // Count number of spatial dimensions.
4681 for (char c : lhs) {
4682 if (c != 'b' && c != 'f' && c != '?') {
4683 dnums->add_input_spatial_dimensions(-1);
4684 }
4685 }
4686 for (int i = 0; i < lhs.size(); i++) {
4687 char c = lhs[i];
4688 if (c == '?') {
4689 continue;
4690 } else if (c == 'b') {
4691 dnums->set_input_batch_dimension(i);
4692 } else if (c == 'f') {
4693 dnums->set_input_feature_dimension(i);
4694 } else if (c < '0' + lhs.size() && c >= '0') {
4695 dnums->set_input_spatial_dimensions(c - '0', i);
4696 } else {
4697 return TokenError(StrFormat(
4698 "expects [0-%dbf?] in lhs dimension numbers", lhs.size() - 1));
4699 }
4700 }
4701 }
4702 // rhs
4703 {
4704 if (!is_unique(rhs)) {
4705 return TokenError(
4706 StrCat("expects unique rhs dimension numbers, but sees ", rhs));
4707 }
4708 // Count number of spatial dimensions.
4709 for (char c : rhs) {
4710 if (c != 'i' && c != 'o' && c != '?') {
4711 dnums->add_kernel_spatial_dimensions(-1);
4712 }
4713 }
4714 for (int i = 0; i < rhs.size(); i++) {
4715 char c = rhs[i];
4716 if (c == '?') {
4717 continue;
4718 } else if (c == 'i') {
4719 dnums->set_kernel_input_feature_dimension(i);
4720 } else if (c == 'o') {
4721 dnums->set_kernel_output_feature_dimension(i);
4722 } else if (c < '0' + rhs.size() && c >= '0') {
4723 dnums->set_kernel_spatial_dimensions(c - '0', i);
4724 } else {
4725 return TokenError(StrFormat(
4726 "expects [0-%dio?] in rhs dimension numbers", rhs.size() - 1));
4727 }
4728 }
4729 }
4730 // output
4731 {
4732 if (!is_unique(out)) {
4733 return TokenError(
4734 StrCat("expects unique output dimension numbers, but sees ", out));
4735 }
4736 // Count number of spatial dimensions.
4737 for (char c : out) {
4738 if (c != 'b' && c != 'f' && c != '?') {
4739 dnums->add_output_spatial_dimensions(-1);
4740 }
4741 }
4742 for (int i = 0; i < out.size(); i++) {
4743 char c = out[i];
4744 if (c == '?') {
4745 continue;
4746 } else if (c == 'b') {
4747 dnums->set_output_batch_dimension(i);
4748 } else if (c == 'f') {
4749 dnums->set_output_feature_dimension(i);
4750 } else if (c < '0' + out.size() && c >= '0') {
4751 dnums->set_output_spatial_dimensions(c - '0', i);
4752 } else {
4753 return TokenError(StrFormat(
4754 "expects [0-%dbf?] in output dimension numbers", out.size() - 1));
4755 }
4756 }
4757 }
4758
4759 // lhs, rhs, and output should have the same number of spatial dimensions.
4760 if (dnums->input_spatial_dimensions_size() !=
4761 dnums->output_spatial_dimensions_size() ||
4762 dnums->input_spatial_dimensions_size() !=
4763 dnums->kernel_spatial_dimensions_size()) {
4764 return TokenError(
4765 StrFormat("input, kernel, and output must have same number of spatial "
4766 "dimensions, but got %d, %d, %d, respectively.",
4767 dnums->input_spatial_dimensions_size(),
4768 dnums->kernel_spatial_dimensions_size(),
4769 dnums->output_spatial_dimensions_size()));
4770 }
4771
4772 lexer_.Lex();
4773 return true;
4774 }
4775
4776 // ::= '{' ranges '}'
4777 // ::= /*empty*/
4778 // ::= range (',' range)*
4779 // range ::= '[' start ':' limit (':' stride)? ']'
4780 //
4781 // The slice ranges are printed as:
4782 //
4783 // {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
4784 //
4785 // This function extracts the starts, limits, and strides as 3 vectors to the
4786 // result. If stride is not present, stride is 1. For example, if the slice
4787 // ranges is printed as:
4788 //
4789 // {[2:3:4], [5:6:7], [8:9]}
4790 //
4791 // The parsed result will be:
4792 //
4793 // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
4794 //
ParseSliceRanges(SliceRanges * result)4795 bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
4796 if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
4797 return false;
4798 }
4799 std::vector<std::vector<int64_t>> ranges;
4800 if (lexer_.GetKind() == TokKind::kRbrace) {
4801 // empty
4802 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4803 }
4804 do {
4805 LocTy loc = lexer_.GetLoc();
4806 ranges.emplace_back();
4807 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
4808 &ranges.back())) {
4809 return false;
4810 }
4811 const auto& range = ranges.back();
4812 if (range.size() != 2 && range.size() != 3) {
4813 return Error(loc,
4814 StrFormat("expects [start:limit:step] or [start:limit], "
4815 "but sees %d elements.",
4816 range.size()));
4817 }
4818 } while (EatIfPresent(TokKind::kComma));
4819
4820 for (const auto& range : ranges) {
4821 result->starts.push_back(range[0]);
4822 result->limits.push_back(range[1]);
4823 result->strides.push_back(range.size() == 3 ? range[2] : 1);
4824 }
4825 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4826 }
4827
4828 // precisionlist ::= start precision_elements end
4829 // precision_elements
4830 // ::= /*empty*/
4831 // ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)4832 bool HloParserImpl::ParsePrecisionList(
4833 std::vector<PrecisionConfig::Precision>* result) {
4834 auto parse_and_add_item = [&]() {
4835 PrecisionConfig::Precision item;
4836 if (!ParsePrecision(&item)) {
4837 return false;
4838 }
4839 result->push_back(item);
4840 return true;
4841 };
4842 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4843 parse_and_add_item);
4844 }
4845
ParseHloComputation(HloComputation ** result)4846 bool HloParserImpl::ParseHloComputation(HloComputation** result) {
4847 if (lexer_.GetKind() == TokKind::kLbrace) {
4848 // This means it is a nested computation.
4849 return ParseInstructionList(result, /*computation_name=*/"_");
4850 }
4851 // This means it is a computation name.
4852 return ParseComputationName(result);
4853 }
4854
ParseHloComputationList(std::vector<HloComputation * > * result)4855 bool HloParserImpl::ParseHloComputationList(
4856 std::vector<HloComputation*>* result) {
4857 auto parse_and_add_item = [&]() {
4858 HloComputation* computation;
4859 if (!ParseHloComputation(&computation)) {
4860 return false;
4861 }
4862 VLOG(3) << "parsed computation " << computation->name();
4863 result->push_back(computation);
4864 return true;
4865 };
4866 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4867 parse_and_add_item);
4868 }
4869
4870 // shapelist ::= '{' shapes '}'
4871 // precision_elements
4872 // ::= /*empty*/
4873 // ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)4874 bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
4875 auto parse_and_add_item = [&]() {
4876 Shape shape;
4877 if (!ParseShape(&shape)) {
4878 return false;
4879 }
4880 result->push_back(std::move(shape));
4881 return true;
4882 };
4883 return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4884 parse_and_add_item);
4885 }
4886
4887 // int64_tlist ::= start int64_elements end
4888 // int64_elements
4889 // ::= /*empty*/
4890 // ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64_t> * result)4891 bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
4892 const TokKind delim,
4893 std::vector<int64_t>* result) {
4894 auto parse_and_add_item = [&]() {
4895 int64_t i;
4896 if (!ParseInt64(&i)) {
4897 return false;
4898 }
4899 result->push_back(i);
4900 return true;
4901 };
4902 return ParseList(start, end, delim, parse_and_add_item);
4903 }
4904
4905 // int64_tlistlist ::= start int64_tlist_elements end
4906 // int64_tlist_elements
4907 // ::= /*empty*/
4908 // ::= int64_tlist (delim int64_tlist)*
4909 // int64_tlist ::= start int64_elements end
4910 // int64_elements
4911 // ::= /*empty*/
4912 // ::= int64_val (delim int64_val)*
ParseInt64ListList(const TokKind start,const TokKind end,const TokKind delim,std::vector<std::vector<int64_t>> * result)4913 bool HloParserImpl::ParseInt64ListList(
4914 const TokKind start, const TokKind end, const TokKind delim,
4915 std::vector<std::vector<int64_t>>* result) {
4916 auto parse_and_add_item = [&]() {
4917 std::vector<int64_t> item;
4918 if (!ParseInt64List(start, end, delim, &item)) {
4919 return false;
4920 }
4921 result->push_back(item);
4922 return true;
4923 };
4924 return ParseList(start, end, delim, parse_and_add_item);
4925 }
4926
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)4927 bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
4928 const TokKind delim,
4929 const std::function<bool()>& parse_and_add_item) {
4930 if (!ParseToken(start, StrCat("expects a list starting with ",
4931 TokKindToString(start)))) {
4932 return false;
4933 }
4934 if (lexer_.GetKind() == end) {
4935 // empty
4936 } else {
4937 do {
4938 if (!parse_and_add_item()) {
4939 return false;
4940 }
4941 } while (EatIfPresent(delim));
4942 }
4943 return ParseToken(
4944 end, StrCat("expects a list to end with ", TokKindToString(end)));
4945 }
4946
4947 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)4948 bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
4949 if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
4950 return false;
4951 }
4952 *shape_loc = lexer_.GetLoc();
4953 return ParseShape(shape);
4954 }
4955
CanBeParamListToShape()4956 bool HloParserImpl::CanBeParamListToShape() {
4957 return lexer_.GetKind() == TokKind::kLparen;
4958 }
4959
4960 // param_list ::= '(' param_list1 ')'
4961 // param_list1
4962 // ::= /*empty*/
4963 // ::= param (',' param)*
4964 // param ::= name shape
ParseParamList()4965 bool HloParserImpl::ParseParamList() {
4966 if (!ParseToken(TokKind::kLparen,
4967 "expects '(' at the beginning of param list")) {
4968 return false;
4969 }
4970
4971 if (lexer_.GetKind() == TokKind::kRparen) {
4972 // empty
4973 } else {
4974 do {
4975 Shape shape;
4976 std::string name;
4977 if (!ParseName(&name) || !ParseShape(&shape)) {
4978 return false;
4979 }
4980 } while (EatIfPresent(TokKind::kComma));
4981 }
4982 return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
4983 }
4984
4985 // dimension_sizes ::= '[' dimension_list ']'
4986 // dimension_list
4987 // ::= /*empty*/
4988 // ::= <=? int64_t (',' param)*
4989 // param ::= name shape
ParseDimensionSizes(std::vector<int64_t> * dimension_sizes,std::vector<bool> * dynamic_dimensions)4990 bool HloParserImpl::ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
4991 std::vector<bool>* dynamic_dimensions) {
4992 auto parse_and_add_item = [&]() {
4993 int64_t i;
4994 bool is_dynamic = false;
4995 if (lexer_.GetKind() == TokKind::kLeq) {
4996 is_dynamic = true;
4997 lexer_.Lex();
4998 }
4999 if (!ParseInt64(&i)) {
5000 return false;
5001 }
5002 dimension_sizes->push_back(i);
5003 dynamic_dimensions->push_back(is_dynamic);
5004 return true;
5005 };
5006 return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
5007 parse_and_add_item);
5008 }
5009
5010 // dim_level_types
5011 // ::= /* empty */
5012 // ::= 'D' '(' dim_level_type_list ')'
5013 // dim_level_type_list
5014 // ::= /* empty */
5015 // ..= dim_level_type (',' dim_level_type)*
5016 // dim_level_type
5017 // ::= 'D'
5018 // ::= 'C'
5019 // ::= 'S'
ParseDimLevelTypes(std::vector<DimLevelType> * dim_level_types)5020 bool HloParserImpl::ParseDimLevelTypes(
5021 std::vector<DimLevelType>* dim_level_types) {
5022 auto parse_and_add_item = [&]() {
5023 if (lexer_.GetKind() == TokKind::kIdent) {
5024 if (lexer_.GetStrVal() == "D") {
5025 lexer_.Lex();
5026 dim_level_types->push_back(DIM_DENSE);
5027 return true;
5028 } else if (lexer_.GetStrVal() == "C") {
5029 dim_level_types->push_back(DIM_COMPRESSED);
5030 lexer_.Lex();
5031 return true;
5032 } else if (lexer_.GetStrVal() == "S") {
5033 dim_level_types->push_back(DIM_SINGLETON);
5034 lexer_.Lex();
5035 return true;
5036 }
5037 }
5038 return Error(lexer_.GetLoc(),
5039 "expected a DimLevelType abbreviation (D, C, or S)");
5040 };
5041 return ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
5042 parse_and_add_item);
5043 }
5044
5045 // tiles
5046 // ::= /*empty*/
5047 // ::= 'T' '(' dim_list ')'
5048 // dim_list
5049 // ::= /*empty*/
5050 // ::= (int64_t | '*') (',' (int64_t | '*'))*
ParseTiles(std::vector<Tile> * tiles)5051 bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
5052 auto parse_and_add_tile_dimension = [&]() {
5053 int64_t i;
5054 if (ParseInt64(&i)) {
5055 tiles->back().add_dimensions(i);
5056 return true;
5057 }
5058 if (lexer_.GetKind() == TokKind::kAsterisk) {
5059 tiles->back().add_dimensions(Tile::kCombineDimension);
5060 lexer_.Lex();
5061 return true;
5062 }
5063 return false;
5064 };
5065
5066 do {
5067 tiles->push_back(Tile());
5068 if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
5069 parse_and_add_tile_dimension)) {
5070 return false;
5071 }
5072 } while (lexer_.GetKind() == TokKind::kLparen);
5073 return true;
5074 }
5075
5076 // int_attribute
5077 // ::= /*empty*/
5078 // ::= attr_token '(' attr_value ')'
5079 // attr_token
5080 // ::= 'E' | 'S'
5081 // attr_value
5082 // ::= int64_t
ParseLayoutIntAttribute(int64_t * attr_value,absl::string_view attr_description)5083 bool HloParserImpl::ParseLayoutIntAttribute(
5084 int64_t* attr_value, absl::string_view attr_description) {
5085 if (!ParseToken(TokKind::kLparen,
5086 StrCat("expects ", attr_description, " to start with ",
5087 TokKindToString(TokKind::kLparen)))) {
5088 return false;
5089 }
5090 if (!ParseInt64(attr_value)) {
5091 return false;
5092 }
5093 if (!ParseToken(TokKind::kRparen,
5094 StrCat("expects ", attr_description, " to end with ",
5095 TokKindToString(TokKind::kRparen)))) {
5096 return false;
5097 }
5098 return true;
5099 }
5100
5101 // layout
5102 // ::= '{' int64_list
5103 // (':' dim_level_types tiles element_size_in_bits memory_space)?
5104 // '}'
5105 // element_size_in_bits
5106 // ::= /*empty*/
5107 // ::= 'E' '(' int64_t ')'
5108 // memory_space
5109 // ::= /*empty*/
5110 // ::= 'S' '(' int64_t ')'
ParseLayout(Layout * layout)5111 bool HloParserImpl::ParseLayout(Layout* layout) {
5112 std::vector<int64_t> minor_to_major;
5113 std::vector<DimLevelType> dim_level_types;
5114 std::vector<Tile> tiles;
5115 int64_t element_size_in_bits = 0;
5116 int64_t memory_space = 0;
5117
5118 auto parse_and_add_item = [&]() {
5119 int64_t i;
5120 if (!ParseInt64(&i)) {
5121 return false;
5122 }
5123 minor_to_major.push_back(i);
5124 return true;
5125 };
5126
5127 if (!ParseToken(TokKind::kLbrace,
5128 StrCat("expects layout to start with ",
5129 TokKindToString(TokKind::kLbrace)))) {
5130 return false;
5131 }
5132 if (lexer_.GetKind() != TokKind::kRbrace) {
5133 if (lexer_.GetKind() == TokKind::kInt) {
5134 // Parse minor to major.
5135 do {
5136 if (!parse_and_add_item()) {
5137 return false;
5138 }
5139 } while (EatIfPresent(TokKind::kComma));
5140 }
5141
5142 if (lexer_.GetKind() == TokKind::kColon) {
5143 lexer_.Lex();
5144
5145 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "D") {
5146 lexer_.Lex();
5147 ParseDimLevelTypes(&dim_level_types);
5148 }
5149
5150 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
5151 lexer_.Lex();
5152 ParseTiles(&tiles);
5153 }
5154
5155 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
5156 lexer_.Lex();
5157 ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
5158 }
5159
5160 if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
5161 lexer_.Lex();
5162 ParseLayoutIntAttribute(&memory_space, "memory space");
5163 }
5164 }
5165 }
5166 if (!ParseToken(TokKind::kRbrace,
5167 StrCat("expects layout to end with ",
5168 TokKindToString(TokKind::kRbrace)))) {
5169 return false;
5170 }
5171
5172 std::vector<Tile> vec_tiles(tiles.size());
5173 for (int i = 0; i < tiles.size(); i++) {
5174 vec_tiles[i] = Tile(tiles[i]);
5175 }
5176 *layout = LayoutUtil::MakeLayout(minor_to_major, dim_level_types, vec_tiles,
5177 element_size_in_bits, memory_space);
5178 return true;
5179 }
5180
5181 // shape ::= shape_val_
5182 // shape ::= '(' tuple_elements ')'
5183 // tuple_elements
5184 // ::= /*empty*/
5185 // ::= shape (',' shape)*
ParseShape(Shape * result)5186 bool HloParserImpl::ParseShape(Shape* result) {
5187 if (EatIfPresent(TokKind::kLparen)) { // Tuple
5188 std::vector<Shape> shapes;
5189 if (lexer_.GetKind() == TokKind::kRparen) {
5190 /*empty*/
5191 } else {
5192 // shape (',' shape)*
5193 do {
5194 shapes.emplace_back();
5195 if (!ParseShape(&shapes.back())) {
5196 return false;
5197 }
5198 } while (EatIfPresent(TokKind::kComma));
5199 }
5200 *result = ShapeUtil::MakeTupleShape(shapes);
5201 return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
5202 }
5203
5204 if (lexer_.GetKind() != TokKind::kPrimitiveType) {
5205 return TokenError(absl::StrCat("expected primitive type, saw ",
5206 TokKindToString(lexer_.GetKind())));
5207 }
5208 PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
5209 lexer_.Lex();
5210
5211 // Each element contains a dimension size and a bool indicating whether this
5212 // is a dynamic dimension.
5213 std::vector<int64_t> dimension_sizes;
5214 std::vector<bool> dynamic_dimensions;
5215 if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
5216 return false;
5217 }
5218 result->set_element_type(primitive_type);
5219 for (int i = 0; i < dimension_sizes.size(); ++i) {
5220 result->add_dimensions(dimension_sizes[i]);
5221 result->set_dynamic_dimension(i, dynamic_dimensions[i]);
5222 }
5223 LayoutUtil::SetToDefaultLayout(result);
5224 // We need to lookahead to see if a following open brace is the start of a
5225 // layout. The specific problematic case is:
5226 //
5227 // ENTRY %foo (x: f32[42]) -> f32[123] {
5228 // ...
5229 // }
5230 //
5231 // The open brace could either be the start of a computation or the start of a
5232 // layout for the f32[123] shape. We consider it the start of a layout if the
5233 // next token after the open brace is an integer or a colon.
5234 if (lexer_.GetKind() == TokKind::kLbrace &&
5235 (lexer_.LookAhead() == TokKind::kInt ||
5236 lexer_.LookAhead() == TokKind::kColon)) {
5237 Layout layout;
5238 if (!ParseLayout(&layout)) {
5239 return false;
5240 }
5241 if (layout.dim_level_types_size() != 0 &&
5242 layout.dim_level_types_size() != result->rank()) {
5243 return Error(
5244 lexer_.GetLoc(),
5245 StrFormat("Dimensions size is %ld, but dim level types size is %ld.",
5246 result->rank(), layout.dim_level_types_size()));
5247 }
5248 if (layout.minor_to_major_size() != result->rank()) {
5249 return Error(
5250 lexer_.GetLoc(),
5251 StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
5252 result->rank(), layout.minor_to_major_size()));
5253 }
5254 if (LayoutUtil::IsSparse(layout) && layout.tiles_size() > 0) {
5255 return Error(lexer_.GetLoc(),
5256 StrFormat("Layout has tiles, but is for a sparse array: %s",
5257 layout.ToString()));
5258 }
5259 *result->mutable_layout() = layout;
5260 }
5261 return true;
5262 }
5263
CanBeShape()5264 bool HloParserImpl::CanBeShape() {
5265 // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
5266 // with '('.
5267 return lexer_.GetKind() == TokKind::kPrimitiveType ||
5268 lexer_.GetKind() == TokKind::kLparen;
5269 }
5270
ParseName(std::string * result)5271 bool HloParserImpl::ParseName(std::string* result) {
5272 VLOG(3) << "ParseName";
5273 if (lexer_.GetKind() != TokKind::kIdent &&
5274 lexer_.GetKind() != TokKind::kName) {
5275 return TokenError("expects name");
5276 }
5277 *result = lexer_.GetStrVal();
5278 lexer_.Lex();
5279 return true;
5280 }
5281
ParseAttributeName(std::string * result)5282 bool HloParserImpl::ParseAttributeName(std::string* result) {
5283 if (lexer_.GetKind() != TokKind::kAttributeName) {
5284 return TokenError("expects attribute name");
5285 }
5286 *result = lexer_.GetStrVal();
5287 lexer_.Lex();
5288 return true;
5289 }
5290
ParseString(std::string * result)5291 bool HloParserImpl::ParseString(std::string* result) {
5292 VLOG(3) << "ParseString";
5293 if (lexer_.GetKind() != TokKind::kString) {
5294 return TokenError("expects string");
5295 }
5296 *result = lexer_.GetStrVal();
5297 lexer_.Lex();
5298 return true;
5299 }
5300
ParseDxD(const std::string & name,std::vector<int64_t> * result)5301 bool HloParserImpl::ParseDxD(const std::string& name,
5302 std::vector<int64_t>* result) {
5303 LocTy loc = lexer_.GetLoc();
5304 if (!result->empty()) {
5305 return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
5306 }
5307 // 1D
5308 if (lexer_.GetKind() == TokKind::kInt) {
5309 int64_t number;
5310 if (!ParseInt64(&number)) {
5311 return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
5312 }
5313 result->push_back(number);
5314 return true;
5315 }
5316 // 2D or higher.
5317 if (lexer_.GetKind() == TokKind::kDxD) {
5318 std::string str = lexer_.GetStrVal();
5319 if (!SplitToInt64s(str, 'x', result)) {
5320 return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
5321 }
5322 lexer_.Lex();
5323 return true;
5324 }
5325 return TokenError("expects token type kInt or kDxD");
5326 }
5327
ParseWindowPad(std::vector<std::vector<int64_t>> * pad)5328 bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64_t>>* pad) {
5329 LocTy loc = lexer_.GetLoc();
5330 if (!pad->empty()) {
5331 return Error(loc, "sub-attribute 'pad=' already exists");
5332 }
5333 if (lexer_.GetKind() != TokKind::kPad) {
5334 return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
5335 }
5336 std::string str = lexer_.GetStrVal();
5337 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5338 std::vector<int64_t> low_high;
5339 if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
5340 low_high.size() != 2) {
5341 return Error(loc,
5342 "expects padding_low and padding_high separated by '_'");
5343 }
5344 pad->push_back(low_high);
5345 }
5346 lexer_.Lex();
5347 return true;
5348 }
5349
5350 // This is the inverse xla::ToString(PaddingConfig). The padding config string
5351 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
5352 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
5353 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)5354 bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
5355 if (lexer_.GetKind() != TokKind::kPad) {
5356 return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
5357 }
5358 LocTy loc = lexer_.GetLoc();
5359 std::string str = lexer_.GetStrVal();
5360 for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5361 std::vector<int64_t> padding_dim;
5362 if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
5363 (padding_dim.size() != 2 && padding_dim.size() != 3)) {
5364 return Error(loc,
5365 "expects padding config pattern like 'low_high_interior' or "
5366 "'low_high'");
5367 }
5368 auto* dim = padding->add_dimensions();
5369 dim->set_edge_padding_low(padding_dim[0]);
5370 dim->set_edge_padding_high(padding_dim[1]);
5371 dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
5372 }
5373 lexer_.Lex();
5374 return true;
5375 }
5376
5377 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)5378 bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
5379 absl::flat_hash_map<std::string, AttrConfig> attrs;
5380 optional<std::string> op_type;
5381 optional<std::string> op_name;
5382 optional<std::string> source_file;
5383 optional<int32_t> source_line;
5384 optional<std::vector<int64_t>> profile_type;
5385 attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
5386 attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
5387 attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
5388 attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
5389 attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
5390 &profile_type};
5391 if (!ParseSubAttributes(attrs)) {
5392 return false;
5393 }
5394 if (op_type) {
5395 metadata->set_op_type(*op_type);
5396 }
5397 if (op_name) {
5398 metadata->set_op_name(*op_name);
5399 }
5400 if (source_file) {
5401 metadata->set_source_file(*source_file);
5402 }
5403 if (source_line) {
5404 metadata->set_source_line(*source_line);
5405 }
5406 if (profile_type) {
5407 for (const auto& type : *profile_type) {
5408 if (!ProfileType_IsValid(type)) {
5409 return false;
5410 }
5411 metadata->add_profile_type(static_cast<ProfileType>(type));
5412 }
5413 }
5414 return true;
5415 }
5416
5417 // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleOrListMetadata(tensorflow::protobuf::RepeatedPtrField<OpMetadata> * metadata)5418 bool HloParserImpl::ParseSingleOrListMetadata(
5419 tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
5420 if (lexer_.GetKind() == TokKind::kLbrace &&
5421 lexer_.LookAhead() == TokKind::kLbrace) {
5422 if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
5423 return false;
5424 }
5425
5426 if (lexer_.GetKind() != TokKind::kRbrace) {
5427 do {
5428 if (!ParseMetadata(metadata->Add())) {
5429 return false;
5430 }
5431 } while (EatIfPresent(TokKind::kComma));
5432 }
5433
5434 return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
5435 }
5436
5437 return ParseMetadata(metadata->Add());
5438 }
5439
ParseOpShardingType(OpSharding::Type * type)5440 bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) {
5441 switch (lexer_.GetKind()) {
5442 case TokKind::kw_maximal:
5443 *type = OpSharding::MAXIMAL;
5444 lexer_.Lex();
5445 break;
5446 case TokKind::kw_replicated:
5447 *type = OpSharding::REPLICATED;
5448 lexer_.Lex();
5449 break;
5450 case TokKind::kw_manual:
5451 *type = OpSharding::MANUAL;
5452 lexer_.Lex();
5453 break;
5454 default:
5455 return false;
5456 }
5457 return true;
5458 }
5459
ParseListShardingType(std::vector<OpSharding::Type> * types)5460 bool HloParserImpl::ParseListShardingType(
5461 std::vector<OpSharding::Type>* types) {
5462 if (!ParseToken(TokKind::kLbrace,
5463 "expected '{' to start sharding type list")) {
5464 return false;
5465 }
5466
5467 if (lexer_.GetKind() != TokKind::kRbrace) {
5468 do {
5469 OpSharding::Type type;
5470 if (!ParseOpShardingType(&type)) {
5471 return false;
5472 }
5473 types->emplace_back(type);
5474 } while (EatIfPresent(TokKind::kComma));
5475 }
5476
5477 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding type list");
5478 }
5479
ParseOpcode(HloOpcode * opcode,std::optional<HloOpcode> * async_wrapped_opcode)5480 bool HloParserImpl::ParseOpcode(
5481 HloOpcode* opcode, std::optional<HloOpcode>* async_wrapped_opcode) {
5482 VLOG(3) << "ParseOpcode";
5483 if (lexer_.GetKind() != TokKind::kIdent) {
5484 return TokenError("expects opcode");
5485 }
5486 std::string val = lexer_.GetStrVal();
5487 auto status_or_result = StringToHloOpcode(val);
5488 if (!status_or_result.ok()) {
5489 auto try_parsing_async_op = [&](absl::string_view suffix,
5490 HloOpcode async_opcode) {
5491 absl::string_view wrapped_opcode_view(val);
5492 if (absl::ConsumeSuffix(&wrapped_opcode_view, suffix)) {
5493 *opcode = async_opcode;
5494 std::string wrapped_opcode(wrapped_opcode_view);
5495 status_or_result = StringToHloOpcode(wrapped_opcode);
5496 return true;
5497 }
5498 return false;
5499 };
5500 if (try_parsing_async_op("-start", HloOpcode::kAsyncStart) ||
5501 try_parsing_async_op("-update", HloOpcode::kAsyncUpdate) ||
5502 try_parsing_async_op("-done", HloOpcode::kAsyncDone)) {
5503 if (!status_or_result.ok()) {
5504 return TokenError(
5505 StrFormat("expects async wrapped opcode but sees: %s, error: %s",
5506 val, status_or_result.status().error_message()));
5507 }
5508 *async_wrapped_opcode = status_or_result.ValueOrDie();
5509 } else {
5510 return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
5511 status_or_result.status().error_message()));
5512 }
5513 } else {
5514 *opcode = status_or_result.ValueOrDie();
5515 }
5516 lexer_.Lex();
5517 return true;
5518 }
5519
ParseFftType(FftType * result)5520 bool HloParserImpl::ParseFftType(FftType* result) {
5521 VLOG(3) << "ParseFftType";
5522 if (lexer_.GetKind() != TokKind::kIdent) {
5523 return TokenError("expects fft type");
5524 }
5525 std::string val = lexer_.GetStrVal();
5526 if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
5527 return TokenError(StrFormat("expects fft type but sees: %s", val));
5528 }
5529 lexer_.Lex();
5530 return true;
5531 }
5532
ParsePaddingType(PaddingType * result)5533 bool HloParserImpl::ParsePaddingType(PaddingType* result) {
5534 VLOG(3) << "ParsePaddingType";
5535 if (lexer_.GetKind() != TokKind::kIdent) {
5536 return TokenError("expects padding type");
5537 }
5538 std::string val = lexer_.GetStrVal();
5539 if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
5540 return TokenError(StrFormat("expects padding type but sees: %s", val));
5541 }
5542 lexer_.Lex();
5543 return true;
5544 }
5545
ParseComparisonDirection(ComparisonDirection * result)5546 bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
5547 VLOG(3) << "ParseComparisonDirection";
5548 if (lexer_.GetKind() != TokKind::kIdent) {
5549 return TokenError("expects comparison direction");
5550 }
5551 std::string val = lexer_.GetStrVal();
5552 auto status_or_result = StringToComparisonDirection(val);
5553 if (!status_or_result.ok()) {
5554 return TokenError(
5555 StrFormat("expects comparison direction but sees: %s", val));
5556 }
5557 *result = status_or_result.ValueOrDie();
5558 lexer_.Lex();
5559 return true;
5560 }
5561
ParseComparisonType(Comparison::Type * result)5562 bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
5563 VLOG(1) << "ParseComparisonType";
5564 if (lexer_.GetKind() != TokKind::kIdent) {
5565 return TokenError("expects comparison type");
5566 }
5567 std::string val = lexer_.GetStrVal();
5568 auto status_or_result = StringToComparisonType(val);
5569 if (!status_or_result.ok()) {
5570 return TokenError(StrFormat("expects comparison type but sees: %s", val));
5571 }
5572 *result = status_or_result.ValueOrDie();
5573 lexer_.Lex();
5574 return true;
5575 }
5576
ParseFusionKind(HloInstruction::FusionKind * result)5577 bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
5578 VLOG(3) << "ParseFusionKind";
5579 if (lexer_.GetKind() != TokKind::kIdent) {
5580 return TokenError("expects fusion kind");
5581 }
5582 std::string val = lexer_.GetStrVal();
5583 auto status_or_result = StringToFusionKind(val);
5584 if (!status_or_result.ok()) {
5585 return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
5586 val,
5587 status_or_result.status().error_message()));
5588 }
5589 *result = status_or_result.ValueOrDie();
5590 lexer_.Lex();
5591 return true;
5592 }
5593
ParseRandomDistribution(RandomDistribution * result)5594 bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
5595 VLOG(3) << "ParseRandomDistribution";
5596 if (lexer_.GetKind() != TokKind::kIdent) {
5597 return TokenError("expects random distribution");
5598 }
5599 std::string val = lexer_.GetStrVal();
5600 auto status_or_result = StringToRandomDistribution(val);
5601 if (!status_or_result.ok()) {
5602 return TokenError(
5603 StrFormat("expects random distribution but sees: %s, error: %s", val,
5604 status_or_result.status().error_message()));
5605 }
5606 *result = status_or_result.ValueOrDie();
5607 lexer_.Lex();
5608 return true;
5609 }
5610
ParseRandomAlgorithm(RandomAlgorithm * result)5611 bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
5612 VLOG(3) << "ParseRandomAlgorithm";
5613 if (lexer_.GetKind() != TokKind::kIdent) {
5614 return TokenError("expects random algorithm");
5615 }
5616 std::string val = lexer_.GetStrVal();
5617 auto status_or_result = StringToRandomAlgorithm(val);
5618 if (!status_or_result.ok()) {
5619 return TokenError(
5620 StrFormat("expects random algorithm but sees: %s, error: %s", val,
5621 status_or_result.status().error_message()));
5622 }
5623 *result = status_or_result.ValueOrDie();
5624 lexer_.Lex();
5625 return true;
5626 }
5627
ParsePrecision(PrecisionConfig::Precision * result)5628 bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
5629 VLOG(3) << "ParsePrecision";
5630 if (lexer_.GetKind() != TokKind::kIdent) {
5631 return TokenError("expects random distribution");
5632 }
5633 std::string val = lexer_.GetStrVal();
5634 auto status_or_result = StringToPrecision(val);
5635 if (!status_or_result.ok()) {
5636 return TokenError(StrFormat("expects precision but sees: %s, error: %s",
5637 val,
5638 status_or_result.status().error_message()));
5639 }
5640 *result = status_or_result.ValueOrDie();
5641 lexer_.Lex();
5642 return true;
5643 }
5644
ParseInt64(int64_t * result)5645 bool HloParserImpl::ParseInt64(int64_t* result) {
5646 VLOG(3) << "ParseInt64";
5647 if (lexer_.GetKind() != TokKind::kInt) {
5648 return TokenError("expects integer");
5649 }
5650 *result = lexer_.GetInt64Val();
5651 lexer_.Lex();
5652 return true;
5653 }
5654
ParseDouble(double * result)5655 bool HloParserImpl::ParseDouble(double* result) {
5656 switch (lexer_.GetKind()) {
5657 case TokKind::kDecimal: {
5658 double val = lexer_.GetDecimalVal();
5659 // If GetDecimalVal returns +/-inf, that means that we overflowed
5660 // `double`.
5661 if (std::isinf(val)) {
5662 return TokenError(StrCat("Constant is out of range for double (+/-",
5663 std::numeric_limits<double>::max(),
5664 ") and so is unparsable."));
5665 }
5666 *result = val;
5667 break;
5668 }
5669 case TokKind::kInt:
5670 *result = static_cast<double>(lexer_.GetInt64Val());
5671 break;
5672 case TokKind::kw_inf:
5673 *result = std::numeric_limits<double>::infinity();
5674 break;
5675 case TokKind::kNegInf:
5676 *result = -std::numeric_limits<double>::infinity();
5677 break;
5678 default:
5679 return TokenError("expects decimal or integer");
5680 }
5681 lexer_.Lex();
5682 return true;
5683 }
5684
ParseComplex(std::complex<double> * result)5685 bool HloParserImpl::ParseComplex(std::complex<double>* result) {
5686 if (lexer_.GetKind() != TokKind::kLparen) {
5687 return TokenError("expects '(' before complex number");
5688 }
5689 lexer_.Lex();
5690
5691 double real;
5692 LocTy loc = lexer_.GetLoc();
5693 if (!ParseDouble(&real)) {
5694 return Error(loc,
5695 "expect floating-point value for real part of complex number");
5696 }
5697
5698 if (lexer_.GetKind() != TokKind::kComma) {
5699 return TokenError(
5700 absl::StrFormat("expect comma after real part of complex literal"));
5701 }
5702 lexer_.Lex();
5703
5704 double imag;
5705 loc = lexer_.GetLoc();
5706 if (!ParseDouble(&imag)) {
5707 return Error(
5708 loc,
5709 "expect floating-point value for imaginary part of complex number");
5710 }
5711
5712 if (lexer_.GetKind() != TokKind::kRparen) {
5713 return TokenError(absl::StrFormat("expect ')' after complex number"));
5714 }
5715
5716 *result = std::complex<double>(real, imag);
5717 lexer_.Lex();
5718 return true;
5719 }
5720
ParseBool(bool * result)5721 bool HloParserImpl::ParseBool(bool* result) {
5722 if (lexer_.GetKind() != TokKind::kw_true &&
5723 lexer_.GetKind() != TokKind::kw_false) {
5724 return TokenError("expects true or false");
5725 }
5726 *result = lexer_.GetKind() == TokKind::kw_true;
5727 lexer_.Lex();
5728 return true;
5729 }
5730
ParseToken(TokKind kind,const std::string & msg)5731 bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
5732 VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
5733 if (lexer_.GetKind() != kind) {
5734 return TokenError(msg);
5735 }
5736 lexer_.Lex();
5737 return true;
5738 }
5739
EatIfPresent(TokKind kind)5740 bool HloParserImpl::EatIfPresent(TokKind kind) {
5741 if (lexer_.GetKind() != kind) {
5742 return false;
5743 }
5744 lexer_.Lex();
5745 return true;
5746 }
5747
AddInstruction(const std::string & name,HloInstruction * instruction,LocTy name_loc)5748 bool HloParserImpl::AddInstruction(const std::string& name,
5749 HloInstruction* instruction,
5750 LocTy name_loc) {
5751 auto result = current_name_table().insert({name, {instruction, name_loc}});
5752 if (!result.second) {
5753 Error(name_loc, StrCat("instruction already exists: ", name));
5754 return Error(/*loc=*/result.first->second.second,
5755 "instruction previously defined here");
5756 }
5757 return true;
5758 }
5759
AddComputation(const std::string & name,HloComputation * computation,LocTy name_loc)5760 bool HloParserImpl::AddComputation(const std::string& name,
5761 HloComputation* computation,
5762 LocTy name_loc) {
5763 auto result = computation_pool_.insert({name, {computation, name_loc}});
5764 if (!result.second) {
5765 Error(name_loc, StrCat("computation already exists: ", name));
5766 return Error(/*loc=*/result.first->second.second,
5767 "computation previously defined here");
5768 }
5769 return true;
5770 }
5771
ParseShapeOnly()5772 StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
5773 lexer_.Lex();
5774 Shape shape;
5775 if (!ParseShape(&shape)) {
5776 return InvalidArgument("Syntax error:\n%s", GetError());
5777 }
5778 if (lexer_.GetKind() != TokKind::kEof) {
5779 return InvalidArgument("Syntax error:\nExtra content after shape");
5780 }
5781 return shape;
5782 }
5783
ParseShardingOnly()5784 StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
5785 lexer_.Lex();
5786 OpSharding op_sharding;
5787 if (!ParseSharding(&op_sharding)) {
5788 return InvalidArgument("Syntax error:\n%s", GetError());
5789 }
5790 if (lexer_.GetKind() != TokKind::kEof) {
5791 return InvalidArgument("Syntax error:\nExtra content after sharding");
5792 }
5793 return HloSharding::FromProto(op_sharding);
5794 }
5795
ParseFrontendAttributesOnly()5796 StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
5797 lexer_.Lex();
5798 FrontendAttributes attributes;
5799 if (!ParseFrontendAttributes(&attributes)) {
5800 return InvalidArgument("Syntax error:\n%s", GetError());
5801 }
5802 if (lexer_.GetKind() != TokKind::kEof) {
5803 return InvalidArgument(
5804 "Syntax error:\nExtra content after frontend attributes");
5805 }
5806 return attributes;
5807 }
5808
ParseParameterReplicationOnly()5809 StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
5810 lexer_.Lex();
5811 ParameterReplication parameter_replication;
5812 if (!ParseParameterReplication(¶meter_replication)) {
5813 return InvalidArgument("Syntax error:\n%s", GetError());
5814 }
5815 if (lexer_.GetKind() != TokKind::kEof) {
5816 return InvalidArgument(
5817 "Syntax error:\nExtra content after parameter replication");
5818 }
5819 return std::vector<bool>(
5820 parameter_replication.replicated_at_leaf_buffers().begin(),
5821 parameter_replication.replicated_at_leaf_buffers().end());
5822 }
5823
ParseReplicaGroupsOnly()5824 StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
5825 lexer_.Lex();
5826 std::vector<ReplicaGroup> replica_groups;
5827 if (!ParseReplicaGroupsOnly(&replica_groups)) {
5828 return InvalidArgument("Syntax error:\n%s", GetError());
5829 }
5830 if (lexer_.GetKind() != TokKind::kEof) {
5831 return InvalidArgument("Syntax error:\nExtra content after replica groups");
5832 }
5833 return replica_groups;
5834 }
5835
ParseWindowOnly()5836 StatusOr<Window> HloParserImpl::ParseWindowOnly() {
5837 lexer_.Lex();
5838 Window window;
5839 if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
5840 return InvalidArgument("Syntax error:\n%s", GetError());
5841 }
5842 if (lexer_.GetKind() != TokKind::kEof) {
5843 return InvalidArgument("Syntax error:\nExtra content after window");
5844 }
5845 return window;
5846 }
5847
5848 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()5849 HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
5850 lexer_.Lex();
5851 ConvolutionDimensionNumbers dnums;
5852 if (!ParseConvolutionDimensionNumbers(&dnums)) {
5853 return InvalidArgument("Syntax error:\n%s", GetError());
5854 }
5855 if (lexer_.GetKind() != TokKind::kEof) {
5856 return InvalidArgument(
5857 "Syntax error:\nExtra content after convolution dnums");
5858 }
5859 return dnums;
5860 }
5861
ParsePaddingConfigOnly()5862 StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
5863 lexer_.Lex();
5864 PaddingConfig padding_config;
5865 if (!ParsePaddingConfig(&padding_config)) {
5866 return InvalidArgument("Syntax error:\n%s", GetError());
5867 }
5868 if (lexer_.GetKind() != TokKind::kEof) {
5869 return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
5870 }
5871 return padding_config;
5872 }
5873
ParseSingleInstruction(HloModule * module)5874 bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
5875 if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
5876 LOG(FATAL) << "Parser state is not clean. Please do not call any other "
5877 "methods before calling ParseSingleInstruction.";
5878 }
5879 HloComputation::Builder builder(module->name());
5880
5881 // The missing instruction hook we register creates the shaped instruction on
5882 // the fly as a parameter and returns it.
5883 int64_t parameter_count = 0;
5884 create_missing_instruction_ =
5885 [this, &builder, ¶meter_count](
5886 const std::string& name,
5887 const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
5888 std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
5889 HloInstruction* parameter = builder.AddInstruction(
5890 HloInstruction::CreateParameter(parameter_count++, shape, new_name));
5891 current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
5892 return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
5893 };
5894
5895 // Parse the instruction with the registered hook.
5896 Scope scope(&scoped_name_tables_);
5897 if (CanBeShape()) {
5898 // This means that the instruction's left-hand side is probably omitted,
5899 // e.g.
5900 //
5901 // f32[10] fusion(...), calls={...}
5902 if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
5903 return false;
5904 }
5905 } else {
5906 // This means that the instruction's left-hand side might exist, e.g.
5907 //
5908 // foo = f32[10] fusion(...), calls={...}
5909 std::string root_name;
5910 if (!ParseInstruction(&builder, &root_name)) {
5911 return false;
5912 }
5913 }
5914
5915 if (lexer_.GetKind() != TokKind::kEof) {
5916 Error(
5917 lexer_.GetLoc(),
5918 "Syntax error:\nExpected eof after parsing single instruction. Did "
5919 "you mean to write an HLO module and forget the \"HloModule\" header?");
5920 return false;
5921 }
5922
5923 module->AddEntryComputation(builder.Build());
5924 for (auto& comp : computations_) {
5925 module->AddEmbeddedComputation(std::move(comp));
5926 }
5927 TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
5928 return true;
5929 }
5930
5931 } // namespace
5932
ParseAndReturnUnverifiedModule(absl::string_view str,const HloModuleConfig & config)5933 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5934 absl::string_view str, const HloModuleConfig& config) {
5935 auto module = std::make_unique<HloModule>(/*name=*/"_", config);
5936 HloParserImpl parser(str);
5937 TF_RETURN_IF_ERROR(parser.Run(module.get()));
5938 return std::move(module);
5939 }
5940
ParseAndReturnUnverifiedModule(absl::string_view str)5941 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5942 absl::string_view str) {
5943 return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
5944 }
5945
ParseSharding(absl::string_view str)5946 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
5947 HloParserImpl parser(str);
5948 return parser.ParseShardingOnly();
5949 }
5950
ParseFrontendAttributes(absl::string_view str)5951 StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
5952 HloParserImpl parser(str);
5953 return parser.ParseFrontendAttributesOnly();
5954 }
5955
ParseParameterReplication(absl::string_view str)5956 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
5957 HloParserImpl parser(str);
5958 return parser.ParseParameterReplicationOnly();
5959 }
5960
ParseReplicaGroupsOnly(absl::string_view str)5961 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
5962 absl::string_view str) {
5963 HloParserImpl parser(str);
5964 return parser.ParseReplicaGroupsOnly();
5965 }
5966
ParseWindow(absl::string_view str)5967 StatusOr<Window> ParseWindow(absl::string_view str) {
5968 HloParserImpl parser(str);
5969 return parser.ParseWindowOnly();
5970 }
5971
ParseConvolutionDimensionNumbers(absl::string_view str)5972 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
5973 absl::string_view str) {
5974 HloParserImpl parser(str);
5975 return parser.ParseConvolutionDimensionNumbersOnly();
5976 }
5977
ParsePaddingConfig(absl::string_view str)5978 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
5979 HloParserImpl parser(str);
5980 return parser.ParsePaddingConfigOnly();
5981 }
5982
ParseShape(absl::string_view str)5983 StatusOr<Shape> ParseShape(absl::string_view str) {
5984 HloParserImpl parser(str);
5985 return parser.ParseShapeOnly();
5986 }
5987
CreateHloParserForTests(absl::string_view str)5988 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
5989 absl::string_view str) {
5990 return std::make_unique<HloParserImpl>(str);
5991 }
5992
5993 } // namespace xla
5994