xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/MemoryAccess.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <type_traits>
5 #include <c10/core/DynamicCast.h>
6 #include <c10/util/Exception.h>
7 #include <c10/util/TypeCast.h>
8 #include <c10/macros/Macros.h>
9 #include <ATen/core/Array.h>
10 #include <ATen/detail/FunctionTraits.h>
11 #include <ATen/cuda/detail/OffsetCalculator.cuh>
12 #include <ATen/native/cuda/thread_constants.h>
13 
14 #include <thrust/tuple.h>
15 
16 // References:
17 // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
18 
19 namespace at { namespace native { namespace memory {
20 
21 namespace detail {
22 
23 // What does the `static_unroll` do?
24 //
25 // We want to do something like:
26 //
27 //    using args_t = typename traits::ArgsTuple;
28 //    args_t args;
29 //    #pragma unroll
30 //    for (int i = 0; i < traits::arity; i++) {
31 //      std::get<i>(args) = ....
32 //    }
33 //
34 // but unfortunately the above code does not work because
35 // the template argument has to be a compile time constant
36 // so `static_unroll` is created to simulate `#pragma unroll`
37 // using template metaprogramming.
38 
39 template<template<int i> typename func, int end, int current=0>
40 struct static_unroll {
41   template<typename... Args>
with_argsat::native::memory::detail::static_unroll42   static inline C10_HOST_DEVICE void with_args(Args&&... args) {
43     func<current>::apply(std::forward<Args>(args)...);
44     static_unroll<func, end, current+1>::with_args(args...);
45   }
46 };
47 
48 template<template<int i> typename func, int end>
49 struct static_unroll<func, end, end> {
50   template<typename... Args>
with_argsat::native::memory::detail::static_unroll51   static inline C10_HOST_DEVICE void with_args(Args... args) {}
52 };
53 
54 // helper structs to be used with static_unroll to load arguments
55 // one by one
56 
57 template<int arg_index>
58 struct vectorized_load_helper {
59   template <typename args_t, typename policy_t>
applyat::native::memory::detail::vectorized_load_helper60   static __device__ void apply(policy_t &self, args_t *args, int idx) {
61     using arg_t = std::tuple_element_t<arg_index, args_t>;
62     // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
63     // need a +1 offset to get the input
64     auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
65     auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
66     self.load_single_arg(args_accessor, ptr);
67   }
68 };
69 
70 template<int arg_index>
71 struct unroll_load_helper {
72   template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
applyat::native::memory::detail::unroll_load_helper73   static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
74     using arg_t = std::tuple_element_t<arg_index, args_t>;
75     // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
76     // need a +1 offset to get the input
77     std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
78   }
79 };
80 
81 template <int current>
82 struct multi_outputs_store_helper {
83   template<int ntensors, int num_outputs, typename ...Args>
applyat::native::memory::detail::multi_outputs_store_helper84   C10_HOST_DEVICE static void apply(
85       at::detail::Array<char*, ntensors> data,
86       at::detail::Array<uint32_t, num_outputs> offsets,
87       thrust::tuple<Args...> ret) {
88     using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
89     T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
90     *to = thrust::get<current>(ret);
91   }
92 };
93 
94 }  // namespace detail
95 
96 struct LoadWithoutCast {
97   template<typename scalar_t>
loadat::native::memory::LoadWithoutCast98   __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
99     return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
100   }
101 };
102 
103 template <int N>
104 struct LoadWithCast {
105   using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
106   using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
107 
108   array_t dtypes;
109   size_array_t element_sizes;
110 
LoadWithCastat::native::memory::LoadWithCast111   LoadWithCast(const TensorIteratorBase& iter) {
112     CUDA_KERNEL_ASSERT(iter.ninputs() == N);
113     #pragma unroll
114     for (auto i = 0; i < N; ++i) {
115       this->dtypes[i] = iter.dtype(i + iter.noutputs());
116       element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
117     }
118   }
119 
120   template<typename scalar_t>
loadat::native::memory::LoadWithCast121   __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
122     void *ptr = base_ptr + element_sizes[arg] * offset;
123     return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
124   }
125 };
126 
127 struct StoreWithoutCast {
128   template<typename scalar_t>
storeat::native::memory::StoreWithoutCast129   __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
130     *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
131   }
132 };
133 
134 template <int N = 1>
135 struct StoreWithCast {
136   using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
137   using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
138 
139   array_t dtypes;
140   size_array_t element_sizes;
141 
StoreWithCastat::native::memory::StoreWithCast142   StoreWithCast(const TensorIteratorBase& iter) {
143     CUDA_KERNEL_ASSERT(iter.noutputs() == N);
144     #pragma unroll
145     for (auto i = 0; i < N; ++i) {
146       this->dtypes[i] = iter.dtype(i);
147       element_sizes[i] = c10::elementSize(iter.dtype(i));
148     }
149   }
150 
151   template<typename scalar_t>
storeat::native::memory::StoreWithCast152   __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
153     void *ptr = base_ptr + element_sizes[arg] * offset;
154     c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
155   }
156 };
157 
158 // aligned vector generates vectorized load/store on CUDA
159 template<typename scalar_t, int vec_size>
160 struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
161   scalar_t val[vec_size];
162 };
163 
164 template <int vec_size, typename scalar_t>
load_vector(const scalar_t * base_ptr,uint32_t offset)165 __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
166   using vec_t = aligned_vector<scalar_t, vec_size>;
167   auto *from = reinterpret_cast<const vec_t *>(base_ptr);
168   return from[offset];
169 }
170 
171 template <int vec_size>
load_vector(const bool * base_ptr,uint32_t offset)172 __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
173   // See NOTE [Loading boolean values]
174   auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
175   aligned_vector<bool, vec_size> ret;
176   for (int i = 0; i < vec_size; ++i) {
177     ret.val[i] = bool(tmp.val[i]);
178   }
179   return ret;
180 }
181 
182 namespace policies {
183 
184 // Assumption:
185 // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
186 template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
187 struct unroll {
188 
189   data_t data;
190   int remaining;
191   inp_calc_t input_offset_calculator;
192   out_calc_t output_offset_calculator;
193   loader_t loader;
194   storer_t storer;
195 
unrollat::native::memory::policies::unroll196   __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
197     data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
198 
check_inboundsat::native::memory::policies::unroll199   __device__ inline bool check_inbounds(int thread_work_elem) {
200     return ((int)(threadIdx.x  + thread_work_elem*num_threads()) < remaining);
201   }
202 
203   template<typename args_t>
loadat::native::memory::policies::unroll204   __device__ inline void load(args_t *args, int idx) {
205     constexpr int arity = std::tuple_size<args_t>::value;
206     int thread_idx = threadIdx.x;
207     #pragma unroll
208     for (int i = 0; i < thread_work_size(); i++) {
209       if (thread_idx >= remaining) {
210         return;
211       }
212       int linear_idx = thread_idx + block_work_size() * idx;
213       auto offset = input_offset_calculator.get(linear_idx);
214       detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
215       thread_idx += num_threads();
216     }
217   }
218 
219   template<typename scalar_t>
storeat::native::memory::policies::unroll220   __device__ inline void store(scalar_t *from, int idx) {
221     int thread_idx = threadIdx.x;
222     #pragma unroll
223     for (int i = 0; i < thread_work_size(); i++) {
224       if (thread_idx >= remaining) {
225         return;
226       }
227       int linear_idx = thread_idx + block_work_size() * idx;
228       int offset = output_offset_calculator.get(linear_idx)[0];
229       storer.store(from[i], data[0], offset);
230       thread_idx += num_threads();
231     }
232   }
233 };
234 
235 // Assumption:
236 // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
237 // Note:
238 // Functions in vectorized policy does not do boundary check. It assumes the whole block
239 // has its job to do. So the reminders should be handled by the caller manually.
240 template <int vec_size, typename data_t>  // vec_size: number of scalars, can be 1, 2, or 4.
241 struct vectorized {
242 
243   static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
244   static constexpr int loop_size = thread_work_size() / vec_size;
245 
246   data_t data;
247 
vectorizedat::native::memory::policies::vectorized248   __device__ vectorized(data_t data) : data(data) {}
249 
check_inboundsat::native::memory::policies::vectorized250   __device__ inline constexpr bool check_inbounds(int thread_work_elem) {
251     return true;
252   }
253 
254   template<typename accessor_t, typename scalar_t>
load_single_argat::native::memory::policies::vectorized255   __device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
256     int thread_idx = threadIdx.x;
257     #pragma unroll
258     for (int i = 0; i < loop_size; i++) {
259       int index = thread_idx + i * num_threads();
260       auto v = load_vector<vec_size>(from, index);
261       #pragma unroll
262       for (int j = 0; j < vec_size; j++) {
263         to(vec_size * i + j) = v.val[j];
264       }
265     }
266   }
267 
268   template<typename args_t>
loadat::native::memory::policies::vectorized269   __device__ inline void load(args_t *args, int idx) {
270     constexpr int arity = std::tuple_size<args_t>::value;
271     detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
272   }
273 
274   template<typename scalar_t>
storeat::native::memory::policies::vectorized275   __device__ inline void store(scalar_t *from, int idx) {
276     using vec_t = aligned_vector<scalar_t, vec_size>;
277     scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
278     vec_t *to_ = reinterpret_cast<vec_t *>(to);
279     int thread_idx = threadIdx.x;
280     #pragma unroll
281     for (int i = 0; i < loop_size; i++) {
282       int index = thread_idx + i * num_threads();
283       vec_t v;
284       for (int j = 0; j < vec_size; j++) {
285         v.val[j] = from[vec_size * i + j];
286       }
287       to_[index] = v;
288     }
289   }
290 };
291 
292 template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
293 struct multi_outputs_unroll {
294   //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
295   //we don't use inheritance because of compiler bug in cuda 10.2+
296   data_t data;
297   int remaining;
298   inp_calc_t input_offset_calculator;
299   out_calc_t output_offset_calculator;
300   LoadWithoutCast loader;
301   StoreWithoutCast storer;
302 
multi_outputs_unrollat::native::memory::policies::multi_outputs_unroll303   __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
304   data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
305 
check_inboundsat::native::memory::policies::multi_outputs_unroll306   __device__ inline bool check_inbounds(int thread_work_elem) {
307     return ((int)(threadIdx.x  + thread_work_elem*num_threads()) < remaining);
308   }
309 
310   template<typename args_t>
loadat::native::memory::policies::multi_outputs_unroll311   __device__ inline void load(args_t *args, int idx) {
312     constexpr int arity = std::tuple_size<args_t>::value;
313     int thread_idx = threadIdx.x;
314     #pragma unroll
315     for (int i = 0; i < thread_work_size(); i++) {
316       if (thread_idx >= remaining) {
317         return;
318       }
319       int linear_idx = thread_idx + block_work_size() * idx;
320       auto offset = input_offset_calculator.get(linear_idx);
321       detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
322       thread_idx += num_threads();
323     }
324   }
325 
326 
327   template <typename return_t>
storeat::native::memory::policies::multi_outputs_unroll328   __device__ inline void store(return_t *from, int idx) {
329     int thread_idx = threadIdx.x;
330     #pragma unroll
331     for (int i = 0; i < thread_work_size(); i++) {
332       if (thread_idx >= this->remaining) {
333         return;
334       }
335       int linear_idx = thread_idx + block_work_size() * idx;
336       auto offsets = this->output_offset_calculator.get(linear_idx);
337       memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
338       thread_idx += num_threads();
339     }
340   }
341 };
342 
343 }  // namespace policies
344 
345 // This is only used in host, but we will wrap this into some templates
346 // which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
347 // in order to compile
348 template<typename scalar_t>
can_vectorize_up_to(const char * pointer)349 inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
350   uint64_t address = reinterpret_cast<uint64_t>(pointer);
351   constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
352   constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
353   if (address % vec4_alignment == 0) {
354     return 4;
355   } else if (address % vec2_alignment == 0) {
356     return 2;
357   }
358   return 1;
359 }
360 
361 template<typename scalar_t>
can_vectorize_up_to(char * pointer)362 inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
363   return can_vectorize_up_to<scalar_t>(static_cast<const char*>(pointer));
364 }
365 
366 template<int i>
367 struct can_vectorize_up_to_helper {
368   template <typename array_t, typename traits>
applyat::native::memory::can_vectorize_up_to_helper369   static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
370     using arg_t = typename traits::template arg<i>::type;
371     // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
372     // need a +1 offset to get the input
373     result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
374   }
375 };
376 
377 template<typename func_t, typename array_t>
can_vectorize_up_to(array_t pointers)378 inline int can_vectorize_up_to(array_t pointers) {
379   using traits = function_traits<func_t>;
380   using return_t = typename traits::result_type;
381   constexpr int arity = traits::arity;
382   int result = can_vectorize_up_to<return_t>(pointers[0]);
383   // We need to get the type for each argument of `func_t`, this can only
384   // be done at compile time.
385   detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
386   return result;
387 }
388 
389 }}} // namespace at::native::memory
390