xref: /aosp_15_r20/external/pytorch/test/edge/Evalue.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
4*da0073e9SAndroid Build Coastguard Worker /**
5*da0073e9SAndroid Build Coastguard Worker  * WARNING: EValue is a class used by Executorch, for its boxed operators. It
6*da0073e9SAndroid Build Coastguard Worker  * contains similar logic as `IValue` in PyTorch, by providing APIs to convert
7*da0073e9SAndroid Build Coastguard Worker  * boxed values to unboxed values.
8*da0073e9SAndroid Build Coastguard Worker  *
9*da0073e9SAndroid Build Coastguard Worker  * It's mirroring a fbcode internal source file
10*da0073e9SAndroid Build Coastguard Worker  * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h).
11*da0073e9SAndroid Build Coastguard Worker  *
12*da0073e9SAndroid Build Coastguard Worker  * The reason why we are mirroring this class, is to make sure we have CI job
13*da0073e9SAndroid Build Coastguard Worker  * coverage on torchgen logic, given that torchgen is used for both Executorch
14*da0073e9SAndroid Build Coastguard Worker  * and PyTorch.
15*da0073e9SAndroid Build Coastguard Worker  *
16*da0073e9SAndroid Build Coastguard Worker  * If any of the logic here needs to be changed, please update fbcode version of
17*da0073e9SAndroid Build Coastguard Worker  * `Evalue.h` as well. These two versions will be merged as soon as Executorch
18*da0073e9SAndroid Build Coastguard Worker  * is in OSS (hopefully by Q2 2023).
19*da0073e9SAndroid Build Coastguard Worker  */
20*da0073e9SAndroid Build Coastguard Worker namespace torch {
21*da0073e9SAndroid Build Coastguard Worker namespace executor {
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker #define ET_CHECK_MSG TORCH_CHECK_MSG
24*da0073e9SAndroid Build Coastguard Worker #define EXECUTORCH_FORALL_TAGS(_) \
25*da0073e9SAndroid Build Coastguard Worker   _(None)                         \
26*da0073e9SAndroid Build Coastguard Worker   _(Tensor)                       \
27*da0073e9SAndroid Build Coastguard Worker   _(String)                       \
28*da0073e9SAndroid Build Coastguard Worker   _(Double)                       \
29*da0073e9SAndroid Build Coastguard Worker   _(Int)                          \
30*da0073e9SAndroid Build Coastguard Worker   _(Bool)                         \
31*da0073e9SAndroid Build Coastguard Worker   _(ListBool)                     \
32*da0073e9SAndroid Build Coastguard Worker   _(ListDouble)                   \
33*da0073e9SAndroid Build Coastguard Worker   _(ListInt)                      \
34*da0073e9SAndroid Build Coastguard Worker   _(ListTensor)                   \
35*da0073e9SAndroid Build Coastguard Worker   _(ListScalar)                   \
36*da0073e9SAndroid Build Coastguard Worker   _(ListOptionalTensor)
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker enum class Tag : uint32_t {
39*da0073e9SAndroid Build Coastguard Worker #define DEFINE_TAG(x) x,
40*da0073e9SAndroid Build Coastguard Worker   EXECUTORCH_FORALL_TAGS(DEFINE_TAG)
41*da0073e9SAndroid Build Coastguard Worker #undef DEFINE_TAG
42*da0073e9SAndroid Build Coastguard Worker };
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker struct EValue;
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker template <typename T>
47*da0073e9SAndroid Build Coastguard Worker struct evalue_to_const_ref_overload_return {
48*da0073e9SAndroid Build Coastguard Worker   using type = T;
49*da0073e9SAndroid Build Coastguard Worker };
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker template <>
52*da0073e9SAndroid Build Coastguard Worker struct evalue_to_const_ref_overload_return<at::Tensor> {
53*da0073e9SAndroid Build Coastguard Worker   using type = const at::Tensor&;
54*da0073e9SAndroid Build Coastguard Worker };
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker template <typename T>
57*da0073e9SAndroid Build Coastguard Worker struct evalue_to_ref_overload_return {
58*da0073e9SAndroid Build Coastguard Worker   using type = T;
59*da0073e9SAndroid Build Coastguard Worker };
60*da0073e9SAndroid Build Coastguard Worker 
61*da0073e9SAndroid Build Coastguard Worker template <>
62*da0073e9SAndroid Build Coastguard Worker struct evalue_to_ref_overload_return<at::Tensor> {
63*da0073e9SAndroid Build Coastguard Worker   using type = at::Tensor&;
64*da0073e9SAndroid Build Coastguard Worker };
65*da0073e9SAndroid Build Coastguard Worker 
66*da0073e9SAndroid Build Coastguard Worker /*
67*da0073e9SAndroid Build Coastguard Worker  * Helper class used to correlate EValues in the executor table, with the
68*da0073e9SAndroid Build Coastguard Worker  * unwrapped list of the proper type. Because values in the runtime's values
69*da0073e9SAndroid Build Coastguard Worker  * table can change during execution, we cannot statically allocate list of
70*da0073e9SAndroid Build Coastguard Worker  * objects at deserialization. Imagine the serialized list says index 0 in the
71*da0073e9SAndroid Build Coastguard Worker  * value table is element 2 in the list, but during execution the value in
72*da0073e9SAndroid Build Coastguard Worker  * element 2 changes (in the case of tensor this means the TensorImpl* stored in
73*da0073e9SAndroid Build Coastguard Worker  * the tensor changes). To solve this instead they must be created dynamically
74*da0073e9SAndroid Build Coastguard Worker  * whenever they are used.
75*da0073e9SAndroid Build Coastguard Worker  */
76*da0073e9SAndroid Build Coastguard Worker template <typename T>
77*da0073e9SAndroid Build Coastguard Worker class EValObjectList {
78*da0073e9SAndroid Build Coastguard Worker  public:
79*da0073e9SAndroid Build Coastguard Worker   EValObjectList() = default;
80*da0073e9SAndroid Build Coastguard Worker   /*
81*da0073e9SAndroid Build Coastguard Worker    * Wrapped_vals is a list of pointers into the values table of the runtime
82*da0073e9SAndroid Build Coastguard Worker    * whose destinations correlate with the elements of the list, unwrapped_vals
83*da0073e9SAndroid Build Coastguard Worker    * is a container of the same size whose serves as memory to construct the
84*da0073e9SAndroid Build Coastguard Worker    * unwrapped vals.
85*da0073e9SAndroid Build Coastguard Worker    */
86*da0073e9SAndroid Build Coastguard Worker   EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size)
87*da0073e9SAndroid Build Coastguard Worker       : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
88*da0073e9SAndroid Build Coastguard Worker   /*
89*da0073e9SAndroid Build Coastguard Worker    * Constructs and returns the list of T specified by the EValue pointers
90*da0073e9SAndroid Build Coastguard Worker    */
91*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<T> get() const;
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker  private:
94*da0073e9SAndroid Build Coastguard Worker   // Source of truth for the list
95*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<EValue*> wrapped_vals_;
96*da0073e9SAndroid Build Coastguard Worker   // Same size as wrapped_vals
97*da0073e9SAndroid Build Coastguard Worker   mutable T* unwrapped_vals_;
98*da0073e9SAndroid Build Coastguard Worker };
99*da0073e9SAndroid Build Coastguard Worker 
100*da0073e9SAndroid Build Coastguard Worker // Aggregate typing system similar to IValue only slimmed down with less
101*da0073e9SAndroid Build Coastguard Worker // functionality, no dependencies on atomic, and fewer supported types to better
102*da0073e9SAndroid Build Coastguard Worker // suit embedded systems (ie no intrusive ptr)
103*da0073e9SAndroid Build Coastguard Worker struct EValue {
104*da0073e9SAndroid Build Coastguard Worker   union Payload {
105*da0073e9SAndroid Build Coastguard Worker     // When in ATen mode at::Tensor is not trivially copyable, this nested union
106*da0073e9SAndroid Build Coastguard Worker     // lets us handle tensor as a special case while leaving the rest of the
107*da0073e9SAndroid Build Coastguard Worker     // fields in a simple state instead of requiring a switch on tag everywhere.
108*da0073e9SAndroid Build Coastguard Worker     union TriviallyCopyablePayload {
109*da0073e9SAndroid Build Coastguard Worker       TriviallyCopyablePayload() : as_int(0) {}
110*da0073e9SAndroid Build Coastguard Worker       // Scalar supported through these 3 types
111*da0073e9SAndroid Build Coastguard Worker       int64_t as_int;
112*da0073e9SAndroid Build Coastguard Worker       double as_double;
113*da0073e9SAndroid Build Coastguard Worker       bool as_bool;
114*da0073e9SAndroid Build Coastguard Worker       // TODO(jakeszwe): convert back to pointers to optimize size of this
115*da0073e9SAndroid Build Coastguard Worker       // struct
116*da0073e9SAndroid Build Coastguard Worker       at::ArrayRef<char> as_string;
117*da0073e9SAndroid Build Coastguard Worker       at::ArrayRef<int64_t> as_int_list;
118*da0073e9SAndroid Build Coastguard Worker       at::ArrayRef<double> as_double_list;
119*da0073e9SAndroid Build Coastguard Worker       at::ArrayRef<bool> as_bool_list;
120*da0073e9SAndroid Build Coastguard Worker       EValObjectList<at::Tensor> as_tensor_list;
121*da0073e9SAndroid Build Coastguard Worker       EValObjectList<std::optional<at::Tensor>> as_list_optional_tensor;
122*da0073e9SAndroid Build Coastguard Worker     } copyable_union;
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker     // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor*
125*da0073e9SAndroid Build Coastguard Worker     // here.
126*da0073e9SAndroid Build Coastguard Worker     at::Tensor as_tensor;
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker     Payload() {}
129*da0073e9SAndroid Build Coastguard Worker     ~Payload() {}
130*da0073e9SAndroid Build Coastguard Worker   };
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker   // Data storage and type tag
133*da0073e9SAndroid Build Coastguard Worker   Payload payload;
134*da0073e9SAndroid Build Coastguard Worker   Tag tag;
135*da0073e9SAndroid Build Coastguard Worker 
136*da0073e9SAndroid Build Coastguard Worker   // Basic ctors and assignments
137*da0073e9SAndroid Build Coastguard Worker   EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {}
138*da0073e9SAndroid Build Coastguard Worker 
139*da0073e9SAndroid Build Coastguard Worker   EValue(EValue&& rhs) noexcept : tag(rhs.tag) {
140*da0073e9SAndroid Build Coastguard Worker     moveFrom(std::move(rhs));
141*da0073e9SAndroid Build Coastguard Worker   }
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   EValue& operator=(EValue&& rhs) & noexcept {
144*da0073e9SAndroid Build Coastguard Worker     if (&rhs == this) {
145*da0073e9SAndroid Build Coastguard Worker       return *this;
146*da0073e9SAndroid Build Coastguard Worker     }
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker     destroy();
149*da0073e9SAndroid Build Coastguard Worker     moveFrom(std::move(rhs));
150*da0073e9SAndroid Build Coastguard Worker     return *this;
151*da0073e9SAndroid Build Coastguard Worker   }
152*da0073e9SAndroid Build Coastguard Worker 
153*da0073e9SAndroid Build Coastguard Worker   EValue& operator=(EValue const& rhs) & {
154*da0073e9SAndroid Build Coastguard Worker     // Define copy assignment through copy ctor and move assignment
155*da0073e9SAndroid Build Coastguard Worker     *this = EValue(rhs);
156*da0073e9SAndroid Build Coastguard Worker     return *this;
157*da0073e9SAndroid Build Coastguard Worker   }
158*da0073e9SAndroid Build Coastguard Worker 
159*da0073e9SAndroid Build Coastguard Worker   ~EValue() {
160*da0073e9SAndroid Build Coastguard Worker     destroy();
161*da0073e9SAndroid Build Coastguard Worker   }
162*da0073e9SAndroid Build Coastguard Worker 
163*da0073e9SAndroid Build Coastguard Worker   /****** None Type ******/
164*da0073e9SAndroid Build Coastguard Worker   EValue() : tag(Tag::None) {
165*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_int = 0;
166*da0073e9SAndroid Build Coastguard Worker   }
167*da0073e9SAndroid Build Coastguard Worker 
168*da0073e9SAndroid Build Coastguard Worker   bool isNone() const {
169*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::None;
170*da0073e9SAndroid Build Coastguard Worker   }
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker   /****** Int Type ******/
173*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(int64_t i) : tag(Tag::Int) {
174*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_int = i;
175*da0073e9SAndroid Build Coastguard Worker   }
176*da0073e9SAndroid Build Coastguard Worker 
177*da0073e9SAndroid Build Coastguard Worker   bool isInt() const {
178*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::Int;
179*da0073e9SAndroid Build Coastguard Worker   }
180*da0073e9SAndroid Build Coastguard Worker 
181*da0073e9SAndroid Build Coastguard Worker   int64_t toInt() const {
182*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isInt(), "EValue is not an int.");
183*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_int;
184*da0073e9SAndroid Build Coastguard Worker   }
185*da0073e9SAndroid Build Coastguard Worker 
186*da0073e9SAndroid Build Coastguard Worker   /****** Double Type ******/
187*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(double d) : tag(Tag::Double) {
188*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_double = d;
189*da0073e9SAndroid Build Coastguard Worker   }
190*da0073e9SAndroid Build Coastguard Worker 
191*da0073e9SAndroid Build Coastguard Worker   bool isDouble() const {
192*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::Double;
193*da0073e9SAndroid Build Coastguard Worker   }
194*da0073e9SAndroid Build Coastguard Worker 
195*da0073e9SAndroid Build Coastguard Worker   double toDouble() const {
196*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isDouble(), "EValue is not a Double.");
197*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_double;
198*da0073e9SAndroid Build Coastguard Worker   }
199*da0073e9SAndroid Build Coastguard Worker 
200*da0073e9SAndroid Build Coastguard Worker   /****** Bool Type ******/
201*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(bool b) : tag(Tag::Bool) {
202*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_bool = b;
203*da0073e9SAndroid Build Coastguard Worker   }
204*da0073e9SAndroid Build Coastguard Worker 
205*da0073e9SAndroid Build Coastguard Worker   bool isBool() const {
206*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::Bool;
207*da0073e9SAndroid Build Coastguard Worker   }
208*da0073e9SAndroid Build Coastguard Worker 
209*da0073e9SAndroid Build Coastguard Worker   bool toBool() const {
210*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isBool(), "EValue is not a Bool.");
211*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_bool;
212*da0073e9SAndroid Build Coastguard Worker   }
213*da0073e9SAndroid Build Coastguard Worker 
214*da0073e9SAndroid Build Coastguard Worker   /****** Scalar Type ******/
215*da0073e9SAndroid Build Coastguard Worker   /// Construct an EValue using the implicit value of a Scalar.
216*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(at::Scalar s) {
217*da0073e9SAndroid Build Coastguard Worker     if (s.isIntegral(false)) {
218*da0073e9SAndroid Build Coastguard Worker       tag = Tag::Int;
219*da0073e9SAndroid Build Coastguard Worker       payload.copyable_union.as_int = s.to<int64_t>();
220*da0073e9SAndroid Build Coastguard Worker     } else if (s.isFloatingPoint()) {
221*da0073e9SAndroid Build Coastguard Worker       tag = Tag::Double;
222*da0073e9SAndroid Build Coastguard Worker       payload.copyable_union.as_double = s.to<double>();
223*da0073e9SAndroid Build Coastguard Worker     } else if (s.isBoolean()) {
224*da0073e9SAndroid Build Coastguard Worker       tag = Tag::Bool;
225*da0073e9SAndroid Build Coastguard Worker       payload.copyable_union.as_bool = s.to<bool>();
226*da0073e9SAndroid Build Coastguard Worker     } else {
227*da0073e9SAndroid Build Coastguard Worker       ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized.");
228*da0073e9SAndroid Build Coastguard Worker     }
229*da0073e9SAndroid Build Coastguard Worker   }
230*da0073e9SAndroid Build Coastguard Worker 
231*da0073e9SAndroid Build Coastguard Worker   bool isScalar() const {
232*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool;
233*da0073e9SAndroid Build Coastguard Worker   }
234*da0073e9SAndroid Build Coastguard Worker 
235*da0073e9SAndroid Build Coastguard Worker   at::Scalar toScalar() const {
236*da0073e9SAndroid Build Coastguard Worker     // Convert from implicit value to Scalar using implicit constructors.
237*da0073e9SAndroid Build Coastguard Worker 
238*da0073e9SAndroid Build Coastguard Worker     if (isDouble()) {
239*da0073e9SAndroid Build Coastguard Worker       return toDouble();
240*da0073e9SAndroid Build Coastguard Worker     } else if (isInt()) {
241*da0073e9SAndroid Build Coastguard Worker       return toInt();
242*da0073e9SAndroid Build Coastguard Worker     } else if (isBool()) {
243*da0073e9SAndroid Build Coastguard Worker       return toBool();
244*da0073e9SAndroid Build Coastguard Worker     } else {
245*da0073e9SAndroid Build Coastguard Worker       ET_CHECK_MSG(false, "EValue is not a Scalar.");
246*da0073e9SAndroid Build Coastguard Worker       return c10::Scalar();
247*da0073e9SAndroid Build Coastguard Worker     }
248*da0073e9SAndroid Build Coastguard Worker   }
249*da0073e9SAndroid Build Coastguard Worker 
250*da0073e9SAndroid Build Coastguard Worker   /****** Tensor Type ******/
251*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) {
252*da0073e9SAndroid Build Coastguard Worker     // When built in aten mode, at::Tensor has a non trivial constructor
253*da0073e9SAndroid Build Coastguard Worker     // destructor, so regular assignment to a union field is UB. Instead we must
254*da0073e9SAndroid Build Coastguard Worker     // go through placement new (which causes a refcount bump).
255*da0073e9SAndroid Build Coastguard Worker     new (&payload.as_tensor) at::Tensor(t);
256*da0073e9SAndroid Build Coastguard Worker   }
257*da0073e9SAndroid Build Coastguard Worker 
258*da0073e9SAndroid Build Coastguard Worker   bool isTensor() const {
259*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::Tensor;
260*da0073e9SAndroid Build Coastguard Worker   }
261*da0073e9SAndroid Build Coastguard Worker 
262*da0073e9SAndroid Build Coastguard Worker   at::Tensor toTensor() && {
263*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
264*da0073e9SAndroid Build Coastguard Worker     return std::move(payload.as_tensor);
265*da0073e9SAndroid Build Coastguard Worker   }
266*da0073e9SAndroid Build Coastguard Worker 
267*da0073e9SAndroid Build Coastguard Worker   at::Tensor& toTensor() & {
268*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
269*da0073e9SAndroid Build Coastguard Worker     return payload.as_tensor;
270*da0073e9SAndroid Build Coastguard Worker   }
271*da0073e9SAndroid Build Coastguard Worker 
272*da0073e9SAndroid Build Coastguard Worker   const at::Tensor& toTensor() const& {
273*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
274*da0073e9SAndroid Build Coastguard Worker     return payload.as_tensor;
275*da0073e9SAndroid Build Coastguard Worker   }
276*da0073e9SAndroid Build Coastguard Worker 
277*da0073e9SAndroid Build Coastguard Worker   /****** String Type ******/
278*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) {
279*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_string = at::ArrayRef<char>(s, size);
280*da0073e9SAndroid Build Coastguard Worker   }
281*da0073e9SAndroid Build Coastguard Worker 
282*da0073e9SAndroid Build Coastguard Worker   bool isString() const {
283*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::String;
284*da0073e9SAndroid Build Coastguard Worker   }
285*da0073e9SAndroid Build Coastguard Worker 
286*da0073e9SAndroid Build Coastguard Worker   at::string_view toString() const {
287*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isString(), "EValue is not a String.");
288*da0073e9SAndroid Build Coastguard Worker     return at::string_view(
289*da0073e9SAndroid Build Coastguard Worker         payload.copyable_union.as_string.data(),
290*da0073e9SAndroid Build Coastguard Worker         payload.copyable_union.as_string.size());
291*da0073e9SAndroid Build Coastguard Worker   }
292*da0073e9SAndroid Build Coastguard Worker 
293*da0073e9SAndroid Build Coastguard Worker   /****** Int List Type ******/
294*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) {
295*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_int_list = i;
296*da0073e9SAndroid Build Coastguard Worker   }
297*da0073e9SAndroid Build Coastguard Worker 
298*da0073e9SAndroid Build Coastguard Worker   bool isIntList() const {
299*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::ListInt;
300*da0073e9SAndroid Build Coastguard Worker   }
301*da0073e9SAndroid Build Coastguard Worker 
302*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<int64_t> toIntList() const {
303*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
304*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_int_list;
305*da0073e9SAndroid Build Coastguard Worker   }
306*da0073e9SAndroid Build Coastguard Worker 
307*da0073e9SAndroid Build Coastguard Worker   /****** Bool List Type ******/
308*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) {
309*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_bool_list = b;
310*da0073e9SAndroid Build Coastguard Worker   }
311*da0073e9SAndroid Build Coastguard Worker 
312*da0073e9SAndroid Build Coastguard Worker   bool isBoolList() const {
313*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::ListBool;
314*da0073e9SAndroid Build Coastguard Worker   }
315*da0073e9SAndroid Build Coastguard Worker 
316*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<bool> toBoolList() const {
317*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
318*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_bool_list;
319*da0073e9SAndroid Build Coastguard Worker   }
320*da0073e9SAndroid Build Coastguard Worker 
321*da0073e9SAndroid Build Coastguard Worker   /****** Double List Type ******/
322*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) {
323*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_double_list = d;
324*da0073e9SAndroid Build Coastguard Worker   }
325*da0073e9SAndroid Build Coastguard Worker 
326*da0073e9SAndroid Build Coastguard Worker   bool isDoubleList() const {
327*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::ListDouble;
328*da0073e9SAndroid Build Coastguard Worker   }
329*da0073e9SAndroid Build Coastguard Worker 
330*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<double> toDoubleList() const {
331*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
332*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_double_list;
333*da0073e9SAndroid Build Coastguard Worker   }
334*da0073e9SAndroid Build Coastguard Worker 
335*da0073e9SAndroid Build Coastguard Worker   /****** Tensor List Type ******/
336*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) {
337*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_tensor_list = t;
338*da0073e9SAndroid Build Coastguard Worker   }
339*da0073e9SAndroid Build Coastguard Worker 
340*da0073e9SAndroid Build Coastguard Worker   bool isTensorList() const {
341*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::ListTensor;
342*da0073e9SAndroid Build Coastguard Worker   }
343*da0073e9SAndroid Build Coastguard Worker 
344*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<at::Tensor> toTensorList() const {
345*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
346*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_tensor_list.get();
347*da0073e9SAndroid Build Coastguard Worker   }
348*da0073e9SAndroid Build Coastguard Worker 
349*da0073e9SAndroid Build Coastguard Worker   /****** List Optional Tensor Type ******/
350*da0073e9SAndroid Build Coastguard Worker   /*implicit*/ EValue(EValObjectList<std::optional<at::Tensor>> t)
351*da0073e9SAndroid Build Coastguard Worker       : tag(Tag::ListOptionalTensor) {
352*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_list_optional_tensor = t;
353*da0073e9SAndroid Build Coastguard Worker   }
354*da0073e9SAndroid Build Coastguard Worker 
355*da0073e9SAndroid Build Coastguard Worker   bool isListOptionalTensor() const {
356*da0073e9SAndroid Build Coastguard Worker     return tag == Tag::ListOptionalTensor;
357*da0073e9SAndroid Build Coastguard Worker   }
358*da0073e9SAndroid Build Coastguard Worker 
359*da0073e9SAndroid Build Coastguard Worker   at::ArrayRef<std::optional<at::Tensor>> toListOptionalTensor() {
360*da0073e9SAndroid Build Coastguard Worker     return payload.copyable_union.as_list_optional_tensor.get();
361*da0073e9SAndroid Build Coastguard Worker   }
362*da0073e9SAndroid Build Coastguard Worker 
363*da0073e9SAndroid Build Coastguard Worker   /****** ScalarType Type ******/
364*da0073e9SAndroid Build Coastguard Worker   at::ScalarType toScalarType() const {
365*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
366*da0073e9SAndroid Build Coastguard Worker     return static_cast<at::ScalarType>(payload.copyable_union.as_int);
367*da0073e9SAndroid Build Coastguard Worker   }
368*da0073e9SAndroid Build Coastguard Worker 
369*da0073e9SAndroid Build Coastguard Worker   /****** MemoryFormat Type ******/
370*da0073e9SAndroid Build Coastguard Worker   at::MemoryFormat toMemoryFormat() const {
371*da0073e9SAndroid Build Coastguard Worker     ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
372*da0073e9SAndroid Build Coastguard Worker     return static_cast<at::MemoryFormat>(payload.copyable_union.as_int);
373*da0073e9SAndroid Build Coastguard Worker   }
374*da0073e9SAndroid Build Coastguard Worker 
375*da0073e9SAndroid Build Coastguard Worker   template <typename T>
376*da0073e9SAndroid Build Coastguard Worker   T to() &&;
377*da0073e9SAndroid Build Coastguard Worker 
378*da0073e9SAndroid Build Coastguard Worker   template <typename T>
379*da0073e9SAndroid Build Coastguard Worker   typename evalue_to_ref_overload_return<T>::type to() &;
380*da0073e9SAndroid Build Coastguard Worker 
381*da0073e9SAndroid Build Coastguard Worker   /**
382*da0073e9SAndroid Build Coastguard Worker    * Converts the EValue to an optional object that can represent both T and
383*da0073e9SAndroid Build Coastguard Worker    * an uninitialized state.
384*da0073e9SAndroid Build Coastguard Worker    */
385*da0073e9SAndroid Build Coastguard Worker   template <typename T>
386*da0073e9SAndroid Build Coastguard Worker   inline std::optional<T> toOptional() {
387*da0073e9SAndroid Build Coastguard Worker     if (this->isNone()) {
388*da0073e9SAndroid Build Coastguard Worker       return std::nullopt;
389*da0073e9SAndroid Build Coastguard Worker     }
390*da0073e9SAndroid Build Coastguard Worker     return this->to<T>();
391*da0073e9SAndroid Build Coastguard Worker   }
392*da0073e9SAndroid Build Coastguard Worker 
393*da0073e9SAndroid Build Coastguard Worker  private:
394*da0073e9SAndroid Build Coastguard Worker   // Pre cond: the payload value has had its destructor called
395*da0073e9SAndroid Build Coastguard Worker   void clearToNone() noexcept {
396*da0073e9SAndroid Build Coastguard Worker     payload.copyable_union.as_int = 0;
397*da0073e9SAndroid Build Coastguard Worker     tag = Tag::None;
398*da0073e9SAndroid Build Coastguard Worker   }
399*da0073e9SAndroid Build Coastguard Worker 
400*da0073e9SAndroid Build Coastguard Worker   // Shared move logic
401*da0073e9SAndroid Build Coastguard Worker   void moveFrom(EValue&& rhs) noexcept {
402*da0073e9SAndroid Build Coastguard Worker     if (rhs.isTensor()) {
403*da0073e9SAndroid Build Coastguard Worker       new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
404*da0073e9SAndroid Build Coastguard Worker       rhs.payload.as_tensor.~Tensor();
405*da0073e9SAndroid Build Coastguard Worker     } else {
406*da0073e9SAndroid Build Coastguard Worker       payload.copyable_union = rhs.payload.copyable_union;
407*da0073e9SAndroid Build Coastguard Worker     }
408*da0073e9SAndroid Build Coastguard Worker     tag = rhs.tag;
409*da0073e9SAndroid Build Coastguard Worker     rhs.clearToNone();
410*da0073e9SAndroid Build Coastguard Worker   }
411*da0073e9SAndroid Build Coastguard Worker 
412*da0073e9SAndroid Build Coastguard Worker   // Destructs stored tensor if there is one
413*da0073e9SAndroid Build Coastguard Worker   void destroy() {
414*da0073e9SAndroid Build Coastguard Worker     // Necessary for ATen tensor to refcount decrement the intrusive_ptr to
415*da0073e9SAndroid Build Coastguard Worker     // tensorimpl that got a refcount increment when we placed it in the evalue,
416*da0073e9SAndroid Build Coastguard Worker     // no-op if executorch tensor #ifdef could have a
417*da0073e9SAndroid Build Coastguard Worker     // minor performance bump for a code maintainability hit
418*da0073e9SAndroid Build Coastguard Worker     if (isTensor()) {
419*da0073e9SAndroid Build Coastguard Worker       payload.as_tensor.~Tensor();
420*da0073e9SAndroid Build Coastguard Worker     } else if (isTensorList()) {
421*da0073e9SAndroid Build Coastguard Worker       for (auto& tensor : toTensorList()) {
422*da0073e9SAndroid Build Coastguard Worker         tensor.~Tensor();
423*da0073e9SAndroid Build Coastguard Worker       }
424*da0073e9SAndroid Build Coastguard Worker     } else if (isListOptionalTensor()) {
425*da0073e9SAndroid Build Coastguard Worker       for (auto& optional_tensor : toListOptionalTensor()) {
426*da0073e9SAndroid Build Coastguard Worker         optional_tensor.~optional();
427*da0073e9SAndroid Build Coastguard Worker       }
428*da0073e9SAndroid Build Coastguard Worker     }
429*da0073e9SAndroid Build Coastguard Worker   }
430*da0073e9SAndroid Build Coastguard Worker 
431*da0073e9SAndroid Build Coastguard Worker   EValue(const Payload& p, Tag t) : tag(t) {
432*da0073e9SAndroid Build Coastguard Worker     if (isTensor()) {
433*da0073e9SAndroid Build Coastguard Worker       new (&payload.as_tensor) at::Tensor(p.as_tensor);
434*da0073e9SAndroid Build Coastguard Worker     } else {
435*da0073e9SAndroid Build Coastguard Worker       payload.copyable_union = p.copyable_union;
436*da0073e9SAndroid Build Coastguard Worker     }
437*da0073e9SAndroid Build Coastguard Worker   }
438*da0073e9SAndroid Build Coastguard Worker };
439*da0073e9SAndroid Build Coastguard Worker 
440*da0073e9SAndroid Build Coastguard Worker #define EVALUE_DEFINE_TO(T, method_name)                           \
441*da0073e9SAndroid Build Coastguard Worker   template <>                                                      \
442*da0073e9SAndroid Build Coastguard Worker   inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \
443*da0073e9SAndroid Build Coastguard Worker     return static_cast<T>(this->method_name());                    \
444*da0073e9SAndroid Build Coastguard Worker   }
445*da0073e9SAndroid Build Coastguard Worker 
446*da0073e9SAndroid Build Coastguard Worker template <>
447*da0073e9SAndroid Build Coastguard Worker inline at::Tensor& EValue::to<at::Tensor>() & {
448*da0073e9SAndroid Build Coastguard Worker   return this->toTensor();
449*da0073e9SAndroid Build Coastguard Worker }
450*da0073e9SAndroid Build Coastguard Worker 
451*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::Scalar, toScalar)
452*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(int64_t, toInt)
453*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(bool, toBool)
454*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(double, toDouble)
455*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::string_view, toString)
456*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ScalarType, toScalarType)
457*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat)
458*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(std::optional<at::Tensor>, toOptional<at::Tensor>)
459*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList)
460*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(
461*da0073e9SAndroid Build Coastguard Worker     std::optional<at::ArrayRef<int64_t>>,
462*da0073e9SAndroid Build Coastguard Worker     toOptional<at::ArrayRef<int64_t>>)
463*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(
464*da0073e9SAndroid Build Coastguard Worker     std::optional<at::ArrayRef<double>>,
465*da0073e9SAndroid Build Coastguard Worker     toOptional<at::ArrayRef<double>>)
466*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ArrayRef<std::optional<at::Tensor>>, toListOptionalTensor)
467*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList)
468*da0073e9SAndroid Build Coastguard Worker #undef EVALUE_DEFINE_TO
469*da0073e9SAndroid Build Coastguard Worker 
470*da0073e9SAndroid Build Coastguard Worker template <typename T>
471*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<T> EValObjectList<T>::get() const {
472*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < wrapped_vals_.size(); i++) {
473*da0073e9SAndroid Build Coastguard Worker     unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>();
474*da0073e9SAndroid Build Coastguard Worker   }
475*da0073e9SAndroid Build Coastguard Worker   return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
476*da0073e9SAndroid Build Coastguard Worker }
477*da0073e9SAndroid Build Coastguard Worker 
478*da0073e9SAndroid Build Coastguard Worker } // namespace executor
479*da0073e9SAndroid Build Coastguard Worker } // namespace torch
480