xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/IndexKernel.h>
3 #include <ATen/native/IndexKernel.h>
4 
5 #include <type_traits>
6 #include <ATen/core/TensorBase.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/Dispatch_v2.h>
9 #include <ATen/core/Array.h>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/cuda/cub.h>
12 #include <ATen/cuda/detail/IndexUtils.cuh>
13 #include <ATen/cuda/detail/OffsetCalculator.cuh>
14 #include <ATen/native/cuda/Loops.cuh>
15 #include <ATen/native/cuda/KernelUtils.cuh>
16 #include <ATen/native/quantized/IndexKernel.h>
17 
18 #include <c10/core/Scalar.h>
19 
20 namespace at::native {
21 
22 static constexpr int launch_bound2 = 4;
23 
24 static constexpr int launch_size_nd = 128;
25 
26 template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt,launch_bound2)27 C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
28 __global__ void index_elementwise_kernel(const int64_t N, const func_t f) {
29   const auto tid = threadIdx.x;
30   const auto nv = nt * vt;
31   auto idx = nv * blockIdx.x + tid;
32   #pragma unroll
33   for (int i = 0; i < vt; i++) {
34     if (idx < N) {
35       f(idx);
36       idx += nt;
37     }
38   }
39 }
40 
41 template<int nt, int vt, typename func_t>
launch_kernel(const int64_t N,const func_t & f)42 static void launch_kernel(const int64_t N, const func_t& f) {
43   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
44   if (N == 0) {
45     return;
46   }
47   const dim3 block(nt);
48   const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
49   const auto stream = at::cuda::getCurrentCUDAStream();
50   index_elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
51   C10_CUDA_KERNEL_LAUNCH_CHECK();
52 }
53 
54 template <typename func_t>
gpu_index_kernel(TensorIteratorBase & iter,const IntArrayRef index_size,const IntArrayRef index_stride,const func_t & f)55 void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f) {
56   const auto num_indices = index_size.size();
57   AT_ASSERT(num_indices == index_stride.size());
58   AT_ASSERT(static_cast<int64_t>(num_indices) == iter.ntensors() - 2);
59 
60   if (iter.numel() == 0) {
61     return;
62   }
63 
64   if (!iter.can_use_32bit_indexing()) {
65     for (auto& sub_iter : iter.with_32bit_indexing()) {
66       gpu_index_kernel(sub_iter, index_size, index_stride, f);
67     }
68     return;
69   }
70 
71   auto sizes = at::detail::Array<int64_t, MAX_DIMS>(0);
72   auto strides = at::detail::Array<int64_t, MAX_DIMS>(0);
73   auto index_ptrs = at::detail::Array<char*, MAX_DIMS>(nullptr);
74   for (unsigned i = 0; i < num_indices; i++) {
75     sizes[i] = index_size[i];
76     strides[i] = index_stride[i];
77     index_ptrs[i] = (char*)iter.data_ptr(i + 2);
78   }
79 
80   char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
81   char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
82 
83   auto offset_calc = make_offset_calculator<3>(iter);
84   launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), [=]__device__(int idx) {
85     const auto offsets = offset_calc.get(idx);
86     char* const out_data = out_ptr + offsets[0];
87     const char* const in_data = in_ptr + offsets[1];
88 
89     int64_t offset = 0;
90     #pragma unroll
91     for (int i = 0; i < num_indices; i++) {
92       int64_t index = *reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]);
93       CUDA_KERNEL_ASSERT(-sizes[i] <= index && index < sizes[i] && "index out of bounds");
94       if (index < 0) {
95         index += sizes[i];
96       }
97       offset += index * strides[i];
98     }
99 
100     f(out_data, in_data, offset);
101   });
102 }
103 
104 // The kernels are templated on an opaque, self-aligned type of the correct
105 // size to avoid redundant kernels for different types of the same size.
106 template <int N> struct alignas(N) OpaqueType { char data[N]; };
107 
108 template <typename scalar_t>
index_fill_kernel_impl(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride,const scalar_t fill_val)109 void index_fill_kernel_impl(
110   TensorIterator& iter,
111   const int64_t dim,
112   const int64_t self_dim_size,
113   const int64_t self_dim_stride,
114   const scalar_t fill_val) {
115   if (0 == iter.numel()) {
116     return;
117   }
118 
119   if (!iter.can_use_32bit_indexing()) {
120     for (auto& sub_iter : iter.with_32bit_indexing()) {
121       index_fill_kernel_impl(sub_iter, dim, self_dim_size, self_dim_stride, fill_val);
122     }
123     return;
124   }
125 
126   char* const __restrict__ self_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
127   char* const __restrict__ idx_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
128 
129   const auto offset_calc = make_offset_calculator<2>(iter);
130 
131   const auto loop = [=]C10_DEVICE(int i) {
132     const auto offsets = offset_calc.get(i);
133 
134     auto* __restrict__ self_data = reinterpret_cast<scalar_t*>(self_ptr + offsets[0]);
135     auto idx = *reinterpret_cast<int64_t*>(idx_ptr + offsets[1]);
136     CUDA_KERNEL_ASSERT(idx >= -self_dim_size && idx < self_dim_size && "index out of bounds");
137     if (idx < 0) {
138       idx += self_dim_size;
139     }
140 
141     self_data[idx * self_dim_stride] = fill_val;
142   };
143   launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
144 }
145 
146 template <typename scalar_t>
index_copy_kernel_impl(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride)147 void index_copy_kernel_impl(
148   TensorIterator& iter,
149   const int64_t dim,
150   const int64_t self_dim_size,
151   const int64_t self_dim_stride) {
152   if (iter.numel() == 0) {
153     return;
154   }
155 
156   if (!iter.can_use_32bit_indexing()) {
157     for (auto& sub_iter : iter.with_32bit_indexing()) {
158       index_copy_kernel_impl<scalar_t>(sub_iter, dim, self_dim_size, self_dim_stride);
159     }
160     return;
161   }
162 
163   char* const __restrict__ self_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
164   char* const __restrict__ idx_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
165   char* const __restrict__ source_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
166 
167   const auto offset_calc = make_offset_calculator<3>(iter);
168 
169   const auto loop = [=]C10_DEVICE(int i) {
170     const auto offsets = offset_calc.get(i);
171 
172     auto* const __restrict__ self_data = reinterpret_cast<scalar_t*>(self_ptr + offsets[0]);
173     auto idx = *reinterpret_cast<int64_t*>(idx_ptr + offsets[1]);
174     const auto* const __restrict__ source_data = reinterpret_cast<scalar_t*>(source_ptr + offsets[2]);
175     CUDA_KERNEL_ASSERT(idx >= 0 && idx < self_dim_size && "index_copy_(): index out of bounds");
176 
177     self_data[idx * self_dim_stride] = *source_data;
178   };
179   launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
180 }
181 
182 template <typename scalar_t>
index_kernel_impl(TensorIteratorBase & iter,const IntArrayRef index_size,const IntArrayRef index_stride)183 void index_kernel_impl(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
184   gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
185     *reinterpret_cast<scalar_t*>(out_data) = *reinterpret_cast<const scalar_t*>(in_data + offset);
186   });
187 }
188 
189 template <typename scalar_t>
index_put_kernel_impl(TensorIterator & iter,const IntArrayRef index_size,const IntArrayRef index_stride)190 void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
191   gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
192     *reinterpret_cast<scalar_t*>(out_data + offset) = *reinterpret_cast<const scalar_t*>(in_data);
193   });
194 }
195 
index_kernel(TensorIteratorBase & iter,const IntArrayRef index_size,const IntArrayRef index_stride)196 static void index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
197   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_cuda", [&] {
198     using dtype = OpaqueType<sizeof(scalar_t)>;
199     index_kernel_impl<dtype>(iter, index_size, index_stride);
200   });
201 }
202 
index_fill_kernel(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride,const Scalar & source)203 static void index_fill_kernel(
204   TensorIterator& iter,
205   const int64_t dim,
206   const int64_t self_dim_size,
207   const int64_t self_dim_stride,
208   const Scalar& source) {
209   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
210     at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf,
211     iter.dtype(), "index_fill_cuda", [&] {
212     using dtype = OpaqueType<sizeof(scalar_t)>;
213     const auto fill_val = source.to<scalar_t>();
214     const auto fill_val_opaque = *reinterpret_cast<const dtype*>(&fill_val);
215     index_fill_kernel_impl<dtype>(iter, dim, self_dim_size, self_dim_stride, fill_val_opaque);
216   });
217 }
218 
index_copy_kernel(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride)219 static void index_copy_kernel(
220   TensorIterator& iter,
221   const int64_t dim,
222   const int64_t self_dim_size,
223   const int64_t self_dim_stride) {
224   // See note [Writing Nondeterministic Operations]
225   // Nondeterministic when index contains duplicate entries
226   // this kernel will not be called when torch.use_deterministic_algorithms(True)
227   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
228     at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf,
229     iter.dtype(), "index_copy_cuda", [&] {
230     using dtype = OpaqueType<sizeof(scalar_t)>;
231     index_copy_kernel_impl<dtype>(iter, dim, self_dim_size, self_dim_stride);
232   });
233 }
234 
235 
index_put_kernel(TensorIterator & iter,const IntArrayRef index_size,const IntArrayRef index_stride,const bool accumulate)236 static void index_put_kernel(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const bool accumulate) {
237   TORCH_CHECK(!accumulate, "index_put does not support accumulate=true");
238   AT_DISPATCH_V2(
239     iter.dtype(),
240     "index_put",
241     AT_WRAP([&] {
242       using dtype = OpaqueType<sizeof(scalar_t)>;
243       index_put_kernel_impl<dtype>(iter, index_size, index_stride);
244     }),
245     AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
246     AT_EXPAND(AT_FLOAT8_TYPES),
247     kComplexHalf,
248     kHalf,
249     kBool,
250     kBFloat16);
251 }
252 
index_put_kernel_quantized_cuda(TensorIterator & iter,const IntArrayRef index_size,const IntArrayRef index_stride,const bool accumulate,const double scale,const int zero_point)253 void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const bool accumulate, const double scale, const int zero_point) {
254   TORCH_CHECK(!accumulate, "index_put does not support accumulate=true");
255   AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "index_put", [&] {
256     constexpr int64_t qmin = std::numeric_limits<typename scalar_t::underlying>::min();
257     constexpr int64_t qmax = std::numeric_limits<typename scalar_t::underlying>::max();
258     const float inv_scale = 1.0f / static_cast<float>(scale);
259 
260     gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
261       int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(*(float*)in_data * inv_scale));
262       // See https://github.com/pytorch/pytorch/issues/127666
263       // and https://github.com/pytorch/pytorch/issues/128253.
264       // hip-clang std::clamp __glibcxx_assert_fail host function when building on Fedora40/gcc14.
265       // The following replaces std::clamp(qvalue, qmin, qmax) and is a viable solution for
266       // both CUDA and ROCm since std::clamp and this replacement generates the same PTX.
267       // Using #ifdef USE_ROCM to differentiate caused Windows build failures.
268       // The replacement should generate the same PTX as std::clamp. See https://godbolt.org/z/Wde9KW3v4
269       qvalue = (qvalue < qmin) ? qmin : (qmax < qvalue) ? qmax : qvalue;
270       *(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
271     });
272   });
273 }
274 
275 template <typename scalar_t, typename index_t, typename func_t>
cuda_take_put_kernel(TensorIterator & iter,const TensorBase & indexed,const func_t & f)276 void cuda_take_put_kernel(
277   TensorIterator& iter,
278   const TensorBase& indexed,
279   const func_t& f) {
280   if (!iter.can_use_32bit_indexing()) {
281     for (auto& sub_iter : iter.with_32bit_indexing()) {
282       cuda_take_put_kernel<scalar_t, index_t>(sub_iter, indexed, f);
283     }
284     return;
285   }
286 
287   const auto numel = indexed.numel();
288   const bool is_contiguous = indexed.is_contiguous();
289 
290   char* const __restrict__ iterated_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
291   char* const __restrict__ idx_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
292 
293   const auto offset_calc = make_offset_calculator<2>(iter);
294   using uindex_t = std::make_unsigned_t<index_t>;
295 
296   // OffsetCalculator needs the sizes and strides reveresed
297   const auto indexed_sizes = std::vector<int64_t>(indexed.sizes().rbegin(), indexed.sizes().rend());
298   const auto indexed_strides = std::vector<int64_t>(indexed.strides().rbegin(), indexed.strides().rend());
299   const auto* indexed_strides_data = indexed_strides.data();
300   const auto offset_indexed = OffsetCalculator<1, uindex_t>(indexed.dim(),
301                                                             indexed_sizes.data(),
302                                                             &indexed_strides_data);
303 
304   const auto loop = [=]C10_DEVICE(int i) {
305     const auto offsets = offset_calc.get(i);
306 
307     auto& iterated = *reinterpret_cast<scalar_t*>(iterated_ptr + offsets[0]);
308     const auto idx = *reinterpret_cast<int64_t*>(idx_ptr + offsets[1]);
309     CUDA_KERNEL_ASSERT(idx < numel && idx >= -numel && "cuda_take_put_kernel() index out of bounds");
310     index_t offset = static_cast<index_t>(idx);
311     if (offset < 0) {
312       offset += numel;
313     }
314     if (!is_contiguous) {
315       offset = offset_indexed.get(offset)[0];
316     }
317 
318     f(iterated, offset);
319   };
320   launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
321 }
322 
put_kernel(TensorIterator & iter,const TensorBase & output,const bool accumulate)323 void put_kernel(TensorIterator& iter, const TensorBase& output, const bool accumulate) {
324   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "put_cuda", [&] {
325     // Cannot use `OpaqueType`, as we need the actual type for `fastSpecializedgpuAtomicAdd`
326     AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(output) ? ScalarType::Int : ScalarType::Long,
327         "put_cuda_index", [&] {
328            auto* __restrict__ indexed_ptr = output.template data_ptr<scalar_t>();
329            if (accumulate) {
330              index_t numel = output.numel();
331              cuda_take_put_kernel<scalar_t, index_t>(iter, output,
332                  [numel, indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
333                    fastSpecializedAtomicAdd(indexed_ptr, offset, numel, iterated);
334                  });
335            }
336            else {
337              cuda_take_put_kernel<scalar_t, index_t>(iter, output,
338                  [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
339                    indexed_ptr[offset] = iterated;
340                  });
341            }
342     });
343   });
344 }
345 
take_kernel(TensorIterator & iter,const TensorBase & input)346 void take_kernel(
347   TensorIterator& iter,
348   const TensorBase& input) {
349   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "take_cuda", [&] {
350     // Cannot use `OpaqueType`, as Tensor::data_ptr<OpaqueType<N>> is not implemented
351     AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(input) ? ScalarType::Int : ScalarType::Long,
352       "take_cuda_index", [&] {
353          const auto* __restrict__ indexed_ptr = input.template const_data_ptr<scalar_t>();
354          cuda_take_put_kernel<scalar_t, index_t>(iter, input,
355             [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
356                iterated = indexed_ptr[offset];
357              });
358      });
359   });
360 }
361 
362 namespace {
363 
masked_scatter_size_check(const int64_t * const mask_exclusive_sum,const bool * const mask,const int64_t srcSize)364 __global__ void masked_scatter_size_check(
365   const int64_t* const mask_exclusive_sum,
366   const bool* const mask,
367   const int64_t srcSize) {
368   // Convert exclusive sum to inclusive sum
369   const auto totalElements = *mask_exclusive_sum + *mask;
370   CUDA_KERNEL_ASSERT(totalElements <= srcSize);
371 }
372 
373 } // anonymous namespace
374 
launch_masked_scatter_kernel(const TensorBase & self,const TensorBase & mask,const TensorBase & maskPrefixSum,const TensorBase & source)375 void launch_masked_scatter_kernel(
376     const TensorBase &self, const TensorBase &mask,
377     const TensorBase &maskPrefixSum, const TensorBase &source) {
378   const auto srcSize = source.numel();
379   const auto mask_cont = mask.contiguous();
380   const auto mask_numel = mask.numel();
381 
382   // Use a prefix sum to determine the output locations of the masked elements
383   auto maskPrefixSum_data = maskPrefixSum.mutable_data_ptr<int64_t>();
384   auto mask_data = mask_cont.const_data_ptr<bool>();
385 
386   at::cuda::cub::mask_exclusive_sum(
387       mask_data, maskPrefixSum_data, mask_numel);
388 
389   // Asynchronously check that the number of `1` elements present in the mask
390   // must be <= the number of elements available in `src`.
391   masked_scatter_size_check<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
392       &maskPrefixSum_data[mask_numel - 1], &mask_data[mask_numel - 1], srcSize);
393   C10_CUDA_KERNEL_LAUNCH_CHECK();
394 
395   // We are getting elements from `src` based on an offset from
396   // `maskPrefixSum`, so that should be made contiguous too
397   auto source_contig = source.contiguous();
398 
399   auto iter = TensorIteratorConfig()
400       .set_check_mem_overlap(false)
401       .check_all_same_dtype(false)
402       .resize_outputs(false)
403       .add_output(self)
404       .add_input(self)
405       .add_const_input(mask_cont)
406       .add_input(maskPrefixSum)
407       .build();
408 
409   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
410       ScalarType::Bool,
411       ScalarType::BFloat16,
412       ScalarType::Half,
413       self.scalar_type(),
414       "masked_scatter_",
415       [&]() {
416         auto source_ptr = source_contig.const_data_ptr<scalar_t>();
417         gpu_kernel(
418             iter, [=] GPU_LAMBDA(const scalar_t a, const bool mask, const int64_t maskPrefixSum) -> scalar_t {
419               if (mask) {
420                 return source_ptr[maskPrefixSum];
421               }
422               return a;
423             });
424         AT_CUDA_CHECK(cudaGetLastError());
425       });
426 }
427 
428 template <typename scalar_t>
flip_kernel_impl(TensorIterator & iter)429 void flip_kernel_impl(TensorIterator& iter) {
430   if (!iter.can_use_32bit_indexing()) {
431     for (auto& sub_iter : iter.with_32bit_indexing()) {
432       flip_kernel_impl<scalar_t>(sub_iter);
433     }
434     return;
435   }
436 
437   char* const __restrict__ out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
438   const char* const __restrict__ in_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
439 
440   const auto offset_calc = make_offset_calculator<2, /*signed_strides=*/true>(iter);
441 
442   const auto loop = [=]C10_DEVICE(const int i) {
443     const auto offsets = offset_calc.get(i);
444     // offsets can be negative here, but it's fine
445     scalar_t* const __restrict__ out_data = reinterpret_cast<scalar_t*>(out_ptr + offsets[0]);
446     const scalar_t* const __restrict__ in_data = reinterpret_cast<const scalar_t*>(in_ptr + offsets[1]);
447     *out_data = *in_data;
448   };
449   launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
450 }
451 
flip_kernel(TensorIterator & iter,const bool quantized)452 void flip_kernel(TensorIterator& iter, const bool quantized) {
453   if (quantized) {
454     AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cuda",
455     [&] {
456       using dtype = OpaqueType<sizeof(scalar_t)>;
457       flip_kernel_impl<dtype>(iter);
458     });
459   } else {
460     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
461                                            iter.dtype(), "flip_cuda",
462     [&] {
463       using dtype = OpaqueType<sizeof(scalar_t)>;
464       flip_kernel_impl<dtype>(iter);
465     });
466   }
467 }
468 
469 
470 REGISTER_DISPATCH(index_stub, &index_kernel);
471 REGISTER_DISPATCH(index_fill_stub, &index_fill_kernel);
472 REGISTER_DISPATCH(index_copy_stub, &index_copy_kernel);
473 REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
474 REGISTER_DISPATCH(put_stub, &put_kernel);
475 REGISTER_DISPATCH(take_stub, &take_kernel);
476 REGISTER_DISPATCH(flip_stub, &flip_kernel);
477 
478 REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda);
479 
480 } // namespace at::native
481