xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/containers/Value.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 
15 #include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
16 #include <executorch/backends/vulkan/runtime/graph/containers/SymInt.h>
17 #include <executorch/backends/vulkan/runtime/graph/containers/Types.h>
18 
19 namespace vkcompute {
20 
21 using ValueRef = int32_t;
22 
23 constexpr ValueRef kDummyValueRef = -1;
24 
is_valid(ValueRef value_ref)25 inline bool is_valid(ValueRef value_ref) {
26   return value_ref >= 0;
27 }
28 
29 struct IOValueRef {
30   ValueRef value;
31   ValueRef staging;
32 
33   // Custom cast to ValueRef
ValueRefIOValueRef34   operator ValueRef() const {
35     return value;
36   };
37 };
38 
39 /*
40  * This class is modelled after c10::IValue; however, it is simplified and does
41  * not support as many types. However, the core design is the same; it is a
42  * tagged union over the types supported by the Vulkan Graph type.
43  */
44 struct Value final {
45  private:
46   /*
47    * The union type which is used to store the value of the Value.
48    */
49   union Payload {
50     /*
51      * Similar to IValue::Payload, trivially copyable types are nested in their
52      * own union.
53      */
54     union TriviallyCopyablePayload {
TriviallyCopyablePayload()55       TriviallyCopyablePayload() : as_int(0) {}
56       int64_t as_int;
57       double as_double;
58       bool as_bool;
59     } u;
60 
61     api::vTensor as_tensor;
62     api::StagingBuffer as_staging;
63     TensorRef as_tensorref;
64 
65     std::vector<int64_t> as_int_list;
66     std::vector<double> as_double_list;
67     std::vector<bool> as_bool_list;
68 
69     // The below is a special type that is used to represent a list of other
70     // values stored in the graph. One application of the type is to represent
71     // a list of tensors or a list of optional tensors.
72     std::vector<ValueRef> as_value_list;
73 
74     std::string as_string;
75 
76     SymInt as_symint;
77 
Payload()78     Payload() : u() {}
79     // NOLINTNEXTLINE
~Payload()80     ~Payload(){};
81   };
82 
83  public:
84   //
85   // Copy constructor and assignment (disabled)
86   //
87 
88   Value(const Value& rhs) = delete;
89   Value& operator=(const Value&) = delete;
90 
91   //
92   // Move constructor and assignment; Move assignment is disabled but
93   // construction is implemented to allow for use in container types.
94   //
95 
96   Value& operator=(Value&&) = delete;
97 
98 #define CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(type_tag, member_name) \
99   case type_tag:                                                 \
100     payload.u.member_name = rhs.payload.u.member_name;           \
101     break;
102 
103 #define CASE_MOVE_MOVEABLE_TYPE(type_tag, type, member_name, dtor_name)  \
104   case type_tag:                                                         \
105     new (&payload.member_name) type(std::move(rhs.payload.member_name)); \
106     rhs.payload.member_name.~dtor_name();                                \
107     break;
108 
Valuefinal109   Value(Value&& rhs) noexcept : tag(rhs.tag) {
110     switch (tag) {
111       // Scalar types
112       CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int);
113       CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double);
114       CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool);
115       // Tensor and tensor adjacent types
116       CASE_MOVE_MOVEABLE_TYPE(
117           TypeTag::TENSOR, api::vTensor, as_tensor, vTensor);
118       CASE_MOVE_MOVEABLE_TYPE(
119           TypeTag::STAGING, api::StagingBuffer, as_staging, StagingBuffer);
120       CASE_MOVE_MOVEABLE_TYPE(
121           TypeTag::TENSORREF, TensorRef, as_tensorref, TensorRef);
122       // Scalar lists
123       CASE_MOVE_MOVEABLE_TYPE(
124           TypeTag::INTLIST, std::vector<int64_t>, as_int_list, vector);
125       CASE_MOVE_MOVEABLE_TYPE(
126           TypeTag::DOUBLELIST, std::vector<double>, as_double_list, vector);
127       CASE_MOVE_MOVEABLE_TYPE(
128           TypeTag::BOOLLIST, std::vector<bool>, as_bool_list, vector);
129       // Special types
130       CASE_MOVE_MOVEABLE_TYPE(
131           TypeTag::VALUELIST, std::vector<ValueRef>, as_value_list, vector);
132       CASE_MOVE_MOVEABLE_TYPE(
133           TypeTag::STRING, std::string, as_string, basic_string);
134       CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt);
135 
136       case TypeTag::NONE:
137         clearToNone();
138         break;
139     }
140     rhs.clearToNone();
141   }
142 
143 #undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE
144 #undef CASE_MOVE_MOVEABLE_TYPE
145 
146   //
147   // Accessors
148   //
149 
typefinal150   inline TypeTag type() const {
151     return tag;
152   }
153 
154   //
155   // Destructor
156   //
157 
~Valuefinal158   ~Value() {
159     switch (tag) {
160       case TypeTag::TENSOR:
161         payload.as_tensor.~vTensor();
162         break;
163       case TypeTag::STAGING:
164         payload.as_staging.~StagingBuffer();
165         break;
166       case TypeTag::TENSORREF:
167         payload.as_tensorref.~TensorRef();
168         break;
169       case TypeTag::INTLIST:
170         payload.as_int_list.~vector();
171         break;
172       case TypeTag::DOUBLELIST:
173         payload.as_double_list.~vector();
174         break;
175       case TypeTag::BOOLLIST:
176         payload.as_bool_list.~vector();
177         break;
178       case TypeTag::VALUELIST:
179         payload.as_value_list.~vector();
180         break;
181       case TypeTag::STRING:
182         payload.as_string.~basic_string();
183         break;
184       case TypeTag::SYMINT:
185         payload.as_symint.~SymInt();
186         break;
187       // Manually list out the types so that if a type here is added later and
188       // not handled the compiler can catch it.
189       case TypeTag::NONE:
190       case TypeTag::INT:
191       case TypeTag::DOUBLE:
192       case TypeTag::BOOL:
193         break;
194     }
195   }
196 
197   //
198   // Constructors, isType(), toType()
199   //
200 
Valuefinal201   Value() : tag(TypeTag::NONE) {}
202 
isNonefinal203   inline bool isNone() const {
204     return tag == TypeTag::NONE;
205   }
206 
207 #define SUPPORT_TRIVIALLY_COPYABLE_TYPE(                    \
208     type, type_name, type_tag, member_name)                 \
209   explicit Value(type t) : tag(type_tag) {                  \
210     payload.u.member_name = t;                              \
211   }                                                         \
212   inline bool is##type_name() const {                       \
213     return tag == type_tag;                                 \
214   }                                                         \
215   inline const type& to##type_name() const {                \
216     VK_CHECK_COND(                                          \
217         is##type_name(),                                    \
218         "Expected value to have type " #type_name ", got ", \
219         tag,                                                \
220         " instead.");                                       \
221     return payload.u.member_name;                           \
222   }
223 
224   SUPPORT_TRIVIALLY_COPYABLE_TYPE(int64_t, Int, TypeTag::INT, as_int);
225   SUPPORT_TRIVIALLY_COPYABLE_TYPE(double, Double, TypeTag::DOUBLE, as_double);
226   SUPPORT_TRIVIALLY_COPYABLE_TYPE(bool, Bool, TypeTag::BOOL, as_bool);
227 
228 #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
229 
230 #define SUPPORT_TRIVIALLY_MOVEABLE_TYPE(                    \
231     type, type_name, type_tag, member_name)                 \
232   explicit Value(type&& t) : tag(type_tag) {                \
233     new (&payload.member_name) type(std::move(t));          \
234   }                                                         \
235   inline bool is##type_name() const {                       \
236     return tag == type_tag;                                 \
237   }                                                         \
238   inline type& to##type_name() {                            \
239     VK_CHECK_COND(                                          \
240         is##type_name(),                                    \
241         "Expected value to have type " #type_name ", got ", \
242         tag,                                                \
243         " instead.");                                       \
244     return payload.member_name;                             \
245   }                                                         \
246   inline const type& toConst##type_name() const {           \
247     VK_CHECK_COND(                                          \
248         is##type_name(),                                    \
249         "Expected value to have type " #type_name ", got ", \
250         tag,                                                \
251         " instead.");                                       \
252     return payload.member_name;                             \
253   }
254 
255   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
256       api::vTensor,
257       Tensor,
258       TypeTag::TENSOR,
259       as_tensor);
260 
261   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
262       api::StagingBuffer,
263       Staging,
264       TypeTag::STAGING,
265       as_staging);
266 
267   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
268       TensorRef,
269       TensorRef,
270       TypeTag::TENSORREF,
271       as_tensorref);
272 
273   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
274       std::vector<int64_t>,
275       IntList,
276       TypeTag::INTLIST,
277       as_int_list);
278 
279   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
280       std::vector<double>,
281       DoubleList,
282       TypeTag::DOUBLELIST,
283       as_double_list);
284 
285   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
286       std::vector<bool>,
287       BoolList,
288       TypeTag::BOOLLIST,
289       as_bool_list);
290 
291   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
292       std::vector<ValueRef>,
293       ValueList,
294       TypeTag::VALUELIST,
295       as_value_list);
296 
297   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(
298       std::string,
299       String,
300       TypeTag::STRING,
301       as_string);
302 
303   SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint);
304 
305 #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE
306 #undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE
307 
308  private:
309   Payload payload;
310   TypeTag tag;
311 
312   //
313   // Utility Functions
314   //
315 
clearToNonefinal316   inline void clearToNone() noexcept {
317     payload.u.as_int = -1;
318     tag = TypeTag::NONE;
319   }
320 };
321 
322 } // namespace vkcompute
323