xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/operator_upgraders/upgraders.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 #include <mutex>
6 #include <string>
7 #include <unordered_map>
8 
9 namespace torch::jit {
10 
11 static UpgradersMap upgradersMap;
12 
set_content(std::unordered_map<std::string,std::shared_ptr<Graph>> && content)13 void UpgradersMap::set_content(
14     std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
15   // make sure we populate the map only once
16   std::lock_guard<std::mutex> _(lock);
17   if (isPopulated) {
18     return;
19   }
20 
21   content_ = std::move(content);
22   isPopulated = true;
23 }
24 
count()25 int UpgradersMap::count() {
26   std::lock_guard<std::mutex> _(lock);
27   return content_.size();
28 }
29 
is_populated()30 bool UpgradersMap::is_populated() {
31   std::lock_guard<std::mutex> _(lock);
32   return isPopulated;
33 }
34 
35 const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
get_content()36     get_content() {
37   std::lock_guard<std::mutex> _(lock);
38   return content_;
39 }
40 
test_only_set_content(const std::unordered_map<std::string,std::string> & content)41 void UpgradersMap::test_only_set_content(
42     const std::unordered_map<std::string, std::string>& content) {
43   std::lock_guard<std::mutex> _(lock);
44   for (const auto& entry : content) {
45     auto graph = std::make_shared<Graph>();
46     torch::jit::parseIR(entry.second, graph.get());
47     content_.insert(std::make_pair(entry.first, graph));
48   }
49 }
test_only_remove_content(const std::unordered_map<std::string,std::string> & content)50 void UpgradersMap::test_only_remove_content(
51     const std::unordered_map<std::string, std::string>& content) {
52   std::lock_guard<std::mutex> _(lock);
53   for (const auto& entry : content) {
54     content_.erase(entry.first);
55   }
56 }
57 
populate_upgraders_map(std::unordered_map<std::string,std::shared_ptr<Graph>> && content)58 void populate_upgraders_map(
59     std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
60   upgradersMap.set_content(std::move(content));
61 }
62 
get_upgraders_map_size()63 int get_upgraders_map_size() {
64   return upgradersMap.count();
65 }
66 
is_upgraders_map_populated()67 bool is_upgraders_map_populated() {
68   return upgradersMap.is_populated();
69 }
70 
71 const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map()72 dump_upgraders_map() {
73   return upgradersMap.get_content();
74 }
75 
test_only_populate_upgraders(const std::unordered_map<std::string,std::string> & content)76 void test_only_populate_upgraders(
77     const std::unordered_map<std::string, std::string>& content) {
78   upgradersMap.test_only_set_content(content);
79 }
80 
test_only_remove_upgraders(const std::unordered_map<std::string,std::string> & content)81 void test_only_remove_upgraders(
82     const std::unordered_map<std::string, std::string>& content) {
83   upgradersMap.test_only_remove_content(content);
84 }
85 
86 } // namespace torch::jit
87