1 #pragma once
2
3 #include <c10/macros/Macros.h>
4
5 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
6 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wnewline-eof")
7 #include <onnx/shape_inference/implementation.h>
C10_DIAGNOSTIC_POP()8 C10_DIAGNOSTIC_POP()
9 C10_DIAGNOSTIC_POP()
10
11 #include <torch/csrc/jit/ir/ir.h>
12 #include <torch/csrc/jit/serialization/export.h>
13 #include <mutex>
14 #include <unordered_map>
15
16 namespace torch::jit {
17
18 using ShapeDataMap =
19 std::unordered_map<std::string, ::ONNX_NAMESPACE::TensorShapeProto>;
20
21 class ConstantValueMap {
22 public:
23 static ConstantValueMap& getInstance();
24 static void SetRank(const std::string& tensorName, size_t rankValue);
25 static bool HasRank(const std::string& tensorName);
26 static std::optional<size_t> GetRank(const std::string& tensorName);
27
28 static void SetAllGraphInputsStatic(bool all_static);
29 static std::optional<bool> GetAllGraphInputsStatic();
30
31 static void SetAllGraphInputsReliableComputed(bool computed);
32 static bool GetAllGraphInputsReliableComputed();
33
34 static void SetShape(
35 const std::string& tensorName,
36 const c10::SymbolicShape& shapeValue);
37 static bool HasShape(const std::string& tensorName);
38 static std::optional<c10::SymbolicShape> GetShape(
39 const std::string& tensorName);
40
41 static void SetValue(const std::string& tensorName, const at::Tensor& value);
42 static bool HasValue(const std::string& tensorName);
43 static std::optional<at::Tensor> GetValue(const std::string& tensorName);
44 static void EraseValue(const std::string& tensorName);
45
46 static std::vector<int64_t> GetCompleteShapeInto1DInt64Vector(
47 const c10::SymbolicShape& shape);
48 static std::optional<std::vector<int64_t>> GetShapeInto1DInt64Vector(
49 const std::string& value_name);
50 static std::optional<std::vector<int64_t>>
51 GetShapeInto1DInt64VectorWithOneUnknown(const std::string& value_name);
52 static std::vector<int64_t> GetValueInto1DInt64Vector(
53 const std::string& value_name);
54
55 static void SetTypeReliable(const std::string& tensorName, bool reliable);
56 static bool HasTypeReliable(const std::string& tensorName);
57 static std::optional<bool> GetTypeReliable(const std::string& tensorName);
58
59 static void SetUseInferredType(
60 const std::string& tensorName,
61 bool useInferredType);
62 static bool HasUseInferredType(const std::string& tensorName);
63 static std::optional<bool> GetUseInferredType(const std::string& tensorName);
64
65 static void SetShapeValue(
66 const std::string& tensorName,
67 const c10::SymbolicShape& shapeValue);
68 static bool HasShapeValue(const std::string& tensorName);
69 static std::optional<c10::SymbolicShape> GetShapeValue(
70 const std::string& tensorName);
71
72 static ShapeDataMap& GetInferredShapeData();
73
74 static SymbolDimMap& GetSymbolDimMap();
75 static DimSymbolMap& GetDimSymbolMap();
76
77 static void UpdateValueName(
78 const std::string& old_name,
79 const std::string& new_name);
80
81 static void PrintMaps();
82 static void ClearMaps();
83 ~ConstantValueMap() = default;
84
85 ConstantValueMap& operator=(const ConstantValueMap&) = delete;
86
87 private:
88 ConstantValueMap() = default;
89
90 std::unordered_map<std::string, size_t> rankMap;
91 std::unordered_map<std::string, c10::SymbolicShape> shapeMap;
92 std::unordered_map<std::string, at::Tensor> tensorValueMap;
93 // This map indicates whether the current type is reliably estimated or not.
94 std::unordered_map<std::string, bool> typeReliableMap;
95 // This map indicates whether the current type is estimated through inference
96 // or tracer.
97 std::unordered_map<std::string, bool> useInferredTypeMap;
98 // This map indicates a tensor value which represents a shape.
99 // We assume that the rank of the tensor value <= 1, and we ensure this when
100 // we write the processing logic for the operators. When the rank > 1, we
101 // should be able to rewrite the model so that the rank <= 1. The difference
102 // between shapeMap and shapeValueMap: shapeMap stores the shape of the tensor
103 // from a node. shapeValueMap stores the value of the tensor from a node when
104 // this tensor represents a shape.
105 std::unordered_map<std::string, c10::SymbolicShape> shapeValueMap;
106 // Stores earlier data propagation results so that they are accessible
107 // during future node-level shape inference.
108 ShapeDataMap inferredShapeData;
109 SymbolDimMap symbolDimMap;
110 DimSymbolMap dimSymbolMap;
111 // Stores if all graph-level inputs have static shape
112 std::optional<bool> allGraphInputsStatic;
113 // True if reliable has been computed for all graph inputs
114 bool allGraphInputsReliableComputed{};
115 };
116
117 } // namespace torch::jit
118