1 #pragma once 2 3 namespace at::mps { 4 5 static const char * indexing_metal_shaders = R"INDEX_METAL( 6 #include <metal_stdlib> 7 #include <metal_atomic> 8 9 using namespace metal; 10 11 struct IndexAB { 12 constant int64_t* indexArray; 13 }; 14 15 template<typename T, typename OffsetsT> 16 kernel void index_select( 17 constant IndexAB * indexAB [[buffer(0)]], 18 constant void * indexSizes [[buffer(1)]], 19 constant void * indexStrides [[buffer(2)]], 20 constant OffsetsT * offsets [[buffer(3)]], 21 constant void * inputData [[buffer(4)]], 22 device void * outputData [[buffer(5)]], 23 constant uint32_t & num_indices [[buffer(6)]], 24 uint thread_index [[thread_position_in_grid]]) { 25 constant int64_t * index_sizes = (constant int64_t *)indexSizes; 26 constant int64_t * index_strides = (constant int64_t *)indexStrides; 27 int64_t offset = 0; 28 for (uint32_t i = 0; i < num_indices; i++) { 29 constant int64_t* indexArray = indexAB[i].indexArray; 30 int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; 31 if (index < 0) { 32 index += index_sizes[i]; 33 } 34 offset += index * index_strides[i]; 35 } 36 device T * out = (device T*)((device char*)outputData + offsets[thread_index].x); 37 constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y + offset); 38 *out = *in; 39 } 40 41 template<typename T, typename OffsetsT> 42 void index_put_impl( 43 constant IndexAB * indexAB, 44 constant int64_t * index_sizes, 45 constant int64_t * index_strides, 46 constant OffsetsT * offsets, 47 constant void * inputData, 48 device void * outputData, 49 constant uint32_t & num_indices, 50 uint thread_index) { 51 int64_t offset = 0; 52 for (uint32_t i = 0; i < num_indices; i++) { 53 constant int64_t* indexArray = indexAB[i].indexArray; 54 int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; 55 56 if (index < 0) { 57 index += index_sizes[i]; 58 } 59 offset += index * index_strides[i]; 60 } 61 device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset); 62 constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y); 63 *out = *in; 64 } 65 66 template<typename T, typename OffsetsT> 67 kernel void index_put_serial( 68 constant IndexAB * indexAB [[buffer(0)]], 69 constant void * indexSizes [[buffer(1)]], 70 constant void * indexStrides [[buffer(2)]], 71 constant OffsetsT * offsets [[buffer(3)]], 72 constant void * inputData [[buffer(4)]], 73 device void * outputData [[buffer(5)]], 74 constant uint32_t & num_indices [[buffer(6)]], 75 constant uint * numIters [[buffer(7)]], 76 uint thread_index [[thread_position_in_grid]]) { 77 78 constant int64_t * index_sizes = (constant int64_t *)indexSizes; 79 constant int64_t * index_strides = (constant int64_t *)indexStrides; 80 81 for (uint iter_i = 0; iter_i < *numIters; iter_i++) { 82 index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, iter_i); 83 } 84 } 85 86 template<typename T, typename OffsetsT> 87 kernel void index_put( 88 constant IndexAB * indexAB [[buffer(0)]], 89 constant void * indexSizes [[buffer(1)]], 90 constant void * indexStrides [[buffer(2)]], 91 constant OffsetsT * offsets [[buffer(3)]], 92 constant void * inputData [[buffer(4)]], 93 device void * outputData [[buffer(5)]], 94 constant uint32_t & num_indices [[buffer(6)]], 95 uint thread_index [[thread_position_in_grid]]) { 96 97 constant int64_t * index_sizes = (constant int64_t *)indexSizes; 98 constant int64_t * index_strides = (constant int64_t *)indexStrides; 99 index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index); 100 } 101 102 #define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ 103 template \ 104 [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ 105 kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \ 106 constant IndexAB * indexAB [[buffer(0)]], \ 107 constant void * indexSizes [[buffer(1)]], \ 108 constant void * indexStrides [[buffer(2)]], \ 109 constant IDX_DTYPE * offsets [[buffer(3)]], \ 110 constant void * inputData [[buffer(4)]], \ 111 device void * outputData [[buffer(5)]], \ 112 constant uint32_t & num_indices [[buffer(6)]], \ 113 uint thread_index [[thread_position_in_grid]]); 114 115 #define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ 116 REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ 117 REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ 118 REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ 119 REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ 120 REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \ 121 REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \ 122 REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \ 123 REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3); 124 125 REGISTER_INDEX_OP_ALL_DTYPES(select); 126 REGISTER_INDEX_OP_ALL_DTYPES(put); 127 128 #define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ 129 template \ 130 [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \ 131 kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \ 132 constant IndexAB * indexAB [[buffer(0)]], \ 133 constant void * indexSizes [[buffer(1)]], \ 134 constant void * indexStrides [[buffer(2)]], \ 135 constant IDX_DTYPE * offsets [[buffer(3)]], \ 136 constant void * inputData [[buffer(4)]], \ 137 device void * outputData [[buffer(5)]], \ 138 constant uint32_t & num_indices [[buffer(6)]], \ 139 constant uint * numIters [[buffer(7)]], \ 140 uint thread_index [[thread_position_in_grid]]); 141 142 #define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ 143 REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ 144 REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ 145 REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ 146 REGISTER_SINGLE_THREADED_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ 147 REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \ 148 REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \ 149 REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \ 150 REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3); 151 152 REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial); 153 154 template<typename StridesT, typename DataT> 155 kernel void kernel_index_offsets(constant StridesT * strides [[buffer(0)]], 156 device DataT * data_offsets [[buffer(1)]], 157 constant uint * iter_shape [[buffer(2)]], 158 constant uint & num_dimensions [[buffer(3)]], 159 uint thread_index [[thread_position_in_grid]]) { 160 data_offsets[thread_index] = 0; 161 uint32_t idx = thread_index; 162 for (uint32_t dim = 0; dim < num_dimensions; dim++) { 163 uint32_t remainder = idx % iter_shape[dim]; 164 idx /= iter_shape[dim]; 165 166 data_offsets[thread_index] += remainder * DataT(strides[dim]); 167 } 168 } 169 170 template 171 [[host_name("kernel_index_offsets_32")]] 172 kernel void kernel_index_offsets<packed_uint3, uint3>( 173 constant packed_uint3 * strides [[buffer(0)]], 174 device uint3 * data_offsets [[buffer(1)]], 175 constant uint * iter_shape [[buffer(2)]], 176 constant uint & num_dimensions [[buffer(3)]], 177 uint thread_index [[thread_position_in_grid]]); 178 179 template 180 [[host_name("kernel_index_offsets_64")]] 181 kernel void kernel_index_offsets<packed_uint3, ulong3>( 182 constant packed_uint3 * strides [[buffer(0)]], 183 device ulong3 * data_offsets [[buffer(1)]], 184 constant uint * iter_shape [[buffer(2)]], 185 constant uint & num_dimensions [[buffer(3)]], 186 uint thread_index [[thread_position_in_grid]]); 187 188 template<typename T, typename E, typename OffsetsT> 189 kernel void index_put_accumulate_native_dtypes( 190 constant IndexAB * indexAB [[buffer(0)]], 191 constant void * indexSizes [[buffer(1)]], 192 constant void * indexStrides [[buffer(2)]], 193 constant OffsetsT * offsets [[buffer(3)]], 194 constant void * inputData [[buffer(4)]], 195 device void * outputData [[buffer(5)]], 196 constant uint32_t & num_indices [[buffer(6)]], 197 uint thread_index [[thread_position_in_grid]]) { 198 constant int64_t * index_sizes = (constant int64_t *)indexSizes; 199 constant int64_t * index_strides = (constant int64_t *)indexStrides; 200 int64_t offset = 0; 201 for (uint32_t i = 0; i < num_indices; i++) { 202 constant int64_t* indexArray = indexAB[i].indexArray; 203 int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; 204 if (index < 0) { 205 index += index_sizes[i]; 206 } 207 offset += index * index_strides[i]; 208 } 209 device T * out = (device T*)((device char*)outputData + offsets[thread_index].x + offset); 210 constant E * in = (constant E*)((constant char*)inputData + offsets[thread_index].y); 211 atomic_fetch_add_explicit(out, *in, memory_order_relaxed); 212 } 213 214 template<typename T> 215 __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * addr, T value) { 216 device atomic_uint* uintAddr = (device atomic_uint*)addr; 217 uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed); 218 T updated = as_type<T>(expected) + value; 219 while (!atomic_compare_exchange_weak_explicit(uintAddr, &expected, as_type<uint>(updated), memory_order_relaxed, memory_order_relaxed)) { 220 updated = as_type<T>(expected) + value; 221 } 222 } 223 224 template<typename T, typename OffsetsT> 225 kernel void atomic_index_put_accumulate( 226 constant IndexAB * indexAB [[buffer(0)]], 227 constant void * indexSizes [[buffer(1)]], 228 constant void * indexStrides [[buffer(2)]], 229 constant OffsetsT * offsets [[buffer(3)]], 230 constant void * inputData [[buffer(4)]], 231 device void * outputData [[buffer(5)]], 232 constant uint32_t & num_indices [[buffer(6)]], 233 uint thread_index [[thread_position_in_grid]]) { 234 constant int64_t * index_sizes = (constant int64_t *)indexSizes; 235 constant int64_t * index_strides = (constant int64_t *)indexStrides; 236 int64_t offset = 0; 237 for (uint32_t i = 0; i < num_indices; i++) { 238 constant int64_t* indexArray = indexAB[i].indexArray; 239 int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)]; 240 if (index < 0) { 241 index += index_sizes[i]; 242 } 243 offset += index * index_strides[i]; 244 } 245 device void * out = (device void*)((device char*)outputData + offsets[thread_index].x + offset); 246 constant T * in = (constant T*)((constant char*)inputData + offsets[thread_index].y); 247 atomic_fetch_add_relaxed<T>(out, *in); 248 } 249 250 template 251 [[host_name("index_put_accumulate_32bit_float_idx32")]] 252 kernel void atomic_index_put_accumulate<float, uint3>( 253 constant IndexAB * indexAB [[buffer(0)]], 254 constant void * indexSizes [[buffer(1)]], 255 constant void * indexStrides [[buffer(2)]], 256 constant uint3 * offsets [[buffer(3)]], 257 constant void * inputData [[buffer(4)]], 258 device void * outputData [[buffer(5)]], 259 constant uint32_t & num_indices [[buffer(6)]], 260 uint thread_index [[thread_position_in_grid]]); 261 262 template 263 [[host_name("index_put_accumulate_32bit_float_idx64")]] 264 kernel void atomic_index_put_accumulate<float, ulong3>( 265 constant IndexAB * indexAB [[buffer(0)]], 266 constant void * indexSizes [[buffer(1)]], 267 constant void * indexStrides [[buffer(2)]], 268 constant ulong3 * offsets [[buffer(3)]], 269 constant void * inputData [[buffer(4)]], 270 device void * outputData [[buffer(5)]], 271 constant uint32_t & num_indices [[buffer(6)]], 272 uint thread_index [[thread_position_in_grid]]); 273 274 template 275 [[host_name("index_put_accumulate_32bit_int_idx32")]] 276 kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>( 277 constant IndexAB * indexAB [[buffer(0)]], 278 constant void * indexSizes [[buffer(1)]], 279 constant void * indexStrides [[buffer(2)]], 280 constant uint3 * offsets [[buffer(3)]], 281 constant void * inputData [[buffer(4)]], 282 device void * outputData [[buffer(5)]], 283 constant uint32_t & num_indices [[buffer(6)]], 284 uint thread_index [[thread_position_in_grid]]); 285 286 template 287 [[host_name("index_put_accumulate_32bit_int_idx64")]] 288 kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>( 289 constant IndexAB * indexAB [[buffer(0)]], 290 constant void * indexSizes [[buffer(1)]], 291 constant void * indexStrides [[buffer(2)]], 292 constant ulong3 * offsets [[buffer(3)]], 293 constant void * inputData [[buffer(4)]], 294 device void * outputData [[buffer(5)]], 295 constant uint32_t & num_indices [[buffer(6)]], 296 uint thread_index [[thread_position_in_grid]]); 297 )INDEX_METAL"; 298 299 static const char *SCATTER_OPS_TEMPLATE = R"METAL_SCATTER( 300 struct __attribute__ ((packed)) packed_uint5{{ 301 uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u; 302 }}; 303 304 template<typename Y, typename X> 305 Y cast(const X x); 306 307 template<> 308 {1} cast<{1}, {0}>(const {0} x) {{ 309 return {2}; 310 }} 311 312 kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]], 313 constant void * src_ [[buffer(0)]], 314 device void * dst_ [[buffer(1)]], 315 constant packed_uint5 & size [[buffer(2)]], 316 constant packed_uint5 & stride [[buffer(3)]], 317 constant uint32_t & numel [[buffer(4)]]) {{ 318 if (linear_index >= numel) return; 319 320 constant {0} * src = (constant {0} *)src_; 321 device {1} * dst = (device {1} *)dst_; 322 323 packed_uint5 local_index; 324 local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x; 325 local_index.y = linear_index / (size.u * size.w * size.z) % size.y; 326 local_index.z = linear_index / (size.u * size.w) % size.z; 327 local_index.w = linear_index / size.u % size.w; 328 local_index.u = linear_index % size.u; 329 330 packed_uint5 strided_index; 331 strided_index.x = local_index.x * stride.x; 332 strided_index.y = local_index.y * stride.y; 333 strided_index.z = local_index.z * stride.z; 334 strided_index.w = local_index.w * stride.w; 335 strided_index.u = local_index.u * stride.u; 336 337 dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]); 338 }} 339 340 kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]], 341 constant void * src_ [[buffer(0)]], 342 device void * dst_ [[buffer(1)]], 343 constant packed_uint4 & size [[buffer(2)]], 344 constant packed_uint4 & stride [[buffer(3)]], 345 constant uint32_t & numel [[buffer(4)]]) {{ 346 if (linear_index >= numel) return; 347 348 constant {0} * src = (constant {0} *)src_; 349 device {1} * dst = (device {1} *)dst_; 350 351 packed_uint4 local_index; 352 local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; 353 local_index.y = linear_index / (size[3] * size[2]) % size[1]; 354 local_index.z = linear_index / size[3] % size[2]; 355 local_index.w = linear_index % size[3]; 356 357 const packed_uint4 strided_index = local_index * stride; 358 dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]); 359 }} 360 361 kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]], 362 constant void * src_ [[buffer(0)]], 363 device void * dst_ [[buffer(1)]], 364 constant packed_uint3 & size [[buffer(2)]], 365 constant packed_uint3 & stride [[buffer(3)]], 366 constant uint32_t & numel [[buffer(4)]]) {{ 367 if (linear_index >= numel) return; 368 369 constant {0} * src = (constant {0} *)src_; 370 device {1} * dst = (device {1} *)dst_; 371 372 packed_uint3 local_index; 373 local_index.x = linear_index / (size[2] * size[1]) % size[0]; 374 local_index.y = linear_index / size[2] % size[1]; 375 local_index.z = linear_index % size[2]; 376 377 const packed_uint3 strided_index = local_index * stride; 378 dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]); 379 }} 380 381 kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]], 382 constant void * src_ [[buffer(0)]], 383 device void * dst_ [[buffer(1)]], 384 constant packed_uint2 & size [[buffer(2)]], 385 constant packed_uint2 & stride [[buffer(3)]], 386 constant uint32_t & numel [[buffer(4)]]) {{ 387 if (linear_index >= numel) return; 388 389 constant {0} * src = (constant {0} *)src_; 390 device {1} * dst = (device {1} *)dst_; 391 392 packed_uint2 local_index; 393 local_index.x = linear_index / size[1] % size[0]; 394 local_index.y = linear_index % size[1]; 395 396 const packed_uint2 strided_index = local_index * stride; 397 dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]); 398 }} 399 400 kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]], 401 constant void * src_ [[buffer(0)]], 402 device void * dst_ [[buffer(1)]], 403 constant int & size [[buffer(2)]], 404 constant int & stride [[buffer(3)]], 405 constant uint32_t & numel [[buffer(4)]]) {{ 406 if (linear_index >= numel) return; 407 408 constant {0} * src = (constant {0} *)src_; 409 device {1} * dst = (device {1} *)dst_; 410 411 const int local_index = linear_index % size; 412 const int strided_index = local_index * stride; 413 dst[strided_index] = cast<{1}>(src[linear_index]); 414 }} 415 )METAL_SCATTER"; 416 417 static const char *GATHER_OPS_TEMPLATE = R"METAL_GATHER( 418 struct __attribute__ ((packed)) packed_uint5{{ 419 uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u; 420 }}; 421 422 template<typename Y, typename X> 423 Y cast(const X x); 424 425 template<> 426 {1} cast<{1}, {0}>(const {0} x) {{ 427 return {2}; 428 }} 429 430 kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]], 431 constant void * src_ [[buffer(0)]], 432 device void * dst_ [[buffer(1)]], 433 constant packed_uint5 & size [[buffer(2)]], 434 constant packed_uint5 & stride [[buffer(3)]], 435 constant uint32_t & numel [[buffer(4)]]) {{ 436 if (linear_index >= numel) return; 437 438 constant {0} * src = (constant {0} *)src_; 439 device {1} * dst = (device {1} *)dst_; 440 441 442 packed_uint5 local_index; 443 local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x; 444 local_index.y = linear_index / (size.u * size.w * size.z) % size.y; 445 local_index.z = linear_index / (size.u * size.w) % size.z; 446 local_index.w = linear_index / size.u % size.w; 447 local_index.u = linear_index % size.u; 448 449 packed_uint5 strided_index; 450 strided_index.x = local_index.x * stride.x; 451 strided_index.y = local_index.y * stride.y; 452 strided_index.z = local_index.z * stride.z; 453 strided_index.w = local_index.w * stride.w; 454 strided_index.u = local_index.u * stride.u; 455 456 dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]); 457 }} 458 459 kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]], 460 constant void * src_ [[buffer(0)]], 461 device void * dst_ [[buffer(1)]], 462 constant packed_uint4 & size [[buffer(2)]], 463 constant packed_uint4 & stride [[buffer(3)]], 464 constant uint32_t & numel [[buffer(4)]]) {{ 465 if (linear_index >= numel) return; 466 467 constant {0} * src = (constant {0} *)src_; 468 device {1} * dst = (device {1} *)dst_; 469 470 packed_uint4 local_index; 471 local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0]; 472 local_index.y = linear_index / (size[3] * size[2]) % size[1]; 473 local_index.z = linear_index / size[3] % size[2]; 474 local_index.w = linear_index % size[3]; 475 476 const packed_uint4 strided_index = local_index * stride; 477 dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]); 478 }} 479 480 kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]], 481 constant void * src_ [[buffer(0)]], 482 device void * dst_ [[buffer(1)]], 483 constant packed_uint3 & size [[buffer(2)]], 484 constant packed_uint3 & stride [[buffer(3)]], 485 constant uint32_t & numel [[buffer(4)]]) {{ 486 if (linear_index >= numel) return; 487 488 constant {0} * src = (constant {0} *)src_; 489 device {1} * dst = (device {1} *)dst_; 490 491 packed_uint3 local_index; 492 local_index.x = linear_index / (size[2] * size[1]) % size[0]; 493 local_index.y = linear_index / size[2] % size[1]; 494 local_index.z = linear_index % size[2]; 495 496 const packed_uint3 strided_index = local_index * stride; 497 dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]); 498 }} 499 500 kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]], 501 constant void * src_ [[buffer(0)]], 502 device void * dst_ [[buffer(1)]], 503 constant packed_uint2 & size [[buffer(2)]], 504 constant packed_uint2 & stride [[buffer(3)]], 505 constant uint32_t & numel [[buffer(4)]]) {{ 506 if (linear_index >= numel) return; 507 508 constant {0} * src = (constant {0} *)src_; 509 device {1} * dst = (device {1} *)dst_; 510 511 packed_uint2 local_index; 512 local_index.x = linear_index / size[1] % size[0]; 513 local_index.y = linear_index % size[1]; 514 515 const packed_uint2 strided_index = local_index * stride; 516 dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]); 517 }} 518 519 kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]], 520 constant void * src_ [[buffer(0)]], 521 device void * dst_ [[buffer(1)]], 522 constant int & size [[buffer(2)]], 523 constant int & stride [[buffer(3)]], 524 constant uint32_t & numel [[buffer(4)]]) {{ 525 if (linear_index >= numel) return; 526 527 constant {0} * src = (constant {0} *)src_; 528 device {1} * dst = (device {1} *)dst_; 529 530 const int local_index = linear_index % size; 531 const int strided_index = local_index * stride; 532 dst[linear_index] = cast<{1}>(src[strided_index]); 533 }} 534 )METAL_GATHER"; 535 } // namespace at::mps 536