1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <string>
4 #include <vector>
5
6 #include <ATen/core/jit_type_base.h>
7 #include <ATen/core/symbol.h>
8
9 #include <torch/csrc/Export.h>
10
11 namespace torch::jit {
12
13 using ::c10::Symbol;
14
15 constexpr int max_tensor_display_size = 10;
16
17 enum class AttributeKind {
18 f,
19 fs,
20 c,
21 cs,
22 i,
23 is,
24 s,
25 ss,
26 t,
27 ts,
28 g,
29 gs,
30 ty,
31 tys,
32 ival
33 };
toString(AttributeKind kind)34 static inline const char* toString(AttributeKind kind) {
35 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
36 static const char* names[] = {
37 "f",
38 "c",
39 "cs",
40 "fs",
41 "i",
42 "is",
43 "s",
44 "ss",
45 "t",
46 "ts",
47 "g",
48 "gs",
49 "ty",
50 "tys",
51 "ival"};
52 AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(*names));
53 return names[int(kind)];
54 }
55
56 struct AttributeValue {
AttributeValueAttributeValue57 AttributeValue(Symbol name) : name(name) {}
58 using Ptr = std::unique_ptr<AttributeValue>;
59 Symbol name;
60 virtual AttributeKind kind() const = 0;
61 virtual Ptr clone() const = 0;
62 virtual ~AttributeValue() = default;
63 };
64
65 template <typename T, AttributeKind Kind>
66 struct ScalarAttributeValue : public AttributeValue {
67 using ConstructorType = T;
68 using ValueType = T;
ScalarAttributeValueScalarAttributeValue69 ScalarAttributeValue(Symbol name, ConstructorType value_)
70 : AttributeValue(name), value_(std::move(value_)) {}
valueScalarAttributeValue71 ValueType& value() {
72 return value_;
73 }
cloneScalarAttributeValue74 Ptr clone() const override {
75 return Ptr(new ScalarAttributeValue(name, value_));
76 }
kindScalarAttributeValue77 AttributeKind kind() const override {
78 return Kind;
79 }
80
81 private:
82 ValueType value_;
83 };
84
85 template <typename T, AttributeKind Kind>
86 struct VectorAttributeValue : public AttributeValue {
87 using ConstructorType = std::vector<T>;
88 using ValueType = std::vector<T>;
89 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
VectorAttributeValueVectorAttributeValue90 VectorAttributeValue(Symbol name, ConstructorType value_)
91 : AttributeValue(name), value_(std::move(value_)) {}
valueVectorAttributeValue92 ValueType& value() {
93 return value_;
94 }
kindVectorAttributeValue95 AttributeKind kind() const override {
96 return Kind;
97 }
cloneVectorAttributeValue98 std::unique_ptr<AttributeValue> clone() const override {
99 auto copy = value_;
100 return Ptr(new VectorAttributeValue(name, std::move(copy)));
101 }
102
103 private:
104 ValueType value_;
105 };
106
107 using ComplexAttr =
108 ScalarAttributeValue<c10::complex<double>, AttributeKind::c>;
109 using ComplexValsAttr =
110 VectorAttributeValue<c10::complex<double>, AttributeKind::cs>;
111 using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
112 using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
113 using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
114 using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
115 using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
116 using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
117 using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
118 using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
119 using TypeAttr = ScalarAttributeValue<c10::TypePtr, AttributeKind::ty>;
120 using TypesAttr = VectorAttributeValue<c10::TypePtr, AttributeKind::tys>;
121 using IValueAttr = ScalarAttributeValue<at::IValue, AttributeKind::ival>;
122
123 struct Graph;
124
125 // We special case Graph attributes like this because we want to ensure that
126 // Graph::copy() is called when we clone() these attributes.
127 struct TORCH_API GraphAttr : public AttributeValue {
128 using ConstructorType = std::shared_ptr<Graph>;
129 using ValueType = std::shared_ptr<Graph>;
GraphAttrGraphAttr130 GraphAttr(Symbol name, ConstructorType value_)
131 : AttributeValue(name), value_(std::move(value_)) {}
valueGraphAttr132 ValueType& value() {
133 return value_;
134 }
135 Ptr clone() const override;
kindGraphAttr136 AttributeKind kind() const override {
137 return AttributeKind::g;
138 }
139
140 private:
141 std::shared_ptr<Graph> value_;
142 };
143
144 struct TORCH_API GraphsAttr : public AttributeValue {
145 using ConstructorType = std::vector<std::shared_ptr<Graph>>;
146 using ValueType = std::vector<std::shared_ptr<Graph>>;
147 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GraphsAttrGraphsAttr148 GraphsAttr(Symbol name, ConstructorType value_)
149 : AttributeValue(name), value_(std::move(value_)) {}
valueGraphsAttr150 ValueType& value() {
151 return value_;
152 }
kindGraphsAttr153 AttributeKind kind() const override {
154 return AttributeKind::gs;
155 }
156 std::unique_ptr<AttributeValue> clone() const override;
157
158 private:
159 ValueType value_;
160 };
161
162 struct IRAttributeError : public std::exception {
IRAttributeErrorIRAttributeError163 IRAttributeError(Symbol name, bool defined) {
164 std::stringstream ss;
165 // NOLINTNEXTLINE(bugprone-branch-clone)
166 if (!defined) {
167 ss << "required keyword attribute '" << name.toUnqualString()
168 << "' is undefined";
169 } else {
170 ss << "required keyword attribute '" << name.toUnqualString()
171 << "' has the wrong type";
172 }
173 msg = ss.str();
174 }
whatIRAttributeError175 const char* what() const noexcept override {
176 return msg.c_str();
177 }
178
179 private:
180 std::string msg;
181 };
182 } // namespace torch::jit
183