xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_lexer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17 
18 #include <limits>
19 #include <optional>
20 #include <string>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/platform/regexp.h"
32 
33 namespace xla {
34 namespace {
35 
36 using absl::string_view;
37 
38 constexpr int kEOF = -1;
39 constexpr int kError = -2;
40 
41 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)42 bool IsIdentifierChar(char c) {
43   return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
44          c == '.' || c == '_';
45 }
46 
47 }  // namespace
48 
GetNextChar()49 int HloLexer::GetNextChar() {
50   int current_char = PeekCurrentChar();
51   if (current_char != kEOF && current_char != kError) {
52     current_ptr_++;
53   }
54   return current_char;
55 }
56 
PeekCurrentChar() const57 int HloLexer::PeekCurrentChar() const {
58   if (current_ptr_ == buf_.data() + buf_.size()) {
59     return kEOF;
60   }
61   char current_char = *current_ptr_;
62   if (current_char == 0) {
63     // '\0' should not appear in the middle of the string.
64     return kError;
65   }
66   return static_cast<unsigned char>(current_char);
67 }
68 
CanDereference(const char * ptr) const69 bool HloLexer::CanDereference(const char* ptr) const {
70   return (ptr < buf_.data() + buf_.size()) && ptr >= buf_.data();
71 }
72 
StringViewFromPointers(const char * begin,const char * end) const73 absl::string_view HloLexer::StringViewFromPointers(const char* begin,
74                                                    const char* end) const {
75   CHECK(begin <= end);
76   CHECK((begin == buf_.data() + buf_.size()) || CanDereference(begin));
77   CHECK((end == buf_.data() + buf_.size()) || CanDereference(end));
78   return absl::string_view(begin, end - begin);
79 }
80 
LookAhead()81 TokKind HloLexer::LookAhead() {
82   if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
83     return GetKind();
84   }
85 
86   const char* old_current_ptr = current_ptr_;
87   TokenState old_token_state = token_state_;
88   Lex();
89   TokKind kind = GetKind();
90   token_state_ = old_token_state;
91   current_ptr_ = old_current_ptr;
92   return kind;
93 }
94 
LexToken()95 TokKind HloLexer::LexToken() {
96   while (true) {
97     token_state_.token_start = current_ptr_;
98 
99     int current_char = GetNextChar();
100     switch (current_char) {
101       default:
102         // [a-zA-Z_]
103         if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
104             current_char == '_') {
105           return LexIdentifier();
106         }
107         return TokKind::kError;
108       case kEOF:
109         // Hit the end of the input buffer.
110         return TokKind::kEof;
111       case kError:
112         // Hit an invalid character in the input buffer.
113         return TokKind::kError;
114       case ' ':
115       case '\t':
116       case '\n':
117       case '\r':
118         // Ignore whitespace.
119         continue;
120       case '0':
121       case '1':
122       case '2':
123       case '3':
124       case '4':
125       case '5':
126       case '6':
127       case '7':
128       case '8':
129       case '9':
130       case '-':
131       case '?':
132         if (current_char == '-' && PeekCurrentChar() == '>') {
133           current_ptr_++;
134           return TokKind::kArrow;
135         }
136         return LexNumberOrPattern();
137       case '=':
138         return TokKind::kEqual;
139       case '<':
140         if (current_char == '<' && PeekCurrentChar() == '=') {
141           current_ptr_++;
142           return TokKind::kLeq;
143         }
144         return TokKind::kError;
145       case ',':
146         return TokKind::kComma;
147       case '%':
148         return LexPercent();
149       case ':':
150         return TokKind::kColon;
151       case '*':
152         return TokKind::kAsterisk;
153       case '[':
154         return TokKind::kLsquare;
155       case ']':
156         return TokKind::kRsquare;
157       case '{':
158         return TokKind::kLbrace;
159       case '}':
160         return TokKind::kRbrace;
161       case '(':
162         return TokKind::kLparen;
163       case ')':
164         return TokKind::kRparen;
165       case '/': {
166         if (PeekCurrentChar() == '*') {
167           // This is the start of a /*...*/ delimited comment. Save the current
168           // location in case the comment is unterminated so the error message
169           // will point to the beginning of the comment.
170           const char* comment_start = current_ptr_;
171           current_ptr_++;
172           // Advance until '*/' is found.
173           while (true) {
174             int current = GetNextChar();
175             if (current == '*' && PeekCurrentChar() == '/') {
176               // End of comment.
177               current_ptr_++;
178               break;
179             }
180             if (current == kEOF) {
181               // Unterminated comment.
182               current_ptr_ = comment_start;
183               return TokKind::kError;
184             }
185             if (current == kError) {
186               return TokKind::kError;
187             }
188           }
189           // Return no token for the comment. Keep lexing.
190           continue;
191         } else if (PeekCurrentChar() == '/') {
192           // This is the start of a '//' delimited comment. Throw away
193           // everything until end of line or file. The end-of-line character(s)
194           // are left unlexed in the buffer which is harmless because these are
195           // skipped later by the lexer. This approach enables support for
196           // different end-of-line encodings.
197           while (true) {
198             int current = PeekCurrentChar();
199             if (current == kEOF || current == '\n' || current == '\r') {
200               break;
201             }
202             if (current == kError) {
203               return TokKind::kError;
204             }
205             current_ptr_++;
206           }
207           continue;
208         }
209         // A lone '/' is an error.
210         return TokKind::kError;
211       }
212       case '.':
213         if (PeekCurrentChar() == '.') {
214           current_ptr_++;
215           if (PeekCurrentChar() == '.') {
216             current_ptr_++;
217             return TokKind::kDots;
218           }
219         }
220         return TokKind::kError;
221       case '"':
222         return LexString();
223     }
224   }
225 }
226 
LexNanPayload(absl::string_view & consumable)227 std::optional<int64_t> HloLexer::LexNanPayload(absl::string_view& consumable) {
228   static LazyRE2 payload_pattern = {R"(\(0x[0-9a-fA-F]+\))"};
229   if (!RE2::Consume(&consumable, *payload_pattern)) {
230     return std::nullopt;
231   }
232   auto slice = StringViewFromPointers(current_ptr_, consumable.data());
233   current_ptr_ = consumable.data();
234   CHECK(absl::StartsWith(slice, "(0x"));
235   slice.remove_prefix(std::strlen("(0x"));
236   CHECK(absl::EndsWith(slice, ")"));
237   slice.remove_suffix(std::strlen(")"));
238   uint64_t payload_value;
239   if (tensorflow::strings::HexStringToUint64(slice, &payload_value)) {
240     if (payload_value <= 0 || payload_value > NanPayloadBitMask<double>()) {
241       LOG(ERROR) << "NaN payload out of range: " << payload_value;
242       return std::nullopt;
243     }
244     return payload_value;
245   }
246   return std::nullopt;
247 }
248 
249 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
250 // other identifiers.
251 //
252 // shape    ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
253 // name     ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
254 // keyword  ::= HloModule, ENTRY, ...
255 // attribute_name ::= condition, body, dimensions, ...
256 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
257 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()258 TokKind HloLexer::LexIdentifier() {
259   while (IsIdentifierChar(PeekCurrentChar())) {
260     current_ptr_++;
261   }
262 
263   // If followed by ':', it's a name.
264   if (PeekCurrentChar() == ':') {
265     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
266     current_ptr_++;  // skip ':'
267     return TokKind::kName;
268   }
269 
270   // If followed by '=', it's a attribute name.
271   if (PeekCurrentChar() == '=') {
272     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
273     current_ptr_++;  // skip '='
274     return TokKind::kAttributeName;
275   }
276 
277   absl::string_view identifier =
278       StringViewFromPointers(token_state_.token_start, current_ptr_);
279 
280   // Primitive type strings are reserved words. The exception is 'tuple' whose
281   // type is represented using nested parentheses without the string 'tuple'.
282   if (primitive_util::IsPrimitiveTypeName(identifier)) {
283     PrimitiveType primitive_type =
284         primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
285     if (primitive_type != TUPLE) {
286       token_state_.primitive_type_val = primitive_type;
287       return TokKind::kPrimitiveType;
288     }
289   }
290 
291   if (identifier == "nan") {
292     std::optional<int64_t> payload;
293     if (PeekCurrentChar() == '(') {
294       absl::string_view consumable =
295           StringViewFromPointers(current_ptr_, buf_.data() + buf_.size());
296       payload = LexNanPayload(consumable);
297       if (!payload.has_value()) {
298         return TokKind::kError;
299       }
300     }
301     token_state_.decimal_val = NanWithSignAndPayload<double>(
302         /*sign=*/false, payload.value_or(QuietNanWithoutPayload<double>()));
303     return TokKind::kDecimal;
304   }
305 
306   // See if this is a keyword.
307 #define KEYWORD(STR)            \
308   do {                          \
309     if (identifier == #STR) {   \
310       return TokKind::kw_##STR; \
311     }                           \
312   } while (false)
313 
314   KEYWORD(true);
315   KEYWORD(false);
316   KEYWORD(inf);
317   KEYWORD(HloModule);
318   KEYWORD(ENTRY);
319   KEYWORD(ROOT);
320   KEYWORD(maximal);
321   KEYWORD(replicated);
322   KEYWORD(manual);
323   KEYWORD(last_tile_dim_replicate);
324 
325 #undef KEYWORD
326 
327   {
328     absl::string_view consumable = StringViewFromPointers(
329         token_state_.token_start, buf_.data() + buf_.size());
330     static LazyRE2 dim_labels_pattern = {
331         R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
332     if (RE2::Consume(&consumable, *dim_labels_pattern)) {
333       current_ptr_ = consumable.data();
334       token_state_.str_val.assign(token_state_.token_start, current_ptr_);
335       return TokKind::kDimLabels;
336     }
337   }
338 
339   token_state_.str_val = std::string(identifier);
340   return TokKind::kIdent;
341 }
342 
343 // Lex names after a % character.
344 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()345 TokKind HloLexer::LexPercent() {
346   const char* name_start = current_ptr_;
347   if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
348       PeekCurrentChar() == '_') {
349     current_ptr_++;
350     while (IsIdentifierChar(PeekCurrentChar())) {
351       current_ptr_++;
352     }
353     token_state_.str_val.assign(name_start, current_ptr_);
354     return TokKind::kName;
355   }
356   return TokKind::kError;
357 }
358 
359 // Lex integer and floating-point values, -inf, and patterns for dim labels,
360 // dxd (e.g. 1x2x3), and pad.
361 //
362 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
363 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
364 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
365 // dxd_pattern ::= [0-9]+(x[0-9]+)+
366 // pad_pattern ::=
367 //   [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
368 // int ::=  [-]?[0-9]+
369 // negative inf ::= '-inf'
LexNumberOrPattern()370 TokKind HloLexer::LexNumberOrPattern() {
371   absl::string_view consumable = StringViewFromPointers(
372       token_state_.token_start, buf_.data() + buf_.size());
373   static LazyRE2 float_pattern = {
374       R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
375   if (RE2::Consume(&consumable, *float_pattern)) {
376     current_ptr_ = consumable.data();
377     CHECK(absl::SimpleAtod(std::string(token_state_.token_start, current_ptr_),
378                            &token_state_.decimal_val));
379     return TokKind::kDecimal;
380   }
381 
382   static LazyRE2 dim_labels_pattern = {
383       R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
384   static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
385   static LazyRE2 pad_pattern = {
386       R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
387 
388   if (RE2::Consume(&consumable, *dim_labels_pattern)) {
389     current_ptr_ = consumable.data();
390     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
391     return TokKind::kDimLabels;
392   }
393 
394   if (RE2::Consume(&consumable, *dxd_pattern)) {
395     current_ptr_ = consumable.data();
396     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
397     return TokKind::kDxD;
398   }
399 
400   if (RE2::Consume(&consumable, *pad_pattern)) {
401     current_ptr_ = consumable.data();
402     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
403     return TokKind::kPad;
404   }
405 
406   static LazyRE2 int_pattern = {R"([-]?\d+)"};
407   if (RE2::Consume(&consumable, *int_pattern)) {
408     current_ptr_ = consumable.data();
409     auto slice = StringViewFromPointers(token_state_.token_start, current_ptr_);
410     if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
411       return TokKind::kInt;
412     }
413     uint64_t uint64_val;
414     if (absl::SimpleAtoi(slice, &uint64_val)) {
415       token_state_.int64_val = absl::bit_cast<int64_t>(uint64_val);
416       return TokKind::kInt;
417     }
418     LOG(ERROR) << "Failed to parse int literal: " << slice;
419     return TokKind::kError;
420   }
421 
422   static LazyRE2 neg_inf = {"-inf"};
423   if (RE2::Consume(&consumable, *neg_inf)) {
424     current_ptr_ = consumable.data();
425     return TokKind::kNegInf;
426   }
427 
428   static LazyRE2 neg_nan = {"-nan"};
429   if (RE2::Consume(&consumable, *neg_nan)) {
430     current_ptr_ = consumable.data();
431 
432     std::optional<int64_t> payload;
433     if (PeekCurrentChar() == '(') {
434       payload = LexNanPayload(consumable);
435       if (!payload.has_value()) {
436         return TokKind::kError;
437       }
438     }
439     token_state_.decimal_val = NanWithSignAndPayload<double>(
440         /*sign=*/true, payload.value_or(QuietNanWithoutPayload<double>()));
441     return TokKind::kDecimal;
442   }
443 
444   return TokKind::kError;
445 }
446 
GetLineAndColumn(LocTy location) const447 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
448   unsigned line_no = 1;
449   const char* start = buf_.data();
450   const char* ptr = start;
451   if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
452       line_no_cache_.last_query <= location) {
453     ptr = line_no_cache_.last_query;
454     line_no = line_no_cache_.line_no_of_query;
455   }
456   for (; ptr != location; ptr++) {
457     CHECK_LT(ptr, buf_.data() + buf_.size());
458     if (*ptr == '\n') {
459       line_no++;
460     }
461   }
462 
463   // Update the line number cache.
464   line_no_cache_.last_query = ptr;
465   line_no_cache_.line_no_of_query = line_no;
466   size_t line_offset = StringViewFromPointers(start, ptr).rfind('\n');
467   if (line_offset == absl::string_view::npos) {
468     line_offset = 0;
469   }
470   return {line_no, ptr - start - line_offset};
471 }
472 
GetLine(LocTy loc) const473 absl::string_view HloLexer::GetLine(LocTy loc) const {
474   if (!CanDereference(loc)) {
475     return "LINE OUT OF RANGE";
476   }
477   size_t line_start = StringViewFromPointers(buf_.data(), loc + 1).rfind('\n');
478   const char* start = line_start == absl::string_view::npos
479                           ? buf_.data()
480                           : buf_.data() + line_start + 1;
481   size_t line_end =
482       StringViewFromPointers(loc, buf_.data() + buf_.size()).find('\n');
483   const char* end = line_end == absl::string_view::npos
484                         ? buf_.data() + buf_.size()
485                         : loc + line_end;
486 
487   return StringViewFromPointers(start, end);
488 }
489 
490 // Lexes quoted string with escaping characters. If matched, the quoted string
491 // will be unescaped and stored to token_state_.str_val.
LexString()492 TokKind HloLexer::LexString() {
493   absl::string_view consumable = StringViewFromPointers(
494       token_state_.token_start, buf_.data() + buf_.size());
495   static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
496   if (RE2::Consume(&consumable, *escaping_pattern)) {
497     current_ptr_ = consumable.data();
498     absl::string_view raw =
499         StringViewFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
500     std::string error;
501     if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
502       LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
503       return TokKind::kError;
504     }
505     return TokKind::kString;
506   }
507   return TokKind::kError;
508 }
509 
510 std::string TokKindToString(TokKind kind) {
511   switch (kind) {
512     case TokKind::kEof:
513       return "kEof";
514     case TokKind::kError:
515       return "kError";
516     case TokKind::kEqual:
517       return "kEqaul";
518     case TokKind::kComma:
519       return "kComma";
520     case TokKind::kColon:
521       return "kColon";
522     case TokKind::kAsterisk:
523       return "kAsterisk";
524     case TokKind::kLsquare:
525       return "kLsquare";
526     case TokKind::kRsquare:
527       return "kRsquare";
528     case TokKind::kLbrace:
529       return "kLbrace";
530     case TokKind::kRbrace:
531       return "kRbrace";
532     case TokKind::kLparen:
533       return "kLparen";
534     case TokKind::kRparen:
535       return "kRparen";
536     case TokKind::kArrow:
537       return "kArrow";
538     case TokKind::kLeq:
539       return "kLeq";
540     case TokKind::kw_HloModule:
541       return "kw_HloModule";
542     case TokKind::kw_ENTRY:
543       return "kw_ENTRY";
544     case TokKind::kw_ROOT:
545       return "kw_ROOT";
546     case TokKind::kw_true:
547       return "kw_true";
548     case TokKind::kw_false:
549       return "kw_false";
550     case TokKind::kw_maximal:
551       return "kw_maximal";
552     case TokKind::kw_replicated:
553       return "kw_replicated";
554     case TokKind::kw_manual:
555       return "kw_manual";
556     case TokKind::kw_last_tile_dim_replicate:
557       return "kw_last_tile_dim_replicate";
558     case TokKind::kw_inf:
559       return "kw_inf";
560     case TokKind::kNegInf:
561       return "kNegInf";
562     case TokKind::kPrimitiveType:
563       return "kPrimitiveType";
564     case TokKind::kName:
565       return "kName";
566     case TokKind::kAttributeName:
567       return "kAttributeName";
568     case TokKind::kDimLabels:
569       return "kDimLabels";
570     case TokKind::kDxD:
571       return "kDxD";
572     case TokKind::kPad:
573       return "kPad";
574     case TokKind::kIdent:
575       return "kIdent";
576     case TokKind::kString:
577       return "kString";
578     case TokKind::kInt:
579       return "kInt";
580     case TokKind::kDecimal:
581       return "kDecimal";
582     case TokKind::kDots:
583       return "kDots";
584   }
585 }
586 
587 }  // namespace xla
588