xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/jit_opt_limit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cstdlib>
2 #include <iomanip>
3 #include <sstream>
4 #include <string>
5 #include <utility>
6 #include <vector>
7 
8 #include <ATen/core/function.h>
9 #include <c10/util/Exception.h>
10 #include <c10/util/StringUtil.h>
11 #include <torch/csrc/jit/api/function_impl.h>
12 #include <torch/csrc/jit/jit_opt_limit.h>
13 
14 namespace torch::jit {
15 
passes_to_current_counter()16 static std::unordered_map<std::string, int64_t>& passes_to_current_counter() {
17   static std::unordered_map<std::string, int64_t> passes_to_current_counter;
18   return passes_to_current_counter;
19 }
20 
parseOptLimit(const std::string & opt_limit)21 static int parseOptLimit(const std::string& opt_limit) {
22   try {
23     return std::stoi(opt_limit);
24   } catch (...) {
25     return -1;
26   }
27 }
28 
parseJITOptLimitOption(const char * option)29 static std::unordered_map<std::string, int64_t> parseJITOptLimitOption(
30     const char* option) {
31   std::stringstream in_ss;
32   if (option) {
33     in_ss << option;
34   }
35   std::unordered_map<std::string, int64_t> passes_to_opt_limits;
36   std::string line;
37   while (std::getline(in_ss, line, ':')) {
38     if (line.empty()) {
39       continue;
40     }
41     auto index_at = line.find_last_of('=');
42     auto pass_name = line.substr(0, index_at);
43     pass_name = c10::detail::ExcludeFileExtension(pass_name);
44     auto opt_limit = parseOptLimit(line.substr(index_at + 1));
45     passes_to_opt_limits.insert({pass_name, opt_limit});
46   }
47 
48   return passes_to_opt_limits;
49 }
50 
opt_limit(const char * pass_name)51 bool opt_limit(const char* pass_name) {
52   static const char* opt_limit = std::getenv("PYTORCH_JIT_OPT_LIMIT");
53   // if nothing is provided, let's allow everything
54   if (!opt_limit) {
55     return true;
56   }
57 
58   static const std::unordered_map<std::string, int64_t> passes_to_opt_limits =
59       parseJITOptLimitOption(opt_limit);
60   std::string pass{pass_name};
61   pass = c10::detail::StripBasename(pass);
62   pass = c10::detail::ExcludeFileExtension(pass);
63 
64   auto opt_limit_it = passes_to_opt_limits.find(pass);
65   if (opt_limit_it == passes_to_opt_limits.end()) {
66     return true;
67   }
68 
69   auto current_count_it = passes_to_current_counter().find(pass);
70   if (current_count_it == passes_to_current_counter().end()) {
71     passes_to_current_counter().insert({pass, 0});
72   }
73 
74   current_count_it = passes_to_current_counter().find(pass);
75   if (current_count_it->second >= opt_limit_it->second) {
76     return false;
77   }
78 
79   current_count_it->second++;
80   return true;
81 }
82 
83 } // namespace torch::jit
84