xref: /aosp_15_r20/external/pytorch/test/edge/Evalue.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 /**
5  * WARNING: EValue is a class used by Executorch, for its boxed operators. It
6  * contains similar logic as `IValue` in PyTorch, by providing APIs to convert
7  * boxed values to unboxed values.
8  *
9  * It's mirroring a fbcode internal source file
10  * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h).
11  *
12  * The reason why we are mirroring this class, is to make sure we have CI job
13  * coverage on torchgen logic, given that torchgen is used for both Executorch
14  * and PyTorch.
15  *
16  * If any of the logic here needs to be changed, please update fbcode version of
17  * `Evalue.h` as well. These two versions will be merged as soon as Executorch
18  * is in OSS (hopefully by Q2 2023).
19  */
20 namespace torch {
21 namespace executor {
22 
23 #define ET_CHECK_MSG TORCH_CHECK_MSG
24 #define EXECUTORCH_FORALL_TAGS(_) \
25   _(None)                         \
26   _(Tensor)                       \
27   _(String)                       \
28   _(Double)                       \
29   _(Int)                          \
30   _(Bool)                         \
31   _(ListBool)                     \
32   _(ListDouble)                   \
33   _(ListInt)                      \
34   _(ListTensor)                   \
35   _(ListScalar)                   \
36   _(ListOptionalTensor)
37 
38 enum class Tag : uint32_t {
39 #define DEFINE_TAG(x) x,
40   EXECUTORCH_FORALL_TAGS(DEFINE_TAG)
41 #undef DEFINE_TAG
42 };
43 
44 struct EValue;
45 
46 template <typename T>
47 struct evalue_to_const_ref_overload_return {
48   using type = T;
49 };
50 
51 template <>
52 struct evalue_to_const_ref_overload_return<at::Tensor> {
53   using type = const at::Tensor&;
54 };
55 
56 template <typename T>
57 struct evalue_to_ref_overload_return {
58   using type = T;
59 };
60 
61 template <>
62 struct evalue_to_ref_overload_return<at::Tensor> {
63   using type = at::Tensor&;
64 };
65 
66 /*
67  * Helper class used to correlate EValues in the executor table, with the
68  * unwrapped list of the proper type. Because values in the runtime's values
69  * table can change during execution, we cannot statically allocate list of
70  * objects at deserialization. Imagine the serialized list says index 0 in the
71  * value table is element 2 in the list, but during execution the value in
72  * element 2 changes (in the case of tensor this means the TensorImpl* stored in
73  * the tensor changes). To solve this instead they must be created dynamically
74  * whenever they are used.
75  */
76 template <typename T>
77 class EValObjectList {
78  public:
79   EValObjectList() = default;
80   /*
81    * Wrapped_vals is a list of pointers into the values table of the runtime
82    * whose destinations correlate with the elements of the list, unwrapped_vals
83    * is a container of the same size whose serves as memory to construct the
84    * unwrapped vals.
85    */
86   EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size)
87       : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
88   /*
89    * Constructs and returns the list of T specified by the EValue pointers
90    */
91   at::ArrayRef<T> get() const;
92 
93  private:
94   // Source of truth for the list
95   at::ArrayRef<EValue*> wrapped_vals_;
96   // Same size as wrapped_vals
97   mutable T* unwrapped_vals_;
98 };
99 
100 // Aggregate typing system similar to IValue only slimmed down with less
101 // functionality, no dependencies on atomic, and fewer supported types to better
102 // suit embedded systems (ie no intrusive ptr)
103 struct EValue {
104   union Payload {
105     // When in ATen mode at::Tensor is not trivially copyable, this nested union
106     // lets us handle tensor as a special case while leaving the rest of the
107     // fields in a simple state instead of requiring a switch on tag everywhere.
108     union TriviallyCopyablePayload {
109       TriviallyCopyablePayload() : as_int(0) {}
110       // Scalar supported through these 3 types
111       int64_t as_int;
112       double as_double;
113       bool as_bool;
114       // TODO(jakeszwe): convert back to pointers to optimize size of this
115       // struct
116       at::ArrayRef<char> as_string;
117       at::ArrayRef<int64_t> as_int_list;
118       at::ArrayRef<double> as_double_list;
119       at::ArrayRef<bool> as_bool_list;
120       EValObjectList<at::Tensor> as_tensor_list;
121       EValObjectList<std::optional<at::Tensor>> as_list_optional_tensor;
122     } copyable_union;
123 
124     // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor*
125     // here.
126     at::Tensor as_tensor;
127 
128     Payload() {}
129     ~Payload() {}
130   };
131 
132   // Data storage and type tag
133   Payload payload;
134   Tag tag;
135 
136   // Basic ctors and assignments
137   EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {}
138 
139   EValue(EValue&& rhs) noexcept : tag(rhs.tag) {
140     moveFrom(std::move(rhs));
141   }
142 
143   EValue& operator=(EValue&& rhs) & noexcept {
144     if (&rhs == this) {
145       return *this;
146     }
147 
148     destroy();
149     moveFrom(std::move(rhs));
150     return *this;
151   }
152 
153   EValue& operator=(EValue const& rhs) & {
154     // Define copy assignment through copy ctor and move assignment
155     *this = EValue(rhs);
156     return *this;
157   }
158 
159   ~EValue() {
160     destroy();
161   }
162 
163   /****** None Type ******/
164   EValue() : tag(Tag::None) {
165     payload.copyable_union.as_int = 0;
166   }
167 
168   bool isNone() const {
169     return tag == Tag::None;
170   }
171 
172   /****** Int Type ******/
173   /*implicit*/ EValue(int64_t i) : tag(Tag::Int) {
174     payload.copyable_union.as_int = i;
175   }
176 
177   bool isInt() const {
178     return tag == Tag::Int;
179   }
180 
181   int64_t toInt() const {
182     ET_CHECK_MSG(isInt(), "EValue is not an int.");
183     return payload.copyable_union.as_int;
184   }
185 
186   /****** Double Type ******/
187   /*implicit*/ EValue(double d) : tag(Tag::Double) {
188     payload.copyable_union.as_double = d;
189   }
190 
191   bool isDouble() const {
192     return tag == Tag::Double;
193   }
194 
195   double toDouble() const {
196     ET_CHECK_MSG(isDouble(), "EValue is not a Double.");
197     return payload.copyable_union.as_double;
198   }
199 
200   /****** Bool Type ******/
201   /*implicit*/ EValue(bool b) : tag(Tag::Bool) {
202     payload.copyable_union.as_bool = b;
203   }
204 
205   bool isBool() const {
206     return tag == Tag::Bool;
207   }
208 
209   bool toBool() const {
210     ET_CHECK_MSG(isBool(), "EValue is not a Bool.");
211     return payload.copyable_union.as_bool;
212   }
213 
214   /****** Scalar Type ******/
215   /// Construct an EValue using the implicit value of a Scalar.
216   /*implicit*/ EValue(at::Scalar s) {
217     if (s.isIntegral(false)) {
218       tag = Tag::Int;
219       payload.copyable_union.as_int = s.to<int64_t>();
220     } else if (s.isFloatingPoint()) {
221       tag = Tag::Double;
222       payload.copyable_union.as_double = s.to<double>();
223     } else if (s.isBoolean()) {
224       tag = Tag::Bool;
225       payload.copyable_union.as_bool = s.to<bool>();
226     } else {
227       ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized.");
228     }
229   }
230 
231   bool isScalar() const {
232     return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool;
233   }
234 
235   at::Scalar toScalar() const {
236     // Convert from implicit value to Scalar using implicit constructors.
237 
238     if (isDouble()) {
239       return toDouble();
240     } else if (isInt()) {
241       return toInt();
242     } else if (isBool()) {
243       return toBool();
244     } else {
245       ET_CHECK_MSG(false, "EValue is not a Scalar.");
246       return c10::Scalar();
247     }
248   }
249 
250   /****** Tensor Type ******/
251   /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) {
252     // When built in aten mode, at::Tensor has a non trivial constructor
253     // destructor, so regular assignment to a union field is UB. Instead we must
254     // go through placement new (which causes a refcount bump).
255     new (&payload.as_tensor) at::Tensor(t);
256   }
257 
258   bool isTensor() const {
259     return tag == Tag::Tensor;
260   }
261 
262   at::Tensor toTensor() && {
263     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
264     return std::move(payload.as_tensor);
265   }
266 
267   at::Tensor& toTensor() & {
268     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
269     return payload.as_tensor;
270   }
271 
272   const at::Tensor& toTensor() const& {
273     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
274     return payload.as_tensor;
275   }
276 
277   /****** String Type ******/
278   /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) {
279     payload.copyable_union.as_string = at::ArrayRef<char>(s, size);
280   }
281 
282   bool isString() const {
283     return tag == Tag::String;
284   }
285 
286   at::string_view toString() const {
287     ET_CHECK_MSG(isString(), "EValue is not a String.");
288     return at::string_view(
289         payload.copyable_union.as_string.data(),
290         payload.copyable_union.as_string.size());
291   }
292 
293   /****** Int List Type ******/
294   /*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) {
295     payload.copyable_union.as_int_list = i;
296   }
297 
298   bool isIntList() const {
299     return tag == Tag::ListInt;
300   }
301 
302   at::ArrayRef<int64_t> toIntList() const {
303     ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
304     return payload.copyable_union.as_int_list;
305   }
306 
307   /****** Bool List Type ******/
308   /*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) {
309     payload.copyable_union.as_bool_list = b;
310   }
311 
312   bool isBoolList() const {
313     return tag == Tag::ListBool;
314   }
315 
316   at::ArrayRef<bool> toBoolList() const {
317     ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
318     return payload.copyable_union.as_bool_list;
319   }
320 
321   /****** Double List Type ******/
322   /*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) {
323     payload.copyable_union.as_double_list = d;
324   }
325 
326   bool isDoubleList() const {
327     return tag == Tag::ListDouble;
328   }
329 
330   at::ArrayRef<double> toDoubleList() const {
331     ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
332     return payload.copyable_union.as_double_list;
333   }
334 
335   /****** Tensor List Type ******/
336   /*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) {
337     payload.copyable_union.as_tensor_list = t;
338   }
339 
340   bool isTensorList() const {
341     return tag == Tag::ListTensor;
342   }
343 
344   at::ArrayRef<at::Tensor> toTensorList() const {
345     ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
346     return payload.copyable_union.as_tensor_list.get();
347   }
348 
349   /****** List Optional Tensor Type ******/
350   /*implicit*/ EValue(EValObjectList<std::optional<at::Tensor>> t)
351       : tag(Tag::ListOptionalTensor) {
352     payload.copyable_union.as_list_optional_tensor = t;
353   }
354 
355   bool isListOptionalTensor() const {
356     return tag == Tag::ListOptionalTensor;
357   }
358 
359   at::ArrayRef<std::optional<at::Tensor>> toListOptionalTensor() {
360     return payload.copyable_union.as_list_optional_tensor.get();
361   }
362 
363   /****** ScalarType Type ******/
364   at::ScalarType toScalarType() const {
365     ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
366     return static_cast<at::ScalarType>(payload.copyable_union.as_int);
367   }
368 
369   /****** MemoryFormat Type ******/
370   at::MemoryFormat toMemoryFormat() const {
371     ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
372     return static_cast<at::MemoryFormat>(payload.copyable_union.as_int);
373   }
374 
375   template <typename T>
376   T to() &&;
377 
378   template <typename T>
379   typename evalue_to_ref_overload_return<T>::type to() &;
380 
381   /**
382    * Converts the EValue to an optional object that can represent both T and
383    * an uninitialized state.
384    */
385   template <typename T>
386   inline std::optional<T> toOptional() {
387     if (this->isNone()) {
388       return std::nullopt;
389     }
390     return this->to<T>();
391   }
392 
393  private:
394   // Pre cond: the payload value has had its destructor called
395   void clearToNone() noexcept {
396     payload.copyable_union.as_int = 0;
397     tag = Tag::None;
398   }
399 
400   // Shared move logic
401   void moveFrom(EValue&& rhs) noexcept {
402     if (rhs.isTensor()) {
403       new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
404       rhs.payload.as_tensor.~Tensor();
405     } else {
406       payload.copyable_union = rhs.payload.copyable_union;
407     }
408     tag = rhs.tag;
409     rhs.clearToNone();
410   }
411 
412   // Destructs stored tensor if there is one
413   void destroy() {
414     // Necessary for ATen tensor to refcount decrement the intrusive_ptr to
415     // tensorimpl that got a refcount increment when we placed it in the evalue,
416     // no-op if executorch tensor #ifdef could have a
417     // minor performance bump for a code maintainability hit
418     if (isTensor()) {
419       payload.as_tensor.~Tensor();
420     } else if (isTensorList()) {
421       for (auto& tensor : toTensorList()) {
422         tensor.~Tensor();
423       }
424     } else if (isListOptionalTensor()) {
425       for (auto& optional_tensor : toListOptionalTensor()) {
426         optional_tensor.~optional();
427       }
428     }
429   }
430 
431   EValue(const Payload& p, Tag t) : tag(t) {
432     if (isTensor()) {
433       new (&payload.as_tensor) at::Tensor(p.as_tensor);
434     } else {
435       payload.copyable_union = p.copyable_union;
436     }
437   }
438 };
439 
440 #define EVALUE_DEFINE_TO(T, method_name)                           \
441   template <>                                                      \
442   inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \
443     return static_cast<T>(this->method_name());                    \
444   }
445 
446 template <>
447 inline at::Tensor& EValue::to<at::Tensor>() & {
448   return this->toTensor();
449 }
450 
451 EVALUE_DEFINE_TO(at::Scalar, toScalar)
452 EVALUE_DEFINE_TO(int64_t, toInt)
453 EVALUE_DEFINE_TO(bool, toBool)
454 EVALUE_DEFINE_TO(double, toDouble)
455 EVALUE_DEFINE_TO(at::string_view, toString)
456 EVALUE_DEFINE_TO(at::ScalarType, toScalarType)
457 EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat)
458 EVALUE_DEFINE_TO(std::optional<at::Tensor>, toOptional<at::Tensor>)
459 EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList)
460 EVALUE_DEFINE_TO(
461     std::optional<at::ArrayRef<int64_t>>,
462     toOptional<at::ArrayRef<int64_t>>)
463 EVALUE_DEFINE_TO(
464     std::optional<at::ArrayRef<double>>,
465     toOptional<at::ArrayRef<double>>)
466 EVALUE_DEFINE_TO(at::ArrayRef<std::optional<at::Tensor>>, toListOptionalTensor)
467 EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList)
468 #undef EVALUE_DEFINE_TO
469 
470 template <typename T>
471 at::ArrayRef<T> EValObjectList<T>::get() const {
472   for (size_t i = 0; i < wrapped_vals_.size(); i++) {
473     unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>();
474   }
475   return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
476 }
477 
478 } // namespace executor
479 } // namespace torch
480