1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/cuda/CUDAGeneratorImpl.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Utils.h>
7 #include <ATen/cuda/detail/IndexUtils.cuh>
8 #include <ATen/cuda/detail/TensorInfo.cuh>
9 #include <ATen/cuda/CUDAGraphsUtils.cuh>
10 #include <c10/macros/Macros.h>
11 #include <curand_kernel.h>
12
13 #include <ATen/native/TensorIterator.h>
14 #include <ATen/native/cuda/Loops.cuh>
15 #include <ATen/native/cuda/MemoryAccess.cuh>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #include <ATen/NativeFunctions.h>
20 #else
21 #include <ATen/ops/_masked_scale_native.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/native_dropout_backward_native.h>
24 #include <ATen/ops/ones_like.h>
25 #include <ATen/ops/zeros_like.h>
26 #endif
27
28 namespace at::native {
29
30 namespace {
31
32 // philox generates 128 bits of randomness at a time. Kernel uses this explicitly by putting suitably transformed result into float4
33 // for all members of float4 to be consumed UNROLL has to be 4. Don't change!
34 // Note: VEC <= 4 (and in most real-world cases will be 4), so same logic applies.
35 const int UNROLL = 4;
36
37 template <
38 typename scalar_t,
39 typename accscalar_t,
40 typename IndexType,
41 int ADims,
42 int VEC,
43 typename mask_t>
44 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
45 C10_LAUNCH_BOUNDS_2(256, 4)
46 #endif
47 __global__ void
fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t,IndexType> a,at::cuda::detail::TensorInfo<scalar_t,IndexType> b,at::cuda::detail::TensorInfo<mask_t,IndexType> c,IndexType totalElements,accscalar_t p,PhiloxCudaState philox_args)48 fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType> a,
49 at::cuda::detail::TensorInfo<scalar_t, IndexType> b,
50 at::cuda::detail::TensorInfo<mask_t, IndexType> c,
51 IndexType totalElements, accscalar_t p,
52 PhiloxCudaState philox_args) {
53 // make sure we don't break assumption that we can't have > 4 elements / thread
54 static_assert(VEC <= 4, "Value of VEC must be in [2, 4]");
55
56 using LoadT = memory::aligned_vector<scalar_t, VEC>;
57 using MaskLoadT = memory::aligned_vector<mask_t, VEC>;
58
59 auto seeds = at::cuda::philox::unpack(philox_args);
60 IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
61 curandStatePhilox4_32_10_t state;
62 curand_init(std::get<0>(seeds),
63 idx,
64 std::get<1>(seeds),
65 &state);
66
67 // Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
68 // in the vec=2 and vec=4 cases.
69 bool gridxvec_loop_state = 0;
70 accscalar_t scale = 1.0 / p;
71
72 float4 rand;
73
74 // Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
75 for (IndexType linearIndex = idx * VEC;
76 linearIndex < totalElements;
77 linearIndex += gridDim.x * blockDim.x * VEC) {
78 // local storage
79 scalar_t src[VEC];
80 // We'll use this to actually cause vectorized loads later
81 LoadT *value = reinterpret_cast<LoadT*>(&src);
82
83 //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
84 // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
85 // sets of rand.
86 if ((VEC == 4) || (gridxvec_loop_state == 0)) {
87 rand = curand_uniform4(&state);
88 } else {
89 // sets up the last two values we generated last iteration to be used this iteration.
90 rand.x = rand.z;
91 rand.y = rand.w;
92 gridxvec_loop_state ^= 1;
93 }
94
95 rand.x = rand.x < p;
96 rand.y = rand.y < p;
97 if (VEC == 4) {
98 rand.z = rand.z < p;
99 rand.w = rand.w < p;
100 }
101
102 // Note: We explicitly check for is_contiguous() before launching the vectorized kernel
103 // and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other)
104 // ordering.
105 // Single vectorized load
106 *value = *reinterpret_cast<const LoadT*>(&a.data[linearIndex]);
107
108 scalar_t r[VEC];
109 mask_t mask[VEC];
110
111 // Perform the actual computation
112 #pragma unroll
113 for (int ii = 0; ii < VEC; ii++) {
114 r[ii] = src[ii]*(&rand.x)[ii]*scale;
115 mask[ii] = (mask_t)(&rand.x)[ii];
116 }
117 // Vectorized writes for both mask & result
118 *(reinterpret_cast<LoadT*>(&b.data[linearIndex])) = *reinterpret_cast<LoadT*>(&r[0]);
119 *(reinterpret_cast<MaskLoadT*>(&c.data[linearIndex])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
120
121 __syncthreads();
122 }
123 }
124
125 template <
126 typename scalar_t,
127 typename accscalar_t,
128 typename IndexType,
129 int ADims,
130 int BDims = ADims,
131 typename mask_t>
132 #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
133 C10_LAUNCH_BOUNDS_2(256, 4)
134 #endif
135 __global__ void
fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t,IndexType> a,cuda::detail::TensorInfo<scalar_t,IndexType> b,cuda::detail::TensorInfo<mask_t,IndexType> c,IndexType totalElements,accscalar_t p,PhiloxCudaState philox_args)136 fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t, IndexType> a,
137 cuda::detail::TensorInfo<scalar_t, IndexType> b,
138 cuda::detail::TensorInfo<mask_t, IndexType> c,
139 IndexType totalElements, accscalar_t p,
140 PhiloxCudaState philox_args) {
141 auto seeds = at::cuda::philox::unpack(philox_args);
142 IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
143 curandStatePhilox4_32_10_t state;
144 curand_init(std::get<0>(seeds),
145 idx,
146 std::get<1>(seeds),
147 &state);
148 accscalar_t scale = 1.0 / p;
149
150 IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
151 blockDim.x * gridDim.x * UNROLL;
152 for (IndexType linearIndex = idx;
153 linearIndex < rounded_size;
154 linearIndex += gridDim.x * blockDim.x*UNROLL) {
155 //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
156 float4 rand = curand_uniform4(&state);
157 scalar_t src[UNROLL];
158 rand.x = rand.x < p;
159 rand.y = rand.y < p;
160 rand.z = rand.z < p;
161 rand.w = rand.w < p;
162 for (int ii = 0; ii < UNROLL; ii++) {
163 IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
164 if (li < totalElements) {
165 // Convert `linearIndex` into an offset of `a`
166 const IndexType aOffset =
167 cuda::detail::IndexToOffset<const scalar_t, IndexType, ADims>::get(li, a);
168 src[ii] = a.data[aOffset];
169 }
170 }
171 for (int ii = 0; ii < UNROLL; ii++) {
172 IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
173 if (li < totalElements) {
174 // Convert `linearIndex` into an offset of `b`
175 const IndexType bOffset =
176 cuda::detail::IndexToOffset<scalar_t, IndexType, BDims>::get(li, b);
177 b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale;
178 c.data[bOffset] = (mask_t)(&rand.x)[ii];
179 }
180 }
181 __syncthreads();
182 }
183 }
184
185 template<typename mask_t, typename scalar_t, typename accscalar_t>
masked_scale_kernel(at::Tensor & ret,const at::Tensor & src,const at::Tensor & mask,accscalar_t scale)186 void masked_scale_kernel(at::Tensor& ret, const at::Tensor& src, const at::Tensor& mask, accscalar_t scale){
187 auto iter = at::TensorIteratorConfig()
188 .check_all_same_dtype(false)
189 .add_output(ret)
190 .add_const_input(src)
191 .add_const_input(mask)
192 .build();
193
194 at::native::gpu_kernel(
195 iter,
196 [=]GPU_LAMBDA(const scalar_t src_val, const mask_t mask_val) -> scalar_t {
197 return (float)mask_val * src_val * scale;
198 });
199 }
200
201 template <typename scalar_t>
get_vector_size(at::Tensor self,at::Tensor ret,at::Tensor mask)202 int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
203 int vec_size = 4;
204 // get the vector size
205 if (!self.is_non_overlapping_and_dense() || !ret.is_non_overlapping_and_dense() || !mask.is_non_overlapping_and_dense()) {
206 vec_size = 1;
207 } else {
208 vec_size = memory::can_vectorize_up_to<scalar_t>((const char*)self.const_data_ptr());
209 }
210
211 // check that we'd have no remainders - prefer a smaller vector size with no remainders over a larger vector and remainder.
212 bool can_vectorize = true;
213 do {
214 can_vectorize = self.numel() % vec_size == 0 && ret.numel() % vec_size == 0 && mask.numel() % vec_size == 0;
215 if (!can_vectorize) vec_size /= 2;
216 } while (vec_size > 1 && !can_vectorize);
217 return can_vectorize ? vec_size : 1;
218 }
219
220 template <typename index_type, typename mask_t>
launcher(const Tensor & self,Tensor & ret,Tensor & mask,double p,const int64_t nelem,const PhiloxCudaState rng_engine_inputs,dim3 grid,dim3 dim_block)221 inline void launcher(
222 const Tensor& self,
223 Tensor& ret,
224 Tensor& mask,
225 double p,
226 const int64_t nelem,
227 const PhiloxCudaState rng_engine_inputs,
228 dim3 grid,
229 dim3 dim_block) {
230 AT_DISPATCH_FLOATING_TYPES_AND2(
231 at::ScalarType::Half,
232 at::ScalarType::BFloat16,
233 self.scalar_type(),
234 "fused_dropout",
235 [&] {
236 using accscalar_t = acc_type<scalar_t, true>;
237 accscalar_t pa = (accscalar_t)(p);
238 auto self_info =
239 cuda::detail::getTensorInfo<const scalar_t, index_type>(self);
240 auto ret_info =
241 cuda::detail::getTensorInfo<scalar_t, index_type>(ret);
242 auto mask_info =
243 cuda::detail::getTensorInfo<mask_t, index_type>(mask);
244 self_info.collapseDims();
245 ret_info.collapseDims();
246 mask_info.collapseDims(); // ret and mask are collapsed to 1d
247 // contiguous tensor
248
249 int vec_size = get_vector_size<scalar_t>(self, ret, mask);
250
251 if (vec_size > 1) {
252 switch (vec_size) {
253 case 4:
254 fused_dropout_kernel_vec<
255 scalar_t,
256 accscalar_t,
257 index_type,
258 1,
259 4>
260 <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
261 self_info,
262 ret_info,
263 mask_info,
264 nelem,
265 pa,
266 rng_engine_inputs);
267 C10_CUDA_KERNEL_LAUNCH_CHECK();
268 break;
269 case 2:
270 fused_dropout_kernel_vec<
271 scalar_t,
272 accscalar_t,
273 index_type,
274 1,
275 2>
276 <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
277 self_info,
278 ret_info,
279 mask_info,
280 nelem,
281 pa,
282 rng_engine_inputs);
283 C10_CUDA_KERNEL_LAUNCH_CHECK();
284 break;
285 }
286 } else {
287 switch (self_info.dims) {
288 case 1:
289 fused_dropout_kernel<scalar_t, accscalar_t, index_type, 1>
290 <<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
291 self_info,
292 ret_info,
293 mask_info,
294 nelem,
295 pa,
296 rng_engine_inputs);
297 C10_CUDA_KERNEL_LAUNCH_CHECK();
298 break;
299 default:
300 if (!self.is_contiguous() && ret.is_contiguous() &&
301 mask.is_contiguous()) {
302 fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1, 1>
303 <<<grid,
304 dim_block,
305 0,
306 at::cuda::getCurrentCUDAStream()>>>(
307 self_info,
308 ret_info,
309 mask_info,
310 nelem,
311 pa,
312 rng_engine_inputs);
313 C10_CUDA_KERNEL_LAUNCH_CHECK();
314 } else {
315 fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1>
316 <<<grid,
317 dim_block,
318 0,
319 at::cuda::getCurrentCUDAStream()>>>(
320 self_info,
321 ret_info,
322 mask_info,
323 nelem,
324 pa,
325 rng_engine_inputs);
326 C10_CUDA_KERNEL_LAUNCH_CHECK();
327 }
328 }
329 }
330 });
331 }
332
333 } //anonymous namespace
334
335 template <typename mask_t>
336 std::tuple<Tensor,Tensor>
dropout_cuda(CUDAGeneratorImpl * gen,const Tensor & self,double p)337 dropout_cuda(CUDAGeneratorImpl* gen, const Tensor& self, double p){
338 Tensor mask = at::empty_like(self, self.options().dtype(c10::CppTypeToScalarType<mask_t>::value));
339 const int64_t nelem = self.numel();
340 // empty tensors should not get here, but just in case, avoid FPE
341 // non-training shot-cut
342 if (nelem==0) return std::tuple<Tensor,Tensor>(self.clone(), mask);
343
344 Tensor ret = at::empty_like(self);
345 const int64_t block_size = 256;
346 unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
347 dim3 dim_block(block_size);
348 dim3 grid((nelem + block_size -1)/block_size);
349 grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
350 //number of times random will be generated per thread, to offset philox counter in thc random state
351 int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
352 PhiloxCudaState rng_engine_inputs;
353 {
354 // See Note [Acquire lock when using random generators]
355 std::lock_guard<std::mutex> lock(gen->mutex_);
356 rng_engine_inputs = gen->philox_cuda_state(counter_offset);
357 }
358 if (cuda::detail::canUse32BitIndexMath(self)){
359 launcher<unsigned int, mask_t>(
360 self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
361 } else {
362 launcher<uint64_t, mask_t>(
363 self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
364 }
365 return std::tuple<Tensor,Tensor>(ret, mask);
366 }
367
368 std::tuple<Tensor,Tensor>
native_dropout_cuda(const Tensor & self,double p,std::optional<bool> train)369 native_dropout_cuda(const Tensor& self, double p, std::optional<bool> train){
370 // short-cut for train == false
371 if (train.has_value() && !train.value()) {
372 return std::make_tuple(self.clone(), at::ones_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value)));
373 }
374 // short-cut
375 if (p == 1) {
376 // native_dropout_cuda is in derivatives.yaml, so we don't need to add data
377 // dependency from output to input for autograd
378 auto ret = at::zeros_like(self);
379 auto mask = at::zeros_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value));
380 return std::tuple<Tensor,Tensor>(ret, mask);
381 }
382
383 auto gen = get_generator_or_default<CUDAGeneratorImpl>(std::nullopt, cuda::detail::getDefaultCUDAGenerator());
384 double p1m = 1. - p;
385 return dropout_cuda<bool>(gen, self, p1m);
386 }
387
388 // TODO: _fused_dropout_cuda is to be removed, see PR #63937
389 std::tuple<Tensor,Tensor>
fused_dropout_cuda(const Tensor & self,double p,std::optional<Generator> gen_)390 fused_dropout_cuda(const Tensor& self, double p, std::optional<Generator> gen_){
391 auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
392 return dropout_cuda<uint8_t>(gen, self, p);
393 }
394
395 template <typename mask_t>
dropout_backward_cuda(const Tensor & grad,const Tensor & mask,double scale)396 Tensor dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
397 Tensor ret = at::empty_like(grad, grad.suggest_memory_format());
398 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] {
399 using accscalar_t = acc_type<scalar_t, true>;
400 masked_scale_kernel<mask_t, scalar_t>(ret, grad, mask, (accscalar_t)scale);
401 });
402 return ret;
403 }
404
native_dropout_backward_cuda(const Tensor & grad,const Tensor & mask,double scale)405 Tensor native_dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
406 TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Mask should be Bool Scalar Type", mask.scalar_type());
407 return dropout_backward_cuda<bool>(grad, mask, scale);
408 }
409
410 // TODO: masked_scale_cuda is to be removed, see PR #63937
masked_scale_cuda(const Tensor & self,const Tensor & mask,double scale)411 Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
412 TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
413 return dropout_backward_cuda<uint8_t>(self, mask, scale);
414 }
415
416 } // namespace at::native
417