xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/constant_map.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/onnx/constant_map.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <iostream>
6 #include <sstream>
7 #include <string>
8 #include <unordered_map>
9 
10 namespace torch::jit {
11 
12 // Meyer’s Singleton for C++ 14
getInstance()13 ConstantValueMap& ConstantValueMap::getInstance() {
14   static ConstantValueMap s;
15   return s;
16 }
17 
SetRank(const std::string & tensorName,size_t rankValue)18 void ConstantValueMap::SetRank(
19     const std::string& tensorName,
20     size_t rankValue) {
21   ConstantValueMap::getInstance().rankMap[tensorName] = rankValue;
22   ConstantValueMap::getInstance().useInferredTypeMap[tensorName] = true;
23 }
24 
HasRank(const std::string & tensorName)25 bool ConstantValueMap::HasRank(const std::string& tensorName) {
26   return ConstantValueMap::getInstance().rankMap.find(tensorName) !=
27       ConstantValueMap::getInstance().rankMap.end();
28 }
29 
GetRank(const std::string & tensorName)30 std::optional<size_t> ConstantValueMap::GetRank(const std::string& tensorName) {
31   if (!HasRank(tensorName)) {
32     return std::nullopt;
33   }
34   return ConstantValueMap::getInstance().rankMap[tensorName];
35 }
36 
SetAllGraphInputsStatic(bool all_static)37 void ConstantValueMap::SetAllGraphInputsStatic(bool all_static) {
38   ConstantValueMap::getInstance().allGraphInputsStatic =
39       std::make_optional(all_static);
40 }
41 
GetAllGraphInputsStatic()42 std::optional<bool> ConstantValueMap::GetAllGraphInputsStatic() {
43   return ConstantValueMap::getInstance().allGraphInputsStatic;
44 }
45 
SetAllGraphInputsReliableComputed(bool computed)46 void ConstantValueMap::SetAllGraphInputsReliableComputed(bool computed) {
47   ConstantValueMap::getInstance().allGraphInputsReliableComputed = computed;
48 }
49 
GetAllGraphInputsReliableComputed()50 bool ConstantValueMap::GetAllGraphInputsReliableComputed() {
51   return ConstantValueMap::getInstance().allGraphInputsReliableComputed;
52 }
53 
SetShape(const std::string & tensorName,const c10::SymbolicShape & shapeValue)54 void ConstantValueMap::SetShape(
55     const std::string& tensorName,
56     const c10::SymbolicShape& shapeValue) {
57   ConstantValueMap::getInstance().shapeMap[tensorName] = shapeValue;
58   ConstantValueMap::getInstance().useInferredTypeMap[tensorName] = true;
59 }
60 
HasShape(const std::string & tensorName)61 bool ConstantValueMap::HasShape(const std::string& tensorName) {
62   return ConstantValueMap::getInstance().shapeMap.find(tensorName) !=
63       ConstantValueMap::getInstance().shapeMap.end();
64 }
65 
GetShape(const std::string & tensorName)66 std::optional<c10::SymbolicShape> ConstantValueMap::GetShape(
67     const std::string& tensorName) {
68   if (!HasShape(tensorName)) {
69     return std::nullopt;
70   }
71   return ConstantValueMap::getInstance().shapeMap[tensorName];
72 }
73 
SetValue(const std::string & tensorName,const at::Tensor & value)74 void ConstantValueMap::SetValue(
75     const std::string& tensorName,
76     const at::Tensor& value) {
77   ConstantValueMap::getInstance().tensorValueMap[tensorName] = value;
78 }
79 
HasValue(const std::string & tensorName)80 bool ConstantValueMap::HasValue(const std::string& tensorName) {
81   return ConstantValueMap::getInstance().tensorValueMap.find(tensorName) !=
82       ConstantValueMap::getInstance().tensorValueMap.end();
83 }
84 
GetValue(const std::string & tensorName)85 std::optional<at::Tensor> ConstantValueMap::GetValue(
86     const std::string& tensorName) {
87   if (!HasValue(tensorName)) {
88     return std::nullopt;
89   }
90   return ConstantValueMap::getInstance().tensorValueMap[tensorName];
91 }
92 
EraseValue(const std::string & tensorName)93 void ConstantValueMap::EraseValue(const std::string& tensorName) {
94   ConstantValueMap::getInstance().tensorValueMap.erase(tensorName);
95 }
96 
GetCompleteShapeInto1DInt64Vector(const c10::SymbolicShape & shape)97 std::vector<int64_t> ConstantValueMap::GetCompleteShapeInto1DInt64Vector(
98     const c10::SymbolicShape& shape) {
99   TORCH_INTERNAL_ASSERT(shape.isComplete());
100   std::vector<int64_t> shape_value;
101   auto shape_symbol_list = shape.sizes().value();
102   shape_value.reserve(shape_symbol_list.size());
103   for (const auto& v : shape_symbol_list) {
104     shape_value.emplace_back(v.static_size());
105   }
106   return shape_value;
107 }
108 
GetShapeInto1DInt64Vector(const std::string & value_name)109 std::optional<std::vector<int64_t>> ConstantValueMap::GetShapeInto1DInt64Vector(
110     const std::string& value_name) {
111   if (ConstantValueMap::HasShape(value_name)) {
112     auto shape_size = ConstantValueMap::GetShape(value_name).value();
113     if (shape_size.isComplete()) {
114       auto shape_value =
115           ConstantValueMap::GetCompleteShapeInto1DInt64Vector(shape_size);
116       return shape_value;
117     }
118   }
119   return std::nullopt;
120 }
121 
122 std::optional<std::vector<int64_t>> ConstantValueMap::
GetShapeInto1DInt64VectorWithOneUnknown(const std::string & value_name)123     GetShapeInto1DInt64VectorWithOneUnknown(const std::string& value_name) {
124   if (ConstantValueMap::HasShape(value_name)) {
125     auto shape_size = ConstantValueMap::GetShape(value_name).value();
126     std::vector<int64_t> shape_value;
127     if (shape_size.isComplete()) {
128       shape_value =
129           ConstantValueMap::GetCompleteShapeInto1DInt64Vector(shape_size);
130       return shape_value;
131     } else {
132       size_t count_unknown = 0;
133       auto shape_size_sizes = shape_size.sizes();
134       if (shape_size_sizes.has_value()) {
135         auto shape_symbol_list = shape_size_sizes.value();
136         for (const auto& v : shape_symbol_list) {
137           if (v.is_static()) {
138             shape_value.emplace_back(v.static_size());
139           } else {
140             shape_value.emplace_back(-1);
141             count_unknown += 1;
142           }
143         }
144         if (count_unknown == 1) {
145           return shape_value;
146         }
147       }
148     }
149   }
150   return std::nullopt;
151 }
152 
153 // accessor<int64_t, 1> for 1DInt64 case.
GetValueInto1DInt64Vector(const std::string & value_name)154 std::vector<int64_t> ConstantValueMap::GetValueInto1DInt64Vector(
155     const std::string& value_name) {
156   auto value = ConstantValueMap::GetValue(value_name).value();
157   auto value_int64_t = value.toType(at::ScalarType::Long);
158   std::vector<int64_t> value_vector;
159   value_vector.reserve(value_int64_t.size(0));
160   auto value_size_a = value_int64_t.accessor<int64_t, 1>();
161   for (const auto i : c10::irange(value_int64_t.size(0))) {
162     value_vector.emplace_back(static_cast<int64_t>(value_size_a[i]));
163   }
164   return value_vector;
165 }
166 
SetTypeReliable(const std::string & tensorName,bool value)167 void ConstantValueMap::SetTypeReliable(
168     const std::string& tensorName,
169     bool value) {
170   ConstantValueMap::getInstance().typeReliableMap[tensorName] = value;
171 }
172 
HasTypeReliable(const std::string & tensorName)173 bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) {
174   return ConstantValueMap::getInstance().typeReliableMap.find(tensorName) !=
175       ConstantValueMap::getInstance().typeReliableMap.end();
176 }
177 
GetTypeReliable(const std::string & tensorName)178 std::optional<bool> ConstantValueMap::GetTypeReliable(
179     const std::string& tensorName) {
180   if (!HasTypeReliable(tensorName)) {
181     return std::nullopt;
182   }
183   return ConstantValueMap::getInstance().typeReliableMap[tensorName];
184 }
185 
SetUseInferredType(const std::string & tensorName,bool value)186 void ConstantValueMap::SetUseInferredType(
187     const std::string& tensorName,
188     bool value) {
189   ConstantValueMap::getInstance().useInferredTypeMap[tensorName] = value;
190 }
191 
HasUseInferredType(const std::string & tensorName)192 bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) {
193   return ConstantValueMap::getInstance().useInferredTypeMap.find(tensorName) !=
194       ConstantValueMap::getInstance().useInferredTypeMap.end();
195 }
196 
GetUseInferredType(const std::string & tensorName)197 std::optional<bool> ConstantValueMap::GetUseInferredType(
198     const std::string& tensorName) {
199   if (!HasUseInferredType(tensorName)) {
200     return std::nullopt;
201   }
202   return ConstantValueMap::getInstance().useInferredTypeMap[tensorName];
203 }
204 
SetShapeValue(const std::string & tensorName,const c10::SymbolicShape & shapeValue)205 void ConstantValueMap::SetShapeValue(
206     const std::string& tensorName,
207     const c10::SymbolicShape& shapeValue) {
208   ConstantValueMap::getInstance().shapeValueMap[tensorName] = shapeValue;
209 }
210 
HasShapeValue(const std::string & tensorName)211 bool ConstantValueMap::HasShapeValue(const std::string& tensorName) {
212   return ConstantValueMap::getInstance().shapeValueMap.find(tensorName) !=
213       ConstantValueMap::getInstance().shapeValueMap.end();
214 }
215 
GetShapeValue(const std::string & tensorName)216 std::optional<c10::SymbolicShape> ConstantValueMap::GetShapeValue(
217     const std::string& tensorName) {
218   if (!HasShapeValue(tensorName)) {
219     return std::nullopt;
220   }
221   return ConstantValueMap::getInstance().shapeValueMap[tensorName];
222 }
223 
224 // Gets the inferredShapeData which is obtained by ONNX data propagation
GetInferredShapeData()225 ShapeDataMap& ConstantValueMap::GetInferredShapeData() {
226   return ConstantValueMap::getInstance().inferredShapeData;
227 }
228 
GetSymbolDimMap()229 SymbolDimMap& ConstantValueMap::GetSymbolDimMap() {
230   return ConstantValueMap::getInstance().symbolDimMap;
231 }
232 
GetDimSymbolMap()233 DimSymbolMap& ConstantValueMap::GetDimSymbolMap() {
234   return ConstantValueMap::getInstance().dimSymbolMap;
235 }
236 
237 template <typename Map>
UpdateStrKey(Map & map,const std::string & old_key,const std::string & new_key)238 void UpdateStrKey(
239     Map& map,
240     const std::string& old_key,
241     const std::string& new_key) {
242   TORCH_INTERNAL_ASSERT(old_key != new_key);
243   if (map.find(old_key) == map.end()) {
244     return;
245   }
246   map[new_key] = map[old_key];
247   map.erase(old_key);
248 }
249 
UpdateValueName(const std::string & old_name,const std::string & new_name)250 void ConstantValueMap::UpdateValueName(
251     const std::string& old_name,
252     const std::string& new_name) {
253   if (old_name == new_name) {
254     return;
255   }
256   UpdateStrKey<decltype(rankMap)>(
257       ConstantValueMap::getInstance().rankMap, old_name, new_name);
258   UpdateStrKey<decltype(shapeMap)>(
259       ConstantValueMap::getInstance().shapeMap, old_name, new_name);
260   UpdateStrKey<decltype(tensorValueMap)>(
261       ConstantValueMap::getInstance().tensorValueMap, old_name, new_name);
262   UpdateStrKey<decltype(typeReliableMap)>(
263       ConstantValueMap::getInstance().typeReliableMap, old_name, new_name);
264   UpdateStrKey<decltype(useInferredTypeMap)>(
265       ConstantValueMap::getInstance().useInferredTypeMap, old_name, new_name);
266   UpdateStrKey<decltype(shapeValueMap)>(
267       ConstantValueMap::getInstance().shapeValueMap, old_name, new_name);
268   UpdateStrKey<decltype(inferredShapeData)>(
269       ConstantValueMap::getInstance().inferredShapeData, old_name, new_name);
270 }
271 
ClearMaps()272 void ConstantValueMap::ClearMaps() {
273   ConstantValueMap::getInstance().rankMap.clear();
274   ConstantValueMap::getInstance().shapeMap.clear();
275   ConstantValueMap::getInstance().tensorValueMap.clear();
276   ConstantValueMap::getInstance().typeReliableMap.clear();
277   ConstantValueMap::getInstance().useInferredTypeMap.clear();
278   ConstantValueMap::getInstance().shapeValueMap.clear();
279   ConstantValueMap::getInstance().inferredShapeData.clear();
280   ConstantValueMap::getInstance().symbolDimMap.clear();
281   ConstantValueMap::getInstance().dimSymbolMap.clear();
282   ConstantValueMap::getInstance().allGraphInputsStatic = std::nullopt;
283   ConstantValueMap::getInstance().allGraphInputsReliableComputed = false;
284 }
285 
286 // For debug only.
PrintMaps()287 void ConstantValueMap::PrintMaps() {
288   std::cout << "Rank/Shape Map:" << '\n';
289   for (const auto& x : ConstantValueMap::getInstance().rankMap) {
290     std::stringstream ss;
291     if (ConstantValueMap::getInstance().shapeMap.find(x.first) !=
292         ConstantValueMap::getInstance().shapeMap.end()) {
293       auto shape_symbols =
294           ConstantValueMap::getInstance().shapeMap[x.first].sizes();
295       if (shape_symbols.has_value()) {
296         for (const auto& shape_symbol : shape_symbols.value()) {
297           if (shape_symbol.is_static()) {
298             ss << shape_symbol.static_size() << ", ";
299           } else {
300             ss << "*, ";
301           }
302         }
303       }
304     }
305     ss << " (rank = " << x.second << ")";
306     std::cout << "node " << x.first << ": " << ss.str() << '\n';
307   }
308   std::cout << '\n';
309   std::cout << "Value Map:" << '\n';
310   for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) {
311     std::cout << "node " << x.first << ": " << x.second << '\n';
312   }
313   std::cout << '\n';
314   std::cout << "TypeReliable Map:" << '\n';
315   size_t count = 0;
316   for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) {
317     std::cout << "(node " << x.first << ": " << x.second << "), ";
318     count++;
319     if (count % 10 == 0) {
320       std::cout << '\n';
321     }
322   }
323   std::cout << '\n';
324   std::cout << "UseInferredType Map:" << '\n';
325   count = 0;
326   for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) {
327     std::cout << "(node " << x.first << ": " << x.second << "), ";
328     count++;
329     if (count % 10 == 0) {
330       std::cout << '\n';
331     }
332   }
333   std::cout << '\n';
334   std::cout << "ShapeValue Map:" << '\n';
335   count = 0;
336   for (const auto& x : ConstantValueMap::getInstance().shapeValueMap) {
337     std::cout << "(node " << x.first << ": " << x.second << "), ";
338     count++;
339     if (count % 10 == 0) {
340       std::cout << '\n';
341     }
342   }
343   std::cout << '\n';
344   std::cout << "InferredShape Map:" << '\n';
345   count = 0;
346   for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) {
347     std::cout << "(node " << x.first << ": ";
348     for (const auto& dim : x.second.dim()) {
349       if (dim.has_dim_param()) {
350         std::cout << dim.dim_param() << " ";
351       } else {
352         std::cout << dim.dim_value() << " ";
353       }
354     }
355     std::cout << "), ";
356     count++;
357     if (count % 10 == 0) {
358       std::cout << '\n';
359     }
360   }
361   std::cout << '\n';
362   std::cout << "SymbolDim Map:" << '\n';
363   count = 0;
364   for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) {
365     std::cout << "(" << x.first << ": " << x.second << "), ";
366     count++;
367     if (count % 10 == 0) {
368       std::cout << '\n';
369     }
370   }
371   std::cout << "DimSymbol Map:" << '\n';
372   count = 0;
373   for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) {
374     std::cout << "(" << x.first << ": " << x.second << "), ";
375     count++;
376     if (count % 10 == 0) {
377       std::cout << '\n';
378     }
379   }
380 }
381 
382 } // namespace torch::jit
383