1 #include <cstdlib>
2 #include <iomanip>
3 #include <sstream>
4 #include <string>
5 #include <utility>
6 #include <vector>
7
8 #include <ATen/core/function.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/StringUtil.h>
11 #include <torch/csrc/jit/api/function_impl.h>
12 #include <torch/csrc/jit/jit_opt_limit.h>
13
14 namespace torch::jit {
15
passes_to_current_counter()16 static std::unordered_map<std::string, int64_t>& passes_to_current_counter() {
17 static std::unordered_map<std::string, int64_t> passes_to_current_counter;
18 return passes_to_current_counter;
19 }
20
parseOptLimit(const std::string & opt_limit)21 static int parseOptLimit(const std::string& opt_limit) {
22 try {
23 return std::stoi(opt_limit);
24 } catch (...) {
25 return -1;
26 }
27 }
28
parseJITOptLimitOption(const char * option)29 static std::unordered_map<std::string, int64_t> parseJITOptLimitOption(
30 const char* option) {
31 std::stringstream in_ss;
32 if (option) {
33 in_ss << option;
34 }
35 std::unordered_map<std::string, int64_t> passes_to_opt_limits;
36 std::string line;
37 while (std::getline(in_ss, line, ':')) {
38 if (line.empty()) {
39 continue;
40 }
41 auto index_at = line.find_last_of('=');
42 auto pass_name = line.substr(0, index_at);
43 pass_name = c10::detail::ExcludeFileExtension(pass_name);
44 auto opt_limit = parseOptLimit(line.substr(index_at + 1));
45 passes_to_opt_limits.insert({pass_name, opt_limit});
46 }
47
48 return passes_to_opt_limits;
49 }
50
opt_limit(const char * pass_name)51 bool opt_limit(const char* pass_name) {
52 static const char* opt_limit = std::getenv("PYTORCH_JIT_OPT_LIMIT");
53 // if nothing is provided, let's allow everything
54 if (!opt_limit) {
55 return true;
56 }
57
58 static const std::unordered_map<std::string, int64_t> passes_to_opt_limits =
59 parseJITOptLimitOption(opt_limit);
60 std::string pass{pass_name};
61 pass = c10::detail::StripBasename(pass);
62 pass = c10::detail::ExcludeFileExtension(pass);
63
64 auto opt_limit_it = passes_to_opt_limits.find(pass);
65 if (opt_limit_it == passes_to_opt_limits.end()) {
66 return true;
67 }
68
69 auto current_count_it = passes_to_current_counter().find(pass);
70 if (current_count_it == passes_to_current_counter().end()) {
71 passes_to_current_counter().insert({pass, 0});
72 }
73
74 current_count_it = passes_to_current_counter().find(pass);
75 if (current_count_it->second >= opt_limit_it->second) {
76 return false;
77 }
78
79 current_count_it->second++;
80 return true;
81 }
82
83 } // namespace torch::jit
84