1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Suite of datatypes to represent data-parallel kernel objects (code entities). 17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature 18 // to do some template-based helper generation and give compile-time type 19 // checking for kernel launch parameters. 20 // 21 // Users typically don't see KernelBase, they see typed kernels, analogous to a 22 // typed function pointer. TypedKernels express their argument types via 23 // template parameters like so: 24 // 25 // TypedKernel<DeviceMemory<int>*, int> 26 // 27 // Which expresses a data parallel kernel signature for: 28 // 29 // void(int*, int); 30 // 31 // And for a const memory region: 32 // 33 // TypedKernel<const DeviceMemory<int>&, int> 34 // 35 // Corresponds to a data parallel kernel signature for: 36 // 37 // void(const int*, int) 38 // 39 // Note that kernels always have a void return type, so results typically must 40 // be memcpy'ied from device memory to the host. 41 // 42 // Also note that a scalar integer residing in device memory and an array of 43 // integers residing in device memory have the same signature: DeviceMemory<T>. 44 // However, in the future, checks may be added for additional safety that arrays 45 // of minimum sizes are passed when those minimum sizes are contractually 46 // expected by the kernel. 47 // 48 // For user-defined types whose definitions are appropriately shared between the 49 // host code doing the launching and the kernel code being launched, the user 50 // defined types are similarly permitted to be expressed as residing in device 51 // memory: 52 // 53 // TypedKernel<DeviceMemory<MyUserDefinedStructure>> 54 // 55 // And, when the alignment and padding are agreed upon, POD types will also be 56 // able to be passed by value; for example, it is a common idiom to specify a 57 // bunch of options simultaneously with a structure: 58 // 59 // TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>> 60 // 61 // Which corresponds to a data parallel kernel signature like: 62 // 63 // void(MyOptionsStructurePassedByValue value, float *result); 64 // 65 // Users typically won't need to type out the TypedKernel signature in full, it 66 // will be typedef'd by automatically generated code; for example, see 67 // stream_executor::executor_sample::VecReduceAddKernel. 68 69 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_KERNEL_H_ 70 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_KERNEL_H_ 71 72 #include <array> 73 #include <memory> 74 #include <tuple> 75 #include <type_traits> 76 #include <vector> 77 78 #include "absl/strings/string_view.h" 79 #include "tensorflow/compiler/xla/stream_executor/device_memory.h" 80 #include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h" 81 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h" 82 #include "tensorflow/compiler/xla/stream_executor/platform/port.h" 83 #include "tensorflow/core/platform/logging.h" 84 85 namespace stream_executor { 86 87 class DeviceMemoryBase; 88 template <typename ElemT> 89 class DeviceMemory; 90 class StreamExecutor; 91 92 namespace internal { 93 class KernelInterface; 94 } // namespace internal 95 96 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as 97 // registers allocated, shared memory used, etc. 98 // Not all platforms support reporting of all information, so each accessor 99 // returns false if the associated field is not populated in the underlying 100 // platform. 101 class KernelMetadata { 102 public: KernelMetadata()103 KernelMetadata() 104 : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {} 105 106 // Returns the number of registers used per thread executing this kernel. 107 bool registers_per_thread(int *registers_per_thread) const; 108 109 // Sets the number of registers used per thread executing this kernel. 110 void set_registers_per_thread(int registers_per_thread); 111 112 // Returns the amount of [static] shared memory used per block executing this 113 // kernel. Note that dynamic shared memory allocations are not (and can not) 114 // be reported here (since they're not specified until kernel launch time). 115 bool shared_memory_bytes(int *shared_memory_bytes) const; 116 117 // Sets the amount of [static] shared memory used per block executing this 118 // kernel. 119 void set_shared_memory_bytes(int shared_memory_bytes); 120 121 private: 122 // Holds the value returned by registers_per_thread above. 123 bool has_registers_per_thread_; 124 int registers_per_thread_; 125 126 // Holds the value returned by shared_memory_bytes above. 127 bool has_shared_memory_bytes_; 128 int64_t shared_memory_bytes_; 129 }; 130 131 // A data-parallel kernel (code entity) for launching via the StreamExecutor, 132 // analogous to a void* device function pointer. See TypedKernel for the typed 133 // variant. 134 // 135 // Thread-compatible. 136 class KernelBase { 137 public: 138 KernelBase(KernelBase &&from); 139 140 // Constructs an "empty" (not-yet-loaded) kernel instance. 141 // 142 // parent is the StreamExecutor that will be responsible for loading the 143 // implementation of this kernel. It must not be null. 144 explicit KernelBase(StreamExecutor *parent); 145 146 // Test-only constructor that can take a mock KernelInterface implementation. 147 KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation); 148 149 // Releases resources associated with the kernel instance (i.e. 150 // platform-specific implementation). 151 ~KernelBase(); 152 153 // Returns the number of parameters that this kernel accepts. (Arity refers to 154 // nullary, unary, ...). 155 unsigned Arity() const; 156 157 // Returns the StreamExecutor that represents the platform this kernel 158 // executes upon. parent()159 StreamExecutor *parent() const { return parent_; } 160 161 // Returns a const pointer to the (opaque) platform-dependent implementation. implementation()162 const internal::KernelInterface *implementation() const { 163 return implementation_.get(); 164 } 165 166 // Returns a non-const pointer to the (opaque) platform-dependent 167 // implementation. implementation()168 internal::KernelInterface *implementation() { return implementation_.get(); } 169 set_metadata(const KernelMetadata & metadata)170 void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } 171 metadata()172 const KernelMetadata &metadata() const { return metadata_; } 173 174 // Sets the preferred cache configuration for a kernel. This is just a 175 // suggestion to the runtime, and may not be honored during execution. 176 void SetPreferredCacheConfig(KernelCacheConfig config); 177 178 // Gets the preferred cache configuration for a kernel. 179 KernelCacheConfig GetPreferredCacheConfig() const; 180 181 void set_name(absl::string_view name); name()182 const std::string &name() const { return name_; } demangled_name()183 const std::string &demangled_name() const { return demangled_name_; } 184 185 private: 186 // The StreamExecutor that loads this kernel object. 187 StreamExecutor *parent_; 188 189 // Implementation delegated to for platform-specific functionality. 190 std::unique_ptr<internal::KernelInterface> implementation_; 191 192 std::string name_; 193 std::string demangled_name_; 194 195 KernelMetadata metadata_; 196 197 SE_DISALLOW_COPY_AND_ASSIGN(KernelBase); 198 }; 199 200 // Whether T is a DeviceMemory-family pointer. 201 template <typename T> 202 struct IsDeviceMemoryPointer { 203 static constexpr bool value = false; 204 }; 205 206 template <typename U> 207 struct IsDeviceMemoryPointer<DeviceMemory<U> *> { 208 static constexpr bool value = true; 209 }; 210 211 template <> 212 struct IsDeviceMemoryPointer<DeviceMemoryBase *> { 213 static constexpr bool value = true; 214 }; 215 216 // Whether T is a DeviceMemory-family value-like thing (which includes a 217 // reference). This trait is useful because we pack values in the same manner as 218 // references. 219 template <typename T> 220 struct IsDeviceMemoryValueLike { 221 static constexpr bool value = false; 222 }; 223 224 template <typename U> 225 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> { 226 static constexpr bool value = true; 227 }; 228 229 // We need to treat SharedDeviceMemory types differently than other DeviceMemory 230 // types (since they maintain no allocations), hence these specializations. 231 template <typename U> 232 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> { 233 static constexpr bool value = false; 234 }; 235 236 template <> 237 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> { 238 static constexpr bool value = true; 239 }; 240 241 template <typename U> 242 struct IsDeviceMemoryValueLike<DeviceMemory<U>> { 243 static constexpr bool value = true; 244 }; 245 246 template <typename U> 247 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> { 248 static constexpr bool value = false; 249 }; 250 251 template <> 252 struct IsDeviceMemoryValueLike<DeviceMemoryBase> { 253 static constexpr bool value = true; 254 }; 255 256 template <typename U> 257 struct IsSharedDeviceMemory { 258 static constexpr bool value = false; 259 }; 260 261 template <typename U> 262 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> { 263 static constexpr bool value = true; 264 }; 265 266 template <typename U> 267 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> { 268 static constexpr bool value = true; 269 }; 270 271 // Basic data about a kernel argument. 272 struct KernelArg { 273 bool is_shared; 274 const void *address; 275 size_t size; 276 }; 277 278 // An iterator for traversing all the arguments of a KernelArgsArray. 279 class KernelArgIterator { 280 public: 281 KernelArgIterator(int number_of_argument_addresses, 282 int number_of_shared_memory_arguments, 283 const void *const *arg_addresses_data, 284 const size_t *arg_sizes_data, 285 const size_t *shmem_bytes_data, 286 const size_t *shmem_indices_data) 287 : arg_index_(0), 288 number_of_arguments_(number_of_argument_addresses + 289 number_of_shared_memory_arguments), 290 arg_address_iter_(arg_addresses_data), 291 arg_size_iter_(arg_sizes_data), 292 shmem_bytes_iter_(shmem_bytes_data), 293 shmem_indices_iter_(shmem_indices_data), 294 shmem_indices_end_(shmem_indices_data + 295 number_of_shared_memory_arguments) {} 296 297 // Returns true if another argument is present in the iterator. 298 bool has_next() { return arg_index_ < number_of_arguments_; } 299 300 // Returns the next argument in the iterator. 301 // 302 // Returns a default-constructed KernelArg if there is no next argument. 303 KernelArg next() { 304 KernelArg result = {}; 305 if (!has_next()) { 306 return result; 307 } else if ((shmem_indices_iter_ != shmem_indices_end_) && 308 (arg_index_ == *shmem_indices_iter_)) { 309 result.is_shared = true; 310 result.address = nullptr; 311 result.size = *shmem_bytes_iter_; 312 ++shmem_indices_iter_; 313 ++shmem_bytes_iter_; 314 } else { 315 result.is_shared = false; 316 result.address = *arg_address_iter_; 317 result.size = *arg_size_iter_; 318 ++arg_address_iter_; 319 ++arg_size_iter_; 320 } 321 ++arg_index_; 322 return result; 323 } 324 325 private: 326 size_t arg_index_; 327 size_t number_of_arguments_; 328 const void *const *arg_address_iter_; 329 const size_t *arg_size_iter_; 330 const size_t *shmem_bytes_iter_; 331 const size_t *shmem_indices_iter_; 332 const size_t *const shmem_indices_end_; 333 }; 334 335 // Base class for KernelArgsArray. 336 // 337 // Supports all the getter methods that do not depend on the compile-time number 338 // of arguments template parameter. 339 // 340 // This class exists as a way to pass kernel arguments to 341 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't 342 // be templated to accept any KernelArgsArray type, therefore a reference to 343 // this base type is passed instead. 344 // 345 // Performance is not a concern here because each of these methods will be 346 // called at most once per kernel launch. Past performance concerns with 347 // KernelArgsArray have been in reference to the argument packing routines which 348 // are called once per kernel argument. Those packing routines are now handled 349 // by the templated KernelArgsArray subclass of this class where they can take 350 // advantage of compile-time knowledge of the number of arguments in order to be 351 // very efficient. 352 class KernelArgsArrayBase { 353 public: 354 virtual ~KernelArgsArrayBase() = default; 355 356 // Gets the number of arguments added so far, including shared memory 357 // arguments. 358 virtual size_t number_of_arguments() const = 0; 359 360 // Gets the total number of shared memory bytes added so far. 361 virtual uint64_t number_of_shared_bytes() const = 0; 362 363 // Gets the list of argument addresses. 364 virtual port::ArraySlice<const void *> argument_addresses() // non-absl ok 365 const = 0; 366 367 // Gets an iterator to the arguments in the array. 368 virtual KernelArgIterator arg_iterator() const = 0; 369 }; 370 371 // A list of arguments for a kernel call. 372 // 373 // The template parameter kNumArgs is the maximum number of arguments which can 374 // be stored in the list. 375 // 376 // Contains a list of addresses for non-shared-memory arguments and a list of 377 // sizes for shared-memory arguments. Since the shared-memory arguments may be 378 // interspersed with the non-shared-memory arguments, it also stores a list of 379 // the indices at which the shared-memory arguments appeared. 380 // 381 // For example, if the argument address list contains {a, b, c, d, e}, the 382 // shared-memory arguments list contains the sizes of {A, B, C}, and the 383 // shared-memory indices list contains {0, 3, 5}, then the original list of 384 // arguments was {A, a, b, B, c, C, d, e}. 385 // 386 // This way of storing the arguments makes CUDA kernel calls efficient because 387 // they only require the argument address list and the total number of shared 388 // bytes, but it also makes it possible for OpenCL kernel calls because they 389 // depend on the location of each shared-memory argument and its size. 390 // 391 // Note that the code for adding arguments has been identified as a performance 392 // hotspot in some real-world applications so this structure has been optimized 393 // for the performance of argument adding. 394 template <size_t kNumArgs> 395 class KernelArgsArray : public KernelArgsArrayBase { 396 public: 397 static constexpr int kMaxGenericArgSize = 8; 398 399 // Adds an argument to the list. 400 template <typename T> 401 void add_argument(const T &arg) { 402 static_assert(sizeof(T) <= kMaxGenericArgSize, 403 "Please adjust kMaxGenericArgSize"); 404 static_assert(std::is_pod<T>::value, "Only pod types supported!"); 405 char *generic_arg_storage = 406 &generic_arguments_[number_of_generic_arguments_++ * 407 kMaxGenericArgSize]; 408 409 CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0); 410 std::memcpy(generic_arg_storage, &arg, sizeof(T)); 411 412 argument_addresses_[number_of_argument_addresses_] = generic_arg_storage; 413 argument_sizes_[number_of_argument_addresses_] = sizeof(arg); 414 ++number_of_argument_addresses_; 415 } 416 417 // Adds a device memory argument to the list. 418 void add_device_memory_argument(const DeviceMemoryBase &arg) { 419 const void **copy_ptr = 420 &device_memory_opaque_pointers_[number_of_argument_addresses_]; 421 *copy_ptr = arg.opaque(); 422 argument_addresses_[number_of_argument_addresses_] = copy_ptr; 423 argument_sizes_[number_of_argument_addresses_] = sizeof(void *); 424 ++number_of_argument_addresses_; 425 } 426 427 // Adds a shared memory argument to the list. 428 // 429 // The only significant information about a shared argument is its size, so 430 // that is the only parameter in this function. 431 void add_shared_bytes(size_t number_of_bytes) { 432 shared_memory_indices_[number_of_shared_memory_arguments_] = 433 number_of_argument_addresses_ + number_of_shared_memory_arguments_; 434 shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes; 435 ++number_of_shared_memory_arguments_; 436 total_shared_memory_bytes_ += number_of_bytes; 437 } 438 439 // Gets the number of arguments added so far, including shared memory 440 // arguments. 441 size_t number_of_arguments() const override { 442 return number_of_argument_addresses_ + number_of_shared_memory_arguments_; 443 } 444 445 // Gets the total number of shared memory bytes added so far. 446 uint64_t number_of_shared_bytes() const override { 447 return total_shared_memory_bytes_; 448 } 449 450 // Gets the list of argument addresses. 451 port::ArraySlice<const void *> argument_addresses() // non-absl ok 452 const override { 453 return port::ArraySlice<const void *>( // non-absl ok 454 argument_addresses_.data(), number_of_argument_addresses_); 455 } 456 457 // Gets an iterator to the arguments in the array. 458 KernelArgIterator arg_iterator() const override { 459 return KernelArgIterator( 460 number_of_argument_addresses_, number_of_shared_memory_arguments_, 461 argument_addresses_.data(), argument_sizes_.data(), 462 shared_memory_bytes_.data(), shared_memory_indices_.data()); 463 } 464 465 private: 466 // A place to store copies of opaque pointers from device memory arguments. 467 std::array<const void *, kNumArgs> device_memory_opaque_pointers_; 468 469 // Addresses for non-shared-memory arguments. 470 std::array<const void *, kNumArgs> argument_addresses_; 471 472 // Storage for arguments of templated type. 473 alignas(kMaxGenericArgSize) 474 std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_; 475 476 // Sizes for non-shared-memory arguments. 477 std::array<size_t, kNumArgs> argument_sizes_; 478 479 // Size in bytes for each shared memory argument. 480 std::array<size_t, kNumArgs> shared_memory_bytes_; 481 482 // Indices in the arguments array for shared memory arguments. 483 std::array<size_t, kNumArgs> shared_memory_indices_; 484 485 // Total of all shared memory sizes. 486 size_t total_shared_memory_bytes_ = 0; 487 488 // Number of significant entries in argument_addresses_ and argument_sizes_. 489 size_t number_of_argument_addresses_ = 0; 490 491 // Number of significant entries in shared_memory_bytes_ and 492 // shared_memory_indices_. 493 size_t number_of_shared_memory_arguments_ = 0; 494 495 // The number of generic arguments that have been added to generic_arguments_. 496 size_t number_of_generic_arguments_ = 0; 497 }; 498 499 // Typed variant of KernelBase, like a typed device function pointer. See the 500 // file comment for details and example usage. 501 // 502 // This class contains template metaprogramming magic to type check the 503 // parameters passed to a kernel launch are acceptable, and subsequently pack 504 // them into a form which can be used by the StreamExecutorInterface 505 // implementation. (i.e. CUDA and OpenCL both bind void*s with associated 506 // sizes as kernel arguments.) 507 // 508 // Thread-compatible. 509 template <typename... Params> 510 class TypedKernel : public KernelBase { 511 public: 512 static constexpr size_t kNumberOfParameters = sizeof...(Params); 513 514 // Delegates to KernelBase::KernelBase(), see that constructor. 515 explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {} 516 517 // Test-only constructor that can take a mock KernelInterface implementation. 518 // Takes ownership of implementation, it should not be null. 519 TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation) 520 : KernelBase(parent, implementation) {} 521 522 private: 523 // Stream needs access to the specific parameter-packing functionality that 524 // the TypedKernel provides for its corresponding type signature (and no other 525 // type signatures). 526 friend class Stream; 527 528 // This is the main entry point into the magic. Packs the parameters (which 529 // must type check against the class template) into the args and sizes 530 // arrays. 531 // 532 // Const refs are taken as parameters on all of the handlers to avoid 533 // implicit type promotion of integers. 534 // 535 // WARNING: as a performance optimization this method may store pointers to 536 // some of the input parameters in the kernel args structure, so any params 537 // passed into this method must live at least as long as the kernel args 538 // structure. 539 void PackParams(KernelArgsArray<kNumberOfParameters> *args, 540 Params &... params) const { 541 PackOneParamFromList(args, params...); 542 } 543 544 template <typename T, typename... RestOfParams> 545 void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args, 546 const T &arg, const RestOfParams &... rest) const { 547 PackOneParam(args, arg); 548 PackOneParamFromList(args, rest...); 549 } 550 551 // Base case for variadic template expansion - nothing to do! 552 void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args) const {} 553 554 // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array. 555 // The enable_if<> is for excluding DeviceMemoryBase args, which have a 556 // separate implementation below. 557 template <typename T> 558 void PackOneParam( 559 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 560 typename std::enable_if<!IsDeviceMemoryValueLike<T>::value && 561 !IsDeviceMemoryPointer<T>::value && 562 !IsSharedDeviceMemory<T>::value>::type * = 563 nullptr) const { 564 static_assert(!std::is_pointer<T>::value, 565 "cannot pass raw pointer to the device"); 566 static_assert(!std::is_convertible<T, DeviceMemoryBase>::value, 567 "cannot pass device memory as a normal value"); 568 args->add_argument(arg); 569 } 570 571 // DeviceMemoryBase family reference override. 572 template <typename T> 573 void PackOneParam( 574 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 575 typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * = 576 nullptr) const { 577 args->add_device_memory_argument(arg); 578 } 579 580 // DeviceMemoryBase family pointer override. 581 template <typename T> 582 void PackOneParam( 583 KernelArgsArray<kNumberOfParameters> *args, T arg, 584 typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * = 585 nullptr) const { 586 DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg); 587 args->add_device_memory_argument(*ptr); 588 } 589 590 // Dynamic shared device memory has a size, but no associated allocation on 591 // the host; internally, the device will allocate storage. 592 template <typename T> 593 void PackOneParam( 594 KernelArgsArray<kNumberOfParameters> *args, T arg, 595 typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * = 596 nullptr) const { 597 args->add_shared_bytes(arg.size()); 598 } 599 600 SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel); 601 }; 602 603 // Template metaprogramming helper type that helps us produce better error 604 // messages at compile time when the are mismatches between the parameter 605 // type list and the argument type list. 606 template <typename ParamTuple, typename ArgTuple> 607 struct KernelInvocationChecker { 608 // Whether the parameter tuple and argument tuple match in length. 609 static constexpr bool kLengthMatches = 610 std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value; 611 612 // The (matching) length of the parameters and arguments type lists. 613 static constexpr int kTupleLength = 614 static_cast<int>(std::tuple_size<ArgTuple>::value); 615 616 // Helper trait to say whether the parameter wants a DeviceMemory-reference 617 // compatible type. This is for inexact type matches, so that it doesn't have 618 // to be precisely a const DeviceMemory<T>&, but can also be a value that 619 // represents the same. 620 template <typename ParamType, typename ArgType> 621 struct IsCompatibleDeviceMemoryRef { 622 static constexpr bool value = false; 623 }; 624 625 // See type trait definition above. 626 template <typename U> 627 struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> { 628 static constexpr bool value = true; 629 }; 630 631 // See type trait definition above. 632 template <typename U> 633 struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &, 634 SharedDeviceMemory<U>> { 635 static constexpr bool value = true; 636 }; 637 638 // Returns whether ParamT and ArgT are compatible for data parallel kernel 639 // parameter packing without any assert functionality. 640 template <typename ParamT, typename ArgT> 641 static constexpr bool CompatibleNoAssert() { 642 return std::is_same<typename std::remove_const<ParamT>::type, 643 ArgT>::value || 644 IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value; 645 } 646 647 // Checks whether ParamT and ArgT are compatible for data parallel kernel 648 // parameter packing. kArgumentNumber is unused, it just for error display. 649 // 650 // NOTE: if you encounter an error here, you can see the mismatch by looking 651 // at the end of the last error message, which will be of the form: 652 // 653 // ...::Compatible<const stream_executor::DeviceMemory<OneThing> &, 654 // stream_executor::DeviceMemory<AnotherThing>, true, 655 // 0>' 656 // requested here 657 // 658 // This means that the 0th argument you passed to the kernel invocation should 659 // have been DeviceMemory<OneThing> but was observed to be 660 // DeviceMemory<AnotherThing>. 661 template <typename ParamT, typename ArgT, bool kShouldStaticAssert, 662 int kArgumentNumber> 663 static constexpr bool Compatible() { 664 static_assert( 665 kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true, 666 "parameter type (LHS) is not compatible with argument type (RHS)"); 667 return CompatibleNoAssert<ParamT, ArgT>(); 668 } 669 670 // Checks the parameter/argument match at kArgumentNumber for an out of bounds 671 // argument number. 672 // 673 // This is the base case: we've run out of argument to check, so we're all 674 // good. 675 template <int kArgumentNumber, bool kShouldStaticAssert> 676 static constexpr bool CheckParam( 677 typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) { 678 return true; 679 } 680 681 // Checks the parameter/argument match at kArgumentNumber. 682 // kShouldStaticAssert determines whether to assert out on a mismatch, or just 683 // yield the constexpr boolean value. 684 template <int kArgumentNumber, bool kShouldStaticAssert> 685 static constexpr bool CheckParam( 686 typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) { 687 typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type 688 ParamT; 689 typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT; 690 return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() && 691 CheckParam<kArgumentNumber - 1, kShouldStaticAssert>(); 692 } 693 694 // Checks the parameters/arguments for match, but doesn't static assert out. 695 // This is useful for testing/inspecting whether a set of parameters match in 696 // things like tests. 697 static constexpr bool CheckAllNoStaticAssert() { 698 return kLengthMatches && CheckParam<kTupleLength - 1, false>(); 699 } 700 701 // Checks the parameters and static asserts out with a helpful error message 702 // (and useful template parameters in the instantiation stack) if there is an 703 // error. 704 static constexpr bool CheckAllStaticAssert() { 705 static_assert(kLengthMatches, 706 "argument length mismatched against typed kernel parameters"); 707 return kLengthMatches && CheckParam<kTupleLength - 1, true>(); 708 } 709 }; 710 711 // This is a convenience type for checking whether a typed kernel matches 712 // against a type list. 713 template <typename KernelT, typename... Params> 714 struct KernelParamsOk { 715 static constexpr bool kResult = false; 716 }; 717 718 // See above. 719 template <typename... Params, typename... Args> 720 struct KernelParamsOk<TypedKernel<Params...>, Args...> { 721 static constexpr bool kResult = KernelInvocationChecker< 722 std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert(); 723 }; 724 725 } // namespace stream_executor 726 727 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_KERNEL_H_ 728