xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/thread_constants.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 
4 // Marks a lambda as executable on both the host and device. The __host__
5 // attribute is important so that we can access static type information from
6 // the host, even if the function is typically only executed on the device.
7 #ifndef GPU_LAMBDA
8 #define GPU_LAMBDA __host__ __device__
9 #endif
10 
11 #if defined(USE_ROCM)
num_threads()12 constexpr int num_threads() {
13   return 256;
14 }
15 #else
num_threads()16 constexpr uint32_t num_threads() {
17   return C10_WARP_SIZE * 4;
18 }
19 #endif
20 
thread_work_size()21 constexpr int thread_work_size() { return 4; }
block_work_size()22 constexpr int block_work_size() { return thread_work_size() * num_threads(); }
23