1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17
18 #include <limits>
19 #include <optional>
20 #include <string>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_split.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/platform/regexp.h"
32
33 namespace xla {
34 namespace {
35
36 using absl::string_view;
37
38 constexpr int kEOF = -1;
39 constexpr int kError = -2;
40
41 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)42 bool IsIdentifierChar(char c) {
43 return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
44 c == '.' || c == '_';
45 }
46
47 } // namespace
48
GetNextChar()49 int HloLexer::GetNextChar() {
50 int current_char = PeekCurrentChar();
51 if (current_char != kEOF && current_char != kError) {
52 current_ptr_++;
53 }
54 return current_char;
55 }
56
PeekCurrentChar() const57 int HloLexer::PeekCurrentChar() const {
58 if (current_ptr_ == buf_.data() + buf_.size()) {
59 return kEOF;
60 }
61 char current_char = *current_ptr_;
62 if (current_char == 0) {
63 // '\0' should not appear in the middle of the string.
64 return kError;
65 }
66 return static_cast<unsigned char>(current_char);
67 }
68
CanDereference(const char * ptr) const69 bool HloLexer::CanDereference(const char* ptr) const {
70 return (ptr < buf_.data() + buf_.size()) && ptr >= buf_.data();
71 }
72
StringViewFromPointers(const char * begin,const char * end) const73 absl::string_view HloLexer::StringViewFromPointers(const char* begin,
74 const char* end) const {
75 CHECK(begin <= end);
76 CHECK((begin == buf_.data() + buf_.size()) || CanDereference(begin));
77 CHECK((end == buf_.data() + buf_.size()) || CanDereference(end));
78 return absl::string_view(begin, end - begin);
79 }
80
LookAhead()81 TokKind HloLexer::LookAhead() {
82 if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
83 return GetKind();
84 }
85
86 const char* old_current_ptr = current_ptr_;
87 TokenState old_token_state = token_state_;
88 Lex();
89 TokKind kind = GetKind();
90 token_state_ = old_token_state;
91 current_ptr_ = old_current_ptr;
92 return kind;
93 }
94
LexToken()95 TokKind HloLexer::LexToken() {
96 while (true) {
97 token_state_.token_start = current_ptr_;
98
99 int current_char = GetNextChar();
100 switch (current_char) {
101 default:
102 // [a-zA-Z_]
103 if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
104 current_char == '_') {
105 return LexIdentifier();
106 }
107 return TokKind::kError;
108 case kEOF:
109 // Hit the end of the input buffer.
110 return TokKind::kEof;
111 case kError:
112 // Hit an invalid character in the input buffer.
113 return TokKind::kError;
114 case ' ':
115 case '\t':
116 case '\n':
117 case '\r':
118 // Ignore whitespace.
119 continue;
120 case '0':
121 case '1':
122 case '2':
123 case '3':
124 case '4':
125 case '5':
126 case '6':
127 case '7':
128 case '8':
129 case '9':
130 case '-':
131 case '?':
132 if (current_char == '-' && PeekCurrentChar() == '>') {
133 current_ptr_++;
134 return TokKind::kArrow;
135 }
136 return LexNumberOrPattern();
137 case '=':
138 return TokKind::kEqual;
139 case '<':
140 if (current_char == '<' && PeekCurrentChar() == '=') {
141 current_ptr_++;
142 return TokKind::kLeq;
143 }
144 return TokKind::kError;
145 case ',':
146 return TokKind::kComma;
147 case '%':
148 return LexPercent();
149 case ':':
150 return TokKind::kColon;
151 case '*':
152 return TokKind::kAsterisk;
153 case '[':
154 return TokKind::kLsquare;
155 case ']':
156 return TokKind::kRsquare;
157 case '{':
158 return TokKind::kLbrace;
159 case '}':
160 return TokKind::kRbrace;
161 case '(':
162 return TokKind::kLparen;
163 case ')':
164 return TokKind::kRparen;
165 case '/': {
166 if (PeekCurrentChar() == '*') {
167 // This is the start of a /*...*/ delimited comment. Save the current
168 // location in case the comment is unterminated so the error message
169 // will point to the beginning of the comment.
170 const char* comment_start = current_ptr_;
171 current_ptr_++;
172 // Advance until '*/' is found.
173 while (true) {
174 int current = GetNextChar();
175 if (current == '*' && PeekCurrentChar() == '/') {
176 // End of comment.
177 current_ptr_++;
178 break;
179 }
180 if (current == kEOF) {
181 // Unterminated comment.
182 current_ptr_ = comment_start;
183 return TokKind::kError;
184 }
185 if (current == kError) {
186 return TokKind::kError;
187 }
188 }
189 // Return no token for the comment. Keep lexing.
190 continue;
191 } else if (PeekCurrentChar() == '/') {
192 // This is the start of a '//' delimited comment. Throw away
193 // everything until end of line or file. The end-of-line character(s)
194 // are left unlexed in the buffer which is harmless because these are
195 // skipped later by the lexer. This approach enables support for
196 // different end-of-line encodings.
197 while (true) {
198 int current = PeekCurrentChar();
199 if (current == kEOF || current == '\n' || current == '\r') {
200 break;
201 }
202 if (current == kError) {
203 return TokKind::kError;
204 }
205 current_ptr_++;
206 }
207 continue;
208 }
209 // A lone '/' is an error.
210 return TokKind::kError;
211 }
212 case '.':
213 if (PeekCurrentChar() == '.') {
214 current_ptr_++;
215 if (PeekCurrentChar() == '.') {
216 current_ptr_++;
217 return TokKind::kDots;
218 }
219 }
220 return TokKind::kError;
221 case '"':
222 return LexString();
223 }
224 }
225 }
226
LexNanPayload(absl::string_view & consumable)227 std::optional<int64_t> HloLexer::LexNanPayload(absl::string_view& consumable) {
228 static LazyRE2 payload_pattern = {R"(\(0x[0-9a-fA-F]+\))"};
229 if (!RE2::Consume(&consumable, *payload_pattern)) {
230 return std::nullopt;
231 }
232 auto slice = StringViewFromPointers(current_ptr_, consumable.data());
233 current_ptr_ = consumable.data();
234 CHECK(absl::StartsWith(slice, "(0x"));
235 slice.remove_prefix(std::strlen("(0x"));
236 CHECK(absl::EndsWith(slice, ")"));
237 slice.remove_suffix(std::strlen(")"));
238 uint64_t payload_value;
239 if (tensorflow::strings::HexStringToUint64(slice, &payload_value)) {
240 if (payload_value <= 0 || payload_value > NanPayloadBitMask<double>()) {
241 LOG(ERROR) << "NaN payload out of range: " << payload_value;
242 return std::nullopt;
243 }
244 return payload_value;
245 }
246 return std::nullopt;
247 }
248
249 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
250 // other identifiers.
251 //
252 // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
253 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
254 // keyword ::= HloModule, ENTRY, ...
255 // attribute_name ::= condition, body, dimensions, ...
256 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
257 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()258 TokKind HloLexer::LexIdentifier() {
259 while (IsIdentifierChar(PeekCurrentChar())) {
260 current_ptr_++;
261 }
262
263 // If followed by ':', it's a name.
264 if (PeekCurrentChar() == ':') {
265 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
266 current_ptr_++; // skip ':'
267 return TokKind::kName;
268 }
269
270 // If followed by '=', it's a attribute name.
271 if (PeekCurrentChar() == '=') {
272 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
273 current_ptr_++; // skip '='
274 return TokKind::kAttributeName;
275 }
276
277 absl::string_view identifier =
278 StringViewFromPointers(token_state_.token_start, current_ptr_);
279
280 // Primitive type strings are reserved words. The exception is 'tuple' whose
281 // type is represented using nested parentheses without the string 'tuple'.
282 if (primitive_util::IsPrimitiveTypeName(identifier)) {
283 PrimitiveType primitive_type =
284 primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
285 if (primitive_type != TUPLE) {
286 token_state_.primitive_type_val = primitive_type;
287 return TokKind::kPrimitiveType;
288 }
289 }
290
291 if (identifier == "nan") {
292 std::optional<int64_t> payload;
293 if (PeekCurrentChar() == '(') {
294 absl::string_view consumable =
295 StringViewFromPointers(current_ptr_, buf_.data() + buf_.size());
296 payload = LexNanPayload(consumable);
297 if (!payload.has_value()) {
298 return TokKind::kError;
299 }
300 }
301 token_state_.decimal_val = NanWithSignAndPayload<double>(
302 /*sign=*/false, payload.value_or(QuietNanWithoutPayload<double>()));
303 return TokKind::kDecimal;
304 }
305
306 // See if this is a keyword.
307 #define KEYWORD(STR) \
308 do { \
309 if (identifier == #STR) { \
310 return TokKind::kw_##STR; \
311 } \
312 } while (false)
313
314 KEYWORD(true);
315 KEYWORD(false);
316 KEYWORD(inf);
317 KEYWORD(HloModule);
318 KEYWORD(ENTRY);
319 KEYWORD(ROOT);
320 KEYWORD(maximal);
321 KEYWORD(replicated);
322 KEYWORD(manual);
323 KEYWORD(last_tile_dim_replicate);
324
325 #undef KEYWORD
326
327 {
328 absl::string_view consumable = StringViewFromPointers(
329 token_state_.token_start, buf_.data() + buf_.size());
330 static LazyRE2 dim_labels_pattern = {
331 R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
332 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
333 current_ptr_ = consumable.data();
334 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
335 return TokKind::kDimLabels;
336 }
337 }
338
339 token_state_.str_val = std::string(identifier);
340 return TokKind::kIdent;
341 }
342
343 // Lex names after a % character.
344 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()345 TokKind HloLexer::LexPercent() {
346 const char* name_start = current_ptr_;
347 if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
348 PeekCurrentChar() == '_') {
349 current_ptr_++;
350 while (IsIdentifierChar(PeekCurrentChar())) {
351 current_ptr_++;
352 }
353 token_state_.str_val.assign(name_start, current_ptr_);
354 return TokKind::kName;
355 }
356 return TokKind::kError;
357 }
358
359 // Lex integer and floating-point values, -inf, and patterns for dim labels,
360 // dxd (e.g. 1x2x3), and pad.
361 //
362 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
363 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
364 // dim_labels_pattern ::= [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,}
365 // dxd_pattern ::= [0-9]+(x[0-9]+)+
366 // pad_pattern ::=
367 // [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
368 // int ::= [-]?[0-9]+
369 // negative inf ::= '-inf'
LexNumberOrPattern()370 TokKind HloLexer::LexNumberOrPattern() {
371 absl::string_view consumable = StringViewFromPointers(
372 token_state_.token_start, buf_.data() + buf_.size());
373 static LazyRE2 float_pattern = {
374 R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
375 if (RE2::Consume(&consumable, *float_pattern)) {
376 current_ptr_ = consumable.data();
377 CHECK(absl::SimpleAtod(std::string(token_state_.token_start, current_ptr_),
378 &token_state_.decimal_val));
379 return TokKind::kDecimal;
380 }
381
382 static LazyRE2 dim_labels_pattern = {
383 R"([0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,})"};
384 static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
385 static LazyRE2 pad_pattern = {
386 R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
387
388 if (RE2::Consume(&consumable, *dim_labels_pattern)) {
389 current_ptr_ = consumable.data();
390 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
391 return TokKind::kDimLabels;
392 }
393
394 if (RE2::Consume(&consumable, *dxd_pattern)) {
395 current_ptr_ = consumable.data();
396 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
397 return TokKind::kDxD;
398 }
399
400 if (RE2::Consume(&consumable, *pad_pattern)) {
401 current_ptr_ = consumable.data();
402 token_state_.str_val.assign(token_state_.token_start, current_ptr_);
403 return TokKind::kPad;
404 }
405
406 static LazyRE2 int_pattern = {R"([-]?\d+)"};
407 if (RE2::Consume(&consumable, *int_pattern)) {
408 current_ptr_ = consumable.data();
409 auto slice = StringViewFromPointers(token_state_.token_start, current_ptr_);
410 if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
411 return TokKind::kInt;
412 }
413 uint64_t uint64_val;
414 if (absl::SimpleAtoi(slice, &uint64_val)) {
415 token_state_.int64_val = absl::bit_cast<int64_t>(uint64_val);
416 return TokKind::kInt;
417 }
418 LOG(ERROR) << "Failed to parse int literal: " << slice;
419 return TokKind::kError;
420 }
421
422 static LazyRE2 neg_inf = {"-inf"};
423 if (RE2::Consume(&consumable, *neg_inf)) {
424 current_ptr_ = consumable.data();
425 return TokKind::kNegInf;
426 }
427
428 static LazyRE2 neg_nan = {"-nan"};
429 if (RE2::Consume(&consumable, *neg_nan)) {
430 current_ptr_ = consumable.data();
431
432 std::optional<int64_t> payload;
433 if (PeekCurrentChar() == '(') {
434 payload = LexNanPayload(consumable);
435 if (!payload.has_value()) {
436 return TokKind::kError;
437 }
438 }
439 token_state_.decimal_val = NanWithSignAndPayload<double>(
440 /*sign=*/true, payload.value_or(QuietNanWithoutPayload<double>()));
441 return TokKind::kDecimal;
442 }
443
444 return TokKind::kError;
445 }
446
GetLineAndColumn(LocTy location) const447 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
448 unsigned line_no = 1;
449 const char* start = buf_.data();
450 const char* ptr = start;
451 if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
452 line_no_cache_.last_query <= location) {
453 ptr = line_no_cache_.last_query;
454 line_no = line_no_cache_.line_no_of_query;
455 }
456 for (; ptr != location; ptr++) {
457 CHECK_LT(ptr, buf_.data() + buf_.size());
458 if (*ptr == '\n') {
459 line_no++;
460 }
461 }
462
463 // Update the line number cache.
464 line_no_cache_.last_query = ptr;
465 line_no_cache_.line_no_of_query = line_no;
466 size_t line_offset = StringViewFromPointers(start, ptr).rfind('\n');
467 if (line_offset == absl::string_view::npos) {
468 line_offset = 0;
469 }
470 return {line_no, ptr - start - line_offset};
471 }
472
GetLine(LocTy loc) const473 absl::string_view HloLexer::GetLine(LocTy loc) const {
474 if (!CanDereference(loc)) {
475 return "LINE OUT OF RANGE";
476 }
477 size_t line_start = StringViewFromPointers(buf_.data(), loc + 1).rfind('\n');
478 const char* start = line_start == absl::string_view::npos
479 ? buf_.data()
480 : buf_.data() + line_start + 1;
481 size_t line_end =
482 StringViewFromPointers(loc, buf_.data() + buf_.size()).find('\n');
483 const char* end = line_end == absl::string_view::npos
484 ? buf_.data() + buf_.size()
485 : loc + line_end;
486
487 return StringViewFromPointers(start, end);
488 }
489
490 // Lexes quoted string with escaping characters. If matched, the quoted string
491 // will be unescaped and stored to token_state_.str_val.
LexString()492 TokKind HloLexer::LexString() {
493 absl::string_view consumable = StringViewFromPointers(
494 token_state_.token_start, buf_.data() + buf_.size());
495 static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
496 if (RE2::Consume(&consumable, *escaping_pattern)) {
497 current_ptr_ = consumable.data();
498 absl::string_view raw =
499 StringViewFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
500 std::string error;
501 if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
502 LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
503 return TokKind::kError;
504 }
505 return TokKind::kString;
506 }
507 return TokKind::kError;
508 }
509
510 std::string TokKindToString(TokKind kind) {
511 switch (kind) {
512 case TokKind::kEof:
513 return "kEof";
514 case TokKind::kError:
515 return "kError";
516 case TokKind::kEqual:
517 return "kEqaul";
518 case TokKind::kComma:
519 return "kComma";
520 case TokKind::kColon:
521 return "kColon";
522 case TokKind::kAsterisk:
523 return "kAsterisk";
524 case TokKind::kLsquare:
525 return "kLsquare";
526 case TokKind::kRsquare:
527 return "kRsquare";
528 case TokKind::kLbrace:
529 return "kLbrace";
530 case TokKind::kRbrace:
531 return "kRbrace";
532 case TokKind::kLparen:
533 return "kLparen";
534 case TokKind::kRparen:
535 return "kRparen";
536 case TokKind::kArrow:
537 return "kArrow";
538 case TokKind::kLeq:
539 return "kLeq";
540 case TokKind::kw_HloModule:
541 return "kw_HloModule";
542 case TokKind::kw_ENTRY:
543 return "kw_ENTRY";
544 case TokKind::kw_ROOT:
545 return "kw_ROOT";
546 case TokKind::kw_true:
547 return "kw_true";
548 case TokKind::kw_false:
549 return "kw_false";
550 case TokKind::kw_maximal:
551 return "kw_maximal";
552 case TokKind::kw_replicated:
553 return "kw_replicated";
554 case TokKind::kw_manual:
555 return "kw_manual";
556 case TokKind::kw_last_tile_dim_replicate:
557 return "kw_last_tile_dim_replicate";
558 case TokKind::kw_inf:
559 return "kw_inf";
560 case TokKind::kNegInf:
561 return "kNegInf";
562 case TokKind::kPrimitiveType:
563 return "kPrimitiveType";
564 case TokKind::kName:
565 return "kName";
566 case TokKind::kAttributeName:
567 return "kAttributeName";
568 case TokKind::kDimLabels:
569 return "kDimLabels";
570 case TokKind::kDxD:
571 return "kDxD";
572 case TokKind::kPad:
573 return "kPad";
574 case TokKind::kIdent:
575 return "kIdent";
576 case TokKind::kString:
577 return "kString";
578 case TokKind::kInt:
579 return "kInt";
580 case TokKind::kDecimal:
581 return "kDecimal";
582 case TokKind::kDots:
583 return "kDots";
584 }
585 }
586
587 } // namespace xla
588