1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/cuda/IndexKernel.h>
3 #include <ATen/native/IndexKernel.h>
4
5 #include <type_traits>
6 #include <ATen/core/TensorBase.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/Dispatch_v2.h>
9 #include <ATen/core/Array.h>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/cuda/cub.h>
12 #include <ATen/cuda/detail/IndexUtils.cuh>
13 #include <ATen/cuda/detail/OffsetCalculator.cuh>
14 #include <ATen/native/cuda/Loops.cuh>
15 #include <ATen/native/cuda/KernelUtils.cuh>
16 #include <ATen/native/quantized/IndexKernel.h>
17
18 #include <c10/core/Scalar.h>
19
20 namespace at::native {
21
22 static constexpr int launch_bound2 = 4;
23
24 static constexpr int launch_size_nd = 128;
25
26 template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt,launch_bound2)27 C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
28 __global__ void index_elementwise_kernel(const int64_t N, const func_t f) {
29 const auto tid = threadIdx.x;
30 const auto nv = nt * vt;
31 auto idx = nv * blockIdx.x + tid;
32 #pragma unroll
33 for (int i = 0; i < vt; i++) {
34 if (idx < N) {
35 f(idx);
36 idx += nt;
37 }
38 }
39 }
40
41 template<int nt, int vt, typename func_t>
launch_kernel(const int64_t N,const func_t & f)42 static void launch_kernel(const int64_t N, const func_t& f) {
43 TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
44 if (N == 0) {
45 return;
46 }
47 const dim3 block(nt);
48 const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
49 const auto stream = at::cuda::getCurrentCUDAStream();
50 index_elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
51 C10_CUDA_KERNEL_LAUNCH_CHECK();
52 }
53
54 template <typename func_t>
gpu_index_kernel(TensorIteratorBase & iter,const IntArrayRef index_size,const IntArrayRef index_stride,const func_t & f)55 void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f) {
56 const auto num_indices = index_size.size();
57 AT_ASSERT(num_indices == index_stride.size());
58 AT_ASSERT(static_cast<int64_t>(num_indices) == iter.ntensors() - 2);
59
60 if (iter.numel() == 0) {
61 return;
62 }
63
64 if (!iter.can_use_32bit_indexing()) {
65 for (auto& sub_iter : iter.with_32bit_indexing()) {
66 gpu_index_kernel(sub_iter, index_size, index_stride, f);
67 }
68 return;
69 }
70
71 auto sizes = at::detail::Array<int64_t, MAX_DIMS>(0);
72 auto strides = at::detail::Array<int64_t, MAX_DIMS>(0);
73 auto index_ptrs = at::detail::Array<char*, MAX_DIMS>(nullptr);
74 for (unsigned i = 0; i < num_indices; i++) {
75 sizes[i] = index_size[i];
76 strides[i] = index_stride[i];
77 index_ptrs[i] = (char*)iter.data_ptr(i + 2);
78 }
79
80 char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
81 char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
82
83 auto offset_calc = make_offset_calculator<3>(iter);
84 launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), [=]__device__(int idx) {
85 const auto offsets = offset_calc.get(idx);
86 char* const out_data = out_ptr + offsets[0];
87 const char* const in_data = in_ptr + offsets[1];
88
89 int64_t offset = 0;
90 #pragma unroll
91 for (int i = 0; i < num_indices; i++) {
92 int64_t index = *reinterpret_cast<int64_t*>(index_ptrs[i] + offsets[2]);
93 CUDA_KERNEL_ASSERT(-sizes[i] <= index && index < sizes[i] && "index out of bounds");
94 if (index < 0) {
95 index += sizes[i];
96 }
97 offset += index * strides[i];
98 }
99
100 f(out_data, in_data, offset);
101 });
102 }
103
104 // The kernels are templated on an opaque, self-aligned type of the correct
105 // size to avoid redundant kernels for different types of the same size.
106 template <int N> struct alignas(N) OpaqueType { char data[N]; };
107
108 template <typename scalar_t>
index_fill_kernel_impl(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride,const scalar_t fill_val)109 void index_fill_kernel_impl(
110 TensorIterator& iter,
111 const int64_t dim,
112 const int64_t self_dim_size,
113 const int64_t self_dim_stride,
114 const scalar_t fill_val) {
115 if (0 == iter.numel()) {
116 return;
117 }
118
119 if (!iter.can_use_32bit_indexing()) {
120 for (auto& sub_iter : iter.with_32bit_indexing()) {
121 index_fill_kernel_impl(sub_iter, dim, self_dim_size, self_dim_stride, fill_val);
122 }
123 return;
124 }
125
126 char* const __restrict__ self_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
127 char* const __restrict__ idx_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
128
129 const auto offset_calc = make_offset_calculator<2>(iter);
130
131 const auto loop = [=]C10_DEVICE(int i) {
132 const auto offsets = offset_calc.get(i);
133
134 auto* __restrict__ self_data = reinterpret_cast<scalar_t*>(self_ptr + offsets[0]);
135 auto idx = *reinterpret_cast<int64_t*>(idx_ptr + offsets[1]);
136 CUDA_KERNEL_ASSERT(idx >= -self_dim_size && idx < self_dim_size && "index out of bounds");
137 if (idx < 0) {
138 idx += self_dim_size;
139 }
140
141 self_data[idx * self_dim_stride] = fill_val;
142 };
143 launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
144 }
145
146 template <typename scalar_t>
index_copy_kernel_impl(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride)147 void index_copy_kernel_impl(
148 TensorIterator& iter,
149 const int64_t dim,
150 const int64_t self_dim_size,
151 const int64_t self_dim_stride) {
152 if (iter.numel() == 0) {
153 return;
154 }
155
156 if (!iter.can_use_32bit_indexing()) {
157 for (auto& sub_iter : iter.with_32bit_indexing()) {
158 index_copy_kernel_impl<scalar_t>(sub_iter, dim, self_dim_size, self_dim_stride);
159 }
160 return;
161 }
162
163 char* const __restrict__ self_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
164 char* const __restrict__ idx_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
165 char* const __restrict__ source_ptr = reinterpret_cast<char*>(iter.data_ptr(2));
166
167 const auto offset_calc = make_offset_calculator<3>(iter);
168
169 const auto loop = [=]C10_DEVICE(int i) {
170 const auto offsets = offset_calc.get(i);
171
172 auto* const __restrict__ self_data = reinterpret_cast<scalar_t*>(self_ptr + offsets[0]);
173 auto idx = *reinterpret_cast<int64_t*>(idx_ptr + offsets[1]);
174 const auto* const __restrict__ source_data = reinterpret_cast<scalar_t*>(source_ptr + offsets[2]);
175 CUDA_KERNEL_ASSERT(idx >= 0 && idx < self_dim_size && "index_copy_(): index out of bounds");
176
177 self_data[idx * self_dim_stride] = *source_data;
178 };
179 launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
180 }
181
182 template <typename scalar_t>
index_kernel_impl(TensorIteratorBase & iter,const IntArrayRef index_size,const IntArrayRef index_stride)183 void index_kernel_impl(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
184 gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
185 *reinterpret_cast<scalar_t*>(out_data) = *reinterpret_cast<const scalar_t*>(in_data + offset);
186 });
187 }
188
189 template <typename scalar_t>
index_put_kernel_impl(TensorIterator & iter,const IntArrayRef index_size,const IntArrayRef index_stride)190 void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
191 gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
192 *reinterpret_cast<scalar_t*>(out_data + offset) = *reinterpret_cast<const scalar_t*>(in_data);
193 });
194 }
195
index_kernel(TensorIteratorBase & iter,const IntArrayRef index_size,const IntArrayRef index_stride)196 static void index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
197 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_cuda", [&] {
198 using dtype = OpaqueType<sizeof(scalar_t)>;
199 index_kernel_impl<dtype>(iter, index_size, index_stride);
200 });
201 }
202
index_fill_kernel(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride,const Scalar & source)203 static void index_fill_kernel(
204 TensorIterator& iter,
205 const int64_t dim,
206 const int64_t self_dim_size,
207 const int64_t self_dim_stride,
208 const Scalar& source) {
209 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
210 at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf,
211 iter.dtype(), "index_fill_cuda", [&] {
212 using dtype = OpaqueType<sizeof(scalar_t)>;
213 const auto fill_val = source.to<scalar_t>();
214 const auto fill_val_opaque = *reinterpret_cast<const dtype*>(&fill_val);
215 index_fill_kernel_impl<dtype>(iter, dim, self_dim_size, self_dim_stride, fill_val_opaque);
216 });
217 }
218
index_copy_kernel(TensorIterator & iter,const int64_t dim,const int64_t self_dim_size,const int64_t self_dim_stride)219 static void index_copy_kernel(
220 TensorIterator& iter,
221 const int64_t dim,
222 const int64_t self_dim_size,
223 const int64_t self_dim_stride) {
224 // See note [Writing Nondeterministic Operations]
225 // Nondeterministic when index contains duplicate entries
226 // this kernel will not be called when torch.use_deterministic_algorithms(True)
227 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
228 at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, kComplexHalf,
229 iter.dtype(), "index_copy_cuda", [&] {
230 using dtype = OpaqueType<sizeof(scalar_t)>;
231 index_copy_kernel_impl<dtype>(iter, dim, self_dim_size, self_dim_stride);
232 });
233 }
234
235
index_put_kernel(TensorIterator & iter,const IntArrayRef index_size,const IntArrayRef index_stride,const bool accumulate)236 static void index_put_kernel(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const bool accumulate) {
237 TORCH_CHECK(!accumulate, "index_put does not support accumulate=true");
238 AT_DISPATCH_V2(
239 iter.dtype(),
240 "index_put",
241 AT_WRAP([&] {
242 using dtype = OpaqueType<sizeof(scalar_t)>;
243 index_put_kernel_impl<dtype>(iter, index_size, index_stride);
244 }),
245 AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
246 AT_EXPAND(AT_FLOAT8_TYPES),
247 kComplexHalf,
248 kHalf,
249 kBool,
250 kBFloat16);
251 }
252
index_put_kernel_quantized_cuda(TensorIterator & iter,const IntArrayRef index_size,const IntArrayRef index_stride,const bool accumulate,const double scale,const int zero_point)253 void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const bool accumulate, const double scale, const int zero_point) {
254 TORCH_CHECK(!accumulate, "index_put does not support accumulate=true");
255 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "index_put", [&] {
256 constexpr int64_t qmin = std::numeric_limits<typename scalar_t::underlying>::min();
257 constexpr int64_t qmax = std::numeric_limits<typename scalar_t::underlying>::max();
258 const float inv_scale = 1.0f / static_cast<float>(scale);
259
260 gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
261 int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(*(float*)in_data * inv_scale));
262 // See https://github.com/pytorch/pytorch/issues/127666
263 // and https://github.com/pytorch/pytorch/issues/128253.
264 // hip-clang std::clamp __glibcxx_assert_fail host function when building on Fedora40/gcc14.
265 // The following replaces std::clamp(qvalue, qmin, qmax) and is a viable solution for
266 // both CUDA and ROCm since std::clamp and this replacement generates the same PTX.
267 // Using #ifdef USE_ROCM to differentiate caused Windows build failures.
268 // The replacement should generate the same PTX as std::clamp. See https://godbolt.org/z/Wde9KW3v4
269 qvalue = (qvalue < qmin) ? qmin : (qmax < qvalue) ? qmax : qvalue;
270 *(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
271 });
272 });
273 }
274
275 template <typename scalar_t, typename index_t, typename func_t>
cuda_take_put_kernel(TensorIterator & iter,const TensorBase & indexed,const func_t & f)276 void cuda_take_put_kernel(
277 TensorIterator& iter,
278 const TensorBase& indexed,
279 const func_t& f) {
280 if (!iter.can_use_32bit_indexing()) {
281 for (auto& sub_iter : iter.with_32bit_indexing()) {
282 cuda_take_put_kernel<scalar_t, index_t>(sub_iter, indexed, f);
283 }
284 return;
285 }
286
287 const auto numel = indexed.numel();
288 const bool is_contiguous = indexed.is_contiguous();
289
290 char* const __restrict__ iterated_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
291 char* const __restrict__ idx_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
292
293 const auto offset_calc = make_offset_calculator<2>(iter);
294 using uindex_t = std::make_unsigned_t<index_t>;
295
296 // OffsetCalculator needs the sizes and strides reveresed
297 const auto indexed_sizes = std::vector<int64_t>(indexed.sizes().rbegin(), indexed.sizes().rend());
298 const auto indexed_strides = std::vector<int64_t>(indexed.strides().rbegin(), indexed.strides().rend());
299 const auto* indexed_strides_data = indexed_strides.data();
300 const auto offset_indexed = OffsetCalculator<1, uindex_t>(indexed.dim(),
301 indexed_sizes.data(),
302 &indexed_strides_data);
303
304 const auto loop = [=]C10_DEVICE(int i) {
305 const auto offsets = offset_calc.get(i);
306
307 auto& iterated = *reinterpret_cast<scalar_t*>(iterated_ptr + offsets[0]);
308 const auto idx = *reinterpret_cast<int64_t*>(idx_ptr + offsets[1]);
309 CUDA_KERNEL_ASSERT(idx < numel && idx >= -numel && "cuda_take_put_kernel() index out of bounds");
310 index_t offset = static_cast<index_t>(idx);
311 if (offset < 0) {
312 offset += numel;
313 }
314 if (!is_contiguous) {
315 offset = offset_indexed.get(offset)[0];
316 }
317
318 f(iterated, offset);
319 };
320 launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
321 }
322
put_kernel(TensorIterator & iter,const TensorBase & output,const bool accumulate)323 void put_kernel(TensorIterator& iter, const TensorBase& output, const bool accumulate) {
324 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "put_cuda", [&] {
325 // Cannot use `OpaqueType`, as we need the actual type for `fastSpecializedgpuAtomicAdd`
326 AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(output) ? ScalarType::Int : ScalarType::Long,
327 "put_cuda_index", [&] {
328 auto* __restrict__ indexed_ptr = output.template data_ptr<scalar_t>();
329 if (accumulate) {
330 index_t numel = output.numel();
331 cuda_take_put_kernel<scalar_t, index_t>(iter, output,
332 [numel, indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
333 fastSpecializedAtomicAdd(indexed_ptr, offset, numel, iterated);
334 });
335 }
336 else {
337 cuda_take_put_kernel<scalar_t, index_t>(iter, output,
338 [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
339 indexed_ptr[offset] = iterated;
340 });
341 }
342 });
343 });
344 }
345
take_kernel(TensorIterator & iter,const TensorBase & input)346 void take_kernel(
347 TensorIterator& iter,
348 const TensorBase& input) {
349 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "take_cuda", [&] {
350 // Cannot use `OpaqueType`, as Tensor::data_ptr<OpaqueType<N>> is not implemented
351 AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(input) ? ScalarType::Int : ScalarType::Long,
352 "take_cuda_index", [&] {
353 const auto* __restrict__ indexed_ptr = input.template const_data_ptr<scalar_t>();
354 cuda_take_put_kernel<scalar_t, index_t>(iter, input,
355 [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
356 iterated = indexed_ptr[offset];
357 });
358 });
359 });
360 }
361
362 namespace {
363
masked_scatter_size_check(const int64_t * const mask_exclusive_sum,const bool * const mask,const int64_t srcSize)364 __global__ void masked_scatter_size_check(
365 const int64_t* const mask_exclusive_sum,
366 const bool* const mask,
367 const int64_t srcSize) {
368 // Convert exclusive sum to inclusive sum
369 const auto totalElements = *mask_exclusive_sum + *mask;
370 CUDA_KERNEL_ASSERT(totalElements <= srcSize);
371 }
372
373 } // anonymous namespace
374
launch_masked_scatter_kernel(const TensorBase & self,const TensorBase & mask,const TensorBase & maskPrefixSum,const TensorBase & source)375 void launch_masked_scatter_kernel(
376 const TensorBase &self, const TensorBase &mask,
377 const TensorBase &maskPrefixSum, const TensorBase &source) {
378 const auto srcSize = source.numel();
379 const auto mask_cont = mask.contiguous();
380 const auto mask_numel = mask.numel();
381
382 // Use a prefix sum to determine the output locations of the masked elements
383 auto maskPrefixSum_data = maskPrefixSum.mutable_data_ptr<int64_t>();
384 auto mask_data = mask_cont.const_data_ptr<bool>();
385
386 at::cuda::cub::mask_exclusive_sum(
387 mask_data, maskPrefixSum_data, mask_numel);
388
389 // Asynchronously check that the number of `1` elements present in the mask
390 // must be <= the number of elements available in `src`.
391 masked_scatter_size_check<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
392 &maskPrefixSum_data[mask_numel - 1], &mask_data[mask_numel - 1], srcSize);
393 C10_CUDA_KERNEL_LAUNCH_CHECK();
394
395 // We are getting elements from `src` based on an offset from
396 // `maskPrefixSum`, so that should be made contiguous too
397 auto source_contig = source.contiguous();
398
399 auto iter = TensorIteratorConfig()
400 .set_check_mem_overlap(false)
401 .check_all_same_dtype(false)
402 .resize_outputs(false)
403 .add_output(self)
404 .add_input(self)
405 .add_const_input(mask_cont)
406 .add_input(maskPrefixSum)
407 .build();
408
409 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
410 ScalarType::Bool,
411 ScalarType::BFloat16,
412 ScalarType::Half,
413 self.scalar_type(),
414 "masked_scatter_",
415 [&]() {
416 auto source_ptr = source_contig.const_data_ptr<scalar_t>();
417 gpu_kernel(
418 iter, [=] GPU_LAMBDA(const scalar_t a, const bool mask, const int64_t maskPrefixSum) -> scalar_t {
419 if (mask) {
420 return source_ptr[maskPrefixSum];
421 }
422 return a;
423 });
424 AT_CUDA_CHECK(cudaGetLastError());
425 });
426 }
427
428 template <typename scalar_t>
flip_kernel_impl(TensorIterator & iter)429 void flip_kernel_impl(TensorIterator& iter) {
430 if (!iter.can_use_32bit_indexing()) {
431 for (auto& sub_iter : iter.with_32bit_indexing()) {
432 flip_kernel_impl<scalar_t>(sub_iter);
433 }
434 return;
435 }
436
437 char* const __restrict__ out_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
438 const char* const __restrict__ in_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
439
440 const auto offset_calc = make_offset_calculator<2, /*signed_strides=*/true>(iter);
441
442 const auto loop = [=]C10_DEVICE(const int i) {
443 const auto offsets = offset_calc.get(i);
444 // offsets can be negative here, but it's fine
445 scalar_t* const __restrict__ out_data = reinterpret_cast<scalar_t*>(out_ptr + offsets[0]);
446 const scalar_t* const __restrict__ in_data = reinterpret_cast<const scalar_t*>(in_ptr + offsets[1]);
447 *out_data = *in_data;
448 };
449 launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), loop);
450 }
451
flip_kernel(TensorIterator & iter,const bool quantized)452 void flip_kernel(TensorIterator& iter, const bool quantized) {
453 if (quantized) {
454 AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cuda",
455 [&] {
456 using dtype = OpaqueType<sizeof(scalar_t)>;
457 flip_kernel_impl<dtype>(iter);
458 });
459 } else {
460 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
461 iter.dtype(), "flip_cuda",
462 [&] {
463 using dtype = OpaqueType<sizeof(scalar_t)>;
464 flip_kernel_impl<dtype>(iter);
465 });
466 }
467 }
468
469
470 REGISTER_DISPATCH(index_stub, &index_kernel);
471 REGISTER_DISPATCH(index_fill_stub, &index_fill_kernel);
472 REGISTER_DISPATCH(index_copy_stub, &index_copy_kernel);
473 REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
474 REGISTER_DISPATCH(put_stub, &put_kernel);
475 REGISTER_DISPATCH(take_stub, &take_kernel);
476 REGISTER_DISPATCH(flip_stub, &flip_kernel);
477
478 REGISTER_CUDA_DISPATCH(index_put_kernel_quantized_stub, &index_put_kernel_quantized_cuda);
479
480 } // namespace at::native
481