1 #pragma once
2
3 // This file provides two functions to help write GPU elementwise kernels:
4 //
5 // gpu_kernel(TensorIterator iter, <lambda>)
6 // gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
7 //
8 // The gpu_kernel_with_scalars generates specializations that support a
9 // single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
10 // is lifted to a kernel parameter instead of copying to device memory.
11 // This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
12 // which is the default for TensorIterator::binary_op. Otherwise, all inputs
13 // and the output must be on the GPU.
14 //
15 // For example, to write a reciprocal kernel for GPU float Tensors:
16 //
17 // gpu_kernel(iter, []GPU_LAMBDA(float a) {
18 // return 1.0f / a;
19 // });
20 //
21 // To write a multiplication kernel for GPU float Tensors where one argument
22 // may be a CPU scalar:
23 //
24 // gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
25 // return a * b;
26 // });
27 //
28 // See BinaryOpsKernel.cu for the complete implementation
29 //
30
31 #include <iostream>
32 #include <tuple>
33 #include <type_traits>
34
35 #include <ATen/core/Array.h>
36 #include <ATen/cuda/CUDAContext.h>
37 #include <ATen/detail/FunctionTraits.h>
38 #include <ATen/native/TensorIterator.h>
39 #include <c10/core/DynamicCast.h>
40 #include <c10/core/ScalarType.h>
41 #include <c10/macros/Macros.h>
42 #include <c10/util/TypeCast.h>
43
44 #ifdef __NVCC__
45 #define ASSERT_HOST_DEVICE_LAMBDA(type) \
46 static_assert( \
47 __nv_is_extended_host_device_lambda_closure_type(type), \
48 #type " must be a __host__ __device__ lambda")
49 #else
50 #define ASSERT_HOST_DEVICE_LAMBDA(type)
51 #endif
52
53 namespace at {
54 namespace native {
55
56 template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads ())57 C10_LAUNCH_BOUNDS_1(num_threads())
58 __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
59 using traits = function_traits<func_t>;
60 int remaining = N - block_work_size() * blockIdx.x;
61
62 if (remaining < block_work_size()) { // if this block handles the reminder,
63 // just do a naive unrolled loop
64 auto input_calc = TrivialOffsetCalculator<traits::arity>();
65 auto output_calc = TrivialOffsetCalculator<1>();
66 auto loader = memory::LoadWithoutCast();
67 auto storer = memory::StoreWithoutCast();
68 auto policy = memory::policies::unroll<
69 array_t,
70 decltype(input_calc),
71 decltype(output_calc),
72 memory::LoadWithoutCast,
73 memory::StoreWithoutCast>(
74 data, remaining, input_calc, output_calc, loader, storer);
75 elementwise_kernel_helper(f, policy);
76 } else { // if this block has a full `block_work_size` data to handle, use
77 // vectorized memory access
78 elementwise_kernel_helper(
79 f, memory::policies::vectorized<vec_size, array_t>(data));
80 }
81 }
82
83 template <
84 typename func_t,
85 typename array_t,
86 typename inp_calc_t,
87 typename out_calc_t,
88 typename loader_t,
89 typename storer_t>
C10_LAUNCH_BOUNDS_1(num_threads ())90 C10_LAUNCH_BOUNDS_1(num_threads())
91 __global__ void unrolled_elementwise_kernel(
92 int N,
93 func_t f,
94 array_t data,
95 inp_calc_t ic,
96 out_calc_t oc,
97 loader_t l,
98 storer_t s) {
99 int remaining = N - block_work_size() * blockIdx.x;
100 auto policy = memory::policies::
101 unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
102 data, remaining, ic, oc, l, s);
103 elementwise_kernel_helper(f, policy);
104 }
105
106 // this function assume trivial 1d and no dynamic casting
107 template <typename func_t, typename array_t>
launch_vectorized_kernel(int64_t N,const func_t & f,array_t data)108 static inline void launch_vectorized_kernel(
109 int64_t N,
110 const func_t& f,
111 array_t data) {
112 TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
113 using traits = function_traits<func_t>;
114 int64_t grid = (N + block_work_size() - 1) / block_work_size();
115 auto stream = at::cuda::getCurrentCUDAStream();
116 int vec_size = memory::can_vectorize_up_to<func_t>(data);
117
118 switch (vec_size) {
119 case 4:
120 vectorized_elementwise_kernel<4, func_t, array_t>
121 <<<grid, num_threads(), 0, stream>>>(N, f, data);
122 C10_CUDA_KERNEL_LAUNCH_CHECK();
123 break;
124 case 2:
125 vectorized_elementwise_kernel<2, func_t, array_t>
126 <<<grid, num_threads(), 0, stream>>>(N, f, data);
127 C10_CUDA_KERNEL_LAUNCH_CHECK();
128 break;
129 case 1: {
130 auto input_calc = TrivialOffsetCalculator<traits::arity>();
131 auto output_calc = TrivialOffsetCalculator<1>();
132 auto loader = memory::LoadWithoutCast();
133 auto storer = memory::StoreWithoutCast();
134 unrolled_elementwise_kernel<func_t, array_t>
135 <<<grid, num_threads(), 0, stream>>>(
136 N, f, data, input_calc, output_calc, loader, storer);
137 C10_CUDA_KERNEL_LAUNCH_CHECK();
138 break;
139 }
140 default:
141 TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
142 }
143 }
144
145 template <
146 typename func_t,
147 typename array_t,
148 typename inp_calc_t,
149 typename out_calc_t,
150 typename loader_t,
151 typename storer_t>
launch_unrolled_kernel(int64_t N,const func_t & f,array_t data,inp_calc_t ic,out_calc_t oc,loader_t l,storer_t s)152 static inline void launch_unrolled_kernel(
153 int64_t N,
154 const func_t& f,
155 array_t data,
156 inp_calc_t ic,
157 out_calc_t oc,
158 loader_t l,
159 storer_t s) {
160 TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
161 int64_t grid = (N + block_work_size() - 1) / block_work_size();
162 auto stream = at::cuda::getCurrentCUDAStream();
163 unrolled_elementwise_kernel<func_t, array_t>
164 <<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
165 C10_CUDA_KERNEL_LAUNCH_CHECK();
166 }
167
168 template <int nt, int vt, typename func_t>
169 C10_LAUNCH_BOUNDS_2(nt, 4)
elementwise_kernel(int N,func_t f)170 __global__ void elementwise_kernel(int N, func_t f) {
171 int tid = threadIdx.x;
172 int nv = nt * vt;
173 int idx = nv * blockIdx.x + tid;
174 #pragma unroll
175 for (int i = 0; i < vt; i++) {
176 if (idx < N) {
177 f(idx);
178 idx += nt;
179 }
180 }
181 }
182
183 template <int nt, int vt, typename func_t>
launch_legacy_kernel(int64_t N,const func_t & f)184 static void launch_legacy_kernel(int64_t N, const func_t& f) {
185 TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
186 if (N == 0) {
187 return;
188 }
189 dim3 block(nt);
190 dim3 grid((N + block.x * vt - 1) / (block.x * vt));
191 auto stream = at::cuda::getCurrentCUDAStream();
192 elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
193 C10_CUDA_KERNEL_LAUNCH_CHECK();
194 }
195
196 template <typename traits, typename func_t, typename index_t, size_t... INDEX>
invoke_impl(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],int i,std::index_sequence<INDEX...>)197 C10_HOST_DEVICE typename traits::result_type invoke_impl(
198 const func_t& f,
199 char* const C10_RESTRICT data[],
200 const index_t strides[],
201 int i,
202 std::index_sequence<INDEX...>) {
203 (void)strides;
204 (void)i;
205 return f(c10::load<typename traits::template arg<INDEX>::type>(
206 data[INDEX] + i * strides[INDEX])...);
207 }
208
209 template <
210 typename func_t,
211 typename index_t,
212 typename traits = function_traits<func_t>>
invoke(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],int i)213 C10_HOST_DEVICE typename traits::result_type invoke(
214 const func_t& f,
215 char* const C10_RESTRICT data[],
216 const index_t strides[],
217 int i) {
218 using Indices = std::make_index_sequence<traits::arity>;
219 return invoke_impl<traits>(f, data, strides, i, Indices{});
220 }
221
222 template <typename traits, typename func_t, typename index_t, size_t... I>
invoke_impl(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],const ScalarType dtypes[],int i,std::index_sequence<I...>)223 C10_HOST_DEVICE typename traits::result_type invoke_impl(
224 const func_t& f,
225 char* const C10_RESTRICT data[],
226 const index_t strides[],
227 const ScalarType dtypes[],
228 int i,
229 std::index_sequence<I...>) {
230 (void)strides;
231 (void)i;
232 return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(
233 dtypes[I], data[I] + i * strides[I])...);
234 }
235
236 template <
237 typename func_t,
238 typename index_t,
239 typename traits = function_traits<func_t>>
invoke(const func_t & f,char * const C10_RESTRICT data[],const index_t strides[],const ScalarType dtypes[],int i)240 C10_HOST_DEVICE typename traits::result_type invoke(
241 const func_t& f,
242 char* const C10_RESTRICT data[],
243 const index_t strides[],
244 const ScalarType dtypes[],
245 int i) {
246 using Indices = std::make_index_sequence<traits::arity>;
247 return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
248 }
249
250 template <typename func_t>
gpu_kernel_impl_nocast(TensorIteratorBase & iter,const func_t & f)251 void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
252 using traits = function_traits<func_t>;
253 using arg0_t = typename traits::result_type;
254 constexpr int ntensors = traits::arity + 1;
255
256 TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
257 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
258 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
259 TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
260
261 at::detail::Array<char*, ntensors> data;
262 for (int i = 0; i < ntensors; i++) {
263 data[i] = (char*)iter.data_ptr(i);
264 }
265
266 int64_t numel = iter.numel();
267
268 bool contiguous = iter.is_contiguous();
269
270 if (contiguous) {
271 return launch_vectorized_kernel(numel, f, data);
272 }
273 auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
274 constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
275 launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
276 auto offsets = offset_calc.get(idx);
277 arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
278 *out = invoke(f, &data.data[1], &offsets.data[1], 1);
279 });
280 }
281
282 template <typename func_t>
gpu_kernel_impl(TensorIteratorBase & iter,const func_t & f)283 void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
284 if (!needs_dynamic_casting<func_t>::check(iter)) {
285 return gpu_kernel_impl_nocast(iter, f);
286 }
287 using traits = function_traits<func_t>;
288 using arg0_t = typename traits::result_type;
289 constexpr int ntensors = traits::arity + 1;
290
291 TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
292 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
293 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
294
295 at::detail::Array<char*, ntensors> data;
296 for (int i = 0; i < ntensors; i++) {
297 data[i] = (char*)iter.data_ptr(i);
298 }
299
300 int64_t numel = iter.numel();
301
302 bool contiguous = iter.is_contiguous();
303
304 if (contiguous) {
305 #ifdef USE_ROCM
306 at::detail::Array<ScalarType, ntensors> dtypes;
307 auto inner_strides = iter.get_inner_strides();
308 at::detail::Array<int, ntensors> strides;
309 for (int i = 0; i < ntensors; i++) {
310 dtypes[i] = iter.dtype(i);
311 strides[i] = inner_strides[i];
312 }
313 launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
314 void* out = data[0] + strides[0] * idx;
315 arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
316 c10::cast_and_store<arg0_t>(dtypes[0], out, result);
317 });
318 #else
319 auto loader = memory::LoadWithCast<traits::arity>(iter);
320 auto storer = memory::StoreWithCast<1>(iter);
321 auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
322 auto output_offset_calculator = TrivialOffsetCalculator<1>();
323 launch_unrolled_kernel(
324 numel,
325 f,
326 data,
327 input_offset_calculator,
328 output_offset_calculator,
329 loader,
330 storer);
331 #endif
332 } else {
333 at::detail::Array<ScalarType, ntensors> dtypes;
334 for (int i = 0; i < ntensors; i++) {
335 dtypes[i] = iter.dtype(i);
336 }
337 auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
338 launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
339 auto offsets = offset_calc.get(idx);
340 void* out = data[0] + offsets[0];
341 arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
342 c10::cast_and_store<arg0_t>(dtypes[0], out, result);
343 });
344 }
345 }
346
347 } // namespace native
348 } // namespace at
349