1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/onnx/constant_map.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <iostream>
6 #include <sstream>
7 #include <string>
8 #include <unordered_map>
9
10 namespace torch::jit {
11
12 // Meyer’s Singleton for C++ 14
getInstance()13 ConstantValueMap& ConstantValueMap::getInstance() {
14 static ConstantValueMap s;
15 return s;
16 }
17
SetRank(const std::string & tensorName,size_t rankValue)18 void ConstantValueMap::SetRank(
19 const std::string& tensorName,
20 size_t rankValue) {
21 ConstantValueMap::getInstance().rankMap[tensorName] = rankValue;
22 ConstantValueMap::getInstance().useInferredTypeMap[tensorName] = true;
23 }
24
HasRank(const std::string & tensorName)25 bool ConstantValueMap::HasRank(const std::string& tensorName) {
26 return ConstantValueMap::getInstance().rankMap.find(tensorName) !=
27 ConstantValueMap::getInstance().rankMap.end();
28 }
29
GetRank(const std::string & tensorName)30 std::optional<size_t> ConstantValueMap::GetRank(const std::string& tensorName) {
31 if (!HasRank(tensorName)) {
32 return std::nullopt;
33 }
34 return ConstantValueMap::getInstance().rankMap[tensorName];
35 }
36
SetAllGraphInputsStatic(bool all_static)37 void ConstantValueMap::SetAllGraphInputsStatic(bool all_static) {
38 ConstantValueMap::getInstance().allGraphInputsStatic =
39 std::make_optional(all_static);
40 }
41
GetAllGraphInputsStatic()42 std::optional<bool> ConstantValueMap::GetAllGraphInputsStatic() {
43 return ConstantValueMap::getInstance().allGraphInputsStatic;
44 }
45
SetAllGraphInputsReliableComputed(bool computed)46 void ConstantValueMap::SetAllGraphInputsReliableComputed(bool computed) {
47 ConstantValueMap::getInstance().allGraphInputsReliableComputed = computed;
48 }
49
GetAllGraphInputsReliableComputed()50 bool ConstantValueMap::GetAllGraphInputsReliableComputed() {
51 return ConstantValueMap::getInstance().allGraphInputsReliableComputed;
52 }
53
SetShape(const std::string & tensorName,const c10::SymbolicShape & shapeValue)54 void ConstantValueMap::SetShape(
55 const std::string& tensorName,
56 const c10::SymbolicShape& shapeValue) {
57 ConstantValueMap::getInstance().shapeMap[tensorName] = shapeValue;
58 ConstantValueMap::getInstance().useInferredTypeMap[tensorName] = true;
59 }
60
HasShape(const std::string & tensorName)61 bool ConstantValueMap::HasShape(const std::string& tensorName) {
62 return ConstantValueMap::getInstance().shapeMap.find(tensorName) !=
63 ConstantValueMap::getInstance().shapeMap.end();
64 }
65
GetShape(const std::string & tensorName)66 std::optional<c10::SymbolicShape> ConstantValueMap::GetShape(
67 const std::string& tensorName) {
68 if (!HasShape(tensorName)) {
69 return std::nullopt;
70 }
71 return ConstantValueMap::getInstance().shapeMap[tensorName];
72 }
73
SetValue(const std::string & tensorName,const at::Tensor & value)74 void ConstantValueMap::SetValue(
75 const std::string& tensorName,
76 const at::Tensor& value) {
77 ConstantValueMap::getInstance().tensorValueMap[tensorName] = value;
78 }
79
HasValue(const std::string & tensorName)80 bool ConstantValueMap::HasValue(const std::string& tensorName) {
81 return ConstantValueMap::getInstance().tensorValueMap.find(tensorName) !=
82 ConstantValueMap::getInstance().tensorValueMap.end();
83 }
84
GetValue(const std::string & tensorName)85 std::optional<at::Tensor> ConstantValueMap::GetValue(
86 const std::string& tensorName) {
87 if (!HasValue(tensorName)) {
88 return std::nullopt;
89 }
90 return ConstantValueMap::getInstance().tensorValueMap[tensorName];
91 }
92
EraseValue(const std::string & tensorName)93 void ConstantValueMap::EraseValue(const std::string& tensorName) {
94 ConstantValueMap::getInstance().tensorValueMap.erase(tensorName);
95 }
96
GetCompleteShapeInto1DInt64Vector(const c10::SymbolicShape & shape)97 std::vector<int64_t> ConstantValueMap::GetCompleteShapeInto1DInt64Vector(
98 const c10::SymbolicShape& shape) {
99 TORCH_INTERNAL_ASSERT(shape.isComplete());
100 std::vector<int64_t> shape_value;
101 auto shape_symbol_list = shape.sizes().value();
102 shape_value.reserve(shape_symbol_list.size());
103 for (const auto& v : shape_symbol_list) {
104 shape_value.emplace_back(v.static_size());
105 }
106 return shape_value;
107 }
108
GetShapeInto1DInt64Vector(const std::string & value_name)109 std::optional<std::vector<int64_t>> ConstantValueMap::GetShapeInto1DInt64Vector(
110 const std::string& value_name) {
111 if (ConstantValueMap::HasShape(value_name)) {
112 auto shape_size = ConstantValueMap::GetShape(value_name).value();
113 if (shape_size.isComplete()) {
114 auto shape_value =
115 ConstantValueMap::GetCompleteShapeInto1DInt64Vector(shape_size);
116 return shape_value;
117 }
118 }
119 return std::nullopt;
120 }
121
122 std::optional<std::vector<int64_t>> ConstantValueMap::
GetShapeInto1DInt64VectorWithOneUnknown(const std::string & value_name)123 GetShapeInto1DInt64VectorWithOneUnknown(const std::string& value_name) {
124 if (ConstantValueMap::HasShape(value_name)) {
125 auto shape_size = ConstantValueMap::GetShape(value_name).value();
126 std::vector<int64_t> shape_value;
127 if (shape_size.isComplete()) {
128 shape_value =
129 ConstantValueMap::GetCompleteShapeInto1DInt64Vector(shape_size);
130 return shape_value;
131 } else {
132 size_t count_unknown = 0;
133 auto shape_size_sizes = shape_size.sizes();
134 if (shape_size_sizes.has_value()) {
135 auto shape_symbol_list = shape_size_sizes.value();
136 for (const auto& v : shape_symbol_list) {
137 if (v.is_static()) {
138 shape_value.emplace_back(v.static_size());
139 } else {
140 shape_value.emplace_back(-1);
141 count_unknown += 1;
142 }
143 }
144 if (count_unknown == 1) {
145 return shape_value;
146 }
147 }
148 }
149 }
150 return std::nullopt;
151 }
152
153 // accessor<int64_t, 1> for 1DInt64 case.
GetValueInto1DInt64Vector(const std::string & value_name)154 std::vector<int64_t> ConstantValueMap::GetValueInto1DInt64Vector(
155 const std::string& value_name) {
156 auto value = ConstantValueMap::GetValue(value_name).value();
157 auto value_int64_t = value.toType(at::ScalarType::Long);
158 std::vector<int64_t> value_vector;
159 value_vector.reserve(value_int64_t.size(0));
160 auto value_size_a = value_int64_t.accessor<int64_t, 1>();
161 for (const auto i : c10::irange(value_int64_t.size(0))) {
162 value_vector.emplace_back(static_cast<int64_t>(value_size_a[i]));
163 }
164 return value_vector;
165 }
166
SetTypeReliable(const std::string & tensorName,bool value)167 void ConstantValueMap::SetTypeReliable(
168 const std::string& tensorName,
169 bool value) {
170 ConstantValueMap::getInstance().typeReliableMap[tensorName] = value;
171 }
172
HasTypeReliable(const std::string & tensorName)173 bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) {
174 return ConstantValueMap::getInstance().typeReliableMap.find(tensorName) !=
175 ConstantValueMap::getInstance().typeReliableMap.end();
176 }
177
GetTypeReliable(const std::string & tensorName)178 std::optional<bool> ConstantValueMap::GetTypeReliable(
179 const std::string& tensorName) {
180 if (!HasTypeReliable(tensorName)) {
181 return std::nullopt;
182 }
183 return ConstantValueMap::getInstance().typeReliableMap[tensorName];
184 }
185
SetUseInferredType(const std::string & tensorName,bool value)186 void ConstantValueMap::SetUseInferredType(
187 const std::string& tensorName,
188 bool value) {
189 ConstantValueMap::getInstance().useInferredTypeMap[tensorName] = value;
190 }
191
HasUseInferredType(const std::string & tensorName)192 bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) {
193 return ConstantValueMap::getInstance().useInferredTypeMap.find(tensorName) !=
194 ConstantValueMap::getInstance().useInferredTypeMap.end();
195 }
196
GetUseInferredType(const std::string & tensorName)197 std::optional<bool> ConstantValueMap::GetUseInferredType(
198 const std::string& tensorName) {
199 if (!HasUseInferredType(tensorName)) {
200 return std::nullopt;
201 }
202 return ConstantValueMap::getInstance().useInferredTypeMap[tensorName];
203 }
204
SetShapeValue(const std::string & tensorName,const c10::SymbolicShape & shapeValue)205 void ConstantValueMap::SetShapeValue(
206 const std::string& tensorName,
207 const c10::SymbolicShape& shapeValue) {
208 ConstantValueMap::getInstance().shapeValueMap[tensorName] = shapeValue;
209 }
210
HasShapeValue(const std::string & tensorName)211 bool ConstantValueMap::HasShapeValue(const std::string& tensorName) {
212 return ConstantValueMap::getInstance().shapeValueMap.find(tensorName) !=
213 ConstantValueMap::getInstance().shapeValueMap.end();
214 }
215
GetShapeValue(const std::string & tensorName)216 std::optional<c10::SymbolicShape> ConstantValueMap::GetShapeValue(
217 const std::string& tensorName) {
218 if (!HasShapeValue(tensorName)) {
219 return std::nullopt;
220 }
221 return ConstantValueMap::getInstance().shapeValueMap[tensorName];
222 }
223
224 // Gets the inferredShapeData which is obtained by ONNX data propagation
GetInferredShapeData()225 ShapeDataMap& ConstantValueMap::GetInferredShapeData() {
226 return ConstantValueMap::getInstance().inferredShapeData;
227 }
228
GetSymbolDimMap()229 SymbolDimMap& ConstantValueMap::GetSymbolDimMap() {
230 return ConstantValueMap::getInstance().symbolDimMap;
231 }
232
GetDimSymbolMap()233 DimSymbolMap& ConstantValueMap::GetDimSymbolMap() {
234 return ConstantValueMap::getInstance().dimSymbolMap;
235 }
236
237 template <typename Map>
UpdateStrKey(Map & map,const std::string & old_key,const std::string & new_key)238 void UpdateStrKey(
239 Map& map,
240 const std::string& old_key,
241 const std::string& new_key) {
242 TORCH_INTERNAL_ASSERT(old_key != new_key);
243 if (map.find(old_key) == map.end()) {
244 return;
245 }
246 map[new_key] = map[old_key];
247 map.erase(old_key);
248 }
249
UpdateValueName(const std::string & old_name,const std::string & new_name)250 void ConstantValueMap::UpdateValueName(
251 const std::string& old_name,
252 const std::string& new_name) {
253 if (old_name == new_name) {
254 return;
255 }
256 UpdateStrKey<decltype(rankMap)>(
257 ConstantValueMap::getInstance().rankMap, old_name, new_name);
258 UpdateStrKey<decltype(shapeMap)>(
259 ConstantValueMap::getInstance().shapeMap, old_name, new_name);
260 UpdateStrKey<decltype(tensorValueMap)>(
261 ConstantValueMap::getInstance().tensorValueMap, old_name, new_name);
262 UpdateStrKey<decltype(typeReliableMap)>(
263 ConstantValueMap::getInstance().typeReliableMap, old_name, new_name);
264 UpdateStrKey<decltype(useInferredTypeMap)>(
265 ConstantValueMap::getInstance().useInferredTypeMap, old_name, new_name);
266 UpdateStrKey<decltype(shapeValueMap)>(
267 ConstantValueMap::getInstance().shapeValueMap, old_name, new_name);
268 UpdateStrKey<decltype(inferredShapeData)>(
269 ConstantValueMap::getInstance().inferredShapeData, old_name, new_name);
270 }
271
ClearMaps()272 void ConstantValueMap::ClearMaps() {
273 ConstantValueMap::getInstance().rankMap.clear();
274 ConstantValueMap::getInstance().shapeMap.clear();
275 ConstantValueMap::getInstance().tensorValueMap.clear();
276 ConstantValueMap::getInstance().typeReliableMap.clear();
277 ConstantValueMap::getInstance().useInferredTypeMap.clear();
278 ConstantValueMap::getInstance().shapeValueMap.clear();
279 ConstantValueMap::getInstance().inferredShapeData.clear();
280 ConstantValueMap::getInstance().symbolDimMap.clear();
281 ConstantValueMap::getInstance().dimSymbolMap.clear();
282 ConstantValueMap::getInstance().allGraphInputsStatic = std::nullopt;
283 ConstantValueMap::getInstance().allGraphInputsReliableComputed = false;
284 }
285
286 // For debug only.
PrintMaps()287 void ConstantValueMap::PrintMaps() {
288 std::cout << "Rank/Shape Map:" << '\n';
289 for (const auto& x : ConstantValueMap::getInstance().rankMap) {
290 std::stringstream ss;
291 if (ConstantValueMap::getInstance().shapeMap.find(x.first) !=
292 ConstantValueMap::getInstance().shapeMap.end()) {
293 auto shape_symbols =
294 ConstantValueMap::getInstance().shapeMap[x.first].sizes();
295 if (shape_symbols.has_value()) {
296 for (const auto& shape_symbol : shape_symbols.value()) {
297 if (shape_symbol.is_static()) {
298 ss << shape_symbol.static_size() << ", ";
299 } else {
300 ss << "*, ";
301 }
302 }
303 }
304 }
305 ss << " (rank = " << x.second << ")";
306 std::cout << "node " << x.first << ": " << ss.str() << '\n';
307 }
308 std::cout << '\n';
309 std::cout << "Value Map:" << '\n';
310 for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) {
311 std::cout << "node " << x.first << ": " << x.second << '\n';
312 }
313 std::cout << '\n';
314 std::cout << "TypeReliable Map:" << '\n';
315 size_t count = 0;
316 for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) {
317 std::cout << "(node " << x.first << ": " << x.second << "), ";
318 count++;
319 if (count % 10 == 0) {
320 std::cout << '\n';
321 }
322 }
323 std::cout << '\n';
324 std::cout << "UseInferredType Map:" << '\n';
325 count = 0;
326 for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) {
327 std::cout << "(node " << x.first << ": " << x.second << "), ";
328 count++;
329 if (count % 10 == 0) {
330 std::cout << '\n';
331 }
332 }
333 std::cout << '\n';
334 std::cout << "ShapeValue Map:" << '\n';
335 count = 0;
336 for (const auto& x : ConstantValueMap::getInstance().shapeValueMap) {
337 std::cout << "(node " << x.first << ": " << x.second << "), ";
338 count++;
339 if (count % 10 == 0) {
340 std::cout << '\n';
341 }
342 }
343 std::cout << '\n';
344 std::cout << "InferredShape Map:" << '\n';
345 count = 0;
346 for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) {
347 std::cout << "(node " << x.first << ": ";
348 for (const auto& dim : x.second.dim()) {
349 if (dim.has_dim_param()) {
350 std::cout << dim.dim_param() << " ";
351 } else {
352 std::cout << dim.dim_value() << " ";
353 }
354 }
355 std::cout << "), ";
356 count++;
357 if (count % 10 == 0) {
358 std::cout << '\n';
359 }
360 }
361 std::cout << '\n';
362 std::cout << "SymbolDim Map:" << '\n';
363 count = 0;
364 for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) {
365 std::cout << "(" << x.first << ": " << x.second << "), ";
366 count++;
367 if (count % 10 == 0) {
368 std::cout << '\n';
369 }
370 }
371 std::cout << "DimSymbol Map:" << '\n';
372 count = 0;
373 for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) {
374 std::cout << "(" << x.first << ": " << x.second << "), ";
375 count++;
376 if (count % 10 == 0) {
377 std::cout << '\n';
378 }
379 }
380 }
381
382 } // namespace torch::jit
383