xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/IndexKernels.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 namespace at::mps {
4 
5 static const char * indexing_metal_shaders = R"INDEX_METAL(
6 #include <metal_stdlib>
7 #include <metal_atomic>
8 
9 using namespace metal;
10 
11 struct IndexAB {
12     constant int64_t* indexArray;
13 };
14 
15 template<typename T, typename OffsetsT>
16 kernel void index_select(
17     constant IndexAB  * indexAB           [[buffer(0)]],
18     constant void     * indexSizes        [[buffer(1)]],
19     constant void     * indexStrides      [[buffer(2)]],
20     constant OffsetsT * offsets           [[buffer(3)]],
21     constant void     * inputData         [[buffer(4)]],
22     device   void     * outputData        [[buffer(5)]],
23     constant uint32_t & num_indices       [[buffer(6)]],
24     uint thread_index [[thread_position_in_grid]]) {
25     constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
26     constant int64_t * index_strides = (constant int64_t *)indexStrides;
27     int64_t offset = 0;
28     for (uint32_t i = 0; i < num_indices; i++) {
29         constant int64_t* indexArray = indexAB[i].indexArray;
30         int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
31         if (index < 0) {
32             index += index_sizes[i];
33         }
34         offset += index * index_strides[i];
35      }
36     device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
37     constant T * in  = (constant T*)((constant char*)inputData  + offsets[thread_index].y + offset);
38     *out = *in;
39 }
40 
41 template<typename T, typename OffsetsT>
42 void index_put_impl(
43     constant IndexAB  * indexAB,
44     constant int64_t  * index_sizes,
45     constant int64_t  * index_strides,
46     constant OffsetsT * offsets,
47     constant void     * inputData,
48     device   void     * outputData,
49     constant uint32_t & num_indices,
50     uint thread_index) {
51     int64_t offset = 0;
52     for (uint32_t i = 0; i < num_indices; i++) {
53         constant int64_t* indexArray = indexAB[i].indexArray;
54         int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
55 
56         if (index < 0) {
57             index += index_sizes[i];
58         }
59         offset += index * index_strides[i];
60     }
61     device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
62     constant T * in  = (constant T*)((constant char*)inputData  + offsets[thread_index].y);
63     *out = *in;
64 }
65 
66 template<typename T, typename OffsetsT>
67 kernel void index_put_serial(
68     constant IndexAB  * indexAB           [[buffer(0)]],
69     constant void     * indexSizes        [[buffer(1)]],
70     constant void     * indexStrides      [[buffer(2)]],
71     constant OffsetsT * offsets           [[buffer(3)]],
72     constant void     * inputData         [[buffer(4)]],
73     device   void     * outputData        [[buffer(5)]],
74     constant uint32_t & num_indices       [[buffer(6)]],
75     constant uint     * numIters          [[buffer(7)]],
76     uint thread_index [[thread_position_in_grid]]) {
77 
78     constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
79     constant int64_t * index_strides = (constant int64_t *)indexStrides;
80 
81     for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
82         index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i);
83     }
84 }
85 
86 template<typename T, typename OffsetsT>
87 kernel void index_put(
88     constant IndexAB  * indexAB           [[buffer(0)]],
89     constant void     * indexSizes        [[buffer(1)]],
90     constant void     * indexStrides      [[buffer(2)]],
91     constant OffsetsT * offsets           [[buffer(3)]],
92     constant void     * inputData         [[buffer(4)]],
93     device   void     * outputData        [[buffer(5)]],
94     constant uint32_t & num_indices       [[buffer(6)]],
95     uint thread_index [[thread_position_in_grid]]) {
96 
97     constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
98     constant int64_t * index_strides = (constant int64_t *)indexStrides;
99     index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
100 }
101 
102 #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE)   \
103 template                                                                           \
104 [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]]               \
105 kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>(                             \
106     constant IndexAB * indexAB           [[buffer(0)]],                            \
107     constant void    * indexSizes        [[buffer(1)]],                            \
108     constant void    * indexStrides      [[buffer(2)]],                            \
109     constant IDX_DTYPE   * offsets           [[buffer(3)]],                        \
110     constant void    * inputData         [[buffer(4)]],                            \
111     device   void    * outputData        [[buffer(5)]],                            \
112     constant uint32_t & num_indices      [[buffer(6)]],                            \
113     uint thread_index [[thread_position_in_grid]]);
114 
115 #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE)     \
116     REGISTER_INDEX_OP(8bit,  idx32, char,  INDEX_OP_TYPE, uint3);     \
117     REGISTER_INDEX_OP(8bit,  idx64, char,  INDEX_OP_TYPE, ulong3);    \
118     REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3);     \
119     REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3);    \
120     REGISTER_INDEX_OP(32bit, idx32, int,   INDEX_OP_TYPE, uint3);     \
121     REGISTER_INDEX_OP(32bit, idx64, int,   INDEX_OP_TYPE, ulong3);    \
122     REGISTER_INDEX_OP(64bit, idx32, long,  INDEX_OP_TYPE, uint3);     \
123     REGISTER_INDEX_OP(64bit, idx64, long,  INDEX_OP_TYPE, ulong3);
124 
125 REGISTER_INDEX_OP_ALL_DTYPES(select);
126 REGISTER_INDEX_OP_ALL_DTYPES(put);
127 
128 #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE)   \
129 template                                                                                           \
130 [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]]                               \
131 kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>(                                             \
132     constant IndexAB   * indexAB           [[buffer(0)]],                                          \
133     constant void      * indexSizes        [[buffer(1)]],                                          \
134     constant void      * indexStrides      [[buffer(2)]],                                          \
135     constant IDX_DTYPE * offsets           [[buffer(3)]],                                          \
136     constant void      * inputData         [[buffer(4)]],                                          \
137     device   void      * outputData        [[buffer(5)]],                                          \
138     constant uint32_t  & num_indices       [[buffer(6)]],                                          \
139     constant uint      * numIters          [[buffer(7)]],                                          \
140     uint thread_index [[thread_position_in_grid]]);
141 
142 #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE)                   \
143     REGISTER_SINGLE_THREADED_INDEX_OP(8bit,  idx32, char,  INDEX_OP_TYPE, uint3);     \
144     REGISTER_SINGLE_THREADED_INDEX_OP(8bit,  idx64, char,  INDEX_OP_TYPE, ulong3);    \
145     REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3);     \
146     REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3);    \
147     REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int,   INDEX_OP_TYPE, uint3);     \
148     REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int,   INDEX_OP_TYPE, ulong3);    \
149     REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long,  INDEX_OP_TYPE, uint3);     \
150     REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long,  INDEX_OP_TYPE, ulong3);
151 
152 REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
153 
154 template<typename StridesT, typename DataT>
155 kernel void kernel_index_offsets(constant StridesT * strides         [[buffer(0)]],
156                                 device DataT      * data_offsets    [[buffer(1)]],
157                                 constant uint     * iter_shape      [[buffer(2)]],
158                                 constant uint     & num_dimensions  [[buffer(3)]],
159                                 uint thread_index [[thread_position_in_grid]]) {
160     data_offsets[thread_index] = 0;
161     uint32_t idx = thread_index;
162     for (uint32_t dim = 0; dim < num_dimensions; dim++) {
163         uint32_t remainder = idx % iter_shape[dim];
164         idx /= iter_shape[dim];
165 
166         data_offsets[thread_index] += remainder * DataT(strides[dim]);
167     }
168 }
169 
170 template
171 [[host_name("kernel_index_offsets_32")]]
172 kernel void kernel_index_offsets<packed_uint3, uint3>(
173                 constant packed_uint3 * strides         [[buffer(0)]],
174                 device uint3          * data_offsets    [[buffer(1)]],
175                 constant uint         * iter_shape      [[buffer(2)]],
176                 constant uint         & num_dimensions  [[buffer(3)]],
177                 uint thread_index [[thread_position_in_grid]]);
178 
179 template
180 [[host_name("kernel_index_offsets_64")]]
181 kernel void kernel_index_offsets<packed_uint3, ulong3>(
182                 constant packed_uint3 * strides         [[buffer(0)]],
183                 device ulong3          * data_offsets    [[buffer(1)]],
184                 constant uint         * iter_shape      [[buffer(2)]],
185                 constant uint         & num_dimensions  [[buffer(3)]],
186                 uint thread_index [[thread_position_in_grid]]);
187 
188 template<typename T, typename E, typename OffsetsT>
189 kernel void index_put_accumulate_native_dtypes(
190     constant IndexAB  * indexAB     [[buffer(0)]],
191     constant void     * indexSizes   [[buffer(1)]],
192     constant void     * indexStrides [[buffer(2)]],
193     constant OffsetsT * offsets      [[buffer(3)]],
194     constant void     * inputData    [[buffer(4)]],
195     device void       * outputData   [[buffer(5)]],
196     constant uint32_t & num_indices  [[buffer(6)]],
197     uint thread_index [[thread_position_in_grid]]) {
198     constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
199     constant int64_t * index_strides = (constant int64_t *)indexStrides;
200     int64_t offset = 0;
201     for (uint32_t i = 0; i < num_indices; i++) {
202         constant int64_t* indexArray = indexAB[i].indexArray;
203         int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
204         if (index < 0) {
205             index += index_sizes[i];
206         }
207         offset += index * index_strides[i];
208     }
209     device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset);
210     constant E * in  = (constant E*)((constant char*)inputData  + offsets[thread_index].y);
211     atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
212 }
213 
214 template<typename T>
215 __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) {
216     device atomic_uint* uintAddr = (device atomic_uint*)addr;
217     uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
218     T updated = as_type<T>(expected) + value;
219     while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) {
220         updated = as_type<T>(expected) + value;
221     }
222 }
223 
224 template<typename T, typename OffsetsT>
225 kernel void atomic_index_put_accumulate(
226     constant IndexAB  * indexAB           [[buffer(0)]],
227     constant void     * indexSizes        [[buffer(1)]],
228     constant void     * indexStrides      [[buffer(2)]],
229     constant OffsetsT * offsets           [[buffer(3)]],
230     constant void     * inputData         [[buffer(4)]],
231     device   void     * outputData        [[buffer(5)]],
232     constant uint32_t & num_indices       [[buffer(6)]],
233     uint thread_index [[thread_position_in_grid]]) {
234     constant int64_t * index_sizes   = (constant int64_t *)indexSizes;
235     constant int64_t * index_strides = (constant int64_t *)indexStrides;
236     int64_t offset = 0;
237     for (uint32_t i = 0; i < num_indices; i++) {
238         constant int64_t* indexArray = indexAB[i].indexArray;
239         int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
240         if (index < 0) {
241             index += index_sizes[i];
242         }
243         offset += index * index_strides[i];
244     }
245     device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset);
246     constant T  * in  = (constant T*)((constant char*)inputData + offsets[thread_index].y);
247     atomic_fetch_add_relaxed<T>(out, *in);
248 }
249 
250 template
251 [[host_name("index_put_accumulate_32bit_float_idx32")]]
252 kernel void atomic_index_put_accumulate<float, uint3>(
253     constant IndexAB  * indexAB     [[buffer(0)]],
254     constant void     * indexSizes   [[buffer(1)]],
255     constant void     * indexStrides [[buffer(2)]],
256     constant uint3    * offsets      [[buffer(3)]],
257     constant void     * inputData    [[buffer(4)]],
258     device   void     * outputData   [[buffer(5)]],
259     constant uint32_t & num_indices  [[buffer(6)]],
260     uint thread_index [[thread_position_in_grid]]);
261 
262 template
263 [[host_name("index_put_accumulate_32bit_float_idx64")]]
264 kernel void atomic_index_put_accumulate<float, ulong3>(
265     constant IndexAB  * indexAB     [[buffer(0)]],
266     constant void     * indexSizes   [[buffer(1)]],
267     constant void     * indexStrides [[buffer(2)]],
268     constant ulong3   * offsets      [[buffer(3)]],
269     constant void     * inputData    [[buffer(4)]],
270     device   void     * outputData   [[buffer(5)]],
271     constant uint32_t & num_indices  [[buffer(6)]],
272     uint thread_index [[thread_position_in_grid]]);
273 
274 template
275 [[host_name("index_put_accumulate_32bit_int_idx32")]]
276 kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
277     constant IndexAB  * indexAB     [[buffer(0)]],
278     constant void     * indexSizes   [[buffer(1)]],
279     constant void     * indexStrides [[buffer(2)]],
280     constant uint3    * offsets      [[buffer(3)]],
281     constant void     * inputData    [[buffer(4)]],
282     device   void     * outputData   [[buffer(5)]],
283     constant uint32_t & num_indices [[buffer(6)]],
284     uint thread_index [[thread_position_in_grid]]);
285 
286 template
287 [[host_name("index_put_accumulate_32bit_int_idx64")]]
288 kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
289     constant IndexAB  * indexAB     [[buffer(0)]],
290     constant void     * indexSizes   [[buffer(1)]],
291     constant void     * indexStrides [[buffer(2)]],
292     constant ulong3   * offsets      [[buffer(3)]],
293     constant void     * inputData    [[buffer(4)]],
294     device   void     * outputData   [[buffer(5)]],
295     constant uint32_t & num_indices [[buffer(6)]],
296     uint thread_index [[thread_position_in_grid]]);
297 )INDEX_METAL";
298 
299 static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
300 struct __attribute__ ((packed)) packed_uint5{{
301   uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
302 }};
303 
304 template<typename Y, typename X>
305 Y cast(const X x);
306 
307 template<>
308 {1} cast<{1}, {0}>(const {0} x) {{
309  return {2};
310 }}
311 
312 kernel void scatter_kernel_5(uint linear_index              [[thread_position_in_grid]],
313                              constant void * src_           [[buffer(0)]],
314                              device void * dst_             [[buffer(1)]],
315                              constant packed_uint5 & size   [[buffer(2)]],
316                              constant packed_uint5 & stride [[buffer(3)]],
317                              constant uint32_t & numel      [[buffer(4)]]) {{
318     if (linear_index >= numel) return;
319 
320     constant {0} * src = (constant {0} *)src_;
321     device {1} * dst = (device {1} *)dst_;
322 
323     packed_uint5 local_index;
324     local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
325     local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
326     local_index.z = linear_index / (size.u * size.w) % size.z;
327     local_index.w = linear_index / size.u % size.w;
328     local_index.u = linear_index % size.u;
329 
330     packed_uint5 strided_index;
331     strided_index.x = local_index.x * stride.x;
332     strided_index.y = local_index.y * stride.y;
333     strided_index.z = local_index.z * stride.z;
334     strided_index.w = local_index.w * stride.w;
335     strided_index.u = local_index.u * stride.u;
336 
337     dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
338 }}
339 
340 kernel void scatter_kernel_4(uint linear_index              [[thread_position_in_grid]],
341                              constant void * src_           [[buffer(0)]],
342                              device void * dst_             [[buffer(1)]],
343                              constant packed_uint4 & size   [[buffer(2)]],
344                              constant packed_uint4 & stride [[buffer(3)]],
345                              constant uint32_t & numel      [[buffer(4)]]) {{
346     if (linear_index >= numel) return;
347 
348     constant {0} * src = (constant {0} *)src_;
349     device {1} * dst = (device {1} *)dst_;
350 
351     packed_uint4 local_index;
352     local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
353     local_index.y = linear_index / (size[3] * size[2]) % size[1];
354     local_index.z = linear_index / size[3] % size[2];
355     local_index.w = linear_index % size[3];
356 
357     const packed_uint4 strided_index = local_index * stride;
358     dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
359 }}
360 
361 kernel void scatter_kernel_3(uint linear_index              [[thread_position_in_grid]],
362                              constant void * src_           [[buffer(0)]],
363                              device void * dst_             [[buffer(1)]],
364                              constant packed_uint3 & size   [[buffer(2)]],
365                              constant packed_uint3 & stride [[buffer(3)]],
366                              constant uint32_t & numel      [[buffer(4)]]) {{
367     if (linear_index >= numel) return;
368 
369     constant {0} * src = (constant {0} *)src_;
370     device {1} * dst = (device {1} *)dst_;
371 
372     packed_uint3 local_index;
373     local_index.x = linear_index / (size[2] * size[1]) % size[0];
374     local_index.y = linear_index / size[2] % size[1];
375     local_index.z = linear_index % size[2];
376 
377     const packed_uint3 strided_index = local_index * stride;
378     dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
379 }}
380 
381 kernel void scatter_kernel_2(uint linear_index              [[thread_position_in_grid]],
382                              constant void * src_           [[buffer(0)]],
383                              device void * dst_             [[buffer(1)]],
384                              constant packed_uint2 & size   [[buffer(2)]],
385                              constant packed_uint2 & stride [[buffer(3)]],
386                              constant uint32_t & numel      [[buffer(4)]]) {{
387     if (linear_index >= numel) return;
388 
389     constant {0} * src = (constant {0} *)src_;
390     device {1} * dst = (device {1} *)dst_;
391 
392     packed_uint2 local_index;
393     local_index.x = linear_index / size[1] % size[0];
394     local_index.y = linear_index % size[1];
395 
396     const packed_uint2 strided_index = local_index * stride;
397     dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
398 }}
399 
400 kernel void scatter_kernel_1(uint linear_index              [[thread_position_in_grid]],
401                              constant void * src_           [[buffer(0)]],
402                              device void * dst_             [[buffer(1)]],
403                              constant int & size            [[buffer(2)]],
404                              constant int & stride          [[buffer(3)]],
405                              constant uint32_t & numel      [[buffer(4)]]) {{
406     if (linear_index >= numel) return;
407 
408     constant {0} * src = (constant {0} *)src_;
409     device {1} * dst = (device {1} *)dst_;
410 
411     const int local_index = linear_index % size;
412     const int strided_index = local_index * stride;
413     dst[strided_index] = cast<{1}>(src[linear_index]);
414 }}
415 )METAL_SCATTER";
416 
417 static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER(
418 struct __attribute__ ((packed)) packed_uint5{{
419   uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
420 }};
421 
422 template<typename Y, typename X>
423 Y cast(const X x);
424 
425 template<>
426 {1} cast<{1}, {0}>(const {0} x) {{
427  return {2};
428 }}
429 
430 kernel void gather_kernel_5(uint linear_index               [[thread_position_in_grid]],
431                             constant void * src_            [[buffer(0)]],
432                             device void * dst_              [[buffer(1)]],
433                             constant packed_uint5 & size    [[buffer(2)]],
434                             constant packed_uint5 & stride  [[buffer(3)]],
435                             constant uint32_t & numel       [[buffer(4)]]) {{
436     if (linear_index >= numel) return;
437 
438     constant {0} * src = (constant {0} *)src_;
439     device {1} * dst = (device {1} *)dst_;
440 
441 
442     packed_uint5 local_index;
443     local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
444     local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
445     local_index.z = linear_index / (size.u * size.w) % size.z;
446     local_index.w = linear_index / size.u % size.w;
447     local_index.u = linear_index % size.u;
448 
449     packed_uint5 strided_index;
450     strided_index.x = local_index.x * stride.x;
451     strided_index.y = local_index.y * stride.y;
452     strided_index.z = local_index.z * stride.z;
453     strided_index.w = local_index.w * stride.w;
454     strided_index.u = local_index.u * stride.u;
455 
456     dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
457 }}
458 
459 kernel void gather_kernel_4(uint linear_index               [[thread_position_in_grid]],
460                             constant void * src_            [[buffer(0)]],
461                             device void * dst_              [[buffer(1)]],
462                             constant packed_uint4 & size    [[buffer(2)]],
463                             constant packed_uint4 & stride  [[buffer(3)]],
464                             constant uint32_t & numel       [[buffer(4)]]) {{
465     if (linear_index >= numel) return;
466 
467     constant {0} * src = (constant {0} *)src_;
468     device {1} * dst = (device {1} *)dst_;
469 
470     packed_uint4 local_index;
471     local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
472     local_index.y = linear_index / (size[3] * size[2]) % size[1];
473     local_index.z = linear_index / size[3] % size[2];
474     local_index.w = linear_index % size[3];
475 
476     const packed_uint4 strided_index = local_index * stride;
477     dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
478 }}
479 
480 kernel void gather_kernel_3(uint linear_index               [[thread_position_in_grid]],
481                             constant void * src_            [[buffer(0)]],
482                             device void * dst_              [[buffer(1)]],
483                             constant packed_uint3 & size    [[buffer(2)]],
484                             constant packed_uint3 & stride  [[buffer(3)]],
485                             constant uint32_t & numel       [[buffer(4)]]) {{
486     if (linear_index >= numel) return;
487 
488     constant {0} * src = (constant {0} *)src_;
489     device {1} * dst = (device {1} *)dst_;
490 
491     packed_uint3 local_index;
492     local_index.x = linear_index / (size[2] * size[1]) % size[0];
493     local_index.y = linear_index / size[2] % size[1];
494     local_index.z = linear_index % size[2];
495 
496     const packed_uint3 strided_index = local_index * stride;
497     dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
498 }}
499 
500 kernel void gather_kernel_2(uint linear_index               [[thread_position_in_grid]],
501                             constant void * src_            [[buffer(0)]],
502                             device void * dst_              [[buffer(1)]],
503                             constant packed_uint2 & size    [[buffer(2)]],
504                             constant packed_uint2 & stride  [[buffer(3)]],
505                             constant uint32_t & numel       [[buffer(4)]]) {{
506     if (linear_index >= numel) return;
507 
508     constant {0} * src = (constant {0} *)src_;
509     device {1} * dst = (device {1} *)dst_;
510 
511     packed_uint2 local_index;
512     local_index.x = linear_index / size[1] % size[0];
513     local_index.y = linear_index % size[1];
514 
515     const packed_uint2 strided_index = local_index * stride;
516     dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
517 }}
518 
519 kernel void gather_kernel_1(uint linear_index               [[thread_position_in_grid]],
520                             constant void * src_            [[buffer(0)]],
521                             device void * dst_              [[buffer(1)]],
522                             constant int & size             [[buffer(2)]],
523                             constant int & stride           [[buffer(3)]],
524                             constant uint32_t & numel       [[buffer(4)]]) {{
525     if (linear_index >= numel) return;
526 
527     constant {0} * src = (constant {0} *)src_;
528     device {1} * dst = (device {1} *)dst_;
529 
530     const int local_index = linear_index % size;
531     const int strided_index = local_index * stride;
532     dst[linear_index] = cast<{1}>(src[strided_index]);
533 }}
534 )METAL_GATHER";
535 } // namespace at::mps
536