xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/testing/file_check.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 // API modified from llvm::FileCheck
11 
12 #include <c10/util/Exception.h>
13 #include <c10/util/StringUtil.h>
14 #include <c10/util/irange.h>
15 #include <torch/csrc/Export.h>
16 #include <torch/csrc/jit/frontend/source_range.h>
17 #include <torch/csrc/jit/ir/ir.h>
18 #include <torch/csrc/jit/testing/file_check.h>
19 #include <optional>
20 
21 #include <algorithm>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 
26 namespace torch {
27 namespace jit {
28 
29 namespace testing {
30 
31 enum CheckType {
32   CHECK,
33   CHECK_NEXT,
34   CHECK_SAME,
35   CHECK_NOT,
36   CHECK_COUNT,
37   CHECK_DAG,
38   CHECK_SOURCE_HIGHLIGHTED,
39   CHECK_REGEX,
40 };
41 
42 struct Check {
Checktorch::jit::testing::Check43   Check(
44       CheckType type,
45       std::string str,
46       std::optional<size_t> count = std::nullopt)
47       : type_(type), count_(count), search_str_(std::move(str)) {}
48 
Checktorch::jit::testing::Check49   Check(
50       CheckType type,
51       c10::string_view str,
52       std::optional<size_t> count = std::nullopt)
53       : Check(type, std::string(str.begin(), str.end()), count) {}
54 
55   CheckType type_;
56   std::optional<size_t> count_;
57   const std::string search_str_;
58 
59   friend std::ostream& operator<<(std::ostream& out, const Check& c);
60 };
61 
operator <<(std::ostream & out,const Check & c)62 std::ostream& operator<<(std::ostream& out, const Check& c) {
63   switch (c.type_) {
64     case CHECK:
65       out << "CHECK";
66       break;
67     case CHECK_NEXT:
68       out << "CHECK-NEXT";
69       break;
70     case CHECK_SAME:
71       out << "CHECK-SAME";
72       break;
73     case CHECK_NOT:
74       out << "CHECK-NOT";
75       break;
76     case CHECK_DAG:
77       out << "CHECK-DAG";
78       break;
79     case CHECK_COUNT:
80       out << "CHECK-COUNT-" << *c.count_;
81       break;
82     case CHECK_SOURCE_HIGHLIGHTED:
83       out << "CHECK-SOURCE-HIGHLIGHTED";
84       break;
85     case CHECK_REGEX:
86       out << "CHECK-REGEX";
87       break;
88   }
89   out << ": " << c.search_str_;
90   return out;
91 };
92 
93 namespace {
94 
assertFind(const SourceRange & search_range,const std::string & sub,const std::function<void (std::ostream & out)> & extra_msg=nullptr)95 size_t assertFind(
96     const SourceRange& search_range,
97     const std::string& sub,
98     const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
99   auto pos = search_range.source()->text_str().find(sub, search_range.start());
100   if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
101     auto found_range =
102         SourceRange(search_range.source(), search_range.start(), sub.size());
103     std::stringstream ss;
104     ss << "Expected to find ";
105     c10::printQuotedString(ss, sub);
106     ss << " but did not find it" << std::endl;
107     ss << "Searched string:" << std::endl;
108     found_range.highlight(ss);
109     if (extra_msg) {
110       extra_msg(ss);
111     }
112     throw std::runtime_error(ss.str());
113   }
114   return pos;
115 }
116 
assertFind(const SourceRange & search_range,const std::string & sub,const Check & check)117 size_t assertFind(
118     const SourceRange& search_range,
119     const std::string& sub,
120     const Check& check) {
121   return assertFind(search_range, sub, [&](std::ostream& out) {
122     out << "From " << check << "\n";
123   });
124 }
125 
assertFind(const std::shared_ptr<Source> & source,const std::string & sub,size_t start,const Check & check)126 size_t assertFind(
127     const std::shared_ptr<Source>& source,
128     const std::string& sub,
129     size_t start,
130     const Check& check) {
131   return assertFind(SourceRange(source, start, source->size()), sub, check);
132 }
133 
assertFindRegex(const SourceRange & search_range,const std::string & sub,const std::function<void (std::ostream & out)> & extra_msg=nullptr)134 size_t assertFindRegex(
135     const SourceRange& search_range,
136     const std::string& sub,
137     const std::function<void(std::ostream& out)>& extra_msg = nullptr) {
138   auto pos =
139       search_range.source()->text_str().find_regex(sub, search_range.start());
140 
141   if (pos == std::string::npos) {
142     std::stringstream ss;
143     ss << "Expected to find regex ";
144     c10::printQuotedString(ss, sub);
145     ss << " but did not find it" << std::endl;
146     ss << "Searched string:" << std::endl;
147     if (extra_msg) {
148       extra_msg(ss);
149     }
150     throw std::runtime_error(ss.str());
151 
152     return std::string::npos;
153   }
154   return pos;
155 }
156 
assertFindRegex(const SourceRange & search_range,const std::string & sub,const Check & check)157 size_t assertFindRegex(
158     const SourceRange& search_range,
159     const std::string& sub,
160     const Check& check) {
161   return assertFindRegex(search_range, sub, [&](std::ostream& out) {
162     out << "From " << check << "\n";
163   });
164 }
165 
assertFindRegex(const std::shared_ptr<Source> & source,const std::string & sub,size_t start,const Check & check)166 size_t assertFindRegex(
167     const std::shared_ptr<Source>& source,
168     const std::string& sub,
169     size_t start,
170     const Check& check) {
171   return assertFindRegex(
172       SourceRange(source, start, source->size()), sub, check);
173 }
174 
assertNotFind(const SourceRange & search_range,const std::string & sub,const Check & check)175 void assertNotFind(
176     const SourceRange& search_range,
177     const std::string& sub,
178     const Check& check) {
179   auto pos = search_range.source()->text_str().find(sub, search_range.start());
180   if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
181     auto found_range =
182         SourceRange(search_range.source(), pos, sub.size() + pos);
183     std::stringstream ss;
184     ss << "Expected to not find ";
185     c10::printQuotedString(ss, sub);
186     ss << " but found it\n";
187     found_range.highlight(ss);
188     ss << "From " << check << "\n";
189     throw std::runtime_error(ss.str());
190   }
191 }
192 
193 } // namespace
194 
195 struct FileCheckImpl {
196   TORCH_API explicit FileCheckImpl() = default;
197 
runtorch::jit::testing::FileCheckImpl198   TORCH_API void run(const std::string& test_file) {
199     has_run = true;
200 
201     if (groups.empty() || groups[0].empty()) {
202       throw std::runtime_error(
203           "No checks have been added to this instance of"
204           "Filecheck! Check for bad input.");
205     }
206 
207     doChecks(std::make_shared<Source>(test_file));
208   }
209 
runtorch::jit::testing::FileCheckImpl210   TORCH_API void run(
211       const std::string& checks_file,
212       const std::string& test_file) {
213     auto source = std::make_shared<Source>(checks_file);
214     parseStrings(source);
215     run(test_file);
216   }
217 
addChecktorch::jit::testing::FileCheckImpl218   TORCH_API void addCheck(const Check& check) {
219     // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
220     if (groups.empty() ||
221         (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
222       groups.push_back({check});
223     } else {
224       auto& last_group = groups.back();
225       if (last_group.at(0).type_ == check.type_) {
226         last_group.push_back(check);
227       } else {
228         groups.push_back({check});
229       }
230     }
231     has_run = false;
232   }
233 
addChecktorch::jit::testing::FileCheckImpl234   TORCH_API void addCheck(
235       CheckType type,
236       const std::string& s,
237       std::optional<size_t> count = std::nullopt) {
238     addCheck(Check(type, s, count));
239   }
240 
241   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
242   bool has_run = false;
243 
244   friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
245 
246  private:
parseSingleChecktorch::jit::testing::FileCheckImpl247   bool parseSingleCheck(const std::shared_ptr<Source>& source, size_t* start) {
248     const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
249         {CHECK, ": "},
250         {CHECK_NEXT, "-NEXT: "},
251         {CHECK_SAME, "-SAME: "},
252         {CHECK_NOT, "-NOT: "},
253         {CHECK_DAG, "-DAG: "},
254         {CHECK_COUNT, "-COUNT-"}, // needs special parsing
255         {CHECK_SOURCE_HIGHLIGHTED, "-SOURCE-HIGHLIGHTED: "},
256         {CHECK_REGEX, "-REGEX: "},
257     };
258 
259     for (const auto& check_pair : check_pairs) {
260       const std::string& check_suffix = check_pair.second;
261       auto suffix_pos = source->text_str().find(check_suffix, *start);
262       if (suffix_pos != *start) {
263         continue;
264       }
265       size_t end_check_string = suffix_pos + check_suffix.size();
266       CheckType type = check_pair.first;
267       std::optional<size_t> count = std::nullopt;
268       auto end_line = source->text_str().find("\n", end_check_string);
269       bool exactly = false;
270       if (type == CHECK_COUNT) {
271         const std::string exact = "EXACTLY-";
272         if (source->text_str().find(exact, end_check_string) ==
273             end_check_string) {
274           exactly = true;
275           end_check_string += exact.size();
276         }
277         size_t end =
278             assertFind(SourceRange(source, end_check_string, end_line), ":");
279         auto count_view = source->text_str()
280                               .substr(end_check_string, end - end_check_string)
281                               .str();
282         count = std::stoll(std::string(count_view.begin(), count_view.end()));
283         end_check_string = end + 2; // add ':' and the space
284       }
285       auto check = Check(
286           type,
287           source->text_str()
288               .substr(end_check_string, end_line - end_check_string)
289               .str(),
290           count);
291       addCheck(check);
292       if (exactly) {
293         addCheck(CHECK_NOT, check.search_str_);
294       }
295       *start = end_line;
296       return true;
297     }
298     return false;
299   }
300 
findNextStarttorch::jit::testing::FileCheckImpl301   size_t findNextStart(const std::shared_ptr<Source>& source, size_t prev_end) {
302     size_t start = source->text_str().find("#", prev_end);
303     if (start == std::string::npos) {
304       return start;
305     }
306     start += 1;
307     static constexpr size_t max_whitespace = 6;
308     size_t i = 0;
309     while (start + i < source->size() && i < max_whitespace) {
310       auto c = source->char_at(start + i);
311       if (c != ' ' && c != '\t') {
312         break;
313       }
314       i++;
315     }
316     static const std::string check = "CHECK";
317     if (source->text_str().substr(start + i, check.size()) == check) {
318       return start + i + check.size();
319     } else {
320       return findNextStart(source, start + i + 1);
321     }
322   }
323 
parseStringstorch::jit::testing::FileCheckImpl324   void parseStrings(const std::shared_ptr<Source>& source) {
325     size_t start = 0;
326     start = findNextStart(source, 0);
327     while (start != std::string::npos) {
328       bool found_match = parseSingleCheck(source, &start);
329       if (!found_match) {
330         std::ostringstream ss;
331         ss << "Could not parse check at:\n";
332         SourceRange(source, start, start + 1).highlight(ss);
333         ss << "Check for bad input.";
334         has_run = true;
335         throw std::runtime_error(ss.str());
336       }
337       start = findNextStart(source, start);
338     }
339   }
340 
doCheckNottorch::jit::testing::FileCheckImpl341   void doCheckNot(
342       const std::vector<Check>& nots,
343       const std::shared_ptr<Source>& source,
344       const SourceRange& prev,
345       const SourceRange& next) {
346     auto start = prev.end(); // inclusive
347     auto end = next.start(); // exclusive
348     if (end < start) {
349       return;
350     }
351     for (const auto& check : nots) {
352       AT_ASSERT(check.type_ == CHECK_NOT);
353       assertNotFind(SourceRange(source, start, end), check.search_str_, check);
354     }
355   }
356 
357   // Checks that source token is highlighted, does not advance search range.
doCheckSourceHighlightedtorch::jit::testing::FileCheckImpl358   void doCheckSourceHighlighted(
359       const Check& check,
360       const std::shared_ptr<Source>& source,
361       size_t start_offset) {
362     auto construct_error_and_throw = [&](size_t error_start_pos) {
363       SourceRange error_range(
364           source, error_start_pos, check.search_str_.size());
365       std::stringstream ss;
366       ss << "Expected to find ";
367       c10::printQuotedString(ss, check.search_str_);
368       ss << "highlighted but it is not." << std::endl;
369       error_range.highlight(ss);
370       throw std::runtime_error(ss.str());
371     };
372 
373     size_t search_start_offset = start_offset;
374     bool found_token_at_least_once = false;
375     size_t pos = search_start_offset;
376     while (pos < source->size()) {
377       pos = source->text_str().find(check.search_str_, search_start_offset);
378       if (pos == std::string::npos) {
379         break;
380       }
381 
382       found_token_at_least_once = true;
383 
384       auto lineno = source->lineno_for_offset(pos);
385       auto col = pos - source->offset_for_line(lineno);
386       auto highlight_lineno = lineno + 1;
387 
388       if (highlight_lineno >= source->num_lines()) {
389         construct_error_and_throw(pos);
390       }
391 
392       auto highlight_start_offset =
393           source->offset_for_line(highlight_lineno) + col;
394       auto highlight_end_offset = std::min(
395           highlight_start_offset + check.search_str_.size(), source->size());
396 
397       if (highlight_end_offset >= source->size()) {
398         construct_error_and_throw(pos);
399       }
400 
401       bool found_highlight = true;
402       for (const auto posi :
403            c10::irange(highlight_start_offset, highlight_end_offset)) {
404         if (source->char_at(posi) != '~') {
405           found_highlight = false;
406         }
407       }
408 
409       if (found_highlight) {
410         assertNotFind(
411             SourceRange(
412                 source, highlight_start_offset - 1, highlight_start_offset),
413             "~",
414             check);
415         assertNotFind(
416             SourceRange(source, highlight_end_offset, highlight_end_offset + 1),
417             "~",
418             check);
419         return;
420       }
421 
422       search_start_offset = pos + 1;
423     }
424 
425     if (!found_token_at_least_once) {
426       // Guaranteed to fail to generate error message.
427       assertFind(source, check.search_str_, start_offset, check);
428     }
429 
430     construct_error_and_throw(start_offset);
431   }
432 
matchDagGrouptorch::jit::testing::FileCheckImpl433   SourceRange matchDagGroup(
434       const std::vector<Check>& group,
435       const std::shared_ptr<Source>& source,
436       const SourceRange& prev) {
437     size_t group_beg = std::string::npos;
438     size_t group_end = 0;
439 
440     AT_ASSERT(!groups.empty());
441     for (const auto& check : group) {
442       AT_ASSERT(check.type_ == group[0].type_);
443       auto pos = assertFind(source, check.search_str_, prev.end(), check);
444       group_beg = std::min(pos, group_beg);
445       group_end = std::max(pos + check.search_str_.size(), group_end);
446     }
447 
448     return SourceRange(source, group_beg, group_end);
449   }
450 
matchGrouptorch::jit::testing::FileCheckImpl451   SourceRange matchGroup(
452       const std::vector<Check>& group,
453       const std::shared_ptr<Source>& source,
454       const SourceRange& prev) {
455     AT_ASSERT(!group.empty());
456     CheckType type = group[0].type_;
457 
458     if (type == CHECK_DAG) {
459       return matchDagGroup(group, source, prev);
460     }
461     AT_ASSERT(type != CHECK_NOT);
462     AT_ASSERT(group.size() == 1);
463 
464     const auto& check = group[0];
465     size_t start_range = prev.end();
466     size_t end_range = start_range;
467 
468     switch (check.type_) {
469       case CHECK: {
470         start_range = assertFind(source, check.search_str_, start_range, check);
471         end_range = start_range + check.search_str_.size();
472       } break;
473       case CHECK_SAME: {
474         auto pos = assertFind(source, check.search_str_, start_range, check);
475         assertNotFind(SourceRange(source, prev.end(), pos), "\n", check);
476         start_range = pos;
477         end_range = pos + check.search_str_.size();
478       } break;
479       case CHECK_NEXT: {
480         auto line_end = assertFind(source, "\n", start_range, check);
481         auto pos = assertFind(source, check.search_str_, line_end + 1, check);
482         assertNotFind(SourceRange(source, line_end + 1, pos), "\n", check);
483         start_range = pos;
484         end_range = pos + check.search_str_.size();
485       } break;
486       case CHECK_COUNT: {
487         auto group_start_range = std::string::npos;
488         AT_ASSERT(check.count_ && *check.count_ != 0);
489         for (size_t i = 0; i < *check.count_; ++i) {
490           start_range =
491               assertFind(source, check.search_str_, start_range, check);
492           group_start_range = std::min(start_range, group_start_range);
493           end_range = start_range + check.search_str_.size();
494           start_range = end_range;
495         }
496         start_range = group_start_range;
497       } break;
498       case CHECK_SOURCE_HIGHLIGHTED: {
499         doCheckSourceHighlighted(check, source, start_range);
500         break;
501       }
502       case CHECK_REGEX: {
503         start_range =
504             assertFindRegex(source, check.search_str_, start_range, check);
505         end_range = start_range + check.search_str_.size();
506         break;
507       }
508       case CHECK_DAG: {
509         AT_ERROR();
510       } break;
511       case CHECK_NOT: {
512         AT_ERROR();
513       } break;
514     }
515     return SourceRange(source, start_range, end_range);
516   }
517 
doCheckstorch::jit::testing::FileCheckImpl518   void doChecks(const std::shared_ptr<Source>& source) {
519     SourceRange prev(source, 0, 0);
520     for (size_t i = 0; i < groups.size(); i++) {
521       const auto& curr_group = groups[i];
522       CheckType type = curr_group.at(0).type_;
523       if (type != CHECK_NOT) {
524         prev = matchGroup(curr_group, source, prev);
525       } else {
526         if (i + 1 < groups.size()) {
527           const auto& next_group = groups[i + 1];
528           AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
529           SourceRange after_not = matchGroup(next_group, source, prev);
530           doCheckNot(curr_group, source, prev, after_not);
531           prev = after_not;
532           ++i; // already checked the group after
533         } else {
534           SourceRange end_of_file(
535               source, source->size() + 1, source->size() + 1);
536           doCheckNot(curr_group, source, prev, end_of_file);
537         }
538       }
539     }
540   }
541 
542   std::vector<Check> checks;
543   std::vector<std::vector<Check>> groups;
544 };
545 
FileCheck()546 FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
547 
operator <<(std::ostream & out,const FileCheckImpl & fc)548 std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
549   out << "FileCheck checks:\n";
550   for (const Check& c : fc.checks) {
551     out << "\t" << c << "\n";
552   }
553   return out;
554 };
555 
~FileCheck()556 FileCheck::~FileCheck() {
557   if (!fcImpl->has_run) {
558     std::cout << "You have not run this instance of FileCheck!\n";
559     std::cout << *fcImpl;
560   }
561   fcImpl.reset();
562 };
563 
run(const std::string & test_file)564 void FileCheck::run(const std::string& test_file) {
565   fcImpl->run(test_file);
566 };
567 
run(const Graph & graph)568 void FileCheck::run(const Graph& graph) {
569   std::stringstream graph_str;
570   graph_str << graph;
571   fcImpl->run(graph_str.str());
572 };
573 
run(const std::string & input_checks_string,const std::string & test_string)574 void FileCheck::run(
575     const std::string& input_checks_string,
576     const std::string& test_string) {
577   fcImpl->run(input_checks_string, test_string);
578 }
579 
run(const std::string & input_checks_string,const Graph & graph)580 void FileCheck::run(
581     const std::string& input_checks_string,
582     const Graph& graph) {
583   std::stringstream graph_str;
584   graph_str << graph;
585   fcImpl->run(input_checks_string, graph_str.str());
586 }
587 
check(const std::string & str)588 FileCheck* FileCheck::check(const std::string& str) {
589   fcImpl->addCheck(CHECK, str);
590   return this;
591 }
592 
check_not(const std::string & str)593 FileCheck* FileCheck::check_not(const std::string& str) {
594   fcImpl->addCheck(CHECK_NOT, str);
595   return this;
596 }
597 
check_same(const std::string & str)598 FileCheck* FileCheck::check_same(const std::string& str) {
599   fcImpl->addCheck(CHECK_SAME, str);
600   return this;
601 }
602 
check_next(const std::string & str)603 FileCheck* FileCheck::check_next(const std::string& str) {
604   fcImpl->addCheck(CHECK_NEXT, str);
605   return this;
606 }
607 
check_count(const std::string & str,size_t count,bool exactly)608 FileCheck* FileCheck::check_count(
609     const std::string& str,
610     size_t count,
611     bool exactly) {
612   TORCH_INTERNAL_ASSERT(
613       count != 0 || exactly, "Count == 0 && !exactly doesn't do anything");
614   if (count) {
615     fcImpl->addCheck(CHECK_COUNT, str, count);
616   }
617   if (exactly) {
618     fcImpl->addCheck(CHECK_NOT, str);
619   }
620   return this;
621 }
622 
check_dag(const std::string & str)623 FileCheck* FileCheck::check_dag(const std::string& str) {
624   fcImpl->addCheck(CHECK_DAG, str);
625   return this;
626 }
627 
check_source_highlighted(const std::string & str)628 FileCheck* FileCheck::check_source_highlighted(const std::string& str) {
629   fcImpl->addCheck(CHECK_SOURCE_HIGHLIGHTED, str);
630   return this;
631 }
632 
check_regex(const std::string & str)633 FileCheck* FileCheck::check_regex(const std::string& str) {
634   fcImpl->addCheck(CHECK_REGEX, str);
635   return this;
636 }
637 
638 } // namespace testing
639 } // namespace jit
640 } // namespace torch
641