xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/attributes.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <string>
4 #include <vector>
5 
6 #include <ATen/core/jit_type_base.h>
7 #include <ATen/core/symbol.h>
8 
9 #include <torch/csrc/Export.h>
10 
11 namespace torch::jit {
12 
13 using ::c10::Symbol;
14 
15 constexpr int max_tensor_display_size = 10;
16 
17 enum class AttributeKind {
18   f,
19   fs,
20   c,
21   cs,
22   i,
23   is,
24   s,
25   ss,
26   t,
27   ts,
28   g,
29   gs,
30   ty,
31   tys,
32   ival
33 };
toString(AttributeKind kind)34 static inline const char* toString(AttributeKind kind) {
35   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
36   static const char* names[] = {
37       "f",
38       "c",
39       "cs",
40       "fs",
41       "i",
42       "is",
43       "s",
44       "ss",
45       "t",
46       "ts",
47       "g",
48       "gs",
49       "ty",
50       "tys",
51       "ival"};
52   AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(*names));
53   return names[int(kind)];
54 }
55 
56 struct AttributeValue {
AttributeValueAttributeValue57   AttributeValue(Symbol name) : name(name) {}
58   using Ptr = std::unique_ptr<AttributeValue>;
59   Symbol name;
60   virtual AttributeKind kind() const = 0;
61   virtual Ptr clone() const = 0;
62   virtual ~AttributeValue() = default;
63 };
64 
65 template <typename T, AttributeKind Kind>
66 struct ScalarAttributeValue : public AttributeValue {
67   using ConstructorType = T;
68   using ValueType = T;
ScalarAttributeValueScalarAttributeValue69   ScalarAttributeValue(Symbol name, ConstructorType value_)
70       : AttributeValue(name), value_(std::move(value_)) {}
valueScalarAttributeValue71   ValueType& value() {
72     return value_;
73   }
cloneScalarAttributeValue74   Ptr clone() const override {
75     return Ptr(new ScalarAttributeValue(name, value_));
76   }
kindScalarAttributeValue77   AttributeKind kind() const override {
78     return Kind;
79   }
80 
81  private:
82   ValueType value_;
83 };
84 
85 template <typename T, AttributeKind Kind>
86 struct VectorAttributeValue : public AttributeValue {
87   using ConstructorType = std::vector<T>;
88   using ValueType = std::vector<T>;
89   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
VectorAttributeValueVectorAttributeValue90   VectorAttributeValue(Symbol name, ConstructorType value_)
91       : AttributeValue(name), value_(std::move(value_)) {}
valueVectorAttributeValue92   ValueType& value() {
93     return value_;
94   }
kindVectorAttributeValue95   AttributeKind kind() const override {
96     return Kind;
97   }
cloneVectorAttributeValue98   std::unique_ptr<AttributeValue> clone() const override {
99     auto copy = value_;
100     return Ptr(new VectorAttributeValue(name, std::move(copy)));
101   }
102 
103  private:
104   ValueType value_;
105 };
106 
107 using ComplexAttr =
108     ScalarAttributeValue<c10::complex<double>, AttributeKind::c>;
109 using ComplexValsAttr =
110     VectorAttributeValue<c10::complex<double>, AttributeKind::cs>;
111 using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
112 using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
113 using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
114 using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
115 using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
116 using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
117 using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
118 using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
119 using TypeAttr = ScalarAttributeValue<c10::TypePtr, AttributeKind::ty>;
120 using TypesAttr = VectorAttributeValue<c10::TypePtr, AttributeKind::tys>;
121 using IValueAttr = ScalarAttributeValue<at::IValue, AttributeKind::ival>;
122 
123 struct Graph;
124 
125 // We special case Graph attributes like this because we want to ensure that
126 // Graph::copy() is called when we clone() these attributes.
127 struct TORCH_API GraphAttr : public AttributeValue {
128   using ConstructorType = std::shared_ptr<Graph>;
129   using ValueType = std::shared_ptr<Graph>;
GraphAttrGraphAttr130   GraphAttr(Symbol name, ConstructorType value_)
131       : AttributeValue(name), value_(std::move(value_)) {}
valueGraphAttr132   ValueType& value() {
133     return value_;
134   }
135   Ptr clone() const override;
kindGraphAttr136   AttributeKind kind() const override {
137     return AttributeKind::g;
138   }
139 
140  private:
141   std::shared_ptr<Graph> value_;
142 };
143 
144 struct TORCH_API GraphsAttr : public AttributeValue {
145   using ConstructorType = std::vector<std::shared_ptr<Graph>>;
146   using ValueType = std::vector<std::shared_ptr<Graph>>;
147   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GraphsAttrGraphsAttr148   GraphsAttr(Symbol name, ConstructorType value_)
149       : AttributeValue(name), value_(std::move(value_)) {}
valueGraphsAttr150   ValueType& value() {
151     return value_;
152   }
kindGraphsAttr153   AttributeKind kind() const override {
154     return AttributeKind::gs;
155   }
156   std::unique_ptr<AttributeValue> clone() const override;
157 
158  private:
159   ValueType value_;
160 };
161 
162 struct IRAttributeError : public std::exception {
IRAttributeErrorIRAttributeError163   IRAttributeError(Symbol name, bool defined) {
164     std::stringstream ss;
165     // NOLINTNEXTLINE(bugprone-branch-clone)
166     if (!defined) {
167       ss << "required keyword attribute '" << name.toUnqualString()
168          << "' is undefined";
169     } else {
170       ss << "required keyword attribute '" << name.toUnqualString()
171          << "' has the wrong type";
172     }
173     msg = ss.str();
174   }
whatIRAttributeError175   const char* what() const noexcept override {
176     return msg.c_str();
177   }
178 
179  private:
180   std::string msg;
181 };
182 } // namespace torch::jit
183