1 //==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9
10 // API modified from llvm::FileCheck
11
12 #include <c10/util/Exception.h>
13 #include <c10/util/StringUtil.h>
14 #include <c10/util/irange.h>
15 #include <torch/csrc/Export.h>
16 #include <torch/csrc/jit/frontend/source_range.h>
17 #include <torch/csrc/jit/ir/ir.h>
18 #include <torch/csrc/jit/testing/file_check.h>
19 #include <optional>
20
21 #include <algorithm>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25
26 namespace torch {
27 namespace jit {
28
29 namespace testing {
30
31 enum CheckType {
32 CHECK,
33 CHECK_NEXT,
34 CHECK_SAME,
35 CHECK_NOT,
36 CHECK_COUNT,
37 CHECK_DAG,
38 CHECK_SOURCE_HIGHLIGHTED,
39 CHECK_REGEX,
40 };
41
42 struct Check {
Checktorch::jit::testing::Check43 Check(
44 CheckType type,
45 std::string str,
46 std::optional<size_t> count = std::nullopt)
47 : type_(type), count_(count), search_str_(std::move(str)) {}
48
Checktorch::jit::testing::Check49 Check(
50 CheckType type,
51 c10::string_view str,
52 std::optional<size_t> count = std::nullopt)
53 : Check(type, std::string(str.begin(), str.end()), count) {}
54
55 CheckType type_;
56 std::optional<size_t> count_;
57 const std::string search_str_;
58
59 friend std::ostream& operator<<(std::ostream& out, const Check& c);
60 };
61
operator <<(std::ostream & out,const Check & c)62 std::ostream& operator<<(std::ostream& out, const Check& c) {
63 switch (c.type_) {
64 case CHECK:
65 out << "CHECK";
66 break;
67 case CHECK_NEXT:
68 out << "CHECK-NEXT";
69 break;
70 case CHECK_SAME:
71 out << "CHECK-SAME";
72 break;
73 case CHECK_NOT:
74 out << "CHECK-NOT";
75 break;
76 case CHECK_DAG:
77 out << "CHECK-DAG";
78 break;
79 case CHECK_COUNT:
80 out << "CHECK-COUNT-" << *c.count_;
81 break;
82 case CHECK_SOURCE_HIGHLIGHTED:
83 out << "CHECK-SOURCE-HIGHLIGHTED";
84 break;
85 case CHECK_REGEX:
86 out << "CHECK-REGEX";
87 break;
88 }
89 out << ": " << c.search_str_;
90 return out;
91 };
92
93 namespace {
94
assertFind(const SourceRange & search_range,const std::string & sub,const std::function<void (std::ostream & out)> & extra_msg=nullptr)95 size_t assertFind(
96 const SourceRange& search_range,
97 const std::string& sub,
98 const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
99 auto pos = search_range.source()->text_str().find(sub, search_range.start());
100 if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
101 auto found_range =
102 SourceRange(search_range.source(), search_range.start(), sub.size());
103 std::stringstream ss;
104 ss << "Expected to find ";
105 c10::printQuotedString(ss, sub);
106 ss << " but did not find it" << std::endl;
107 ss << "Searched string:" << std::endl;
108 found_range.highlight(ss);
109 if (extra_msg) {
110 extra_msg(ss);
111 }
112 throw std::runtime_error(ss.str());
113 }
114 return pos;
115 }
116
assertFind(const SourceRange & search_range,const std::string & sub,const Check & check)117 size_t assertFind(
118 const SourceRange& search_range,
119 const std::string& sub,
120 const Check& check) {
121 return assertFind(search_range, sub, [&](std::ostream& out) {
122 out << "From " << check << "\n";
123 });
124 }
125
assertFind(const std::shared_ptr<Source> & source,const std::string & sub,size_t start,const Check & check)126 size_t assertFind(
127 const std::shared_ptr<Source>& source,
128 const std::string& sub,
129 size_t start,
130 const Check& check) {
131 return assertFind(SourceRange(source, start, source->size()), sub, check);
132 }
133
assertFindRegex(const SourceRange & search_range,const std::string & sub,const std::function<void (std::ostream & out)> & extra_msg=nullptr)134 size_t assertFindRegex(
135 const SourceRange& search_range,
136 const std::string& sub,
137 const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
138 auto pos =
139 search_range.source()->text_str().find_regex(sub, search_range.start());
140
141 if (pos == std::string::npos) {
142 std::stringstream ss;
143 ss << "Expected to find regex ";
144 c10::printQuotedString(ss, sub);
145 ss << " but did not find it" << std::endl;
146 ss << "Searched string:" << std::endl;
147 if (extra_msg) {
148 extra_msg(ss);
149 }
150 throw std::runtime_error(ss.str());
151
152 return std::string::npos;
153 }
154 return pos;
155 }
156
assertFindRegex(const SourceRange & search_range,const std::string & sub,const Check & check)157 size_t assertFindRegex(
158 const SourceRange& search_range,
159 const std::string& sub,
160 const Check& check) {
161 return assertFindRegex(search_range, sub, [&](std::ostream& out) {
162 out << "From " << check << "\n";
163 });
164 }
165
assertFindRegex(const std::shared_ptr<Source> & source,const std::string & sub,size_t start,const Check & check)166 size_t assertFindRegex(
167 const std::shared_ptr<Source>& source,
168 const std::string& sub,
169 size_t start,
170 const Check& check) {
171 return assertFindRegex(
172 SourceRange(source, start, source->size()), sub, check);
173 }
174
assertNotFind(const SourceRange & search_range,const std::string & sub,const Check & check)175 void assertNotFind(
176 const SourceRange& search_range,
177 const std::string& sub,
178 const Check& check) {
179 auto pos = search_range.source()->text_str().find(sub, search_range.start());
180 if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
181 auto found_range =
182 SourceRange(search_range.source(), pos, sub.size() + pos);
183 std::stringstream ss;
184 ss << "Expected to not find ";
185 c10::printQuotedString(ss, sub);
186 ss << " but found it\n";
187 found_range.highlight(ss);
188 ss << "From " << check << "\n";
189 throw std::runtime_error(ss.str());
190 }
191 }
192
193 } // namespace
194
195 struct FileCheckImpl {
196 TORCH_API explicit FileCheckImpl() = default;
197
runtorch::jit::testing::FileCheckImpl198 TORCH_API void run(const std::string& test_file) {
199 has_run = true;
200
201 if (groups.empty() || groups[0].empty()) {
202 throw std::runtime_error(
203 "No checks have been added to this instance of"
204 "Filecheck! Check for bad input.");
205 }
206
207 doChecks(std::make_shared<Source>(test_file));
208 }
209
runtorch::jit::testing::FileCheckImpl210 TORCH_API void run(
211 const std::string& checks_file,
212 const std::string& test_file) {
213 auto source = std::make_shared<Source>(checks_file);
214 parseStrings(source);
215 run(test_file);
216 }
217
addChecktorch::jit::testing::FileCheckImpl218 TORCH_API void addCheck(const Check& check) {
219 // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
220 if (groups.empty() ||
221 (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
222 groups.push_back({check});
223 } else {
224 auto& last_group = groups.back();
225 if (last_group.at(0).type_ == check.type_) {
226 last_group.push_back(check);
227 } else {
228 groups.push_back({check});
229 }
230 }
231 has_run = false;
232 }
233
addChecktorch::jit::testing::FileCheckImpl234 TORCH_API void addCheck(
235 CheckType type,
236 const std::string& s,
237 std::optional<size_t> count = std::nullopt) {
238 addCheck(Check(type, s, count));
239 }
240
241 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
242 bool has_run = false;
243
244 friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
245
246 private:
parseSingleChecktorch::jit::testing::FileCheckImpl247 bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
248 const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
249 {CHECK, ": "},
250 {CHECK_NEXT, "-NEXT: "},
251 {CHECK_SAME, "-SAME: "},
252 {CHECK_NOT, "-NOT: "},
253 {CHECK_DAG, "-DAG: "},
254 {CHECK_COUNT, "-COUNT-"}, // needs special parsing
255 {CHECK_SOURCE_HIGHLIGHTED, "-SOURCE-HIGHLIGHTED: "},
256 {CHECK_REGEX, "-REGEX: "},
257 };
258
259 for (const auto& check_pair : check_pairs) {
260 const std::string& check_suffix = check_pair.second;
261 auto suffix_pos = source->text_str().find(check_suffix, *start);
262 if (suffix_pos != *start) {
263 continue;
264 }
265 size_t end_check_string = suffix_pos + check_suffix.size();
266 CheckType type = check_pair.first;
267 std::optional<size_t> count = std::nullopt;
268 auto end_line = source->text_str().find("\n", end_check_string);
269 bool exactly = false;
270 if (type == CHECK_COUNT) {
271 const std::string exact = "EXACTLY-";
272 if (source->text_str().find(exact, end_check_string) ==
273 end_check_string) {
274 exactly = true;
275 end_check_string += exact.size();
276 }
277 size_t end =
278 assertFind(SourceRange(source, end_check_string, end_line), ":");
279 auto count_view = source->text_str()
280 .substr(end_check_string, end - end_check_string)
281 .str();
282 count = std::stoll(std::string(count_view.begin(), count_view.end()));
283 end_check_string = end + 2; // add ':' and the space
284 }
285 auto check = Check(
286 type,
287 source->text_str()
288 .substr(end_check_string, end_line - end_check_string)
289 .str(),
290 count);
291 addCheck(check);
292 if (exactly) {
293 addCheck(CHECK_NOT, check.search_str_);
294 }
295 *start = end_line;
296 return true;
297 }
298 return false;
299 }
300
findNextStarttorch::jit::testing::FileCheckImpl301 size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
302 size_t start = source->text_str().find("#", prev_end);
303 if (start == std::string::npos) {
304 return start;
305 }
306 start += 1;
307 static constexpr size_t max_whitespace = 6;
308 size_t i = 0;
309 while (start + i < source->size() && i < max_whitespace) {
310 auto c = source->char_at(start + i);
311 if (c != ' ' && c != '\t') {
312 break;
313 }
314 i++;
315 }
316 static const std::string check = "CHECK";
317 if (source->text_str().substr(start + i, check.size()) == check) {
318 return start + i + check.size();
319 } else {
320 return findNextStart(source, start + i + 1);
321 }
322 }
323
parseStringstorch::jit::testing::FileCheckImpl324 void parseStrings(const std::shared_ptr<Source>& source) {
325 size_t start = 0;
326 start = findNextStart(source, 0);
327 while (start != std::string::npos) {
328 bool found_match = parseSingleCheck(source, &start);
329 if (!found_match) {
330 std::ostringstream ss;
331 ss << "Could not parse check at:\n";
332 SourceRange(source, start, start + 1).highlight(ss);
333 ss << "Check for bad input.";
334 has_run = true;
335 throw std::runtime_error(ss.str());
336 }
337 start = findNextStart(source, start);
338 }
339 }
340
doCheckNottorch::jit::testing::FileCheckImpl341 void doCheckNot(
342 const std::vector<Check>& nots,
343 const std::shared_ptr<Source>& source,
344 const SourceRange& prev,
345 const SourceRange& next) {
346 auto start = prev.end(); // inclusive
347 auto end = next.start(); // exclusive
348 if (end < start) {
349 return;
350 }
351 for (const auto& check : nots) {
352 AT_ASSERT(check.type_ == CHECK_NOT);
353 assertNotFind(SourceRange(source, start, end), check.search_str_, check);
354 }
355 }
356
357 // Checks that source token is highlighted, does not advance search range.
doCheckSourceHighlightedtorch::jit::testing::FileCheckImpl358 void doCheckSourceHighlighted(
359 const Check& check,
360 const std::shared_ptr<Source>& source,
361 size_t start_offset) {
362 auto construct_error_and_throw = [&](size_t error_start_pos) {
363 SourceRange error_range(
364 source, error_start_pos, check.search_str_.size());
365 std::stringstream ss;
366 ss << "Expected to find ";
367 c10::printQuotedString(ss, check.search_str_);
368 ss << "highlighted but it is not." << std::endl;
369 error_range.highlight(ss);
370 throw std::runtime_error(ss.str());
371 };
372
373 size_t search_start_offset = start_offset;
374 bool found_token_at_least_once = false;
375 size_t pos = search_start_offset;
376 while (pos < source->size()) {
377 pos = source->text_str().find(check.search_str_, search_start_offset);
378 if (pos == std::string::npos) {
379 break;
380 }
381
382 found_token_at_least_once = true;
383
384 auto lineno = source->lineno_for_offset(pos);
385 auto col = pos - source->offset_for_line(lineno);
386 auto highlight_lineno = lineno + 1;
387
388 if (highlight_lineno >= source->num_lines()) {
389 construct_error_and_throw(pos);
390 }
391
392 auto highlight_start_offset =
393 source->offset_for_line(highlight_lineno) + col;
394 auto highlight_end_offset = std::min(
395 highlight_start_offset + check.search_str_.size(), source->size());
396
397 if (highlight_end_offset >= source->size()) {
398 construct_error_and_throw(pos);
399 }
400
401 bool found_highlight = true;
402 for (const auto posi :
403 c10::irange(highlight_start_offset, highlight_end_offset)) {
404 if (source->char_at(posi) != '~') {
405 found_highlight = false;
406 }
407 }
408
409 if (found_highlight) {
410 assertNotFind(
411 SourceRange(
412 source, highlight_start_offset - 1, highlight_start_offset),
413 "~",
414 check);
415 assertNotFind(
416 SourceRange(source, highlight_end_offset, highlight_end_offset + 1),
417 "~",
418 check);
419 return;
420 }
421
422 search_start_offset = pos + 1;
423 }
424
425 if (!found_token_at_least_once) {
426 // Guaranteed to fail to generate error message.
427 assertFind(source, check.search_str_, start_offset, check);
428 }
429
430 construct_error_and_throw(start_offset);
431 }
432
matchDagGrouptorch::jit::testing::FileCheckImpl433 SourceRange matchDagGroup(
434 const std::vector<Check>& group,
435 const std::shared_ptr<Source>& source,
436 const SourceRange& prev) {
437 size_t group_beg = std::string::npos;
438 size_t group_end = 0;
439
440 AT_ASSERT(!groups.empty());
441 for (const auto& check : group) {
442 AT_ASSERT(check.type_ == group[0].type_);
443 auto pos = assertFind(source, check.search_str_, prev.end(), check);
444 group_beg = std::min(pos, group_beg);
445 group_end = std::max(pos + check.search_str_.size(), group_end);
446 }
447
448 return SourceRange(source, group_beg, group_end);
449 }
450
matchGrouptorch::jit::testing::FileCheckImpl451 SourceRange matchGroup(
452 const std::vector<Check>& group,
453 const std::shared_ptr<Source>& source,
454 const SourceRange& prev) {
455 AT_ASSERT(!group.empty());
456 CheckType type = group[0].type_;
457
458 if (type == CHECK_DAG) {
459 return matchDagGroup(group, source, prev);
460 }
461 AT_ASSERT(type != CHECK_NOT);
462 AT_ASSERT(group.size() == 1);
463
464 const auto& check = group[0];
465 size_t start_range = prev.end();
466 size_t end_range = start_range;
467
468 switch (check.type_) {
469 case CHECK: {
470 start_range = assertFind(source, check.search_str_, start_range, check);
471 end_range = start_range + check.search_str_.size();
472 } break;
473 case CHECK_SAME: {
474 auto pos = assertFind(source, check.search_str_, start_range, check);
475 assertNotFind(SourceRange(source, prev.end(), pos), "\n", check);
476 start_range = pos;
477 end_range = pos + check.search_str_.size();
478 } break;
479 case CHECK_NEXT: {
480 auto line_end = assertFind(source, "\n", start_range, check);
481 auto pos = assertFind(source, check.search_str_, line_end + 1, check);
482 assertNotFind(SourceRange(source, line_end + 1, pos), "\n", check);
483 start_range = pos;
484 end_range = pos + check.search_str_.size();
485 } break;
486 case CHECK_COUNT: {
487 auto group_start_range = std::string::npos;
488 AT_ASSERT(check.count_ && *check.count_ != 0);
489 for (size_t i = 0; i < *check.count_; ++i) {
490 start_range =
491 assertFind(source, check.search_str_, start_range, check);
492 group_start_range = std::min(start_range, group_start_range);
493 end_range = start_range + check.search_str_.size();
494 start_range = end_range;
495 }
496 start_range = group_start_range;
497 } break;
498 case CHECK_SOURCE_HIGHLIGHTED: {
499 doCheckSourceHighlighted(check, source, start_range);
500 break;
501 }
502 case CHECK_REGEX: {
503 start_range =
504 assertFindRegex(source, check.search_str_, start_range, check);
505 end_range = start_range + check.search_str_.size();
506 break;
507 }
508 case CHECK_DAG: {
509 AT_ERROR();
510 } break;
511 case CHECK_NOT: {
512 AT_ERROR();
513 } break;
514 }
515 return SourceRange(source, start_range, end_range);
516 }
517
doCheckstorch::jit::testing::FileCheckImpl518 void doChecks(const std::shared_ptr<Source>& source) {
519 SourceRange prev(source, 0, 0);
520 for (size_t i = 0; i < groups.size(); i++) {
521 const auto& curr_group = groups[i];
522 CheckType type = curr_group.at(0).type_;
523 if (type != CHECK_NOT) {
524 prev = matchGroup(curr_group, source, prev);
525 } else {
526 if (i + 1 < groups.size()) {
527 const auto& next_group = groups[i + 1];
528 AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
529 SourceRange after_not = matchGroup(next_group, source, prev);
530 doCheckNot(curr_group, source, prev, after_not);
531 prev = after_not;
532 ++i; // already checked the group after
533 } else {
534 SourceRange end_of_file(
535 source, source->size() + 1, source->size() + 1);
536 doCheckNot(curr_group, source, prev, end_of_file);
537 }
538 }
539 }
540 }
541
542 std::vector<Check> checks;
543 std::vector<std::vector<Check>> groups;
544 };
545
FileCheck()546 FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
547
operator <<(std::ostream & out,const FileCheckImpl & fc)548 std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
549 out << "FileCheck checks:\n";
550 for (const Check& c : fc.checks) {
551 out << "\t" << c << "\n";
552 }
553 return out;
554 };
555
~FileCheck()556 FileCheck::~FileCheck() {
557 if (!fcImpl->has_run) {
558 std::cout << "You have not run this instance of FileCheck!\n";
559 std::cout << *fcImpl;
560 }
561 fcImpl.reset();
562 };
563
run(const std::string & test_file)564 void FileCheck::run(const std::string& test_file) {
565 fcImpl->run(test_file);
566 };
567
run(const Graph & graph)568 void FileCheck::run(const Graph& graph) {
569 std::stringstream graph_str;
570 graph_str << graph;
571 fcImpl->run(graph_str.str());
572 };
573
run(const std::string & input_checks_string,const std::string & test_string)574 void FileCheck::run(
575 const std::string& input_checks_string,
576 const std::string& test_string) {
577 fcImpl->run(input_checks_string, test_string);
578 }
579
run(const std::string & input_checks_string,const Graph & graph)580 void FileCheck::run(
581 const std::string& input_checks_string,
582 const Graph& graph) {
583 std::stringstream graph_str;
584 graph_str << graph;
585 fcImpl->run(input_checks_string, graph_str.str());
586 }
587
check(const std::string & str)588 FileCheck* FileCheck::check(const std::string& str) {
589 fcImpl->addCheck(CHECK, str);
590 return this;
591 }
592
check_not(const std::string & str)593 FileCheck* FileCheck::check_not(const std::string& str) {
594 fcImpl->addCheck(CHECK_NOT, str);
595 return this;
596 }
597
check_same(const std::string & str)598 FileCheck* FileCheck::check_same(const std::string& str) {
599 fcImpl->addCheck(CHECK_SAME, str);
600 return this;
601 }
602
check_next(const std::string & str)603 FileCheck* FileCheck::check_next(const std::string& str) {
604 fcImpl->addCheck(CHECK_NEXT, str);
605 return this;
606 }
607
check_count(const std::string & str,size_t count,bool exactly)608 FileCheck* FileCheck::check_count(
609 const std::string& str,
610 size_t count,
611 bool exactly) {
612 TORCH_INTERNAL_ASSERT(
613 count != 0 || exactly, "Count == 0 && !exactly doesn't do anything");
614 if (count) {
615 fcImpl->addCheck(CHECK_COUNT, str, count);
616 }
617 if (exactly) {
618 fcImpl->addCheck(CHECK_NOT, str);
619 }
620 return this;
621 }
622
check_dag(const std::string & str)623 FileCheck* FileCheck::check_dag(const std::string& str) {
624 fcImpl->addCheck(CHECK_DAG, str);
625 return this;
626 }
627
check_source_highlighted(const std::string & str)628 FileCheck* FileCheck::check_source_highlighted(const std::string& str) {
629 fcImpl->addCheck(CHECK_SOURCE_HIGHLIGHTED, str);
630 return this;
631 }
632
check_regex(const std::string & str)633 FileCheck* FileCheck::check_regex(const std::string& str) {
634 fcImpl->addCheck(CHECK_REGEX, str);
635 return this;
636 }
637
638 } // namespace testing
639 } // namespace jit
640 } // namespace torch
641