1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_CUSTOM_CALL_H_ 17 #define XLA_RUNTIME_CUSTOM_CALL_H_ 18 19 #include <any> 20 #include <cstddef> 21 #include <cstdint> 22 #include <functional> 23 #include <iterator> 24 #include <numeric> 25 #include <string> 26 #include <tuple> 27 #include <type_traits> 28 #include <utility> 29 #include <vector> 30 31 #include "absl/base/dynamic_annotations.h" 32 #include "third_party/eigen3/Eigen/Core" 33 #include "llvm/ADT/Any.h" 34 #include "llvm/ADT/ArrayRef.h" 35 #include "llvm/ADT/STLExtras.h" 36 #include "llvm/ADT/StringExtras.h" 37 #include "llvm/ADT/StringMap.h" 38 #include "llvm/ADT/StringRef.h" 39 #include "llvm/Support/Compiler.h" 40 #include "llvm/Support/Error.h" 41 #include "tensorflow/compiler/xla/runtime/diagnostics.h" 42 #include "tensorflow/compiler/xla/runtime/logical_result.h" 43 #include "tensorflow/compiler/xla/runtime/type_id.h" 44 #include "tfrt/dtype/dtype.h" // from @tf_runtime 45 #include "tfrt/support/map_by_type.h" // from @tf_runtime 46 47 namespace xla { 48 namespace runtime { 49 50 // Forward declare. 51 struct KernelContext; 52 53 // Forward declare template defined below. 54 template <typename... Ts> 55 class CustomCallBinding; 56 57 // Registers mappings from TypeIDs supported by the custom calls to their unique 58 // names in the given registry. 59 void PopulateCustomCallTypeIdNames(TypeIDNameRegistry& registry); 60 61 class CustomCall { 62 public: 63 // Container for passing data between XLA user and the custom call handler. 64 using UserData = tfrt::PtrMapByType<CustomCall>; 65 66 // A type for matching all remaining custom call arguments. 67 class RemainingArgs; 68 69 // A type for passing an argument of different types at the same position, 70 // and the handler will do the decoding. 71 class VariantArg; 72 class VariantAttr; 73 74 // A type for representing tensors with shapes. 75 template <typename T> 76 struct TensorRef { 77 llvm::ArrayRef<int64_t> shape; 78 llvm::ArrayRef<T> data; 79 }; 80 81 // Custom call handler can check arguments and attributes types and names 82 // at runtime, however this comes at extra cost and can be optionally 83 // disabled. If the version of the compiler that generated the XLA executable 84 // doesn't match the custom call handler, it can lead to undefined behavior. 85 enum class RuntimeChecks : uint8_t { 86 // Check arguments and attributes types, also check attribute names. It is 87 // safe to pass extra arguments to the custom call handler when name 88 // checking is enabled, because it will safely skip irrelevant attributes. 89 kDefault = 0, 90 91 // Check only the types of the arguments and attributes. If an attribute 92 // with the same type but different name is passed to the custom call 93 // handler, 94 // it will happily proceed ignoring the name mismatch. 95 kTypes = 1, 96 97 // Do not check the number of arguments and attributes and their types, and 98 // do not check that the user data was passed to the custom call. This is 99 // the most dangerous option, because it blindly reinterprets opaque memory 100 // passed to the handler, and can easily lead to segfaults if the data 101 // doesn't match the expected custom call signature. 102 kNone = 2 103 }; 104 105 // Allows to bind custom calls to handlers with optional arguments without 106 // spelling the full type. 107 // 108 // Example: 109 // 110 // LogicalResult MyCustomCall(Optional<int32_t> version); 111 // 112 // CustomCall::Bind("api").Value(CustomCall::None).To(MyCustomCall); 113 // 114 // Works around the fact that llvm::Optional can't store an instance of 115 // llvm::NoneType (llvm::Optional<llvm::NoneType> has ambiguous constructor). 116 struct NoneType { 117 template <typename T> 118 operator llvm::Optional<T>() const { // NOLINT 119 return llvm::None; 120 } 121 }; 122 123 static constexpr NoneType None = {}; // NOLINT 124 CheckNames(RuntimeChecks checks)125 static constexpr bool CheckNames(RuntimeChecks checks) { 126 return checks == RuntimeChecks::kDefault; 127 } 128 CheckTypes(RuntimeChecks checks)129 static constexpr bool CheckTypes(RuntimeChecks checks) { 130 return checks != RuntimeChecks::kNone; 131 } 132 CheckUserData(RuntimeChecks checks)133 static constexpr bool CheckUserData(RuntimeChecks checks) { 134 return checks != RuntimeChecks::kNone; 135 } 136 137 template <typename T> CheckType(RuntimeChecks checks,TypeID type_id)138 static bool CheckType(RuntimeChecks checks, TypeID type_id) { 139 return !CheckTypes(checks) || type_id == TypeID::get<T>(); 140 } 141 142 virtual ~CustomCall() = default; 143 144 virtual llvm::StringRef name() const = 0; 145 virtual LogicalResult call(void** args, void** attrs, 146 const UserData* user_data, 147 const DiagnosticEngine* diagnostic) const = 0; 148 149 static CustomCallBinding<> Bind(std::string callee); 150 }; 151 152 // Direct custom call is a custom call that can be linked directly with the 153 // compiled executable, and doesn't have to go through the custom call look up 154 // by name at run time (see CustomCallRegistry). 155 // 156 // Direct custom call is a preffered way of implemenenting custom calls with 157 // low run time overheads, as they will become just an indirect function calls 158 // once LLVM ORC links them with the executable. 159 // 160 // See `GetSymbolsBinding` to convert custom call library to symbols binding. 161 class DirectCustomCallLibrary { 162 public: 163 // Function type corresponding to the direct custom call (custom calls 164 // linked directly with the compiled executable). 165 using DirectCustomCall = bool (*)(KernelContext* kernel_context, void** args, 166 void** attrs); 167 Insert(llvm::StringRef name,DirectCustomCall custom_call)168 void Insert(llvm::StringRef name, DirectCustomCall custom_call) { 169 lib_.try_emplace(name, custom_call); 170 } 171 ForEach(std::function<void (llvm::StringRef,DirectCustomCall)> f)172 void ForEach(std::function<void(llvm::StringRef, DirectCustomCall)> f) const { 173 for (auto& kv : lib_) f(kv.first(), kv.second); 174 } 175 176 private: 177 llvm::StringMap<DirectCustomCall> lib_; 178 }; 179 180 // Forward declare template defined below. 181 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts> 182 class CustomCallHandler; 183 184 namespace internal { 185 186 // A type tag to distinguish arguments tied to the attributes in the 187 // `CustomCallBinding` variadic template argument. 188 template <typename T> 189 struct Attr {}; 190 191 // A type tag to distinguish arguments tied to the user data in the 192 // `CustomCallBinding` variadic template argument. 193 template <typename T> 194 struct UserData {}; 195 196 // A type tag to distinguish arguments tied to the constant values in the 197 // `CustomCallBinding` variadic template argument. 198 template <typename T> 199 struct Value {}; 200 201 // A template for checking if type is a wrapped attribute or user data. 202 template <typename> 203 struct IsWrapped : std::false_type {}; 204 205 template <typename T> 206 struct IsWrapped<internal::Attr<T>> : std::true_type {}; 207 208 template <typename T> 209 struct IsWrapped<internal::UserData<T>> : std::true_type {}; 210 211 template <typename T> 212 struct IsWrapped<internal::Value<T>> : std::true_type {}; 213 214 // Checks if remaining arguments are in the parameter pack. 215 template <typename... Ts> 216 using HasRemainingArgs = 217 std::disjunction<std::is_same<CustomCall::RemainingArgs, Ts>...>; 218 219 } // namespace internal 220 221 // Custom call binding describes the function signature of the expected custom 222 // call handler using its variadic template parameter. 223 // 224 // Custom call binding: 225 // CustomCallBinding<int32_t, MemrefView> 226 // 227 // Function signature: 228 // LogicalResult MyHandle(int32_t algo, MemrefView memref); 229 // 230 template <typename... Ts> 231 class CustomCallBinding { 232 public: 233 using RuntimeChecks = CustomCall::RuntimeChecks; 234 235 template <typename T> 236 CustomCallBinding<Ts..., T> Arg() && { 237 return {std::move(*this)}; 238 } 239 240 CustomCallBinding<Ts..., CustomCall::RemainingArgs> RemainingArgs() && { 241 static_assert(!internal::HasRemainingArgs<Ts...>::value, 242 "remaining arguments can be passed just once"); 243 return {std::move(*this)}; 244 } 245 246 template <typename T> 247 CustomCallBinding<Ts..., internal::Attr<T>> Attr(std::string attr) && { 248 attrs_.push_back(std::move(attr)); 249 return {std::move(*this)}; 250 } 251 252 template <typename T> 253 CustomCallBinding<Ts..., internal::UserData<T>> UserData() && { 254 static_assert(std::is_pointer<T>::value, "user data must be a pointer"); 255 return {std::move(*this)}; 256 } 257 258 template <typename T> 259 CustomCallBinding<Ts..., internal::Value<T>> Value(T value) && { 260 values_.push_back(std::move(value)); 261 return {std::move(*this)}; 262 } 263 264 template <RuntimeChecks checks = RuntimeChecks::kDefault, typename Fn> 265 std::unique_ptr<CustomCallHandler<checks, Fn, Ts...>> To(Fn fn) { 266 return std::unique_ptr<CustomCallHandler<checks, Fn, Ts...>>( 267 new CustomCallHandler<checks, Fn, Ts...>( 268 std::forward<Fn>(fn), std::move(callee_), std::move(attrs_), 269 std::move(values_))); 270 } 271 272 private: 273 template <typename...> 274 friend class CustomCallBinding; 275 friend class CustomCall; 276 277 explicit CustomCallBinding(std::string callee) : callee_(std::move(callee)) { 278 static_assert(sizeof...(Ts) == 0, "custom call arguments must be empty"); 279 } 280 281 template <typename... TTs> 282 CustomCallBinding(CustomCallBinding<TTs...>&& other) // NOLINT 283 : callee_(std::move(other.callee_)), 284 attrs_(std::move(other.attrs_)), 285 values_(std::move(other.values_)) {} 286 287 CustomCallBinding(CustomCallBinding&) = delete; 288 289 std::string callee_; // custom call target 290 std::vector<std::string> attrs_; // names of bound attributes 291 std::vector<llvm::Any> values_; // values bound to arguments 292 }; 293 294 inline CustomCallBinding<> CustomCall::Bind(std::string callee) { 295 return CustomCallBinding<>(std::move(callee)); 296 } 297 298 // Custom call arguments decoding must be defined by specializing this template. 299 // 300 // Example: decoding for the `MyType` arguments 301 // 302 // template <CustomCall::RuntimeChecks checks> 303 // struct CustomCallArgDecoding<MyType, checks> { 304 // static FailureOr<MyType> Decode(TypeID type_id, void* value); 305 // }; 306 // 307 template <typename T, CustomCall::RuntimeChecks> 308 struct CustomCallArgDecoding; 309 310 // Custom call attribute decoding must be defined by specializing this template. 311 // 312 // Example: decoding for the `MyType` attributes 313 // 314 // template <CustomCall::RuntimeChecks checks> 315 // struct CustomCallAttrDecoding<MyType, checks> { 316 // static FailureOr<MyType> Decode(llvm::StringRef name, 317 // TypeID type_id, void* value); 318 // } 319 // 320 template <typename T, CustomCall::RuntimeChecks> 321 struct CustomCallAttrDecoding; 322 323 // A type tag to declare MLIR TypeID specializations for types passed to the 324 // custom calls. We don't want to declare specializations for scalar types 325 // directly in this translation unit, so we rely on a tag to wrap them. 326 // 327 // See explicit TypeID declarations at the end of this file. 328 template <typename T> 329 struct Tagged {}; 330 331 // A type tag to represent empty arrays of unknown element type. 332 struct EmptyArrayRef {}; 333 334 //===----------------------------------------------------------------------===// 335 // C structures corresponding to the `rt-to-llvm` pass LLVM structs encoding 336 // various types of arguments/attributes. 337 338 namespace internal { 339 340 struct EncodedMemref { 341 uint8_t dtype; 342 uint8_t rank; 343 void* data; 344 int64_t dims[]; 345 }; 346 347 template <typename T> 348 struct EncodedArray { 349 int64_t size; 350 const T* data; 351 }; 352 353 template <typename T> 354 struct EncodedDenseElements { 355 struct EncodedArray<T> payload; 356 int64_t rank; 357 int64_t shape[]; 358 }; 359 360 } // namespace internal 361 362 //===----------------------------------------------------------------------===// 363 // Helpers for decoding opaque arguments and attributes memory. 364 365 namespace internal { 366 367 // Decoded pair of an argument type and opaque value. 368 struct DecodedArg { 369 TypeID type_id; 370 void* value; 371 }; 372 373 // Decoded triple of an attribute name, type and opaque value. 374 struct DecodedAttr { 375 llvm::StringRef name; 376 TypeID type_id; 377 void* value; 378 }; 379 380 // A convenience wrapper around opaque arguments memory. 381 class DecodedArgs { 382 public: 383 explicit DecodedArgs(void** args) 384 : args_(args), num_args_(*reinterpret_cast<int64_t*>(args_[0])) {} 385 386 LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t size() const { return num_args_; } 387 388 LLVM_ATTRIBUTE_ALWAYS_INLINE DecodedArg operator[](size_t i) const { 389 void** arg_base = args_ + 1 + i * 2; 390 391 DecodedArg arg; 392 arg.type_id = TypeID::getFromOpaquePointer(arg_base[0]); 393 arg.value = arg_base[1]; 394 395 return arg; 396 } 397 398 private: 399 void** args_; 400 int64_t num_args_; 401 }; 402 403 // A convenience wrapper around opaque attributes memory. 404 class DecodedAttrs { 405 public: 406 explicit DecodedAttrs(void** attrs) 407 : attrs_(attrs), num_attrs_(*reinterpret_cast<int64_t*>(attrs_[0])) {} 408 409 LLVM_ATTRIBUTE_ALWAYS_INLINE int64_t size() const { return num_attrs_; } 410 411 LLVM_ATTRIBUTE_ALWAYS_INLINE DecodedAttr operator[](size_t i) const { 412 void** attr_base = attrs_ + 1 + i * 3; 413 414 DecodedAttr attr; 415 auto* name = reinterpret_cast<internal::EncodedArray<char>*>(attr_base[0]); 416 attr.name = llvm::StringRef(name->data, name->size); 417 attr.type_id = TypeID::getFromOpaquePointer(attr_base[1]); 418 attr.value = attr_base[2]; 419 420 return attr; 421 } 422 423 private: 424 void** attrs_; 425 int64_t num_attrs_; 426 }; 427 428 } // namespace internal 429 430 //===----------------------------------------------------------------------===// 431 // CustomCall remaining arguments wraps the type-erased `DecodedArg` container, 432 // and provides a type-safe API for accessing individual arguments. 433 434 class CustomCall::RemainingArgs { 435 public: 436 using RuntimeChecks = CustomCall::RuntimeChecks; 437 438 RemainingArgs(internal::DecodedArgs args, size_t offset) 439 : args_(args), offset_(offset) { 440 assert(offset <= args_.size() && "illegal remaining args offset"); 441 } 442 443 size_t size() const { return args_.size() - offset_; } 444 bool empty() const { return size() == 0; } 445 446 template <typename T> 447 bool isa(size_t index) const { 448 return args_[index + offset_].type_id == TypeID::get<Tagged<T>>(); 449 } 450 451 template <typename T, RuntimeChecks checks = RuntimeChecks::kDefault> 452 FailureOr<T> get(size_t index) const { 453 internal::DecodedArg arg = args_[index + offset_]; 454 return CustomCallArgDecoding<T, checks>::Decode(arg.type_id, arg.value); 455 } 456 457 private: 458 internal::DecodedArgs args_; 459 size_t offset_; 460 }; 461 462 class CustomCall::VariantArg { 463 public: 464 using RuntimeChecks = CustomCall::RuntimeChecks; 465 466 VariantArg(internal::DecodedArgs args, size_t offset) 467 : args_(args), offset_(offset) { 468 assert(offset <= args_.size() && "illegal remaining args offset"); 469 } 470 471 template <typename T> 472 bool isa() const { 473 return args_[offset_].type_id == TypeID::get<Tagged<T>>(); 474 } 475 476 template <typename T, RuntimeChecks checks = RuntimeChecks::kDefault> 477 FailureOr<T> get() const { 478 internal::DecodedArg arg = args_[offset_]; 479 return CustomCallArgDecoding<T, checks>::Decode(arg.type_id, arg.value); 480 } 481 482 private: 483 internal::DecodedArgs args_; 484 size_t offset_; 485 }; 486 487 class CustomCall::VariantAttr { 488 public: 489 using RuntimeChecks = CustomCall::RuntimeChecks; 490 491 VariantAttr(llvm::StringRef name, TypeID type_id, void* value) 492 : name_(name), type_id_(type_id), value_(value) {} 493 494 template <typename T> 495 bool isa() const { 496 return type_id_ == TypeID::get<Tagged<T>>(); 497 } 498 499 template <typename T, RuntimeChecks checks = RuntimeChecks::kDefault> 500 FailureOr<T> get() const { 501 return CustomCallAttrDecoding<T, checks>::Decode(name_, type_id_, value_); 502 } 503 504 private: 505 llvm::StringRef name_; 506 TypeID type_id_; 507 void* value_; 508 }; 509 510 //===----------------------------------------------------------------------===// 511 // A little bit of template metaprogramming to implement type safe binding 512 // of custom calls to C++ functions. This is internal implementation details, 513 // and must not be relied on in any of the client code. 514 515 namespace internal { 516 517 // A helper struct to extract the type of the handler argument. 518 template <typename T> 519 struct FnArgType { 520 using Type = T; 521 }; 522 523 // Extracts the underlying type from the attribute type tag. 524 template <typename T> 525 struct FnArgType<internal::Attr<T>> { 526 using Type = T; 527 }; 528 529 // Extracts the underlying type from the user data type tag. 530 template <typename T> 531 struct FnArgType<internal::UserData<T>> { 532 using Type = T; 533 }; 534 535 // Extracts the underlying type from the value type tag. 536 template <typename T> 537 struct FnArgType<internal::Value<T>> { 538 using Type = T; 539 }; 540 541 // A template for counting regular arguments in the Ts pack. 542 template <typename T, typename... Ts> 543 struct NumArgs { 544 static constexpr int64_t value = !IsWrapped<T>::value + NumArgs<Ts...>::value; 545 }; 546 547 template <typename T> 548 struct NumArgs<T> { 549 static constexpr int64_t value = !IsWrapped<T>::value; 550 }; 551 552 // When decoding input data we need to keep track of how many arguments and 553 // attributes we decoded so far to index into the correct data strucuture. 554 struct DecodingOffsets { 555 int64_t args = 0; 556 int64_t attrs = 0; 557 int64_t values = 0; 558 }; 559 560 template <typename T, CustomCall::RuntimeChecks checks> 561 struct Decode { 562 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call( 563 DecodingOffsets& offsets, internal::DecodedArgs args, 564 llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx, 565 internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values, 566 const CustomCall::UserData* user_data) { 567 internal::DecodedArg arg = args[offsets.args++]; 568 return CustomCallArgDecoding<T, checks>::Decode(arg.type_id, arg.value); 569 } 570 }; 571 572 template <typename T, CustomCall::RuntimeChecks checks> 573 struct Decode<internal::Attr<T>, checks> { 574 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call( 575 DecodingOffsets& offsets, internal::DecodedArgs args, 576 llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx, 577 internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values, 578 const CustomCall::UserData* user_data) { 579 // Find decoded attribute corresponding for the given attribute index. 580 int64_t idx = offsets.attrs++; 581 582 // Do not check the attribute name, and decode attribute at the given index. 583 if (!CustomCall::CheckNames(checks)) { 584 size_t i = attrs_idx[idx]; 585 return CustomCallAttrDecoding<T, checks>::Decode( 586 attrs[i].name, attrs[i].type_id, attrs[i].value); 587 } 588 589 llvm::StringRef attr = attrs_names[idx]; 590 591 // Given that attributes are passed to the custom call handler 592 // lexicographically sorted by name, we can find the attribute we are 593 // looking for only between the `attrs_idx` offset and the end of the 594 // attributes array. 595 for (size_t i = attrs_idx[idx]; i < attrs.size(); ++i) { 596 if (LLVM_LIKELY(attrs[i].name == attr)) 597 return CustomCallAttrDecoding<T, checks>::Decode( 598 attrs[i].name, attrs[i].type_id, attrs[i].value); 599 } 600 601 // Attribute we were looking for was not passed as an argument. 602 return mlir::failure(); 603 } 604 }; 605 606 template <typename T, CustomCall::RuntimeChecks checks> 607 struct Decode<internal::UserData<T>, checks> { 608 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call( 609 DecodingOffsets& offsets, internal::DecodedArgs args, 610 llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx, 611 internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values, 612 const CustomCall::UserData* user_data) { 613 using UserDataT = std::remove_pointer_t<T>; 614 615 if (!CustomCall::CheckUserData(checks)) return user_data->get<UserDataT>(); 616 617 // TODO(ezhulenev): Add an option to request nullable user data, because 618 // right now we do not distinguish between a user data pointer that doesn't 619 // exist, and a null pointer passed by the user. 620 621 // Get the requested value if user data was passed to the custom call. 622 auto* ptr = user_data ? user_data->getIfExists<UserDataT>() : nullptr; 623 if (LLVM_UNLIKELY(!ptr)) return mlir::failure(); 624 return ptr; 625 } 626 }; 627 628 template <typename T, CustomCall::RuntimeChecks checks> 629 struct Decode<internal::Value<T>, checks> { 630 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> call( 631 DecodingOffsets& offsets, internal::DecodedArgs args, 632 llvm::ArrayRef<std::string> attrs_names, llvm::ArrayRef<size_t> attrs_idx, 633 internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values, 634 const CustomCall::UserData* user_data) { 635 return llvm::any_cast<T>(values[offsets.values++]); 636 } 637 }; 638 639 template <CustomCall::RuntimeChecks checks> 640 struct Decode<CustomCall::RemainingArgs, checks> { 641 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::RemainingArgs> call( 642 DecodingOffsets& offsets, internal::DecodedArgs args, 643 llvm::ArrayRef<std::string> attr_names, llvm::ArrayRef<size_t> attrs_idx, 644 internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values, 645 const CustomCall::UserData* user_data) { 646 return CustomCall::RemainingArgs(args, offsets.args); 647 } 648 }; 649 650 template <CustomCall::RuntimeChecks checks> 651 struct Decode<CustomCall::VariantArg, checks> { 652 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::VariantArg> call( 653 DecodingOffsets& offsets, internal::DecodedArgs args, 654 llvm::ArrayRef<std::string> attr_names, llvm::ArrayRef<size_t> attrs_idx, 655 internal::DecodedAttrs attrs, llvm::ArrayRef<llvm::Any> values, 656 const CustomCall::UserData* user_data) { 657 return CustomCall::VariantArg(args, offsets.args++); 658 } 659 }; 660 661 } // namespace internal 662 663 // Custom call handler binds concrete custom call implementation of type `Fn` to 664 // the custom call function signature. `Fn` can be a function pointer, or a 665 // lambda. 666 // 667 // Custom call handler uses the variadic template parameter `Ts` to decode the 668 // opaque pointers passed to the `call` function into the C++ types that are 669 // forwarded to the custom call implementation. 670 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts> 671 class CustomCallHandler : public CustomCall { 672 static constexpr int64_t kSize = sizeof...(Ts); 673 static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value; 674 675 template <typename T> 676 using FnArgType = typename internal::FnArgType<T>::Type; 677 678 // Custom call can signal error using a LogicalError result. 679 static constexpr bool kIsLogicalErr = 680 std::is_invocable_r_v<LogicalResult, Fn, FnArgType<Ts>...>; 681 682 // Custom call can signal error together with a detailed error message. 683 static constexpr bool kIsDetailedErr = 684 std::is_invocable_r_v<llvm::Error, Fn, FnArgType<Ts>...>; 685 686 static_assert(kIsLogicalErr || kIsDetailedErr, 687 "incompatible custom call handler types"); 688 689 public: 690 llvm::StringRef name() const final { return callee_; } 691 692 LLVM_ATTRIBUTE_ALWAYS_INLINE LogicalResult 693 call(void** args, void** attrs, const UserData* user_data, 694 const DiagnosticEngine* diagnostic) const final { 695 // Unpoison the first pointer to get the args and attrs sizes. 696 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(args, sizeof(void*)); 697 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(attrs, sizeof(void*)); 698 699 // Decode arguments and attributes from the opaque pointers. 700 internal::DecodedArgs decoded_args(args); 701 internal::DecodedAttrs decoded_attrs(attrs); 702 703 int64_t num_args = decoded_args.size(); 704 int64_t num_attrs = decoded_attrs.size(); 705 706 // Unpoison the rest of the of args and attrs data. 707 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(args, 708 (1 + 2 * num_args) * sizeof(void*)); 709 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(attrs, 710 (1 + 3 * num_attrs) * sizeof(void*)); 711 712 if (LLVM_UNLIKELY(diagnostic == nullptr)) 713 diagnostic = DiagnosticEngine::DefaultDiagnosticEngine(); 714 715 // If all runtime checks are disabled we are just reinterpreting opaque 716 // `args` and `attrs` memory acording to the requested handler signature. 717 if (checks != RuntimeChecks::kNone) { 718 // Check that the number of passed arguments matches the signature. Each 719 // individual argument decoding will check the actual type. 720 if (internal::HasRemainingArgs<Ts...>::value) { 721 if (LLVM_UNLIKELY(num_args < kNumArgs - 1)) 722 return diagnostic->EmitError() 723 << "Wrong number of arguments: expected at least " 724 << (kNumArgs - 1) << " got " << num_args; 725 } else { 726 if (LLVM_UNLIKELY(num_args != kNumArgs)) 727 return diagnostic->EmitError() 728 << "Wrong number of arguments: expected " << kNumArgs 729 << " got " << num_args; 730 } 731 732 // Check that we have enough attributes passed to the custom call. Each 733 // individual attribute decoding will check the name and the type. 734 if (LLVM_UNLIKELY(num_attrs < attrs_.size())) 735 return diagnostic->EmitError() 736 << "Wrong number of attributes: expected at least " 737 << attrs_.size() << " got " << num_attrs; 738 } 739 740 return call(decoded_args, decoded_attrs, user_data, diagnostic, 741 std::make_index_sequence<kSize>{}); 742 } 743 744 template <size_t... Is> 745 LLVM_ATTRIBUTE_ALWAYS_INLINE LogicalResult 746 call(internal::DecodedArgs args, internal::DecodedAttrs attrs, 747 const UserData* user_data, const DiagnosticEngine* diagnostic, 748 std::index_sequence<Is...>) const { 749 // A helper structure to allow each decoder find the correct offset in the 750 // arguments or attributes. 751 internal::DecodingOffsets offsets; 752 753 // Check if all arguments and attributes were decoded. 754 bool all_decoded = true; 755 auto check_all_decoded = [&](auto result) { 756 all_decoded &= mlir::succeeded(result); 757 return std::move(result); 758 }; 759 760 // Decode all arguments into FailureOr containers. It is guaranteed 761 // that initializer list will be evaluated left-to-right, and we can rely 762 // on correct offsets computation. 763 std::tuple<FailureOr<FnArgType<Ts>>...> fn_args = { 764 check_all_decoded(internal::Decode<Ts, checks>::call( 765 offsets, args, attrs_, attrs_idx_, attrs, values_, user_data))...}; 766 if (LLVM_UNLIKELY(!all_decoded)) 767 return diagnostic->EmitError() 768 << "Failed to decode all custom call arguments and attributes"; 769 770 // Custom call returns logical result to signal failures. 771 if constexpr (kIsLogicalErr) 772 return fn_(std::move(*std::get<Is>(fn_args))...); 773 774 // Custom call returns detailed error to signal failures. 775 if constexpr (kIsDetailedErr) { 776 if (auto err = fn_(std::move(*std::get<Is>(fn_args))...)) 777 return diagnostic->EmitError() << std::move(err); 778 return mlir::success(); 779 } 780 781 llvm_unreachable("unexpected custom call type"); 782 } 783 784 private: 785 template <typename...> 786 friend class CustomCallBinding; 787 788 CustomCallHandler(Fn fn, std::string callee, std::vector<std::string> attrs, 789 std::vector<llvm::Any> values) 790 : fn_(std::move(fn)), 791 callee_(std::move(callee)), 792 attrs_(std::move(attrs)), 793 values_(std::move(values)), 794 attrs_idx_(attrs_.size()) { 795 // Sort attributes names. 796 std::vector<std::string> sorted = attrs_; 797 llvm::sort(sorted); 798 799 // Find index or every attribute in the sorted attributes vector. 800 for (size_t i = 0; i < attrs_.size(); ++i) { 801 const std::string& attr = attrs_[i]; 802 attrs_idx_[i] = std::distance(sorted.begin(), llvm::find(sorted, attr)); 803 } 804 } 805 806 Fn fn_; 807 std::string callee_; 808 std::vector<std::string> attrs_; 809 std::vector<llvm::Any> values_; 810 // A mapping from the attribute index to its index in the lexicographically 811 // sorter vector of attribute names. Attributes passed in the custom call 812 // handler sorted by the name, we use this index to efficiently find the 813 // decoded attribute entry. 814 std::vector<size_t> attrs_idx_; 815 }; 816 817 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts> 818 constexpr int64_t CustomCallHandler<checks, Fn, Ts...>::kSize; 819 820 template <CustomCall::RuntimeChecks checks, typename Fn, typename... Ts> 821 constexpr int64_t CustomCallHandler<checks, Fn, Ts...>::kNumArgs; 822 823 //===----------------------------------------------------------------------===// 824 // Custom arguments attributes decoding. 825 826 // A view into the memref argument. Corresponds to the MemrefDesc, however it 827 // doesn't own the sizes/strides vectors, and cheap to pass around. Memrefs with 828 // non-identity layouts can be decoded only as a StridedMemrefView. 829 struct StridedMemrefView { 830 tfrt::DType dtype; 831 void* data; 832 llvm::ArrayRef<int64_t> sizes; 833 llvm::ArrayRef<int64_t> strides; 834 }; 835 836 // A view into the memref argument with an identity (row major) layout. 837 struct MemrefView { 838 tfrt::DType dtype; 839 void* data; 840 llvm::ArrayRef<int64_t> sizes; 841 }; 842 843 // A flat view into memref argument with an identity (row major) layout. If the 844 // memref shape and strides are not required for the custom call, it's cheaper 845 // to pass the flat view. 846 struct FlatMemrefView { 847 tfrt::DType dtype; 848 void* data; 849 int64_t size_in_bytes; 850 }; 851 852 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const StridedMemrefView&); 853 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const MemrefView&); 854 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const FlatMemrefView&); 855 856 template <CustomCall::RuntimeChecks checks> 857 struct CustomCallArgDecoding<StridedMemrefView, checks> { 858 using EncodedMemref = internal::EncodedMemref; 859 860 LLVM_ATTRIBUTE_ALWAYS_INLINE 861 static FailureOr<StridedMemrefView> Decode(TypeID type_id, void* value) { 862 if (!(CustomCall::CheckType<Tagged<MemrefView>>(checks, type_id) || 863 CustomCall::CheckType<Tagged<StridedMemrefView>>(checks, type_id))) 864 return mlir::failure(); 865 866 auto* encoded = reinterpret_cast<EncodedMemref*>(value); 867 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(encoded, sizeof(EncodedMemref)); 868 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED( 869 encoded, sizeof(EncodedMemref) + encoded->rank * sizeof(int64_t)); 870 871 tfrt::DType dtype = static_cast<tfrt::DType>(encoded->dtype); 872 return StridedMemrefView{dtype, 873 encoded->data, 874 {encoded->dims, encoded->rank}, 875 {encoded->dims + encoded->rank, encoded->rank}}; 876 } 877 }; 878 879 template <CustomCall::RuntimeChecks checks> 880 struct CustomCallArgDecoding<MemrefView, checks> { 881 using EncodedMemref = internal::EncodedMemref; 882 883 LLVM_ATTRIBUTE_ALWAYS_INLINE 884 static FailureOr<MemrefView> Decode(TypeID type_id, void* value) { 885 if (!CustomCall::CheckType<Tagged<MemrefView>>(checks, type_id)) 886 return mlir::failure(); 887 888 auto* encoded = reinterpret_cast<EncodedMemref*>(value); 889 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(encoded, sizeof(EncodedMemref)); 890 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED( 891 encoded, sizeof(EncodedMemref) + encoded->rank * sizeof(int64_t)); 892 893 tfrt::DType dtype = static_cast<tfrt::DType>(encoded->dtype); 894 return MemrefView{dtype, encoded->data, {encoded->dims, encoded->rank}}; 895 } 896 }; 897 898 template <CustomCall::RuntimeChecks checks> 899 struct CustomCallArgDecoding<FlatMemrefView, checks> { 900 using EncodedMemref = internal::EncodedMemref; 901 902 LLVM_ATTRIBUTE_ALWAYS_INLINE 903 static FailureOr<FlatMemrefView> Decode(TypeID type_id, void* value) { 904 if (!CustomCall::CheckType<Tagged<MemrefView>>(checks, type_id)) 905 return mlir::failure(); 906 907 auto* encoded = reinterpret_cast<EncodedMemref*>(value); 908 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(encoded, sizeof(EncodedMemref)); 909 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED( 910 encoded, sizeof(EncodedMemref) + encoded->rank * sizeof(int64_t)); 911 912 tfrt::DType dtype = static_cast<tfrt::DType>(encoded->dtype); 913 int64_t size_in_bytes = GetHostSize(dtype); 914 for (int d = 0; d < encoded->rank; ++d) size_in_bytes *= encoded->dims[d]; 915 return FlatMemrefView{dtype, encoded->data, size_in_bytes}; 916 } 917 }; 918 919 #define XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(T) \ 920 template <CustomCall::RuntimeChecks checks> \ 921 struct CustomCallArgDecoding<T, checks> { \ 922 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode(TypeID type_id, \ 923 void* value) { \ 924 if (!CustomCall::CheckType<Tagged<T>>(checks, type_id)) \ 925 return mlir::failure(); \ 926 \ 927 return *reinterpret_cast<T*>(value); \ 928 } \ 929 } 930 931 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(bool); 932 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(int32_t); 933 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(int64_t); 934 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(float); 935 XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING(double); 936 937 #undef XLA_RUNTIME_REGISTER_SCALAR_ARG_DECODING 938 939 template <CustomCall::RuntimeChecks checks> 940 struct CustomCallArgDecoding<Eigen::half, checks> { 941 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<Eigen::half> Decode( 942 TypeID type_id, void* value) { 943 if (!CustomCall::CheckType<Tagged<Eigen::half>>(checks, type_id)) 944 return mlir::failure(); 945 946 auto* src = reinterpret_cast<uint16_t*>(value); 947 return Eigen::numext::bit_cast<Eigen::half>(*src); 948 } 949 }; 950 951 //===----------------------------------------------------------------------===// 952 // Custom call attributes decoding. 953 954 template <CustomCall::RuntimeChecks checks> 955 struct CustomCallAttrDecoding<llvm::StringRef, checks> { 956 using StringRef = llvm::StringRef; 957 958 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<StringRef> Decode( 959 llvm::StringRef name, TypeID type_id, void* value) { 960 if (!CustomCall::CheckType<Tagged<StringRef>>(checks, type_id)) 961 return mlir::failure(); 962 963 auto* encoded = reinterpret_cast<internal::EncodedArray<char>*>(value); 964 return StringRef(encoded->data, encoded->size); 965 } 966 }; 967 968 template <CustomCall::RuntimeChecks checks> 969 struct CustomCallAttrDecoding<CustomCall::VariantAttr, checks> { 970 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::VariantAttr> Decode( 971 llvm::StringRef name, TypeID type_id, void* value) { 972 return CustomCall::VariantAttr(name, type_id, value); 973 } 974 }; 975 976 #define XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(T) \ 977 template <CustomCall::RuntimeChecks checks> \ 978 struct CustomCallAttrDecoding<T, checks> { \ 979 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode( \ 980 llvm::StringRef name, TypeID type_id, void* value) { \ 981 if (!CustomCall::CheckType<Tagged<T>>(checks, type_id)) \ 982 return mlir::failure(); \ 983 \ 984 return *reinterpret_cast<T*>(value); \ 985 } \ 986 } 987 988 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(bool); 989 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(int32_t); 990 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(int64_t); 991 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(float); 992 XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING(double); 993 994 #undef XLA_RUNTIME_REGISTER_SCALAR_ATTR_DECODING 995 996 // Both EncodedArray and 1-D EncodedDenseElements can be decoded as an 997 // llvm::ArrayRef. Pointers to both EncodedArray and 1-D EncodedDenseElements 998 // can be dereferenced as a pointer to EncodedArray. 999 #define XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(T) \ 1000 template <CustomCall::RuntimeChecks checks> \ 1001 struct CustomCallAttrDecoding<llvm::ArrayRef<T>, checks> { \ 1002 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<llvm::ArrayRef<T>> Decode( \ 1003 llvm::StringRef name, TypeID type_id, void* value) { \ 1004 if ((!CustomCall::CheckType<Tagged<llvm::ArrayRef<T>>>(checks, \ 1005 type_id)) && \ 1006 (!CustomCall::CheckType<Tagged<CustomCall::TensorRef<T>>>( \ 1007 checks, type_id)) && \ 1008 (!CustomCall::CheckType<Tagged<EmptyArrayRef>>(checks, type_id))) { \ 1009 return mlir::failure(); \ 1010 } \ 1011 \ 1012 auto* encoded = reinterpret_cast<internal::EncodedArray<T>*>(value); \ 1013 return llvm::ArrayRef<T>(encoded->data, encoded->size); \ 1014 } \ 1015 } 1016 1017 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(int32_t); 1018 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(int64_t); 1019 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(float); 1020 XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING(double); 1021 1022 #undef XLA_RUNTIME_REGISTER_ARRAY_ATTR_DECODING 1023 1024 #define XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(T) \ 1025 template <CustomCall::RuntimeChecks checks> \ 1026 struct CustomCallAttrDecoding<CustomCall::TensorRef<T>, checks> { \ 1027 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<CustomCall::TensorRef<T>> \ 1028 Decode(llvm::StringRef name, TypeID type_id, void* value) { \ 1029 if (!CustomCall::CheckType<Tagged<CustomCall::TensorRef<T>>>(checks, \ 1030 type_id)) \ 1031 return mlir::failure(); \ 1032 \ 1033 auto* encoded = \ 1034 reinterpret_cast<internal::EncodedDenseElements<T>*>(value); \ 1035 auto payload = encoded->payload; \ 1036 llvm::ArrayRef<T> data(payload.data, payload.size); \ 1037 llvm::ArrayRef<int64_t> shape(encoded->shape, encoded->rank); \ 1038 return CustomCall::TensorRef<T>({shape, data}); \ 1039 } \ 1040 } 1041 1042 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(int32_t); 1043 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(int64_t); 1044 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(float); 1045 XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING(double); 1046 1047 #undef XLA_RUNTIME_REGISTER_DENSE_ELEMENTS_ATTR_DECODING 1048 1049 //===----------------------------------------------------------------------===// 1050 // Register an XLA custom call attribute decoding for enum class. At runtime the 1051 // value should be passed as the underlying enum type. 1052 //===----------------------------------------------------------------------===// 1053 1054 // Example: register decoding for a user-defined enum class 1055 // 1056 // enum class MyEnumType { kFoo, kBar, kBaz }; 1057 // 1058 // XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(MyEnumType); 1059 // 1060 #define XLA_RUNTIME_REGISTER_ENUM_ATTR_DECODING(T) \ 1061 template <CustomCall::RuntimeChecks checks> \ 1062 struct CustomCallAttrDecoding<T, checks> { \ 1063 static_assert(std::is_enum<T>::value, "expected enum class"); \ 1064 using U = std::underlying_type_t<T>; \ 1065 \ 1066 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode( \ 1067 llvm::StringRef name, TypeID type_id, void* value) { \ 1068 if (!CustomCall::CheckType<Tagged<T>>(checks, type_id)) \ 1069 return mlir::failure(); \ 1070 \ 1071 return static_cast<T>(*reinterpret_cast<U*>(value)); \ 1072 } \ 1073 } 1074 1075 //===----------------------------------------------------------------------===// 1076 // Register an XLA custom call attribute decoding for aggregate attributes. 1077 //===----------------------------------------------------------------------===// 1078 1079 // A workaround for passing braced initializers to macro. 1080 #define XLA_RUNTIME_AGGREGATE_FIELDS(...) \ 1081 { __VA_ARGS__ } 1082 1083 // Example: register decoding for a user-defined struct 1084 // 1085 // struct PairOfI64 { int64_t a; int64_t b; }; 1086 // 1087 // XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING( 1088 // PairOfI64, XLA_RUNTIME_AGGREGATE_FIELDS("a", "b"), 1089 // int64_t, int64_t); 1090 // 1091 #define XLA_RUNTIME_REGISTER_AGGREGATE_ATTR_DECODING(T, NAMES, ...) \ 1092 template <CustomCall::RuntimeChecks checks> \ 1093 struct CustomCallAttrDecoding<T, checks> { \ 1094 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode( \ 1095 llvm::StringRef name, TypeID type_id, void* value) { \ 1096 if (!CustomCall::CheckType<Tagged<T>>(checks, type_id)) \ 1097 return mlir::failure(); \ 1098 \ 1099 using Impl = internal::DecodeAggregateAttr<T, checks, __VA_ARGS__>; \ 1100 return Impl::Decode(reinterpret_cast<void**>(value), NAMES); \ 1101 } \ 1102 } 1103 1104 namespace internal { 1105 // Decodes aggregate attribute into the object of type `T` that must be 1106 // constructible from the `Ts` types. 1107 template <typename T, CustomCall::RuntimeChecks checks, typename... Ts> 1108 struct DecodeAggregateAttr { 1109 static constexpr size_t kSize = sizeof...(Ts); 1110 1111 using RuntimeChecks = CustomCall::RuntimeChecks; 1112 1113 LLVM_ATTRIBUTE_ALWAYS_INLINE 1114 static FailureOr<T> Decode(void** value, 1115 std::array<llvm::StringRef, kSize> names) { 1116 internal::DecodedAttrs attrs(value); 1117 return Decode(attrs, names, std::make_index_sequence<kSize>{}); 1118 } 1119 1120 template <size_t... Is> 1121 LLVM_ATTRIBUTE_ALWAYS_INLINE static FailureOr<T> Decode( 1122 internal::DecodedAttrs attrs, std::array<llvm::StringRef, kSize> names, 1123 std::index_sequence<Is...>) { 1124 // Check that the number of encoded attributes matches the signature. 1125 if (checks != RuntimeChecks::kNone && kSize != attrs.size()) 1126 return mlir::failure(); 1127 1128 // Check that aggregate member names match the expected names. 1129 if (CustomCall::CheckNames(checks)) { 1130 for (unsigned i = 0; i < kSize; ++i) 1131 if (attrs[i].name != names[i]) return mlir::failure(); 1132 } 1133 1134 // Check if all members were decoded. 1135 bool all_decoded = true; 1136 auto check_all_decoded = [&](auto result) { 1137 all_decoded &= mlir::succeeded(result); 1138 return std::move(result); 1139 }; 1140 1141 // Decode all arguments into FailureOr containers. It is guaranteed 1142 // that initializer list will be evaluated left-to-right, and we can rely 1143 // on correct offsets computation. 1144 std::tuple<FailureOr<Ts>...> members = { 1145 check_all_decoded(CustomCallAttrDecoding<Ts, checks>::Decode( 1146 attrs[Is].name, attrs[Is].type_id, attrs[Is].value))...}; 1147 if (LLVM_UNLIKELY(!all_decoded)) return mlir::failure(); 1148 1149 // Forward unpacked members to the type constructor. 1150 return T{std::move(*std::get<Is>(members))...}; 1151 } 1152 }; 1153 } // namespace internal 1154 1155 // Declare/define an explicit specialialization for TypeID for types used 1156 // by the custom calls. This forces the compiler to emit a strong definition for 1157 // a class and controls which translation unit and shared object will actually 1158 // have it. 1159 // 1160 // See TypeID for more documentation. 1161 // 1162 // Because custom calls do not "own" the types passed across the function 1163 // boundary, we declare/define specializations for tagged types to avoid 1164 // potential conflicts with other libraries. 1165 #define XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(T) \ 1166 MLIR_DECLARE_EXPLICIT_TYPE_ID(::xla::runtime::Tagged<T>) 1167 1168 #define XLA_RUNTIME_DEFINE_EXPLICIT_TYPE_ID(T) \ 1169 MLIR_DEFINE_EXPLICIT_TYPE_ID(::xla::runtime::Tagged<T>) 1170 1171 } // namespace runtime 1172 } // namespace xla 1173 1174 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::StringRef); 1175 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(xla::runtime::StridedMemrefView); 1176 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(xla::runtime::MemrefView); 1177 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(xla::runtime::FlatMemrefView); 1178 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(int32_t); 1179 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(int64_t); 1180 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(float); 1181 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(double); 1182 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<int32_t>); 1183 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<int64_t>); 1184 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<float>); 1185 XLA_RUNTIME_DECLARE_EXPLICIT_TYPE_ID(llvm::ArrayRef<double>); 1186 1187 #endif // XLA_RUNTIME_CUSTOM_CALL_H_ 1188