xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/constant_map.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 
5 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
6 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wnewline-eof")
7 #include <onnx/shape_inference/implementation.h>
C10_DIAGNOSTIC_POP()8 C10_DIAGNOSTIC_POP()
9 C10_DIAGNOSTIC_POP()
10 
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/serialization/export.h>
13 #include <mutex>
14 #include <unordered_map>
15 
16 namespace torch::jit {
17 
18 using ShapeDataMap =
19     std::unordered_map<std::string, ::ONNX_NAMESPACE::TensorShapeProto>;
20 
21 class ConstantValueMap {
22  public:
23   static ConstantValueMap& getInstance();
24   static void SetRank(const std::string& tensorName, size_t rankValue);
25   static bool HasRank(const std::string& tensorName);
26   static std::optional<size_t> GetRank(const std::string& tensorName);
27 
28   static void SetAllGraphInputsStatic(bool all_static);
29   static std::optional<bool> GetAllGraphInputsStatic();
30 
31   static void SetAllGraphInputsReliableComputed(bool computed);
32   static bool GetAllGraphInputsReliableComputed();
33 
34   static void SetShape(
35       const std::string& tensorName,
36       const c10::SymbolicShape& shapeValue);
37   static bool HasShape(const std::string& tensorName);
38   static std::optional<c10::SymbolicShape> GetShape(
39       const std::string& tensorName);
40 
41   static void SetValue(const std::string& tensorName, const at::Tensor& value);
42   static bool HasValue(const std::string& tensorName);
43   static std::optional<at::Tensor> GetValue(const std::string& tensorName);
44   static void EraseValue(const std::string& tensorName);
45 
46   static std::vector<int64_t> GetCompleteShapeInto1DInt64Vector(
47       const c10::SymbolicShape& shape);
48   static std::optional<std::vector<int64_t>> GetShapeInto1DInt64Vector(
49       const std::string& value_name);
50   static std::optional<std::vector<int64_t>>
51   GetShapeInto1DInt64VectorWithOneUnknown(const std::string& value_name);
52   static std::vector<int64_t> GetValueInto1DInt64Vector(
53       const std::string& value_name);
54 
55   static void SetTypeReliable(const std::string& tensorName, bool reliable);
56   static bool HasTypeReliable(const std::string& tensorName);
57   static std::optional<bool> GetTypeReliable(const std::string& tensorName);
58 
59   static void SetUseInferredType(
60       const std::string& tensorName,
61       bool useInferredType);
62   static bool HasUseInferredType(const std::string& tensorName);
63   static std::optional<bool> GetUseInferredType(const std::string& tensorName);
64 
65   static void SetShapeValue(
66       const std::string& tensorName,
67       const c10::SymbolicShape& shapeValue);
68   static bool HasShapeValue(const std::string& tensorName);
69   static std::optional<c10::SymbolicShape> GetShapeValue(
70       const std::string& tensorName);
71 
72   static ShapeDataMap& GetInferredShapeData();
73 
74   static SymbolDimMap& GetSymbolDimMap();
75   static DimSymbolMap& GetDimSymbolMap();
76 
77   static void UpdateValueName(
78       const std::string& old_name,
79       const std::string& new_name);
80 
81   static void PrintMaps();
82   static void ClearMaps();
83   ~ConstantValueMap() = default;
84 
85   ConstantValueMap& operator=(const ConstantValueMap&) = delete;
86 
87  private:
88   ConstantValueMap() = default;
89 
90   std::unordered_map<std::string, size_t> rankMap;
91   std::unordered_map<std::string, c10::SymbolicShape> shapeMap;
92   std::unordered_map<std::string, at::Tensor> tensorValueMap;
93   // This map indicates whether the current type is reliably estimated or not.
94   std::unordered_map<std::string, bool> typeReliableMap;
95   // This map indicates whether the current type is estimated through inference
96   // or tracer.
97   std::unordered_map<std::string, bool> useInferredTypeMap;
98   // This map indicates a tensor value which represents a shape.
99   // We assume that the rank of the tensor value <= 1, and we ensure this when
100   // we write the processing logic for the operators. When the rank > 1, we
101   // should be able to rewrite the model so that the rank <= 1. The difference
102   // between shapeMap and shapeValueMap: shapeMap stores the shape of the tensor
103   // from a node. shapeValueMap stores the value of the tensor from a node when
104   // this tensor represents a shape.
105   std::unordered_map<std::string, c10::SymbolicShape> shapeValueMap;
106   // Stores earlier data propagation results so that they are accessible
107   // during future node-level shape inference.
108   ShapeDataMap inferredShapeData;
109   SymbolDimMap symbolDimMap;
110   DimSymbolMap dimSymbolMap;
111   // Stores if all graph-level inputs have static shape
112   std::optional<bool> allGraphInputsStatic;
113   // True if reliable has been computed for all graph inputs
114   bool allGraphInputsReliableComputed{};
115 };
116 
117 } // namespace torch::jit
118