xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/kernel.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Suite of datatypes to represent data-parallel kernel objects (code entities).
17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature
18 // to do some template-based helper generation and give compile-time type
19 // checking for kernel launch parameters.
20 //
21 // Users typically don't see KernelBase, they see typed kernels, analogous to a
22 // typed function pointer. TypedKernels express their argument types via
23 // template parameters like so:
24 //
25 //  TypedKernel<DeviceMemory<int>*, int>
26 //
27 // Which expresses a data parallel kernel signature for:
28 //
29 //  void(int*, int);
30 //
31 // And for a const memory region:
32 //
33 //  TypedKernel<const DeviceMemory<int>&, int>
34 //
35 // Corresponds to a data parallel kernel signature for:
36 //
37 //  void(const int*, int)
38 //
39 // Note that kernels always have a void return type, so results typically must
40 // be memcpy'ied from device memory to the host.
41 //
42 // Also note that a scalar integer residing in device memory and an array of
43 // integers residing in device memory have the same signature: DeviceMemory<T>.
44 // However, in the future, checks may be added for additional safety that arrays
45 // of minimum sizes are passed when those minimum sizes are contractually
46 // expected by the kernel.
47 //
48 // For user-defined types whose definitions are appropriately shared between the
49 // host code doing the launching and the kernel code being launched, the user
50 // defined types are similarly permitted to be expressed as residing in device
51 // memory:
52 //
53 //  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
54 //
55 // And, when the alignment and padding are agreed upon, POD types will also be
56 // able to be passed by value; for example, it is a common idiom to specify a
57 // bunch of options simultaneously with a structure:
58 //
59 //  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
60 //
61 // Which corresponds to a data parallel kernel signature like:
62 //
63 //  void(MyOptionsStructurePassedByValue value, float *result);
64 //
65 // Users typically won't need to type out the TypedKernel signature in full, it
66 // will be typedef'd by automatically generated code; for example, see
67 // stream_executor::executor_sample::VecReduceAddKernel.
68 
69 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_KERNEL_H_
70 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_KERNEL_H_
71 
72 #include <array>
73 #include <memory>
74 #include <tuple>
75 #include <type_traits>
76 #include <vector>
77 
78 #include "absl/strings/string_view.h"
79 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
80 #include "tensorflow/compiler/xla/stream_executor/kernel_cache_config.h"
81 #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h"
82 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
83 #include "tensorflow/core/platform/logging.h"
84 
85 namespace stream_executor {
86 
87 class DeviceMemoryBase;
88 template <typename ElemT>
89 class DeviceMemory;
90 class StreamExecutor;
91 
92 namespace internal {
93 class KernelInterface;
94 }  // namespace internal
95 
96 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
97 // registers allocated, shared memory used, etc.
98 // Not all platforms support reporting of all information, so each accessor
99 // returns false if the associated field is not populated in the underlying
100 // platform.
101 class KernelMetadata {
102  public:
KernelMetadata()103   KernelMetadata()
104       : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
105 
106   // Returns the number of registers used per thread executing this kernel.
107   bool registers_per_thread(int *registers_per_thread) const;
108 
109   // Sets the number of registers used per thread executing this kernel.
110   void set_registers_per_thread(int registers_per_thread);
111 
112   // Returns the amount of [static] shared memory used per block executing this
113   // kernel. Note that dynamic shared memory allocations are not (and can not)
114   // be reported here (since they're not specified until kernel launch time).
115   bool shared_memory_bytes(int *shared_memory_bytes) const;
116 
117   // Sets the amount of [static] shared memory used per block executing this
118   // kernel.
119   void set_shared_memory_bytes(int shared_memory_bytes);
120 
121  private:
122   // Holds the value returned by registers_per_thread above.
123   bool has_registers_per_thread_;
124   int registers_per_thread_;
125 
126   // Holds the value returned by shared_memory_bytes above.
127   bool has_shared_memory_bytes_;
128   int64_t shared_memory_bytes_;
129 };
130 
131 // A data-parallel kernel (code entity) for launching via the StreamExecutor,
132 // analogous to a void* device function pointer. See TypedKernel for the typed
133 // variant.
134 //
135 // Thread-compatible.
136 class KernelBase {
137  public:
138   KernelBase(KernelBase &&from);
139 
140   // Constructs an "empty" (not-yet-loaded) kernel instance.
141   //
142   // parent is the StreamExecutor that will be responsible for loading the
143   // implementation of this kernel. It must not be null.
144   explicit KernelBase(StreamExecutor *parent);
145 
146   // Test-only constructor that can take a mock KernelInterface implementation.
147   KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
148 
149   // Releases resources associated with the kernel instance (i.e.
150   // platform-specific implementation).
151   ~KernelBase();
152 
153   // Returns the number of parameters that this kernel accepts. (Arity refers to
154   // nullary, unary, ...).
155   unsigned Arity() const;
156 
157   // Returns the StreamExecutor that represents the platform this kernel
158   // executes upon.
parent()159   StreamExecutor *parent() const { return parent_; }
160 
161   // Returns a const pointer to the (opaque) platform-dependent implementation.
implementation()162   const internal::KernelInterface *implementation() const {
163     return implementation_.get();
164   }
165 
166   // Returns a non-const pointer to the (opaque) platform-dependent
167   // implementation.
implementation()168   internal::KernelInterface *implementation() { return implementation_.get(); }
169 
set_metadata(const KernelMetadata & metadata)170   void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
171 
metadata()172   const KernelMetadata &metadata() const { return metadata_; }
173 
174   // Sets the preferred cache configuration for a kernel. This is just a
175   // suggestion to the runtime, and may not be honored during execution.
176   void SetPreferredCacheConfig(KernelCacheConfig config);
177 
178   // Gets the preferred cache configuration for a kernel.
179   KernelCacheConfig GetPreferredCacheConfig() const;
180 
181   void set_name(absl::string_view name);
name()182   const std::string &name() const { return name_; }
demangled_name()183   const std::string &demangled_name() const { return demangled_name_; }
184 
185  private:
186   // The StreamExecutor that loads this kernel object.
187   StreamExecutor *parent_;
188 
189   // Implementation delegated to for platform-specific functionality.
190   std::unique_ptr<internal::KernelInterface> implementation_;
191 
192   std::string name_;
193   std::string demangled_name_;
194 
195   KernelMetadata metadata_;
196 
197   SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
198 };
199 
200 // Whether T is a DeviceMemory-family pointer.
201 template <typename T>
202 struct IsDeviceMemoryPointer {
203   static constexpr bool value = false;
204 };
205 
206 template <typename U>
207 struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
208   static constexpr bool value = true;
209 };
210 
211 template <>
212 struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
213   static constexpr bool value = true;
214 };
215 
216 // Whether T is a DeviceMemory-family value-like thing (which includes a
217 // reference). This trait is useful because we pack values in the same manner as
218 // references.
219 template <typename T>
220 struct IsDeviceMemoryValueLike {
221   static constexpr bool value = false;
222 };
223 
224 template <typename U>
225 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
226   static constexpr bool value = true;
227 };
228 
229 // We need to treat SharedDeviceMemory types differently than other DeviceMemory
230 // types (since they maintain no allocations), hence these specializations.
231 template <typename U>
232 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
233   static constexpr bool value = false;
234 };
235 
236 template <>
237 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
238   static constexpr bool value = true;
239 };
240 
241 template <typename U>
242 struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
243   static constexpr bool value = true;
244 };
245 
246 template <typename U>
247 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
248   static constexpr bool value = false;
249 };
250 
251 template <>
252 struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
253   static constexpr bool value = true;
254 };
255 
256 template <typename U>
257 struct IsSharedDeviceMemory {
258   static constexpr bool value = false;
259 };
260 
261 template <typename U>
262 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
263   static constexpr bool value = true;
264 };
265 
266 template <typename U>
267 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
268   static constexpr bool value = true;
269 };
270 
271 // Basic data about a kernel argument.
272 struct KernelArg {
273   bool is_shared;
274   const void *address;
275   size_t size;
276 };
277 
278 // An iterator for traversing all the arguments of a KernelArgsArray.
279 class KernelArgIterator {
280  public:
281   KernelArgIterator(int number_of_argument_addresses,
282                     int number_of_shared_memory_arguments,
283                     const void *const *arg_addresses_data,
284                     const size_t *arg_sizes_data,
285                     const size_t *shmem_bytes_data,
286                     const size_t *shmem_indices_data)
287       : arg_index_(0),
288         number_of_arguments_(number_of_argument_addresses +
289                              number_of_shared_memory_arguments),
290         arg_address_iter_(arg_addresses_data),
291         arg_size_iter_(arg_sizes_data),
292         shmem_bytes_iter_(shmem_bytes_data),
293         shmem_indices_iter_(shmem_indices_data),
294         shmem_indices_end_(shmem_indices_data +
295                            number_of_shared_memory_arguments) {}
296 
297   // Returns true if another argument is present in the iterator.
298   bool has_next() { return arg_index_ < number_of_arguments_; }
299 
300   // Returns the next argument in the iterator.
301   //
302   // Returns a default-constructed KernelArg if there is no next argument.
303   KernelArg next() {
304     KernelArg result = {};
305     if (!has_next()) {
306       return result;
307     } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
308                (arg_index_ == *shmem_indices_iter_)) {
309       result.is_shared = true;
310       result.address = nullptr;
311       result.size = *shmem_bytes_iter_;
312       ++shmem_indices_iter_;
313       ++shmem_bytes_iter_;
314     } else {
315       result.is_shared = false;
316       result.address = *arg_address_iter_;
317       result.size = *arg_size_iter_;
318       ++arg_address_iter_;
319       ++arg_size_iter_;
320     }
321     ++arg_index_;
322     return result;
323   }
324 
325  private:
326   size_t arg_index_;
327   size_t number_of_arguments_;
328   const void *const *arg_address_iter_;
329   const size_t *arg_size_iter_;
330   const size_t *shmem_bytes_iter_;
331   const size_t *shmem_indices_iter_;
332   const size_t *const shmem_indices_end_;
333 };
334 
335 // Base class for KernelArgsArray.
336 //
337 // Supports all the getter methods that do not depend on the compile-time number
338 // of arguments template parameter.
339 //
340 // This class exists as a way to pass kernel arguments to
341 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
342 // be templated to accept any KernelArgsArray type, therefore a reference to
343 // this base type is passed instead.
344 //
345 // Performance is not a concern here because each of these methods will be
346 // called at most once per kernel launch. Past performance concerns with
347 // KernelArgsArray have been in reference to the argument packing routines which
348 // are called once per kernel argument. Those packing routines are now handled
349 // by the templated KernelArgsArray subclass of this class where they can take
350 // advantage of compile-time knowledge of the number of arguments in order to be
351 // very efficient.
352 class KernelArgsArrayBase {
353  public:
354   virtual ~KernelArgsArrayBase() = default;
355 
356   // Gets the number of arguments added so far, including shared memory
357   // arguments.
358   virtual size_t number_of_arguments() const = 0;
359 
360   // Gets the total number of shared memory bytes added so far.
361   virtual uint64_t number_of_shared_bytes() const = 0;
362 
363   // Gets the list of argument addresses.
364   virtual port::ArraySlice<const void *> argument_addresses()  // non-absl ok
365       const = 0;
366 
367   // Gets an iterator to the arguments in the array.
368   virtual KernelArgIterator arg_iterator() const = 0;
369 };
370 
371 // A list of arguments for a kernel call.
372 //
373 // The template parameter kNumArgs is the maximum number of arguments which can
374 // be stored in the list.
375 //
376 // Contains a list of addresses for non-shared-memory arguments and a list of
377 // sizes for shared-memory arguments. Since the shared-memory arguments may be
378 // interspersed with the non-shared-memory arguments, it also stores a list of
379 // the indices at which the shared-memory arguments appeared.
380 //
381 // For example, if the argument address list contains {a, b, c, d, e}, the
382 // shared-memory arguments list contains the sizes of {A, B, C}, and the
383 // shared-memory indices list contains {0, 3, 5}, then the original list of
384 // arguments was {A, a, b, B, c, C, d, e}.
385 //
386 // This way of storing the arguments makes CUDA kernel calls efficient because
387 // they only require the argument address list and the total number of shared
388 // bytes, but it also makes it possible for OpenCL kernel calls because they
389 // depend on the location of each shared-memory argument and its size.
390 //
391 // Note that the code for adding arguments has been identified as a performance
392 // hotspot in some real-world applications so this structure has been optimized
393 // for the performance of argument adding.
394 template <size_t kNumArgs>
395 class KernelArgsArray : public KernelArgsArrayBase {
396  public:
397   static constexpr int kMaxGenericArgSize = 8;
398 
399   // Adds an argument to the list.
400   template <typename T>
401   void add_argument(const T &arg) {
402     static_assert(sizeof(T) <= kMaxGenericArgSize,
403                   "Please adjust kMaxGenericArgSize");
404     static_assert(std::is_pod<T>::value, "Only pod types supported!");
405     char *generic_arg_storage =
406         &generic_arguments_[number_of_generic_arguments_++ *
407                             kMaxGenericArgSize];
408 
409     CHECK_EQ(reinterpret_cast<uintptr_t>(generic_arg_storage) % alignof(T), 0);
410     std::memcpy(generic_arg_storage, &arg, sizeof(T));
411 
412     argument_addresses_[number_of_argument_addresses_] = generic_arg_storage;
413     argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
414     ++number_of_argument_addresses_;
415   }
416 
417   // Adds a device memory argument to the list.
418   void add_device_memory_argument(const DeviceMemoryBase &arg) {
419     const void **copy_ptr =
420         &device_memory_opaque_pointers_[number_of_argument_addresses_];
421     *copy_ptr = arg.opaque();
422     argument_addresses_[number_of_argument_addresses_] = copy_ptr;
423     argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
424     ++number_of_argument_addresses_;
425   }
426 
427   // Adds a shared memory argument to the list.
428   //
429   // The only significant information about a shared argument is its size, so
430   // that is the only parameter in this function.
431   void add_shared_bytes(size_t number_of_bytes) {
432     shared_memory_indices_[number_of_shared_memory_arguments_] =
433         number_of_argument_addresses_ + number_of_shared_memory_arguments_;
434     shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
435     ++number_of_shared_memory_arguments_;
436     total_shared_memory_bytes_ += number_of_bytes;
437   }
438 
439   // Gets the number of arguments added so far, including shared memory
440   // arguments.
441   size_t number_of_arguments() const override {
442     return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
443   }
444 
445   // Gets the total number of shared memory bytes added so far.
446   uint64_t number_of_shared_bytes() const override {
447     return total_shared_memory_bytes_;
448   }
449 
450   // Gets the list of argument addresses.
451   port::ArraySlice<const void *> argument_addresses()  // non-absl ok
452       const override {
453     return port::ArraySlice<const void *>(  // non-absl ok
454         argument_addresses_.data(), number_of_argument_addresses_);
455   }
456 
457   // Gets an iterator to the arguments in the array.
458   KernelArgIterator arg_iterator() const override {
459     return KernelArgIterator(
460         number_of_argument_addresses_, number_of_shared_memory_arguments_,
461         argument_addresses_.data(), argument_sizes_.data(),
462         shared_memory_bytes_.data(), shared_memory_indices_.data());
463   }
464 
465  private:
466   // A place to store copies of opaque pointers from device memory arguments.
467   std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
468 
469   // Addresses for non-shared-memory arguments.
470   std::array<const void *, kNumArgs> argument_addresses_;
471 
472   // Storage for arguments of templated type.
473   alignas(kMaxGenericArgSize)
474       std::array<char, kNumArgs * kMaxGenericArgSize> generic_arguments_;
475 
476   // Sizes for non-shared-memory arguments.
477   std::array<size_t, kNumArgs> argument_sizes_;
478 
479   // Size in bytes for each shared memory argument.
480   std::array<size_t, kNumArgs> shared_memory_bytes_;
481 
482   // Indices in the arguments array for shared memory arguments.
483   std::array<size_t, kNumArgs> shared_memory_indices_;
484 
485   // Total of all shared memory sizes.
486   size_t total_shared_memory_bytes_ = 0;
487 
488   // Number of significant entries in argument_addresses_ and argument_sizes_.
489   size_t number_of_argument_addresses_ = 0;
490 
491   // Number of significant entries in shared_memory_bytes_ and
492   // shared_memory_indices_.
493   size_t number_of_shared_memory_arguments_ = 0;
494 
495   // The number of generic arguments that have been added to generic_arguments_.
496   size_t number_of_generic_arguments_ = 0;
497 };
498 
499 // Typed variant of KernelBase, like a typed device function pointer. See the
500 // file comment for details and example usage.
501 //
502 // This class contains template metaprogramming magic to type check the
503 // parameters passed to a kernel launch are acceptable, and subsequently pack
504 // them into a form which can be used by the StreamExecutorInterface
505 // implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
506 // sizes as kernel arguments.)
507 //
508 // Thread-compatible.
509 template <typename... Params>
510 class TypedKernel : public KernelBase {
511  public:
512   static constexpr size_t kNumberOfParameters = sizeof...(Params);
513 
514   // Delegates to KernelBase::KernelBase(), see that constructor.
515   explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
516 
517   // Test-only constructor that can take a mock KernelInterface implementation.
518   // Takes ownership of implementation, it should not be null.
519   TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
520       : KernelBase(parent, implementation) {}
521 
522  private:
523   // Stream needs access to the specific parameter-packing functionality that
524   // the TypedKernel provides for its corresponding type signature (and no other
525   // type signatures).
526   friend class Stream;
527 
528   // This is the main entry point into the magic. Packs the parameters (which
529   // must type check against the class template) into the args and sizes
530   // arrays.
531   //
532   // Const refs are taken as parameters on all of the handlers to avoid
533   // implicit type promotion of integers.
534   //
535   // WARNING: as a performance optimization this method may store pointers to
536   // some of the input parameters in the kernel args structure, so any params
537   // passed into this method must live at least as long as the kernel args
538   // structure.
539   void PackParams(KernelArgsArray<kNumberOfParameters> *args,
540                   Params &... params) const {
541     PackOneParamFromList(args, params...);
542   }
543 
544   template <typename T, typename... RestOfParams>
545   void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args,
546                             const T &arg, const RestOfParams &... rest) const {
547     PackOneParam(args, arg);
548     PackOneParamFromList(args, rest...);
549   }
550 
551   // Base case for variadic template expansion - nothing to do!
552   void PackOneParamFromList(KernelArgsArray<kNumberOfParameters> *args) const {}
553 
554   // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
555   // The enable_if<> is for excluding DeviceMemoryBase args, which have a
556   // separate implementation below.
557   template <typename T>
558   void PackOneParam(
559       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
560       typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
561                               !IsDeviceMemoryPointer<T>::value &&
562                               !IsSharedDeviceMemory<T>::value>::type * =
563           nullptr) const {
564     static_assert(!std::is_pointer<T>::value,
565                   "cannot pass raw pointer to the device");
566     static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
567                   "cannot pass device memory as a normal value");
568     args->add_argument(arg);
569   }
570 
571   // DeviceMemoryBase family reference override.
572   template <typename T>
573   void PackOneParam(
574       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
575       typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
576           nullptr) const {
577     args->add_device_memory_argument(arg);
578   }
579 
580   // DeviceMemoryBase family pointer override.
581   template <typename T>
582   void PackOneParam(
583       KernelArgsArray<kNumberOfParameters> *args, T arg,
584       typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
585           nullptr) const {
586     DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
587     args->add_device_memory_argument(*ptr);
588   }
589 
590   // Dynamic shared device memory has a size, but no associated allocation on
591   // the host; internally, the device will allocate storage.
592   template <typename T>
593   void PackOneParam(
594       KernelArgsArray<kNumberOfParameters> *args, T arg,
595       typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
596           nullptr) const {
597     args->add_shared_bytes(arg.size());
598   }
599 
600   SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
601 };
602 
603 // Template metaprogramming helper type that helps us produce better error
604 // messages at compile time when the are mismatches between the parameter
605 // type list and the argument type list.
606 template <typename ParamTuple, typename ArgTuple>
607 struct KernelInvocationChecker {
608   // Whether the parameter tuple and argument tuple match in length.
609   static constexpr bool kLengthMatches =
610       std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
611 
612   // The (matching) length of the parameters and arguments type lists.
613   static constexpr int kTupleLength =
614       static_cast<int>(std::tuple_size<ArgTuple>::value);
615 
616   // Helper trait to say whether the parameter wants a DeviceMemory-reference
617   // compatible type. This is for inexact type matches, so that it doesn't have
618   // to be precisely a const DeviceMemory<T>&, but can also be a value that
619   // represents the same.
620   template <typename ParamType, typename ArgType>
621   struct IsCompatibleDeviceMemoryRef {
622     static constexpr bool value = false;
623   };
624 
625   // See type trait definition above.
626   template <typename U>
627   struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
628     static constexpr bool value = true;
629   };
630 
631   // See type trait definition above.
632   template <typename U>
633   struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
634                                      SharedDeviceMemory<U>> {
635     static constexpr bool value = true;
636   };
637 
638   // Returns whether ParamT and ArgT are compatible for data parallel kernel
639   // parameter packing without any assert functionality.
640   template <typename ParamT, typename ArgT>
641   static constexpr bool CompatibleNoAssert() {
642     return std::is_same<typename std::remove_const<ParamT>::type,
643                         ArgT>::value ||
644            IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
645   }
646 
647   // Checks whether ParamT and ArgT are compatible for data parallel kernel
648   // parameter packing. kArgumentNumber is unused, it just for error display.
649   //
650   // NOTE: if you encounter an error here, you can see the mismatch by looking
651   // at the end of the last error message, which will be of the form:
652   //
653   //    ...::Compatible<const stream_executor::DeviceMemory<OneThing> &,
654   //                    stream_executor::DeviceMemory<AnotherThing>, true,
655   //                    0>'
656   //    requested here
657   //
658   // This means that the 0th argument you passed to the kernel invocation should
659   // have been DeviceMemory<OneThing> but was observed to be
660   // DeviceMemory<AnotherThing>.
661   template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
662             int kArgumentNumber>
663   static constexpr bool Compatible() {
664     static_assert(
665         kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
666         "parameter type (LHS) is not compatible with argument type (RHS)");
667     return CompatibleNoAssert<ParamT, ArgT>();
668   }
669 
670   // Checks the parameter/argument match at kArgumentNumber for an out of bounds
671   // argument number.
672   //
673   // This is the base case: we've run out of argument to check, so we're all
674   // good.
675   template <int kArgumentNumber, bool kShouldStaticAssert>
676   static constexpr bool CheckParam(
677       typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
678     return true;
679   }
680 
681   // Checks the parameter/argument match at kArgumentNumber.
682   // kShouldStaticAssert determines whether to assert out on a mismatch, or just
683   // yield the constexpr boolean value.
684   template <int kArgumentNumber, bool kShouldStaticAssert>
685   static constexpr bool CheckParam(
686       typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
687     typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
688         ParamT;
689     typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
690     return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
691            CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
692   }
693 
694   // Checks the parameters/arguments for match, but doesn't static assert out.
695   // This is useful for testing/inspecting whether a set of parameters match in
696   // things like tests.
697   static constexpr bool CheckAllNoStaticAssert() {
698     return kLengthMatches && CheckParam<kTupleLength - 1, false>();
699   }
700 
701   // Checks the parameters and static asserts out with a helpful error message
702   // (and useful template parameters in the instantiation stack) if there is an
703   // error.
704   static constexpr bool CheckAllStaticAssert() {
705     static_assert(kLengthMatches,
706                   "argument length mismatched against typed kernel parameters");
707     return kLengthMatches && CheckParam<kTupleLength - 1, true>();
708   }
709 };
710 
711 // This is a convenience type for checking whether a typed kernel matches
712 // against a type list.
713 template <typename KernelT, typename... Params>
714 struct KernelParamsOk {
715   static constexpr bool kResult = false;
716 };
717 
718 // See above.
719 template <typename... Params, typename... Args>
720 struct KernelParamsOk<TypedKernel<Params...>, Args...> {
721   static constexpr bool kResult = KernelInvocationChecker<
722       std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
723 };
724 
725 }  // namespace stream_executor
726 
727 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_KERNEL_H_
728