1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h> 4*da0073e9SAndroid Build Coastguard Worker /** 5*da0073e9SAndroid Build Coastguard Worker * WARNING: EValue is a class used by Executorch, for its boxed operators. It 6*da0073e9SAndroid Build Coastguard Worker * contains similar logic as `IValue` in PyTorch, by providing APIs to convert 7*da0073e9SAndroid Build Coastguard Worker * boxed values to unboxed values. 8*da0073e9SAndroid Build Coastguard Worker * 9*da0073e9SAndroid Build Coastguard Worker * It's mirroring a fbcode internal source file 10*da0073e9SAndroid Build Coastguard Worker * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h). 11*da0073e9SAndroid Build Coastguard Worker * 12*da0073e9SAndroid Build Coastguard Worker * The reason why we are mirroring this class, is to make sure we have CI job 13*da0073e9SAndroid Build Coastguard Worker * coverage on torchgen logic, given that torchgen is used for both Executorch 14*da0073e9SAndroid Build Coastguard Worker * and PyTorch. 15*da0073e9SAndroid Build Coastguard Worker * 16*da0073e9SAndroid Build Coastguard Worker * If any of the logic here needs to be changed, please update fbcode version of 17*da0073e9SAndroid Build Coastguard Worker * `Evalue.h` as well. These two versions will be merged as soon as Executorch 18*da0073e9SAndroid Build Coastguard Worker * is in OSS (hopefully by Q2 2023). 19*da0073e9SAndroid Build Coastguard Worker */ 20*da0073e9SAndroid Build Coastguard Worker namespace torch { 21*da0073e9SAndroid Build Coastguard Worker namespace executor { 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker #define ET_CHECK_MSG TORCH_CHECK_MSG 24*da0073e9SAndroid Build Coastguard Worker #define EXECUTORCH_FORALL_TAGS(_) \ 25*da0073e9SAndroid Build Coastguard Worker _(None) \ 26*da0073e9SAndroid Build Coastguard Worker _(Tensor) \ 27*da0073e9SAndroid Build Coastguard Worker _(String) \ 28*da0073e9SAndroid Build Coastguard Worker _(Double) \ 29*da0073e9SAndroid Build Coastguard Worker _(Int) \ 30*da0073e9SAndroid Build Coastguard Worker _(Bool) \ 31*da0073e9SAndroid Build Coastguard Worker _(ListBool) \ 32*da0073e9SAndroid Build Coastguard Worker _(ListDouble) \ 33*da0073e9SAndroid Build Coastguard Worker _(ListInt) \ 34*da0073e9SAndroid Build Coastguard Worker _(ListTensor) \ 35*da0073e9SAndroid Build Coastguard Worker _(ListScalar) \ 36*da0073e9SAndroid Build Coastguard Worker _(ListOptionalTensor) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker enum class Tag : uint32_t { 39*da0073e9SAndroid Build Coastguard Worker #define DEFINE_TAG(x) x, 40*da0073e9SAndroid Build Coastguard Worker EXECUTORCH_FORALL_TAGS(DEFINE_TAG) 41*da0073e9SAndroid Build Coastguard Worker #undef DEFINE_TAG 42*da0073e9SAndroid Build Coastguard Worker }; 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker struct EValue; 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker template <typename T> 47*da0073e9SAndroid Build Coastguard Worker struct evalue_to_const_ref_overload_return { 48*da0073e9SAndroid Build Coastguard Worker using type = T; 49*da0073e9SAndroid Build Coastguard Worker }; 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker template <> 52*da0073e9SAndroid Build Coastguard Worker struct evalue_to_const_ref_overload_return<at::Tensor> { 53*da0073e9SAndroid Build Coastguard Worker using type = const at::Tensor&; 54*da0073e9SAndroid Build Coastguard Worker }; 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker template <typename T> 57*da0073e9SAndroid Build Coastguard Worker struct evalue_to_ref_overload_return { 58*da0073e9SAndroid Build Coastguard Worker using type = T; 59*da0073e9SAndroid Build Coastguard Worker }; 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker template <> 62*da0073e9SAndroid Build Coastguard Worker struct evalue_to_ref_overload_return<at::Tensor> { 63*da0073e9SAndroid Build Coastguard Worker using type = at::Tensor&; 64*da0073e9SAndroid Build Coastguard Worker }; 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker /* 67*da0073e9SAndroid Build Coastguard Worker * Helper class used to correlate EValues in the executor table, with the 68*da0073e9SAndroid Build Coastguard Worker * unwrapped list of the proper type. Because values in the runtime's values 69*da0073e9SAndroid Build Coastguard Worker * table can change during execution, we cannot statically allocate list of 70*da0073e9SAndroid Build Coastguard Worker * objects at deserialization. Imagine the serialized list says index 0 in the 71*da0073e9SAndroid Build Coastguard Worker * value table is element 2 in the list, but during execution the value in 72*da0073e9SAndroid Build Coastguard Worker * element 2 changes (in the case of tensor this means the TensorImpl* stored in 73*da0073e9SAndroid Build Coastguard Worker * the tensor changes). To solve this instead they must be created dynamically 74*da0073e9SAndroid Build Coastguard Worker * whenever they are used. 75*da0073e9SAndroid Build Coastguard Worker */ 76*da0073e9SAndroid Build Coastguard Worker template <typename T> 77*da0073e9SAndroid Build Coastguard Worker class EValObjectList { 78*da0073e9SAndroid Build Coastguard Worker public: 79*da0073e9SAndroid Build Coastguard Worker EValObjectList() = default; 80*da0073e9SAndroid Build Coastguard Worker /* 81*da0073e9SAndroid Build Coastguard Worker * Wrapped_vals is a list of pointers into the values table of the runtime 82*da0073e9SAndroid Build Coastguard Worker * whose destinations correlate with the elements of the list, unwrapped_vals 83*da0073e9SAndroid Build Coastguard Worker * is a container of the same size whose serves as memory to construct the 84*da0073e9SAndroid Build Coastguard Worker * unwrapped vals. 85*da0073e9SAndroid Build Coastguard Worker */ 86*da0073e9SAndroid Build Coastguard Worker EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size) 87*da0073e9SAndroid Build Coastguard Worker : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {} 88*da0073e9SAndroid Build Coastguard Worker /* 89*da0073e9SAndroid Build Coastguard Worker * Constructs and returns the list of T specified by the EValue pointers 90*da0073e9SAndroid Build Coastguard Worker */ 91*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<T> get() const; 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker private: 94*da0073e9SAndroid Build Coastguard Worker // Source of truth for the list 95*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<EValue*> wrapped_vals_; 96*da0073e9SAndroid Build Coastguard Worker // Same size as wrapped_vals 97*da0073e9SAndroid Build Coastguard Worker mutable T* unwrapped_vals_; 98*da0073e9SAndroid Build Coastguard Worker }; 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker // Aggregate typing system similar to IValue only slimmed down with less 101*da0073e9SAndroid Build Coastguard Worker // functionality, no dependencies on atomic, and fewer supported types to better 102*da0073e9SAndroid Build Coastguard Worker // suit embedded systems (ie no intrusive ptr) 103*da0073e9SAndroid Build Coastguard Worker struct EValue { 104*da0073e9SAndroid Build Coastguard Worker union Payload { 105*da0073e9SAndroid Build Coastguard Worker // When in ATen mode at::Tensor is not trivially copyable, this nested union 106*da0073e9SAndroid Build Coastguard Worker // lets us handle tensor as a special case while leaving the rest of the 107*da0073e9SAndroid Build Coastguard Worker // fields in a simple state instead of requiring a switch on tag everywhere. 108*da0073e9SAndroid Build Coastguard Worker union TriviallyCopyablePayload { 109*da0073e9SAndroid Build Coastguard Worker TriviallyCopyablePayload() : as_int(0) {} 110*da0073e9SAndroid Build Coastguard Worker // Scalar supported through these 3 types 111*da0073e9SAndroid Build Coastguard Worker int64_t as_int; 112*da0073e9SAndroid Build Coastguard Worker double as_double; 113*da0073e9SAndroid Build Coastguard Worker bool as_bool; 114*da0073e9SAndroid Build Coastguard Worker // TODO(jakeszwe): convert back to pointers to optimize size of this 115*da0073e9SAndroid Build Coastguard Worker // struct 116*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<char> as_string; 117*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<int64_t> as_int_list; 118*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<double> as_double_list; 119*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<bool> as_bool_list; 120*da0073e9SAndroid Build Coastguard Worker EValObjectList<at::Tensor> as_tensor_list; 121*da0073e9SAndroid Build Coastguard Worker EValObjectList<std::optional<at::Tensor>> as_list_optional_tensor; 122*da0073e9SAndroid Build Coastguard Worker } copyable_union; 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor* 125*da0073e9SAndroid Build Coastguard Worker // here. 126*da0073e9SAndroid Build Coastguard Worker at::Tensor as_tensor; 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker Payload() {} 129*da0073e9SAndroid Build Coastguard Worker ~Payload() {} 130*da0073e9SAndroid Build Coastguard Worker }; 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker // Data storage and type tag 133*da0073e9SAndroid Build Coastguard Worker Payload payload; 134*da0073e9SAndroid Build Coastguard Worker Tag tag; 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker // Basic ctors and assignments 137*da0073e9SAndroid Build Coastguard Worker EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {} 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker EValue(EValue&& rhs) noexcept : tag(rhs.tag) { 140*da0073e9SAndroid Build Coastguard Worker moveFrom(std::move(rhs)); 141*da0073e9SAndroid Build Coastguard Worker } 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker EValue& operator=(EValue&& rhs) & noexcept { 144*da0073e9SAndroid Build Coastguard Worker if (&rhs == this) { 145*da0073e9SAndroid Build Coastguard Worker return *this; 146*da0073e9SAndroid Build Coastguard Worker } 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker destroy(); 149*da0073e9SAndroid Build Coastguard Worker moveFrom(std::move(rhs)); 150*da0073e9SAndroid Build Coastguard Worker return *this; 151*da0073e9SAndroid Build Coastguard Worker } 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker EValue& operator=(EValue const& rhs) & { 154*da0073e9SAndroid Build Coastguard Worker // Define copy assignment through copy ctor and move assignment 155*da0073e9SAndroid Build Coastguard Worker *this = EValue(rhs); 156*da0073e9SAndroid Build Coastguard Worker return *this; 157*da0073e9SAndroid Build Coastguard Worker } 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker ~EValue() { 160*da0073e9SAndroid Build Coastguard Worker destroy(); 161*da0073e9SAndroid Build Coastguard Worker } 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker /****** None Type ******/ 164*da0073e9SAndroid Build Coastguard Worker EValue() : tag(Tag::None) { 165*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_int = 0; 166*da0073e9SAndroid Build Coastguard Worker } 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker bool isNone() const { 169*da0073e9SAndroid Build Coastguard Worker return tag == Tag::None; 170*da0073e9SAndroid Build Coastguard Worker } 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker /****** Int Type ******/ 173*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(int64_t i) : tag(Tag::Int) { 174*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_int = i; 175*da0073e9SAndroid Build Coastguard Worker } 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker bool isInt() const { 178*da0073e9SAndroid Build Coastguard Worker return tag == Tag::Int; 179*da0073e9SAndroid Build Coastguard Worker } 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker int64_t toInt() const { 182*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isInt(), "EValue is not an int."); 183*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_int; 184*da0073e9SAndroid Build Coastguard Worker } 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker /****** Double Type ******/ 187*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(double d) : tag(Tag::Double) { 188*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_double = d; 189*da0073e9SAndroid Build Coastguard Worker } 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker bool isDouble() const { 192*da0073e9SAndroid Build Coastguard Worker return tag == Tag::Double; 193*da0073e9SAndroid Build Coastguard Worker } 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker double toDouble() const { 196*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isDouble(), "EValue is not a Double."); 197*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_double; 198*da0073e9SAndroid Build Coastguard Worker } 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker /****** Bool Type ******/ 201*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(bool b) : tag(Tag::Bool) { 202*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_bool = b; 203*da0073e9SAndroid Build Coastguard Worker } 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker bool isBool() const { 206*da0073e9SAndroid Build Coastguard Worker return tag == Tag::Bool; 207*da0073e9SAndroid Build Coastguard Worker } 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker bool toBool() const { 210*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isBool(), "EValue is not a Bool."); 211*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_bool; 212*da0073e9SAndroid Build Coastguard Worker } 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker /****** Scalar Type ******/ 215*da0073e9SAndroid Build Coastguard Worker /// Construct an EValue using the implicit value of a Scalar. 216*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(at::Scalar s) { 217*da0073e9SAndroid Build Coastguard Worker if (s.isIntegral(false)) { 218*da0073e9SAndroid Build Coastguard Worker tag = Tag::Int; 219*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_int = s.to<int64_t>(); 220*da0073e9SAndroid Build Coastguard Worker } else if (s.isFloatingPoint()) { 221*da0073e9SAndroid Build Coastguard Worker tag = Tag::Double; 222*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_double = s.to<double>(); 223*da0073e9SAndroid Build Coastguard Worker } else if (s.isBoolean()) { 224*da0073e9SAndroid Build Coastguard Worker tag = Tag::Bool; 225*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_bool = s.to<bool>(); 226*da0073e9SAndroid Build Coastguard Worker } else { 227*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized."); 228*da0073e9SAndroid Build Coastguard Worker } 229*da0073e9SAndroid Build Coastguard Worker } 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker bool isScalar() const { 232*da0073e9SAndroid Build Coastguard Worker return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool; 233*da0073e9SAndroid Build Coastguard Worker } 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker at::Scalar toScalar() const { 236*da0073e9SAndroid Build Coastguard Worker // Convert from implicit value to Scalar using implicit constructors. 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker if (isDouble()) { 239*da0073e9SAndroid Build Coastguard Worker return toDouble(); 240*da0073e9SAndroid Build Coastguard Worker } else if (isInt()) { 241*da0073e9SAndroid Build Coastguard Worker return toInt(); 242*da0073e9SAndroid Build Coastguard Worker } else if (isBool()) { 243*da0073e9SAndroid Build Coastguard Worker return toBool(); 244*da0073e9SAndroid Build Coastguard Worker } else { 245*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(false, "EValue is not a Scalar."); 246*da0073e9SAndroid Build Coastguard Worker return c10::Scalar(); 247*da0073e9SAndroid Build Coastguard Worker } 248*da0073e9SAndroid Build Coastguard Worker } 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker /****** Tensor Type ******/ 251*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) { 252*da0073e9SAndroid Build Coastguard Worker // When built in aten mode, at::Tensor has a non trivial constructor 253*da0073e9SAndroid Build Coastguard Worker // destructor, so regular assignment to a union field is UB. Instead we must 254*da0073e9SAndroid Build Coastguard Worker // go through placement new (which causes a refcount bump). 255*da0073e9SAndroid Build Coastguard Worker new (&payload.as_tensor) at::Tensor(t); 256*da0073e9SAndroid Build Coastguard Worker } 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker bool isTensor() const { 259*da0073e9SAndroid Build Coastguard Worker return tag == Tag::Tensor; 260*da0073e9SAndroid Build Coastguard Worker } 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker at::Tensor toTensor() && { 263*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); 264*da0073e9SAndroid Build Coastguard Worker return std::move(payload.as_tensor); 265*da0073e9SAndroid Build Coastguard Worker } 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker at::Tensor& toTensor() & { 268*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); 269*da0073e9SAndroid Build Coastguard Worker return payload.as_tensor; 270*da0073e9SAndroid Build Coastguard Worker } 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker const at::Tensor& toTensor() const& { 273*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); 274*da0073e9SAndroid Build Coastguard Worker return payload.as_tensor; 275*da0073e9SAndroid Build Coastguard Worker } 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker /****** String Type ******/ 278*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) { 279*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_string = at::ArrayRef<char>(s, size); 280*da0073e9SAndroid Build Coastguard Worker } 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker bool isString() const { 283*da0073e9SAndroid Build Coastguard Worker return tag == Tag::String; 284*da0073e9SAndroid Build Coastguard Worker } 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker at::string_view toString() const { 287*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isString(), "EValue is not a String."); 288*da0073e9SAndroid Build Coastguard Worker return at::string_view( 289*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_string.data(), 290*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_string.size()); 291*da0073e9SAndroid Build Coastguard Worker } 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker /****** Int List Type ******/ 294*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) { 295*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_int_list = i; 296*da0073e9SAndroid Build Coastguard Worker } 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker bool isIntList() const { 299*da0073e9SAndroid Build Coastguard Worker return tag == Tag::ListInt; 300*da0073e9SAndroid Build Coastguard Worker } 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<int64_t> toIntList() const { 303*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isIntList(), "EValue is not an Int List."); 304*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_int_list; 305*da0073e9SAndroid Build Coastguard Worker } 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker /****** Bool List Type ******/ 308*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) { 309*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_bool_list = b; 310*da0073e9SAndroid Build Coastguard Worker } 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker bool isBoolList() const { 313*da0073e9SAndroid Build Coastguard Worker return tag == Tag::ListBool; 314*da0073e9SAndroid Build Coastguard Worker } 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<bool> toBoolList() const { 317*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List."); 318*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_bool_list; 319*da0073e9SAndroid Build Coastguard Worker } 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker /****** Double List Type ******/ 322*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) { 323*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_double_list = d; 324*da0073e9SAndroid Build Coastguard Worker } 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker bool isDoubleList() const { 327*da0073e9SAndroid Build Coastguard Worker return tag == Tag::ListDouble; 328*da0073e9SAndroid Build Coastguard Worker } 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<double> toDoubleList() const { 331*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List."); 332*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_double_list; 333*da0073e9SAndroid Build Coastguard Worker } 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker /****** Tensor List Type ******/ 336*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) { 337*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_tensor_list = t; 338*da0073e9SAndroid Build Coastguard Worker } 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker bool isTensorList() const { 341*da0073e9SAndroid Build Coastguard Worker return tag == Tag::ListTensor; 342*da0073e9SAndroid Build Coastguard Worker } 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<at::Tensor> toTensorList() const { 345*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List."); 346*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_tensor_list.get(); 347*da0073e9SAndroid Build Coastguard Worker } 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker /****** List Optional Tensor Type ******/ 350*da0073e9SAndroid Build Coastguard Worker /*implicit*/ EValue(EValObjectList<std::optional<at::Tensor>> t) 351*da0073e9SAndroid Build Coastguard Worker : tag(Tag::ListOptionalTensor) { 352*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_list_optional_tensor = t; 353*da0073e9SAndroid Build Coastguard Worker } 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker bool isListOptionalTensor() const { 356*da0073e9SAndroid Build Coastguard Worker return tag == Tag::ListOptionalTensor; 357*da0073e9SAndroid Build Coastguard Worker } 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<std::optional<at::Tensor>> toListOptionalTensor() { 360*da0073e9SAndroid Build Coastguard Worker return payload.copyable_union.as_list_optional_tensor.get(); 361*da0073e9SAndroid Build Coastguard Worker } 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker /****** ScalarType Type ******/ 364*da0073e9SAndroid Build Coastguard Worker at::ScalarType toScalarType() const { 365*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); 366*da0073e9SAndroid Build Coastguard Worker return static_cast<at::ScalarType>(payload.copyable_union.as_int); 367*da0073e9SAndroid Build Coastguard Worker } 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker /****** MemoryFormat Type ******/ 370*da0073e9SAndroid Build Coastguard Worker at::MemoryFormat toMemoryFormat() const { 371*da0073e9SAndroid Build Coastguard Worker ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); 372*da0073e9SAndroid Build Coastguard Worker return static_cast<at::MemoryFormat>(payload.copyable_union.as_int); 373*da0073e9SAndroid Build Coastguard Worker } 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker template <typename T> 376*da0073e9SAndroid Build Coastguard Worker T to() &&; 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker template <typename T> 379*da0073e9SAndroid Build Coastguard Worker typename evalue_to_ref_overload_return<T>::type to() &; 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker /** 382*da0073e9SAndroid Build Coastguard Worker * Converts the EValue to an optional object that can represent both T and 383*da0073e9SAndroid Build Coastguard Worker * an uninitialized state. 384*da0073e9SAndroid Build Coastguard Worker */ 385*da0073e9SAndroid Build Coastguard Worker template <typename T> 386*da0073e9SAndroid Build Coastguard Worker inline std::optional<T> toOptional() { 387*da0073e9SAndroid Build Coastguard Worker if (this->isNone()) { 388*da0073e9SAndroid Build Coastguard Worker return std::nullopt; 389*da0073e9SAndroid Build Coastguard Worker } 390*da0073e9SAndroid Build Coastguard Worker return this->to<T>(); 391*da0073e9SAndroid Build Coastguard Worker } 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker private: 394*da0073e9SAndroid Build Coastguard Worker // Pre cond: the payload value has had its destructor called 395*da0073e9SAndroid Build Coastguard Worker void clearToNone() noexcept { 396*da0073e9SAndroid Build Coastguard Worker payload.copyable_union.as_int = 0; 397*da0073e9SAndroid Build Coastguard Worker tag = Tag::None; 398*da0073e9SAndroid Build Coastguard Worker } 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker // Shared move logic 401*da0073e9SAndroid Build Coastguard Worker void moveFrom(EValue&& rhs) noexcept { 402*da0073e9SAndroid Build Coastguard Worker if (rhs.isTensor()) { 403*da0073e9SAndroid Build Coastguard Worker new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); 404*da0073e9SAndroid Build Coastguard Worker rhs.payload.as_tensor.~Tensor(); 405*da0073e9SAndroid Build Coastguard Worker } else { 406*da0073e9SAndroid Build Coastguard Worker payload.copyable_union = rhs.payload.copyable_union; 407*da0073e9SAndroid Build Coastguard Worker } 408*da0073e9SAndroid Build Coastguard Worker tag = rhs.tag; 409*da0073e9SAndroid Build Coastguard Worker rhs.clearToNone(); 410*da0073e9SAndroid Build Coastguard Worker } 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker // Destructs stored tensor if there is one 413*da0073e9SAndroid Build Coastguard Worker void destroy() { 414*da0073e9SAndroid Build Coastguard Worker // Necessary for ATen tensor to refcount decrement the intrusive_ptr to 415*da0073e9SAndroid Build Coastguard Worker // tensorimpl that got a refcount increment when we placed it in the evalue, 416*da0073e9SAndroid Build Coastguard Worker // no-op if executorch tensor #ifdef could have a 417*da0073e9SAndroid Build Coastguard Worker // minor performance bump for a code maintainability hit 418*da0073e9SAndroid Build Coastguard Worker if (isTensor()) { 419*da0073e9SAndroid Build Coastguard Worker payload.as_tensor.~Tensor(); 420*da0073e9SAndroid Build Coastguard Worker } else if (isTensorList()) { 421*da0073e9SAndroid Build Coastguard Worker for (auto& tensor : toTensorList()) { 422*da0073e9SAndroid Build Coastguard Worker tensor.~Tensor(); 423*da0073e9SAndroid Build Coastguard Worker } 424*da0073e9SAndroid Build Coastguard Worker } else if (isListOptionalTensor()) { 425*da0073e9SAndroid Build Coastguard Worker for (auto& optional_tensor : toListOptionalTensor()) { 426*da0073e9SAndroid Build Coastguard Worker optional_tensor.~optional(); 427*da0073e9SAndroid Build Coastguard Worker } 428*da0073e9SAndroid Build Coastguard Worker } 429*da0073e9SAndroid Build Coastguard Worker } 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker EValue(const Payload& p, Tag t) : tag(t) { 432*da0073e9SAndroid Build Coastguard Worker if (isTensor()) { 433*da0073e9SAndroid Build Coastguard Worker new (&payload.as_tensor) at::Tensor(p.as_tensor); 434*da0073e9SAndroid Build Coastguard Worker } else { 435*da0073e9SAndroid Build Coastguard Worker payload.copyable_union = p.copyable_union; 436*da0073e9SAndroid Build Coastguard Worker } 437*da0073e9SAndroid Build Coastguard Worker } 438*da0073e9SAndroid Build Coastguard Worker }; 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker #define EVALUE_DEFINE_TO(T, method_name) \ 441*da0073e9SAndroid Build Coastguard Worker template <> \ 442*da0073e9SAndroid Build Coastguard Worker inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \ 443*da0073e9SAndroid Build Coastguard Worker return static_cast<T>(this->method_name()); \ 444*da0073e9SAndroid Build Coastguard Worker } 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker template <> 447*da0073e9SAndroid Build Coastguard Worker inline at::Tensor& EValue::to<at::Tensor>() & { 448*da0073e9SAndroid Build Coastguard Worker return this->toTensor(); 449*da0073e9SAndroid Build Coastguard Worker } 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::Scalar, toScalar) 452*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(int64_t, toInt) 453*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(bool, toBool) 454*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(double, toDouble) 455*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::string_view, toString) 456*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ScalarType, toScalarType) 457*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat) 458*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(std::optional<at::Tensor>, toOptional<at::Tensor>) 459*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList) 460*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO( 461*da0073e9SAndroid Build Coastguard Worker std::optional<at::ArrayRef<int64_t>>, 462*da0073e9SAndroid Build Coastguard Worker toOptional<at::ArrayRef<int64_t>>) 463*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO( 464*da0073e9SAndroid Build Coastguard Worker std::optional<at::ArrayRef<double>>, 465*da0073e9SAndroid Build Coastguard Worker toOptional<at::ArrayRef<double>>) 466*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ArrayRef<std::optional<at::Tensor>>, toListOptionalTensor) 467*da0073e9SAndroid Build Coastguard Worker EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList) 468*da0073e9SAndroid Build Coastguard Worker #undef EVALUE_DEFINE_TO 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker template <typename T> 471*da0073e9SAndroid Build Coastguard Worker at::ArrayRef<T> EValObjectList<T>::get() const { 472*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < wrapped_vals_.size(); i++) { 473*da0073e9SAndroid Build Coastguard Worker unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>(); 474*da0073e9SAndroid Build Coastguard Worker } 475*da0073e9SAndroid Build Coastguard Worker return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()}; 476*da0073e9SAndroid Build Coastguard Worker } 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker } // namespace executor 479*da0073e9SAndroid Build Coastguard Worker } // namespace torch 480