xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_parser.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 
18 #include <functional>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <type_traits>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/base/casts.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_split.h"
36 #include "absl/strings/string_view.h"
37 #include "absl/types/span.h"
38 #include "absl/types/variant.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/literal_util.h"
41 #include "tensorflow/compiler/xla/primitive_util.h"
42 #include "tensorflow/compiler/xla/service/computation_layout.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
45 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
46 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
47 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
48 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
49 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
50 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
51 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
52 #include "tensorflow/compiler/xla/service/shape_inference.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 #include "tensorflow/core/lib/gtl/map_util.h"
57 #include "tensorflow/core/platform/protobuf.h"
58 
59 namespace xla {
60 
61 namespace {
62 
63 using absl::StrAppend;
64 using absl::StrCat;
65 using absl::StrFormat;
66 using absl::StrJoin;
67 using std::nullopt;
68 using std::optional;
69 
70 // Creates and returns a schedule created using the order of the instructions in
71 // the HloComputation::instructions() vectors in the module.
ScheduleFromInstructionOrder(HloModule * module)72 HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
73   HloSchedule schedule(module);
74   for (HloComputation* computation : module->computations()) {
75     if (!computation->IsFusionComputation()) {
76       for (HloInstruction* instruction : computation->instructions()) {
77         schedule.GetOrCreateSequence(computation).push_back(instruction);
78       }
79     }
80   }
81   return schedule;
82 }
83 
CanInferShape(HloOpcode code)84 bool CanInferShape(HloOpcode code) {
85   switch (code) {
86     case HloOpcode::kAbs:
87     case HloOpcode::kAdd:
88     case HloOpcode::kAddDependency:
89     case HloOpcode::kAfterAll:
90     case HloOpcode::kAtan2:
91     case HloOpcode::kBatchNormGrad:
92     case HloOpcode::kBatchNormInference:
93     case HloOpcode::kBatchNormTraining:
94     case HloOpcode::kBroadcast:
95     case HloOpcode::kCall:
96     case HloOpcode::kCeil:
97     case HloOpcode::kCholesky:
98     case HloOpcode::kClamp:
99     case HloOpcode::kClz:
100     case HloOpcode::kCompare:
101     case HloOpcode::kComplex:
102     case HloOpcode::kConcatenate:
103     case HloOpcode::kConditional:
104     case HloOpcode::kConvolution:
105     case HloOpcode::kCopy:
106     case HloOpcode::kCos:
107     case HloOpcode::kOptimizationBarrier:
108     case HloOpcode::kDivide:
109     case HloOpcode::kDomain:
110     case HloOpcode::kDot:
111     case HloOpcode::kExp:
112     case HloOpcode::kExpm1:
113     case HloOpcode::kFft:
114     case HloOpcode::kFloor:
115     case HloOpcode::kGather:
116     case HloOpcode::kGetDimensionSize:
117     case HloOpcode::kSetDimensionSize:
118     case HloOpcode::kGetTupleElement:
119     case HloOpcode::kImag:
120     case HloOpcode::kIsFinite:
121     case HloOpcode::kLog:
122     case HloOpcode::kLog1p:
123     case HloOpcode::kLogistic:
124     case HloOpcode::kAnd:
125     case HloOpcode::kNot:
126     case HloOpcode::kOr:
127     case HloOpcode::kXor:
128     case HloOpcode::kMap:
129     case HloOpcode::kMaximum:
130     case HloOpcode::kMinimum:
131     case HloOpcode::kMultiply:
132     case HloOpcode::kNegate:
133     case HloOpcode::kPad:
134     case HloOpcode::kPartitionId:
135     case HloOpcode::kPopulationCount:
136     case HloOpcode::kPower:
137     case HloOpcode::kReal:
138     case HloOpcode::kReduce:
139     case HloOpcode::kRemainder:
140     case HloOpcode::kReplicaId:
141     case HloOpcode::kReverse:
142     case HloOpcode::kRoundNearestAfz:
143     case HloOpcode::kRoundNearestEven:
144     case HloOpcode::kRsqrt:
145     case HloOpcode::kScatter:
146     case HloOpcode::kSelect:
147     case HloOpcode::kShiftLeft:
148     case HloOpcode::kShiftRightArithmetic:
149     case HloOpcode::kShiftRightLogical:
150     case HloOpcode::kSign:
151     case HloOpcode::kSin:
152     case HloOpcode::kSqrt:
153     case HloOpcode::kCbrt:
154     case HloOpcode::kReduceWindow:
155     case HloOpcode::kSelectAndScatter:
156     case HloOpcode::kSort:
157     case HloOpcode::kSubtract:
158     case HloOpcode::kTanh:
159     case HloOpcode::kTranspose:
160     case HloOpcode::kTriangularSolve:
161     case HloOpcode::kTuple:
162     case HloOpcode::kWhile:
163       return true;
164     // Technically the following ops do not require an explicit result shape,
165     // but we made it so that we always write the shapes explicitly.
166     case HloOpcode::kAsyncStart:
167     case HloOpcode::kAsyncUpdate:
168     case HloOpcode::kAsyncDone:
169     case HloOpcode::kAllGather:
170     case HloOpcode::kAllGatherStart:
171     case HloOpcode::kAllGatherDone:
172     case HloOpcode::kAllReduce:
173     case HloOpcode::kAllReduceStart:
174     case HloOpcode::kAllReduceDone:
175     case HloOpcode::kAllToAll:
176     case HloOpcode::kCollectivePermute:
177     case HloOpcode::kCollectivePermuteStart:
178     case HloOpcode::kCollectivePermuteDone:
179     case HloOpcode::kCopyDone:
180     case HloOpcode::kCopyStart:
181     case HloOpcode::kDynamicReshape:
182     case HloOpcode::kDynamicSlice:
183     case HloOpcode::kDynamicUpdateSlice:
184     case HloOpcode::kRecv:
185     case HloOpcode::kRecvDone:
186     case HloOpcode::kReduceScatter:
187     case HloOpcode::kSend:
188     case HloOpcode::kSendDone:
189     case HloOpcode::kSlice:
190     // The following ops require an explicit result shape.
191     case HloOpcode::kBitcast:
192     case HloOpcode::kBitcastConvert:
193     case HloOpcode::kConstant:
194     case HloOpcode::kConvert:
195     case HloOpcode::kCustomCall:
196     case HloOpcode::kFusion:
197     case HloOpcode::kInfeed:
198     case HloOpcode::kIota:
199     case HloOpcode::kOutfeed:
200     case HloOpcode::kParameter:
201     case HloOpcode::kReducePrecision:
202     case HloOpcode::kReshape:
203     case HloOpcode::kRng:
204     case HloOpcode::kRngBitGenerator:
205     case HloOpcode::kRngGetAndUpdateState:
206       return false;
207   }
208 }
209 
210 // Parser for the HloModule::ToString() format text.
211 class HloParserImpl : public HloParser {
212  public:
213   using LocTy = HloLexer::LocTy;
214 
HloParserImpl(absl::string_view str)215   explicit HloParserImpl(absl::string_view str) : lexer_(str) {}
216 
217   // Runs the parser and constructs the resulting HLO in the given (empty)
218   // HloModule. Returns the error status in case an error occurred.
219   Status Run(HloModule* module) override;
220 
221   // Returns the error information.
GetError() const222   std::string GetError() const { return StrJoin(error_, "\n"); }
223 
224   // Stand alone parsing utils for various aggregate data types.
225   StatusOr<Shape> ParseShapeOnly();
226   StatusOr<HloSharding> ParseShardingOnly();
227   StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
228   StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
229   StatusOr<Window> ParseWindowOnly();
230   StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
231   StatusOr<PaddingConfig> ParsePaddingConfigOnly();
232   StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
233 
234  private:
235   // Types of attributes.
236   enum class AttrTy {
237     kBool,
238     kInt64,
239     kInt32,
240     kFloat,
241     kString,
242     kLiteral,
243     kBracedInt64List,
244     kBracedInt64ListList,
245     kHloComputation,
246     kBracedHloComputationList,
247     kFftType,
248     kPaddingType,
249     kComparisonDirection,
250     kComparisonType,
251     kWindow,
252     kConvolutionDimensionNumbers,
253     kSharding,
254     kFrontendAttributes,
255     kParameterReplication,
256     kInstructionList,
257     kSliceRanges,
258     kPaddingConfig,
259     kMetadata,
260     kFusionKind,
261     kDistribution,
262     kDomain,
263     kPrecisionList,
264     kShape,
265     kShapeList,
266     kEnum,
267     kRandomAlgorithm,
268     kAliasing,
269     kComputationLayout,
270     kInstructionAliasing,
271     kCustomCallSchedule,
272     kCustomCallApiVersion,
273   };
274 
275   struct AttrConfig {
276     bool required;     // whether it's required or optional
277     AttrTy attr_type;  // what type it is
278     void* result;      // where to store the parsed result.
279   };
280 
281   using InstrNameTable =
282       absl::flat_hash_map<std::string, std::pair<HloInstruction*, LocTy>>;
283 
284   // Returns the map from the instruction name to the instruction itself and its
285   // location in the current scope.
current_name_table()286   InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
287 
288   // Locates an instruction with the given name in the current_name_table() or
289   // returns nullptr.
290   //
291   // When the name is not found or name is empty, if create_missing_instruction_
292   // hook is registered and a "shape" is provided, the hook will be called to
293   // create an instruction. This is useful when we reify parameters as they're
294   // resolved; i.e. for ParseSingleInstruction.
295   std::pair<HloInstruction*, LocTy>* FindInstruction(
296       const std::string& name, const optional<Shape>& shape = nullopt);
297 
298   // Parse a single instruction worth of text.
299   bool ParseSingleInstruction(HloModule* module);
300 
301   // Parses a module, returning false if an error occurred.
302   bool ParseHloModule(HloModule* module);
303 
304   bool ParseComputations(HloModule* module);
305   bool ParseComputation(HloComputation** entry_computation);
306   bool ParseInstructionList(HloComputation** computation,
307                             const std::string& computation_name);
308   bool ParseInstruction(HloComputation::Builder* builder,
309                         std::string* root_name);
310   bool ParseInstructionRhs(HloComputation::Builder* builder, std::string name,
311                            LocTy name_loc, bool allow_attributes = true);
312   bool ParseControlPredecessors(HloInstruction* instruction);
313   bool ParseLiteral(Literal* literal);
314   bool ParseLiteral(Literal* literal, const Shape& shape);
315   bool ParseTupleLiteral(Literal* literal, const Shape& shape);
316   bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
317   bool ParseDenseLiteral(Literal* literal, const Shape& shape);
318 
319   // Parses and creates instruction given name, shape, opcode etc. This is
320   // refactored out from ParseInstructionRhs to allow recursion of wrapped
321   // async instructions to allow parsing for wrapped-op-specific attributes.
322   HloInstruction* CreateInstruction(
323       HloComputation::Builder* builder, absl::string_view name,
324       std::optional<Shape> shape, HloOpcode opcode,
325       std::optional<HloOpcode> async_wrapped_opcode,
326       absl::flat_hash_map<std::string, AttrConfig>& attrs,
327       bool allow_attributes,
328       std::vector<HloInstruction*>* preset_operands = nullptr);
329 
330   // Sets the sub-value of literal at the given linear index to the
331   // given value. If the literal is dense, it must have the default layout.
332   //
333   // `loc` should be the source location of the value.
334   bool SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
335                          Literal* literal);
336   bool SetValueInLiteral(LocTy loc, double value, int64_t index,
337                          Literal* literal);
338   bool SetValueInLiteral(LocTy loc, bool value, int64_t index,
339                          Literal* literal);
340   bool SetValueInLiteral(LocTy loc, std::complex<double> value, int64_t index,
341                          Literal* literal);
342   // `loc` should be the source location of the value.
343   template <typename LiteralNativeT, typename ParsedElemT>
344   bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, int64_t index,
345                                Literal* literal);
346 
347   // Checks whether the given value is within the range of LiteralNativeT.
348   // `loc` should be the source location of the value.
349   template <typename LiteralNativeT, typename ParsedElemT>
350   bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
351   template <typename LiteralNativeT>
352   bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
353 
354   bool ParseOperands(std::vector<HloInstruction*>* operands,
355                      HloComputation::Builder* builder);
356   // Fills parsed operands into 'operands' and expects a certain number of
357   // operands.
358   bool ParseOperands(std::vector<HloInstruction*>* operands,
359                      HloComputation::Builder* builder, const int expected_size);
360 
361   // Describes the start, limit, and stride on every dimension of the operand
362   // being sliced.
363   struct SliceRanges {
364     std::vector<int64_t> starts;
365     std::vector<int64_t> limits;
366     std::vector<int64_t> strides;
367   };
368 
369   // The data parsed for the kDomain instruction.
370   struct DomainData {
371     std::unique_ptr<DomainMetadata> entry_metadata;
372     std::unique_ptr<DomainMetadata> exit_metadata;
373   };
374 
375   // attributes ::= (',' attribute)*
376   //
377   // Parses attributes given names and configs of the attributes. Each parsed
378   // result is passed back through the result pointer in corresponding
379   // AttrConfig. Note that the result pointer must point to a optional<T> typed
380   // variable which outlives this function. Returns false on error. You should
381   // not use the any of the results if this function failed.
382   //
383   // If allow_attributes is false, returns an error if any attributes are
384   // present.  This is used for contexts in which attributes are not allowed but
385   // e.g. we *also* want to raise an error if any required attributes are
386   // missing.
387   //
388   // Example usage:
389   //
390   //  absl::flat_hash_map<std::string, AttrConfig> attrs;
391   //  optional<int64_t> foo;
392   //  attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
393   //  optional<Window> bar;
394   //  attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
395   //  if (!ParseAttributes(attrs)) {
396   //    return false; // Do not use 'foo' 'bar' if failed.
397   //  }
398   //  // Do something with 'bar'.
399   //  if (foo) { // If attr foo is seen, do something with 'foo'. }
400   //
401   bool ParseAttributes(
402       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
403       bool allow_attributes = true);
404 
405   // sub_attributes ::= '{' (','? attribute)* '}'
406   //
407   // Usage is the same as ParseAttributes. See immediately above.
408   bool ParseSubAttributes(
409       const absl::flat_hash_map<std::string, AttrConfig>& attrs);
410 
411   // Parses one attribute. If it has already been seen, return error. Returns
412   // true and adds to seen_attrs on success.
413   //
414   // Do not call this except in ParseAttributes or ParseSubAttributes.
415   bool ParseAttributeHelper(
416       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
417       absl::flat_hash_set<std::string>* seen_attrs);
418 
419   // Copy attributes from `attrs` to `message`, unless the attribute name is in
420   // `non_proto_attrs`.
421   bool CopyAttributeToProtoMessage(
422       absl::flat_hash_set<std::string> non_proto_attrs,
423       const absl::flat_hash_map<std::string, AttrConfig>& attrs,
424       tensorflow::protobuf::Message* message);
425 
426   // Parses an attribute string into a protocol buffer `message`.
427   // Since proto3 has no notion of mandatory fields, `required_attrs` gives the
428   // set of mandatory attributes.
429   // `non_proto_attrs` specifies attributes that are not written to the proto,
430   // but added to the HloInstruction.
431   bool ParseAttributesAsProtoMessage(
432       const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
433       tensorflow::protobuf::Message* message);
434 
435   // Parses a name and finds the corresponding hlo computation.
436   bool ParseComputationName(HloComputation** value);
437   // Parses a list of names and finds the corresponding hlo instructions.
438   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
439   // Pass expect_outer_curlies == true when parsing a Window in the context of a
440   // larger computation.  Pass false when parsing a stand-alone Window string.
441   bool ParseWindow(Window* window, bool expect_outer_curlies);
442   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
443   bool ParsePaddingConfig(PaddingConfig* padding);
444   bool ParseMetadata(OpMetadata* metadata);
445   bool ParseSingleOrListMetadata(
446       tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata);
447   bool ParseOpShardingType(OpSharding::Type* type);
448   bool ParseListShardingType(std::vector<OpSharding::Type>* types);
449   bool ParseSharding(OpSharding* sharding);
450   bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
451   bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
452   bool ParseParameterReplication(ParameterReplication* parameter_replication);
453   bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
454 
455   // Parses the metadata behind a kDOmain instruction.
456   bool ParseDomain(DomainData* domain);
457 
458   // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
459   bool ParseDxD(const std::string& name, std::vector<int64_t>* result);
460   // Parses window's pad sub-attribute, e.g., pad=0_0x3x3.
461   bool ParseWindowPad(std::vector<std::vector<int64_t>>* pad);
462 
463   bool ParseSliceRanges(SliceRanges* result);
464   bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
465   bool ParseHloComputation(HloComputation** result);
466   bool ParseHloComputationList(std::vector<HloComputation*>* result);
467   bool ParseShapeList(std::vector<Shape>* result);
468   bool ParseInt64List(const TokKind start, const TokKind end,
469                       const TokKind delim, std::vector<int64_t>* result);
470   bool ParseInt64ListList(const TokKind start, const TokKind end,
471                           const TokKind delim,
472                           std::vector<std::vector<int64_t>>* result);
473   // 'parse_and_add_item' is an lambda to parse an element in the list and add
474   // the parsed element to the result. It's supposed to capture the result.
475   bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
476                  const std::function<bool()>& parse_and_add_item);
477 
478   bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
479   bool ParseParamList();
480   bool ParseName(std::string* result);
481   bool ParseAttributeName(std::string* result);
482   bool ParseString(std::string* result);
483   bool ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
484                            std::vector<bool>* dynamic_dimensions);
485   bool ParseShape(Shape* result);
486   bool ParseLayout(Layout* layout);
487   bool ParseLayoutIntAttribute(int64_t* attr_value,
488                                absl::string_view attr_description);
489   bool ParseDimLevelTypes(std::vector<DimLevelType>* dim_level_types);
490   bool ParseTiles(std::vector<Tile>* tiles);
491   bool ParseOpcode(HloOpcode* opcode,
492                    std::optional<HloOpcode>* async_wrapped_opcode);
493   bool ParseFftType(FftType* result);
494   bool ParsePaddingType(PaddingType* result);
495   bool ParseComparisonDirection(ComparisonDirection* result);
496   bool ParseComparisonType(Comparison::Type* result);
497   bool ParseFusionKind(HloInstruction::FusionKind* result);
498   bool ParseRandomDistribution(RandomDistribution* result);
499   bool ParseRandomAlgorithm(RandomAlgorithm* result);
500   bool ParsePrecision(PrecisionConfig::Precision* result);
501   bool ParseInt64(int64_t* result);
502   bool ParseDouble(double* result);
503   bool ParseComplex(std::complex<double>* result);
504   bool ParseBool(bool* result);
505   bool ParseToken(TokKind kind, const std::string& msg);
506 
507   using AliasingData =
508       absl::flat_hash_map<ShapeIndex, HloInputOutputAliasConfig::Alias>;
509 
510   // Parses the aliasing information from string `s`, returns `false` if it
511   // fails.
512   bool ParseAliasing(AliasingData* data);
513 
514   // Parses the entry computation layout.
515   bool ParseComputationLayout(ComputationLayout* computation_layout);
516 
517   // Parses the per-instruction aliasing information from string `s`, returns
518   // `false` if it fails.
519   bool ParseInstructionOutputOperandAliasing(
520       std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
521           aliasing_output_operand_pairs);
522 
523   bool ParseCustomCallSchedule(CustomCallSchedule* result);
524   bool ParseCustomCallApiVersion(CustomCallApiVersion* result);
525   bool ParseShapeIndex(ShapeIndex* out);
526 
527   // Returns true if the current token is the beginning of a shape.
528   bool CanBeShape();
529   // Returns true if the current token is the beginning of a
530   // param_list_to_shape.
531   bool CanBeParamListToShape();
532 
533   // Logs the current parsing line and the given message. Always returns false.
534   bool TokenError(absl::string_view msg);
535   bool Error(LocTy loc, absl::string_view msg);
536 
537   // If the current token is 'kind', eats it (i.e. lexes the next token) and
538   // returns true.
539   bool EatIfPresent(TokKind kind);
540 
541   // Adds the instruction to the pool. Returns false and emits an error if the
542   // instruction already exists.
543   bool AddInstruction(const std::string& name, HloInstruction* instruction,
544                       LocTy name_loc);
545   // Adds the computation to the pool. Returns false and emits an error if the
546   // computation already exists.
547   bool AddComputation(const std::string& name, HloComputation* computation,
548                       LocTy name_loc);
549 
550   HloLexer lexer_;
551 
552   // A stack for the instruction names. The top of the stack stores the
553   // instruction name table for the current scope.
554   //
555   // A instruction's name is unique among its scope (i.e. its parent
556   // computation), but it's not necessarily unique among all computations in the
557   // module. When there are multiple levels of nested computations, the same
558   // name could appear in both an outer computation and an inner computation. So
559   // we need a stack to make sure a name is only visible within its scope,
560   std::vector<InstrNameTable> scoped_name_tables_;
561 
562   // A helper class which pushes and pops to an InstrNameTable stack via RAII.
563   class Scope {
564    public:
Scope(std::vector<InstrNameTable> * scoped_name_tables)565     explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
566         : scoped_name_tables_(scoped_name_tables) {
567       scoped_name_tables_->emplace_back();
568     }
~Scope()569     ~Scope() { scoped_name_tables_->pop_back(); }
570 
571    private:
572     std::vector<InstrNameTable>* scoped_name_tables_;
573   };
574 
575   // Map from the computation name to the computation itself and its location.
576   absl::flat_hash_map<std::string, std::pair<HloComputation*, LocTy>>
577       computation_pool_;
578 
579   std::vector<std::unique_ptr<HloComputation>> computations_;
580   std::vector<std::string> error_;
581 
582   // When an operand name cannot be resolved, this function is called to create
583   // a parameter instruction with the given name and shape. It registers the
584   // name, instruction, and a placeholder location in the name table. It returns
585   // the newly-created instruction and the placeholder location. If `name` is
586   // empty, this should create the parameter with a generated name. This is
587   // supposed to be set and used only in ParseSingleInstruction.
588   std::function<std::pair<HloInstruction*, LocTy>*(const std::string& name,
589                                                    const Shape& shape)>
590       create_missing_instruction_;
591 
592   // Used to generate names for anonymous instructions.
593   NameUniquer name_uniquer_{/*separator=*/"."};
594 };
595 
SplitToInt64s(absl::string_view s,char delim,std::vector<int64_t> * out)596 bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64_t>* out) {
597   for (const auto& split : absl::StrSplit(s, delim)) {
598     int64_t val;
599     if (!absl::SimpleAtoi(split, &val)) {
600       return false;
601     }
602     out->push_back(val);
603   }
604   return true;
605 }
606 
607 // Creates replica groups from the provided nested array. groups[i] represents
608 // the replica ids for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64_t>> groups)609 std::vector<ReplicaGroup> CreateReplicaGroups(
610     absl::Span<const std::vector<int64_t>> groups) {
611   std::vector<ReplicaGroup> replica_groups;
612   absl::c_transform(groups, std::back_inserter(replica_groups),
613                     [](const std::vector<int64_t>& ids) {
614                       ReplicaGroup group;
615                       *group.mutable_replica_ids() = {ids.begin(), ids.end()};
616                       return group;
617                     });
618   return replica_groups;
619 }
620 
Error(LocTy loc,absl::string_view msg)621 bool HloParserImpl::Error(LocTy loc, absl::string_view msg) {
622   auto line_col = lexer_.GetLineAndColumn(loc);
623   const unsigned line = line_col.first;
624   const unsigned col = line_col.second;
625   std::vector<std::string> error_lines;
626   error_lines.push_back(
627       StrCat("was parsing ", line, ":", col, ": error: ", msg));
628   error_lines.emplace_back(lexer_.GetLine(loc));
629   error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^"));
630 
631   error_.push_back(StrJoin(error_lines, "\n"));
632   VLOG(1) << "Error: " << error_.back();
633   return false;
634 }
635 
TokenError(absl::string_view msg)636 bool HloParserImpl::TokenError(absl::string_view msg) {
637   return Error(lexer_.GetLoc(), msg);
638 }
639 
Run(HloModule * module)640 Status HloParserImpl::Run(HloModule* module) {
641   lexer_.Lex();
642   if (lexer_.GetKind() == TokKind::kw_HloModule) {
643     // This means that the text contains a full HLO module.
644     if (!ParseHloModule(module)) {
645       return InvalidArgument(
646           "Syntax error when trying to parse the text as a HloModule:\n%s",
647           GetError());
648     }
649     return OkStatus();
650   }
651   // This means that the text is a single HLO instruction.
652   if (!ParseSingleInstruction(module)) {
653     return InvalidArgument(
654         "Syntax error when trying to parse the text as a single "
655         "HloInstruction:\n%s",
656         GetError());
657   }
658   return OkStatus();
659 }
660 
661 std::pair<HloInstruction*, HloParserImpl::LocTy>*
FindInstruction(const std::string & name,const optional<Shape> & shape)662 HloParserImpl::FindInstruction(const std::string& name,
663                                const optional<Shape>& shape) {
664   std::pair<HloInstruction*, LocTy>* instr = nullptr;
665   if (!name.empty()) {
666     instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
667   }
668 
669   // Potentially call the missing instruction hook.
670   if (instr == nullptr && create_missing_instruction_ != nullptr &&
671       scoped_name_tables_.size() == 1) {
672     if (!shape.has_value()) {
673       Error(lexer_.GetLoc(),
674             "Operand had no shape in HLO text; cannot create parameter for "
675             "single-instruction module.");
676       return nullptr;
677     }
678     return create_missing_instruction_(name, *shape);
679   }
680 
681   if (instr != nullptr && shape.has_value() &&
682       !ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
683     Error(
684         lexer_.GetLoc(),
685         StrCat("The declared operand shape ",
686                ShapeUtil::HumanStringWithLayout(shape.value()),
687                " is not compatible with the shape of the operand instruction ",
688                ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
689     return nullptr;
690   }
691 
692   return instr;
693 }
694 
ParseShapeIndex(ShapeIndex * out)695 bool HloParserImpl::ParseShapeIndex(ShapeIndex* out) {
696   if (!ParseToken(TokKind::kLbrace, "Expects '{' at the start of ShapeIndex")) {
697     return false;
698   }
699 
700   std::vector<int64_t> idxs;
701   while (lexer_.GetKind() != TokKind::kRbrace) {
702     int64_t idx;
703     if (!ParseInt64(&idx)) {
704       return false;
705     }
706     idxs.push_back(idx);
707     if (!EatIfPresent(TokKind::kComma)) {
708       break;
709     }
710   }
711   if (!ParseToken(TokKind::kRbrace, "Expects '}' at the end of ShapeIndex")) {
712     return false;
713   }
714   *out = ShapeIndex(idxs.begin(), idxs.end());
715   return true;
716 }
717 
ParseAliasing(AliasingData * data)718 bool HloParserImpl::ParseAliasing(AliasingData* data) {
719   if (!ParseToken(TokKind::kLbrace,
720                   "Expects '{' at the start of aliasing description")) {
721     return false;
722   }
723 
724   while (lexer_.GetKind() != TokKind::kRbrace) {
725     ShapeIndex out;
726     if (!ParseShapeIndex(&out)) {
727       return false;
728     }
729     std::string errmsg =
730         "Expected format: <output_shape_index>: (<input_param>, "
731         "<input_param_shape_index>) OR <output_shape_index>: <input_param>";
732     if (!ParseToken(TokKind::kColon, errmsg)) {
733       return false;
734     }
735 
736     if (!ParseToken(TokKind::kLparen, errmsg)) {
737       return false;
738     }
739     int64_t param_num;
740     ParseInt64(&param_num);
741     if (!ParseToken(TokKind::kComma, errmsg)) {
742       return false;
743     }
744     ShapeIndex param_idx;
745     if (!ParseShapeIndex(&param_idx)) {
746       return false;
747     }
748 
749     HloInputOutputAliasConfig::AliasKind alias_kind =
750         HloInputOutputAliasConfig::kMayAlias;
751     if (EatIfPresent(TokKind::kComma)) {
752       std::string type;
753       ParseName(&type);
754       if (type == "must-alias") {
755         alias_kind = HloInputOutputAliasConfig::kMustAlias;
756       } else if (type == "may-alias") {
757         alias_kind = HloInputOutputAliasConfig::kMayAlias;
758       } else {
759         return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
760       }
761     }
762 
763     data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
764                   std::forward_as_tuple(param_num, param_idx, alias_kind));
765     if (!ParseToken(TokKind::kRparen, errmsg)) {
766       return false;
767     }
768 
769     if (!EatIfPresent(TokKind::kComma)) {
770       break;
771     }
772   }
773   if (!ParseToken(TokKind::kRbrace,
774                   "Expects '}' at the end of aliasing description")) {
775     return false;
776   }
777   return true;
778 }
779 
ParseComputationLayout(ComputationLayout * computation_layout)780 bool HloParserImpl::ParseComputationLayout(
781     ComputationLayout* computation_layout) {
782   if (!ParseToken(TokKind::kLbrace,
783                   "Expects '{' at the start of aliasing description")) {
784     return false;
785   }
786   if (!ParseToken(TokKind::kLparen, "Expects ( before parameter shape list")) {
787     return false;
788   }
789   while (lexer_.GetKind() != TokKind::kRparen) {
790     Shape param;
791     if (!ParseShape(&param)) {
792       return false;
793     }
794     computation_layout->add_parameter_layout(ShapeLayout(param));
795     if (lexer_.GetKind() == TokKind::kRparen) {
796       break;
797     }
798     if (!ParseToken(TokKind::kComma, "Expects , between parameter shapes")) {
799       return false;
800     }
801   }
802 
803   if (!ParseToken(TokKind::kRparen,
804                   "Expects ) at end of parameter shape list")) {
805     return false;
806   }
807   if (!ParseToken(TokKind::kArrow, "Expects -> before result shape")) {
808     return false;
809   }
810   Shape result;
811   if (!ParseShape(&result)) {
812     return false;
813   }
814   *computation_layout->mutable_result_layout() = ShapeLayout(result);
815   if (!ParseToken(TokKind::kRbrace,
816                   "Expects '}' at the end of computation layouts")) {
817     return false;
818   }
819   return true;
820 }
821 
ParseInstructionOutputOperandAliasing(std::vector<std::pair<ShapeIndex,std::pair<int64_t,ShapeIndex>>> * aliasing_output_operand_pairs)822 bool HloParserImpl::ParseInstructionOutputOperandAliasing(
823     std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
824         aliasing_output_operand_pairs) {
825   if (!ParseToken(
826           TokKind::kLbrace,
827           "Expects '{' at the start of instruction aliasing description")) {
828     return false;
829   }
830 
831   while (lexer_.GetKind() != TokKind::kRbrace) {
832     ShapeIndex out;
833     if (!ParseShapeIndex(&out)) {
834       return false;
835     }
836     std::string errmsg =
837         "Expected format: <output_shape_index>: (<operand_index>, "
838         "<operand_shape_index>)";
839     if (!ParseToken(TokKind::kColon, errmsg)) {
840       return false;
841     }
842 
843     if (!ParseToken(TokKind::kLparen, errmsg)) {
844       return false;
845     }
846     int64_t operand_index;
847     ParseInt64(&operand_index);
848     if (!ParseToken(TokKind::kComma, errmsg)) {
849       return false;
850     }
851     ShapeIndex operand_shape_index;
852     if (!ParseShapeIndex(&operand_shape_index)) {
853       return false;
854     }
855 
856     aliasing_output_operand_pairs->emplace_back(
857         out,
858         std::pair<int64_t, ShapeIndex>{operand_index, operand_shape_index});
859     if (!ParseToken(TokKind::kRparen, errmsg)) {
860       return false;
861     }
862 
863     if (!EatIfPresent(TokKind::kComma)) {
864       break;
865     }
866   }
867   if (!ParseToken(
868           TokKind::kRbrace,
869           "Expects '}' at the end of instruction aliasing description")) {
870     return false;
871   }
872   return true;
873 }
874 
ParseCustomCallSchedule(CustomCallSchedule * result)875 bool HloParserImpl::ParseCustomCallSchedule(CustomCallSchedule* result) {
876   VLOG(3) << "ParseCustomCallSchedule";
877   if (lexer_.GetKind() != TokKind::kIdent) {
878     return TokenError("expects custom-call schedule");
879   }
880   std::string val = lexer_.GetStrVal();
881   auto status_or_result = StringToCustomCallSchedule(val);
882   if (!status_or_result.ok()) {
883     return TokenError(
884         StrFormat("expects custom-call schedule but sees: %s, error: %s", val,
885                   status_or_result.status().error_message()));
886   }
887   *result = status_or_result.ValueOrDie();
888   lexer_.Lex();
889   return true;
890 }
891 
ParseCustomCallApiVersion(CustomCallApiVersion * result)892 bool HloParserImpl::ParseCustomCallApiVersion(CustomCallApiVersion* result) {
893   VLOG(3) << "ParseCustomCallApiVersion";
894   if (lexer_.GetKind() != TokKind::kIdent) {
895     return TokenError("expects custom-call API version");
896   }
897   std::string val = lexer_.GetStrVal();
898   auto status_or_result = StringToCustomCallApiVersion(val);
899   if (!status_or_result.ok()) {
900     return TokenError(
901         StrFormat("expects custom-call API version but sees: %s, error: %s",
902                   val, status_or_result.status().error_message()));
903   }
904   *result = status_or_result.ValueOrDie();
905   lexer_.Lex();
906   return true;
907 }
908 
909 // ::= 'HloModule' name computations
ParseHloModule(HloModule * module)910 bool HloParserImpl::ParseHloModule(HloModule* module) {
911   if (lexer_.GetKind() != TokKind::kw_HloModule) {
912     return TokenError("expects HloModule");
913   }
914   // Eat 'HloModule'
915   lexer_.Lex();
916 
917   std::string name;
918   if (!ParseName(&name)) {
919     return false;
920   }
921 
922   std::optional<bool> is_scheduled;
923   std::optional<AliasingData> aliasing_data;
924   std::optional<bool> alias_passthrough_params;
925   absl::flat_hash_map<std::string, AttrConfig> attrs;
926   std::optional<ComputationLayout> entry_computation_layout;
927 
928   attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
929   attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
930                                  &aliasing_data};
931   attrs["alias_passthrough_params"] = {/*required=*/false, AttrTy::kBool,
932                                        &alias_passthrough_params};
933   attrs["entry_computation_layout"] = {/*required=*/false,
934                                        AttrTy::kComputationLayout,
935                                        &entry_computation_layout};
936   if (!ParseAttributes(attrs)) {
937     return false;
938   }
939   module->set_name(name);
940   if (!ParseComputations(module)) {
941     return false;
942   }
943 
944   if (is_scheduled.has_value() && *is_scheduled) {
945     TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
946   }
947   HloModuleConfig config = module->config();
948   bool default_config = true;
949   if (alias_passthrough_params.has_value() && *alias_passthrough_params) {
950     config.set_alias_passthrough_params(true);
951     default_config = false;
952   }
953   if (entry_computation_layout.has_value()) {
954     *config.mutable_entry_computation_layout() = *entry_computation_layout;
955     default_config = false;
956   }
957   if (!default_config) {
958     module->set_config(config);
959   }
960   if (aliasing_data) {
961     HloInputOutputAliasConfig alias_config(module->result_shape());
962     for (auto& p : *aliasing_data) {
963       Status st =
964           alias_config.SetUpAlias(p.first, p.second.parameter_number,
965                                   p.second.parameter_index, p.second.kind);
966       if (!st.ok()) {
967         return TokenError(st.error_message());
968       }
969     }
970     module->input_output_alias_config() = alias_config;
971   }
972 
973   return true;
974 }
975 
976 // computations ::= (computation)+
ParseComputations(HloModule * module)977 bool HloParserImpl::ParseComputations(HloModule* module) {
978   HloComputation* entry_computation = nullptr;
979   do {
980     if (!ParseComputation(&entry_computation)) {
981       return false;
982     }
983   } while (lexer_.GetKind() != TokKind::kEof);
984 
985   for (int i = 0; i < computations_.size(); i++) {
986     // If entry_computation is not nullptr, it means the computation it pointed
987     // to is marked with "ENTRY"; otherwise, no computation is marked with
988     // "ENTRY", and we use the last computation as the entry computation. We
989     // add the non-entry computations as embedded computations to the module.
990     if ((entry_computation != nullptr &&
991          computations_[i].get() != entry_computation) ||
992         (entry_computation == nullptr && i != computations_.size() - 1)) {
993       module->AddEmbeddedComputation(std::move(computations_[i]));
994       continue;
995     }
996     auto computation = module->AddEntryComputation(std::move(computations_[i]));
997     // The parameters and result layouts were set to default layout. Here we
998     // set the layouts to what the hlo text says.
999     for (int p = 0; p < computation->num_parameters(); p++) {
1000       const Shape& param_shape = computation->parameter_instruction(p)->shape();
1001       TF_CHECK_OK(module->mutable_entry_computation_layout()
1002                       ->mutable_parameter_layout(p)
1003                       ->CopyLayoutFromShape(param_shape));
1004     }
1005     const Shape& result_shape = computation->root_instruction()->shape();
1006     TF_CHECK_OK(module->mutable_entry_computation_layout()
1007                     ->mutable_result_layout()
1008                     ->CopyLayoutFromShape(result_shape));
1009   }
1010   return true;
1011 }
1012 
1013 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list(,
1014 // 'execution_thread='execution_thread)?
ParseComputation(HloComputation ** entry_computation)1015 bool HloParserImpl::ParseComputation(HloComputation** entry_computation) {
1016   LocTy maybe_entry_loc = lexer_.GetLoc();
1017   const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
1018 
1019   std::string name;
1020   LocTy name_loc = lexer_.GetLoc();
1021   if (!ParseName(&name)) {
1022     return false;
1023   }
1024 
1025   LocTy shape_loc = nullptr;
1026   Shape shape;
1027   if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
1028     return false;
1029   }
1030 
1031   HloComputation* computation = nullptr;
1032   if (!ParseInstructionList(&computation, name)) {
1033     return false;
1034   }
1035 
1036   // If param_list_to_shape was present, check compatibility.
1037   if (shape_loc != nullptr &&
1038       !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
1039     return Error(
1040         shape_loc,
1041         StrCat(
1042             "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
1043             ", is not compatible with that of its root instruction ",
1044             computation->root_instruction()->name(), ", ",
1045             ShapeUtil::HumanString(computation->root_instruction()->shape())));
1046   }
1047   absl::flat_hash_map<std::string, AttrConfig> attrs;
1048   optional<std::string> execution_thread = HloInstruction::kMainExecutionThread;
1049   attrs["execution_thread"] = {/*required=*/false, AttrTy::kString,
1050                                &execution_thread};
1051   if (!ParseAttributes(attrs)) {
1052     return false;
1053   }
1054   computation->SetExecutionThread(*execution_thread);
1055   if (is_entry_computation) {
1056     if (*entry_computation != nullptr) {
1057       return Error(maybe_entry_loc, "expects only one ENTRY");
1058     }
1059     *entry_computation = computation;
1060   }
1061 
1062   return AddComputation(name, computation, name_loc);
1063 }
1064 
1065 // instruction_list ::= '{' instruction_list1 '}'
1066 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation ** computation,const std::string & computation_name)1067 bool HloParserImpl::ParseInstructionList(HloComputation** computation,
1068                                          const std::string& computation_name) {
1069   Scope scope(&scoped_name_tables_);
1070   HloComputation::Builder builder(computation_name);
1071   if (!ParseToken(TokKind::kLbrace,
1072                   "expects '{' at the beginning of instruction list.")) {
1073     return false;
1074   }
1075   std::string root_name;
1076   do {
1077     if (!ParseInstruction(&builder, &root_name)) {
1078       return false;
1079     }
1080   } while (lexer_.GetKind() != TokKind::kRbrace);
1081   if (!ParseToken(TokKind::kRbrace,
1082                   "expects '}' at the end of instruction list.")) {
1083     return false;
1084   }
1085   HloInstruction* root = nullptr;
1086   if (!root_name.empty()) {
1087     std::pair<HloInstruction*, LocTy>* root_node =
1088         tensorflow::gtl::FindOrNull(current_name_table(), root_name);
1089 
1090     // This means some instruction was marked as ROOT but we didn't find it in
1091     // the pool, which should not happen.
1092     if (root_node == nullptr) {
1093       // LOG(FATAL) crashes the program by calling abort().
1094       LOG(FATAL) << "instruction " << root_name
1095                  << " was marked as ROOT but the parser has not seen it before";
1096     }
1097     root = root_node->first;
1098   }
1099 
1100   // Now root can be either an existing instruction or a nullptr. If it's a
1101   // nullptr, the implementation of Builder will set the last instruction as
1102   // the root instruction.
1103   computations_.emplace_back(builder.Build(root));
1104   *computation = computations_.back().get();
1105   return true;
1106 }
1107 
1108 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,std::string * root_name)1109 bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder,
1110                                      std::string* root_name) {
1111   std::string name;
1112   LocTy maybe_root_loc = lexer_.GetLoc();
1113   bool is_root = EatIfPresent(TokKind::kw_ROOT);
1114 
1115   const LocTy name_loc = lexer_.GetLoc();
1116   if (!ParseName(&name) ||
1117       !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
1118     return false;
1119   }
1120 
1121   if (is_root) {
1122     if (!root_name->empty()) {
1123       return Error(maybe_root_loc, "one computation should have only one ROOT");
1124     }
1125     *root_name = name;
1126   }
1127 
1128   return ParseInstructionRhs(builder, name, name_loc);
1129 }
1130 
ParseInstructionRhs(HloComputation::Builder * builder,std::string name,LocTy name_loc,bool allow_attributes)1131 bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
1132                                         std::string name, LocTy name_loc,
1133                                         bool allow_attributes) {
1134   Shape shape;
1135   HloOpcode opcode;
1136   std::optional<HloOpcode> async_wrapped_opcode;
1137   std::vector<HloInstruction*> operands;
1138 
1139   const bool parse_shape = CanBeShape();
1140   if ((parse_shape && !ParseShape(&shape)) ||
1141       !ParseOpcode(&opcode, &async_wrapped_opcode)) {
1142     return false;
1143   }
1144   if (!parse_shape && !CanInferShape(opcode)) {
1145     return TokenError(StrFormat("cannot infer shape for opcode: %s",
1146                                 HloOpcodeString(opcode)));
1147   }
1148 
1149   // Add optional attributes. These are added to any HloInstruction type if
1150   // present.
1151   absl::flat_hash_map<std::string, AttrConfig> attrs;
1152   optional<OpSharding> sharding;
1153   optional<FrontendAttributes> frontend_attributes;
1154   attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
1155   attrs["frontend_attributes"] = {
1156       /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
1157   optional<ParameterReplication> parameter_replication;
1158   attrs["parameter_replication"] = {/*required=*/false,
1159                                     AttrTy::kParameterReplication,
1160                                     &parameter_replication};
1161   optional<std::vector<HloInstruction*>> predecessors;
1162   attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
1163                                    &predecessors};
1164   optional<OpMetadata> metadata;
1165   attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
1166 
1167   optional<std::string> backend_config;
1168   attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
1169                              &backend_config};
1170   optional<std::vector<int64_t>> outer_dimension_partitions;
1171   attrs["outer_dimension_partitions"] = {/*required=*/false,
1172                                          AttrTy::kBracedInt64List,
1173                                          &outer_dimension_partitions};
1174 
1175   std::optional<Shape> maybe_shape;
1176   if (parse_shape) {
1177     maybe_shape = shape;
1178   }
1179   HloInstruction* instruction =
1180       CreateInstruction(builder, name, maybe_shape, opcode,
1181                         async_wrapped_opcode, attrs, allow_attributes);
1182   if (instruction == nullptr) {
1183     return false;
1184   }
1185 
1186   // Generate a unique name if the name is empty.  This is used for nested
1187   // instructions (e.g. the `max` in add(max(x, y), z)).
1188   //
1189   // Otherwise, register the given name with the name uniquer.
1190   if (name.empty()) {
1191     name = name_uniquer_.GetUniqueName(
1192         absl::StrCat(HloOpcodeString(instruction->opcode()), ".anon"));
1193   } else {
1194     name_uniquer_.GetUniqueName(name);
1195   }
1196 
1197   instruction->SetAndSanitizeName(name);
1198   if (instruction->name() != name) {
1199     return Error(name_loc,
1200                  StrCat("illegal instruction name: ", name,
1201                         "; suggest renaming to: ", instruction->name()));
1202   }
1203 
1204   // Add shared attributes like metadata to the instruction, if they were seen.
1205   if (sharding) {
1206     instruction->set_sharding(
1207         HloSharding::FromProto(sharding.value()).ValueOrDie());
1208   }
1209   if (parameter_replication) {
1210     int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
1211     const auto& replicated =
1212         parameter_replication->replicated_at_leaf_buffers();
1213     if (leaf_count != replicated.size()) {
1214       return Error(lexer_.GetLoc(),
1215                    StrCat("parameter has ", leaf_count,
1216                           " leaf buffers, but parameter_replication has ",
1217                           replicated.size(), " elements."));
1218     }
1219     instruction->set_parameter_replicated_at_leaf_buffers(replicated);
1220   }
1221   if (predecessors) {
1222     for (auto* pre : *predecessors) {
1223       Status status = pre->AddControlDependencyTo(instruction);
1224       if (!status.ok()) {
1225         return Error(name_loc, StrCat("error adding control dependency for: ",
1226                                       name, " status: ", status.ToString()));
1227       }
1228     }
1229   }
1230   if (metadata) {
1231     instruction->set_metadata(*metadata);
1232   }
1233   if (backend_config) {
1234     instruction->set_raw_backend_config_string(std::move(*backend_config));
1235   }
1236   if (outer_dimension_partitions) {
1237     instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
1238   }
1239   if (frontend_attributes) {
1240     instruction->set_frontend_attributes(*frontend_attributes);
1241   }
1242   return AddInstruction(name, instruction, name_loc);
1243 }
1244 
CreateInstruction(HloComputation::Builder * builder,absl::string_view name,std::optional<Shape> shape,HloOpcode opcode,std::optional<HloOpcode> async_wrapped_opcode,absl::flat_hash_map<std::string,AttrConfig> & attrs,bool allow_attributes,std::vector<HloInstruction * > * preset_operands)1245 HloInstruction* HloParserImpl::CreateInstruction(  // NOLINT
1246     HloComputation::Builder* builder, absl::string_view name,
1247     std::optional<Shape> shape, HloOpcode opcode,
1248     std::optional<HloOpcode> async_wrapped_opcode,
1249     absl::flat_hash_map<std::string, AttrConfig>& attrs, bool allow_attributes,
1250     std::vector<HloInstruction*>* preset_operands) {
1251   std::vector<HloInstruction*> operands;
1252   if (preset_operands) {
1253     operands = *preset_operands;
1254   }
1255   const auto maybe_infer_shape =
1256       [&](const std::function<StatusOr<Shape>()>& infer) {
1257         if (shape.has_value()) {
1258           return true;
1259         }
1260         auto inferred = infer();
1261         if (!inferred.ok()) {
1262           return TokenError(StrFormat(
1263               "failed to infer shape for opcode: %s, error: %s",
1264               HloOpcodeString(opcode), inferred.status().error_message()));
1265         }
1266         shape = std::move(inferred).ValueOrDie();
1267         return true;
1268       };
1269 
1270   switch (opcode) {
1271     case HloOpcode::kParameter: {
1272       int64_t parameter_number;
1273       if (!ParseToken(TokKind::kLparen,
1274                       "expects '(' before parameter number") ||
1275           !ParseInt64(&parameter_number)) {
1276         return nullptr;
1277       }
1278       if (parameter_number < 0) {
1279         Error(lexer_.GetLoc(), "parameter number must be >= 0");
1280         return nullptr;
1281       }
1282       if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
1283           !ParseAttributes(attrs, allow_attributes)) {
1284         return nullptr;
1285       }
1286       std::string param_name(name);
1287       return builder->AddInstruction(HloInstruction::CreateParameter(
1288           parameter_number, *shape, param_name));
1289     }
1290     case HloOpcode::kConstant: {
1291       Literal literal;
1292       if (!ParseToken(TokKind::kLparen,
1293                       "expects '(' before constant literal") ||
1294           !ParseLiteral(&literal, *shape) ||
1295           !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
1296           !ParseAttributes(attrs, allow_attributes)) {
1297         return nullptr;
1298       }
1299       return builder->AddInstruction(
1300           HloInstruction::CreateConstant(std::move(literal)));
1301     }
1302     case HloOpcode::kIota: {
1303       optional<int64_t> iota_dimension;
1304       attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
1305                                  &iota_dimension};
1306       if ((!preset_operands &&
1307            !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
1308           !ParseAttributes(attrs, allow_attributes)) {
1309         return nullptr;
1310       }
1311       return builder->AddInstruction(
1312           HloInstruction::CreateIota(*shape, *iota_dimension));
1313     }
1314     // Unary ops.
1315     case HloOpcode::kAbs:
1316     case HloOpcode::kAllGatherDone:
1317     case HloOpcode::kAllReduceDone:
1318     case HloOpcode::kRoundNearestAfz:
1319     case HloOpcode::kRoundNearestEven:
1320     case HloOpcode::kBitcast:
1321     case HloOpcode::kCeil:
1322     case HloOpcode::kClz:
1323     case HloOpcode::kCollectivePermuteDone:
1324     case HloOpcode::kCopy:
1325     case HloOpcode::kCopyDone:
1326     case HloOpcode::kCos:
1327     case HloOpcode::kOptimizationBarrier:
1328     case HloOpcode::kExp:
1329     case HloOpcode::kExpm1:
1330     case HloOpcode::kImag:
1331     case HloOpcode::kIsFinite:
1332     case HloOpcode::kFloor:
1333     case HloOpcode::kLog:
1334     case HloOpcode::kLog1p:
1335     case HloOpcode::kLogistic:
1336     case HloOpcode::kNot:
1337     case HloOpcode::kNegate:
1338     case HloOpcode::kPopulationCount:
1339     case HloOpcode::kReal:
1340     case HloOpcode::kRsqrt:
1341     case HloOpcode::kSign:
1342     case HloOpcode::kSin:
1343     case HloOpcode::kSqrt:
1344     case HloOpcode::kCbrt:
1345     case HloOpcode::kTanh: {
1346       if ((!preset_operands &&
1347            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1348           !ParseAttributes(attrs, allow_attributes)) {
1349         return nullptr;
1350       }
1351       if (!maybe_infer_shape([&] {
1352             return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
1353           })) {
1354         return nullptr;
1355       }
1356       return builder->AddInstruction(
1357           HloInstruction::CreateUnary(*shape, opcode, operands[0]));
1358     }
1359     // Binary ops.
1360     case HloOpcode::kAdd:
1361     case HloOpcode::kDivide:
1362     case HloOpcode::kMultiply:
1363     case HloOpcode::kSubtract:
1364     case HloOpcode::kAtan2:
1365     case HloOpcode::kComplex:
1366     case HloOpcode::kMaximum:
1367     case HloOpcode::kMinimum:
1368     case HloOpcode::kPower:
1369     case HloOpcode::kRemainder:
1370     case HloOpcode::kAnd:
1371     case HloOpcode::kOr:
1372     case HloOpcode::kXor:
1373     case HloOpcode::kShiftLeft:
1374     case HloOpcode::kShiftRightArithmetic:
1375     case HloOpcode::kShiftRightLogical: {
1376       if ((!preset_operands &&
1377            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1378           !ParseAttributes(attrs, allow_attributes)) {
1379         return nullptr;
1380       }
1381       if (!maybe_infer_shape([&] {
1382             return ShapeInference::InferBinaryOpShape(opcode, operands[0],
1383                                                       operands[1]);
1384           })) {
1385         return nullptr;
1386       }
1387       return builder->AddInstruction(HloInstruction::CreateBinary(
1388           *shape, opcode, operands[0], operands[1]));
1389     }
1390     // Ternary ops.
1391     case HloOpcode::kClamp:
1392     case HloOpcode::kSelect: {
1393       if ((!preset_operands &&
1394            !ParseOperands(&operands, builder, /*expected_size=*/3)) ||
1395           !ParseAttributes(attrs, allow_attributes)) {
1396         return nullptr;
1397       }
1398       if (!maybe_infer_shape([&] {
1399             return ShapeInference::InferTernaryOpShape(
1400                 opcode, operands[0], operands[1], operands[2]);
1401           })) {
1402         return nullptr;
1403       }
1404       return builder->AddInstruction(HloInstruction::CreateTernary(
1405           *shape, opcode, operands[0], operands[1], operands[2]));
1406     }
1407     // Other supported ops.
1408     case HloOpcode::kConvert: {
1409       if ((!preset_operands &&
1410            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1411           !ParseAttributes(attrs, allow_attributes)) {
1412         return nullptr;
1413       }
1414       return builder->AddInstruction(
1415           HloInstruction::CreateConvert(*shape, operands[0]));
1416     }
1417     case HloOpcode::kBitcastConvert: {
1418       if ((!preset_operands &&
1419            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1420           !ParseAttributes(attrs, allow_attributes)) {
1421         return nullptr;
1422       }
1423       return builder->AddInstruction(
1424           HloInstruction::CreateBitcastConvert(*shape, operands[0]));
1425     }
1426     case HloOpcode::kAllGather:
1427     case HloOpcode::kAllGatherStart: {
1428       optional<std::vector<std::vector<int64_t>>> tmp_groups;
1429       optional<std::vector<int64_t>> replica_group_ids;
1430       optional<int64_t> channel_id;
1431       optional<std::vector<int64_t>> dimensions;
1432       optional<bool> constrain_layout;
1433       optional<bool> use_global_device_ids;
1434       attrs["replica_groups"] = {/*required=*/false,
1435                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1436       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1437       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1438                              &dimensions};
1439       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1440                                    &constrain_layout};
1441       attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1442                                         &use_global_device_ids};
1443       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1444           !ParseAttributes(attrs, allow_attributes)) {
1445         return nullptr;
1446       }
1447       std::vector<ReplicaGroup> replica_groups;
1448       if (tmp_groups) {
1449         replica_groups = CreateReplicaGroups(*tmp_groups);
1450       }
1451       if (opcode == HloOpcode::kAllGather) {
1452         return builder->AddInstruction(HloInstruction::CreateAllGather(
1453             *shape, operands, dimensions->at(0), replica_groups,
1454             constrain_layout ? *constrain_layout : false, channel_id,
1455             use_global_device_ids ? *use_global_device_ids : false));
1456       }
1457       return builder->AddInstruction(HloInstruction::CreateAllGatherStart(
1458           *shape, operands, dimensions->at(0), replica_groups,
1459           constrain_layout ? *constrain_layout : false, channel_id,
1460           use_global_device_ids ? *use_global_device_ids : false));
1461     }
1462     case HloOpcode::kAllReduce:
1463     case HloOpcode::kAllReduceStart:
1464     case HloOpcode::kReduceScatter: {
1465       optional<std::vector<std::vector<int64_t>>> tmp_groups;
1466       optional<HloComputation*> to_apply;
1467       optional<std::vector<int64_t>> replica_group_ids;
1468       optional<int64_t> channel_id;
1469       optional<bool> constrain_layout;
1470       optional<bool> use_global_device_ids;
1471       optional<std::vector<int64_t>> dimensions;
1472       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1473                            &to_apply};
1474       attrs["replica_groups"] = {/*required=*/false,
1475                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1476       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1477       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1478                                    &constrain_layout};
1479       attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool,
1480                                         &use_global_device_ids};
1481       if (opcode == HloOpcode::kReduceScatter) {
1482         attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1483                                &dimensions};
1484       }
1485       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1486           !ParseAttributes(attrs, allow_attributes)) {
1487         return nullptr;
1488       }
1489       std::vector<ReplicaGroup> replica_groups;
1490       if (tmp_groups) {
1491         replica_groups = CreateReplicaGroups(*tmp_groups);
1492       }
1493       if (opcode == HloOpcode::kAllReduce) {
1494         return builder->AddInstruction(HloInstruction::CreateAllReduce(
1495             *shape, operands, *to_apply, replica_groups,
1496             constrain_layout ? *constrain_layout : false, channel_id,
1497             use_global_device_ids ? *use_global_device_ids : false));
1498       } else if (opcode == HloOpcode::kReduceScatter) {
1499         return builder->AddInstruction(HloInstruction::CreateReduceScatter(
1500             *shape, operands, *to_apply, replica_groups,
1501             constrain_layout ? *constrain_layout : false, channel_id,
1502             use_global_device_ids ? *use_global_device_ids : false,
1503             dimensions->at(0)));
1504       }
1505       return builder->AddInstruction(HloInstruction::CreateAllReduceStart(
1506           *shape, operands, *to_apply, replica_groups,
1507           constrain_layout ? *constrain_layout : false, channel_id,
1508           use_global_device_ids ? *use_global_device_ids : false));
1509     }
1510     case HloOpcode::kAllToAll: {
1511       optional<std::vector<std::vector<int64_t>>> tmp_groups;
1512       attrs["replica_groups"] = {/*required=*/false,
1513                                  AttrTy::kBracedInt64ListList, &tmp_groups};
1514       optional<int64_t> channel_id;
1515       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1516       optional<std::vector<int64_t>> dimensions;
1517       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
1518                              &dimensions};
1519       optional<bool> constrain_layout;
1520       attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool,
1521                                    &constrain_layout};
1522       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1523           !ParseAttributes(attrs, allow_attributes) ||
1524           (dimensions && dimensions->size() != 1)) {
1525         return nullptr;
1526       }
1527       std::vector<ReplicaGroup> replica_groups;
1528       if (tmp_groups) {
1529         replica_groups = CreateReplicaGroups(*tmp_groups);
1530       }
1531       optional<int64_t> split_dimension;
1532       if (dimensions) {
1533         split_dimension = dimensions->at(0);
1534       }
1535       return builder->AddInstruction(HloInstruction::CreateAllToAll(
1536           *shape, operands, replica_groups,
1537           constrain_layout ? *constrain_layout : false, channel_id,
1538           split_dimension));
1539     }
1540     case HloOpcode::kCollectivePermute:
1541     case HloOpcode::kCollectivePermuteStart: {
1542       optional<std::vector<std::vector<int64_t>>> source_targets;
1543       attrs["source_target_pairs"] = {
1544           /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
1545       optional<int64_t> channel_id;
1546       attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
1547       optional<std::vector<std::vector<int64_t>>> slice_sizes;
1548       attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList,
1549                               &slice_sizes};
1550       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1551           !ParseAttributes(attrs, allow_attributes)) {
1552         return nullptr;
1553       }
1554       std::vector<std::pair<int64_t, int64_t>> pairs(source_targets->size());
1555       for (int i = 0; i < pairs.size(); i++) {
1556         if ((*source_targets)[i].size() != 2) {
1557           TokenError("expects 'source_target_pairs=' to be a list of pairs");
1558           return nullptr;
1559         }
1560         pairs[i].first = (*source_targets)[i][0];
1561         pairs[i].second = (*source_targets)[i][1];
1562       }
1563       if (!slice_sizes.has_value()) {
1564         if (operands.size() != 1) {
1565           TokenError(
1566               "CollectivePermute and CollectivePermuteStart must "
1567               "have exactly  one operand (input buffer) unless "
1568               "it performs dynamic-slice and in-place update.");
1569           return nullptr;
1570         }
1571         if (opcode == HloOpcode::kCollectivePermute) {
1572           return builder->AddInstruction(
1573               HloInstruction::CreateCollectivePermute(*shape, operands[0],
1574                                                       pairs, channel_id));
1575         }
1576         if (opcode == HloOpcode::kCollectivePermuteStart) {
1577           return builder->AddInstruction(
1578               HloInstruction::CreateCollectivePermuteStart(*shape, operands[0],
1579                                                            pairs, channel_id));
1580         }
1581         LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1582                       "CollectivePermuteStart, but got "
1583                    << HloOpcodeString(opcode);
1584       }
1585       if (operands.size() != 4) {
1586         TokenError(
1587             "CollectivePermute and CollectivePermuteStart must "
1588             "have exactly four operands for dynamic-slice and "
1589             "in-place update.");
1590         return nullptr;
1591       }
1592       if (opcode == HloOpcode::kCollectivePermute) {
1593         return builder->AddInstruction(HloInstruction::CreateCollectivePermute(
1594             *shape, operands[0], operands[1], operands[2], operands[3], pairs,
1595             *slice_sizes, channel_id));
1596       }
1597       if (opcode == HloOpcode::kCollectivePermuteStart) {
1598         return builder->AddInstruction(
1599             HloInstruction::CreateCollectivePermuteStart(
1600                 *shape, operands[0], operands[1], operands[2], operands[3],
1601                 pairs, *slice_sizes, channel_id));
1602       }
1603       LOG(FATAL) << "Expect opcode to be CollectivePermute or "
1604                     "CollectivePermuteStart, but got "
1605                  << HloOpcodeString(opcode);
1606     }
1607     case HloOpcode::kAsyncStart:
1608     case HloOpcode::kAsyncUpdate:
1609     case HloOpcode::kAsyncDone: {
1610       std::optional<HloComputation*> async_computation;
1611       if (!preset_operands && !ParseOperands(&operands, builder)) {
1612         return nullptr;
1613       }
1614       auto is_async_shape_correct = [](const Shape& shape) {
1615         return shape.IsTuple() && shape.tuple_shapes_size() >= 2 &&
1616                shape.tuple_shapes(0).IsTuple();
1617       };
1618       optional<int64_t> async_group_id;
1619       attrs["async_group_id"] = {/*required=*/false, AttrTy::kInt64,
1620                                  &async_group_id};
1621       optional<std::string> async_execution_thread =
1622           HloInstruction::kMainExecutionThread;
1623       attrs["async_execution_thread"] = {/*required=*/false, AttrTy::kString,
1624                                          &async_execution_thread};
1625       if (async_wrapped_opcode) {
1626         std::vector<HloInstruction*> async_wrapped_operands;
1627         std::vector<Shape> async_wrapped_operand_shapes;
1628         Shape async_wrapped_root_shape;
1629         if (opcode == HloOpcode::kAsyncStart) {
1630           for (const HloInstruction* operand : operands) {
1631             async_wrapped_operand_shapes.push_back(operand->shape());
1632           }
1633         } else {
1634           if (operands.size() != 1 ||
1635               !is_async_shape_correct(operands[0]->shape())) {
1636             TokenError(
1637                 "AsyncUpdate and AsyncDone expect a single operand in the form "
1638                 "of "
1639                 "((async-operands), async-outputs, state).");
1640             return nullptr;
1641           }
1642           async_wrapped_operand_shapes =
1643               operands[0]->shape().tuple_shapes(0).tuple_shapes();
1644         }
1645 
1646         if (opcode == HloOpcode::kAsyncDone) {
1647           async_wrapped_root_shape = *shape;
1648         } else {
1649           if (!is_async_shape_correct(*shape)) {
1650             TokenError(
1651                 "AsyncStart and AsyncUpdate expect the op shape to be in the "
1652                 "form of "
1653                 "((async-operands), async-outputs, state).");
1654             return nullptr;
1655           }
1656           async_wrapped_root_shape = shape->tuple_shapes(1);
1657         }
1658         HloComputation::Builder async_wrapped_builder("async_wrapped");
1659         async_wrapped_operands.reserve(async_wrapped_operand_shapes.size());
1660         for (int i = 0; i < async_wrapped_operand_shapes.size(); ++i) {
1661           async_wrapped_operands.push_back(async_wrapped_builder.AddInstruction(
1662               HloInstruction::CreateParameter(
1663                   i, async_wrapped_operand_shapes.at(i), "async_param")));
1664         }
1665         HloInstruction* root =
1666             CreateInstruction(&async_wrapped_builder, "async_op",
1667                               async_wrapped_root_shape, *async_wrapped_opcode,
1668                               /*async_wrapped_opcode=*/std::nullopt, attrs,
1669                               allow_attributes, &async_wrapped_operands);
1670         if (!root) {
1671           return nullptr;
1672         }
1673         computations_.emplace_back(async_wrapped_builder.Build(root));
1674         async_computation = computations_.back().get();
1675 
1676       } else {
1677         attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
1678                           &async_computation};
1679         if (!ParseAttributes(attrs, allow_attributes)) {
1680           return nullptr;
1681         }
1682       }
1683       if (opcode == HloOpcode::kAsyncStart) {
1684         return builder->AddInstruction(HloInstruction::CreateAsyncStart(
1685             *shape, operands, *async_computation, async_group_id,
1686             *async_execution_thread));
1687       }
1688       if (opcode == HloOpcode::kAsyncUpdate) {
1689         return builder->AddInstruction(HloInstruction::CreateAsyncUpdate(
1690             *shape, operands[0], *async_computation, async_group_id,
1691             *async_execution_thread));
1692       }
1693       return builder->AddInstruction(HloInstruction::CreateAsyncDone(
1694           *shape, operands[0], *async_computation, async_group_id,
1695           *async_execution_thread));
1696     }
1697     case HloOpcode::kCopyStart: {
1698       // If the is_cross_program_prefetch attribute is not present then default
1699       // to false.
1700       optional<bool> is_cross_program_prefetch = false;
1701       attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
1702                                             &is_cross_program_prefetch};
1703       if ((!preset_operands &&
1704            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1705           !ParseAttributes(attrs, allow_attributes)) {
1706         return nullptr;
1707       }
1708       return builder->AddInstruction(HloInstruction::CreateCopyStart(
1709           *shape, operands[0], *is_cross_program_prefetch));
1710     }
1711     case HloOpcode::kReplicaId: {
1712       if ((!preset_operands &&
1713            !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
1714           !ParseAttributes(attrs, allow_attributes)) {
1715         return nullptr;
1716       }
1717       return builder->AddInstruction(HloInstruction::CreateReplicaId());
1718     }
1719     case HloOpcode::kPartitionId: {
1720       if ((!preset_operands &&
1721            !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
1722           !ParseAttributes(attrs, allow_attributes)) {
1723         return nullptr;
1724       }
1725       return builder->AddInstruction(HloInstruction::CreatePartitionId());
1726     }
1727     case HloOpcode::kDynamicReshape: {
1728       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1729           !ParseAttributes(attrs, allow_attributes)) {
1730         return nullptr;
1731       }
1732       return builder->AddInstruction(HloInstruction::CreateDynamicReshape(
1733           *shape, operands[0],
1734           absl::Span<HloInstruction* const>(operands).subspan(1)));
1735     }
1736     case HloOpcode::kReshape: {
1737       optional<int64_t> inferred_dimension;
1738       attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
1739                                      &inferred_dimension};
1740       if ((!preset_operands &&
1741            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1742           !ParseAttributes(attrs, allow_attributes)) {
1743         return nullptr;
1744       }
1745       return builder->AddInstruction(HloInstruction::CreateReshape(
1746           *shape, operands[0], inferred_dimension.value_or(-1)));
1747     }
1748     case HloOpcode::kAfterAll: {
1749       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1750           !ParseAttributes(attrs, allow_attributes)) {
1751         return nullptr;
1752       }
1753       if (operands.empty()) {
1754         return builder->AddInstruction(HloInstruction::CreateToken());
1755       }
1756       return builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
1757     }
1758     case HloOpcode::kAddDependency: {
1759       if ((!preset_operands &&
1760            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1761           !ParseAttributes(attrs, allow_attributes)) {
1762         return nullptr;
1763       }
1764       return builder->AddInstruction(
1765           HloInstruction::CreateAddDependency(operands[0], operands[1]));
1766     }
1767     case HloOpcode::kSort: {
1768       optional<std::vector<int64_t>> dimensions;
1769       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
1770                              &dimensions};
1771       optional<bool> is_stable = false;
1772       attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
1773       optional<HloComputation*> to_apply;
1774       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1775                            &to_apply};
1776       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1777           !ParseAttributes(attrs, allow_attributes) ||
1778           dimensions->size() != 1) {
1779         return nullptr;
1780       }
1781       if (!maybe_infer_shape([&] {
1782             absl::InlinedVector<const Shape*, 2> arg_shapes;
1783             arg_shapes.reserve(operands.size());
1784             for (auto* operand : operands) {
1785               arg_shapes.push_back(&operand->shape());
1786             }
1787             return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1788           })) {
1789         return nullptr;
1790       }
1791       return builder->AddInstruction(
1792           HloInstruction::CreateSort(*shape, dimensions->at(0), operands,
1793                                      to_apply.value(), is_stable.value()));
1794     }
1795     case HloOpcode::kTuple: {
1796       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1797           !ParseAttributes(attrs, allow_attributes)) {
1798         return nullptr;
1799       }
1800       if (!maybe_infer_shape([&] {
1801             absl::InlinedVector<const Shape*, 2> arg_shapes;
1802             arg_shapes.reserve(operands.size());
1803             for (auto* operand : operands) {
1804               arg_shapes.push_back(&operand->shape());
1805             }
1806             return ShapeInference::InferVariadicOpShape(opcode, arg_shapes);
1807           })) {
1808         return nullptr;
1809       }
1810       // HloInstruction::CreateTuple() infers the shape of the tuple from
1811       // operands and should not be used here.
1812       return builder->AddInstruction(
1813           HloInstruction::CreateVariadic(*shape, HloOpcode::kTuple, operands));
1814     }
1815     case HloOpcode::kWhile: {
1816       optional<HloComputation*> condition;
1817       optional<HloComputation*> body;
1818       attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
1819                             &condition};
1820       attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
1821       if ((!preset_operands &&
1822            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1823           !ParseAttributes(attrs, allow_attributes)) {
1824         return nullptr;
1825       }
1826       if (!maybe_infer_shape([&] {
1827             return ShapeInference::InferWhileShape(
1828                 condition.value()->ComputeProgramShape(),
1829                 body.value()->ComputeProgramShape(), operands[0]->shape());
1830           })) {
1831         return nullptr;
1832       }
1833       return builder->AddInstruction(HloInstruction::CreateWhile(
1834           *shape, *condition, *body, /*init=*/operands[0]));
1835     }
1836     case HloOpcode::kRecv: {
1837       optional<int64_t> channel_id;
1838       // If the is_host_transfer attribute is not present then default to false.
1839       optional<bool> is_host_transfer = false;
1840       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1841       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1842                                    &is_host_transfer};
1843       if ((!preset_operands &&
1844            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1845           !ParseAttributes(attrs, allow_attributes)) {
1846         return nullptr;
1847       }
1848       // If the is_host_transfer attribute is not present then default to false.
1849       return builder->AddInstruction(HloInstruction::CreateRecv(
1850           shape->tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
1851     }
1852     case HloOpcode::kRecvDone: {
1853       optional<int64_t> channel_id;
1854       // If the is_host_transfer attribute is not present then default to false.
1855       optional<bool> is_host_transfer = false;
1856       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1857       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1858                                    &is_host_transfer};
1859       if ((!preset_operands &&
1860            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1861           !ParseAttributes(attrs, allow_attributes)) {
1862         return nullptr;
1863       }
1864       if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1865         return nullptr;
1866       }
1867       if (channel_id != operands[0]->channel_id()) {
1868         return nullptr;
1869       }
1870       return builder->AddInstruction(
1871           HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
1872     }
1873     case HloOpcode::kSend: {
1874       optional<int64_t> channel_id;
1875       // If the is_host_transfer attribute is not present then default to false.
1876       optional<bool> is_host_transfer = false;
1877       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1878       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1879                                    &is_host_transfer};
1880       if ((!preset_operands &&
1881            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1882           !ParseAttributes(attrs, allow_attributes)) {
1883         return nullptr;
1884       }
1885       return builder->AddInstruction(HloInstruction::CreateSend(
1886           operands[0], operands[1], *channel_id, *is_host_transfer));
1887     }
1888     case HloOpcode::kSendDone: {
1889       optional<int64_t> channel_id;
1890       // If the is_host_transfer attribute is not present then default to false.
1891       optional<bool> is_host_transfer = false;
1892       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
1893       attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
1894                                    &is_host_transfer};
1895       if ((!preset_operands &&
1896            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1897           !ParseAttributes(attrs, allow_attributes)) {
1898         return nullptr;
1899       }
1900       if (dynamic_cast<const HloChannelInstruction*>(operands[0]) == nullptr) {
1901         return nullptr;
1902       }
1903       if (channel_id != operands[0]->channel_id()) {
1904         return nullptr;
1905       }
1906       return builder->AddInstruction(
1907           HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
1908     }
1909     case HloOpcode::kGetTupleElement: {
1910       optional<int64_t> index;
1911       attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
1912       if ((!preset_operands &&
1913            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
1914           !ParseAttributes(attrs, allow_attributes)) {
1915         return nullptr;
1916       }
1917       if (!maybe_infer_shape([&] {
1918             return ShapeUtil::GetTupleElementShape(operands[0]->shape(),
1919                                                    *index);
1920           })) {
1921         return nullptr;
1922       }
1923       return builder->AddInstruction(
1924           HloInstruction::CreateGetTupleElement(*shape, operands[0], *index));
1925     }
1926     case HloOpcode::kCall: {
1927       optional<HloComputation*> to_apply;
1928       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1929                            &to_apply};
1930       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1931           !ParseAttributes(attrs, allow_attributes)) {
1932         return nullptr;
1933       }
1934       if (!maybe_infer_shape([&] {
1935             absl::InlinedVector<const Shape*, 2> arg_shapes;
1936             arg_shapes.reserve(operands.size());
1937             for (auto* operand : operands) {
1938               arg_shapes.push_back(&operand->shape());
1939             }
1940             return ShapeInference::InferCallShape(
1941                 arg_shapes, to_apply.value()->ComputeProgramShape());
1942           })) {
1943         return nullptr;
1944       }
1945       return builder->AddInstruction(
1946           HloInstruction::CreateCall(*shape, operands, *to_apply));
1947     }
1948     case HloOpcode::kReduceWindow: {
1949       optional<HloComputation*> reduce_computation;
1950       optional<Window> window;
1951       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1952       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
1953                            &reduce_computation};
1954       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
1955           !ParseAttributes(attrs, allow_attributes)) {
1956         return nullptr;
1957       }
1958       if (!window) {
1959         window.emplace();
1960       }
1961       if (operands.size() % 2) {
1962         TokenError(StrCat("expects an even number of operands, but has ",
1963                           operands.size(), " operands"));
1964         return nullptr;
1965       }
1966       if (!maybe_infer_shape([&] {
1967             return ShapeInference::InferReduceWindowShape(
1968                 operands[0]->shape(), operands[1]->shape(), *window,
1969                 reduce_computation.value()->ComputeProgramShape());
1970           })) {
1971         return nullptr;
1972       }
1973       return builder->AddInstruction(HloInstruction::CreateReduceWindow(
1974           *shape, /*operands=*/
1975           absl::Span<HloInstruction* const>(operands).subspan(
1976               0, operands.size() / 2),
1977           /*init_values=*/
1978           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
1979                                                               2),
1980           *window, *reduce_computation));
1981     }
1982     case HloOpcode::kConvolution: {
1983       optional<Window> window;
1984       optional<ConvolutionDimensionNumbers> dnums;
1985       optional<int64_t> feature_group_count;
1986       optional<int64_t> batch_group_count;
1987       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
1988       attrs["dim_labels"] = {/*required=*/true,
1989                              AttrTy::kConvolutionDimensionNumbers, &dnums};
1990       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
1991                                       &feature_group_count};
1992       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
1993                                     &batch_group_count};
1994       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
1995       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
1996                                     &operand_precision};
1997       if ((!preset_operands &&
1998            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
1999           !ParseAttributes(attrs, allow_attributes)) {
2000         return nullptr;
2001       }
2002       if (!window) {
2003         window.emplace();
2004       }
2005       if (!feature_group_count) {
2006         feature_group_count = 1;
2007       }
2008       if (!batch_group_count) {
2009         batch_group_count = 1;
2010       }
2011       PrecisionConfig precision_config;
2012       if (operand_precision) {
2013         *precision_config.mutable_operand_precision() = {
2014             operand_precision->begin(), operand_precision->end()};
2015       } else {
2016         precision_config.mutable_operand_precision()->Resize(
2017             operands.size(), PrecisionConfig::DEFAULT);
2018       }
2019       if (!maybe_infer_shape([&] {
2020             return ShapeInference::InferConvolveShape(
2021                 operands[0]->shape(), operands[1]->shape(),
2022                 *feature_group_count, *batch_group_count, *window, *dnums,
2023                 /*preferred_element_type=*/std::nullopt);
2024           })) {
2025         return nullptr;
2026       }
2027       return builder->AddInstruction(HloInstruction::CreateConvolve(
2028           *shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
2029           feature_group_count.value(), batch_group_count.value(), *window,
2030           *dnums, precision_config));
2031     }
2032     case HloOpcode::kFft: {
2033       optional<FftType> fft_type;
2034       optional<std::vector<int64_t>> fft_length;
2035       attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
2036       attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
2037                              &fft_length};
2038       if ((!preset_operands &&
2039            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2040           !ParseAttributes(attrs, allow_attributes)) {
2041         return nullptr;
2042       }
2043       if (!maybe_infer_shape([&] {
2044             return ShapeInference::InferFftShape(operands[0]->shape(),
2045                                                  *fft_type, *fft_length);
2046           })) {
2047         return nullptr;
2048       }
2049       return builder->AddInstruction(HloInstruction::CreateFft(
2050           *shape, operands[0], *fft_type, *fft_length));
2051     }
2052     case HloOpcode::kTriangularSolve: {
2053       TriangularSolveOptions options;
2054       if ((!preset_operands &&
2055            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2056           (allow_attributes && !ParseAttributesAsProtoMessage(
2057                                    /*non_proto_attrs=*/attrs, &options))) {
2058         return nullptr;
2059       }
2060       if (!maybe_infer_shape([&] {
2061             return ShapeInference::InferTriangularSolveShape(
2062                 operands[0]->shape(), operands[1]->shape(), options);
2063           })) {
2064         return nullptr;
2065       }
2066       return builder->AddInstruction(HloInstruction::CreateTriangularSolve(
2067           *shape, operands[0], operands[1], options));
2068     }
2069     case HloOpcode::kCompare: {
2070       optional<ComparisonDirection> direction;
2071       optional<Comparison::Type> type;
2072       attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
2073                             &direction};
2074       attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
2075       if ((!preset_operands &&
2076            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2077           !ParseAttributes(attrs, allow_attributes)) {
2078         return nullptr;
2079       }
2080       if (!maybe_infer_shape([&] {
2081             return ShapeInference::InferBinaryOpShape(opcode, operands[0],
2082                                                       operands[1]);
2083           })) {
2084         return nullptr;
2085       }
2086       return builder->AddInstruction(HloInstruction::CreateCompare(
2087           *shape, operands[0], operands[1], *direction, type));
2088     }
2089     case HloOpcode::kCholesky: {
2090       CholeskyOptions options;
2091       if ((!preset_operands &&
2092            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2093           (allow_attributes && !ParseAttributesAsProtoMessage(
2094                                    /*non_proto_attrs=*/attrs, &options))) {
2095         return nullptr;
2096       }
2097       if (!maybe_infer_shape([&] {
2098             return ShapeInference::InferCholeskyShape(operands[0]->shape());
2099           })) {
2100         return nullptr;
2101       }
2102       return builder->AddInstruction(
2103           HloInstruction::CreateCholesky(*shape, operands[0], options));
2104     }
2105     case HloOpcode::kBroadcast: {
2106       if (!preset_operands &&
2107           !ParseOperands(&operands, builder, /*expected_size=*/1)) {
2108         return nullptr;
2109       }
2110 
2111       // The `dimensions` attr is optional if the broadcasted operand is a
2112       // scalar; in that case we can infer it to be {}.
2113       bool operand_is_scalar = ShapeUtil::IsScalar(operands[0]->shape());
2114       optional<std::vector<int64_t>> broadcast_dimensions;
2115       attrs["dimensions"] = {/*required=*/!operand_is_scalar,
2116                              AttrTy::kBracedInt64List, &broadcast_dimensions};
2117       if (!ParseAttributes(attrs, allow_attributes)) {
2118         return nullptr;
2119       }
2120       if (operand_is_scalar && !broadcast_dimensions.has_value()) {
2121         broadcast_dimensions.emplace();
2122       }
2123 
2124       if (!maybe_infer_shape([&] {
2125             return ShapeInference::InferBroadcastShape(operands[0]->shape(),
2126                                                        *broadcast_dimensions);
2127           })) {
2128         return nullptr;
2129       }
2130       return builder->AddInstruction(HloInstruction::CreateBroadcast(
2131           *shape, operands[0], *broadcast_dimensions));
2132     }
2133     case HloOpcode::kConcatenate: {
2134       optional<std::vector<int64_t>> dimensions;
2135       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2136                              &dimensions};
2137       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2138           !ParseAttributes(attrs, allow_attributes) ||
2139           dimensions->size() != 1) {
2140         return nullptr;
2141       }
2142       if (!maybe_infer_shape([&] {
2143             absl::InlinedVector<const Shape*, 2> arg_shapes;
2144             arg_shapes.reserve(operands.size());
2145             for (auto* operand : operands) {
2146               arg_shapes.push_back(&operand->shape());
2147             }
2148             return ShapeInference::InferConcatOpShape(arg_shapes,
2149                                                       dimensions->at(0));
2150           })) {
2151         return nullptr;
2152       }
2153       return builder->AddInstruction(HloInstruction::CreateConcatenate(
2154           *shape, operands, dimensions->at(0)));
2155     }
2156     case HloOpcode::kMap: {
2157       optional<HloComputation*> to_apply;
2158       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2159                            &to_apply};
2160       optional<std::vector<int64_t>> dimensions;
2161       attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
2162                              &dimensions};
2163       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2164           !ParseAttributes(attrs, allow_attributes)) {
2165         return nullptr;
2166       }
2167       if (!maybe_infer_shape([&] {
2168             absl::InlinedVector<const Shape*, 2> arg_shapes;
2169             arg_shapes.reserve(operands.size());
2170             for (auto* operand : operands) {
2171               arg_shapes.push_back(&operand->shape());
2172             }
2173             return ShapeInference::InferMapShape(
2174                 arg_shapes, to_apply.value()->ComputeProgramShape(),
2175                 *dimensions);
2176           })) {
2177         return nullptr;
2178       }
2179       return builder->AddInstruction(
2180           HloInstruction::CreateMap(*shape, operands, *to_apply));
2181     }
2182     case HloOpcode::kReduce: {
2183       optional<HloComputation*> reduce_computation;
2184       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2185                            &reduce_computation};
2186       optional<std::vector<int64_t>> dimensions_to_reduce;
2187       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2188                              &dimensions_to_reduce};
2189       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2190           !ParseAttributes(attrs, allow_attributes)) {
2191         return nullptr;
2192       }
2193       if (operands.size() % 2) {
2194         TokenError(StrCat("expects an even number of operands, but has ",
2195                           operands.size(), " operands"));
2196         return nullptr;
2197       }
2198       if (!maybe_infer_shape([&] {
2199             absl::InlinedVector<const Shape*, 2> arg_shapes;
2200             arg_shapes.reserve(operands.size());
2201             for (auto* operand : operands) {
2202               arg_shapes.push_back(&operand->shape());
2203             }
2204             return ShapeInference::InferReduceShape(
2205                 arg_shapes, *dimensions_to_reduce,
2206                 reduce_computation.value()->ComputeProgramShape());
2207           })) {
2208         return nullptr;
2209       }
2210       return builder->AddInstruction(HloInstruction::CreateReduce(
2211           *shape, /*operands=*/
2212           absl::Span<HloInstruction* const>(operands).subspan(
2213               0, operands.size() / 2),
2214           /*init_values=*/
2215           absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2216                                                               2),
2217           *dimensions_to_reduce, *reduce_computation));
2218     }
2219     case HloOpcode::kReverse: {
2220       optional<std::vector<int64_t>> dimensions;
2221       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2222                              &dimensions};
2223       if ((!preset_operands &&
2224            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2225           !ParseAttributes(attrs, allow_attributes)) {
2226         return nullptr;
2227       }
2228       if (!maybe_infer_shape([&] {
2229             return ShapeInference::InferReverseShape(operands[0]->shape(),
2230                                                      *dimensions);
2231           })) {
2232         return nullptr;
2233       }
2234       return builder->AddInstruction(
2235           HloInstruction::CreateReverse(*shape, operands[0], *dimensions));
2236     }
2237     case HloOpcode::kSelectAndScatter: {
2238       optional<HloComputation*> select;
2239       attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
2240       optional<HloComputation*> scatter;
2241       attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
2242       optional<Window> window;
2243       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2244       if ((!preset_operands &&
2245            !ParseOperands(&operands, builder, /*expected_size=*/3)) ||
2246           !ParseAttributes(attrs, allow_attributes)) {
2247         return nullptr;
2248       }
2249       if (!window) {
2250         window.emplace();
2251       }
2252       if (!maybe_infer_shape([&] {
2253             return ShapeInference::InferSelectAndScatterShape(
2254                 operands[0]->shape(), select.value()->ComputeProgramShape(),
2255                 *window, operands[1]->shape(), operands[2]->shape(),
2256                 scatter.value()->ComputeProgramShape());
2257           })) {
2258         return nullptr;
2259       }
2260       return builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
2261           *shape, /*operand=*/operands[0], *select, *window,
2262           /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
2263     }
2264     case HloOpcode::kSlice: {
2265       optional<SliceRanges> slice_ranges;
2266       attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
2267       if ((!preset_operands &&
2268            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2269           !ParseAttributes(attrs, allow_attributes)) {
2270         return nullptr;
2271       }
2272       return builder->AddInstruction(HloInstruction::CreateSlice(
2273           *shape, operands[0], slice_ranges->starts, slice_ranges->limits,
2274           slice_ranges->strides));
2275     }
2276     case HloOpcode::kDynamicSlice: {
2277       optional<std::vector<int64_t>> dynamic_slice_sizes;
2278       attrs["dynamic_slice_sizes"] = {
2279           /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
2280       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2281           !ParseAttributes(attrs, allow_attributes)) {
2282         return nullptr;
2283       }
2284       if (operands.empty()) {
2285         TokenError("Expected at least one operand.");
2286         return nullptr;
2287       }
2288       if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
2289           operands.size() != 1 + operands[0]->shape().rank()) {
2290         TokenError("Wrong number of operands.");
2291         return nullptr;
2292       }
2293       return builder->AddInstruction(HloInstruction::CreateDynamicSlice(
2294           *shape, /*operand=*/operands[0],
2295           /*start_indices=*/absl::MakeSpan(operands).subspan(1),
2296           *dynamic_slice_sizes));
2297     }
2298     case HloOpcode::kDynamicUpdateSlice: {
2299       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2300           !ParseAttributes(attrs, allow_attributes)) {
2301         return nullptr;
2302       }
2303       if (operands.size() < 2) {
2304         TokenError("Expected at least two operands.");
2305         return nullptr;
2306       }
2307       if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
2308           operands.size() != 2 + operands[0]->shape().rank()) {
2309         TokenError("Wrong number of operands.");
2310         return nullptr;
2311       }
2312       return builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2313           *shape, /*operand=*/operands[0], /*update=*/operands[1],
2314           /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
2315     }
2316     case HloOpcode::kTranspose: {
2317       optional<std::vector<int64_t>> dimensions;
2318       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2319                              &dimensions};
2320       if ((!preset_operands &&
2321            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2322           !ParseAttributes(attrs, allow_attributes)) {
2323         return nullptr;
2324       }
2325       if (!maybe_infer_shape([&] {
2326             return ShapeInference::InferTransposeShape(operands[0]->shape(),
2327                                                        *dimensions);
2328           })) {
2329         return nullptr;
2330       }
2331       return builder->AddInstruction(
2332           HloInstruction::CreateTranspose(*shape, operands[0], *dimensions));
2333     }
2334     case HloOpcode::kBatchNormTraining: {
2335       optional<float> epsilon;
2336       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2337       optional<int64_t> feature_index;
2338       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2339                                 &feature_index};
2340       if ((!preset_operands &&
2341            !ParseOperands(&operands, builder, /*expected_size=*/3)) ||
2342           !ParseAttributes(attrs, allow_attributes)) {
2343         return nullptr;
2344       }
2345       if (!maybe_infer_shape([&] {
2346             return ShapeInference::InferBatchNormTrainingShape(
2347                 operands[0]->shape(), operands[1]->shape(),
2348                 operands[2]->shape(), *feature_index);
2349           })) {
2350         return nullptr;
2351       }
2352       return builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
2353           *shape, /*operand=*/operands[0], /*scale=*/operands[1],
2354           /*offset=*/operands[2], *epsilon, *feature_index));
2355     }
2356     case HloOpcode::kBatchNormInference: {
2357       optional<float> epsilon;
2358       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2359       optional<int64_t> feature_index;
2360       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2361                                 &feature_index};
2362       if ((!preset_operands &&
2363            !ParseOperands(&operands, builder, /*expected_size=*/5)) ||
2364           !ParseAttributes(attrs, allow_attributes)) {
2365         return nullptr;
2366       }
2367       if (!maybe_infer_shape([&] {
2368             return ShapeInference::InferBatchNormInferenceShape(
2369                 operands[0]->shape(), operands[1]->shape(),
2370                 operands[2]->shape(), operands[3]->shape(),
2371                 operands[4]->shape(), *feature_index);
2372           })) {
2373         return nullptr;
2374       }
2375       return builder->AddInstruction(HloInstruction::CreateBatchNormInference(
2376           *shape, /*operand=*/operands[0], /*scale=*/operands[1],
2377           /*offset=*/operands[2], /*mean=*/operands[3],
2378           /*variance=*/operands[4], *epsilon, *feature_index));
2379     }
2380     case HloOpcode::kBatchNormGrad: {
2381       optional<float> epsilon;
2382       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
2383       optional<int64_t> feature_index;
2384       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
2385                                 &feature_index};
2386       if ((!preset_operands &&
2387            !ParseOperands(&operands, builder, /*expected_size=*/5)) ||
2388           !ParseAttributes(attrs, allow_attributes)) {
2389         return nullptr;
2390       }
2391       if (!maybe_infer_shape([&] {
2392             return ShapeInference::InferBatchNormGradShape(
2393                 operands[0]->shape(), operands[1]->shape(),
2394                 operands[2]->shape(), operands[3]->shape(),
2395                 operands[4]->shape(), *feature_index);
2396           })) {
2397         return nullptr;
2398       }
2399       return builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
2400           *shape, /*operand=*/operands[0], /*scale=*/operands[1],
2401           /*mean=*/operands[2], /*variance=*/operands[3],
2402           /*grad_output=*/operands[4], *epsilon, *feature_index));
2403     }
2404     case HloOpcode::kPad: {
2405       optional<PaddingConfig> padding;
2406       attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
2407       if ((!preset_operands &&
2408            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2409           !ParseAttributes(attrs, allow_attributes)) {
2410         return nullptr;
2411       }
2412       if (!maybe_infer_shape([&] {
2413             return ShapeInference::InferPadShape(
2414                 operands[0]->shape(), operands[1]->shape(), *padding);
2415           })) {
2416         return nullptr;
2417       }
2418       return builder->AddInstruction(HloInstruction::CreatePad(
2419           *shape, operands[0], /*padding_value=*/operands[1], *padding));
2420     }
2421     case HloOpcode::kFusion: {
2422       optional<HloComputation*> fusion_computation;
2423       attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
2424                         &fusion_computation};
2425       optional<HloInstruction::FusionKind> fusion_kind;
2426       attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
2427       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2428           !ParseAttributes(attrs, allow_attributes)) {
2429         return nullptr;
2430       }
2431       return builder->AddInstruction(HloInstruction::CreateFusion(
2432           *shape, *fusion_kind, operands, *fusion_computation));
2433     }
2434     case HloOpcode::kInfeed: {
2435       optional<std::string> config;
2436       attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2437       if ((!preset_operands &&
2438            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2439           !ParseAttributes(attrs, allow_attributes)) {
2440         return nullptr;
2441       }
2442       // We need to know the infeed data shape to construct the infeed
2443       // instruction. This is the zero-th element of the tuple-shaped output of
2444       // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
2445       // if the shape is not a non-empty tuple, so add guard so an error message
2446       // can be emitted instead of a check fail
2447       if (!shape->IsTuple() && !ShapeUtil::IsEmptyTuple(*shape)) {
2448         TokenError("infeed must have a non-empty tuple shape");
2449         return nullptr;
2450       }
2451       return builder->AddInstruction(HloInstruction::CreateInfeed(
2452           ShapeUtil::GetTupleElementShape(*shape, 0), operands[0],
2453           config ? *config : ""));
2454     }
2455     case HloOpcode::kOutfeed: {
2456       optional<std::string> config;
2457       optional<Shape> outfeed_shape;
2458       attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
2459       attrs["outfeed_shape"] = {/*required=*/false, AttrTy::kShape,
2460                                 &outfeed_shape};
2461       if ((!preset_operands &&
2462            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2463           !ParseAttributes(attrs, allow_attributes)) {
2464         return nullptr;
2465       }
2466       HloInstruction* const outfeed_input = operands[0];
2467       HloInstruction* const outfeed_token = operands[1];
2468       const Shape shape =
2469           outfeed_shape.has_value() ? *outfeed_shape : outfeed_input->shape();
2470       return builder->AddInstruction(HloInstruction::CreateOutfeed(
2471           shape, outfeed_input, outfeed_token, config ? *config : ""));
2472     }
2473     case HloOpcode::kRng: {
2474       optional<RandomDistribution> distribution;
2475       attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
2476                                &distribution};
2477       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2478           !ParseAttributes(attrs, allow_attributes)) {
2479         return nullptr;
2480       }
2481       return builder->AddInstruction(
2482           HloInstruction::CreateRng(*shape, *distribution, operands));
2483     }
2484     case HloOpcode::kRngGetAndUpdateState: {
2485       optional<int64_t> delta;
2486       attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
2487       if ((!preset_operands &&
2488            !ParseOperands(&operands, builder, /*expected_size=*/0)) ||
2489           !ParseAttributes(attrs, allow_attributes)) {
2490         return nullptr;
2491       }
2492       return builder->AddInstruction(
2493           HloInstruction::CreateRngGetAndUpdateState(*shape, *delta));
2494     }
2495     case HloOpcode::kRngBitGenerator: {
2496       optional<RandomAlgorithm> algorithm;
2497       attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
2498                             &algorithm};
2499       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2500           !ParseAttributes(attrs, allow_attributes)) {
2501         return nullptr;
2502       }
2503       return builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
2504           *shape, operands[0], *algorithm));
2505     }
2506     case HloOpcode::kReducePrecision: {
2507       optional<int64_t> exponent_bits;
2508       optional<int64_t> mantissa_bits;
2509       attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
2510                                 &exponent_bits};
2511       attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
2512                                 &mantissa_bits};
2513       if ((!preset_operands &&
2514            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2515           !ParseAttributes(attrs, allow_attributes)) {
2516         return nullptr;
2517       }
2518       return builder->AddInstruction(HloInstruction::CreateReducePrecision(
2519           *shape, operands[0], static_cast<int>(*exponent_bits),
2520           static_cast<int>(*mantissa_bits)));
2521     }
2522     case HloOpcode::kConditional: {
2523       optional<HloComputation*> true_computation;
2524       optional<HloComputation*> false_computation;
2525       optional<std::vector<HloComputation*>> branch_computations;
2526       if (!preset_operands && !ParseOperands(&operands, builder)) {
2527         return nullptr;
2528       }
2529       if (!ShapeUtil::IsScalar(operands[0]->shape())) {
2530         TokenError("The first operand must be a scalar");
2531         return nullptr;
2532       }
2533       const bool branch_index_is_bool =
2534           operands[0]->shape().element_type() == PRED;
2535       if (branch_index_is_bool) {
2536         attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
2537                                      &true_computation};
2538         attrs["false_computation"] = {
2539             /*required=*/true, AttrTy::kHloComputation, &false_computation};
2540       } else {
2541         if (operands[0]->shape().element_type() != S32) {
2542           TokenError("The first operand must be a scalar of PRED or S32");
2543           return nullptr;
2544         }
2545         attrs["branch_computations"] = {/*required=*/true,
2546                                         AttrTy::kBracedHloComputationList,
2547                                         &branch_computations};
2548       }
2549       if (!ParseAttributes(attrs, allow_attributes)) {
2550         return nullptr;
2551       }
2552       if (branch_index_is_bool) {
2553         branch_computations.emplace({*true_computation, *false_computation});
2554       }
2555       if (branch_computations->empty() ||
2556           operands.size() != branch_computations->size() + 1) {
2557         return nullptr;
2558       }
2559       if (!maybe_infer_shape([&] {
2560             absl::InlinedVector<ProgramShape, 2> branch_computation_shapes;
2561             branch_computation_shapes.reserve(branch_computations->size());
2562             for (auto* computation : *branch_computations) {
2563               branch_computation_shapes.push_back(
2564                   computation->ComputeProgramShape());
2565             }
2566             absl::InlinedVector<Shape, 2> branch_operand_shapes;
2567             branch_operand_shapes.reserve(operands.size() - 1);
2568             for (int i = 1; i < operands.size(); ++i) {
2569               branch_operand_shapes.push_back(operands[i]->shape());
2570             }
2571             return ShapeInference::InferConditionalShape(
2572                 operands[0]->shape(), branch_computation_shapes,
2573                 branch_operand_shapes);
2574           })) {
2575         return nullptr;
2576       }
2577       return builder->AddInstruction(HloInstruction::CreateConditional(
2578           *shape, /*branch_index=*/operands[0],
2579           absl::MakeSpan(*branch_computations),
2580           absl::MakeSpan(operands).subspan(1)));
2581     }
2582     case HloOpcode::kCustomCall: {
2583       optional<std::string> custom_call_target;
2584       optional<Window> window;
2585       optional<ConvolutionDimensionNumbers> dnums;
2586       optional<int64_t> feature_group_count;
2587       optional<int64_t> batch_group_count;
2588       optional<std::vector<Shape>> operand_layout_constraints;
2589       optional<bool> custom_call_has_side_effect;
2590       optional<HloComputation*> to_apply;
2591       optional<
2592           std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>>
2593           output_to_operand_aliasing;
2594       optional<PaddingType> padding_type;
2595       optional<std::vector<HloComputation*>> called_computations;
2596       optional<CustomCallSchedule> custom_call_schedule;
2597       optional<CustomCallApiVersion> api_version;
2598       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
2599                                      &custom_call_target};
2600       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
2601       attrs["dim_labels"] = {/*required=*/false,
2602                              AttrTy::kConvolutionDimensionNumbers, &dnums};
2603       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
2604                                       &feature_group_count};
2605       attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
2606                                     &batch_group_count};
2607       attrs["operand_layout_constraints"] = {
2608           /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
2609       attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
2610                                               &custom_call_has_side_effect};
2611       attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
2612                            &to_apply};
2613       attrs["called_computations"] = {/*required=*/false,
2614                                       AttrTy::kBracedHloComputationList,
2615                                       &called_computations};
2616       attrs["output_to_operand_aliasing"] = {/*required=*/false,
2617                                              AttrTy::kInstructionAliasing,
2618                                              &output_to_operand_aliasing};
2619 
2620       attrs["padding_type"] = {/*required=*/false, AttrTy::kPaddingType,
2621                                &padding_type};
2622 
2623       optional<Literal> literal;
2624       attrs["literal"] = {/*required=*/false, AttrTy::kLiteral, &literal};
2625       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2626       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2627                                     &operand_precision};
2628       HloInstruction* instruction;
2629       if (called_computations.has_value() && to_apply.has_value()) {
2630         TokenError(
2631             "A single instruction can't have both to_apply and "
2632             "calls field");
2633         return nullptr;
2634       }
2635       attrs["schedule"] = {/*required=*/false, AttrTy::kCustomCallSchedule,
2636                            &custom_call_schedule};
2637       attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion,
2638                               &api_version};
2639       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2640           !ParseAttributes(attrs, allow_attributes)) {
2641         return nullptr;
2642       }
2643 
2644       if (api_version.has_value() &&
2645           *api_version == CustomCallApiVersion::API_VERSION_UNSPECIFIED) {
2646         TokenError(StrCat("Invalid API version: ",
2647                           CustomCallApiVersion_Name(*api_version)));
2648         return nullptr;
2649       }
2650       if (operand_layout_constraints.has_value()) {
2651         if (!LayoutUtil::HasLayout(*shape)) {
2652           TokenError("Layout must be set on layout-constrained custom call");
2653           return nullptr;
2654         }
2655         if (operands.size() != operand_layout_constraints->size()) {
2656           TokenError(StrCat("Expected ", operands.size(),
2657                             " operand layout constraints, ",
2658                             operand_layout_constraints->size(), " given"));
2659           return nullptr;
2660         }
2661         for (int64_t i = 0; i < operands.size(); ++i) {
2662           const Shape& operand_shape_with_layout =
2663               (*operand_layout_constraints)[i];
2664           if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
2665             TokenError(StrCat(
2666                 "Operand layout constraint shape ",
2667                 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2668                 " for operand ", i, " does not have a layout"));
2669             return nullptr;
2670           }
2671           if (!ShapeUtil::Compatible(operand_shape_with_layout,
2672                                      operands[i]->shape())) {
2673             TokenError(StrCat(
2674                 "Operand layout constraint shape ",
2675                 ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
2676                 " for operand ", i, " is not compatible with operand shape ",
2677                 ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
2678             return nullptr;
2679           }
2680         }
2681         instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
2682             *shape, operands, *custom_call_target, *operand_layout_constraints,
2683             ""));
2684       } else {
2685         if (to_apply.has_value()) {
2686           instruction =
2687               builder->AddInstruction(HloInstruction::CreateCustomCall(
2688                   *shape, operands, *to_apply, *custom_call_target, ""));
2689         } else if (called_computations.has_value()) {
2690           instruction =
2691               builder->AddInstruction(HloInstruction::CreateCustomCall(
2692                   *shape, operands, *called_computations, *custom_call_target,
2693                   ""));
2694         } else {
2695           instruction =
2696               builder->AddInstruction(HloInstruction::CreateCustomCall(
2697                   *shape, operands, *custom_call_target, ""));
2698         }
2699       }
2700       auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
2701       if (window.has_value()) {
2702         custom_call_instr->set_window(*window);
2703       }
2704       if (dnums.has_value()) {
2705         custom_call_instr->set_convolution_dimension_numbers(*dnums);
2706       }
2707       if (feature_group_count.has_value()) {
2708         custom_call_instr->set_feature_group_count(*feature_group_count);
2709       }
2710       if (batch_group_count.has_value()) {
2711         custom_call_instr->set_batch_group_count(*batch_group_count);
2712       }
2713       if (padding_type.has_value()) {
2714         custom_call_instr->set_padding_type(*padding_type);
2715       }
2716       if (custom_call_has_side_effect.has_value()) {
2717         custom_call_instr->set_custom_call_has_side_effect(
2718             *custom_call_has_side_effect);
2719       }
2720       if (custom_call_schedule.has_value()) {
2721         custom_call_instr->set_custom_call_schedule(*custom_call_schedule);
2722       }
2723       if (api_version.has_value()) {
2724         custom_call_instr->set_api_version(*api_version);
2725       }
2726       if (output_to_operand_aliasing.has_value()) {
2727         custom_call_instr->set_output_to_operand_aliasing(
2728             std::move(*output_to_operand_aliasing));
2729       }
2730       if (literal.has_value()) {
2731         custom_call_instr->set_literal(std::move(*literal));
2732       }
2733       PrecisionConfig precision_config;
2734       if (operand_precision) {
2735         *precision_config.mutable_operand_precision() = {
2736             operand_precision->begin(), operand_precision->end()};
2737       } else {
2738         precision_config.mutable_operand_precision()->Resize(
2739             operands.size(), PrecisionConfig::DEFAULT);
2740       }
2741       *custom_call_instr->mutable_precision_config() = precision_config;
2742       return instruction;
2743     }
2744     case HloOpcode::kDot: {
2745       optional<std::vector<int64_t>> lhs_contracting_dims;
2746       attrs["lhs_contracting_dims"] = {
2747           /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
2748       optional<std::vector<int64_t>> rhs_contracting_dims;
2749       attrs["rhs_contracting_dims"] = {
2750           /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
2751       optional<std::vector<int64_t>> lhs_batch_dims;
2752       attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2753                                  &lhs_batch_dims};
2754       optional<std::vector<int64_t>> rhs_batch_dims;
2755       attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
2756                                  &rhs_batch_dims};
2757       optional<std::vector<PrecisionConfig::Precision>> operand_precision;
2758       attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
2759                                     &operand_precision};
2760 
2761       if ((!preset_operands &&
2762            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2763           !ParseAttributes(attrs, allow_attributes)) {
2764         return nullptr;
2765       }
2766 
2767       DotDimensionNumbers dnum;
2768       if (lhs_contracting_dims) {
2769         *dnum.mutable_lhs_contracting_dimensions() = {
2770             lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
2771       }
2772       if (rhs_contracting_dims) {
2773         *dnum.mutable_rhs_contracting_dimensions() = {
2774             rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
2775       }
2776       if (lhs_batch_dims) {
2777         *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
2778                                                 lhs_batch_dims->end()};
2779       }
2780       if (rhs_batch_dims) {
2781         *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
2782                                                 rhs_batch_dims->end()};
2783       }
2784 
2785       PrecisionConfig precision_config;
2786       if (operand_precision) {
2787         *precision_config.mutable_operand_precision() = {
2788             operand_precision->begin(), operand_precision->end()};
2789       } else {
2790         precision_config.mutable_operand_precision()->Resize(
2791             operands.size(), PrecisionConfig::DEFAULT);
2792       }
2793       if (!maybe_infer_shape([&] {
2794             return ShapeInference::InferDotOpShape(
2795                 operands[0]->shape(), operands[1]->shape(), dnum,
2796                 /*preferred_element_type=*/std::nullopt);
2797           })) {
2798         return nullptr;
2799       }
2800       return builder->AddInstruction(HloInstruction::CreateDot(
2801           *shape, operands[0], operands[1], dnum, precision_config));
2802     }
2803     case HloOpcode::kGather: {
2804       optional<std::vector<int64_t>> offset_dims;
2805       attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
2806                               &offset_dims};
2807       optional<std::vector<int64_t>> collapsed_slice_dims;
2808       attrs["collapsed_slice_dims"] = {
2809           /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
2810       optional<std::vector<int64_t>> start_index_map;
2811       attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
2812                                   &start_index_map};
2813       optional<int64_t> index_vector_dim;
2814       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2815                                    &index_vector_dim};
2816       optional<std::vector<int64_t>> slice_sizes;
2817       attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
2818                               &slice_sizes};
2819       optional<bool> indices_are_sorted = false;
2820       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2821                                      &indices_are_sorted};
2822 
2823       if ((!preset_operands &&
2824            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2825           !ParseAttributes(attrs, allow_attributes)) {
2826         return nullptr;
2827       }
2828 
2829       GatherDimensionNumbers dim_numbers =
2830           HloGatherInstruction::MakeGatherDimNumbers(
2831               /*offset_dims=*/*offset_dims,
2832               /*collapsed_slice_dims=*/*collapsed_slice_dims,
2833               /*start_index_map=*/*start_index_map,
2834               /*index_vector_dim=*/*index_vector_dim);
2835       if (!maybe_infer_shape([&] {
2836             return ShapeInference::InferGatherShape(operands[0]->shape(),
2837                                                     operands[1]->shape(),
2838                                                     dim_numbers, *slice_sizes);
2839           })) {
2840         return nullptr;
2841       }
2842       return builder->AddInstruction(HloInstruction::CreateGather(
2843           *shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
2844           dim_numbers, *slice_sizes, indices_are_sorted.value()));
2845     }
2846     case HloOpcode::kScatter: {
2847       optional<std::vector<int64_t>> update_window_dims;
2848       attrs["update_window_dims"] = {
2849           /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
2850       optional<std::vector<int64_t>> inserted_window_dims;
2851       attrs["inserted_window_dims"] = {
2852           /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
2853       optional<std::vector<int64_t>> scatter_dims_to_operand_dims;
2854       attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
2855                                                AttrTy::kBracedInt64List,
2856                                                &scatter_dims_to_operand_dims};
2857       optional<int64_t> index_vector_dim;
2858       attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
2859                                    &index_vector_dim};
2860 
2861       optional<HloComputation*> update_computation;
2862       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
2863                            &update_computation};
2864       optional<bool> indices_are_sorted = false;
2865       attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
2866                                      &indices_are_sorted};
2867       optional<bool> unique_indices = false;
2868       attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
2869                                  &unique_indices};
2870 
2871       if ((!preset_operands && !ParseOperands(&operands, builder)) ||
2872           !ParseAttributes(attrs, allow_attributes)) {
2873         return nullptr;
2874       }
2875 
2876       if (operands.size() % 2 == 0) {
2877         TokenError(StrCat("expects an odd number of operands, but has ",
2878                           operands.size(), " operands"));
2879         return nullptr;
2880       }
2881 
2882       ScatterDimensionNumbers dim_numbers =
2883           HloScatterInstruction::MakeScatterDimNumbers(
2884               /*update_window_dims=*/*update_window_dims,
2885               /*inserted_window_dims=*/*inserted_window_dims,
2886               /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
2887               /*index_vector_dim=*/*index_vector_dim);
2888 
2889       if (!maybe_infer_shape([&] {
2890             absl::InlinedVector<const Shape*, 3> arg_shapes;
2891             arg_shapes.reserve(operands.size());
2892             for (auto* operand : operands) {
2893               arg_shapes.push_back(&operand->shape());
2894             }
2895             return ShapeInference::InferScatterShape(
2896                 arg_shapes, update_computation.value()->ComputeProgramShape(),
2897                 dim_numbers);
2898           })) {
2899         return nullptr;
2900       }
2901       auto input_count = operands.size() / 2;
2902       auto operand_span = absl::MakeConstSpan(operands);
2903       return builder->AddInstruction(HloInstruction::CreateScatter(
2904           *shape, operand_span.first(input_count), operands[input_count],
2905           operand_span.last(input_count), *update_computation, dim_numbers,
2906           indices_are_sorted.value(), unique_indices.value()));
2907     }
2908     case HloOpcode::kDomain: {
2909       DomainData domain;
2910       attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
2911       if ((!preset_operands &&
2912            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2913           !ParseAttributes(attrs, allow_attributes)) {
2914         return nullptr;
2915       }
2916       if (!maybe_infer_shape([&] {
2917             return ShapeInference::InferUnaryOpShape(opcode, operands[0]);
2918           })) {
2919         return nullptr;
2920       }
2921       return builder->AddInstruction(HloInstruction::CreateDomain(
2922           *shape, operands[0], std::move(domain.exit_metadata),
2923           std::move(domain.entry_metadata)));
2924     }
2925     case HloOpcode::kGetDimensionSize: {
2926       optional<std::vector<int64_t>> dimensions;
2927       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2928                              &dimensions};
2929       if ((!preset_operands &&
2930            !ParseOperands(&operands, builder, /*expected_size=*/1)) ||
2931           !ParseAttributes(attrs, allow_attributes)) {
2932         return nullptr;
2933       }
2934       if (!maybe_infer_shape([&] {
2935             return ShapeInference::InferGetDimensionSizeShape(
2936                 operands[0]->shape(), dimensions->at(0));
2937           })) {
2938         return nullptr;
2939       }
2940       return builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
2941           *shape, operands[0], (*dimensions)[0]));
2942     }
2943     case HloOpcode::kSetDimensionSize: {
2944       optional<std::vector<int64_t>> dimensions;
2945       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
2946                              &dimensions};
2947       if ((!preset_operands &&
2948            !ParseOperands(&operands, builder, /*expected_size=*/2)) ||
2949           !ParseAttributes(attrs, allow_attributes)) {
2950         return nullptr;
2951       }
2952       if (!maybe_infer_shape([&] {
2953             return ShapeInference::InferSetDimensionSizeShape(
2954                 operands[0]->shape(), operands[1]->shape(), dimensions->at(0));
2955           })) {
2956         return nullptr;
2957       }
2958       return builder->AddInstruction(HloInstruction::CreateSetDimensionSize(
2959           *shape, operands[0], operands[1], (*dimensions)[0]));
2960     }
2961   }
2962   return nullptr;
2963 }  // NOLINT(readability/fn_size)
2964 
2965 // ::= '{' (single_sharding | tuple_sharding) '}'
2966 //
2967 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)2968 bool HloParserImpl::ParseSharding(OpSharding* sharding) {
2969   // A single sharding starts with '{' and is not followed by '{'.
2970   // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
2971   // an empty tuple.
2972   if (!ParseToken(TokKind::kLbrace,
2973                   "expected '{' to start sharding attribute")) {
2974     return false;
2975   }
2976 
2977   if (lexer_.GetKind() != TokKind::kLbrace &&
2978       lexer_.GetKind() != TokKind::kRbrace) {
2979     return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
2980   }
2981 
2982   // Tuple sharding.
2983   // Allow empty tuple shardings.
2984   if (lexer_.GetKind() != TokKind::kRbrace) {
2985     do {
2986       if (!ParseSingleSharding(sharding->add_tuple_shardings(),
2987                                /*lbrace_pre_lexed=*/false)) {
2988         return false;
2989       }
2990     } while (EatIfPresent(TokKind::kComma));
2991   }
2992   sharding->set_type(OpSharding::TUPLE);
2993 
2994   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
2995 }
2996 
2997 // frontend_attributes ::= '{' attributes '}'
2998 // attributes
2999 //   ::= /*empty*/
3000 //   ::= attribute '=' value (',' attribute '=' value)*
ParseFrontendAttributes(FrontendAttributes * frontend_attributes)3001 bool HloParserImpl::ParseFrontendAttributes(
3002     FrontendAttributes* frontend_attributes) {
3003   CHECK(frontend_attributes != nullptr);
3004   if (!ParseToken(TokKind::kLbrace,
3005                   "expected '{' to start frontend attributes")) {
3006     return false;
3007   }
3008   if (lexer_.GetKind() == TokKind::kRbrace) {
3009     // empty
3010   } else {
3011     do {
3012       std::string attribute;
3013       if (!ParseAttributeName(&attribute)) {
3014         return false;
3015       }
3016       if (lexer_.GetKind() != TokKind::kString) {
3017         return false;
3018       }
3019       (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
3020       lexer_.Lex();
3021     } while (EatIfPresent(TokKind::kComma));
3022   }
3023   return ParseToken(TokKind::kRbrace,
3024                     "expects '}' at the end of frontend attributes");
3025 }
3026 
3027 // ::= '{' 'replicated'? 'manual'? 'maximal'? ('device=' int)? shape?
3028 //         ('devices=' ('[' dims ']')* device_list)?
3029 //         ('metadata=' metadata)* '}'
3030 //
3031 // dims ::= int_list device_list ::= int_list
3032 // metadata ::= single_metadata |
3033 //              ('{' [single_metadata (',' single_metadata)*] '}')
3034 // last_tile_dims ::= sharding_type_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)3035 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding,
3036                                         bool lbrace_pre_lexed) {
3037   if (!lbrace_pre_lexed &&
3038       !ParseToken(TokKind::kLbrace,
3039                   "expected '{' to start sharding attribute")) {
3040     return false;
3041   }
3042 
3043   LocTy loc = lexer_.GetLoc();
3044   bool maximal = false;
3045   bool replicated = false;
3046   bool manual = false;
3047   bool last_tile_dim_replicate = false;
3048   bool last_tile_dims = false;
3049   std::vector<int64_t> devices;
3050   std::vector<int64_t> tile_assignment_dimensions;
3051   std::vector<OpSharding::Type> subgroup_types;
3052   while (lexer_.GetKind() != TokKind::kRbrace) {
3053     switch (lexer_.GetKind()) {
3054       case TokKind::kw_maximal:
3055         maximal = true;
3056         lexer_.Lex();
3057         break;
3058       case TokKind::kw_replicated:
3059         replicated = true;
3060         lexer_.Lex();
3061         break;
3062       case TokKind::kw_manual:
3063         manual = true;
3064         lexer_.Lex();
3065         break;
3066       case TokKind::kAttributeName: {
3067         if (lexer_.GetStrVal() == "device") {
3068           if (lexer_.Lex() != TokKind::kInt) {
3069             return TokenError("device= attribute must be an integer");
3070           }
3071           devices = {lexer_.GetInt64Val()};
3072           lexer_.Lex();
3073         } else if (lexer_.GetStrVal() == "devices") {
3074           lexer_.Lex();
3075           if (!ParseToken(TokKind::kLsquare,
3076                           "expected '[' to start sharding devices shape")) {
3077             return false;
3078           }
3079 
3080           do {
3081             int64_t dim;
3082             if (!ParseInt64(&dim)) {
3083               return false;
3084             }
3085             tile_assignment_dimensions.push_back(dim);
3086           } while (EatIfPresent(TokKind::kComma));
3087 
3088           if (!ParseToken(TokKind::kRsquare,
3089                           "expected ']' to start sharding devices shape")) {
3090             return false;
3091           }
3092           do {
3093             int64_t device;
3094             if (!ParseInt64(&device)) {
3095               return false;
3096             }
3097             devices.push_back(device);
3098           } while (EatIfPresent(TokKind::kComma));
3099         } else if (lexer_.GetStrVal() == "metadata") {
3100           lexer_.Lex();
3101           if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) {
3102             return false;
3103           }
3104         } else if (lexer_.GetStrVal() == "last_tile_dims") {
3105           last_tile_dims = true;
3106           lexer_.Lex();
3107           if (!ParseListShardingType(&subgroup_types)) {
3108             return false;
3109           }
3110         } else {
3111           return TokenError(
3112               "unknown attribute in sharding: expected device=, devices= "
3113               "metadata= or last_tile_dims= ");
3114         }
3115         break;
3116       }
3117       case TokKind::kw_last_tile_dim_replicate:
3118         last_tile_dim_replicate = true;
3119         lexer_.Lex();
3120         break;
3121       case TokKind::kRbrace:
3122         break;
3123       default:
3124         return TokenError("unexpected token");
3125     }
3126   }
3127 
3128   if (replicated) {
3129     if (!devices.empty()) {
3130       return Error(loc,
3131                    "replicated shardings should not have any devices assigned");
3132     }
3133     sharding->set_type(OpSharding::REPLICATED);
3134   } else if (maximal) {
3135     if (devices.size() != 1) {
3136       return Error(loc,
3137                    "maximal shardings should have exactly one device assigned");
3138     }
3139     sharding->set_type(OpSharding::MAXIMAL);
3140     sharding->add_tile_assignment_devices(devices[0]);
3141   } else if (manual) {
3142     if (!devices.empty()) {
3143       return Error(loc,
3144                    "manual shardings should not have any devices assigned");
3145     }
3146     sharding->set_type(OpSharding::MANUAL);
3147   } else {
3148     if (devices.size() <= 1) {
3149       return Error(
3150           loc, "non-maximal shardings must have more than one device assigned");
3151     }
3152     if (tile_assignment_dimensions.empty()) {
3153       return Error(
3154           loc,
3155           "non-maximal shardings must have a tile assignment list including "
3156           "dimensions");
3157     }
3158     sharding->set_type(OpSharding::OTHER);
3159     for (int64_t dim : tile_assignment_dimensions) {
3160       sharding->add_tile_assignment_dimensions(dim);
3161     }
3162     for (int64_t device : devices) {
3163       sharding->add_tile_assignment_devices(device);
3164     }
3165 
3166     if (last_tile_dims) {
3167       for (OpSharding::Type type : subgroup_types) {
3168         sharding->add_last_tile_dims(type);
3169       }
3170     } else {
3171       sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate);
3172     }
3173   }
3174 
3175   lexer_.Lex();
3176   return true;
3177 }
3178 
3179 // parameter_replication ::=
3180 //   '{' ('true' | 'false')* (',' ('true' | 'false'))*  '}'
ParseParameterReplication(ParameterReplication * parameter_replication)3181 bool HloParserImpl::ParseParameterReplication(
3182     ParameterReplication* parameter_replication) {
3183   if (!ParseToken(TokKind::kLbrace,
3184                   "expected '{' to start parameter_replication attribute")) {
3185     return false;
3186   }
3187 
3188   if (lexer_.GetKind() != TokKind::kRbrace) {
3189     do {
3190       if (lexer_.GetKind() == TokKind::kw_true) {
3191         parameter_replication->add_replicated_at_leaf_buffers(true);
3192       } else if (lexer_.GetKind() == TokKind::kw_false) {
3193         parameter_replication->add_replicated_at_leaf_buffers(false);
3194       } else {
3195         return false;
3196       }
3197       lexer_.Lex();
3198     } while (EatIfPresent(TokKind::kComma));
3199   }
3200 
3201   return ParseToken(TokKind::kRbrace,
3202                     "expected '}' to end parameter_replication attribute");
3203 }
3204 
3205 // replica_groups ::='{' int64_tlist_elements '}'
3206 // int64_tlist_elements
3207 //   ::= /*empty*/
3208 //   ::= int64_tlist (',' int64_tlist)*
3209 // int64_tlist ::= '{' int64_elements '}'
3210 // int64_elements
3211 //   ::= /*empty*/
3212 //   ::= int64_val (',' int64_val)*
ParseReplicaGroupsOnly(std::vector<ReplicaGroup> * replica_groups)3213 bool HloParserImpl::ParseReplicaGroupsOnly(
3214     std::vector<ReplicaGroup>* replica_groups) {
3215   std::vector<std::vector<int64_t>> result;
3216   if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
3217                           &result)) {
3218     return false;
3219   }
3220   *replica_groups = CreateReplicaGroups(result);
3221   return true;
3222 }
3223 
3224 // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
3225 //            'exit=' exit_sharding '}'
ParseDomain(DomainData * domain)3226 bool HloParserImpl::ParseDomain(DomainData* domain) {
3227   absl::flat_hash_map<std::string, AttrConfig> attrs;
3228   optional<std::string> kind;
3229   optional<OpSharding> entry_sharding;
3230   optional<OpSharding> exit_sharding;
3231   attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
3232   attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
3233   attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
3234   if (!ParseSubAttributes(attrs)) {
3235     return false;
3236   }
3237   if (*kind == ShardingMetadata::KindName()) {
3238     auto entry_sharding_ptr = std::make_unique<HloSharding>(
3239         HloSharding::FromProto(*entry_sharding).ValueOrDie());
3240     auto exit_sharding_ptr = std::make_unique<HloSharding>(
3241         HloSharding::FromProto(*exit_sharding).ValueOrDie());
3242     domain->entry_metadata =
3243         std::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
3244     domain->exit_metadata =
3245         std::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
3246   } else {
3247     return TokenError(StrCat("unsupported domain kind: ", *kind));
3248   }
3249   return true;
3250 }
3251 
3252 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)3253 bool HloParserImpl::ParseInstructionNames(
3254     std::vector<HloInstruction*>* instructions) {
3255   if (!ParseToken(TokKind::kLbrace,
3256                   "expects '{' at the beginning of instruction name list")) {
3257     return false;
3258   }
3259   LocTy loc = lexer_.GetLoc();
3260   do {
3261     std::string name;
3262     if (!ParseName(&name)) {
3263       return Error(loc, "expects a instruction name");
3264     }
3265     std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
3266     if (!instr) {
3267       return TokenError(StrFormat("instruction '%s' is not defined", name));
3268     }
3269     instructions->push_back(instr->first);
3270   } while (EatIfPresent(TokKind::kComma));
3271 
3272   return ParseToken(TokKind::kRbrace,
3273                     "expects '}' at the end of instruction name list");
3274 }
3275 
SetValueInLiteral(LocTy loc,int64_t value,int64_t index,Literal * literal)3276 bool HloParserImpl::SetValueInLiteral(LocTy loc, int64_t value, int64_t index,
3277                                       Literal* literal) {
3278   const Shape& shape = literal->shape();
3279   switch (shape.element_type()) {
3280     case S8:
3281       return SetValueInLiteralHelper<int8_t>(loc, value, index, literal);
3282     case S16:
3283       return SetValueInLiteralHelper<int16_t>(loc, value, index, literal);
3284     case S32:
3285       return SetValueInLiteralHelper<int32_t>(loc, value, index, literal);
3286     case S64:
3287       return SetValueInLiteralHelper<int64_t>(loc, value, index, literal);
3288     case U8:
3289       return SetValueInLiteralHelper<uint8_t>(loc, value, index, literal);
3290     case U16:
3291       return SetValueInLiteralHelper<uint16_t>(loc, value, index, literal);
3292     case U32:
3293       return SetValueInLiteralHelper<uint32_t>(loc, value, index, literal);
3294     case U64:
3295       return SetValueInLiteralHelper<uint64_t>(loc, value, index, literal);
3296     case PRED:
3297       // Bool type literals with rank >= 1 are printed in 0s and 1s.
3298       return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
3299                                            literal);
3300     default:
3301       LOG(FATAL) << "unknown integral primitive type "
3302                  << PrimitiveType_Name(shape.element_type());
3303   }
3304 }
3305 
SetValueInLiteral(LocTy loc,double value,int64_t index,Literal * literal)3306 bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, int64_t index,
3307                                       Literal* literal) {
3308   const Shape& shape = literal->shape();
3309   switch (shape.element_type()) {
3310     case F16:
3311       return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
3312     case BF16:
3313       return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
3314                                                            literal);
3315     case F32:
3316       return SetValueInLiteralHelper<float>(loc, value, index, literal);
3317     case F64:
3318       return SetValueInLiteralHelper<double>(loc, value, index, literal);
3319     default:
3320       LOG(FATAL) << "unknown floating point primitive type "
3321                  << PrimitiveType_Name(shape.element_type());
3322   }
3323 }
3324 
SetValueInLiteral(LocTy loc,bool value,int64_t index,Literal * literal)3325 bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, int64_t index,
3326                                       Literal* literal) {
3327   const Shape& shape = literal->shape();
3328   switch (shape.element_type()) {
3329     case PRED:
3330       return SetValueInLiteralHelper<bool>(loc, value, index, literal);
3331     default:
3332       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3333                  << " is not PRED type";
3334   }
3335 }
3336 
SetValueInLiteral(LocTy loc,std::complex<double> value,int64_t index,Literal * literal)3337 bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex<double> value,
3338                                       int64_t index, Literal* literal) {
3339   const Shape& shape = literal->shape();
3340   switch (shape.element_type()) {
3341     case C64:
3342       return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
3343                                                           literal);
3344     case C128:
3345       return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
3346                                                            literal);
3347     default:
3348       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
3349                  << " is not a complex type";
3350   }
3351 }
3352 
3353 template <typename T>
StringifyValue(T val)3354 std::string StringifyValue(T val) {
3355   return StrCat(val);
3356 }
3357 template <>
StringifyValue(std::complex<double> val)3358 std::string StringifyValue(std::complex<double> val) {
3359   return StrFormat("(%f, %f)", std::real(val), std::imag(val));
3360 }
3361 
3362 // Evaluates to V when T == U.
3363 template <typename T, typename U, typename V>
3364 using EnableIfSameWithType = std::enable_if_t<std::is_same<T, U>::value, V>;
3365 
3366 template <class T, EnableIfSameWithType<T, bool, bool> = false>
GetNanPayload(T val)3367 uint64_t GetNanPayload(T val) {
3368   return 0;
3369 }
3370 
3371 template <class T, EnableIfSameWithType<T, int64_t, bool> = false>
GetNanPayload(T val)3372 uint64_t GetNanPayload(T val) {
3373   return 0;
3374 }
3375 
3376 template <class T, EnableIfSameWithType<T, double, bool> = false>
GetNanPayload(T val)3377 uint64_t GetNanPayload(T val) {
3378   auto rep = absl::bit_cast<uint64_t>(val);
3379   if (auto payload = rep & NanPayloadBitMask<double>()) {
3380     return payload;
3381   }
3382   return QuietNanWithoutPayload<double>();
3383 }
3384 
3385 template <typename LiteralNativeT, typename LiteralComponentT>
3386 EnableIfSameWithType<LiteralNativeT, LiteralComponentT, LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3387 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3388   return real;
3389 }
3390 
3391 template <typename LiteralNativeT, typename LiteralComponentT>
3392 EnableIfSameWithType<LiteralNativeT, std::complex<LiteralComponentT>,
3393                      LiteralNativeT>
LiteralNativeFromRealImag(LiteralComponentT real,LiteralComponentT imag)3394 LiteralNativeFromRealImag(LiteralComponentT real, LiteralComponentT imag) {
3395   return LiteralNativeT(real, imag);
3396 }
3397 
3398 template <typename T>
3399 struct ComponentType {
3400   using Type = T;
3401 };
3402 
3403 template <typename T>
3404 struct ComponentType<std::complex<T>> {
3405   using Type = T;
3406 };
3407 
3408 template <typename T>
GetReal(T value)3409 T GetReal(T value) {
3410   return value;
3411 }
3412 
3413 template <typename T>
GetReal(std::complex<T> value)3414 T GetReal(std::complex<T> value) {
3415   return value.real();
3416 }
3417 
3418 template <typename T>
GetImag(T value)3419 T GetImag(T value) {
3420   return 0;
3421 }
3422 
3423 template <typename T>
GetImag(std::complex<T> value)3424 T GetImag(std::complex<T> value) {
3425   return value.imag();
3426 }
3427 
3428 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(LocTy loc,ParsedElemT value,int64_t index,Literal * literal)3429 bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
3430                                             int64_t index, Literal* literal) {
3431   if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
3432     return false;
3433   }
3434 
3435   // Check that the index is in range and assign into the literal
3436   if (index >= ShapeUtil::ElementsIn(literal->shape())) {
3437     return Error(loc, StrCat("tries to set value ", StringifyValue(value),
3438                              " to a literal in shape ",
3439                              ShapeUtil::HumanString(literal->shape()),
3440                              " at linear index ", index,
3441                              ", but the index is out of range"));
3442   }
3443   using ParsedElemComponentT = typename ComponentType<ParsedElemT>::Type;
3444   using LiteralNativeComponentT = typename ComponentType<LiteralNativeT>::Type;
3445   const auto handle_nan = [this, literal, index, loc](
3446                               ParsedElemComponentT parsed_value_component,
3447                               LiteralNativeComponentT*
3448                                   literal_value_component) {
3449     if (!std::isnan(static_cast<double>(parsed_value_component))) {
3450       return true;
3451     }
3452     auto nan_payload = GetNanPayload(parsed_value_component);
3453     if (nan_payload == QuietNanWithoutPayload<double>()) {
3454       nan_payload = QuietNanWithoutPayload<LiteralNativeComponentT>();
3455     }
3456     const auto kLargestPayload = NanPayloadBitMask<LiteralNativeComponentT>();
3457     if (nan_payload > kLargestPayload) {
3458       return Error(
3459           loc,
3460           StrCat("tries to set NaN payload 0x", absl::Hex(nan_payload),
3461                  " to a literal in shape ",
3462                  ShapeUtil::HumanString(literal->shape()), " at linear index ",
3463                  index, ", but the NaN payload is out of range (0x",
3464                  absl::Hex(kLargestPayload), ")"));
3465     }
3466     *literal_value_component = NanWithSignAndPayload<LiteralNativeComponentT>(
3467         /*sign=*/std::signbit(static_cast<double>(parsed_value_component)),
3468         /*nan_payload=*/nan_payload);
3469     return true;
3470   };
3471   const ParsedElemComponentT parsed_real_value = GetReal(value);
3472   auto literal_real_value =
3473       static_cast<LiteralNativeComponentT>(parsed_real_value);
3474   if (std::is_floating_point<ParsedElemT>::value ||
3475       std::is_same<ParsedElemT, std::complex<double>>::value) {
3476     if (!handle_nan(parsed_real_value, &literal_real_value)) {
3477       return false;
3478     }
3479   }
3480   const ParsedElemComponentT parsed_imag_value = GetImag(value);
3481   auto literal_imag_value =
3482       static_cast<LiteralNativeComponentT>(parsed_imag_value);
3483   if (std::is_same<ParsedElemT, std::complex<double>>::value) {
3484     if (!handle_nan(parsed_real_value, &literal_imag_value)) {
3485       return false;
3486     }
3487   }
3488   literal->data<LiteralNativeT>().at(index) =
3489       LiteralNativeFromRealImag<LiteralNativeT>(literal_real_value,
3490                                                 literal_imag_value);
3491   return true;
3492 }
3493 
3494 // Similar to ParseLiteral(Literal* literal, const Shape& shape), but parse the
3495 // shape instead of accepting one as argument.
ParseLiteral(Literal * literal)3496 bool HloParserImpl::ParseLiteral(Literal* literal) {
3497   if (lexer_.GetKind() == TokKind::kLparen) {
3498     // Consume Lparen
3499     lexer_.Lex();
3500     std::vector<Literal> elements;
3501     while (lexer_.GetKind() != TokKind::kRparen) {
3502       Literal element;
3503       if (!ParseLiteral(&element)) {
3504         return TokenError("Fails when parsing tuple element");
3505       }
3506       elements.emplace_back(std::move(element));
3507       if (lexer_.GetKind() != TokKind::kRparen) {
3508         ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3509       }
3510     }
3511 
3512     *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3513     // Consume Rparen
3514     return ParseToken(TokKind::kRparen, "expects ')' to close a tuple literal");
3515   }
3516   Shape literal_shape;
3517   if (!ParseShape(&literal_shape)) {
3518     return false;
3519   }
3520   return ParseLiteral(literal, literal_shape);
3521 }
3522 
3523 // literal
3524 //  ::= tuple
3525 //  ::= non_tuple
ParseLiteral(Literal * literal,const Shape & shape)3526 bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) {
3527   return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
3528                          : ParseNonTupleLiteral(literal, shape);
3529 }
3530 
3531 // tuple
3532 //  ::= shape '(' literal_list ')'
3533 // literal_list
3534 //  ::= /*empty*/
3535 //  ::= literal (',' literal)*
ParseTupleLiteral(Literal * literal,const Shape & shape)3536 bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) {
3537   if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
3538     return false;
3539   }
3540   std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
3541 
3542   if (lexer_.GetKind() == TokKind::kRparen) {
3543     // empty
3544   } else {
3545     // literal, (',' literal)*
3546     for (int i = 0; i < elements.size(); i++) {
3547       if (i > 0) {
3548         ParseToken(TokKind::kComma, "expects ',' to separate tuple elements");
3549       }
3550       if (!ParseLiteral(&elements[i],
3551                         ShapeUtil::GetTupleElementShape(shape, i))) {
3552         return TokenError(StrCat("expects the ", i, "th element"));
3553       }
3554     }
3555   }
3556   *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
3557   return ParseToken(TokKind::kRparen,
3558                     StrCat("expects ')' at the end of the tuple with ",
3559                            ShapeUtil::TupleElementCount(shape), "elements"));
3560 }
3561 
3562 // non_tuple
3563 //   ::= rank01
3564 //   ::= rank2345
3565 // rank2345 ::= shape nested_array
ParseNonTupleLiteral(Literal * literal,const Shape & shape)3566 bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
3567   CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
3568   return ParseDenseLiteral(literal, shape);
3569 }
3570 
ParseDenseLiteral(Literal * literal,const Shape & shape)3571 bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) {
3572   // Cast `rank` to int because we call shape.dimensions(int rank) below, and if
3573   // `rank` is an int64_t, that's an implicit narrowing conversion, which is
3574   // implementation-defined behavior.
3575   const int rank = static_cast<int>(shape.rank());
3576 
3577   // Create a literal with the given shape in default layout.
3578   *literal = LiteralUtil::CreateFromDimensions(shape.element_type(),
3579                                                shape.dimensions());
3580   int64_t nest_level = 0;
3581   int64_t linear_index = 0;
3582   // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
3583   // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
3584   // when we are parsing the 2nd '{' (right before '1'), we are seeing a
3585   // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
3586   // the first '}' (right after '3'), it means the sub-array ends, and the
3587   // sub-array is supposed to contain exactly 3 elements, so check if
3588   // elems_seen_per_dim[1] is 3.
3589   std::vector<int64_t> elems_seen_per_dim(rank);
3590   auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string {
3591     std::vector<int64_t> elems_seen_until_dim(elems_seen_per_dim.begin(),
3592                                               elems_seen_per_dim.begin() + dim);
3593     return StrCat("[",
3594                   StrJoin(elems_seen_until_dim, ",",
3595                           [](std::string* out, const int64_t num_elems) {
3596                             StrAppend(out, num_elems - 1);
3597                           }),
3598                   "]");
3599   };
3600 
3601   auto add_one_elem_seen = [&] {
3602     if (rank > 0) {
3603       if (nest_level != rank) {
3604         return TokenError(absl::StrFormat(
3605             "expects nested array in rank %d, but sees %d", rank, nest_level));
3606       }
3607       elems_seen_per_dim[rank - 1]++;
3608       if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
3609         return TokenError(absl::StrFormat(
3610             "expects %d elements on the minor-most dimension, but "
3611             "sees more",
3612             shape.dimensions(rank - 1)));
3613       }
3614     }
3615     return true;
3616   };
3617 
3618   do {
3619     switch (lexer_.GetKind()) {
3620       default:
3621         return TokenError("unexpected token type in a literal");
3622       case TokKind::kLbrace: {
3623         nest_level++;
3624         if (nest_level > rank) {
3625           return TokenError(absl::StrFormat(
3626               "expects nested array in rank %d, but sees larger", rank));
3627         }
3628         if (nest_level > 1) {
3629           elems_seen_per_dim[nest_level - 2]++;
3630           if (elems_seen_per_dim[nest_level - 2] >
3631               shape.dimensions(nest_level - 2)) {
3632             return TokenError(absl::StrFormat(
3633                 "expects %d elements in the %sth element, but sees more",
3634                 shape.dimensions(nest_level - 2),
3635                 get_index_str(nest_level - 2)));
3636           }
3637         }
3638         lexer_.Lex();
3639         break;
3640       }
3641       case TokKind::kRbrace: {
3642         nest_level--;
3643         if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
3644           return TokenError(absl::StrFormat(
3645               "expects %d elements in the %sth element, but sees %d",
3646               shape.dimensions(nest_level), get_index_str(nest_level),
3647               elems_seen_per_dim[nest_level]));
3648         }
3649         elems_seen_per_dim[nest_level] = 0;
3650         lexer_.Lex();
3651         break;
3652       }
3653       case TokKind::kLparen: {
3654         if (!primitive_util::IsComplexType(shape.element_type())) {
3655           return TokenError(
3656               absl::StrFormat("unexpected '(' in literal.  Parens are only "
3657                               "valid for complex literals"));
3658         }
3659 
3660         std::complex<double> value;
3661         LocTy loc = lexer_.GetLoc();
3662         if (!add_one_elem_seen() || !ParseComplex(&value) ||
3663             !SetValueInLiteral(loc, value, linear_index++, literal)) {
3664           return false;
3665         }
3666         break;
3667       }
3668       case TokKind::kDots: {
3669         if (nest_level != 1) {
3670           return TokenError(absl::StrFormat(
3671               "expects `...` at nest level 1, but sees it at nest level %d",
3672               nest_level));
3673         }
3674         elems_seen_per_dim[0] = shape.dimensions(0);
3675         lexer_.Lex();
3676         // Fill data with deterministic (garbage) values. Use static to avoid
3677         // creating identical constants which could potentially got CSE'ed
3678         // away. This is a best-effort approach to make sure replaying a HLO
3679         // gives us same optimized HLO graph.
3680         static uint32_t data = 0;
3681         uint32_t* raw_data = static_cast<uint32_t*>(literal->untyped_data());
3682         for (int64_t i = 0; i < literal->size_bytes() / 4; ++i) {
3683           raw_data[i] = data++;
3684         }
3685         uint8_t* raw_data_int8 = static_cast<uint8_t*>(literal->untyped_data());
3686         static uint8_t data_int8 = 0;
3687         for (int64_t i = 0; i < literal->size_bytes() % 4; ++i) {
3688           raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
3689         }
3690         break;
3691       }
3692       case TokKind::kComma:
3693         // Skip.
3694         lexer_.Lex();
3695         break;
3696       case TokKind::kw_true:
3697       case TokKind::kw_false:
3698       case TokKind::kInt:
3699       case TokKind::kDecimal:
3700       case TokKind::kw_inf:
3701       case TokKind::kNegInf: {
3702         add_one_elem_seen();
3703         if (lexer_.GetKind() == TokKind::kw_true ||
3704             lexer_.GetKind() == TokKind::kw_false) {
3705           if (!SetValueInLiteral(lexer_.GetLoc(),
3706                                  lexer_.GetKind() == TokKind::kw_true,
3707                                  linear_index++, literal)) {
3708             return false;
3709           }
3710           lexer_.Lex();
3711         } else if (primitive_util::IsIntegralType(shape.element_type()) ||
3712                    shape.element_type() == PRED) {
3713           LocTy loc = lexer_.GetLoc();
3714           int64_t value;
3715           if (!ParseInt64(&value)) {
3716             return Error(loc, StrCat("expects integer for primitive type: ",
3717                                      PrimitiveType_Name(shape.element_type())));
3718           }
3719           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3720             return false;
3721           }
3722         } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
3723           LocTy loc = lexer_.GetLoc();
3724           double value;
3725           if (!ParseDouble(&value)) {
3726             return Error(
3727                 loc, StrCat("expect floating point value for primitive type: ",
3728                             PrimitiveType_Name(shape.element_type())));
3729           }
3730           if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
3731             return false;
3732           }
3733         } else {
3734           return TokenError(StrCat("unsupported primitive type ",
3735                                    PrimitiveType_Name(shape.element_type())));
3736         }
3737         break;
3738       }
3739     }  // end of switch
3740   } while (nest_level > 0);
3741 
3742   *literal = literal->Relayout(shape.layout());
3743   return true;
3744 }
3745 
3746 // MaxFiniteValue is a type-traits helper used by
3747 // HloParserImpl::CheckParsedValueIsInRange.
3748 template <typename T>
3749 struct MinMaxFiniteValue {
maxxla::__anon2255e8ac0111::MinMaxFiniteValue3750   static T max() { return std::numeric_limits<T>::max(); }
minxla::__anon2255e8ac0111::MinMaxFiniteValue3751   static T min() { return std::numeric_limits<T>::lowest(); }
3752 };
3753 
3754 template <>
3755 struct MinMaxFiniteValue<Eigen::half> {
maxxla::__anon2255e8ac0111::MinMaxFiniteValue3756   static double max() {
3757     // Sadly this is not constexpr, so this forces `value` to be a method.
3758     return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
3759   }
minxla::__anon2255e8ac0111::MinMaxFiniteValue3760   static double min() { return -max(); }
3761 };
3762 
3763 template <>
3764 struct MinMaxFiniteValue<bfloat16> {
maxxla::__anon2255e8ac0111::MinMaxFiniteValue3765   static double max() {
3766     return static_cast<double>(Eigen::NumTraits<Eigen::bfloat16>::highest());
3767   }
minxla::__anon2255e8ac0111::MinMaxFiniteValue3768   static double min() { return -max(); }
3769 };
3770 
3771 // MSVC's standard C++ library does not define isnan/isfinite for integer types.
3772 // To work around that we will need to provide our own.
3773 template <typename T>
IsFinite(T val)3774 std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
3775   return std::isfinite(val);
3776 }
3777 template <typename T>
IsNaN(T val)3778 std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
3779   return std::isnan(val);
3780 }
3781 template <typename T>
IsFinite(T val)3782 std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
3783   return std::isfinite(static_cast<double>(val));
3784 }
3785 template <typename T>
IsNaN(T val)3786 std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
3787   return std::isnan(static_cast<double>(val));
3788 }
3789 
3790 template <typename LiteralNativeT, typename ParsedElemT>
CheckParsedValueIsInRange(LocTy loc,ParsedElemT value)3791 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
3792   if (std::is_floating_point<ParsedElemT>::value) {
3793     auto value_as_native_t = static_cast<LiteralNativeT>(value);
3794     auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
3795     if (!IsFinite(value) || IsFinite(value_double_converted)) {
3796       value = value_double_converted;
3797     }
3798   }
3799   PrimitiveType literal_ty =
3800       primitive_util::NativeToPrimitiveType<LiteralNativeT>();
3801   if (IsNaN(value) ||
3802       (std::numeric_limits<ParsedElemT>::has_infinity &&
3803        (std::numeric_limits<ParsedElemT>::infinity() == value ||
3804         -std::numeric_limits<ParsedElemT>::infinity() == value))) {
3805     // Skip range checking for non-finite value.
3806   } else if (std::is_unsigned<LiteralNativeT>::value) {
3807     CHECK((std::is_same<ParsedElemT, int64_t>::value ||
3808            std::is_same<ParsedElemT, bool>::value))
3809         << "Unimplemented checking for ParsedElemT";
3810 
3811     const uint64_t unsigned_value = value;
3812     const uint64_t upper_bound =
3813         static_cast<uint64_t>(std::numeric_limits<LiteralNativeT>::max());
3814     if (unsigned_value > upper_bound) {
3815       // Value is out of range for LiteralNativeT.
3816       return Error(loc, StrCat("value ", value,
3817                                " is out of range for literal's primitive type ",
3818                                PrimitiveType_Name(literal_ty), " namely [0, ",
3819                                upper_bound, "]."));
3820     }
3821   } else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
3822              value < MinMaxFiniteValue<LiteralNativeT>::min()) {
3823     // Value is out of range for LiteralNativeT.
3824     return Error(loc, StrCat("value ", value,
3825                              " is out of range for literal's primitive type ",
3826                              PrimitiveType_Name(literal_ty), " namely [",
3827                              MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
3828                              MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
3829   }
3830   return true;
3831 }
3832 
3833 template <typename LiteralNativeT>
CheckParsedValueIsInRange(LocTy loc,std::complex<double> value)3834 bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc,
3835                                               std::complex<double> value) {
3836   // e.g. `float` for std::complex<float>
3837   using LiteralComplexComponentT =
3838       decltype(std::real(std::declval<LiteralNativeT>()));
3839 
3840   // We could do simply
3841   //
3842   //   return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
3843   //          CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
3844   //
3845   // but this would give bad error messages on failure.
3846 
3847   auto check_component = [&](absl::string_view name, double v) {
3848     if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
3849         v == -std::numeric_limits<double>::infinity()) {
3850       // Skip range-checking for non-finite values.
3851       return true;
3852     }
3853 
3854     double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
3855     double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
3856     if (v < min || v > max) {
3857       // Value is out of range for LitearlComplexComponentT.
3858       return Error(
3859           loc,
3860           StrCat(name, " part ", v,
3861                  " is out of range for literal's primitive type ",
3862                  PrimitiveType_Name(
3863                      primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
3864                  ", namely [", min, ", ", max, "]."));
3865     }
3866     return true;
3867   };
3868   return check_component("real", std::real(value)) &&
3869          check_component("imaginary", std::imag(value));
3870 }
3871 
3872 // operands ::= '(' operands1 ')'
3873 // operands1
3874 //   ::= /*empty*/
3875 //   ::= operand (, operand)*
3876 // operand ::= (shape)? name
3877 //         ::= (shape)? opcode operands
ParseOperands(std::vector<HloInstruction * > * operands,HloComputation::Builder * builder)3878 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3879                                   HloComputation::Builder* builder) {
3880   CHECK(operands != nullptr);
3881   if (!ParseToken(TokKind::kLparen,
3882                   "expects '(' at the beginning of operands")) {
3883     return false;
3884   }
3885   if (lexer_.GetKind() == TokKind::kRparen) {
3886     // empty
3887   } else {
3888     do {
3889       // Try to parse the operand as a name with an optional shape.  If that
3890       // doesn't work, try again parsing it as a nested instruction.
3891       //
3892       // (Trying nested instructions second is important here: If you have a
3893       // giant HLO dump, it likely doesn't have any nested instructions, but
3894       // likely has tons of non-nested operands.  Generating an error is slow --
3895       // O(n) as of writing -- so we only want to hit the error branch in the
3896       // uncommon case.)
3897       HloLexer lexer_copy = lexer_;
3898       std::vector<std::string> saved_errors;
3899       std::swap(saved_errors, error_);
3900       bool is_normal_operand = [&] {
3901         LocTy loc = lexer_.GetLoc();
3902         std::string name;
3903         optional<Shape> shape;
3904         if (CanBeShape()) {
3905           shape.emplace();
3906           if (!ParseShape(&shape.value())) {
3907             return false;
3908           }
3909         }
3910         if (!ParseName(&name)) {
3911           // When parsing a single instruction (as opposed to a whole module),
3912           // an HLO may have one or more operands with a shape but no name:
3913           //
3914           //  foo = add(f32[10], f32[10])
3915           //
3916           // create_missing_instruction_ is always non-null when parsing a
3917           // single instruction, and is responsible for creating kParameter
3918           // instructions for these operands.
3919           if (shape.has_value() && create_missing_instruction_ != nullptr &&
3920               scoped_name_tables_.size() == 1) {
3921             name = "";
3922           } else {
3923             return false;
3924           }
3925         }
3926         std::pair<HloInstruction*, LocTy>* instruction =
3927             FindInstruction(name, shape);
3928         if (instruction == nullptr) {
3929           return Error(loc, StrCat("instruction does not exist: ", name));
3930         }
3931 
3932         // If this is a regular named operand, it must be followed by a comma or
3933         // a close-paren.  If not, it has to be a named instruction.  Don't
3934         // output an error here -- if it fails to parse as a named instruction
3935         // too, we'll just use that set of errors.
3936         auto next = lexer_.GetKind();
3937         if (next != TokKind::kComma && next != TokKind::kRparen) {
3938           return false;
3939         }
3940 
3941         operands->push_back(instruction->first);
3942         return true;
3943       }();
3944 
3945       if (is_normal_operand) {
3946         error_ = std::move(saved_errors);
3947         continue;
3948       }
3949 
3950       // If parsing as a normal operand failed, try parsing as a nested
3951       // instruction.
3952       std::vector<std::string> normal_operand_errors;
3953       std::swap(error_, normal_operand_errors);
3954       lexer_ = lexer_copy;
3955 
3956       // Nested instructions can't have attributes because it's ambiguous
3957       // whether the comma separates an instruction from its attribute, or
3958       // whether the comma separates two instructions.
3959       LocTy loc = lexer_.GetLoc();
3960       bool is_nested_instruction = ParseInstructionRhs(
3961           builder, /*name=*/"", loc, /*allow_attributes=*/false);
3962       if (is_nested_instruction) {
3963         operands->push_back(builder->last_added_instruction());
3964         error_ = std::move(saved_errors);
3965         continue;
3966       }
3967 
3968       // If neither parsing as a normal operand nor parsing as a nested
3969       // instruction worked, fail.  Return both sets of errors.
3970       std::vector<std::string> nested_instruction_errors;
3971       std::swap(error_, nested_instruction_errors);
3972       error_ = std::move(saved_errors);
3973       Error(loc,
3974             "cannot parse as an instruction name or as a nested instruction:");
3975       error_.insert(error_.end(),
3976                     std::make_move_iterator(normal_operand_errors.begin()),
3977                     std::make_move_iterator(normal_operand_errors.end()));
3978       error_.insert(error_.end(),
3979                     std::make_move_iterator(nested_instruction_errors.begin()),
3980                     std::make_move_iterator(nested_instruction_errors.end()));
3981     } while (EatIfPresent(TokKind::kComma));
3982   }
3983   return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
3984 }
3985 
ParseOperands(std::vector<HloInstruction * > * operands,HloComputation::Builder * builder,const int expected_size)3986 bool HloParserImpl::ParseOperands(std::vector<HloInstruction*>* operands,
3987                                   HloComputation::Builder* builder,
3988                                   const int expected_size) {
3989   CHECK(operands != nullptr);
3990   LocTy loc = lexer_.GetLoc();
3991   if (!ParseOperands(operands, builder)) {
3992     return false;
3993   }
3994   if (expected_size != operands->size()) {
3995     return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
3996                              operands->size(), " operands"));
3997   }
3998   return true;
3999 }
4000 
4001 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs)4002 bool HloParserImpl::ParseSubAttributes(
4003     const absl::flat_hash_map<std::string, AttrConfig>& attrs) {
4004   LocTy loc = lexer_.GetLoc();
4005   if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
4006     return false;
4007   }
4008   absl::flat_hash_set<std::string> seen_attrs;
4009   if (lexer_.GetKind() == TokKind::kRbrace) {
4010     // empty
4011   } else {
4012     do {
4013       EatIfPresent(TokKind::kComma);
4014       if (!ParseAttributeHelper(attrs, &seen_attrs)) {
4015         return false;
4016       }
4017     } while (lexer_.GetKind() != TokKind::kRbrace);
4018   }
4019   // Check that all required attrs were seen.
4020   for (const auto& attr_it : attrs) {
4021     if (attr_it.second.required &&
4022         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
4023       return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
4024                                   attr_it.first));
4025     }
4026   }
4027   return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
4028 }
4029 
4030 // attributes ::= (',' attribute)*
ParseAttributes(const absl::flat_hash_map<std::string,AttrConfig> & attrs,bool allow_attributes)4031 bool HloParserImpl::ParseAttributes(
4032     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4033     bool allow_attributes) {
4034   LocTy loc = lexer_.GetLoc();
4035   absl::flat_hash_set<std::string> seen_attrs;
4036   if (allow_attributes) {
4037     while (EatIfPresent(TokKind::kComma)) {
4038       if (!ParseAttributeHelper(attrs, &seen_attrs)) {
4039         return false;
4040       }
4041     }
4042   }
4043 
4044   // Check that all required attrs were seen.
4045   for (const auto& attr_it : attrs) {
4046     if (attr_it.second.required &&
4047         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
4048       return Error(loc, StrFormat("attribute %s is expected but not seen",
4049                                   attr_it.first));
4050     }
4051   }
4052   return true;
4053 }
4054 
ParseAttributeHelper(const absl::flat_hash_map<std::string,AttrConfig> & attrs,absl::flat_hash_set<std::string> * seen_attrs)4055 bool HloParserImpl::ParseAttributeHelper(
4056     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4057     absl::flat_hash_set<std::string>* seen_attrs) {
4058   LocTy loc = lexer_.GetLoc();
4059   std::string name;
4060   if (!ParseAttributeName(&name)) {
4061     return Error(loc, "error parsing attributes");
4062   }
4063   VLOG(3) << "Parsing attribute " << name;
4064   if (!seen_attrs->insert(name).second) {
4065     return Error(loc, StrFormat("attribute %s already exists", name));
4066   }
4067   auto attr_it = attrs.find(name);
4068   if (attr_it == attrs.end()) {
4069     std::string allowed_attrs;
4070     if (attrs.empty()) {
4071       allowed_attrs = "No attributes are allowed here.";
4072     } else {
4073       allowed_attrs =
4074           StrCat("Allowed attributes: ",
4075                  StrJoin(attrs, ", ",
4076                          [&](std::string* out,
4077                              const std::pair<std::string, AttrConfig>& kv) {
4078                            StrAppend(out, kv.first);
4079                          }));
4080     }
4081     return Error(loc, StrFormat("unexpected attribute \"%s\".  %s", name,
4082                                 allowed_attrs));
4083   }
4084   AttrTy attr_type = attr_it->second.attr_type;
4085   void* attr_out_ptr = attr_it->second.result;
4086   bool success = [&] {
4087     LocTy attr_loc = lexer_.GetLoc();
4088     switch (attr_type) {
4089       case AttrTy::kBool: {
4090         bool result;
4091         if (!ParseBool(&result)) {
4092           return false;
4093         }
4094         static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
4095         return true;
4096       }
4097       case AttrTy::kInt64: {
4098         int64_t result;
4099         if (!ParseInt64(&result)) {
4100           return false;
4101         }
4102         static_cast<optional<int64_t>*>(attr_out_ptr)->emplace(result);
4103         return true;
4104       }
4105       case AttrTy::kInt32: {
4106         int64_t result;
4107         if (!ParseInt64(&result)) {
4108           return false;
4109         }
4110         if (result != static_cast<int32_t>(result)) {
4111           return Error(attr_loc, "value out of range for int32_t");
4112         }
4113         static_cast<optional<int32_t>*>(attr_out_ptr)
4114             ->emplace(static_cast<int32_t>(result));
4115         return true;
4116       }
4117       case AttrTy::kFloat: {
4118         double result;
4119         if (!ParseDouble(&result)) {
4120           return false;
4121         }
4122         if (result > std::numeric_limits<float>::max() ||
4123             result < std::numeric_limits<float>::lowest()) {
4124           return Error(attr_loc, "value out of range for float");
4125         }
4126         static_cast<optional<float>*>(attr_out_ptr)
4127             ->emplace(static_cast<float>(result));
4128         return true;
4129       }
4130       case AttrTy::kHloComputation: {
4131         HloComputation* result = nullptr;
4132         if (!ParseHloComputation(&result)) {
4133           return false;
4134         }
4135         static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
4136         return true;
4137       }
4138       case AttrTy::kBracedHloComputationList: {
4139         std::vector<HloComputation*> result;
4140         if (!ParseHloComputationList(&result)) {
4141           return false;
4142         }
4143         static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
4144             ->emplace(result);
4145         return true;
4146       }
4147       case AttrTy::kFftType: {
4148         FftType result;
4149         if (!ParseFftType(&result)) {
4150           return false;
4151         }
4152         static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
4153         return true;
4154       }
4155       case AttrTy::kPaddingType: {
4156         PaddingType result;
4157         if (!ParsePaddingType(&result)) {
4158           return false;
4159         }
4160         static_cast<optional<PaddingType>*>(attr_out_ptr)->emplace(result);
4161         return true;
4162       }
4163       case AttrTy::kComparisonDirection: {
4164         ComparisonDirection result;
4165         if (!ParseComparisonDirection(&result)) {
4166           return false;
4167         }
4168         static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
4169             ->emplace(result);
4170         return true;
4171       }
4172       case AttrTy::kComparisonType: {
4173         Comparison::Type result;
4174         if (!ParseComparisonType(&result)) {
4175           return false;
4176         }
4177         static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
4178         return true;
4179       }
4180       case AttrTy::kEnum: {
4181         if (lexer_.GetKind() != TokKind::kIdent) {
4182           return TokenError("expects an enumeration value");
4183         }
4184         std::string result = lexer_.GetStrVal();
4185         lexer_.Lex();
4186         static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
4187         return true;
4188       }
4189       case AttrTy::kWindow: {
4190         Window result;
4191         if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
4192           return false;
4193         }
4194         static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
4195         return true;
4196       }
4197       case AttrTy::kConvolutionDimensionNumbers: {
4198         ConvolutionDimensionNumbers result;
4199         if (!ParseConvolutionDimensionNumbers(&result)) {
4200           return false;
4201         }
4202         static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
4203             ->emplace(result);
4204         return true;
4205       }
4206       case AttrTy::kSharding: {
4207         OpSharding sharding;
4208         if (!ParseSharding(&sharding)) {
4209           return false;
4210         }
4211         static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
4212         return true;
4213       }
4214       case AttrTy::kFrontendAttributes: {
4215         FrontendAttributes frontend_attributes;
4216         if (!ParseFrontendAttributes(&frontend_attributes)) {
4217           return false;
4218         }
4219         static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
4220             ->emplace(frontend_attributes);
4221         return true;
4222       }
4223       case AttrTy::kParameterReplication: {
4224         ParameterReplication parameter_replication;
4225         if (!ParseParameterReplication(&parameter_replication)) {
4226           return false;
4227         }
4228         static_cast<optional<ParameterReplication>*>(attr_out_ptr)
4229             ->emplace(parameter_replication);
4230         return true;
4231       }
4232       case AttrTy::kInstructionList: {
4233         std::vector<HloInstruction*> result;
4234         if (!ParseInstructionNames(&result)) {
4235           return false;
4236         }
4237         static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
4238             ->emplace(result);
4239         return true;
4240       }
4241       case AttrTy::kFusionKind: {
4242         HloInstruction::FusionKind result;
4243         if (!ParseFusionKind(&result)) {
4244           return false;
4245         }
4246         static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
4247             ->emplace(result);
4248         return true;
4249       }
4250       case AttrTy::kBracedInt64List: {
4251         std::vector<int64_t> result;
4252         if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4253                             &result)) {
4254           return false;
4255         }
4256         static_cast<optional<std::vector<int64_t>>*>(attr_out_ptr)
4257             ->emplace(result);
4258         return true;
4259       }
4260       case AttrTy::kBracedInt64ListList: {
4261         std::vector<std::vector<int64_t>> result;
4262         if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
4263                                 TokKind::kComma, &result)) {
4264           return false;
4265         }
4266         static_cast<optional<std::vector<std::vector<int64_t>>>*>(attr_out_ptr)
4267             ->emplace(result);
4268         return true;
4269       }
4270       case AttrTy::kSliceRanges: {
4271         SliceRanges result;
4272         if (!ParseSliceRanges(&result)) {
4273           return false;
4274         }
4275         static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
4276         return true;
4277       }
4278       case AttrTy::kPaddingConfig: {
4279         PaddingConfig result;
4280         if (!ParsePaddingConfig(&result)) {
4281           return false;
4282         }
4283         static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
4284         return true;
4285       }
4286       case AttrTy::kString: {
4287         std::string result;
4288         if (!ParseString(&result)) {
4289           return false;
4290         }
4291         static_cast<optional<std::string>*>(attr_out_ptr)->emplace(result);
4292         return true;
4293       }
4294       case AttrTy::kMetadata: {
4295         OpMetadata result;
4296         if (!ParseMetadata(&result)) {
4297           return false;
4298         }
4299         static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
4300         return true;
4301       }
4302       case AttrTy::kDistribution: {
4303         RandomDistribution result;
4304         if (!ParseRandomDistribution(&result)) {
4305           return false;
4306         }
4307         static_cast<optional<RandomDistribution>*>(attr_out_ptr)
4308             ->emplace(result);
4309         return true;
4310       }
4311       case AttrTy::kDomain: {
4312         return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
4313       }
4314       case AttrTy::kPrecisionList: {
4315         std::vector<PrecisionConfig::Precision> result;
4316         if (!ParsePrecisionList(&result)) {
4317           return false;
4318         }
4319         static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
4320             attr_out_ptr)
4321             ->emplace(result);
4322         return true;
4323       }
4324       case AttrTy::kShape: {
4325         Shape result;
4326         if (!ParseShape(&result)) {
4327           return false;
4328         }
4329         static_cast<optional<Shape>*>(attr_out_ptr)->emplace(result);
4330         return true;
4331       }
4332       case AttrTy::kShapeList: {
4333         std::vector<Shape> result;
4334         if (!ParseShapeList(&result)) {
4335           return false;
4336         }
4337         static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
4338             ->emplace(result);
4339         return true;
4340       }
4341       case AttrTy::kRandomAlgorithm: {
4342         RandomAlgorithm result;
4343         if (!ParseRandomAlgorithm(&result)) {
4344           return false;
4345         }
4346         static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
4347         return true;
4348       }
4349       case AttrTy::kAliasing: {
4350         AliasingData aliasing_data;
4351         if (!ParseAliasing(&aliasing_data)) {
4352           return false;
4353         }
4354         static_cast<optional<AliasingData>*>(attr_out_ptr)
4355             ->emplace(aliasing_data);
4356         return true;
4357       }
4358       case AttrTy::kComputationLayout: {
4359         ComputationLayout computation_layout(ShapeLayout(Shape{}));
4360         if (!ParseComputationLayout(&computation_layout)) {
4361           return false;
4362         }
4363         static_cast<optional<ComputationLayout>*>(attr_out_ptr)
4364             ->emplace(computation_layout);
4365         return true;
4366       }
4367       case AttrTy::kInstructionAliasing: {
4368         std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
4369             aliasing_output_operand_pairs;
4370         if (!ParseInstructionOutputOperandAliasing(
4371                 &aliasing_output_operand_pairs)) {
4372           return false;
4373         }
4374         static_cast<optional<std::vector<
4375             std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>>*>(
4376             attr_out_ptr)
4377             ->emplace(std::move(aliasing_output_operand_pairs));
4378         return true;
4379       }
4380       case AttrTy::kLiteral: {
4381         Literal result;
4382         if (!ParseLiteral(&result)) {
4383           return false;
4384         }
4385         static_cast<optional<Literal>*>(attr_out_ptr)
4386             ->emplace(std::move(result));
4387         return true;
4388       }
4389       case AttrTy::kCustomCallSchedule: {
4390         CustomCallSchedule result;
4391         if (!ParseCustomCallSchedule(&result)) {
4392           return false;
4393         }
4394         static_cast<optional<CustomCallSchedule>*>(attr_out_ptr)
4395             ->emplace(result);
4396         return true;
4397       }
4398       case AttrTy::kCustomCallApiVersion: {
4399         CustomCallApiVersion result;
4400         if (!ParseCustomCallApiVersion(&result)) {
4401           return false;
4402         }
4403         static_cast<optional<CustomCallApiVersion>*>(attr_out_ptr)
4404             ->emplace(result);
4405         return true;
4406       }
4407     }
4408   }();
4409   if (!success) {
4410     return Error(loc, StrFormat("error parsing attribute %s", name));
4411   }
4412   return true;
4413 }
4414 
CopyAttributeToProtoMessage(absl::flat_hash_set<std::string> non_proto_attrs,const absl::flat_hash_map<std::string,AttrConfig> & attrs,tensorflow::protobuf::Message * message)4415 bool HloParserImpl::CopyAttributeToProtoMessage(
4416     absl::flat_hash_set<std::string> non_proto_attrs,
4417     const absl::flat_hash_map<std::string, AttrConfig>& attrs,
4418     tensorflow::protobuf::Message* message) {
4419   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4420   const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
4421 
4422   for (const auto& p : attrs) {
4423     const std::string& name = p.first;
4424     if (non_proto_attrs.find(name) != non_proto_attrs.end()) {
4425       continue;
4426     }
4427     const tensorflow::protobuf::FieldDescriptor* fd =
4428         descriptor->FindFieldByName(name);
4429     if (!fd) {
4430       std::string allowed_attrs = "Allowed attributes: ";
4431 
4432       for (int i = 0; i < descriptor->field_count(); ++i) {
4433         if (i == 0) {
4434           absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
4435         } else {
4436           absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
4437         }
4438       }
4439       return TokenError(
4440           StrFormat("unexpected attribute \"%s\".  %s", name, allowed_attrs));
4441     }
4442 
4443     CHECK(!fd->is_repeated());  // Repeated fields not implemented.
4444     bool success = [&] {
4445       switch (fd->type()) {
4446         case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4447           auto attr_value = static_cast<optional<bool>*>(p.second.result);
4448           if (attr_value->has_value()) {
4449             reflection->SetBool(message, fd, **attr_value);
4450           }
4451           return true;
4452         }
4453         case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4454           auto attr_value =
4455               static_cast<optional<std::string>*>(p.second.result);
4456           if (attr_value->has_value()) {
4457             const tensorflow::protobuf::EnumValueDescriptor* evd =
4458                 fd->enum_type()->FindValueByName(**attr_value);
4459             reflection->SetEnum(message, fd, evd);
4460           }
4461           return true;
4462         }
4463         default:
4464           return false;
4465       }
4466     }();
4467 
4468     if (!success) {
4469       return TokenError(StrFormat("error parsing attribute %s", name));
4470     }
4471   }
4472 
4473   return true;
4474 }
4475 
4476 // attributes ::= (',' attribute)*
ParseAttributesAsProtoMessage(const absl::flat_hash_map<std::string,AttrConfig> & non_proto_attrs,tensorflow::protobuf::Message * message)4477 bool HloParserImpl::ParseAttributesAsProtoMessage(
4478     const absl::flat_hash_map<std::string, AttrConfig>& non_proto_attrs,
4479     tensorflow::protobuf::Message* message) {
4480   const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
4481   absl::flat_hash_map<std::string, AttrConfig> attrs;
4482 
4483   // Storage for attributes.
4484   std::vector<optional<bool>> bool_params;
4485   std::vector<optional<std::string>> string_params;
4486   // Reserve enough capacity to make sure that the vector is not growing, so we
4487   // can rely on the pointers to stay valid.
4488   bool_params.reserve(descriptor->field_count());
4489   string_params.reserve(descriptor->field_count());
4490 
4491   // Populate the storage of expected attributes from the protobuf description.
4492   for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
4493     const tensorflow::protobuf::FieldDescriptor* fd =
4494         descriptor->field(field_idx);
4495     const std::string& field_name = fd->name();
4496     switch (fd->type()) {
4497       case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
4498         bool_params.emplace_back(std::nullopt);
4499         attrs[field_name] = {/*is_required*/ false, AttrTy::kBool,
4500                              &bool_params.back()};
4501         break;
4502       }
4503       case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
4504         string_params.emplace_back(std::nullopt);
4505         attrs[field_name] = {/*is_required*/ false, AttrTy::kEnum,
4506                              &string_params.back()};
4507         break;
4508       }
4509       default:
4510         return TokenError(absl::StrFormat(
4511             "Unexpected protocol buffer type: %s ", fd->DebugString()));
4512     }
4513   }
4514 
4515   absl::flat_hash_set<std::string> non_proto_attrs_names;
4516   non_proto_attrs_names.reserve(non_proto_attrs.size());
4517   for (const auto& p : non_proto_attrs) {
4518     const std::string& attr_name = p.first;
4519     // If an attribute is both specified within 'non_proto_attrs' and an
4520     // attribute of the proto message, we prefer the attribute of the proto
4521     // message.
4522     if (attrs.find(attr_name) == attrs.end()) {
4523       non_proto_attrs_names.insert(attr_name);
4524       attrs[attr_name] = p.second;
4525     }
4526   }
4527 
4528   if (!ParseAttributes(attrs)) {
4529     return false;
4530   }
4531 
4532   return CopyAttributeToProtoMessage(non_proto_attrs_names, attrs, message);
4533 }
4534 
ParseComputationName(HloComputation ** value)4535 bool HloParserImpl::ParseComputationName(HloComputation** value) {
4536   std::string name;
4537   LocTy loc = lexer_.GetLoc();
4538   if (!ParseName(&name)) {
4539     return Error(loc, "expects computation name");
4540   }
4541   std::pair<HloComputation*, LocTy>* computation =
4542       tensorflow::gtl::FindOrNull(computation_pool_, name);
4543   if (computation == nullptr) {
4544     return Error(loc, StrCat("computation does not exist: ", name));
4545   }
4546   *value = computation->first;
4547   return true;
4548 }
4549 
4550 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
4551 // The subattributes can appear in any order. 'size=' is required, others are
4552 // optional.
ParseWindow(Window * window,bool expect_outer_curlies)4553 bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
4554   LocTy loc = lexer_.GetLoc();
4555   if (expect_outer_curlies &&
4556       !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
4557     return false;
4558   }
4559 
4560   std::vector<int64_t> size;
4561   std::vector<int64_t> stride;
4562   std::vector<std::vector<int64_t>> pad;
4563   std::vector<int64_t> lhs_dilate;
4564   std::vector<int64_t> rhs_dilate;
4565   std::vector<int64_t> rhs_reversal;
4566   const auto end_token =
4567       expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
4568   while (lexer_.GetKind() != end_token) {
4569     LocTy attr_loc = lexer_.GetLoc();
4570     std::string field_name;
4571     if (!ParseAttributeName(&field_name)) {
4572       return Error(attr_loc, "expects sub-attributes in window");
4573     }
4574     bool ok = [&] {
4575       if (field_name == "size") {
4576         return ParseDxD("size", &size);
4577       }
4578       if (field_name == "stride") {
4579         return ParseDxD("stride", &stride);
4580       }
4581       if (field_name == "lhs_dilate") {
4582         return ParseDxD("lhs_dilate", &lhs_dilate);
4583       }
4584       if (field_name == "rhs_dilate") {
4585         return ParseDxD("rls_dilate", &rhs_dilate);
4586       }
4587       if (field_name == "pad") {
4588         return ParseWindowPad(&pad);
4589       }
4590       if (field_name == "rhs_reversal") {
4591         return ParseDxD("rhs_reversal", &rhs_reversal);
4592       }
4593       return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
4594     }();
4595     if (!ok) {
4596       return false;
4597     }
4598   }
4599 
4600   if (!stride.empty() && stride.size() != size.size()) {
4601     return Error(loc, "expects 'stride=' has the same size as 'size='");
4602   }
4603   if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
4604     return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
4605   }
4606   if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
4607     return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
4608   }
4609   if (!pad.empty() && pad.size() != size.size()) {
4610     return Error(loc, "expects 'pad=' has the same size as 'size='");
4611   }
4612 
4613   for (int i = 0; i < size.size(); i++) {
4614     window->add_dimensions()->set_size(size[i]);
4615     if (!pad.empty()) {
4616       window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
4617       window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
4618     }
4619     // If some field is not present, it has the default value.
4620     window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
4621     window->mutable_dimensions(i)->set_base_dilation(
4622         lhs_dilate.empty() ? 1 : lhs_dilate[i]);
4623     window->mutable_dimensions(i)->set_window_dilation(
4624         rhs_dilate.empty() ? 1 : rhs_dilate[i]);
4625     window->mutable_dimensions(i)->set_window_reversal(
4626         rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
4627   }
4628   return !expect_outer_curlies ||
4629          ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
4630 }
4631 
4632 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
4633 // The string looks like "dim_labels=0bf_0io->0bf".
4634 //
4635 // '?' dims don't appear in ConvolutionDimensionNumbers.  There can be more than
4636 // one '?' dim.
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)4637 bool HloParserImpl::ParseConvolutionDimensionNumbers(
4638     ConvolutionDimensionNumbers* dnums) {
4639   if (lexer_.GetKind() != TokKind::kDimLabels) {
4640     return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
4641   }
4642   std::string str = lexer_.GetStrVal();
4643 
4644   // The str is expected to have 3 items, lhs, rhs, out, and it must look like
4645   // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
4646   std::vector<std::string> split1 = absl::StrSplit(str, '_');
4647   if (split1.size() != 2) {
4648     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4649                << str;
4650   }
4651   std::vector<std::string> split2 = absl::StrSplit(split1[1], "->");
4652   if (split2.size() != 2) {
4653     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
4654                << str;
4655   }
4656   absl::string_view lhs = split1[0];
4657   absl::string_view rhs = split2[0];
4658   absl::string_view out = split2[1];
4659 
4660   auto is_unique = [](absl::string_view str) -> bool {
4661     absl::flat_hash_set<char> chars;
4662     for (char c : str) {
4663       // '?' dims are skipped.
4664       if (c == '?') {
4665         continue;
4666       }
4667       if (!chars.insert(c).second) {
4668         return false;
4669       }
4670     }
4671     return true;
4672   };
4673 
4674   // lhs
4675   {
4676     if (!is_unique(lhs)) {
4677       return TokenError(
4678           StrCat("expects unique lhs dimension numbers, but sees ", lhs));
4679     }
4680     // Count number of spatial dimensions.
4681     for (char c : lhs) {
4682       if (c != 'b' && c != 'f' && c != '?') {
4683         dnums->add_input_spatial_dimensions(-1);
4684       }
4685     }
4686     for (int i = 0; i < lhs.size(); i++) {
4687       char c = lhs[i];
4688       if (c == '?') {
4689         continue;
4690       } else if (c == 'b') {
4691         dnums->set_input_batch_dimension(i);
4692       } else if (c == 'f') {
4693         dnums->set_input_feature_dimension(i);
4694       } else if (c < '0' + lhs.size() && c >= '0') {
4695         dnums->set_input_spatial_dimensions(c - '0', i);
4696       } else {
4697         return TokenError(StrFormat(
4698             "expects [0-%dbf?] in lhs dimension numbers", lhs.size() - 1));
4699       }
4700     }
4701   }
4702   // rhs
4703   {
4704     if (!is_unique(rhs)) {
4705       return TokenError(
4706           StrCat("expects unique rhs dimension numbers, but sees ", rhs));
4707     }
4708     // Count number of spatial dimensions.
4709     for (char c : rhs) {
4710       if (c != 'i' && c != 'o' && c != '?') {
4711         dnums->add_kernel_spatial_dimensions(-1);
4712       }
4713     }
4714     for (int i = 0; i < rhs.size(); i++) {
4715       char c = rhs[i];
4716       if (c == '?') {
4717         continue;
4718       } else if (c == 'i') {
4719         dnums->set_kernel_input_feature_dimension(i);
4720       } else if (c == 'o') {
4721         dnums->set_kernel_output_feature_dimension(i);
4722       } else if (c < '0' + rhs.size() && c >= '0') {
4723         dnums->set_kernel_spatial_dimensions(c - '0', i);
4724       } else {
4725         return TokenError(StrFormat(
4726             "expects [0-%dio?] in rhs dimension numbers", rhs.size() - 1));
4727       }
4728     }
4729   }
4730   // output
4731   {
4732     if (!is_unique(out)) {
4733       return TokenError(
4734           StrCat("expects unique output dimension numbers, but sees ", out));
4735     }
4736     // Count number of spatial dimensions.
4737     for (char c : out) {
4738       if (c != 'b' && c != 'f' && c != '?') {
4739         dnums->add_output_spatial_dimensions(-1);
4740       }
4741     }
4742     for (int i = 0; i < out.size(); i++) {
4743       char c = out[i];
4744       if (c == '?') {
4745         continue;
4746       } else if (c == 'b') {
4747         dnums->set_output_batch_dimension(i);
4748       } else if (c == 'f') {
4749         dnums->set_output_feature_dimension(i);
4750       } else if (c < '0' + out.size() && c >= '0') {
4751         dnums->set_output_spatial_dimensions(c - '0', i);
4752       } else {
4753         return TokenError(StrFormat(
4754             "expects [0-%dbf?] in output dimension numbers", out.size() - 1));
4755       }
4756     }
4757   }
4758 
4759   // lhs, rhs, and output should have the same number of spatial dimensions.
4760   if (dnums->input_spatial_dimensions_size() !=
4761           dnums->output_spatial_dimensions_size() ||
4762       dnums->input_spatial_dimensions_size() !=
4763           dnums->kernel_spatial_dimensions_size()) {
4764     return TokenError(
4765         StrFormat("input, kernel, and output must have same number of spatial "
4766                   "dimensions, but got %d, %d, %d, respectively.",
4767                   dnums->input_spatial_dimensions_size(),
4768                   dnums->kernel_spatial_dimensions_size(),
4769                   dnums->output_spatial_dimensions_size()));
4770   }
4771 
4772   lexer_.Lex();
4773   return true;
4774 }
4775 
4776 // ::= '{' ranges '}'
4777 //   ::= /*empty*/
4778 //   ::= range (',' range)*
4779 // range ::= '[' start ':' limit (':' stride)? ']'
4780 //
4781 // The slice ranges are printed as:
4782 //
4783 //  {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
4784 //
4785 // This function extracts the starts, limits, and strides as 3 vectors to the
4786 // result. If stride is not present, stride is 1. For example, if the slice
4787 // ranges is printed as:
4788 //
4789 //  {[2:3:4], [5:6:7], [8:9]}
4790 //
4791 // The parsed result will be:
4792 //
4793 //  {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
4794 //
ParseSliceRanges(SliceRanges * result)4795 bool HloParserImpl::ParseSliceRanges(SliceRanges* result) {
4796   if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
4797     return false;
4798   }
4799   std::vector<std::vector<int64_t>> ranges;
4800   if (lexer_.GetKind() == TokKind::kRbrace) {
4801     // empty
4802     return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4803   }
4804   do {
4805     LocTy loc = lexer_.GetLoc();
4806     ranges.emplace_back();
4807     if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
4808                         &ranges.back())) {
4809       return false;
4810     }
4811     const auto& range = ranges.back();
4812     if (range.size() != 2 && range.size() != 3) {
4813       return Error(loc,
4814                    StrFormat("expects [start:limit:step] or [start:limit], "
4815                              "but sees %d elements.",
4816                              range.size()));
4817     }
4818   } while (EatIfPresent(TokKind::kComma));
4819 
4820   for (const auto& range : ranges) {
4821     result->starts.push_back(range[0]);
4822     result->limits.push_back(range[1]);
4823     result->strides.push_back(range.size() == 3 ? range[2] : 1);
4824   }
4825   return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
4826 }
4827 
4828 // precisionlist ::= start precision_elements end
4829 // precision_elements
4830 //   ::= /*empty*/
4831 //   ::= precision_val (delim precision_val)*
ParsePrecisionList(std::vector<PrecisionConfig::Precision> * result)4832 bool HloParserImpl::ParsePrecisionList(
4833     std::vector<PrecisionConfig::Precision>* result) {
4834   auto parse_and_add_item = [&]() {
4835     PrecisionConfig::Precision item;
4836     if (!ParsePrecision(&item)) {
4837       return false;
4838     }
4839     result->push_back(item);
4840     return true;
4841   };
4842   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4843                    parse_and_add_item);
4844 }
4845 
ParseHloComputation(HloComputation ** result)4846 bool HloParserImpl::ParseHloComputation(HloComputation** result) {
4847   if (lexer_.GetKind() == TokKind::kLbrace) {
4848     // This means it is a nested computation.
4849     return ParseInstructionList(result, /*computation_name=*/"_");
4850   }
4851   // This means it is a computation name.
4852   return ParseComputationName(result);
4853 }
4854 
ParseHloComputationList(std::vector<HloComputation * > * result)4855 bool HloParserImpl::ParseHloComputationList(
4856     std::vector<HloComputation*>* result) {
4857   auto parse_and_add_item = [&]() {
4858     HloComputation* computation;
4859     if (!ParseHloComputation(&computation)) {
4860       return false;
4861     }
4862     VLOG(3) << "parsed computation " << computation->name();
4863     result->push_back(computation);
4864     return true;
4865   };
4866   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4867                    parse_and_add_item);
4868 }
4869 
4870 // shapelist ::= '{' shapes '}'
4871 // precision_elements
4872 //   ::= /*empty*/
4873 //   ::= shape (',' shape)*
ParseShapeList(std::vector<Shape> * result)4874 bool HloParserImpl::ParseShapeList(std::vector<Shape>* result) {
4875   auto parse_and_add_item = [&]() {
4876     Shape shape;
4877     if (!ParseShape(&shape)) {
4878       return false;
4879     }
4880     result->push_back(std::move(shape));
4881     return true;
4882   };
4883   return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
4884                    parse_and_add_item);
4885 }
4886 
4887 // int64_tlist ::= start int64_elements end
4888 // int64_elements
4889 //   ::= /*empty*/
4890 //   ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64_t> * result)4891 bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end,
4892                                    const TokKind delim,
4893                                    std::vector<int64_t>* result) {
4894   auto parse_and_add_item = [&]() {
4895     int64_t i;
4896     if (!ParseInt64(&i)) {
4897       return false;
4898     }
4899     result->push_back(i);
4900     return true;
4901   };
4902   return ParseList(start, end, delim, parse_and_add_item);
4903 }
4904 
4905 // int64_tlistlist ::= start int64_tlist_elements end
4906 // int64_tlist_elements
4907 //   ::= /*empty*/
4908 //   ::= int64_tlist (delim int64_tlist)*
4909 // int64_tlist ::= start int64_elements end
4910 // int64_elements
4911 //   ::= /*empty*/
4912 //   ::= int64_val (delim int64_val)*
ParseInt64ListList(const TokKind start,const TokKind end,const TokKind delim,std::vector<std::vector<int64_t>> * result)4913 bool HloParserImpl::ParseInt64ListList(
4914     const TokKind start, const TokKind end, const TokKind delim,
4915     std::vector<std::vector<int64_t>>* result) {
4916   auto parse_and_add_item = [&]() {
4917     std::vector<int64_t> item;
4918     if (!ParseInt64List(start, end, delim, &item)) {
4919       return false;
4920     }
4921     result->push_back(item);
4922     return true;
4923   };
4924   return ParseList(start, end, delim, parse_and_add_item);
4925 }
4926 
ParseList(const TokKind start,const TokKind end,const TokKind delim,const std::function<bool ()> & parse_and_add_item)4927 bool HloParserImpl::ParseList(const TokKind start, const TokKind end,
4928                               const TokKind delim,
4929                               const std::function<bool()>& parse_and_add_item) {
4930   if (!ParseToken(start, StrCat("expects a list starting with ",
4931                                 TokKindToString(start)))) {
4932     return false;
4933   }
4934   if (lexer_.GetKind() == end) {
4935     // empty
4936   } else {
4937     do {
4938       if (!parse_and_add_item()) {
4939         return false;
4940       }
4941     } while (EatIfPresent(delim));
4942   }
4943   return ParseToken(
4944       end, StrCat("expects a list to end with ", TokKindToString(end)));
4945 }
4946 
4947 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)4948 bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
4949   if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
4950     return false;
4951   }
4952   *shape_loc = lexer_.GetLoc();
4953   return ParseShape(shape);
4954 }
4955 
CanBeParamListToShape()4956 bool HloParserImpl::CanBeParamListToShape() {
4957   return lexer_.GetKind() == TokKind::kLparen;
4958 }
4959 
4960 // param_list ::= '(' param_list1 ')'
4961 // param_list1
4962 //   ::= /*empty*/
4963 //   ::= param (',' param)*
4964 // param ::= name shape
ParseParamList()4965 bool HloParserImpl::ParseParamList() {
4966   if (!ParseToken(TokKind::kLparen,
4967                   "expects '(' at the beginning of param list")) {
4968     return false;
4969   }
4970 
4971   if (lexer_.GetKind() == TokKind::kRparen) {
4972     // empty
4973   } else {
4974     do {
4975       Shape shape;
4976       std::string name;
4977       if (!ParseName(&name) || !ParseShape(&shape)) {
4978         return false;
4979       }
4980     } while (EatIfPresent(TokKind::kComma));
4981   }
4982   return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
4983 }
4984 
4985 // dimension_sizes ::= '[' dimension_list ']'
4986 // dimension_list
4987 //   ::= /*empty*/
4988 //   ::= <=? int64_t (',' param)*
4989 // param ::= name shape
ParseDimensionSizes(std::vector<int64_t> * dimension_sizes,std::vector<bool> * dynamic_dimensions)4990 bool HloParserImpl::ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
4991                                         std::vector<bool>* dynamic_dimensions) {
4992   auto parse_and_add_item = [&]() {
4993     int64_t i;
4994     bool is_dynamic = false;
4995     if (lexer_.GetKind() == TokKind::kLeq) {
4996       is_dynamic = true;
4997       lexer_.Lex();
4998     }
4999     if (!ParseInt64(&i)) {
5000       return false;
5001     }
5002     dimension_sizes->push_back(i);
5003     dynamic_dimensions->push_back(is_dynamic);
5004     return true;
5005   };
5006   return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
5007                    parse_and_add_item);
5008 }
5009 
5010 // dim_level_types
5011 //   ::=  /* empty */
5012 //   ::= 'D' '(' dim_level_type_list ')'
5013 // dim_level_type_list
5014 //   ::= /* empty */
5015 //   ..= dim_level_type (',' dim_level_type)*
5016 // dim_level_type
5017 //   ::= 'D'
5018 //   ::= 'C'
5019 //   ::= 'S'
ParseDimLevelTypes(std::vector<DimLevelType> * dim_level_types)5020 bool HloParserImpl::ParseDimLevelTypes(
5021     std::vector<DimLevelType>* dim_level_types) {
5022   auto parse_and_add_item = [&]() {
5023     if (lexer_.GetKind() == TokKind::kIdent) {
5024       if (lexer_.GetStrVal() == "D") {
5025         lexer_.Lex();
5026         dim_level_types->push_back(DIM_DENSE);
5027         return true;
5028       } else if (lexer_.GetStrVal() == "C") {
5029         dim_level_types->push_back(DIM_COMPRESSED);
5030         lexer_.Lex();
5031         return true;
5032       } else if (lexer_.GetStrVal() == "S") {
5033         dim_level_types->push_back(DIM_SINGLETON);
5034         lexer_.Lex();
5035         return true;
5036       }
5037     }
5038     return Error(lexer_.GetLoc(),
5039                  "expected a DimLevelType abbreviation (D, C, or S)");
5040   };
5041   return ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
5042                    parse_and_add_item);
5043 }
5044 
5045 // tiles
5046 //   ::= /*empty*/
5047 //   ::= 'T' '(' dim_list ')'
5048 // dim_list
5049 //   ::= /*empty*/
5050 //   ::= (int64_t | '*') (',' (int64_t | '*'))*
ParseTiles(std::vector<Tile> * tiles)5051 bool HloParserImpl::ParseTiles(std::vector<Tile>* tiles) {
5052   auto parse_and_add_tile_dimension = [&]() {
5053     int64_t i;
5054     if (ParseInt64(&i)) {
5055       tiles->back().add_dimensions(i);
5056       return true;
5057     }
5058     if (lexer_.GetKind() == TokKind::kAsterisk) {
5059       tiles->back().add_dimensions(Tile::kCombineDimension);
5060       lexer_.Lex();
5061       return true;
5062     }
5063     return false;
5064   };
5065 
5066   do {
5067     tiles->push_back(Tile());
5068     if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
5069                    parse_and_add_tile_dimension)) {
5070       return false;
5071     }
5072   } while (lexer_.GetKind() == TokKind::kLparen);
5073   return true;
5074 }
5075 
5076 // int_attribute
5077 //   ::= /*empty*/
5078 //   ::= attr_token '(' attr_value ')'
5079 // attr_token
5080 //   ::= 'E' | 'S'
5081 // attr_value
5082 //   ::= int64_t
ParseLayoutIntAttribute(int64_t * attr_value,absl::string_view attr_description)5083 bool HloParserImpl::ParseLayoutIntAttribute(
5084     int64_t* attr_value, absl::string_view attr_description) {
5085   if (!ParseToken(TokKind::kLparen,
5086                   StrCat("expects ", attr_description, " to start with ",
5087                          TokKindToString(TokKind::kLparen)))) {
5088     return false;
5089   }
5090   if (!ParseInt64(attr_value)) {
5091     return false;
5092   }
5093   if (!ParseToken(TokKind::kRparen,
5094                   StrCat("expects ", attr_description, " to end with ",
5095                          TokKindToString(TokKind::kRparen)))) {
5096     return false;
5097   }
5098   return true;
5099 }
5100 
5101 // layout
5102 //   ::= '{' int64_list
5103 //       (':' dim_level_types tiles element_size_in_bits memory_space)?
5104 //       '}'
5105 // element_size_in_bits
5106 //   ::= /*empty*/
5107 //   ::= 'E' '(' int64_t ')'
5108 // memory_space
5109 //   ::= /*empty*/
5110 //   ::= 'S' '(' int64_t ')'
ParseLayout(Layout * layout)5111 bool HloParserImpl::ParseLayout(Layout* layout) {
5112   std::vector<int64_t> minor_to_major;
5113   std::vector<DimLevelType> dim_level_types;
5114   std::vector<Tile> tiles;
5115   int64_t element_size_in_bits = 0;
5116   int64_t memory_space = 0;
5117 
5118   auto parse_and_add_item = [&]() {
5119     int64_t i;
5120     if (!ParseInt64(&i)) {
5121       return false;
5122     }
5123     minor_to_major.push_back(i);
5124     return true;
5125   };
5126 
5127   if (!ParseToken(TokKind::kLbrace,
5128                   StrCat("expects layout to start with ",
5129                          TokKindToString(TokKind::kLbrace)))) {
5130     return false;
5131   }
5132   if (lexer_.GetKind() != TokKind::kRbrace) {
5133     if (lexer_.GetKind() == TokKind::kInt) {
5134       // Parse minor to major.
5135       do {
5136         if (!parse_and_add_item()) {
5137           return false;
5138         }
5139       } while (EatIfPresent(TokKind::kComma));
5140     }
5141 
5142     if (lexer_.GetKind() == TokKind::kColon) {
5143       lexer_.Lex();
5144 
5145       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "D") {
5146         lexer_.Lex();
5147         ParseDimLevelTypes(&dim_level_types);
5148       }
5149 
5150       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
5151         lexer_.Lex();
5152         ParseTiles(&tiles);
5153       }
5154 
5155       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
5156         lexer_.Lex();
5157         ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
5158       }
5159 
5160       if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
5161         lexer_.Lex();
5162         ParseLayoutIntAttribute(&memory_space, "memory space");
5163       }
5164     }
5165   }
5166   if (!ParseToken(TokKind::kRbrace,
5167                   StrCat("expects layout to end with ",
5168                          TokKindToString(TokKind::kRbrace)))) {
5169     return false;
5170   }
5171 
5172   std::vector<Tile> vec_tiles(tiles.size());
5173   for (int i = 0; i < tiles.size(); i++) {
5174     vec_tiles[i] = Tile(tiles[i]);
5175   }
5176   *layout = LayoutUtil::MakeLayout(minor_to_major, dim_level_types, vec_tiles,
5177                                    element_size_in_bits, memory_space);
5178   return true;
5179 }
5180 
5181 // shape ::= shape_val_
5182 // shape ::= '(' tuple_elements ')'
5183 // tuple_elements
5184 //   ::= /*empty*/
5185 //   ::= shape (',' shape)*
ParseShape(Shape * result)5186 bool HloParserImpl::ParseShape(Shape* result) {
5187   if (EatIfPresent(TokKind::kLparen)) {  // Tuple
5188     std::vector<Shape> shapes;
5189     if (lexer_.GetKind() == TokKind::kRparen) {
5190       /*empty*/
5191     } else {
5192       // shape (',' shape)*
5193       do {
5194         shapes.emplace_back();
5195         if (!ParseShape(&shapes.back())) {
5196           return false;
5197         }
5198       } while (EatIfPresent(TokKind::kComma));
5199     }
5200     *result = ShapeUtil::MakeTupleShape(shapes);
5201     return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
5202   }
5203 
5204   if (lexer_.GetKind() != TokKind::kPrimitiveType) {
5205     return TokenError(absl::StrCat("expected primitive type, saw ",
5206                                    TokKindToString(lexer_.GetKind())));
5207   }
5208   PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
5209   lexer_.Lex();
5210 
5211   // Each element contains a dimension size and a bool indicating whether this
5212   // is a dynamic dimension.
5213   std::vector<int64_t> dimension_sizes;
5214   std::vector<bool> dynamic_dimensions;
5215   if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
5216     return false;
5217   }
5218   result->set_element_type(primitive_type);
5219   for (int i = 0; i < dimension_sizes.size(); ++i) {
5220     result->add_dimensions(dimension_sizes[i]);
5221     result->set_dynamic_dimension(i, dynamic_dimensions[i]);
5222   }
5223   LayoutUtil::SetToDefaultLayout(result);
5224   // We need to lookahead to see if a following open brace is the start of a
5225   // layout. The specific problematic case is:
5226   //
5227   // ENTRY %foo (x: f32[42]) -> f32[123] {
5228   //  ...
5229   // }
5230   //
5231   // The open brace could either be the start of a computation or the start of a
5232   // layout for the f32[123] shape. We consider it the start of a layout if the
5233   // next token after the open brace is an integer or a colon.
5234   if (lexer_.GetKind() == TokKind::kLbrace &&
5235       (lexer_.LookAhead() == TokKind::kInt ||
5236        lexer_.LookAhead() == TokKind::kColon)) {
5237     Layout layout;
5238     if (!ParseLayout(&layout)) {
5239       return false;
5240     }
5241     if (layout.dim_level_types_size() != 0 &&
5242         layout.dim_level_types_size() != result->rank()) {
5243       return Error(
5244           lexer_.GetLoc(),
5245           StrFormat("Dimensions size is %ld, but dim level types size is %ld.",
5246                     result->rank(), layout.dim_level_types_size()));
5247     }
5248     if (layout.minor_to_major_size() != result->rank()) {
5249       return Error(
5250           lexer_.GetLoc(),
5251           StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
5252                     result->rank(), layout.minor_to_major_size()));
5253     }
5254     if (LayoutUtil::IsSparse(layout) && layout.tiles_size() > 0) {
5255       return Error(lexer_.GetLoc(),
5256                    StrFormat("Layout has tiles, but is for a sparse array: %s",
5257                              layout.ToString()));
5258     }
5259     *result->mutable_layout() = layout;
5260   }
5261   return true;
5262 }
5263 
CanBeShape()5264 bool HloParserImpl::CanBeShape() {
5265   // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
5266   // with '('.
5267   return lexer_.GetKind() == TokKind::kPrimitiveType ||
5268          lexer_.GetKind() == TokKind::kLparen;
5269 }
5270 
ParseName(std::string * result)5271 bool HloParserImpl::ParseName(std::string* result) {
5272   VLOG(3) << "ParseName";
5273   if (lexer_.GetKind() != TokKind::kIdent &&
5274       lexer_.GetKind() != TokKind::kName) {
5275     return TokenError("expects name");
5276   }
5277   *result = lexer_.GetStrVal();
5278   lexer_.Lex();
5279   return true;
5280 }
5281 
ParseAttributeName(std::string * result)5282 bool HloParserImpl::ParseAttributeName(std::string* result) {
5283   if (lexer_.GetKind() != TokKind::kAttributeName) {
5284     return TokenError("expects attribute name");
5285   }
5286   *result = lexer_.GetStrVal();
5287   lexer_.Lex();
5288   return true;
5289 }
5290 
ParseString(std::string * result)5291 bool HloParserImpl::ParseString(std::string* result) {
5292   VLOG(3) << "ParseString";
5293   if (lexer_.GetKind() != TokKind::kString) {
5294     return TokenError("expects string");
5295   }
5296   *result = lexer_.GetStrVal();
5297   lexer_.Lex();
5298   return true;
5299 }
5300 
ParseDxD(const std::string & name,std::vector<int64_t> * result)5301 bool HloParserImpl::ParseDxD(const std::string& name,
5302                              std::vector<int64_t>* result) {
5303   LocTy loc = lexer_.GetLoc();
5304   if (!result->empty()) {
5305     return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
5306   }
5307   // 1D
5308   if (lexer_.GetKind() == TokKind::kInt) {
5309     int64_t number;
5310     if (!ParseInt64(&number)) {
5311       return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
5312     }
5313     result->push_back(number);
5314     return true;
5315   }
5316   // 2D or higher.
5317   if (lexer_.GetKind() == TokKind::kDxD) {
5318     std::string str = lexer_.GetStrVal();
5319     if (!SplitToInt64s(str, 'x', result)) {
5320       return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
5321     }
5322     lexer_.Lex();
5323     return true;
5324   }
5325   return TokenError("expects token type kInt or kDxD");
5326 }
5327 
ParseWindowPad(std::vector<std::vector<int64_t>> * pad)5328 bool HloParserImpl::ParseWindowPad(std::vector<std::vector<int64_t>>* pad) {
5329   LocTy loc = lexer_.GetLoc();
5330   if (!pad->empty()) {
5331     return Error(loc, "sub-attribute 'pad=' already exists");
5332   }
5333   if (lexer_.GetKind() != TokKind::kPad) {
5334     return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
5335   }
5336   std::string str = lexer_.GetStrVal();
5337   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5338     std::vector<int64_t> low_high;
5339     if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
5340         low_high.size() != 2) {
5341       return Error(loc,
5342                    "expects padding_low and padding_high separated by '_'");
5343     }
5344     pad->push_back(low_high);
5345   }
5346   lexer_.Lex();
5347   return true;
5348 }
5349 
5350 // This is the inverse xla::ToString(PaddingConfig). The padding config string
5351 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
5352 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
5353 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)5354 bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) {
5355   if (lexer_.GetKind() != TokKind::kPad) {
5356     return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
5357   }
5358   LocTy loc = lexer_.GetLoc();
5359   std::string str = lexer_.GetStrVal();
5360   for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
5361     std::vector<int64_t> padding_dim;
5362     if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
5363         (padding_dim.size() != 2 && padding_dim.size() != 3)) {
5364       return Error(loc,
5365                    "expects padding config pattern like 'low_high_interior' or "
5366                    "'low_high'");
5367     }
5368     auto* dim = padding->add_dimensions();
5369     dim->set_edge_padding_low(padding_dim[0]);
5370     dim->set_edge_padding_high(padding_dim[1]);
5371     dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
5372   }
5373   lexer_.Lex();
5374   return true;
5375 }
5376 
5377 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)5378 bool HloParserImpl::ParseMetadata(OpMetadata* metadata) {
5379   absl::flat_hash_map<std::string, AttrConfig> attrs;
5380   optional<std::string> op_type;
5381   optional<std::string> op_name;
5382   optional<std::string> source_file;
5383   optional<int32_t> source_line;
5384   optional<std::vector<int64_t>> profile_type;
5385   attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
5386   attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
5387   attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
5388   attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
5389   attrs["profile_type"] = {/*required=*/false, AttrTy::kBracedInt64List,
5390                            &profile_type};
5391   if (!ParseSubAttributes(attrs)) {
5392     return false;
5393   }
5394   if (op_type) {
5395     metadata->set_op_type(*op_type);
5396   }
5397   if (op_name) {
5398     metadata->set_op_name(*op_name);
5399   }
5400   if (source_file) {
5401     metadata->set_source_file(*source_file);
5402   }
5403   if (source_line) {
5404     metadata->set_source_line(*source_line);
5405   }
5406   if (profile_type) {
5407     for (const auto& type : *profile_type) {
5408       if (!ProfileType_IsValid(type)) {
5409         return false;
5410       }
5411       metadata->add_profile_type(static_cast<ProfileType>(type));
5412     }
5413   }
5414   return true;
5415 }
5416 
5417 // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}')
ParseSingleOrListMetadata(tensorflow::protobuf::RepeatedPtrField<OpMetadata> * metadata)5418 bool HloParserImpl::ParseSingleOrListMetadata(
5419     tensorflow::protobuf::RepeatedPtrField<OpMetadata>* metadata) {
5420   if (lexer_.GetKind() == TokKind::kLbrace &&
5421       lexer_.LookAhead() == TokKind::kLbrace) {
5422     if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) {
5423       return false;
5424     }
5425 
5426     if (lexer_.GetKind() != TokKind::kRbrace) {
5427       do {
5428         if (!ParseMetadata(metadata->Add())) {
5429           return false;
5430         }
5431       } while (EatIfPresent(TokKind::kComma));
5432     }
5433 
5434     return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list");
5435   }
5436 
5437   return ParseMetadata(metadata->Add());
5438 }
5439 
ParseOpShardingType(OpSharding::Type * type)5440 bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) {
5441   switch (lexer_.GetKind()) {
5442     case TokKind::kw_maximal:
5443       *type = OpSharding::MAXIMAL;
5444       lexer_.Lex();
5445       break;
5446     case TokKind::kw_replicated:
5447       *type = OpSharding::REPLICATED;
5448       lexer_.Lex();
5449       break;
5450     case TokKind::kw_manual:
5451       *type = OpSharding::MANUAL;
5452       lexer_.Lex();
5453       break;
5454     default:
5455       return false;
5456   }
5457   return true;
5458 }
5459 
ParseListShardingType(std::vector<OpSharding::Type> * types)5460 bool HloParserImpl::ParseListShardingType(
5461     std::vector<OpSharding::Type>* types) {
5462   if (!ParseToken(TokKind::kLbrace,
5463                   "expected '{' to start sharding type list")) {
5464     return false;
5465   }
5466 
5467   if (lexer_.GetKind() != TokKind::kRbrace) {
5468     do {
5469       OpSharding::Type type;
5470       if (!ParseOpShardingType(&type)) {
5471         return false;
5472       }
5473       types->emplace_back(type);
5474     } while (EatIfPresent(TokKind::kComma));
5475   }
5476 
5477   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding type list");
5478 }
5479 
ParseOpcode(HloOpcode * opcode,std::optional<HloOpcode> * async_wrapped_opcode)5480 bool HloParserImpl::ParseOpcode(
5481     HloOpcode* opcode, std::optional<HloOpcode>* async_wrapped_opcode) {
5482   VLOG(3) << "ParseOpcode";
5483   if (lexer_.GetKind() != TokKind::kIdent) {
5484     return TokenError("expects opcode");
5485   }
5486   std::string val = lexer_.GetStrVal();
5487   auto status_or_result = StringToHloOpcode(val);
5488   if (!status_or_result.ok()) {
5489     auto try_parsing_async_op = [&](absl::string_view suffix,
5490                                     HloOpcode async_opcode) {
5491       absl::string_view wrapped_opcode_view(val);
5492       if (absl::ConsumeSuffix(&wrapped_opcode_view, suffix)) {
5493         *opcode = async_opcode;
5494         std::string wrapped_opcode(wrapped_opcode_view);
5495         status_or_result = StringToHloOpcode(wrapped_opcode);
5496         return true;
5497       }
5498       return false;
5499     };
5500     if (try_parsing_async_op("-start", HloOpcode::kAsyncStart) ||
5501         try_parsing_async_op("-update", HloOpcode::kAsyncUpdate) ||
5502         try_parsing_async_op("-done", HloOpcode::kAsyncDone)) {
5503       if (!status_or_result.ok()) {
5504         return TokenError(
5505             StrFormat("expects async wrapped opcode but sees: %s, error: %s",
5506                       val, status_or_result.status().error_message()));
5507       }
5508       *async_wrapped_opcode = status_or_result.ValueOrDie();
5509     } else {
5510       return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
5511                                   status_or_result.status().error_message()));
5512     }
5513   } else {
5514     *opcode = status_or_result.ValueOrDie();
5515   }
5516   lexer_.Lex();
5517   return true;
5518 }
5519 
ParseFftType(FftType * result)5520 bool HloParserImpl::ParseFftType(FftType* result) {
5521   VLOG(3) << "ParseFftType";
5522   if (lexer_.GetKind() != TokKind::kIdent) {
5523     return TokenError("expects fft type");
5524   }
5525   std::string val = lexer_.GetStrVal();
5526   if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
5527     return TokenError(StrFormat("expects fft type but sees: %s", val));
5528   }
5529   lexer_.Lex();
5530   return true;
5531 }
5532 
ParsePaddingType(PaddingType * result)5533 bool HloParserImpl::ParsePaddingType(PaddingType* result) {
5534   VLOG(3) << "ParsePaddingType";
5535   if (lexer_.GetKind() != TokKind::kIdent) {
5536     return TokenError("expects padding type");
5537   }
5538   std::string val = lexer_.GetStrVal();
5539   if (!PaddingType_Parse(val, result) || !PaddingType_IsValid(*result)) {
5540     return TokenError(StrFormat("expects padding type but sees: %s", val));
5541   }
5542   lexer_.Lex();
5543   return true;
5544 }
5545 
ParseComparisonDirection(ComparisonDirection * result)5546 bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
5547   VLOG(3) << "ParseComparisonDirection";
5548   if (lexer_.GetKind() != TokKind::kIdent) {
5549     return TokenError("expects comparison direction");
5550   }
5551   std::string val = lexer_.GetStrVal();
5552   auto status_or_result = StringToComparisonDirection(val);
5553   if (!status_or_result.ok()) {
5554     return TokenError(
5555         StrFormat("expects comparison direction but sees: %s", val));
5556   }
5557   *result = status_or_result.ValueOrDie();
5558   lexer_.Lex();
5559   return true;
5560 }
5561 
ParseComparisonType(Comparison::Type * result)5562 bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
5563   VLOG(1) << "ParseComparisonType";
5564   if (lexer_.GetKind() != TokKind::kIdent) {
5565     return TokenError("expects comparison type");
5566   }
5567   std::string val = lexer_.GetStrVal();
5568   auto status_or_result = StringToComparisonType(val);
5569   if (!status_or_result.ok()) {
5570     return TokenError(StrFormat("expects comparison type but sees: %s", val));
5571   }
5572   *result = status_or_result.ValueOrDie();
5573   lexer_.Lex();
5574   return true;
5575 }
5576 
ParseFusionKind(HloInstruction::FusionKind * result)5577 bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
5578   VLOG(3) << "ParseFusionKind";
5579   if (lexer_.GetKind() != TokKind::kIdent) {
5580     return TokenError("expects fusion kind");
5581   }
5582   std::string val = lexer_.GetStrVal();
5583   auto status_or_result = StringToFusionKind(val);
5584   if (!status_or_result.ok()) {
5585     return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
5586                                 val,
5587                                 status_or_result.status().error_message()));
5588   }
5589   *result = status_or_result.ValueOrDie();
5590   lexer_.Lex();
5591   return true;
5592 }
5593 
ParseRandomDistribution(RandomDistribution * result)5594 bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
5595   VLOG(3) << "ParseRandomDistribution";
5596   if (lexer_.GetKind() != TokKind::kIdent) {
5597     return TokenError("expects random distribution");
5598   }
5599   std::string val = lexer_.GetStrVal();
5600   auto status_or_result = StringToRandomDistribution(val);
5601   if (!status_or_result.ok()) {
5602     return TokenError(
5603         StrFormat("expects random distribution but sees: %s, error: %s", val,
5604                   status_or_result.status().error_message()));
5605   }
5606   *result = status_or_result.ValueOrDie();
5607   lexer_.Lex();
5608   return true;
5609 }
5610 
ParseRandomAlgorithm(RandomAlgorithm * result)5611 bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
5612   VLOG(3) << "ParseRandomAlgorithm";
5613   if (lexer_.GetKind() != TokKind::kIdent) {
5614     return TokenError("expects random algorithm");
5615   }
5616   std::string val = lexer_.GetStrVal();
5617   auto status_or_result = StringToRandomAlgorithm(val);
5618   if (!status_or_result.ok()) {
5619     return TokenError(
5620         StrFormat("expects random algorithm but sees: %s, error: %s", val,
5621                   status_or_result.status().error_message()));
5622   }
5623   *result = status_or_result.ValueOrDie();
5624   lexer_.Lex();
5625   return true;
5626 }
5627 
ParsePrecision(PrecisionConfig::Precision * result)5628 bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
5629   VLOG(3) << "ParsePrecision";
5630   if (lexer_.GetKind() != TokKind::kIdent) {
5631     return TokenError("expects random distribution");
5632   }
5633   std::string val = lexer_.GetStrVal();
5634   auto status_or_result = StringToPrecision(val);
5635   if (!status_or_result.ok()) {
5636     return TokenError(StrFormat("expects precision but sees: %s, error: %s",
5637                                 val,
5638                                 status_or_result.status().error_message()));
5639   }
5640   *result = status_or_result.ValueOrDie();
5641   lexer_.Lex();
5642   return true;
5643 }
5644 
ParseInt64(int64_t * result)5645 bool HloParserImpl::ParseInt64(int64_t* result) {
5646   VLOG(3) << "ParseInt64";
5647   if (lexer_.GetKind() != TokKind::kInt) {
5648     return TokenError("expects integer");
5649   }
5650   *result = lexer_.GetInt64Val();
5651   lexer_.Lex();
5652   return true;
5653 }
5654 
ParseDouble(double * result)5655 bool HloParserImpl::ParseDouble(double* result) {
5656   switch (lexer_.GetKind()) {
5657     case TokKind::kDecimal: {
5658       double val = lexer_.GetDecimalVal();
5659       // If GetDecimalVal returns +/-inf, that means that we overflowed
5660       // `double`.
5661       if (std::isinf(val)) {
5662         return TokenError(StrCat("Constant is out of range for double (+/-",
5663                                  std::numeric_limits<double>::max(),
5664                                  ") and so is unparsable."));
5665       }
5666       *result = val;
5667       break;
5668     }
5669     case TokKind::kInt:
5670       *result = static_cast<double>(lexer_.GetInt64Val());
5671       break;
5672     case TokKind::kw_inf:
5673       *result = std::numeric_limits<double>::infinity();
5674       break;
5675     case TokKind::kNegInf:
5676       *result = -std::numeric_limits<double>::infinity();
5677       break;
5678     default:
5679       return TokenError("expects decimal or integer");
5680   }
5681   lexer_.Lex();
5682   return true;
5683 }
5684 
ParseComplex(std::complex<double> * result)5685 bool HloParserImpl::ParseComplex(std::complex<double>* result) {
5686   if (lexer_.GetKind() != TokKind::kLparen) {
5687     return TokenError("expects '(' before complex number");
5688   }
5689   lexer_.Lex();
5690 
5691   double real;
5692   LocTy loc = lexer_.GetLoc();
5693   if (!ParseDouble(&real)) {
5694     return Error(loc,
5695                  "expect floating-point value for real part of complex number");
5696   }
5697 
5698   if (lexer_.GetKind() != TokKind::kComma) {
5699     return TokenError(
5700         absl::StrFormat("expect comma after real part of complex literal"));
5701   }
5702   lexer_.Lex();
5703 
5704   double imag;
5705   loc = lexer_.GetLoc();
5706   if (!ParseDouble(&imag)) {
5707     return Error(
5708         loc,
5709         "expect floating-point value for imaginary part of complex number");
5710   }
5711 
5712   if (lexer_.GetKind() != TokKind::kRparen) {
5713     return TokenError(absl::StrFormat("expect ')' after complex number"));
5714   }
5715 
5716   *result = std::complex<double>(real, imag);
5717   lexer_.Lex();
5718   return true;
5719 }
5720 
ParseBool(bool * result)5721 bool HloParserImpl::ParseBool(bool* result) {
5722   if (lexer_.GetKind() != TokKind::kw_true &&
5723       lexer_.GetKind() != TokKind::kw_false) {
5724     return TokenError("expects true or false");
5725   }
5726   *result = lexer_.GetKind() == TokKind::kw_true;
5727   lexer_.Lex();
5728   return true;
5729 }
5730 
ParseToken(TokKind kind,const std::string & msg)5731 bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) {
5732   VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
5733   if (lexer_.GetKind() != kind) {
5734     return TokenError(msg);
5735   }
5736   lexer_.Lex();
5737   return true;
5738 }
5739 
EatIfPresent(TokKind kind)5740 bool HloParserImpl::EatIfPresent(TokKind kind) {
5741   if (lexer_.GetKind() != kind) {
5742     return false;
5743   }
5744   lexer_.Lex();
5745   return true;
5746 }
5747 
AddInstruction(const std::string & name,HloInstruction * instruction,LocTy name_loc)5748 bool HloParserImpl::AddInstruction(const std::string& name,
5749                                    HloInstruction* instruction,
5750                                    LocTy name_loc) {
5751   auto result = current_name_table().insert({name, {instruction, name_loc}});
5752   if (!result.second) {
5753     Error(name_loc, StrCat("instruction already exists: ", name));
5754     return Error(/*loc=*/result.first->second.second,
5755                  "instruction previously defined here");
5756   }
5757   return true;
5758 }
5759 
AddComputation(const std::string & name,HloComputation * computation,LocTy name_loc)5760 bool HloParserImpl::AddComputation(const std::string& name,
5761                                    HloComputation* computation,
5762                                    LocTy name_loc) {
5763   auto result = computation_pool_.insert({name, {computation, name_loc}});
5764   if (!result.second) {
5765     Error(name_loc, StrCat("computation already exists: ", name));
5766     return Error(/*loc=*/result.first->second.second,
5767                  "computation previously defined here");
5768   }
5769   return true;
5770 }
5771 
ParseShapeOnly()5772 StatusOr<Shape> HloParserImpl::ParseShapeOnly() {
5773   lexer_.Lex();
5774   Shape shape;
5775   if (!ParseShape(&shape)) {
5776     return InvalidArgument("Syntax error:\n%s", GetError());
5777   }
5778   if (lexer_.GetKind() != TokKind::kEof) {
5779     return InvalidArgument("Syntax error:\nExtra content after shape");
5780   }
5781   return shape;
5782 }
5783 
ParseShardingOnly()5784 StatusOr<HloSharding> HloParserImpl::ParseShardingOnly() {
5785   lexer_.Lex();
5786   OpSharding op_sharding;
5787   if (!ParseSharding(&op_sharding)) {
5788     return InvalidArgument("Syntax error:\n%s", GetError());
5789   }
5790   if (lexer_.GetKind() != TokKind::kEof) {
5791     return InvalidArgument("Syntax error:\nExtra content after sharding");
5792   }
5793   return HloSharding::FromProto(op_sharding);
5794 }
5795 
ParseFrontendAttributesOnly()5796 StatusOr<FrontendAttributes> HloParserImpl::ParseFrontendAttributesOnly() {
5797   lexer_.Lex();
5798   FrontendAttributes attributes;
5799   if (!ParseFrontendAttributes(&attributes)) {
5800     return InvalidArgument("Syntax error:\n%s", GetError());
5801   }
5802   if (lexer_.GetKind() != TokKind::kEof) {
5803     return InvalidArgument(
5804         "Syntax error:\nExtra content after frontend attributes");
5805   }
5806   return attributes;
5807 }
5808 
ParseParameterReplicationOnly()5809 StatusOr<std::vector<bool>> HloParserImpl::ParseParameterReplicationOnly() {
5810   lexer_.Lex();
5811   ParameterReplication parameter_replication;
5812   if (!ParseParameterReplication(&parameter_replication)) {
5813     return InvalidArgument("Syntax error:\n%s", GetError());
5814   }
5815   if (lexer_.GetKind() != TokKind::kEof) {
5816     return InvalidArgument(
5817         "Syntax error:\nExtra content after parameter replication");
5818   }
5819   return std::vector<bool>(
5820       parameter_replication.replicated_at_leaf_buffers().begin(),
5821       parameter_replication.replicated_at_leaf_buffers().end());
5822 }
5823 
ParseReplicaGroupsOnly()5824 StatusOr<std::vector<ReplicaGroup>> HloParserImpl::ParseReplicaGroupsOnly() {
5825   lexer_.Lex();
5826   std::vector<ReplicaGroup> replica_groups;
5827   if (!ParseReplicaGroupsOnly(&replica_groups)) {
5828     return InvalidArgument("Syntax error:\n%s", GetError());
5829   }
5830   if (lexer_.GetKind() != TokKind::kEof) {
5831     return InvalidArgument("Syntax error:\nExtra content after replica groups");
5832   }
5833   return replica_groups;
5834 }
5835 
ParseWindowOnly()5836 StatusOr<Window> HloParserImpl::ParseWindowOnly() {
5837   lexer_.Lex();
5838   Window window;
5839   if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
5840     return InvalidArgument("Syntax error:\n%s", GetError());
5841   }
5842   if (lexer_.GetKind() != TokKind::kEof) {
5843     return InvalidArgument("Syntax error:\nExtra content after window");
5844   }
5845   return window;
5846 }
5847 
5848 StatusOr<ConvolutionDimensionNumbers>
ParseConvolutionDimensionNumbersOnly()5849 HloParserImpl::ParseConvolutionDimensionNumbersOnly() {
5850   lexer_.Lex();
5851   ConvolutionDimensionNumbers dnums;
5852   if (!ParseConvolutionDimensionNumbers(&dnums)) {
5853     return InvalidArgument("Syntax error:\n%s", GetError());
5854   }
5855   if (lexer_.GetKind() != TokKind::kEof) {
5856     return InvalidArgument(
5857         "Syntax error:\nExtra content after convolution dnums");
5858   }
5859   return dnums;
5860 }
5861 
ParsePaddingConfigOnly()5862 StatusOr<PaddingConfig> HloParserImpl::ParsePaddingConfigOnly() {
5863   lexer_.Lex();
5864   PaddingConfig padding_config;
5865   if (!ParsePaddingConfig(&padding_config)) {
5866     return InvalidArgument("Syntax error:\n%s", GetError());
5867   }
5868   if (lexer_.GetKind() != TokKind::kEof) {
5869     return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
5870   }
5871   return padding_config;
5872 }
5873 
ParseSingleInstruction(HloModule * module)5874 bool HloParserImpl::ParseSingleInstruction(HloModule* module) {
5875   if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
5876     LOG(FATAL) << "Parser state is not clean. Please do not call any other "
5877                   "methods before calling ParseSingleInstruction.";
5878   }
5879   HloComputation::Builder builder(module->name());
5880 
5881   // The missing instruction hook we register creates the shaped instruction on
5882   // the fly as a parameter and returns it.
5883   int64_t parameter_count = 0;
5884   create_missing_instruction_ =
5885       [this, &builder, &parameter_count](
5886           const std::string& name,
5887           const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
5888     std::string new_name = name.empty() ? StrCat("_", parameter_count) : name;
5889     HloInstruction* parameter = builder.AddInstruction(
5890         HloInstruction::CreateParameter(parameter_count++, shape, new_name));
5891     current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
5892     return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
5893   };
5894 
5895   // Parse the instruction with the registered hook.
5896   Scope scope(&scoped_name_tables_);
5897   if (CanBeShape()) {
5898     // This means that the instruction's left-hand side is probably omitted,
5899     // e.g.
5900     //
5901     //  f32[10] fusion(...), calls={...}
5902     if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
5903       return false;
5904     }
5905   } else {
5906     // This means that the instruction's left-hand side might exist, e.g.
5907     //
5908     //  foo = f32[10] fusion(...), calls={...}
5909     std::string root_name;
5910     if (!ParseInstruction(&builder, &root_name)) {
5911       return false;
5912     }
5913   }
5914 
5915   if (lexer_.GetKind() != TokKind::kEof) {
5916     Error(
5917         lexer_.GetLoc(),
5918         "Syntax error:\nExpected eof after parsing single instruction.  Did "
5919         "you mean to write an HLO module and forget the \"HloModule\" header?");
5920     return false;
5921   }
5922 
5923   module->AddEntryComputation(builder.Build());
5924   for (auto& comp : computations_) {
5925     module->AddEmbeddedComputation(std::move(comp));
5926   }
5927   TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
5928   return true;
5929 }
5930 
5931 }  // namespace
5932 
ParseAndReturnUnverifiedModule(absl::string_view str,const HloModuleConfig & config)5933 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5934     absl::string_view str, const HloModuleConfig& config) {
5935   auto module = std::make_unique<HloModule>(/*name=*/"_", config);
5936   HloParserImpl parser(str);
5937   TF_RETURN_IF_ERROR(parser.Run(module.get()));
5938   return std::move(module);
5939 }
5940 
ParseAndReturnUnverifiedModule(absl::string_view str)5941 StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
5942     absl::string_view str) {
5943   return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
5944 }
5945 
ParseSharding(absl::string_view str)5946 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
5947   HloParserImpl parser(str);
5948   return parser.ParseShardingOnly();
5949 }
5950 
ParseFrontendAttributes(absl::string_view str)5951 StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
5952   HloParserImpl parser(str);
5953   return parser.ParseFrontendAttributesOnly();
5954 }
5955 
ParseParameterReplication(absl::string_view str)5956 StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
5957   HloParserImpl parser(str);
5958   return parser.ParseParameterReplicationOnly();
5959 }
5960 
ParseReplicaGroupsOnly(absl::string_view str)5961 StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
5962     absl::string_view str) {
5963   HloParserImpl parser(str);
5964   return parser.ParseReplicaGroupsOnly();
5965 }
5966 
ParseWindow(absl::string_view str)5967 StatusOr<Window> ParseWindow(absl::string_view str) {
5968   HloParserImpl parser(str);
5969   return parser.ParseWindowOnly();
5970 }
5971 
ParseConvolutionDimensionNumbers(absl::string_view str)5972 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
5973     absl::string_view str) {
5974   HloParserImpl parser(str);
5975   return parser.ParseConvolutionDimensionNumbersOnly();
5976 }
5977 
ParsePaddingConfig(absl::string_view str)5978 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
5979   HloParserImpl parser(str);
5980   return parser.ParsePaddingConfigOnly();
5981 }
5982 
ParseShape(absl::string_view str)5983 StatusOr<Shape> ParseShape(absl::string_view str) {
5984   HloParserImpl parser(str);
5985   return parser.ParseShapeOnly();
5986 }
5987 
CreateHloParserForTests(absl::string_view str)5988 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
5989     absl::string_view str) {
5990   return std::make_unique<HloParserImpl>(str);
5991 }
5992 
5993 }  // namespace xla
5994