1 #include <torch/csrc/jit/operator_upgraders/version_map.h>
2
3 #include <algorithm>
4 #include <string>
5 #include <unordered_map>
6 #include <vector>
7
8 namespace torch::jit {
9
10 // this flag is used to make sure the elements in the version map
11 // are sorted according to when the upgraders are introduced.
12 static bool isVersionMapSorted = false;
13
14 // Main entry point for all operators that have valid upgraders.
15 // Note for developers: The list of upgraders should be SORTED
16 // by the version number where the upgrader is registered.
17 static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
18 {{"aten::logspace",
19 {{9,
20 "logspace_0_8",
21 "aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
22 {"aten::logspace.out",
23 {{9,
24 "logspace_out_0_8",
25 "aten::logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)"}}},
26 {"aten::linspace",
27 {{8,
28 "linspace_0_7",
29 "aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
30 {"aten::linspace.out",
31 {{8,
32 "linspace_out_0_7",
33 "aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!)"}}},
34 {"aten::div.Tensor",
35 {{4,
36 "div_Tensor_0_3",
37 "aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
38 {"aten::div.Tensor_mode",
39 {{4,
40 "div_Tensor_mode_0_3",
41 "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"}}},
42 {"aten::div.Scalar",
43 {{4,
44 "div_Scalar_0_3",
45 "aten::div.Scalar(Tensor self, Scalar other) -> Tensor"}}},
46 {"aten::div.Scalar_mode",
47 {{4,
48 "div_Scalar_mode_0_3",
49 "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"}}},
50 {"aten::div.out",
51 {{4,
52 "div_out_0_3",
53 "aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"}}},
54 {"aten::div.out_mode",
55 {{4,
56 "div_out_mode_0_3",
57 "aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)"}}},
58 {"aten::div_.Tensor",
59 {{4,
60 "div__Tensor_0_3",
61 "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
62 {"aten::div_.Tensor_mode",
63 {{4,
64 "div__Tensor_mode_0_3",
65 "aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)"}}},
66 {"aten::div_.Scalar",
67 {{4,
68 "div__Scalar_0_3",
69 "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
70 {"aten::div_.Scalar_mode",
71 {{4,
72 "div__Scalar_mode_0_3",
73 "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"}}},
74 {"aten::full",
75 {{5,
76 "full_0_4",
77 "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
78 {"aten::full.names",
79 {{5,
80 "full_names_0_4",
81 "aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
82 {"aten::full.out",
83 {{5,
84 "full_out_0_4",
85 "aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}},
86 {"aten::gelu", {{10, "gelu_0_9", "aten::gelu(Tensor self) -> Tensor"}}},
87 {"aten::gelu.out",
88 {{10,
89 "gelu_out_0_9",
90 "aten::gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor"}}}});
91
92 const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
get_operator_version_map()93 get_operator_version_map() {
94 if (!isVersionMapSorted) {
95 for (auto entry : operatorVersionMap) {
96 std::sort(
97 entry.second.begin(),
98 entry.second.end(),
99 [](const auto& a, const auto& b) {
100 return a.bumped_at_version > b.bumped_at_version;
101 });
102 }
103 isVersionMapSorted = true;
104 }
105 return operatorVersionMap;
106 }
107
test_only_add_entry(const std::string & op_name,UpgraderEntry entry)108 void test_only_add_entry(const std::string& op_name, UpgraderEntry entry) {
109 test_only_reset_flag();
110 operatorVersionMap[op_name].emplace_back(std::move(entry));
111 }
112
test_only_remove_entry(const std::string & op_name)113 void test_only_remove_entry(const std::string& op_name) {
114 test_only_reset_flag();
115 operatorVersionMap.erase(op_name);
116 }
117
test_only_reset_flag()118 void test_only_reset_flag() {
119 isVersionMapSorted = false;
120 }
121
122 static bool calculatePackageVersionBasedOnUpgraders = false;
123
calculate_package_version_based_on_upgraders(bool val)124 void calculate_package_version_based_on_upgraders(bool val) {
125 calculatePackageVersionBasedOnUpgraders = val;
126 }
127
get_version_calculator_flag()128 bool get_version_calculator_flag() {
129 return calculatePackageVersionBasedOnUpgraders;
130 }
131
132 } // namespace torch::jit
133