1 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
2
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 #include <mutex>
6 #include <string>
7 #include <unordered_map>
8
9 namespace torch::jit {
10
11 static UpgradersMap upgradersMap;
12
set_content(std::unordered_map<std::string,std::shared_ptr<Graph>> && content)13 void UpgradersMap::set_content(
14 std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
15 // make sure we populate the map only once
16 std::lock_guard<std::mutex> _(lock);
17 if (isPopulated) {
18 return;
19 }
20
21 content_ = std::move(content);
22 isPopulated = true;
23 }
24
count()25 int UpgradersMap::count() {
26 std::lock_guard<std::mutex> _(lock);
27 return content_.size();
28 }
29
is_populated()30 bool UpgradersMap::is_populated() {
31 std::lock_guard<std::mutex> _(lock);
32 return isPopulated;
33 }
34
35 const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
get_content()36 get_content() {
37 std::lock_guard<std::mutex> _(lock);
38 return content_;
39 }
40
test_only_set_content(const std::unordered_map<std::string,std::string> & content)41 void UpgradersMap::test_only_set_content(
42 const std::unordered_map<std::string, std::string>& content) {
43 std::lock_guard<std::mutex> _(lock);
44 for (const auto& entry : content) {
45 auto graph = std::make_shared<Graph>();
46 torch::jit::parseIR(entry.second, graph.get());
47 content_.insert(std::make_pair(entry.first, graph));
48 }
49 }
test_only_remove_content(const std::unordered_map<std::string,std::string> & content)50 void UpgradersMap::test_only_remove_content(
51 const std::unordered_map<std::string, std::string>& content) {
52 std::lock_guard<std::mutex> _(lock);
53 for (const auto& entry : content) {
54 content_.erase(entry.first);
55 }
56 }
57
populate_upgraders_map(std::unordered_map<std::string,std::shared_ptr<Graph>> && content)58 void populate_upgraders_map(
59 std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
60 upgradersMap.set_content(std::move(content));
61 }
62
get_upgraders_map_size()63 int get_upgraders_map_size() {
64 return upgradersMap.count();
65 }
66
is_upgraders_map_populated()67 bool is_upgraders_map_populated() {
68 return upgradersMap.is_populated();
69 }
70
71 const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map()72 dump_upgraders_map() {
73 return upgradersMap.get_content();
74 }
75
test_only_populate_upgraders(const std::unordered_map<std::string,std::string> & content)76 void test_only_populate_upgraders(
77 const std::unordered_map<std::string, std::string>& content) {
78 upgradersMap.test_only_set_content(content);
79 }
80
test_only_remove_upgraders(const std::unordered_map<std::string,std::string> & content)81 void test_only_remove_upgraders(
82 const std::unordered_map<std::string, std::string>& content) {
83 upgradersMap.test_only_remove_content(content);
84 }
85
86 } // namespace torch::jit
87