1 #pragma once 2 3 #include <ATen/ATen.h> 4 /** 5 * WARNING: EValue is a class used by Executorch, for its boxed operators. It 6 * contains similar logic as `IValue` in PyTorch, by providing APIs to convert 7 * boxed values to unboxed values. 8 * 9 * It's mirroring a fbcode internal source file 10 * [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h). 11 * 12 * The reason why we are mirroring this class, is to make sure we have CI job 13 * coverage on torchgen logic, given that torchgen is used for both Executorch 14 * and PyTorch. 15 * 16 * If any of the logic here needs to be changed, please update fbcode version of 17 * `Evalue.h` as well. These two versions will be merged as soon as Executorch 18 * is in OSS (hopefully by Q2 2023). 19 */ 20 namespace torch { 21 namespace executor { 22 23 #define ET_CHECK_MSG TORCH_CHECK_MSG 24 #define EXECUTORCH_FORALL_TAGS(_) \ 25 _(None) \ 26 _(Tensor) \ 27 _(String) \ 28 _(Double) \ 29 _(Int) \ 30 _(Bool) \ 31 _(ListBool) \ 32 _(ListDouble) \ 33 _(ListInt) \ 34 _(ListTensor) \ 35 _(ListScalar) \ 36 _(ListOptionalTensor) 37 38 enum class Tag : uint32_t { 39 #define DEFINE_TAG(x) x, 40 EXECUTORCH_FORALL_TAGS(DEFINE_TAG) 41 #undef DEFINE_TAG 42 }; 43 44 struct EValue; 45 46 template <typename T> 47 struct evalue_to_const_ref_overload_return { 48 using type = T; 49 }; 50 51 template <> 52 struct evalue_to_const_ref_overload_return<at::Tensor> { 53 using type = const at::Tensor&; 54 }; 55 56 template <typename T> 57 struct evalue_to_ref_overload_return { 58 using type = T; 59 }; 60 61 template <> 62 struct evalue_to_ref_overload_return<at::Tensor> { 63 using type = at::Tensor&; 64 }; 65 66 /* 67 * Helper class used to correlate EValues in the executor table, with the 68 * unwrapped list of the proper type. Because values in the runtime's values 69 * table can change during execution, we cannot statically allocate list of 70 * objects at deserialization. Imagine the serialized list says index 0 in the 71 * value table is element 2 in the list, but during execution the value in 72 * element 2 changes (in the case of tensor this means the TensorImpl* stored in 73 * the tensor changes). To solve this instead they must be created dynamically 74 * whenever they are used. 75 */ 76 template <typename T> 77 class EValObjectList { 78 public: 79 EValObjectList() = default; 80 /* 81 * Wrapped_vals is a list of pointers into the values table of the runtime 82 * whose destinations correlate with the elements of the list, unwrapped_vals 83 * is a container of the same size whose serves as memory to construct the 84 * unwrapped vals. 85 */ 86 EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size) 87 : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {} 88 /* 89 * Constructs and returns the list of T specified by the EValue pointers 90 */ 91 at::ArrayRef<T> get() const; 92 93 private: 94 // Source of truth for the list 95 at::ArrayRef<EValue*> wrapped_vals_; 96 // Same size as wrapped_vals 97 mutable T* unwrapped_vals_; 98 }; 99 100 // Aggregate typing system similar to IValue only slimmed down with less 101 // functionality, no dependencies on atomic, and fewer supported types to better 102 // suit embedded systems (ie no intrusive ptr) 103 struct EValue { 104 union Payload { 105 // When in ATen mode at::Tensor is not trivially copyable, this nested union 106 // lets us handle tensor as a special case while leaving the rest of the 107 // fields in a simple state instead of requiring a switch on tag everywhere. 108 union TriviallyCopyablePayload { 109 TriviallyCopyablePayload() : as_int(0) {} 110 // Scalar supported through these 3 types 111 int64_t as_int; 112 double as_double; 113 bool as_bool; 114 // TODO(jakeszwe): convert back to pointers to optimize size of this 115 // struct 116 at::ArrayRef<char> as_string; 117 at::ArrayRef<int64_t> as_int_list; 118 at::ArrayRef<double> as_double_list; 119 at::ArrayRef<bool> as_bool_list; 120 EValObjectList<at::Tensor> as_tensor_list; 121 EValObjectList<std::optional<at::Tensor>> as_list_optional_tensor; 122 } copyable_union; 123 124 // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor* 125 // here. 126 at::Tensor as_tensor; 127 128 Payload() {} 129 ~Payload() {} 130 }; 131 132 // Data storage and type tag 133 Payload payload; 134 Tag tag; 135 136 // Basic ctors and assignments 137 EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {} 138 139 EValue(EValue&& rhs) noexcept : tag(rhs.tag) { 140 moveFrom(std::move(rhs)); 141 } 142 143 EValue& operator=(EValue&& rhs) & noexcept { 144 if (&rhs == this) { 145 return *this; 146 } 147 148 destroy(); 149 moveFrom(std::move(rhs)); 150 return *this; 151 } 152 153 EValue& operator=(EValue const& rhs) & { 154 // Define copy assignment through copy ctor and move assignment 155 *this = EValue(rhs); 156 return *this; 157 } 158 159 ~EValue() { 160 destroy(); 161 } 162 163 /****** None Type ******/ 164 EValue() : tag(Tag::None) { 165 payload.copyable_union.as_int = 0; 166 } 167 168 bool isNone() const { 169 return tag == Tag::None; 170 } 171 172 /****** Int Type ******/ 173 /*implicit*/ EValue(int64_t i) : tag(Tag::Int) { 174 payload.copyable_union.as_int = i; 175 } 176 177 bool isInt() const { 178 return tag == Tag::Int; 179 } 180 181 int64_t toInt() const { 182 ET_CHECK_MSG(isInt(), "EValue is not an int."); 183 return payload.copyable_union.as_int; 184 } 185 186 /****** Double Type ******/ 187 /*implicit*/ EValue(double d) : tag(Tag::Double) { 188 payload.copyable_union.as_double = d; 189 } 190 191 bool isDouble() const { 192 return tag == Tag::Double; 193 } 194 195 double toDouble() const { 196 ET_CHECK_MSG(isDouble(), "EValue is not a Double."); 197 return payload.copyable_union.as_double; 198 } 199 200 /****** Bool Type ******/ 201 /*implicit*/ EValue(bool b) : tag(Tag::Bool) { 202 payload.copyable_union.as_bool = b; 203 } 204 205 bool isBool() const { 206 return tag == Tag::Bool; 207 } 208 209 bool toBool() const { 210 ET_CHECK_MSG(isBool(), "EValue is not a Bool."); 211 return payload.copyable_union.as_bool; 212 } 213 214 /****** Scalar Type ******/ 215 /// Construct an EValue using the implicit value of a Scalar. 216 /*implicit*/ EValue(at::Scalar s) { 217 if (s.isIntegral(false)) { 218 tag = Tag::Int; 219 payload.copyable_union.as_int = s.to<int64_t>(); 220 } else if (s.isFloatingPoint()) { 221 tag = Tag::Double; 222 payload.copyable_union.as_double = s.to<double>(); 223 } else if (s.isBoolean()) { 224 tag = Tag::Bool; 225 payload.copyable_union.as_bool = s.to<bool>(); 226 } else { 227 ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized."); 228 } 229 } 230 231 bool isScalar() const { 232 return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool; 233 } 234 235 at::Scalar toScalar() const { 236 // Convert from implicit value to Scalar using implicit constructors. 237 238 if (isDouble()) { 239 return toDouble(); 240 } else if (isInt()) { 241 return toInt(); 242 } else if (isBool()) { 243 return toBool(); 244 } else { 245 ET_CHECK_MSG(false, "EValue is not a Scalar."); 246 return c10::Scalar(); 247 } 248 } 249 250 /****** Tensor Type ******/ 251 /*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) { 252 // When built in aten mode, at::Tensor has a non trivial constructor 253 // destructor, so regular assignment to a union field is UB. Instead we must 254 // go through placement new (which causes a refcount bump). 255 new (&payload.as_tensor) at::Tensor(t); 256 } 257 258 bool isTensor() const { 259 return tag == Tag::Tensor; 260 } 261 262 at::Tensor toTensor() && { 263 ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); 264 return std::move(payload.as_tensor); 265 } 266 267 at::Tensor& toTensor() & { 268 ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); 269 return payload.as_tensor; 270 } 271 272 const at::Tensor& toTensor() const& { 273 ET_CHECK_MSG(isTensor(), "EValue is not a Tensor."); 274 return payload.as_tensor; 275 } 276 277 /****** String Type ******/ 278 /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) { 279 payload.copyable_union.as_string = at::ArrayRef<char>(s, size); 280 } 281 282 bool isString() const { 283 return tag == Tag::String; 284 } 285 286 at::string_view toString() const { 287 ET_CHECK_MSG(isString(), "EValue is not a String."); 288 return at::string_view( 289 payload.copyable_union.as_string.data(), 290 payload.copyable_union.as_string.size()); 291 } 292 293 /****** Int List Type ******/ 294 /*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) { 295 payload.copyable_union.as_int_list = i; 296 } 297 298 bool isIntList() const { 299 return tag == Tag::ListInt; 300 } 301 302 at::ArrayRef<int64_t> toIntList() const { 303 ET_CHECK_MSG(isIntList(), "EValue is not an Int List."); 304 return payload.copyable_union.as_int_list; 305 } 306 307 /****** Bool List Type ******/ 308 /*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) { 309 payload.copyable_union.as_bool_list = b; 310 } 311 312 bool isBoolList() const { 313 return tag == Tag::ListBool; 314 } 315 316 at::ArrayRef<bool> toBoolList() const { 317 ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List."); 318 return payload.copyable_union.as_bool_list; 319 } 320 321 /****** Double List Type ******/ 322 /*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) { 323 payload.copyable_union.as_double_list = d; 324 } 325 326 bool isDoubleList() const { 327 return tag == Tag::ListDouble; 328 } 329 330 at::ArrayRef<double> toDoubleList() const { 331 ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List."); 332 return payload.copyable_union.as_double_list; 333 } 334 335 /****** Tensor List Type ******/ 336 /*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) { 337 payload.copyable_union.as_tensor_list = t; 338 } 339 340 bool isTensorList() const { 341 return tag == Tag::ListTensor; 342 } 343 344 at::ArrayRef<at::Tensor> toTensorList() const { 345 ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List."); 346 return payload.copyable_union.as_tensor_list.get(); 347 } 348 349 /****** List Optional Tensor Type ******/ 350 /*implicit*/ EValue(EValObjectList<std::optional<at::Tensor>> t) 351 : tag(Tag::ListOptionalTensor) { 352 payload.copyable_union.as_list_optional_tensor = t; 353 } 354 355 bool isListOptionalTensor() const { 356 return tag == Tag::ListOptionalTensor; 357 } 358 359 at::ArrayRef<std::optional<at::Tensor>> toListOptionalTensor() { 360 return payload.copyable_union.as_list_optional_tensor.get(); 361 } 362 363 /****** ScalarType Type ******/ 364 at::ScalarType toScalarType() const { 365 ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); 366 return static_cast<at::ScalarType>(payload.copyable_union.as_int); 367 } 368 369 /****** MemoryFormat Type ******/ 370 at::MemoryFormat toMemoryFormat() const { 371 ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); 372 return static_cast<at::MemoryFormat>(payload.copyable_union.as_int); 373 } 374 375 template <typename T> 376 T to() &&; 377 378 template <typename T> 379 typename evalue_to_ref_overload_return<T>::type to() &; 380 381 /** 382 * Converts the EValue to an optional object that can represent both T and 383 * an uninitialized state. 384 */ 385 template <typename T> 386 inline std::optional<T> toOptional() { 387 if (this->isNone()) { 388 return std::nullopt; 389 } 390 return this->to<T>(); 391 } 392 393 private: 394 // Pre cond: the payload value has had its destructor called 395 void clearToNone() noexcept { 396 payload.copyable_union.as_int = 0; 397 tag = Tag::None; 398 } 399 400 // Shared move logic 401 void moveFrom(EValue&& rhs) noexcept { 402 if (rhs.isTensor()) { 403 new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); 404 rhs.payload.as_tensor.~Tensor(); 405 } else { 406 payload.copyable_union = rhs.payload.copyable_union; 407 } 408 tag = rhs.tag; 409 rhs.clearToNone(); 410 } 411 412 // Destructs stored tensor if there is one 413 void destroy() { 414 // Necessary for ATen tensor to refcount decrement the intrusive_ptr to 415 // tensorimpl that got a refcount increment when we placed it in the evalue, 416 // no-op if executorch tensor #ifdef could have a 417 // minor performance bump for a code maintainability hit 418 if (isTensor()) { 419 payload.as_tensor.~Tensor(); 420 } else if (isTensorList()) { 421 for (auto& tensor : toTensorList()) { 422 tensor.~Tensor(); 423 } 424 } else if (isListOptionalTensor()) { 425 for (auto& optional_tensor : toListOptionalTensor()) { 426 optional_tensor.~optional(); 427 } 428 } 429 } 430 431 EValue(const Payload& p, Tag t) : tag(t) { 432 if (isTensor()) { 433 new (&payload.as_tensor) at::Tensor(p.as_tensor); 434 } else { 435 payload.copyable_union = p.copyable_union; 436 } 437 } 438 }; 439 440 #define EVALUE_DEFINE_TO(T, method_name) \ 441 template <> \ 442 inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \ 443 return static_cast<T>(this->method_name()); \ 444 } 445 446 template <> 447 inline at::Tensor& EValue::to<at::Tensor>() & { 448 return this->toTensor(); 449 } 450 451 EVALUE_DEFINE_TO(at::Scalar, toScalar) 452 EVALUE_DEFINE_TO(int64_t, toInt) 453 EVALUE_DEFINE_TO(bool, toBool) 454 EVALUE_DEFINE_TO(double, toDouble) 455 EVALUE_DEFINE_TO(at::string_view, toString) 456 EVALUE_DEFINE_TO(at::ScalarType, toScalarType) 457 EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat) 458 EVALUE_DEFINE_TO(std::optional<at::Tensor>, toOptional<at::Tensor>) 459 EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList) 460 EVALUE_DEFINE_TO( 461 std::optional<at::ArrayRef<int64_t>>, 462 toOptional<at::ArrayRef<int64_t>>) 463 EVALUE_DEFINE_TO( 464 std::optional<at::ArrayRef<double>>, 465 toOptional<at::ArrayRef<double>>) 466 EVALUE_DEFINE_TO(at::ArrayRef<std::optional<at::Tensor>>, toListOptionalTensor) 467 EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList) 468 #undef EVALUE_DEFINE_TO 469 470 template <typename T> 471 at::ArrayRef<T> EValObjectList<T>::get() const { 472 for (size_t i = 0; i < wrapped_vals_.size(); i++) { 473 unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>(); 474 } 475 return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()}; 476 } 477 478 } // namespace executor 479 } // namespace torch 480