xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/operator_upgraders/version_map.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/operator_upgraders/version_map.h>
2 
3 #include <algorithm>
4 #include <string>
5 #include <unordered_map>
6 #include <vector>
7 
8 namespace torch::jit {
9 
10 // this flag is used to make sure the elements in the version map
11 // are sorted according to when the upgraders are introduced.
12 static bool isVersionMapSorted = false;
13 
14 // Main entry point for all operators that have valid upgraders.
15 // Note for developers: The list of upgraders should be SORTED
16 // by the version number where the upgrader is registered.
17 static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
18     {{"aten::logspace",
19       {{9,
20         "logspace_0_8",
21         "aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
22      {"aten::logspace.out",
23       {{9,
24         "logspace_out_0_8",
25         "aten::logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)"}}},
26      {"aten::linspace",
27       {{8,
28         "linspace_0_7",
29         "aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
30      {"aten::linspace.out",
31       {{8,
32         "linspace_out_0_7",
33         "aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!)"}}},
34      {"aten::div.Tensor",
35       {{4,
36         "div_Tensor_0_3",
37         "aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
38      {"aten::div.Tensor_mode",
39       {{4,
40         "div_Tensor_mode_0_3",
41         "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"}}},
42      {"aten::div.Scalar",
43       {{4,
44         "div_Scalar_0_3",
45         "aten::div.Scalar(Tensor self, Scalar other) -> Tensor"}}},
46      {"aten::div.Scalar_mode",
47       {{4,
48         "div_Scalar_mode_0_3",
49         "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"}}},
50      {"aten::div.out",
51       {{4,
52         "div_out_0_3",
53         "aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"}}},
54      {"aten::div.out_mode",
55       {{4,
56         "div_out_mode_0_3",
57         "aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)"}}},
58      {"aten::div_.Tensor",
59       {{4,
60         "div__Tensor_0_3",
61         "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
62      {"aten::div_.Tensor_mode",
63       {{4,
64         "div__Tensor_mode_0_3",
65         "aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)"}}},
66      {"aten::div_.Scalar",
67       {{4,
68         "div__Scalar_0_3",
69         "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
70      {"aten::div_.Scalar_mode",
71       {{4,
72         "div__Scalar_mode_0_3",
73         "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"}}},
74      {"aten::full",
75       {{5,
76         "full_0_4",
77         "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
78      {"aten::full.names",
79       {{5,
80         "full_names_0_4",
81         "aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
82      {"aten::full.out",
83       {{5,
84         "full_out_0_4",
85         "aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}},
86      {"aten::gelu", {{10, "gelu_0_9", "aten::gelu(Tensor self) -> Tensor"}}},
87      {"aten::gelu.out",
88       {{10,
89         "gelu_out_0_9",
90         "aten::gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor"}}}});
91 
92 const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
get_operator_version_map()93 get_operator_version_map() {
94   if (!isVersionMapSorted) {
95     for (auto entry : operatorVersionMap) {
96       std::sort(
97           entry.second.begin(),
98           entry.second.end(),
99           [](const auto& a, const auto& b) {
100             return a.bumped_at_version > b.bumped_at_version;
101           });
102     }
103     isVersionMapSorted = true;
104   }
105   return operatorVersionMap;
106 }
107 
test_only_add_entry(const std::string & op_name,UpgraderEntry entry)108 void test_only_add_entry(const std::string& op_name, UpgraderEntry entry) {
109   test_only_reset_flag();
110   operatorVersionMap[op_name].emplace_back(std::move(entry));
111 }
112 
test_only_remove_entry(const std::string & op_name)113 void test_only_remove_entry(const std::string& op_name) {
114   test_only_reset_flag();
115   operatorVersionMap.erase(op_name);
116 }
117 
test_only_reset_flag()118 void test_only_reset_flag() {
119   isVersionMapSorted = false;
120 }
121 
122 static bool calculatePackageVersionBasedOnUpgraders = false;
123 
calculate_package_version_based_on_upgraders(bool val)124 void calculate_package_version_based_on_upgraders(bool val) {
125   calculatePackageVersionBasedOnUpgraders = val;
126 }
127 
get_version_calculator_flag()128 bool get_version_calculator_flag() {
129   return calculatePackageVersionBasedOnUpgraders;
130 }
131 
132 } // namespace torch::jit
133