1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14
15 #include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
16 #include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
17 #include <executorch/backends/vulkan/runtime/graph/containers/Types.h>
18
19 namespace vkcompute {
20
21 using ValueRef = int32_t;
22
23 constexpr ValueRef kDummyValueRef = -1;
24
is_valid(ValueRef value_ref)25 inline bool is_valid(ValueRef value_ref) {
26 return value_ref >= 0;
27 }
28
29 struct IOValueRef {
30 ValueRef value;
31 ValueRef staging;
32
33 // Custom cast to ValueRef
ValueRefIOValueRef34 operator ValueRef() const {
35 return value;
36 };
37 };
38
39 /*
40 * This class is modelled after c10::IValue; however, it is simplified and does
41 * not support as many types. However, the core design is the same; it is a
42 * tagged union over the types supported by the Vulkan Graph type.
43 */
44 struct Value final {
45 private:
46 /*
47 * The union type which is used to store the value of the Value.
48 */
49 union Payload {
50 /*
51 * Similar to IValue::Payload, trivially copyable types are nested in their
52 * own union.
53 */
54 union TriviallyCopyablePayload {
TriviallyCopyablePayload()55 TriviallyCopyablePayload() : as_int(0) {}
56 int64_t as_int;
57 double as_double;
58 bool as_bool;
59 } u;
60
61 api::vTensor as_tensor;
62 api::StagingBuffer as_staging;
63 TensorRef as_tensorref;
64
65 std::vector<int64_t> as_int_list;
66 std::vector<double> as_double_list;
67 std::vector<bool> as_bool_list;
68
69 // The below is a special type that is used to represent a list of other
70 // values stored in the graph. One application of the type is to represent
71 // a list of tensors or a list of optional tensors.
72 std::vector<ValueRef> as_value_list;
73
74 std::string as_string;
75
76 SymInt as_symint;
77
Payload()78 Payload() : u() {}
79 // NOLINTNEXTLINE
~Payload()80 ~Payload(){};
81 };
82
83 public:
84 //
85 // Copy constructor and assignment (disabled)
86 //
87
88 Value(const Value& rhs) = delete;
89 Value& operator=(const Value&) = delete;
90
91 //
92 // Move constructor and assignment; Move assignment is disabled but
93 // construction is implemented to allow for use in container types.
94 //
95
96 Value& operator=(Value&&) = delete;
97
98 #define CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(type_tag, member_name) \
99 case type_tag: \
100 payload.u.member_name = rhs.payload.u.member_name; \
101 break;
102
103 #define CASE_MOVE_MOVEABLE_TYPE(type_tag, type, member_name, dtor_name) \
104 case type_tag: \
105 new (&payload.member_name) type(std::move(rhs.payload.member_name)); \
106 rhs.payload.member_name.~dtor_name(); \
107 break;
108
Valuefinal109 Value(Value&& rhs) noexcept : tag(rhs.tag) {
110 switch (tag) {
111 // Scalar types
112 CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int);
113 CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double);
114 CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool);
115 // Tensor and tensor adjacent types
116 CASE_MOVE_MOVEABLE_TYPE(
117 TypeTag::TENSOR, api::vTensor, as_tensor, vTensor);
118 CASE_MOVE_MOVEABLE_TYPE(
119 TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer);
120 CASE_MOVE_MOVEABLE_TYPE(
121 TypeTag::TENSORREF, TensorRef, as_tensorref, TensorRef);
122 // Scalar lists
123 CASE_MOVE_MOVEABLE_TYPE(
124 TypeTag::INTLIST, std::vector<int64_t>, as_int_list, vector);
125 CASE_MOVE_MOVEABLE_TYPE(
126 TypeTag::DOUBLELIST, std::vector<double>, as_double_list, vector);
127 CASE_MOVE_MOVEABLE_TYPE(
128 TypeTag::BOOLLIST, std::vector<bool>, as_bool_list, vector);
129 // Special types
130 CASE_MOVE_MOVEABLE_TYPE(
131 TypeTag::VALUELIST, std::vector<ValueRef>, as_value_list, vector);
132 CASE_MOVE_MOVEABLE_TYPE(
133 TypeTag::STRING, std::string, as_string, basic_string);
134 CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt);
135
136 case TypeTag::NONE:
137 clearToNone();
138 break;
139 }
140 rhs.clearToNone();
141 }
142
143 #undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE
144 #undef CASE_MOVE_MOVEABLE_TYPE
145
146 //
147 // Accessors
148 //
149
typefinal150 inline TypeTag type() const {
151 return tag;
152 }
153
154 //
155 // Destructor
156 //
157
~Valuefinal158 ~Value() {
159 switch (tag) {
160 case TypeTag::TENSOR:
161 payload.as_tensor.~vTensor();
162 break;
163 case TypeTag::STAGING:
164 payload.as_staging.~StagingBuffer();
165 break;
166 case TypeTag::TENSORREF:
167 payload.as_tensorref.~TensorRef();
168 break;
169 case TypeTag::INTLIST:
170 payload.as_int_list.~vector();
171 break;
172 case TypeTag::DOUBLELIST:
173 payload.as_double_list.~vector();
174 break;
175 case TypeTag::BOOLLIST:
176 payload.as_bool_list.~vector();
177 break;
178 case TypeTag::VALUELIST:
179 payload.as_value_list.~vector();
180 break;
181 case TypeTag::STRING:
182 payload.as_string.~basic_string();
183 break;
184 case TypeTag::SYMINT:
185 payload.as_symint.~SymInt();
186 break;
187 // Manually list out the types so that if a type here is added later and
188 // not handled the compiler can catch it.
189 case TypeTag::NONE:
190 case TypeTag::INT:
191 case TypeTag::DOUBLE:
192 case TypeTag::BOOL:
193 break;
194 }
195 }
196
197 //
198 // Constructors, isType(), toType()
199 //
200
Valuefinal201 Value() : tag(TypeTag::NONE) {}
202
isNonefinal203 inline bool isNone() const {
204 return tag == TypeTag::NONE;
205 }
206
207 #define SUPPORT_TRIVIALLY_COPYABLE_TYPE( \
208 type, type_name, type_tag, member_name) \
209 explicit Value(type t) : tag(type_tag) { \
210 payload.u.member_name = t; \
211 } \
212 inline bool is##type_name() const { \
213 return tag == type_tag; \
214 } \
215 inline const type& to##type_name() const { \
216 VK_CHECK_COND( \
217 is##type_name(), \
218 "Expected value to have type " #type_name ", got ", \
219 tag, \
220 " instead."); \
221 return payload.u.member_name; \
222 }
223
224 SUPPORT_TRIVIALLY_COPYABLE_TYPE(int64_t, Int, TypeTag::INT, as_int);
225 SUPPORT_TRIVIALLY_COPYABLE_TYPE(double, Double, TypeTag::DOUBLE, as_double);
226 SUPPORT_TRIVIALLY_COPYABLE_TYPE(bool, Bool, TypeTag::BOOL, as_bool);
227
228 #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
229
230 #define SUPPORT_TRIVIALLY_MOVEABLE_TYPE( \
231 type, type_name, type_tag, member_name) \
232 explicit Value(type&& t) : tag(type_tag) { \
233 new (&payload.member_name) type(std::move(t)); \
234 } \
235 inline bool is##type_name() const { \
236 return tag == type_tag; \
237 } \
238 inline type& to##type_name() { \
239 VK_CHECK_COND( \
240 is##type_name(), \
241 "Expected value to have type " #type_name ", got ", \
242 tag, \
243 " instead."); \
244 return payload.member_name; \
245 } \
246 inline const type& toConst##type_name() const { \
247 VK_CHECK_COND( \
248 is##type_name(), \
249 "Expected value to have type " #type_name ", got ", \
250 tag, \
251 " instead."); \
252 return payload.member_name; \
253 }
254
255 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
256 api::vTensor,
257 Tensor,
258 TypeTag::TENSOR,
259 as_tensor);
260
261 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
262 api::StagingBuffer,
263 Staging,
264 TypeTag::STAGING,
265 as_staging);
266
267 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
268 TensorRef,
269 TensorRef,
270 TypeTag::TENSORREF,
271 as_tensorref);
272
273 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
274 std::vector<int64_t>,
275 IntList,
276 TypeTag::INTLIST,
277 as_int_list);
278
279 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
280 std::vector<double>,
281 DoubleList,
282 TypeTag::DOUBLELIST,
283 as_double_list);
284
285 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
286 std::vector<bool>,
287 BoolList,
288 TypeTag::BOOLLIST,
289 as_bool_list);
290
291 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
292 std::vector<ValueRef>,
293 ValueList,
294 TypeTag::VALUELIST,
295 as_value_list);
296
297 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
298 std::string,
299 String,
300 TypeTag::STRING,
301 as_string);
302
303 SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint);
304
305 #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
306 #undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
307
308 private:
309 Payload payload;
310 TypeTag tag;
311
312 //
313 // Utility Functions
314 //
315
clearToNonefinal316 inline void clearToNone() noexcept {
317 payload.u.as_int = -1;
318 tag = TypeTag::NONE;
319 }
320 };
321
322 } // namespace vkcompute
323